Skip to content

Commit ee34b37

Browse files
committed
simplify schedule and use snake_case op names where possible
1 parent 5957395 commit ee34b37

File tree

1 file changed

+18
-38
lines changed

1 file changed

+18
-38
lines changed

python/examples/xegpu_matmul/schedule.py

Lines changed: 18 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -25,48 +25,41 @@ def get_schedule_module(
2525
mod = ir.Module.create()
2626
mod.operation.attributes["transform.with_named_sequence"] = ir.UnitAttr.get()
2727
with ir.InsertionPoint(mod.body):
28-
named_sequence = transform.NamedSequenceOp(
28+
named_sequence = transform.named_sequence(
2929
"__transform_main",
3030
[transform.AnyOpType.get()], # input types
3131
[], # output types
3232
arg_attrs=[{"transform.readonly": ir.UnitAttr.get()}],
3333
)
3434
with ir.InsertionPoint(named_sequence.body):
35+
# match the payload module
36+
anytype = transform.AnyOpType.get()
37+
func = match(named_sequence.bodyTarget, ops={"func.func"})
38+
payload_mod = transform.get_parent_op(
39+
anytype,
40+
func,
41+
op_name="builtin.module",
42+
deduplicate=True,
43+
)
3544
xegpu_matmul_transform_schedule(
36-
named_sequence,
45+
payload_mod,
3746
has_bias=has_bias,
3847
has_relu=has_relu,
3948
dump_kernel=dump_kernel,
4049
params=params,
4150
)
42-
# placeholder for parameter division op
43-
i32 = ir.IntegerType.get_signless(32)
44-
paramInt32Type = transform.ParamType.get(i32)
45-
div_named_sequence = transform.NamedSequenceOp(
46-
"param_div",
47-
[paramInt32Type, paramInt32Type], # input types
48-
[paramInt32Type], # output types
49-
arg_attrs=[
50-
{"transform.readonly": ir.UnitAttr.get()},
51-
{"transform.readonly": ir.UnitAttr.get()},
52-
],
53-
)
54-
with ir.InsertionPoint(div_named_sequence.body):
55-
p = transform.ParamConstantOp(paramInt32Type, ir.IntegerAttr.get(i32, 1))
56-
transform.YieldOp(p)
5751

5852
return mod
5953

6054

6155
def xegpu_matmul_transform_schedule(
62-
named_sequence: transform.NamedSequenceOp,
56+
mod: ir.Value,
6357
has_bias: bool = False,
6458
has_relu: bool = False,
6559
dump_kernel: str = "",
6660
params: Optional[dict] = None,
6761
):
6862
"""Transform schedule for matmul-like payload."""
69-
mod = bundle_header(named_sequence)
7063
mod, interrupted = bundle_xepu_matmul_schedule(
7164
mod,
7265
has_bias=has_bias,
@@ -75,27 +68,14 @@ def xegpu_matmul_transform_schedule(
7568
params=params,
7669
)
7770
if interrupted:
78-
transform.YieldOp()
71+
transform.yield_()
7972
return
8073

8174
mod, interrupted = bundle_xegpu_to_binary(
8275
mod,
8376
dump_kernel=dump_kernel,
8477
)
85-
transform.YieldOp()
86-
87-
88-
def bundle_header(named_sequence: transform.NamedSequenceOp):
89-
"""Matches the payload module."""
90-
anytype = transform.AnyOpType.get()
91-
func = match(named_sequence.bodyTarget, ops={"func.func"})
92-
mod = transform.get_parent_op(
93-
anytype,
94-
func,
95-
op_name="builtin.module",
96-
deduplicate=True,
97-
)
98-
return mod
78+
transform.yield_()
9979

10080

10181
def bundle_xepu_matmul_schedule(
@@ -217,7 +197,7 @@ def bundle_xepu_matmul_schedule(
217197

218198
# convert forall to parallel
219199
wg_loop = match(mod, ops={"scf.forall"})
220-
wg_loop = loop.ForallToParallelOp([anytype], wg_loop)
200+
wg_loop = loop.loop_forall_to_parallel([anytype], wg_loop)
221201
func = transform.get_parent_op(anytype, wg_loop)
222202

223203
# convert to scf.parallel to gpu.launch
@@ -257,9 +237,9 @@ def bundle_xepu_matmul_schedule(
257237
# add layouts to DPAS op operands
258238
k_loop = match(gpu_func, ops={"scf.for"})
259239
dpas_op = match(k_loop, ops={"xegpu.dpas"})
260-
tile_a = transform.GetOperandOp(anyvalue, dpas_op, [0])
261-
tile_b = transform.GetOperandOp(anyvalue, dpas_op, [1])
262-
tile_c = transform.GetOperandOp(anyvalue, dpas_op, [2])
240+
tile_a = transform.get_operand(anyvalue, dpas_op, [0])
241+
tile_b = transform.get_operand(anyvalue, dpas_op, [1])
242+
tile_c = transform.get_operand(anyvalue, dpas_op, [2])
263243

264244
def convert_layout(value, input, target):
265245
xegpu.convert_layout(

0 commit comments

Comments
 (0)