Skip to content
This repository was archived by the owner on Oct 11, 2025. It is now read-only.

Commit ea7589d

Browse files
authored
[MLIR][Transform] Introduce transform.tune.knob op (#146732)
A new transform op to represent that an attribute is to be chosen from a set of alternatives and that this choice is made available as a `!transform.param`. When a `selected` argument is provided, the op's `apply()` semantics is that of just making this selected attribute available as the result. When `selected` is not provided, `apply()` complains that nothing has resolved the non-determinism that the op is representing.
1 parent 556b741 commit ea7589d

File tree

3 files changed

+110
-0
lines changed

3 files changed

+110
-0
lines changed

mlir/python/CMakeLists.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,15 @@ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
180180
DIALECT_NAME transform
181181
EXTENSION_NAME transform_debug_extension)
182182

183+
declare_mlir_dialect_extension_python_bindings(
184+
ADD_TO_PARENT MLIRPythonSources.Dialects
185+
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
186+
TD_FILE dialects/TransformTuneExtensionOps.td
187+
SOURCES
188+
dialects/transform/tune.py
189+
DIALECT_NAME transform
190+
EXTENSION_NAME transform_tune_extension)
191+
183192
declare_mlir_dialect_python_bindings(
184193
ADD_TO_PARENT MLIRPythonSources.Dialects
185194
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
//===-- TransformTuneExtensionOps.td - Binding entry point -*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Entry point of the generated Python bindings for the Tune extension of the
10+
// Transform dialect.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#ifndef PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS
15+
#define PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS
16+
17+
include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td"
18+
19+
#endif // PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
5+
from typing import Optional, Sequence
6+
7+
from ...ir import (
8+
Type,
9+
Attribute,
10+
ArrayAttr,
11+
StringAttr,
12+
F64Type,
13+
IntegerType,
14+
IntegerAttr,
15+
FloatAttr,
16+
BoolAttr,
17+
)
18+
from .._transform_tune_extension_ops_gen import *
19+
from .._transform_tune_extension_ops_gen import _Dialect
20+
21+
try:
22+
from .._ods_common import _cext as _ods_cext
23+
except ImportError as e:
24+
raise RuntimeError("Error loading imports from extension module") from e
25+
26+
from typing import Union
27+
28+
29+
@_ods_cext.register_operation(_Dialect, replace=True)
30+
class KnobOp(KnobOp):
31+
def __init__(
32+
self,
33+
result: Type, # !transform.any_param or !transform.param<Type>
34+
name: Union[StringAttr, str],
35+
options: Union[
36+
ArrayAttr, Sequence[Union[Attribute, bool, int, float, str]], Attribute
37+
],
38+
*,
39+
selected: Optional[Attribute] = None,
40+
loc=None,
41+
ip=None,
42+
):
43+
if isinstance(name, str):
44+
name = StringAttr.get(name)
45+
46+
def map_to_attr(value):
47+
if isinstance(value, bool):
48+
return BoolAttr.get(value)
49+
if isinstance(value, int):
50+
return IntegerAttr.get(IntegerType.get_signless(64), value)
51+
if isinstance(value, float):
52+
return FloatAttr.get(F64Type.get(), value)
53+
if isinstance(value, str):
54+
return StringAttr.get(value)
55+
assert isinstance(value, Attribute)
56+
return value
57+
58+
if isinstance(options, Sequence) and not isinstance(options, ArrayAttr):
59+
options = ArrayAttr.get([map_to_attr(opt) for opt in options])
60+
61+
super().__init__(
62+
result,
63+
name,
64+
options,
65+
selected=selected and map_to_attr(selected),
66+
loc=loc,
67+
ip=ip,
68+
)
69+
70+
71+
def knob(
72+
result: Type, # !transform.any_param or !transform.param<Type>
73+
name: Union[StringAttr, str],
74+
options: Union[
75+
ArrayAttr, Sequence[Union[Attribute, bool, int, float, str]], Attribute
76+
],
77+
*,
78+
selected: Optional[Attribute] = None,
79+
loc=None,
80+
ip=None,
81+
):
82+
return KnobOp(result, name, options, selected=selected, loc=loc, ip=ip)

0 commit comments

Comments
 (0)