Skip to content
This repository was archived by the owner on Oct 11, 2025. It is now read-only.

Commit b7cbca4

Browse files
authored
[MLIR][Transform][SMT] Introduce transform.smt.constrain_params (#159450)
Introduces a Transform-dialect SMT-extension so that we can have an op to express constrains on Transform-dialect params, in particular when these params are knobs -- see transform.tune.knob -- and can hence be seen as symbolic variables. This op allows expressing joint constraints over multiple params/knobs together. While the op's semantics are clearly defined, per SMTLIB, the interpreted semantics -- i.e. the `apply()` method -- for now just defaults to failure. In the future we should support attaching an implementation so that users can Bring Your Own Solver and thereby control performance of interpreting the op. For now the main usage is to walk schedule IR and collect these constraints so that knobs can be rewritten to constants that satisfy the constraints.
1 parent 781d025 commit b7cbca4

File tree

5 files changed

+82
-10
lines changed

5 files changed

+82
-10
lines changed

mlir/lib/Bindings/Python/DialectSMT.cpp

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,21 +26,26 @@ using namespace mlir::python::nanobind_adaptors;
2626

2727
static void populateDialectSMTSubmodule(nanobind::module_ &m) {
2828

29-
auto smtBoolType = mlir_type_subclass(m, "BoolType", mlirSMTTypeIsABool)
30-
.def_classmethod(
31-
"get",
32-
[](const nb::object &, MlirContext context) {
33-
return mlirSMTTypeGetBool(context);
34-
},
35-
"cls"_a, "context"_a = nb::none());
29+
auto smtBoolType =
30+
mlir_type_subclass(m, "BoolType", mlirSMTTypeIsABool)
31+
.def_staticmethod(
32+
"get",
33+
[](MlirContext context) { return mlirSMTTypeGetBool(context); },
34+
"context"_a = nb::none());
3635
auto smtBitVectorType =
3736
mlir_type_subclass(m, "BitVectorType", mlirSMTTypeIsABitVector)
38-
.def_classmethod(
37+
.def_staticmethod(
3938
"get",
40-
[](const nb::object &, int32_t width, MlirContext context) {
39+
[](int32_t width, MlirContext context) {
4140
return mlirSMTTypeGetBitVector(context, width);
4241
},
43-
"cls"_a, "width"_a, "context"_a = nb::none());
42+
"width"_a, "context"_a = nb::none());
43+
auto smtIntType =
44+
mlir_type_subclass(m, "IntType", mlirSMTTypeIsAInt)
45+
.def_staticmethod(
46+
"get",
47+
[](MlirContext context) { return mlirSMTTypeGetInt(context); },
48+
"context"_a = nb::none());
4449

4550
auto exportSMTLIB = [](MlirOperation module, bool inlineSingleUseValues,
4651
bool indentLetBody) {

mlir/python/CMakeLists.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,15 @@ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
171171
DIALECT_NAME transform
172172
EXTENSION_NAME transform_pdl_extension)
173173

174+
declare_mlir_dialect_extension_python_bindings(
175+
ADD_TO_PARENT MLIRPythonSources.Dialects
176+
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
177+
TD_FILE dialects/TransformSMTExtensionOps.td
178+
SOURCES
179+
dialects/transform/smt.py
180+
DIALECT_NAME transform
181+
EXTENSION_NAME transform_smt_extension)
182+
174183
declare_mlir_dialect_extension_python_bindings(
175184
ADD_TO_PARENT MLIRPythonSources.Dialects
176185
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
//===-- TransformSMTExtensionOps.td - Binding entry point --*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Entry point of the generated Python bindings for the SMT extension of the
10+
// Transform dialect.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#ifndef PYTHON_BINDINGS_TRANSFORM_SMT_EXTENSION_OPS
15+
#define PYTHON_BINDINGS_TRANSFORM_SMT_EXTENSION_OPS
16+
17+
include "mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.td"
18+
19+
#endif // PYTHON_BINDINGS_TRANSFORM_SMT_EXTENSION_OPS

mlir/python/mlir/dialects/smt.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44

55
from ._smt_ops_gen import *
6+
from ._smt_enum_gen import *
67

78
from .._mlir_libs._mlirDialectsSMT import *
89
from ..extras.meta import region_op
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
5+
from typing import Sequence
6+
7+
from ...ir import Type, Block
8+
from .._transform_smt_extension_ops_gen import *
9+
from .._transform_smt_extension_ops_gen import _Dialect
10+
from ...dialects import transform
11+
12+
try:
13+
from .._ods_common import _cext as _ods_cext
14+
except ImportError as e:
15+
raise RuntimeError("Error loading imports from extension module") from e
16+
17+
18+
@_ods_cext.register_operation(_Dialect, replace=True)
19+
class ConstrainParamsOp(ConstrainParamsOp):
20+
def __init__(
21+
self,
22+
params: Sequence[transform.AnyParamType],
23+
arg_types: Sequence[Type],
24+
loc=None,
25+
ip=None,
26+
):
27+
if len(params) != len(arg_types):
28+
raise ValueError(f"{params=} not same length as {arg_types=}")
29+
super().__init__(
30+
params,
31+
loc=loc,
32+
ip=ip,
33+
)
34+
self.regions[0].blocks.append(*arg_types)
35+
36+
@property
37+
def body(self) -> Block:
38+
return self.regions[0].blocks[0]

0 commit comments

Comments
 (0)