Skip to content

Commit fa58880

Browse files
authored
[transform] Basic example of applying a schedule to a payload (#12)
Simple demonstration of applying a schedule to a payload. Relies on the little wrapper method `NamedSequenceOp.apply`, now existing upstream, to make things as simple as can be.
1 parent c60d954 commit fa58880

File tree

3 files changed

+91
-20
lines changed

3 files changed

+91
-20
lines changed

.github/workflows/examples.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,7 @@ jobs:
3030
- name: Run Compile And Run
3131
run: |-
3232
uv run python/examples/mlir/compile_and_run.py
33+
34+
- name: Run apply basic schedule to basic payload
35+
run: |-
36+
uv run python/examples/schedule/transform_a_payload_according_to_a_schedule.py

python/examples/mlir/compile_and_run.py

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from mlir import ir
55
from mlir.dialects import transform
66
from mlir.dialects.transform import structured
7-
from mlir.dialects.transform import interpreter
87
from mlir.execution_engine import ExecutionEngine
98
from mlir.passmanager import PassManager
109

@@ -89,22 +88,6 @@ def create_schedule(ctx: ir.Context) -> ir.Module:
8988
return schedule
9089

9190

92-
def apply_schedule(kernel: ir.Module, schedule: ir.Module) -> None:
93-
"""
94-
Apply transformation schedule to a kernel module.
95-
The kernel is modified in-place.
96-
97-
Args:
98-
kernel: A module with payload function.
99-
schedule: A module with transform schedule.
100-
"""
101-
interpreter.apply_named_sequence(
102-
payload_root=kernel,
103-
transform_root=schedule.body.operations[0],
104-
transform_module=schedule,
105-
)
106-
107-
10891
def create_pass_pipeline(ctx: ir.Context) -> PassManager:
10992
"""
11093
Create an MLIR pass pipeline.
@@ -141,9 +124,11 @@ def main(args):
141124
ctx = ir.Context()
142125
kernel = create_kernel(ctx)
143126

144-
# Create a transform schedule and apply initial lowering.
145-
schedule = create_schedule(ctx)
146-
apply_schedule(kernel, schedule)
127+
# Create a transform schedule and apply initial lowering to kernel.
128+
# The kernel is modified in-place.
129+
schedule_module = create_schedule(ctx)
130+
named_seq: transform.NamedSequenceOp = schedule_module.body.operations[0]
131+
named_seq.apply(kernel)
147132

148133
# Create a pass pipeline and lower the kernel to LLVM dialect.
149134
pm = create_pass_pipeline(ctx)
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Simply demonstrates applying a schedule to a payload.
2+
# To do so generates a basic payload and a basic schedule, purely as an example.
3+
4+
from mlir.ir import Context, Location, InsertionPoint, Operation, Module
5+
from mlir.ir import RankedTensorType, F32Type, UnitAttr
6+
from mlir.dialects import arith, func, linalg, tensor, transform
7+
from mlir.dialects.transform import structured
8+
9+
10+
def example_payload() -> Module:
11+
"""IR for:
12+
Zero = ...
13+
X = matmul(..., C=Zero)
14+
Y = matmul(..., C=Zero)
15+
Res = add(X, Y)
16+
17+
Can be re-written to:
18+
X = matmul(..., C=Zero)
19+
Res = matmul(..., C=X)
20+
"""
21+
22+
print("NOTE: example payload module:")
23+
payload = Module.create()
24+
with InsertionPoint(payload.body):
25+
matrixType = RankedTensorType.get([16, 16], F32Type.get())
26+
27+
@func.func(matrixType, matrixType, matrixType)
28+
def fold_add_on_two_matmuls(matrixA, matrixB, weights):
29+
empty = tensor.empty(matrixType.shape, matrixType.element_type)
30+
c0 = arith.constant(F32Type.get(), 0.0)
31+
zero_init = linalg.fill(c0, outs=[empty])
32+
A_x_weights = linalg.matmul(matrixA, weights, outs=[zero_init])
33+
empty2 = tensor.empty(matrixType.shape, matrixType.element_type)
34+
zero_init2 = linalg.fill(c0, outs=[empty2])
35+
B_x_weights = linalg.matmul(matrixB, weights, outs=[zero_init2])
36+
added = linalg.add(A_x_weights, B_x_weights, outs=[empty])
37+
return added
38+
39+
print(payload)
40+
return payload
41+
42+
43+
def example_schedule() -> Module:
44+
"""Basic schedule wrapping a single rewrite pattern."""
45+
46+
print("NOTE: example schedule module:")
47+
schedule_module = Module.create()
48+
schedule_module.operation.attributes["transform.with_named_sequence"] = (
49+
UnitAttr.get()
50+
)
51+
with InsertionPoint(schedule_module.body):
52+
named_seq = transform.named_sequence(
53+
"__transform_main",
54+
input_types=[transform.any_op_t()],
55+
result_types=[],
56+
arg_attrs=[{"transform.readonly": UnitAttr.get()}],
57+
)
58+
59+
with InsertionPoint(named_seq.body):
60+
func = structured.MatchOp.match_op_names(
61+
named_seq.bodyTarget, ["func.func"]
62+
) # TODO: fix syntax upstream
63+
with InsertionPoint(transform.apply_patterns(func).patterns):
64+
Operation.create(
65+
"transform.apply_patterns.linalg.fold_add_into_dest"
66+
) # TODO: expose dedicated builder upstream
67+
transform.yield_([])
68+
69+
print(schedule_module)
70+
return schedule_module
71+
72+
73+
with Context(), Location.unknown():
74+
payload = example_payload()
75+
schedule_module = example_schedule()
76+
# Actual schedule is defined by the contained transform.named_sequence:
77+
schedule: transform.NamedSequenceOp = schedule_module.body.operations[0]
78+
79+
schedule.apply(payload) # The actual transformation happens here.
80+
81+
print("NOTE: result of applying schedule to payload:")
82+
print(payload)

0 commit comments

Comments
 (0)