Skip to content

Commit a2d8da6

Browse files
authored
[MLIR][Transform][Tune] Introduce transform.tune.alternatives op (#160724)
This op enables expressing uncertainty regarding what should be happening at particular places in transform-dialect schedules. In particular, it enables representing a choice among alternative regions. This choice is resolved through providing a `selected_region` argument. When this argument is provided, the semantics are such that it is valid to rewrite the op through substituting in the selected region -- with the op's interpreted semantics corresponding to exactly this. This op represents another piece of the puzzle w.r.t. a toolkit for expressing autotuning problems with the transform dialect. Note that this goes beyond tuning knobs _on_ transforms, going further by making it tunable which (sequences of) transforms are to be applied.
1 parent be6a8c5 commit a2d8da6

File tree

1 file changed

+63
-3
lines changed
  • mlir/python/mlir/dialects/transform

1 file changed

+63
-3
lines changed

mlir/python/mlir/dialects/transform/tune.py

Lines changed: 63 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66

77
from ...ir import (
88
Type,
9+
Value,
10+
Operation,
11+
OpView,
912
Attribute,
1013
ArrayAttr,
1114
StringAttr,
@@ -19,7 +22,10 @@
1922
from .._transform_tune_extension_ops_gen import _Dialect
2023

2124
try:
22-
from .._ods_common import _cext as _ods_cext
25+
from .._ods_common import (
26+
get_op_result_or_value as _get_op_result_or_value,
27+
_cext as _ods_cext,
28+
)
2329
except ImportError as e:
2430
raise RuntimeError("Error loading imports from extension module") from e
2531

@@ -36,7 +42,7 @@ def __init__(
3642
ArrayAttr, Sequence[Union[Attribute, bool, int, float, str]], Attribute
3743
],
3844
*,
39-
selected: Optional[Attribute] = None,
45+
selected: Optional[Union[Attribute, bool, int, float, str]] = None,
4046
loc=None,
4147
ip=None,
4248
):
@@ -75,8 +81,62 @@ def knob(
7581
ArrayAttr, Sequence[Union[Attribute, bool, int, float, str]], Attribute
7682
],
7783
*,
78-
selected: Optional[Attribute] = None,
84+
selected: Optional[Union[Attribute, bool, int, float, str]] = None,
7985
loc=None,
8086
ip=None,
8187
):
8288
return KnobOp(result, name, options, selected=selected, loc=loc, ip=ip)
89+
90+
91+
@_ods_cext.register_operation(_Dialect, replace=True)
92+
class AlternativesOp(AlternativesOp):
93+
def __init__(
94+
self,
95+
results: Sequence[Type],
96+
name: Union[StringAttr, str],
97+
num_alternatives: int,
98+
*,
99+
selected_region: Optional[
100+
Union[int, IntegerAttr, Value, Operation, OpView]
101+
] = None,
102+
loc=None,
103+
ip=None,
104+
):
105+
if isinstance(name, str):
106+
name = StringAttr.get(name)
107+
108+
selected_region_attr = selected_region_param = None
109+
if isinstance(selected_region, IntegerAttr):
110+
selected_region_attr = selected_region
111+
elif isinstance(selected_region, int):
112+
selected_region_attr = IntegerAttr.get(
113+
IntegerType.get_signless(32), selected_region
114+
)
115+
elif isinstance(selected_region, (Value, Operation, OpView)):
116+
selected_region_param = _get_op_result_or_value(selected_region)
117+
118+
super().__init__(
119+
results,
120+
name,
121+
num_alternatives,
122+
selected_region_attr=selected_region_attr,
123+
selected_region_param=selected_region_param,
124+
loc=loc,
125+
ip=ip,
126+
)
127+
for region in self.regions:
128+
region.blocks.append()
129+
130+
131+
def alternatives(
132+
results: Sequence[Type],
133+
name: Union[StringAttr, str],
134+
num_alternatives: int,
135+
*,
136+
selected_region: Optional[Union[int, IntegerAttr, Value, Operation, OpView]] = None,
137+
loc=None,
138+
ip=None,
139+
):
140+
return AlternativesOp(
141+
results, name, num_alternatives, selected_region=selected_region, loc=loc, ip=ip
142+
)

0 commit comments

Comments
 (0)