@@ -431,6 +431,35 @@ def affine_loop_tile(
431
431
)
432
432
return self
433
433
434
+ def affine_loop_unroll (
435
+ self ,
436
+ unroll_factor : int = None ,
437
+ unroll_up_to_factor : bool = None ,
438
+ unroll_full : bool = None ,
439
+ unroll_num_reps : int = None ,
440
+ unroll_full_threshold : int = None ,
441
+ cleanup_unroll : bool = None ,
442
+ ):
443
+ """Unroll affine loops
444
+ Args:
445
+ unroll-factor: Use this unroll factor for all loops being unrolled
446
+ unroll-up-to-factor: Allow unrolling up to the factor specified
447
+ unroll-full: Fully unroll loops
448
+ unroll-num-reps: Unroll innermost loops repeatedly this many times
449
+ unroll-full-threshold: Unroll all loops with trip count less than or equal to this
450
+ cleanup-unroll: Fully unroll the cleanup loop when possible.
451
+ """
452
+ self .add_pass (
453
+ "affine-loop-unroll" ,
454
+ unroll_factor = unroll_factor ,
455
+ unroll_up_to_factor = unroll_up_to_factor ,
456
+ unroll_full = unroll_full ,
457
+ unroll_num_reps = unroll_num_reps ,
458
+ unroll_full_threshold = unroll_full_threshold ,
459
+ cleanup_unroll = cleanup_unroll ,
460
+ )
461
+ return self
462
+
434
463
def affine_loop_unroll_jam (self , unroll_jam_factor : int = None ):
435
464
"""Unroll and jam affine loops
436
465
Args:
@@ -1051,6 +1080,21 @@ def control_flow_sink(self):
1051
1080
self .add_pass ("control-flow-sink" )
1052
1081
return self
1053
1082
1083
+ def convert_affine_for_to_gpu (
1084
+ self , gpu_block_dims : int = None , gpu_thread_dims : int = None
1085
+ ):
1086
+ """Convert top-level AffineFor Ops to GPU kernels
1087
+ Args:
1088
+ gpu-block-dims: Number of GPU block dimensions for mapping
1089
+ gpu-thread-dims: Number of GPU thread dimensions for mapping
1090
+ """
1091
+ self .add_pass (
1092
+ "convert-affine-for-to-gpu" ,
1093
+ gpu_block_dims = gpu_block_dims ,
1094
+ gpu_thread_dims = gpu_thread_dims ,
1095
+ )
1096
+ return self
1097
+
1054
1098
def convert_amdgpu_to_rocdl (self , chipset : str = None ):
1055
1099
"""Convert AMDGPU dialect to ROCDL dialect
1056
1100
@@ -1120,6 +1164,16 @@ def convert_arith_to_spirv(self, emulate_lt_32_bit_scalar_types: bool = None):
1120
1164
)
1121
1165
return self
1122
1166
1167
+ def convert_arm_sme_to_llvm (self , dump_tile_live_ranges : bool = None ):
1168
+ """Lower the operations from the ArmSME dialect into the LLVM dialect
1169
+ Args:
1170
+ dump-tile-live-ranges: Dump the live ranges of SME tiles (for debugging)
1171
+ """
1172
+ self .add_pass (
1173
+ "convert-arm-sme-to-llvm" , dump_tile_live_ranges = dump_tile_live_ranges
1174
+ )
1175
+ return self
1176
+
1123
1177
def convert_arm_sme_to_scf (self ):
1124
1178
"""Lower the operations from the ArmSME dialect into the SCF dialect"""
1125
1179
self .add_pass ("convert-arm-sme-to-scf" )
@@ -2289,6 +2343,39 @@ def linalg_block_pack_matmul(
2289
2343
)
2290
2344
return self
2291
2345
2346
+ def linalg_detensorize (self , aggressive_mode : bool = None ):
2347
+ """Detensorize linalg ops
2348
+
2349
+ Detensoring is the process through which a tensor value is converted to one
2350
+ or potentially more primitive value(s). During this process, operations with
2351
+ such detensored operands are also converted to an equivalent form that works
2352
+ on primitives.
2353
+
2354
+ The detensoring process is driven by linalg-on-tensor ops. In particular, a
2355
+ linalg-on-tensor op is checked to see whether *all* its operands can be
2356
+ detensored. If so, those operands are converted to their primitive
2357
+ counterparts and the linalg op is replaced by an equivalent op that takes
2358
+ those new primitive values as operands. Therefore, detensoring an op can be
2359
+ divided into 2 main logical phases:
2360
+
2361
+ 1. Detect/match an op that can be detensored.
2362
+ 2. Detensor the operands of the op and replace it with a primitive
2363
+ equivalent.
2364
+
2365
+ In addition to detensoring individual ops, this pass detensors internal
2366
+ control flow inside a function. All blocks except for the entry block are
2367
+ detensored by converting their arguments whenever possible.
2368
+
2369
+ This can be run on any FunctionOpInterface op and must not be
2370
+ run on others. This is because it performs specific legalization of the
2371
+ blocks that make up the body, which it assumes has is a FunctionOpInterface.
2372
+
2373
+ Args:
2374
+ aggressive-mode: Detensorize all ops that qualify for detensoring along with branch operands and basic-block arguments.
2375
+ """
2376
+ self .add_pass ("linalg-detensorize" , aggressive_mode = aggressive_mode )
2377
+ return self
2378
+
2292
2379
def linalg_fold_unit_extent_dims (self , use_rank_reducing_slices : bool = None ):
2293
2380
"""Remove unit-extent dimension in Linalg ops on tensors
2294
2381
Args:
@@ -4497,6 +4584,42 @@ def tosa_to_arith(
4497
4584
)
4498
4585
return self
4499
4586
4587
+ def tosa_to_linalg (
4588
+ self ,
4589
+ disable_tosa_decompositions : bool = None ,
4590
+ aggressive_reduce_constant : bool = None ,
4591
+ ):
4592
+ """Lower TOSA to LinAlg on tensors
4593
+
4594
+ Pass that converts TOSA operations to the equivalent operations using the
4595
+ tensor operations in LinAlg.
4596
+
4597
+ Args:
4598
+ disable-tosa-decompositions: Disable tosa decompositions pass
4599
+ aggressive-reduce-constant: Always perform the reduce constant optimization
4600
+ """
4601
+ self .add_pass (
4602
+ "tosa-to-linalg" ,
4603
+ disable_tosa_decompositions = disable_tosa_decompositions ,
4604
+ aggressive_reduce_constant = aggressive_reduce_constant ,
4605
+ )
4606
+ return self
4607
+
4608
+ def tosa_to_linalg_named (self , prefer_conv2d_kernel_layout_hwcf : bool = None ):
4609
+ """Lower TOSA to LinAlg named operations
4610
+
4611
+ Pass that converts TOSA operations to the equivalent operations using the
4612
+ Linalg named operations.
4613
+
4614
+ Args:
4615
+ prefer-conv2d-kernel-layout-hwcf: Prefer generating linalg.conv_2d_nhwc_hwcf over linalg.conv_2d_nhwc_fhwc
4616
+ """
4617
+ self .add_pass (
4618
+ "tosa-to-linalg-named" ,
4619
+ prefer_conv2d_kernel_layout_hwcf = prefer_conv2d_kernel_layout_hwcf ,
4620
+ )
4621
+ return self
4622
+
4500
4623
def tosa_to_mlprogram (self ):
4501
4624
"""Lower TOSA to the MLProgram dialect
4502
4625
0 commit comments