Skip to content

Commit 15cb803

Browse files
authored
add InterfacePass to generate pass pipeline (#122)
1 parent e39e7bb commit 15cb803

File tree

2 files changed

+124
-1
lines changed

2 files changed

+124
-1
lines changed

mlir/extras/runtime/passes.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,35 @@ def affine_loop_tile(
431431
)
432432
return self
433433

434+
def affine_loop_unroll(
435+
self,
436+
unroll_factor: int = None,
437+
unroll_up_to_factor: bool = None,
438+
unroll_full: bool = None,
439+
unroll_num_reps: int = None,
440+
unroll_full_threshold: int = None,
441+
cleanup_unroll: bool = None,
442+
):
443+
"""Unroll affine loops
444+
Args:
445+
unroll-factor: Use this unroll factor for all loops being unrolled
446+
unroll-up-to-factor: Allow unrolling up to the factor specified
447+
unroll-full: Fully unroll loops
448+
unroll-num-reps: Unroll innermost loops repeatedly this many times
449+
unroll-full-threshold: Unroll all loops with trip count less than or equal to this
450+
cleanup-unroll: Fully unroll the cleanup loop when possible.
451+
"""
452+
self.add_pass(
453+
"affine-loop-unroll",
454+
unroll_factor=unroll_factor,
455+
unroll_up_to_factor=unroll_up_to_factor,
456+
unroll_full=unroll_full,
457+
unroll_num_reps=unroll_num_reps,
458+
unroll_full_threshold=unroll_full_threshold,
459+
cleanup_unroll=cleanup_unroll,
460+
)
461+
return self
462+
434463
def affine_loop_unroll_jam(self, unroll_jam_factor: int = None):
435464
"""Unroll and jam affine loops
436465
Args:
@@ -1051,6 +1080,21 @@ def control_flow_sink(self):
10511080
self.add_pass("control-flow-sink")
10521081
return self
10531082

1083+
def convert_affine_for_to_gpu(
1084+
self, gpu_block_dims: int = None, gpu_thread_dims: int = None
1085+
):
1086+
"""Convert top-level AffineFor Ops to GPU kernels
1087+
Args:
1088+
gpu-block-dims: Number of GPU block dimensions for mapping
1089+
gpu-thread-dims: Number of GPU thread dimensions for mapping
1090+
"""
1091+
self.add_pass(
1092+
"convert-affine-for-to-gpu",
1093+
gpu_block_dims=gpu_block_dims,
1094+
gpu_thread_dims=gpu_thread_dims,
1095+
)
1096+
return self
1097+
10541098
def convert_amdgpu_to_rocdl(self, chipset: str = None):
10551099
"""Convert AMDGPU dialect to ROCDL dialect
10561100
@@ -1120,6 +1164,16 @@ def convert_arith_to_spirv(self, emulate_lt_32_bit_scalar_types: bool = None):
11201164
)
11211165
return self
11221166

1167+
def convert_arm_sme_to_llvm(self, dump_tile_live_ranges: bool = None):
1168+
"""Lower the operations from the ArmSME dialect into the LLVM dialect
1169+
Args:
1170+
dump-tile-live-ranges: Dump the live ranges of SME tiles (for debugging)
1171+
"""
1172+
self.add_pass(
1173+
"convert-arm-sme-to-llvm", dump_tile_live_ranges=dump_tile_live_ranges
1174+
)
1175+
return self
1176+
11231177
def convert_arm_sme_to_scf(self):
11241178
"""Lower the operations from the ArmSME dialect into the SCF dialect"""
11251179
self.add_pass("convert-arm-sme-to-scf")
@@ -2289,6 +2343,39 @@ def linalg_block_pack_matmul(
22892343
)
22902344
return self
22912345

