@@ -25,48 +25,41 @@ def get_schedule_module(
2525 mod = ir .Module .create ()
2626 mod .operation .attributes ["transform.with_named_sequence" ] = ir .UnitAttr .get ()
2727 with ir .InsertionPoint (mod .body ):
28- named_sequence = transform .NamedSequenceOp (
28+ named_sequence = transform .named_sequence (
2929 "__transform_main" ,
3030 [transform .AnyOpType .get ()], # input types
3131 [], # output types
3232 arg_attrs = [{"transform.readonly" : ir .UnitAttr .get ()}],
3333 )
3434 with ir .InsertionPoint (named_sequence .body ):
35+ # match the payload module
36+ anytype = transform .AnyOpType .get ()
37+ func = match (named_sequence .bodyTarget , ops = {"func.func" })
38+ payload_mod = transform .get_parent_op (
39+ anytype ,
40+ func ,
41+ op_name = "builtin.module" ,
42+ deduplicate = True ,
43+ )
3544 xegpu_matmul_transform_schedule (
36- named_sequence ,
45+ payload_mod ,
3746 has_bias = has_bias ,
3847 has_relu = has_relu ,
3948 dump_kernel = dump_kernel ,
4049 params = params ,
4150 )
42- # placeholder for parameter division op
43- i32 = ir .IntegerType .get_signless (32 )
44- paramInt32Type = transform .ParamType .get (i32 )
45- div_named_sequence = transform .NamedSequenceOp (
46- "param_div" ,
47- [paramInt32Type , paramInt32Type ], # input types
48- [paramInt32Type ], # output types
49- arg_attrs = [
50- {"transform.readonly" : ir .UnitAttr .get ()},
51- {"transform.readonly" : ir .UnitAttr .get ()},
52- ],
53- )
54- with ir .InsertionPoint (div_named_sequence .body ):
55- p = transform .ParamConstantOp (paramInt32Type , ir .IntegerAttr .get (i32 , 1 ))
56- transform .YieldOp (p )
5751
5852 return mod
5953
6054
6155def xegpu_matmul_transform_schedule (
62- named_sequence : transform . NamedSequenceOp ,
56+ mod : ir . Value ,
6357 has_bias : bool = False ,
6458 has_relu : bool = False ,
6559 dump_kernel : str = "" ,
6660 params : Optional [dict ] = None ,
6761):
6862 """Transform schedule for matmul-like payload."""
69- mod = bundle_header (named_sequence )
7063 mod , interrupted = bundle_xepu_matmul_schedule (
7164 mod ,
7265 has_bias = has_bias ,
@@ -75,27 +68,14 @@ def xegpu_matmul_transform_schedule(
7568 params = params ,
7669 )
7770 if interrupted :
78- transform .YieldOp ()
71+ transform .yield_ ()
7972 return
8073
8174 mod , interrupted = bundle_xegpu_to_binary (
8275 mod ,
8376 dump_kernel = dump_kernel ,
8477 )
85- transform .YieldOp ()
86-
87-
88- def bundle_header (named_sequence : transform .NamedSequenceOp ):
89- """Matches the payload module."""
90- anytype = transform .AnyOpType .get ()
91- func = match (named_sequence .bodyTarget , ops = {"func.func" })
92- mod = transform .get_parent_op (
93- anytype ,
94- func ,
95- op_name = "builtin.module" ,
96- deduplicate = True ,
97- )
98- return mod
78+ transform .yield_ ()
9979
10080
10181def bundle_xepu_matmul_schedule (
@@ -217,7 +197,7 @@ def bundle_xepu_matmul_schedule(
217197
218198 # convert forall to parallel
219199 wg_loop = match (mod , ops = {"scf.forall" })
220- wg_loop = loop .ForallToParallelOp ([anytype ], wg_loop )
200+ wg_loop = loop .loop_forall_to_parallel ([anytype ], wg_loop )
221201 func = transform .get_parent_op (anytype , wg_loop )
222202
223203 # convert to scf.parallel to gpu.launch
@@ -257,9 +237,9 @@ def bundle_xepu_matmul_schedule(
257237 # add layouts to DPAS op operands
258238 k_loop = match (gpu_func , ops = {"scf.for" })
259239 dpas_op = match (k_loop , ops = {"xegpu.dpas" })
260- tile_a = transform .GetOperandOp (anyvalue , dpas_op , [0 ])
261- tile_b = transform .GetOperandOp (anyvalue , dpas_op , [1 ])
262- tile_c = transform .GetOperandOp (anyvalue , dpas_op , [2 ])
240+ tile_a = transform .get_operand (anyvalue , dpas_op , [0 ])
241+ tile_b = transform .get_operand (anyvalue , dpas_op , [1 ])
242+ tile_c = transform .get_operand (anyvalue , dpas_op , [2 ])
263243
264244 def convert_layout (value , input , target ):
265245 xegpu .convert_layout (
0 commit comments