Skip to content

Conversation

@kurapov-peter
Copy link

Putting up this dirty draft for early feedback/questions. I'm putting together some tests to run a e2e llama3.1 going through linalg on tensors. The goal is to generate some nice linalg that would be optimization friendly. At the moment, there are just functional blocks and pieces that are just smoke-tested. These include naive implementations for rotary embeddings, feed forward, rms, and a bunch of other small snippets that are useful to implement the model. These are already enough to put an attention block together. It'd be nice to test it against the original implementation, but that'd require fairscale as a dependency. For now I only added pytest and kept the pipeline as simple as possible. I also reused the example with the schedule, so now it is a part of every test.

@rengolin
Copy link
Member

Should this be in examples?

@kurapov-peter
Copy link
Author

The e2e should be, yup, but this is mostly tests and getters.

@kurapov-peter
Copy link
Author

I moved the whole thing to examples and added attention the list of tests.

Copy link
Member

@rengolin rengolin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Copy link
Contributor

@rolfmorel rolfmorel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Have left some comments inline.

- name: Run pytest-enabled examples as tests
run: |
uv run pytest python/examples
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we instead integrate with lit? That is, would it work if we just added the line # RUN: %PYTHON pytest %s at the top of test_llama3.py?

There's value in trying to preserve being able to just lit $PATH_WITHIN_PROJECT to run the respective tests (including PATH_WITHIN_PROJECT=. in the root).

(get_matmul, (16, 16), "f32"),
(get_outer, (16,), "f32"),
],
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My suggestion would be to define a decorator, something like:

def with_context_and_unknown_location(f: Callable):
    def wrapped(*args, **kwargs):
        with ir.Context(), ir.Location.unknown():
             f(*args, **kwargs)
    return wrapped

Just annotate all your pytest tests with that and completely forget about needing to deal with context and location (and entering context managers) throughout the rest of the code. That is, most APIs will pick up the current context and location automatically. For the APIs that don't there's the mlir.ir.Context.current and mlir.ir.Location.current escape hatch.

Some context: #20 (comment) and #20 (comment)

Comment on lines +45 to +47
def create_pass_pipeline(ctx: ir.Context) -> PassManager:
with ctx:
pm = PassManager("builtin.module")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def create_pass_pipeline(ctx: ir.Context) -> PassManager:
with ctx:
pm = PassManager("builtin.module")
def create_pass_pipeline() -> PassManager:
pm = PassManager("builtin.module")

See comment below for how to elide dealing with contexts in most places.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The above holds for many functions in this file.

For this particular function, it should just become the suffix of the schedule, so we just have end-to-end schedules for the entire MLIR lowering that is happening.

Copy link
Contributor

@rolfmorel rolfmorel Nov 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See https://github.com/libxsmm/tpp-mlir/blob/37a498bd1e320e00fa50e3323cbaac2867cd7a1e/python/mlir/tpp/sched/bundles.py#L41-L43 for an example for dealing with passes that expect to run on particular ops w.r.t to the root module.


# Create entry point transformation sequence.
with ir.InsertionPoint(schedule.body):
named_seq = transform.NamedSequenceOp(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
named_seq = transform.NamedSequenceOp(
named_seq = transform.named_sequence(

#10 (review):

As a general principle we can now set for ourselves for the project right at the start: in case the upstream bindings already have a workable snake_case version of an op, lets use that over the CamelCaseOp version. The crux of the argument being that this will make the Python code look closer to a terse version of the MLIR textual format.

[xq_scores_map, keys_scores_map, scores_map],
[parallel, parallel, parallel, parallel, reduction],
)
def compute_scores(q_val, k_val, score_val):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be written as a linalg.contract, right?

[parallel] * 4,
)
def scale_scores(score, _out):
return arith.MulFOp(score, scale_const).result
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return arith.MulFOp(score, scale_const).result
return arith.mulf(score, scale_const)

If you use arith.mulf (or whatever the snake_case version is called), you should be able to elide the .result.

This holds generally (for single-result ops, that is).

anytype, mod, "convert-linalg-to-loops"
)
# Cleanup.
transform.ApplyCommonSubexpressionEliminationOp(mod)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe transform.apply_cse is the snake_case version, though I might be wrong.

Comment on lines +1200 to +1205
module = generate_module(ctx, ir_type)
bufferize_module(ctx, module)
schedule = create_schedule(ctx)
apply_schedule(module, schedule)
pm = create_pass_pipeline(ctx)
pm.run(module.operation)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
module = generate_module(ctx, ir_type)
bufferize_module(ctx, module)
schedule = create_schedule(ctx)
apply_schedule(module, schedule)
pm = create_pass_pipeline(ctx)
pm.run(module.operation)
module = generate_module(ctx, ir_type)
schedule = create_schedule(ctx)
apply_schedule(module, schedule)

Just move the passes from inside bufferize_module(ctx, module) and create_pass_pipeline(ctx) into the start and end of the schedule, i.e. with transform.apply_registered_pass.

I know this antipattern originates in an example script we merged, but we should not let this proliferate. It clearly is already confusing people.

Comment on lines +115 to +123
return schedule


def apply_schedule(kernel: ir.Module, schedule: ir.Module) -> None:
interpreter.apply_named_sequence(
payload_root=kernel,
transform_root=schedule.body.operations[0],
transform_module=schedule,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return schedule
def apply_schedule(kernel: ir.Module, schedule: ir.Module) -> None:
interpreter.apply_named_sequence(
payload_root=kernel,
transform_root=schedule.body.operations[0],
transform_module=schedule,
)
return named_seq

If we do this, you can simply do:

schedule = create_schedule()
schedule.apply(module)

If you need access to the Module around the named_sequence, just ask for it's .parent.

@@ -1,2 +1,2 @@
import ctypes
import torch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this PR didn't introduce it, though looking at it now, I feel we should think about compartmentalizing code that depends on heavy dependencies a bit more. That is, not have it in the same module with code that doesn't have the dependency, e.g. get_packed_arg.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants