|
16 | 16 | from iree.compiler.dialects import func # type: ignore |
17 | 17 | from iree.compiler.dialects import iree_gpu # type: ignore |
18 | 18 | from iree.compiler.dialects import iree_codegen # type: ignore |
19 | | -from iree.compiler.dialects import linalg # type: ignore |
| 19 | +from iree.compiler.dialects import linalg, arith, tensor, func # type: ignore |
20 | 20 |
|
21 | 21 | from . import common |
22 | 22 | from . import dispatch_parser |
@@ -128,6 +128,98 @@ def test_get_contraction_operation(tuner_ctx: common.TunerContext) -> None: |
128 | 128 | assert shapes.matmul_size.K == [15, 256] |
129 | 129 |
|
130 | 130 |
|
| 131 | +def test_get_matmul_named_op(tuner_ctx: common.TunerContext) -> None: |
| 132 | + context = tuner_ctx.mlir_ctx |
| 133 | + with ir.Location.unknown(context): |
| 134 | + module = ir.Module.create() |
| 135 | + f16 = ir.F16Type.get() |
| 136 | + f32 = ir.F32Type.get() |
| 137 | + |
| 138 | + with ir.InsertionPoint(module.body): |
| 139 | + a_type = ir.RankedTensorType.get((16, 64), f16) |
| 140 | + b_type = ir.RankedTensorType.get((64, 32), f16) |
| 141 | + c_type = ir.RankedTensorType.get((16, 32), f32) |
| 142 | + |
| 143 | + dim_m = ir.AffineDimExpr.get(0) |
| 144 | + dim_n = ir.AffineDimExpr.get(1) |
| 145 | + dim_k = ir.AffineDimExpr.get(2) |
| 146 | + a_map = ir.AffineMap.get(3, 0, [dim_m, dim_k]) |
| 147 | + b_map = ir.AffineMap.get(3, 0, [dim_k, dim_n]) |
| 148 | + c_map = ir.AffineMap.get(3, 0, [dim_m, dim_n]) |
| 149 | + |
| 150 | + @func.FuncOp.from_py_func(a_type, b_type, c_type) |
| 151 | + def named_matmul(a, b, c): |
| 152 | + matmul_op = linalg.MatmulOp( |
| 153 | + result_tensors=[c_type], |
| 154 | + inputs=[a, b], |
| 155 | + outputs=[c], |
| 156 | + indexing_maps=[a_map, b_map, c_map], |
| 157 | + ) |
| 158 | + matmul_op.operation.attributes["root_op"] = ir.UnitAttr.get() |
| 159 | + |
| 160 | + root_op_list = iree_codegen.get_tuner_root_ops(module) |
| 161 | + assert len(root_op_list) == 1, "Expected one root op" |
| 162 | + root_op = root_op_list[0] |
| 163 | + |
| 164 | + parser = dispatch_parser.ContractionOpInterfaceParser(root_op) |
| 165 | + shapes = parser.get_problem_size() |
| 166 | + |
| 167 | + assert shapes.matmul_size.B == [] |
| 168 | + assert shapes.matmul_size.M == [16] |
| 169 | + assert shapes.matmul_size.N == [32] |
| 170 | + assert shapes.matmul_size.K == [64] |
| 171 | + assert shapes.lhs_type.shape == [16, 64] |
| 172 | + assert isinstance(shapes.lhs_type.element_type, ir.F16Type) |
| 173 | + assert shapes.rhs_type.shape == [64, 32] |
| 174 | + assert isinstance(shapes.rhs_type.element_type, ir.F16Type) |
| 175 | + assert shapes.res_type.shape == [16, 32] |
| 176 | + assert isinstance(shapes.res_type.element_type, ir.F32Type) |
| 177 | + |
| 178 | + |
| 179 | +def test_get_named_contraction_op(): |
| 180 | + with ir.Context(), ir.Location.unknown(): |
| 181 | + module = ir.Module.create() |
| 182 | + f32 = ir.F32Type.get() |
| 183 | + |
| 184 | + with ir.InsertionPoint(module.body): |
| 185 | + lhs_type = ir.RankedTensorType.get((5, 3), f32) |
| 186 | + rhs_type = ir.RankedTensorType.get((7, 3), f32) |
| 187 | + res_type = ir.RankedTensorType.get((5, 7), f32) |
| 188 | + |
| 189 | + @func.FuncOp.from_py_func(lhs_type, rhs_type, res_type) |
| 190 | + def named_contraction(lhs, rhs, res): |
| 191 | + dim_i = ir.AffineDimExpr.get(0) |
| 192 | + dim_j = ir.AffineDimExpr.get(1) |
| 193 | + dim_k = ir.AffineDimExpr.get(2) |
| 194 | + |
| 195 | + lhs_map = ir.AffineMap.get(3, 0, [dim_i, dim_k]) |
| 196 | + rhs_map = ir.AffineMap.get(3, 0, [dim_j, dim_k]) |
| 197 | + res_map = ir.AffineMap.get(3, 0, [dim_i, dim_j]) |
| 198 | + |
| 199 | + contraction_op = linalg.ContractOp( |
| 200 | + result_tensors=[res_type], |
| 201 | + inputs=[lhs, rhs], |
| 202 | + outputs=[res], |
| 203 | + indexing_maps=[lhs_map, rhs_map, res_map], |
| 204 | + ) |
| 205 | + contraction_op.attributes["root_op"] = ir.UnitAttr.get() |
| 206 | + |
| 207 | + root_op_list = iree_codegen.get_tuner_root_ops(module) |
| 208 | + assert len(root_op_list) == 1 |
| 209 | + root_op = root_op_list[0] |
| 210 | + |
| 211 | + parser = dispatch_parser.ContractionOpInterfaceParser(root_op) |
| 212 | + shape = parser.get_problem_size() |
| 213 | + |
| 214 | + assert shape.matmul_size.B == [] |
| 215 | + assert shape.matmul_size.M == [5] |
| 216 | + assert shape.matmul_size.N == [7] |
| 217 | + assert shape.matmul_size.K == [3] |
| 218 | + assert shape.lhs_type.shape == [5, 3] |
| 219 | + assert shape.rhs_type.shape == [7, 3] |
| 220 | + assert shape.res_type.shape == [5, 7] |
| 221 | + |
| 222 | + |
131 | 223 | def test_get_conv_operation(tuner_ctx: common.TunerContext) -> None: |
132 | 224 | context = tuner_ctx.mlir_ctx |
133 | 225 | module_str = """ |
|
0 commit comments