Skip to content

Commit b5681dd

Browse files
committed
transform.tune.pick: tune.select but without forgetting the options
1 parent d80f2ba commit b5681dd

File tree

6 files changed

+120
-11
lines changed

6 files changed

+120
-11
lines changed

include/TPP/Dialect/Tune/TuneTransformOps.td

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#ifndef TUNE_TRANSFORM_OPS
22
#define TUNE_TRANSFORM_OPS
33

4+
include "mlir/IR/CommonAttrConstraints.td"
45
include "mlir/Dialect/Transform/IR/TransformDialect.td"
56
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
67
include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -17,7 +18,25 @@ def TuneSelectOp : Op<Transform_Dialect, "tune.select", [
1718

1819
let arguments = (ins SymbolRefAttr:$name,
1920
ArrayAttr:$options);
20-
let results = (outs TransformParamTypeInterface:$selected);
21+
let results = (outs TransformParamTypeInterface:$result);
22+
let assemblyFormat =
23+
"$name `from` $options attr-dict `:` type(results)";
24+
}
25+
26+
27+
def TunePickOp : Op<Transform_Dialect, "tune.pick", [
28+
DeclareOpInterfaceMethods<TransformOpInterface>,
29+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
30+
]> {
31+
let summary = "Non-deterministically select a value from a set of values";
32+
let description = [{
33+
TODO
34+
}];
35+
36+
let arguments = (ins SymbolRefAttr:$name,
37+
ArrayAttr:$options,
38+
OptionalAttr<AnyAttr>:$selected);
39+
let results = (outs TransformParamTypeInterface:$result);
2140
let assemblyFormat =
2241
"$name `from` $options attr-dict `:` type(results)";
2342
}

lib/TPP/Dialect/Tune/TransformOps/TuneTransformOps.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,34 @@ transform::TuneSelectOp::apply(transform::TransformRewriter &rewriter,
2929
<< "this op does not resolve non-deterministic choice!";
3030
}
3131

32+
//===----------------------------------------------------------------------===//
33+
// TunePickOp
34+
//===----------------------------------------------------------------------===//
35+
36+
void transform::TunePickOp::getEffects(
37+
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
38+
producesHandle(getOperation()->getOpResults(), effects);
39+
onlyReadsPayload(effects);
40+
}
41+
42+
DiagnosedSilenceableFailure
43+
transform::TunePickOp::apply(transform::TransformRewriter &rewriter,
44+
transform::TransformResults &results,
45+
transform::TransformState &state) {
46+
if (getSelected()) {
47+
results.setParams(getOperation()->getOpResults()[0], *getSelected());
48+
return DiagnosedSilenceableFailure::success();
49+
}
50+
51+
if (getOptions().size() == 1) {
52+
results.setParams(getOperation()->getOpResults()[0], getOptions()[0]);
53+
return DiagnosedSilenceableFailure::success();
54+
}
55+
56+
return emitDefiniteFailure() << "non-deterministic choice is only resolved "
57+
"through providing a `selected` attr!";
58+
}
59+
3260
//===----------------------------------------------------------------------===//
3361
// Transform op registration
3462
//===----------------------------------------------------------------------===//

python/mlir/dialects/transform/tune.py

Lines changed: 55 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@
1313
IntegerType,
1414
BoolAttr,
1515
)
16-
from .._tune_transform_ops_gen import TuneSelectOp
16+
from .._tune_transform_ops_gen import *
1717

1818
from collections.abc import Sequence
19-
from typing import Union
19+
from typing import Union, Optional
2020

2121

