Skip to content

Commit 6519ca9

Browse files
authored
[tuner] add tests for named ops (#1289)
Per suggestion from this comment #1264 (comment), this PR mainly removes the todo comment and adds tests for named ops. --------- Signed-off-by: Bangtian Liu <[email protected]>
1 parent 0bd1b1e commit 6519ca9

File tree

2 files changed

+94
-6
lines changed

2 files changed

+94
-6
lines changed

tuner/tuner/dispatch_parser.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,17 +44,13 @@ def get_problem_size(self) -> ProblemSize:
4444
pass
4545

4646

47-
# TODO(Max191): Support linalg named op versions of contraction ops. The
48-
# current matchers only work for linalg.generic ops.
4947
class ContractionOpInterfaceParser(DispatchParser):
5048
def __init__(self, root_op: ir.Operation):
5149
super().__init__(root_op)
5250

5351
def has_valid_root_op(self) -> bool:
5452
root_op = self.get_root_op()
55-
if not linalg.isa_contraction_op(root_op):
56-
return False
57-
return root_op.name == "linalg.generic"
53+
return linalg.isa_contraction_op(root_op)
5854

5955
def get_problem_size(self) -> ProblemSize:
6056
root_op = self.get_root_op()

tuner/tuner/dispatch_parser_test.py

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from iree.compiler.dialects import func # type: ignore
1717
from iree.compiler.dialects import iree_gpu # type: ignore
1818
from iree.compiler.dialects import iree_codegen # type: ignore
19-
from iree.compiler.dialects import linalg # type: ignore
19+
from iree.compiler.dialects import linalg, arith, tensor, func # type: ignore
2020

2121
from . import common
2222
from . import dispatch_parser
@@ -128,6 +128,98 @@ def test_get_contraction_operation(tuner_ctx: common.TunerContext) -> None:
128128
assert shapes.matmul_size.K == [15, 256]
129129

130130

131+
def test_get_matmul_named_op(tuner_ctx: common.TunerContext) -> None:
132+
context = tuner_ctx.mlir_ctx
133+
with ir.Location.unknown(context):
134+
module = ir.Module.create()
135+
f16 = ir.F16Type.get()
136+
f32 = ir.F32Type.get()
137+
138+
with ir.InsertionPoint(module.body):
139+
a_type = ir.RankedTensorType.get((16, 64), f16)
140+
b_type = ir.RankedTensorType.get((64, 32), f16)
141+
c_type = ir.RankedTensorType.get((16, 32), f32)
142+
143+
dim_m = ir.AffineDimExpr.get(0)
144+
dim_n = ir.AffineDimExpr.get(1)
145+
dim_k = ir.AffineDimExpr.get(2)
146+
a_map = ir.AffineMap.get(3, 0, [dim_m, dim_k])
147+
b_map = ir.AffineMap.get(3, 0, [dim_k, dim_n])
148+
c_map = ir.AffineMap.get(3, 0, [dim_m, dim_n])
149+
150+
@func.FuncOp.from_py_func(a_type, b_type, c_type)
151+
def named_matmul(a, b, c):
152+
matmul_op = linalg.MatmulOp(
153+
result_tensors=[c_type],
154+
inputs=[a, b],
155+
outputs=[c],
156+
indexing_maps=[a_map, b_map, c_map],
157+
)
158+
matmul_op.operation.attributes["root_op"] = ir.UnitAttr.get()
159+
160+
root_op_list = iree_codegen.get_tuner_root_ops(module)
161+
assert len(root_op_list) == 1, "Expected one root op"
162+
root_op = root_op_list[0]
163+
164+
parser = dispatch_parser.ContractionOpInterfaceParser(root_op)
165+
shapes = parser.get_problem_size()
166+
167+
assert shapes.matmul_size.B == []
168+
assert shapes.matmul_size.M == [16]
169+
assert shapes.matmul_size.N == [32]
170+
assert shapes.matmul_size.K == [64]
171+
assert shapes.lhs_type.shape == [16, 64]
172+
assert isinstance(shapes.lhs_type.element_type, ir.F16Type)
173+
assert shapes.rhs_type.shape == [64, 32]
174+
assert isinstance(shapes.rhs_type.element_type, ir.F16Type)
175+
assert shapes.res_type.shape == [16, 32]
176+
assert isinstance(shapes.res_type.element_type, ir.F32Type)
177+
178+
179+
def test_get_named_contraction_op():
180+
with ir.Context(), ir.Location.unknown():
181+
module = ir.Module.create()
182+
f32 = ir.F32Type.get()
183+
184+
with ir.InsertionPoint(module.body):
185+
lhs_type = ir.RankedTensorType.get((5, 3), f32)
186+
rhs_type = ir.RankedTensorType.get((7, 3), f32)
187+
res_type = ir.RankedTensorType.get((5, 7), f32)
188+
189+
@func.FuncOp.from_py_func(lhs_type, rhs_type, res_type)
190+
def named_contraction(lhs, rhs, res):
191+
dim_i = ir.AffineDimExpr.get(0)
192+
dim_j = ir.AffineDimExpr.get(1)
193+
dim_k = ir.AffineDimExpr.get(2)
194+
195+
lhs_map = ir.AffineMap.get(3, 0, [dim_i, dim_k])
196+
rhs_map = ir.AffineMap.get(3, 0, [dim_j, dim_k])
197+
res_map = ir.AffineMap.get(3, 0, [dim_i, dim_j])
198+
199+
contraction_op = linalg.ContractOp(
200+
result_tensors=[res_type],
201+
inputs=[lhs, rhs],
202+
outputs=[res],
203+
indexing_maps=[lhs_map, rhs_map, res_map],
204+
)
205+
contraction_op.attributes["root_op"] = ir.UnitAttr.get()
206+
207+
root_op_list = iree_codegen.get_tuner_root_ops(module)
208+
assert len(root_op_list) == 1
209+
root_op = root_op_list[0]
210+
211+
parser = dispatch_parser.ContractionOpInterfaceParser(root_op)
212+
shape = parser.get_problem_size()
213+
214+
assert shape.matmul_size.B == []
215+
assert shape.matmul_size.M == [5]
216+
assert shape.matmul_size.N == [7]
217+
assert shape.matmul_size.K == [3]
218+
assert shape.lhs_type.shape == [5, 3]
219+
assert shape.rhs_type.shape == [7, 3]
220+
assert shape.res_type.shape == [5, 7]
221+
222+
131223
def test_get_conv_operation(tuner_ctx: common.TunerContext) -> None:
132224
context = tuner_ctx.mlir_ctx
133225
module_str = """

0 commit comments

Comments
 (0)