Skip to content

Commit b345707

Browse files
tkarnaadam-smnk
authored andcommitted
more snake_case transform ops
1 parent c82ecd4 commit b345707

File tree

2 files changed

+28
-29
lines changed

2 files changed

+28
-29
lines changed

python/examples/xegpu_matmul/mlir_utils.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,12 @@ def apply_registered_pass(*args, **kwargs):
99

1010

1111
def match(*args, **kwargs):
12-
return structured.MatchOp(transform.AnyOpType.get(), *args, **kwargs)
13-
14-
15-
def cse(op):
16-
transform.ApplyCommonSubexpressionEliminationOp(op)
12+
return structured.structured_match(transform.AnyOpType.get(), *args, **kwargs)
1713

1814

1915
def canonicalize(op):
2016
with ir.InsertionPoint(transform.apply_patterns(op).patterns):
21-
transform.ApplyCanonicalizationPatternsOp()
17+
transform.apply_patterns_canonicalization()
2218

2319

2420
def get_mlir_library_path():

python/examples/xegpu_matmul/schedule.py

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from mlir.dialects.bufferization import LayoutMapOption
66
from mlir.dialects import transform
77
from mlir.dialects.transform import structured
8-
from mlir_utils import apply_registered_pass, match, cse, canonicalize
8+
from mlir_utils import apply_registered_pass, match, canonicalize
99
from typing import Optional
1010

1111

