Skip to content

Commit f73990d

Browse files
committed
[Tune] Introduce tpp-tune
Walks schedule IR, collects options (name -> values mapping), calls an `autotune` procedure (dummy implementation just selects first value) to select one value for each name, walks schedule IR again to introduce a param constant for the each option's value and replaces the users of the `transform.tune.select` op with the constant op's result. Run it as `tpp-sched | tpp-tune` where `tpp-sched` produces IR containing `transform.tune.select` ops.
1 parent 0575d06 commit f73990d

File tree

7 files changed

+119
-7
lines changed

7 files changed

+119
-7
lines changed

python/mlir/dialects/transform/tune.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,16 @@
33

44
register_dialect_extension(get_dialect_registry())
55

6-
from ...ir import ArrayAttr, SymbolRefAttr, Attribute, Type
6+
from ...ir import (
7+
ArrayAttr,
8+
SymbolRefAttr,
9+
Attribute,
10+
Type,
11+
StringAttr,
12+
IntegerAttr,
13+
IntegerType,
14+
BoolAttr,
15+
)
716
from .._tune_transform_ops_gen import TuneSelectOp
817

918
from collections.abc import Sequence
@@ -13,13 +22,27 @@
1322
def select(
1423
selected: Type, # transform.any_param or transform.param<...>
1524
name: Union[str, Attribute],
16-
options: Union[ArrayAttr, Sequence[Attribute]],
25+
options: Union[ArrayAttr, Sequence[Union[Attribute, str, int, bool]]],
1726
loc=None,
1827
ip=None,
1928
) -> TuneSelectOp:
2029
if isinstance(name, str):
2130
name = SymbolRefAttr.get([name])
2231

