11from typing import Optional , Sequence
22
3+ from mlir import ir
34from mlir .dialects import transform
45from .common import apply_registered_pass , match
56from .utils import GpuBackend , PipelineInterrupt
@@ -67,7 +68,7 @@ def linalg_lowering(mod, /, *, skip_operations: Sequence[str] = (), **_config):
6768 func = apply_registered_pass (
6869 func ,
6970 "convert-linalg-to-xsmm" ,
70- options = "skip-operations=" + "," .join (skip_operations ),
71+ options = { "skip-operations" : "," .join (skip_operations )} ,
7172 )
7273 func = apply_registered_pass (func , "combine-xsmm-op-optimization" )
7374 func = apply_registered_pass (func , "fold-xsmm-flags" )
@@ -130,7 +131,7 @@ def low_level_parallel(
130131 # Run cleanup after LICM to allow CSE to eliminate common operations now
131132 # that they are hoisted out of loops.
132133 mod = cleanup (mod )
133- options = "parallel-loop-tile-sizes=" + "," .join (map (str , parallel_task_grid ))
134+ options = { "parallel-loop-tile-sizes" : "," .join (map (str , parallel_task_grid ))}
134135 mod = apply_registered_pass (mod , "scf-parallel-loop-tiling" , options = options )
135136 return mod
136137
@@ -228,7 +229,7 @@ def default_tpp_passes(
228229 mod = linalg_lowering (mod , skip_operations = skip_ops , ** config )
229230 if linalg_to_vector or force_linalg_to_vector :
230231 func = match (mod , ops = {"func.func" })
231- options = "registerTileShape=" + "," .join (map (str , register_blocking ))
232+ options = { "registerTileShape" : "," .join (map (str , register_blocking ))}
232233 func = apply_registered_pass (func , "brgemm-linalg-tiling" , options = options )
233234 func = apply_registered_pass (func , "loop-invariant-code-motion" )
234235 apply_registered_pass (func , "vectorization-pass" )
@@ -315,7 +316,7 @@ def default_pipeline(
315316 # #if defined(__x86_64__)
316317 # options.x86Vector = true;
317318 # #endif
318- options = f "enable-amx= { int (xsmm_utils .has_amx ())} "
319+ options = { "enable-amx" : int (xsmm_utils .has_amx ())}
319320 mod = apply_registered_pass (mod , "convert-vector-to-llvm" , options = options )
320321 mod = apply_registered_pass (mod , "finalize-memref-to-llvm" )
321322 mod = apply_registered_pass (mod , "convert-scf-to-cf" )
@@ -327,9 +328,8 @@ def default_pipeline(
327328 # gpu-to-llvm cannot be invoked from transform-interpreter as it
328329 # tries to load ... something while multi-threaded PassManager is running.
329330 mod = apply_registered_pass (mod , "gpu-to-llvm" )
330- mod = apply_registered_pass (
331- mod , "gpu-module-to-binary" , options = "compilation-target=fatbin"
332- )
331+ options = {"compilation-target" : "fatbin" }
332+ mod = apply_registered_pass (mod , "gpu-module-to-binary" , options = options )
333333 mod = apply_registered_pass (mod , "convert-math-to-llvm" )
334334 if gpu_backend :
335335 mod = apply_registered_pass (mod , "async-to-async-runtime" )
0 commit comments