@@ -139,28 +139,31 @@ def bundle_xepu_matmul_schedule(
139139

140140
# wg tiling
141141
if has_relu:
142-
terminal = match(mod, ops={"linalg.max"}).result
142+
terminal = match(mod, ops={"linalg.max"})
143143
elif has_bias:
144-
terminal = match(mod, ops={"linalg.add"}).result
144+
terminal = match(mod, ops={"linalg.add"})
145145
else:
146-
terminal = match(mod, ops={"linalg.matmul"}).result
146+
terminal = match(mod, ops={"linalg.matmul"})
147+
# FIXME use structured.structured_fuse
147148
structured.FuseOp(terminal, tile_sizes=wg_tile, use_forall=True)
148-
cse(mod)
149+
transform.apply_cse(mod)
149150
canonicalize(mod)
150151

151152
# k loop tiling
152-
wg_matmul = match(mod, ops={"linalg.matmul"}).result
153+
wg_matmul = match(mod, ops={"linalg.matmul"})
154+
# FIXME use structured.structured_tile_using_for
153155
wgk_matmul, k_loop = structured.TileUsingForOp(
154156
wg_matmul, sizes=[0, 0, k_tile]
155157
).results
156158

157-
cse(func)
159+
transform.apply_cse(func)
158160
canonicalize(func)
159161

160162
if dump_kernel == "tiled":
161163
return mod, True
162164

163165
# vectorize
166+
# FIXME use structured.structured_vectorize_children_and_apply_patterns
164167
func = structured.VectorizeChildrenAndApplyPatternsOp(
165168
func,
166169
fold_type_extensions_into_contract=True,
@@ -170,7 +173,7 @@ def bundle_xepu_matmul_schedule(
170173
k_loop = match(func, ops={"scf.for"})
171174
loop.HoistLoopInvariantSubsetsOp(k_loop)
172175

173-
cse(func)
176+
transform.apply_cse(func)
174177
canonicalize(func)
175178

176179
if dump_kernel == "vectorized":
@@ -189,7 +192,7 @@ def bundle_xepu_matmul_schedule(
189192
).result
190193
# fold memref.subviews into vector.transfer_read/write ops
191194
mod = apply_registered_pass(mod, "fold-memref-alias-ops")
192-
cse(mod)
195+
transform.apply_cse(mod)
193196
canonicalize(mod)
194197

195198
if dump_kernel == "bufferized":
@@ -204,7 +207,7 @@ def bundle_xepu_matmul_schedule(
204207
func = apply_registered_pass(func, "gpu-map-parallel-loops")
205208
func = apply_registered_pass(func, "convert-parallel-loops-to-gpu")
206209
func = apply_registered_pass(func, "lower-affine")
207-
cse(func)
210+
transform.apply_cse(func)
208211
canonicalize(func)
209212

210213
# set correct number of gpu threads
@@ -216,7 +219,7 @@ def bundle_xepu_matmul_schedule(
216219
canonicalize(func)
217220
func = apply_registered_pass(func, "gpu-launch-sink-index-computations")
218221
mod = apply_registered_pass(mod, "gpu-kernel-outlining")
219-
cse(mod)
222+
transform.apply_cse(mod)
220223

221224
# set xevm target
222225
mod = apply_registered_pass(
@@ -229,7 +232,7 @@ def bundle_xepu_matmul_schedule(
229232
gpu_mod = match(mod, ops={"gpu.module"})
230233
gpu_func = match(gpu_mod, ops={"gpu.func"})
231234
gpu_func = apply_registered_pass(gpu_func, "convert-vector-to-xegpu")
232-
cse(gpu_func)
235+
transform.apply_cse(gpu_func)
233236

234237
if dump_kernel == "xegpu-initial":
235238
return mod, True
@@ -319,14 +322,14 @@ def convert_layout(value, input, target):
319322

320323
if has_relu:
321324
# for post ops we need to add C layout manually
322-
max_op = match(gpu_func, ops={"arith.maximumf"}).result
325+
max_op = match(gpu_func, ops={"arith.maximumf"})
323326
xegpu.set_op_layout_attr(max_op, result=True, index=0, **output_layout)
324327
# find zero constant buffer and annotate it
325328
const_buffer = transform.get_producer_of_operand(anytype, max_op, 1)
326329
xegpu.set_op_layout_attr(const_buffer, result=True, index=0, **output_layout)
327330
if has_bias:
328331
# for post ops we need to add C layout manually
329-
add_op = match(gpu_func, ops={"arith.addf"}).result
332+
add_op = match(gpu_func, ops={"arith.addf"})
330333
xegpu.set_op_layout_attr(add_op, result=True, index=0, **output_layout)
331334

332335
# annotate broadcast op operands
@@ -350,14 +353,14 @@ def convert_layout(value, input, target):
350353
mask = transform.get_producer_of_operand(anytype, bcast_load, 2)
351354
xegpu.set_op_layout_attr(mask, result=True, index=0, **output_layout_dim1)
352355
raise NotImplementedError("Bias layout propagation is not supported.")
353-
cse(gpu_func)
356+
transform.apply_cse(gpu_func)
354357
canonicalize(gpu_func)
355358

356359
# hoist desc ops out of reduction loop
357360
transform.apply_licm(k_loop)
358361

359362
canonicalize(gpu_func)
360-
cse(gpu_func)
363+
transform.apply_cse(gpu_func)
361364

362365
if dump_kernel == "xegpu-wg":
363366
return mod, True
@@ -379,33 +382,33 @@ def bundle_xegpu_to_binary(mod, dump_kernel: str = ""):
379382
# xegpu distribution
380383
gpu_func = match(gpu_mod, ops={"gpu.func"})
381384
gpu_func = apply_registered_pass(gpu_func, "xegpu-wg-to-sg-distribute")
382-
cse(gpu_func)
385+
transform.apply_cse(gpu_func)
383386

384387
if dump_kernel == "xegpu-sg":
385388
return mod, True
386389

387390
gpu_func = apply_registered_pass(gpu_func, "lower-affine")
388-
cse(gpu_func)
391+
transform.apply_cse(gpu_func)
389392
gpu_func = apply_registered_pass(gpu_func, "xegpu-blocking")
390393
canonicalize(gpu_func)
391-
cse(gpu_func)
394+
transform.apply_cse(gpu_func)
392395

393396
if dump_kernel == "xegpu-inst":
394397
return mod, True
395398

396399
gpu_func = apply_registered_pass(gpu_func, "xegpu-propagate-layout")
397400
gpu_mod = apply_registered_pass(gpu_mod, "xegpu-subgroup-distribute")
398401
canonicalize(gpu_mod)
399-
cse(gpu_mod)
402+
transform.apply_cse(gpu_mod)
400403
gpu_mod = apply_registered_pass(gpu_mod, "loop-invariant-code-motion")
401-
cse(gpu_mod)
404+
transform.apply_cse(gpu_mod)
402405
gpu_mod = apply_registered_pass(gpu_mod, "xegpu-vector-linearize")
403406
gpu_mod = apply_registered_pass(gpu_mod, "convert-xegpu-to-xevm")
404407
gpu_mod = apply_registered_pass(
405408
gpu_mod, "convert-gpu-to-llvm-spv", options={"use-64bit-index": "true"}
406409
)
407410
gpu_mod = apply_registered_pass(gpu_mod, "convert-xevm-to-llvm")
408-
cse(gpu_mod)
411+
transform.apply_cse(gpu_mod)
409412

410413
func = match(mod, ops={"func.func"})
411414
func = apply_registered_pass(func, "gpu-async-region")
@@ -424,7 +427,7 @@ def bundle_xegpu_to_binary(mod, dump_kernel: str = ""):
424427
mod = apply_registered_pass(mod, "gpu-to-llvm")
425428
mod = apply_registered_pass(mod, "lower-affine")
426429
mod = apply_registered_pass(mod, "reconcile-unrealized-casts")
427-
cse(mod)
430+
transform.apply_cse(mod)
428431
mod = apply_registered_pass(mod, "gpu-module-to-binary")
429432

430433
return mod, False

0 commit comments

Comments
 (0)