1818from ..utils import *
1919import os
2020import traceback
21+ from iree .compiler import ir
2122
2223
2324@dataclass
@@ -71,63 +72,84 @@ def get_flops(self) -> int:
7172 flops = 2 * self .M * self .N * self .K
7273 return flops
7374
75+ def _convert_dtype_to_mlir (dtype : str ) -> ir .Type :
76+ dtypes = {
77+ "i8" : lambda : ir .IntegerType .get_signless (8 ),
78+ "i16" : lambda : ir .IntegerType .get_signless (16 ),
79+ "i32" : lambda : ir .IntegerType .get_signless (32 ),
80+ "i64" : lambda : ir .IntegerType .get_signless (64 ),
81+ "f16" : lambda : ir .F16Type .get (),
82+ "f32" : lambda : ir .F32Type .get (),
83+ "f64" : lambda : ir .F64Type .get (),
84+ "bf16" : lambda : ir .BF16Type .get (),
85+ }
86+ return dtypes [dtype ]()
7487
7588def generate_mlir (config : GemmConfig ):
7689 K = config .K
7790 M = config .M
7891 N = config .N
79- operand_element_type = config .operand_element_type
80- acc_element_type = config .accumulator_element_type
81- result_element_type = config .result_element_type
82- is_integer = operand_element_type .startswith ('i' )
83- literal_zero = "0" if is_integer else "0.0"
92+
93+ with ir .Location .name (config .get_name ()):
94+ operand_element_type = _convert_dtype_to_mlir (config .operand_element_type )
95+ acc_element_type = _convert_dtype_to_mlir (config .accumulator_element_type )
96+ result_element_type = _convert_dtype_to_mlir (config .result_element_type )
97+ is_integer = isinstance (operand_element_type , ir .IntegerType )
98+ literal_zero = ir .IntegerAttr .get (acc_element_type , 0 ) if is_integer else ir .FloatAttr .get (acc_element_type , 0.0 )
99+ K_M_operand_tensor_type = ir .RankedTensorType .get ([K , M ], operand_element_type )
100+ M_K_operand_tensor_type = ir .RankedTensorType .get ([M , K ], operand_element_type )
101+ K_N_operand_tensor_type = ir .RankedTensorType .get ([K , N ], operand_element_type )
102+ N_K_operand_tensor_type = ir .RankedTensorType .get ([N , K ], operand_element_type )
103+ M_N_acc_tensor_type = ir .RankedTensorType .get ([M , N ], acc_element_type )
104+ M_N_result_tensor_type = ir .RankedTensorType .get ([M , N ], result_element_type )
105+
84106 trunc_op = "arith.trunci" if is_integer else "arith.truncf"
85107
86108 tA = config .tA
87109 tB = config .tB
88110 mlir_template_matmul_transpose_a = f"""
89111module {{
90- func.func @main(%arg0: tensor< { K } x { M } x { operand_element_type } > , %arg1: tensor< { K } x { N } x { operand_element_type } > ) -> tensor< { M } x { N } x { result_element_type } > {{
91- %cst = arith.constant { literal_zero } : { acc_element_type }
92- %0 = tensor.empty() : tensor< { M } x { N } x { acc_element_type } >
93- %1 = linalg.fill ins(%cst : { acc_element_type } ) outs(%0 : tensor< { M } x { N } x { acc_element_type } > ) -> tensor< { M } x { N } x { acc_element_type } >
94- %2 = linalg.matmul_transpose_a ins(%arg0, %arg1 : tensor< { K } x { M } x { operand_element_type } >, tensor< { K } x { N } x { operand_element_type } > )
95- outs(%1 : tensor< { M } x { N } x { acc_element_type } > )
96- -> tensor< { M } x { N } x { acc_element_type } >
112+ func.func @main(%arg0: { K_M_operand_tensor_type } , %arg1: { K_N_operand_tensor_type } ) -> { M_N_result_tensor_type } {{
113+ %cst = arith.constant { literal_zero }
114+ %0 = tensor.empty() : { M_N_acc_tensor_type }
115+ %1 = linalg.fill ins(%cst : { acc_element_type } ) outs(%0 : { M_N_acc_tensor_type } ) -> { M_N_acc_tensor_type }
116+ %2 = linalg.matmul_transpose_a ins(%arg0, %arg1 : { K_M_operand_tensor_type } , { K_N_operand_tensor_type } )
117+ outs(%1 : { M_N_acc_tensor_type } )
118+ -> { M_N_acc_tensor_type }
97119"""
98120
99121 mlir_template_matmul_transpose_b = f"""
100122module {{
101- func.func @main(%arg0: tensor< { M } x { K } x { operand_element_type } > , %arg1: tensor< { N } x { K } x { operand_element_type } > ) -> tensor< { M } x { N } x { result_element_type } > {{
102- %cst = arith.constant { literal_zero } : { acc_element_type }
103- %0 = tensor.empty() : tensor< { M } x { N } x { acc_element_type } >
104- %1 = linalg.fill ins(%cst : { acc_element_type } ) outs(%0 : tensor< { M } x { N } x { acc_element_type } > ) -> tensor< { M } x { N } x { acc_element_type } >
105- %2 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor< { M } x { K } x { operand_element_type } >, tensor< { N } x { K } x { operand_element_type } > )
106- outs(%1 : tensor< { M } x { N } x { acc_element_type } > )
107- -> tensor< { M } x { N } x { acc_element_type } >
123+ func.func @main(%arg0: { M_K_operand_tensor_type } , %arg1: { N_K_operand_tensor_type } ) -> { M_N_result_tensor_type } {{
124+ %cst = arith.constant { literal_zero }
125+ %0 = tensor.empty() : { M_N_acc_tensor_type }
126+ %1 = linalg.fill ins(%cst : { acc_element_type } ) outs(%0 : { M_N_acc_tensor_type } ) -> { M_N_acc_tensor_type }
127+ %2 = linalg.matmul_transpose_b ins(%arg0, %arg1 : { M_K_operand_tensor_type } , { N_K_operand_tensor_type } )
128+ outs(%1 : { M_N_acc_tensor_type } )
129+ -> { M_N_acc_tensor_type }
108130"""
109131
110132 mlir_template_matmul_normal = f"""
111133module {{
112- func.func @main(%arg0: tensor< { M } x { K } x { operand_element_type } > , %arg1: tensor< { K } x { N } x { operand_element_type } > ) -> tensor< { M } x { N } x { result_element_type } > {{
113- %cst = arith.constant { literal_zero } : { acc_element_type }
114- %0 = tensor.empty() : tensor< { M } x { N } x { acc_element_type } >
115- %1 = linalg.fill ins(%cst : { acc_element_type } ) outs(%0 : tensor< { M } x { N } x { acc_element_type } > ) -> tensor< { M } x { N } x { acc_element_type } >
116- %2 = linalg.matmul ins(%arg0, %arg1 : tensor< { M } x { K } x { operand_element_type } >, tensor< { K } x { N } x { operand_element_type } > )
117- outs(%1 : tensor< { M } x { N } x { acc_element_type } > )
118- -> tensor< { M } x { N } x { acc_element_type } >
134+ func.func @main(%arg0: { M_K_operand_tensor_type } , %arg1: { K_N_operand_tensor_type } ) -> { M_N_result_tensor_type } {{
135+ %cst = arith.constant { literal_zero }
136+ %0 = tensor.empty() : { M_N_acc_tensor_type }
137+ %1 = linalg.fill ins(%cst : { acc_element_type } ) outs(%0 : { M_N_acc_tensor_type } ) -> { M_N_acc_tensor_type }
138+ %2 = linalg.matmul ins(%arg0, %arg1 : { M_K_operand_tensor_type } , { K_N_operand_tensor_type } )
139+ outs(%1 : { M_N_acc_tensor_type } )
140+ -> { M_N_acc_tensor_type }
119141"""
120142 mlir_template_matmul = mlir_template_matmul_transpose_a if tA == "T" else mlir_template_matmul_transpose_b if tB == "T" else mlir_template_matmul_normal
121143
122144 mlir_template_return_truncated = f"""
123- %3 = { trunc_op } %2 : tensor< { M } x { N } x { acc_element_type } > to tensor< { M } x { N } x { result_element_type } >
124- return %3 : tensor< { M } x { N } x { result_element_type } >
145+ %3 = { trunc_op } %2 : { M_N_acc_tensor_type } to { M_N_result_tensor_type }
146+ return %3 : { M_N_result_tensor_type }
125147 }}
126148}}
127149"""
128150
129151 mlir_template_return_untruncated = f"""
130- return %2 : tensor< { M } x { N } x { result_element_type } >
152+ return %2 : { M_N_result_tensor_type }
131153 }}
132154}}
133155"""
@@ -167,7 +189,7 @@ def get_tk_tuned_config(config: GemmConfig) -> TkTunedConfig:
167189 # Default config
168190 return TkTunedConfig (64 , 64 , 32 , 2 , 2 , 1 , 2 , 2 , 2 , 1 , 1 , 2 )
169191
170- def _convert_dtype (dtype : str ):
192+ def _convert_dtype_to_tk (dtype : str ):
171193 dtypes = {
172194 "i8" : tkl .i8 ,
173195 "i16" : tkl .i16 ,
@@ -189,7 +211,7 @@ def generate_tk_mlir(config: GemmConfig, vmfb_file: Path):
189211 assert config .operand_element_type == 'f16' , "Unsupported problem"
190212 assert config .accumulator_element_type == 'f32' , "Unsupported problem"
191213
192- res_dtype = _convert_dtype (config .result_element_type )
214+ res_dtype = _convert_dtype_to_tk (config .result_element_type )
193215 # Input sizes
194216 M = tkl .sym .M
195217 N = tkl .sym .N
@@ -306,7 +328,8 @@ def compile_gemm_config(
306328 f .write (traceback .format_exc ())
307329 return mlir_file , None
308330 else :
309- mlir_content = generate_mlir (config )
331+ with ir .Context ():
332+ mlir_content = generate_mlir (config )
310333
311334 # Write MLIR content to file
312335 with open (mlir_file , "w" ) as f :
0 commit comments