1- from typing import Optional , Sequence
1+ from typing import Optional , Sequence , Union
22
3+ from mlir import ir
34from mlir .dialects import transform
45from .common import apply_registered_pass , match , select
56from .utils import GpuBackend , PipelineInterrupt
@@ -21,7 +22,14 @@ def cleanup(op, **_config):
2122
2223
2324# TODO: make bundle into a NamedSequence to call with IncludeOp
24- def tpp_mapping (mod , lower_pack_unpack_without_transpose : bool = False , ** _config ):
25+ def tpp_mapping (
26+ mod ,
27+ lower_pack_unpack_without_transpose : bool = False ,
28+ pack_block_factors : Optional [
29+ Sequence [Union [Sequence [Union [int , ir .IntegerAttr ]], int , ir .IntegerAttr ]]
30+ ] = None ,
31+ ** _config ,
32+ ):
2533 "High-level transforms that map operations to TPP-compatible forms."
2634
2735 # Preprocess convolutions.
@@ -33,12 +41,14 @@ def tpp_mapping(mod, lower_pack_unpack_without_transpose: bool = False, **_confi
3341 func = apply_registered_pass (func , "pack-conv2DNchwFchw" )
3442 func = apply_registered_pass (func , "pack-conv2DNhwcHwcf" )
3543 func = apply_registered_pass (func , "rewrite-conv-to-matmul-or-brgemm" )
36- m = select ("m" , [2 , 4 , 8 ])
37- n = select ("n" , [4 , 8 , 16 ])
38- k = select ("k" , [2 , 4 , 8 , 16 ])
39- func = apply_registered_pass (
40- func , "pack-matmul" , options = {"block-factors" : [m , n , k ]}
41- )
44+ options = None
45+ if pack_block_factors :
46+ m_vals , n_vals , k_vals = pack_block_factors
47+ m = select ("m" , m_vals if isinstance (m_vals , Sequence ) else [m_vals ])
48+ n = select ("n" , n_vals if isinstance (n_vals , Sequence ) else [n_vals ])
49+ k = select ("k" , k_vals if isinstance (k_vals , Sequence ) else [k_vals ])
50+ options = {"block-factors" : [m , n , k ]}
51+ func = apply_registered_pass (func , "pack-matmul" , options = options )
4252 apply_registered_pass (func , "pack-vnni" )
4353 if lower_pack_unpack_without_transpose :
4454 mod = apply_registered_pass (mod , "lower-packs-unpacks-without-transpose" )
0 commit comments