32+
if not isinstance(options, ArrayAttr):
33+
option_attrs = []
34+
for option in options:
35+
if isinstance(option, str):
36+
option_attrs.append(StringAttr.get(option))
37+
elif isinstance(option, int):
38+
int_type = IntegerType.get_signless(64)
39+
option_attrs.append(IntegerAttr.get(int_type, option))
40+
elif isinstance(option, bool):
41+
option_attrs.append(BoolAttr.get(option))
42+
elif isinstance(option, Attribute):
43+
option_attrs.append(option)
44+
options = ArrayAttr.get(option_attrs)
45+
2346
return TuneSelectOp(
2447
selected=selected,
2548
name=name,

python/mlir/dialects/tune.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from .._mlir_libs import get_dialect_registry
2-
from .._mlir_libs._tppDialects.tune import register_dialect
2+
from .._mlir_libs._tppDialects.tune import *
33

44
register_dialect(get_dialect_registry())

python/mlir/tpp/sched/bundles.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
from typing import Optional, Sequence
22

3-
from mlir import ir
43
from mlir.dialects import transform
5-
from .common import apply_registered_pass, match
4+
from .common import apply_registered_pass, match, select
65
from .utils import GpuBackend, PipelineInterrupt
76

87
from ..xsmm import utils as xsmm_utils
@@ -34,7 +33,12 @@ def tpp_mapping(mod, lower_pack_unpack_without_transpose: bool = False, **_confi
3433
func = apply_registered_pass(func, "pack-conv2DNchwFchw")
3534
func = apply_registered_pass(func, "pack-conv2DNhwcHwcf")
3635
func = apply_registered_pass(func, "rewrite-conv-to-matmul-or-brgemm")
37-
func = apply_registered_pass(func, "pack-matmul")
36+
m = select("m", [2, 4, 8])
37+
n = select("n", [4, 8, 16])
38+
k = select("k", [2, 4, 8, 16])
39+
func = apply_registered_pass(
40+
func, "pack-matmul", options={"block-factors": [m, n, k]}
41+
)
3842
apply_registered_pass(func, "pack-vnni")
3943
if lower_pack_unpack_without_transpose:
4044
mod = apply_registered_pass(mod, "lower-packs-unpacks-without-transpose")

python/mlir/tpp/sched/common.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
from mlir.dialects import transform
2-
from mlir.dialects.transform import structured
2+
from mlir.dialects.transform import structured, tune
33

44

55
# Wrapper to addresss verbosity.
66
def apply_registered_pass(*args, **kwargs):
77
return transform.apply_registered_pass(transform.AnyOpType.get(), *args, **kwargs)
88

99

10+
# Wrapper to addresss verbosity.
11+
def select(*args, **kwargs):
12+
return tune.select(transform.AnyParamType.get(), *args, **kwargs)
13+
14+
1015
# Wrapper to addresss verbosity.
1116
def match(*args, **kwargs):
1217
return structured.MatchOp(transform.AnyOpType.get(), *args, **kwargs)

tools/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@ add_subdirectory(mlir-gen)
22
add_subdirectory(tpp-opt)
33
add_subdirectory(tpp-run)
44
add_subdirectory(tpp-sched)
5+
add_subdirectory(tpp-tune)
56
add_subdirectory(fpcmp)
67
add_subdirectory(bench-ref)

tools/tpp-tune/CMakeLists.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
file(MAKE_DIRECTORY
2+
${CMAKE_BINARY_DIR}/bin)
3+
file(CREATE_LINK
4+
${CMAKE_CURRENT_SOURCE_DIR}/tpp-tune.py
5+
${CMAKE_BINARY_DIR}/bin/tpp-tune
6+
SYMBOLIC)
7+
8+
9+
add_custom_target(tpp-tune DEPENDS ${CMAKE_BINARY_DIR}/bin/tpp-tune TPPPythonModules)

tools/tpp-tune/tpp-tune.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
#!/usr/bin/env python3
2+
3+
import sys
4+
from pathlib import Path
5+
from typing import Union, Sequence, Dict
6+
from pprint import pprint
7+
8+
9+
# Enable automagically finding TPP-MLIR's python modules (which include
10+
# and extend MLIR's Python bindings).
11+
python_packages_path = Path(__file__).parent.parent / "python_packages"
12+
if python_packages_path.exists():
13+
sys.path = [str(python_packages_path)] + sys.path
14+
15+
16+
from mlir import ir
17+
from mlir.dialects import transform
18+
from mlir.dialects.transform import tune as transform_tune
19+
20+
21+
def walker(f):
22+
def wrapper(op: Union[ir.OpView, ir.Operation]):
23+
f(op)
24+
for region in op.regions:
25+
for block in region.blocks:
26+
for child_op in block:
27+
wrapper(child_op)
28+
29+
return wrapper
30+
31+
32+
def autotune(choices: Dict[str, Sequence[ir.Attribute]]) -> Dict[str, ir.Attribute]:
33+
# Aint tuning easy!!
34+
return {key: values[0] for key, values in choices.items()}
35+
36+
37+
file = sys.stdin
38+
if len(sys.argv) > 1 and sys.argv[1] != "-":
39+
file = open(sys.argv[1])
40+
41+
42+
with ir.Context(), ir.Location.unknown():
43+
schedule = ir.Module.parse(file.read())
44+
45+
choices = {}
46+
47+
@walker
48+
def choices_finder(op):
49+
if isinstance(op, transform_tune.TuneSelectOp):
50+
choices[op.name] = tuple(op.options)
51+
52+
choices_finder(schedule.operation)
53+
54+
pprint(choices, stream=sys.stderr)
55+
56+
selected = autotune(choices)
57+
58+
@walker
59+
def selected_rewriter(op: Union[ir.OpView, ir.Operation]):
60+
if isinstance(op, transform_tune.TuneSelectOp):
61+
with ir.InsertionPoint(op):
62+
param = transform.param_constant(
63+
transform.AnyParamType.get(), selected[op.name]
64+
)
65+
for use in op.result.uses:
66+
use.owner.operands[use.operand_number] = param
67+
68+
selected_rewriter(schedule.operation)
69+
70+
print(schedule)

0 commit comments

Comments
 (0)