55from mlir .dialects .bufferization import LayoutMapOption
66from mlir .dialects import transform
77from mlir .dialects .transform import structured
8- from mlir_utils import apply_registered_pass , match , cse , canonicalize
8+ from mlir_utils import apply_registered_pass , match , canonicalize
99from typing import Optional
1010
1111
@@ -139,28 +139,31 @@ def bundle_xepu_matmul_schedule(
139139
140140 # wg tiling
141141 if has_relu :
142- terminal = match (mod , ops = {"linalg.max" }). result
142+ terminal = match (mod , ops = {"linalg.max" })
143143 elif has_bias :
144- terminal = match (mod , ops = {"linalg.add" }). result
144+ terminal = match (mod , ops = {"linalg.add" })
145145 else :
146- terminal = match (mod , ops = {"linalg.matmul" }).result
146+ terminal = match (mod , ops = {"linalg.matmul" })
147+ # FIXME use structured.structured_fuse
147148 structured .FuseOp (terminal , tile_sizes = wg_tile , use_forall = True )
148- cse (mod )
149+ transform . apply_cse (mod )
149150 canonicalize (mod )
150151
151152 # k loop tiling
152- wg_matmul = match (mod , ops = {"linalg.matmul" }).result
153+ wg_matmul = match (mod , ops = {"linalg.matmul" })
154+ # FIXME use structured.structured_tile_using_for
153155 wgk_matmul , k_loop = structured .TileUsingForOp (
154156 wg_matmul , sizes = [0 , 0 , k_tile ]
155157 ).results
156158
157- cse (func )
159+ transform . apply_cse (func )
158160 canonicalize (func )
159161
160162 if dump_kernel == "tiled" :
161163 return mod , True
162164
163165 # vectorize
166+ # FIXME use structured.structured_vectorize_children_and_apply_patterns
164167 func = structured .VectorizeChildrenAndApplyPatternsOp (
165168 func ,
166169 fold_type_extensions_into_contract = True ,
@@ -170,7 +173,7 @@ def bundle_xepu_matmul_schedule(
170173 k_loop = match (func , ops = {"scf.for" })
171174 loop .HoistLoopInvariantSubsetsOp (k_loop )
172175
173- cse (func )
176+ transform . apply_cse (func )
174177 canonicalize (func )
175178
176179 if dump_kernel == "vectorized" :
@@ -189,7 +192,7 @@ def bundle_xepu_matmul_schedule(
189192 ).result
190193 # fold memref.subviews into vector.transfer_read/write ops
191194 mod = apply_registered_pass (mod , "fold-memref-alias-ops" )
192- cse (mod )
195+ transform . apply_cse (mod )
193196 canonicalize (mod )
194197
195198 if dump_kernel == "bufferized" :
@@ -204,7 +207,7 @@ def bundle_xepu_matmul_schedule(
204207 func = apply_registered_pass (func , "gpu-map-parallel-loops" )
205208 func = apply_registered_pass (func , "convert-parallel-loops-to-gpu" )
206209 func = apply_registered_pass (func , "lower-affine" )
207- cse (func )
210+ transform . apply_cse (func )
208211 canonicalize (func )
209212
210213 # set correct number of gpu threads
@@ -216,7 +219,7 @@ def bundle_xepu_matmul_schedule(
216219 canonicalize (func )
217220 func = apply_registered_pass (func , "gpu-launch-sink-index-computations" )
218221 mod = apply_registered_pass (mod , "gpu-kernel-outlining" )
219- cse (mod )
222+ transform . apply_cse (mod )
220223
221224 # set xevm target
222225 mod = apply_registered_pass (
@@ -229,7 +232,7 @@ def bundle_xepu_matmul_schedule(
229232 gpu_mod = match (mod , ops = {"gpu.module" })
230233 gpu_func = match (gpu_mod , ops = {"gpu.func" })
231234 gpu_func = apply_registered_pass (gpu_func , "convert-vector-to-xegpu" )
232- cse (gpu_func )
235+ transform . apply_cse (gpu_func )
233236
234237 if dump_kernel == "xegpu-initial" :
235238 return mod , True
@@ -319,14 +322,14 @@ def convert_layout(value, input, target):
319322
320323 if has_relu :
321324 # for post ops we need to add C layout manually
322- max_op = match (gpu_func , ops = {"arith.maximumf" }). result
325+ max_op = match (gpu_func , ops = {"arith.maximumf" })
323326 xegpu .set_op_layout_attr (max_op , result = True , index = 0 , ** output_layout )
324327 # find zero constant buffer and annotate it
325328 const_buffer = transform .get_producer_of_operand (anytype , max_op , 1 )
326329 xegpu .set_op_layout_attr (const_buffer , result = True , index = 0 , ** output_layout )
327330 if has_bias :
328331 # for post ops we need to add C layout manually
329- add_op = match (gpu_func , ops = {"arith.addf" }). result
332+ add_op = match (gpu_func , ops = {"arith.addf" })
330333 xegpu .set_op_layout_attr (add_op , result = True , index = 0 , ** output_layout )
331334
332335 # annotate broadcast op operands
@@ -350,14 +353,14 @@ def convert_layout(value, input, target):
350353 mask = transform .get_producer_of_operand (anytype , bcast_load , 2 )
351354 xegpu .set_op_layout_attr (mask , result = True , index = 0 , ** output_layout_dim1 )
352355 raise NotImplementedError ("Bias layout propagation is not supported." )
353- cse (gpu_func )
356+ transform . apply_cse (gpu_func )
354357 canonicalize (gpu_func )
355358
356359 # hoist desc ops out of reduction loop
357360 transform .apply_licm (k_loop )
358361
359362 canonicalize (gpu_func )
360- cse (gpu_func )
363+ transform . apply_cse (gpu_func )
361364
362365 if dump_kernel == "xegpu-wg" :
363366 return mod , True
@@ -379,33 +382,33 @@ def bundle_xegpu_to_binary(mod, dump_kernel: str = ""):
379382 # xegpu distribution
380383 gpu_func = match (gpu_mod , ops = {"gpu.func" })
381384 gpu_func = apply_registered_pass (gpu_func , "xegpu-wg-to-sg-distribute" )
382- cse (gpu_func )
385+ transform . apply_cse (gpu_func )
383386
384387 if dump_kernel == "xegpu-sg" :
385388 return mod , True
386389
387390 gpu_func = apply_registered_pass (gpu_func , "lower-affine" )
388- cse (gpu_func )
391+ transform . apply_cse (gpu_func )
389392 gpu_func = apply_registered_pass (gpu_func , "xegpu-blocking" )
390393 canonicalize (gpu_func )
391- cse (gpu_func )
394+ transform . apply_cse (gpu_func )
392395
393396 if dump_kernel == "xegpu-inst" :
394397 return mod , True
395398
396399 gpu_func = apply_registered_pass (gpu_func , "xegpu-propagate-layout" )
397400 gpu_mod = apply_registered_pass (gpu_mod , "xegpu-subgroup-distribute" )
398401 canonicalize (gpu_mod )
399- cse (gpu_mod )
402+ transform . apply_cse (gpu_mod )
400403 gpu_mod = apply_registered_pass (gpu_mod , "loop-invariant-code-motion" )
401- cse (gpu_mod )
404+ transform . apply_cse (gpu_mod )
402405 gpu_mod = apply_registered_pass (gpu_mod , "xegpu-vector-linearize" )
403406 gpu_mod = apply_registered_pass (gpu_mod , "convert-xegpu-to-xevm" )
404407 gpu_mod = apply_registered_pass (
405408 gpu_mod , "convert-gpu-to-llvm-spv" , options = {"use-64bit-index" : "true" }
406409 )
407410 gpu_mod = apply_registered_pass (gpu_mod , "convert-xevm-to-llvm" )
408- cse (gpu_mod )
411+ transform . apply_cse (gpu_mod )
409412
410413 func = match (mod , ops = {"func.func" })
411414 func = apply_registered_pass (func , "gpu-async-region" )
@@ -424,7 +427,7 @@ def bundle_xegpu_to_binary(mod, dump_kernel: str = ""):
424427 mod = apply_registered_pass (mod , "gpu-to-llvm" )
425428 mod = apply_registered_pass (mod , "lower-affine" )
426429 mod = apply_registered_pass (mod , "reconcile-unrealized-casts" )
427- cse (mod )
430+ transform . apply_cse (mod )
428431 mod = apply_registered_pass (mod , "gpu-module-to-binary" )
429432
430433 return mod , False
0 commit comments