Skip to content

Commit 0886ba7

Browse files
authored
[tuner] retire op_matchers.py through IREE python bindings (#1264)
This PR offloads the `op_matcher.py` logic to the IREE compiler through exposed python bindings as shown below: - `root_op_list = iree_codegen.get_tuner_root_ops(input_module)` - `linalg::isaContractionOpInterface` and `linalg::inferContractionDims` - `linalg::isaConvolutionOpInterface` and `linalg::inferConvolutionDims` Issue: #1110 #814 #nod-ai/playbook#74 --------- Signed-off-by: Bangtian Liu <[email protected]>
1 parent 2522f23 commit 0886ba7

File tree

9 files changed

+183
-300
lines changed

9 files changed

+183
-300
lines changed

tuner/tuner/candidate_gen.py

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ class DispatchTuner(DispatchParser):
3939
@abstractmethod
4040
def get_td_spec(
4141
self,
42-
ir_module: ir.Module,
4342
compilation_info: iree_codegen.CompilationInfoAttr,
4443
) -> ir.Module:
4544
"""Generate a transform dialect spec that applies the compilation info attr."""
@@ -62,12 +61,14 @@ def find_handler(self, op_name: str) -> DispatchTuner:
6261

6362

6463
class ContractionOpInterfaceTuner(DispatchTuner, ContractionOpInterfaceParser):
64+
def __init__(self, root_op: ir.Operation):
65+
super().__init__(root_op)
66+
6567
def get_td_spec(
6668
self,
67-
ir_module: ir.Module,
6869
compilation_info: iree_codegen.CompilationInfoAttr,
6970
) -> ir.Module:
70-
contraction_op: ir.Operation = self.get_contraction_operation(ir_module)
71+
contraction_op = self.get_root_op()
7172
lhs_type = ir.ShapedType(contraction_op.operands[0].type)
7273
rhs_type = ir.ShapedType(contraction_op.operands[1].type)
7374
acc_type = ir.ShapedType(contraction_op.operands[2].type)
@@ -77,17 +78,16 @@ def get_td_spec(
7778
# TODO(Max191): Get the function name from the func.func in the input module.
7879
func_name = f"match_contraction_{M}x{N}x{K}_{lhs_type.element_type}x{rhs_type.element_type}x{acc_type.element_type}"
7980
return build_td_spec(
80-
ir_module.context, contraction_op, compilation_info, func_name
81+
contraction_op.context, contraction_op, compilation_info, func_name
8182
)
8283

8384

8485
class ConvolutionOpInterfaceTuner(DispatchTuner, ConvolutionOpInterfaceParser):
8586
def get_td_spec(
8687
self,
87-
ir_module: ir.Module,
8888
compilation_info: iree_codegen.CompilationInfoAttr,
8989
) -> ir.Module:
90-
conv_op: ir.Operation = self.get_conv_operation(ir_module)
90+
conv_op = self.get_root_op()
9191
assert (
9292
conv_op.name == "linalg.conv_2d_nhwc_hwcf"
9393
), "expected linalg.conv_2d_nhwc_hwcf"
@@ -104,7 +104,7 @@ def get_td_spec(
104104
conv_type = conv_op.name.split(".")[-1]
105105
# TODO(Max191): Get the function name from the func.func in the input module.
106106
func_name = f"match_{conv_type}_{N}x{H}x{W}x{C}x{P}x{Q}x{F}_{lhs_type.element_type}x{rhs_type.element_type}x{acc_type.element_type}"
107-
return build_td_spec(ir_module.context, conv_op, compilation_info, func_name)
107+
return build_td_spec(conv_op.context, conv_op, compilation_info, func_name)
108108

109109

110110
@dataclass
@@ -156,21 +156,33 @@ def generate_configs_and_td_specs(
156156
pipeline_options_search_space: PipelineOptionsSearchSpace = PipelineOptionsSearchSpace(),
157157
codegen_pipeline: iree_codegen.DispatchLoweringPassPipeline = iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute,
158158
) -> list[ir.Module]:
159-
dispatch_tuner_registry = DispatchTunerRegistry()
160-
dispatch_tuner_registry.register(
161-
[
162-
ContractionOpInterfaceTuner(),
163-
ConvolutionOpInterfaceTuner(),
164-
]
165-
)
159+
dispatch_tuners: list[type[DispatchTuner]] = [
160+
ContractionOpInterfaceTuner,
161+
ConvolutionOpInterfaceTuner,
162+
]
163+
164+
root_op_list = iree_codegen.get_tuner_root_ops(input_module)
165+
if len(root_op_list) == 0:
166+
tune_logger.error(
167+
"No root ops found. Did you forget to pass "
168+
"--iree-config-add-tuner-attributes during compilation?"
169+
)
170+
return []
171+
elif len(root_op_list) > 1:
172+
tune_logger.error("Multiple root ops found. Only one is currently supported.")
173+
return []
166174

167-
walk_result: OpWalkResult = walk_mlir_op(input_module, dispatch_tuner_registry)
175+
root_op = root_op_list[0]
176+
177+
dispatch_tuner: Optional[DispatchTuner] = None
178+
for tuner_class in dispatch_tuners:
179+
tuner = tuner_class(root_op)
180+
if tuner.has_valid_root_op():
181+
dispatch_tuner = tuner
182+
break
168183

169-
dispatch_tuner = walk_result.dispatch_tuner
170184
assert dispatch_tuner, "No suitable dispatch tuner found"
171-
problem_size: ProblemSize = dispatch_tuner.get_shapes(
172-
str(input_module).splitlines()
173-
)
185+
problem_size: ProblemSize = dispatch_tuner.get_problem_size()
174186
tune_logger.debug(str(problem_size))
175187

176188
# Index 0 is reserved for default config, so it gets a placeholder spec.
@@ -196,7 +208,7 @@ def generate_configs_and_td_specs(
196208
if i >= limit:
197209
break
198210
tune_logger.debug(f"Solution #{i+1}: {config}")
199-
td_spec_module = dispatch_tuner.get_td_spec(input_module, config)
211+
td_spec_module = dispatch_tuner.get_td_spec(config)
200212
assert td_spec_module, "Failed to generate transform dialect spec"
201213
config_specs.append(td_spec_module)
202214

@@ -263,7 +275,7 @@ def run_command(run_pack: RunPack) -> RunResult:
263275
# info makes the inputs to compilation consistent, and allows for overwriting
264276
# the compilation info with generated TD specs during codegen.
265277
def strip_root_op_attr(module: ir.Module):
266-
root_ops: list[ir.Operation] = get_ops_from_module(module, is_root_op)
278+
root_ops: list[ir.Operation] = iree_codegen.get_tuner_root_ops(module)
267279
for root_op in root_ops:
268280
assert (
269281
ROOT_OP_ATTR_NAME in root_op.opview.attributes

tuner/tuner/candidate_gen_test.py

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,30 @@
1919

2020
from . import candidate_gen
2121
from . import common
22-
from . import op_matchers
2322

2423
from .test_utils import tuner_ctx
2524

2625

26+
def walk_collect_ops(
27+
op: ir.Operation,
28+
ops: list[ir.Operation],
29+
fn,
30+
) -> ir.WalkResult:
31+
if fn(op):
32+
ops.append(op)
33+
return ir.WalkResult.ADVANCE
34+
35+
36+
def get_ops_from_module(module: ir.Module, fn):
37+
ops: list[ir.Operation] = []
38+
for op in module.body.operations:
39+
op.walk(
40+
lambda op: walk_collect_ops(op, ops, fn),
41+
ir.WalkOrder.POST_ORDER,
42+
)
43+
return ops
44+
45+
2746
def test_get_td_spec_contraction(tuner_ctx: common.TunerContext) -> None:
2847
context = tuner_ctx.mlir_ctx
2948
module_str = """
@@ -75,14 +94,15 @@ def test_get_td_spec_contraction(tuner_ctx: common.TunerContext) -> None:
7594
)
7695

7796
ir_module = ir.Module.parse(module_str, context)
97+
root_op_list = iree_codegen.get_tuner_root_ops(ir_module)
98+
assert len(root_op_list) == 1
99+
root_op = root_op_list[0]
78100

79-
tuner = candidate_gen.ContractionOpInterfaceTuner()
80-
td_spec_module = tuner.get_td_spec(ir_module, compilation_info)
101+
tuner = candidate_gen.ContractionOpInterfaceTuner(root_op)
102+
td_spec_module = tuner.get_td_spec(compilation_info)
81103
assert td_spec_module
82104

83-
named_sequence_ops: list[
84-
transform.NamedSequenceOp
85-
] = op_matchers.get_ops_from_module(
105+
named_sequence_ops: list[transform.NamedSequenceOp] = get_ops_from_module(
86106
module=td_spec_module,
87107
fn=lambda op: isinstance(op.opview, transform.NamedSequenceOp),
88108
)
@@ -157,14 +177,14 @@ def test_get_td_spec_convolution(tuner_ctx: common.TunerContext) -> None:
157177
)
158178

159179
ir_module = ir.Module.parse(module_str, context)
160-
161-
tuner = candidate_gen.ConvolutionOpInterfaceTuner()
162-
td_spec_module = tuner.get_td_spec(ir_module, compilation_info)
180+
root_op_list = iree_codegen.get_tuner_root_ops(ir_module)
181+
assert len(root_op_list) == 1
182+
root_op = root_op_list[0]
183+
tuner = candidate_gen.ConvolutionOpInterfaceTuner(root_op)
184+
td_spec_module = tuner.get_td_spec(compilation_info)
163185
assert td_spec_module
164186

165-
named_sequence_ops: list[
166-
transform.NamedSequenceOp
167-
] = op_matchers.get_ops_from_module(
187+
named_sequence_ops: list[transform.NamedSequenceOp] = get_ops_from_module(
168188
module=td_spec_module,
169189
fn=lambda op: isinstance(op.opview, transform.NamedSequenceOp),
170190
)

tuner/tuner/common.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from iree.compiler.dialects import iree_gpu # type: ignore
2020
from iree.compiler.dialects import transform # type: ignore
2121
import iree.compiler as ireec # type: ignore
22+
from iree.compiler._mlir_libs._mlir import ir # type: ignore
2223

2324

2425
class CommonTypes:
@@ -144,6 +145,13 @@ def MNK(self) -> tuple[list[int], list[int], list[int]]:
144145
return (self.matmul_size.M, self.matmul_size.N, self.matmul_size.K)
145146

146147

148+
def get_map_result_dim_positions(map: ir.AffineMap) -> Optional[list[int]]:
149+
if not map.is_projected_permutation:
150+
return None
151+
152+
return [ir.AffineDimExpr(expr).position for expr in map.results]
153+
154+
147155
def get_compatible_mfma_intrinsics(
148156
problem_size: ProblemSize,
149157
mma_intrinsics: list[iree_gpu.MMAIntrinsic],

tuner/tuner/common_test.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,23 @@ def test_gpu_pipeline_options(tuner_ctx: common.TunerContext) -> None:
6363
)
6464

6565

66+
def test_get_map_result_dim_positions(tuner_ctx: common.TunerContext) -> None:
67+
dim0 = ir.AffineDimExpr.get(0)
68+
dim1 = ir.AffineDimExpr.get(1)
69+
dim2 = ir.AffineDimExpr.get(2)
70+
71+
# Valid projected permutation: (d0, d1, d2) -> (d0, d2).
72+
valid_map = ir.AffineMap.get(3, 0, [dim0, dim2])
73+
result = common.get_map_result_dim_positions(valid_map)
74+
assert result == [0, 2], f"Expected [0, 2], got {result}"
75+
76+
# Not a projected permutation: (d0, d1, d2) -> (d0 + d1).
77+
sum_expr = dim0 + dim1
78+
invalid_map = ir.AffineMap.get(3, 0, [sum_expr])
79+
result = common.get_map_result_dim_positions(invalid_map)
80+
assert result is None, "Expected None for non-projected permutation"
81+
82+
6683
def test_get_pipeline_config(tuner_ctx: common.TunerContext) -> None:
6784
mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16
6885
mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)

0 commit comments

Comments
 (0)