2346+
def linalg_detensorize(self, aggressive_mode: bool = None):
2347+
"""Detensorize linalg ops
2348+
2349+
Detensoring is the process through which a tensor value is converted to one
2350+
or potentially more primitive value(s). During this process, operations with
2351+
such detensored operands are also converted to an equivalent form that works
2352+
on primitives.
2353+
2354+
The detensoring process is driven by linalg-on-tensor ops. In particular, a
2355+
linalg-on-tensor op is checked to see whether *all* its operands can be
2356+
detensored. If so, those operands are converted to their primitive
2357+
counterparts and the linalg op is replaced by an equivalent op that takes
2358+
those new primitive values as operands. Therefore, detensoring an op can be
2359+
divided into 2 main logical phases:
2360+
2361+
1. Detect/match an op that can be detensored.
2362+
2. Detensor the operands of the op and replace it with a primitive
2363+
equivalent.
2364+
2365+
In addition to detensoring individual ops, this pass detensors internal
2366+
control flow inside a function. All blocks except for the entry block are
2367+
detensored by converting their arguments whenever possible.
2368+
2369+
This can be run on any FunctionOpInterface op and must not be
2370+
run on others. This is because it performs specific legalization of the
2371+
blocks that make up the body, which it assumes has is a FunctionOpInterface.
2372+
2373+
Args:
2374+
aggressive-mode: Detensorize all ops that qualify for detensoring along with branch operands and basic-block arguments.
2375+
"""
2376+
self.add_pass("linalg-detensorize", aggressive_mode=aggressive_mode)
2377+
return self
2378+
22922379
def linalg_fold_unit_extent_dims(self, use_rank_reducing_slices: bool = None):
22932380
"""Remove unit-extent dimension in Linalg ops on tensors
22942381
Args:
@@ -4497,6 +4584,42 @@ def tosa_to_arith(
44974584
)
44984585
return self
44994586

4587+
def tosa_to_linalg(
4588+
self,
4589+
disable_tosa_decompositions: bool = None,
4590+
aggressive_reduce_constant: bool = None,
4591+
):
4592+
"""Lower TOSA to LinAlg on tensors
4593+
4594+
Pass that converts TOSA operations to the equivalent operations using the
4595+
tensor operations in LinAlg.
4596+
4597+
Args:
4598+
disable-tosa-decompositions: Disable tosa decompositions pass
4599+
aggressive-reduce-constant: Always perform the reduce constant optimization
4600+
"""
4601+
self.add_pass(
4602+
"tosa-to-linalg",
4603+
disable_tosa_decompositions=disable_tosa_decompositions,
4604+
aggressive_reduce_constant=aggressive_reduce_constant,
4605+
)
4606+
return self
4607+
4608+
def tosa_to_linalg_named(self, prefer_conv2d_kernel_layout_hwcf: bool = None):
4609+
"""Lower TOSA to LinAlg named operations
4610+
4611+
Pass that converts TOSA operations to the equivalent operations using the
4612+
Linalg named operations.
4613+
4614+
Args:
4615+
prefer-conv2d-kernel-layout-hwcf: Prefer generating linalg.conv_2d_nhwc_hwcf over linalg.conv_2d_nhwc_fhwc
4616+
"""
4617+
self.add_pass(
4618+
"tosa-to-linalg-named",
4619+
prefer_conv2d_kernel_layout_hwcf=prefer_conv2d_kernel_layout_hwcf,
4620+
)
4621+
return self
4622+
45004623
def tosa_to_mlprogram(self):
45014624
"""Lower TOSA to the MLProgram dialect
45024625

scripts/generate_pass_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def {pass_name.replace('-', '_')}(self):"""
184184

185185
def gather_passes_from_td_json(j):
186186
passes = []
187-
for pass_ in j["!instanceof"]["Pass"]:
187+
for pass_ in j["!instanceof"]["Pass"] + j["!instanceof"]["InterfacePass"]:
188188
pass_ = j[pass_]
189189
options = []
190190
for o in pass_["options"]:

0 commit comments

Comments
 (0)