2222
def select(
23-
selected: Type, # transform.any_param or transform.param<...>
23+
result: Type, # transform.any_param or transform.param<...>
2424
name: Union[str, Attribute],
2525
options: Union[ArrayAttr, Sequence[Union[Attribute, str, int, bool]]],
2626
loc=None,
@@ -44,9 +44,60 @@ def select(
4444
options = ArrayAttr.get(option_attrs)
4545

4646
return TuneSelectOp(
47-
selected=selected,
47+
result=result,
48+
name=name,
49+
options=options,
50+
loc=loc,
51+
ip=ip,
52+
)
53+
54+
55+
def pick(
56+
result: Type, # transform.any_param or transform.param<...>
57+
name: Union[str, Attribute],
58+
options: Union[ArrayAttr, Sequence[Union[Attribute, str, int, bool]]],
59+
*,
60+
selected: Optional[Union[Attribute, str, int, bool]] = None,
61+
loc=None,
62+
ip=None,
63+
) -> TunePickOp:
64+
if isinstance(name, str):
65+
name = SymbolRefAttr.get([name])
66+
67+
if not isinstance(options, ArrayAttr):
68+
option_attrs = []
69+
for option in options:
70+
if isinstance(option, str):
71+
option_attrs.append(StringAttr.get(option))
72+
elif isinstance(option, int):
73+
int_type = IntegerType.get_signless(64)
74+
option_attrs.append(IntegerAttr.get(int_type, option))
75+
elif isinstance(option, bool):
76+
option_attrs.append(BoolAttr.get(option))
77+
elif isinstance(option, Attribute):
78+
option_attrs.append(option)
79+
else:
80+
assert False
81+
options = ArrayAttr.get(option_attrs)
82+
83+
84+
if selected is None:
85+
pass
86+
elif isinstance(selected, str):
87+
selected = StringAttr.get(selected)
88+
elif isinstance(selected, int):
89+
int_type = IntegerType.get_signless(64)
90+
selected = IntegerAttr.get(int_type, selected)
91+
elif isinstance(selected, bool):
92+
selected = BoolAttr.get(selected)
93+
elif not isinstance(selected, Attribute):
94+
assert False
95+
96+
return TunePickOp(
97+
result=result,
4898
name=name,
4999
options=options,
100+
selected=selected,
50101
loc=loc,
51102
ip=ip,
52103
)

python/mlir/tpp/sched/bundles.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from mlir import ir
44
from mlir.dialects import transform
5-
from .common import apply_registered_pass, match, select
5+
from .common import apply_registered_pass, match, select, pick
66
from .utils import GpuBackend, PipelineInterrupt
77

88
from ..xsmm import utils as xsmm_utils
@@ -46,7 +46,7 @@ def tpp_mapping(
4646
m_vals, n_vals, k_vals = pack_block_factors
4747
m = select("m", m_vals if isinstance(m_vals, Sequence) else [m_vals])
4848
n = select("n", n_vals if isinstance(n_vals, Sequence) else [n_vals])
49-
k = select("k", k_vals if isinstance(k_vals, Sequence) else [k_vals])
49+
k = pick("k", k_vals if isinstance(k_vals, Sequence) else [k_vals])
5050
options = {"block-factors": [m, n, k]}
5151
func = apply_registered_pass(func, "pack-matmul", options=options)
5252
apply_registered_pass(func, "pack-vnni")

python/mlir/tpp/sched/common.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,18 @@
33

44

55
# Wrapper to addresss verbosity.
6-
def apply_registered_pass(*args, **kwargs):
7-
return transform.apply_registered_pass(transform.AnyOpType.get(), *args, **kwargs)
6+
def select(*args, **kwargs):
7+
return tune.select(transform.AnyParamType.get(), *args, **kwargs)
88

99

1010
# Wrapper to addresss verbosity.
11-
def select(*args, **kwargs):
12-
return tune.select(transform.AnyParamType.get(), *args, **kwargs)
11+
def pick(*args, **kwargs):
12+
return tune.pick(transform.AnyParamType.get(), *args, **kwargs)
13+
14+
15+
# Wrapper to addresss verbosity.
16+
def apply_registered_pass(*args, **kwargs):
17+
return transform.apply_registered_pass(transform.AnyOpType.get(), *args, **kwargs)
1318

1419

1520
# Wrapper to addresss verbosity.

tools/tpp-tune/tpp-tune.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@ def choices_finder(op):
4949
if op.name in choices:
5050
raise RuntimeError(f"options name collision: {op.name} used twice")
5151
choices[op.name] = tuple(op.options)
52+
elif isinstance(op, transform_tune.TunePickOp):
53+
if op.name in choices:
54+
raise RuntimeError(f"options name collision: {op.name} used twice")
55+
choices[op.name] = tuple(op.options)
5256

5357
choices_finder(schedule.operation)
5458

@@ -63,6 +67,8 @@ def selected_rewriter(op: Union[ir.OpView, ir.Operation]):
6367
)
6468
for use in op.result.uses:
6569
use.owner.operands[use.operand_number] = param
70+
elif isinstance(op, transform_tune.TunePickOp):
71+
op.attributes["selected"] = selected[op.name]
6672

6773
selected_rewriter(schedule.operation)
6874

0 commit comments

Comments
 (0)