-
Notifications
You must be signed in to change notification settings - Fork 7
[mlir-gen] Add mlir builders for llama3.1 and tests #13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Should this be in |
|
The e2e should be, yup, but this is mostly tests and getters. |
|
I moved the whole thing to examples and added attention the list of tests. |
rengolin
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
rolfmorel
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Have left some comments inline.
| - name: Run pytest-enabled examples as tests | ||
| run: | | ||
| uv run pytest python/examples |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we instead integrate with lit? That is, would it work if we just added the line # RUN: %PYTHON pytest %s at the top of test_llama3.py?
There's value in trying to preserve being able to just lit $PATH_WITHIN_PROJECT to run the respective tests (including PATH_WITHIN_PROJECT=. in the root).
| (get_matmul, (16, 16), "f32"), | ||
| (get_outer, (16,), "f32"), | ||
| ], | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My suggestion would be to define a decorator, something like:
def with_context_and_unknown_location(f: Callable):
def wrapped(*args, **kwargs):
with ir.Context(), ir.Location.unknown():
f(*args, **kwargs)
return wrappedJust annotate all your pytest tests with that and completely forget about needing to deal with context and location (and entering context managers) throughout the rest of the code. That is, most APIs will pick up the current context and location automatically. For the APIs that don't there's the mlir.ir.Context.current and mlir.ir.Location.current escape hatch.
Some context: #20 (comment) and #20 (comment)
| def create_pass_pipeline(ctx: ir.Context) -> PassManager: | ||
| with ctx: | ||
| pm = PassManager("builtin.module") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| def create_pass_pipeline(ctx: ir.Context) -> PassManager: | |
| with ctx: | |
| pm = PassManager("builtin.module") | |
| def create_pass_pipeline() -> PassManager: | |
| pm = PassManager("builtin.module") |
See comment below for how to elide dealing with contexts in most places.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The above holds for many functions in this file.
For this particular function, it should just become the suffix of the schedule, so we just have end-to-end schedules for the entire MLIR lowering that is happening.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See https://github.com/libxsmm/tpp-mlir/blob/37a498bd1e320e00fa50e3323cbaac2867cd7a1e/python/mlir/tpp/sched/bundles.py#L41-L43 for an example for dealing with passes that expect to run on particular ops w.r.t to the root module.
|
|
||
| # Create entry point transformation sequence. | ||
| with ir.InsertionPoint(schedule.body): | ||
| named_seq = transform.NamedSequenceOp( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| named_seq = transform.NamedSequenceOp( | |
| named_seq = transform.named_sequence( |
As a general principle we can now set for ourselves for the project right at the start: in case the upstream bindings already have a workable
snake_caseversion of an op, lets use that over theCamelCaseOpversion. The crux of the argument being that this will make the Python code look closer to a terse version of the MLIR textual format.
| [xq_scores_map, keys_scores_map, scores_map], | ||
| [parallel, parallel, parallel, parallel, reduction], | ||
| ) | ||
| def compute_scores(q_val, k_val, score_val): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could be written as a linalg.contract, right?
| [parallel] * 4, | ||
| ) | ||
| def scale_scores(score, _out): | ||
| return arith.MulFOp(score, scale_const).result |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| return arith.MulFOp(score, scale_const).result | |
| return arith.mulf(score, scale_const) |
If you use arith.mulf (or whatever the snake_case version is called), you should be able to elide the .result.
This holds generally (for single-result ops, that is).
| anytype, mod, "convert-linalg-to-loops" | ||
| ) | ||
| # Cleanup. | ||
| transform.ApplyCommonSubexpressionEliminationOp(mod) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe transform.apply_cse is the snake_case version, though I might be wrong.
| module = generate_module(ctx, ir_type) | ||
| bufferize_module(ctx, module) | ||
| schedule = create_schedule(ctx) | ||
| apply_schedule(module, schedule) | ||
| pm = create_pass_pipeline(ctx) | ||
| pm.run(module.operation) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| module = generate_module(ctx, ir_type) | |
| bufferize_module(ctx, module) | |
| schedule = create_schedule(ctx) | |
| apply_schedule(module, schedule) | |
| pm = create_pass_pipeline(ctx) | |
| pm.run(module.operation) | |
| module = generate_module(ctx, ir_type) | |
| schedule = create_schedule(ctx) | |
| apply_schedule(module, schedule) |
Just move the passes from inside bufferize_module(ctx, module) and create_pass_pipeline(ctx) into the start and end of the schedule, i.e. with transform.apply_registered_pass.
I know this antipattern originates in an example script we merged, but we should not let this proliferate. It clearly is already confusing people.
| return schedule | ||
|
|
||
|
|
||
| def apply_schedule(kernel: ir.Module, schedule: ir.Module) -> None: | ||
| interpreter.apply_named_sequence( | ||
| payload_root=kernel, | ||
| transform_root=schedule.body.operations[0], | ||
| transform_module=schedule, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| return schedule | |
| def apply_schedule(kernel: ir.Module, schedule: ir.Module) -> None: | |
| interpreter.apply_named_sequence( | |
| payload_root=kernel, | |
| transform_root=schedule.body.operations[0], | |
| transform_module=schedule, | |
| ) | |
| return named_seq |
If we do this, you can simply do:
schedule = create_schedule()
schedule.apply(module)
If you need access to the Module around the named_sequence, just ask for it's .parent.
| @@ -1,2 +1,2 @@ | |||
| import ctypes | |||
| import torch | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know this PR didn't introduce it, though looking at it now, I feel we should think about compartmentalizing code that depends on heavy dependencies a bit more. That is, not have it in the same module with code that doesn't have the dependency, e.g. get_packed_arg.
Putting up this dirty draft for early feedback/questions. I'm putting together some tests to run a e2e llama3.1 going through linalg on tensors. The goal is to generate some nice linalg that would be optimization friendly. At the moment, there are just functional blocks and pieces that are just smoke-tested. These include naive implementations for rotary embeddings, feed forward, rms, and a bunch of other small snippets that are useful to implement the model. These are already enough to put an attention block together. It'd be nice to test it against the original implementation, but that'd require fairscale as a dependency. For now I only added pytest and kept the pipeline as simple as possible. I also reused the example with the schedule, so now it is a part of every test.