Skip to content

Commit 500acb4

Browse files
PragmaTwicemahesh-attarde
authored andcommitted
[MLIR][Python] Add bindings for PDL constraint function registering (llvm#160520)
This is a follow-up to llvm#159926. That PR (llvm#159926) exposed native rewrite function registration in PDL through the C API and Python, enabling use with `pdl.apply_native_rewrite`. In this PR, we add support for native constraint functions in PDL via `pdl.apply_native_constraint`, further completing the PDL API.
1 parent 220ca5a commit 500acb4

File tree

4 files changed

+127
-10
lines changed

4 files changed

+127
-10
lines changed

mlir/include/mlir-c/Rewrite.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,20 @@ MLIR_CAPI_EXPORTED void mlirPDLPatternModuleRegisterRewriteFunction(
375375
MlirPDLPatternModule pdlModule, MlirStringRef name,
376376
MlirPDLRewriteFunction rewriteFn, void *userData);
377377

378+
/// This function type is used as callbacks for PDL native constraint functions.
379+
/// Input values can be accessed by `values` with its size `nValues`;
380+
/// output values can be added into `results` by `mlirPDLResultListPushBack*`
381+
/// APIs. And the return value indicates whether the constraint holds.
382+
typedef MlirLogicalResult (*MlirPDLConstraintFunction)(
383+
MlirPatternRewriter rewriter, MlirPDLResultList results, size_t nValues,
384+
MlirPDLValue *values, void *userData);
385+
386+
/// Register a constraint function into the given PDL pattern module.
387+
/// `userData` will be provided as an argument to the constraint function.
388+
MLIR_CAPI_EXPORTED void mlirPDLPatternModuleRegisterConstraintFunction(
389+
MlirPDLPatternModule pdlModule, MlirStringRef name,
390+
MlirPDLConstraintFunction constraintFn, void *userData);
391+
378392
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
379393

380394
#undef DEFINE_C_API_STRUCT

mlir/lib/Bindings/Python/Rewrite.cpp

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,15 @@ static nb::object objectFromPDLValue(MlirPDLValue value) {
4040
throw std::runtime_error("unsupported PDL value type");
4141
}
4242

43+
static std::vector<nb::object> objectsFromPDLValues(size_t nValues,
44+
MlirPDLValue *values) {
45+
std::vector<nb::object> args;
46+
args.reserve(nValues);
47+
for (size_t i = 0; i < nValues; ++i)
48+
args.push_back(objectFromPDLValue(values[i]));
49+
return args;
50+
}
51+
4352
// Convert the Python object to a boolean.
4453
// If it evaluates to False, treat it as success;
4554
// otherwise, treat it as failure.
@@ -74,11 +83,22 @@ class PyPDLPatternModule {
7483
size_t nValues, MlirPDLValue *values,
7584
void *userData) -> MlirLogicalResult {
7685
nb::handle f = nb::handle(static_cast<PyObject *>(userData));
77-
std::vector<nb::object> args;
78-
args.reserve(nValues);
79-
for (size_t i = 0; i < nValues; ++i)
80-
args.push_back(objectFromPDLValue(values[i]));
81-
return logicalResultFromObject(f(rewriter, results, args));
86+
return logicalResultFromObject(
87+
f(rewriter, results, objectsFromPDLValues(nValues, values)));
88+
},
89+
fn.ptr());
90+
}
91+
92+
void registerConstraintFunction(const std::string &name,
93+
const nb::callable &fn) {
94+
mlirPDLPatternModuleRegisterConstraintFunction(
95+
get(), mlirStringRefCreate(name.data(), name.size()),
96+
[](MlirPatternRewriter rewriter, MlirPDLResultList results,
97+
size_t nValues, MlirPDLValue *values,
98+
void *userData) -> MlirLogicalResult {
99+
nb::handle f = nb::handle(static_cast<PyObject *>(userData));
100+
return logicalResultFromObject(
101+
f(rewriter, results, objectsFromPDLValues(nValues, values)));
82102
},
83103
fn.ptr());
84104
}
@@ -199,6 +219,13 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
199219
const nb::callable &fn) {
200220
self.registerRewriteFunction(name, fn);
201221
},
222+
nb::keep_alive<1, 3>())
223+
.def(
224+
"register_constraint_function",
225+
[](PyPDLPatternModule &self, const std::string &name,
226+
const nb::callable &fn) {
227+
self.registerConstraintFunction(name, fn);
228+
},
202229
nb::keep_alive<1, 3>());
203230
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
204231
nb::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet")

mlir/lib/CAPI/Transforms/Rewrite.cpp

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -398,21 +398,41 @@ void mlirPDLResultListPushBackAttribute(MlirPDLResultList results,
398398
unwrap(results)->push_back(unwrap(value));
399399
}
400400

401+
inline std::vector<MlirPDLValue> wrap(ArrayRef<PDLValue> values) {
402+
std::vector<MlirPDLValue> mlirValues;
403+
mlirValues.reserve(values.size());
404+
for (auto &value : values) {
405+
mlirValues.push_back(wrap(&value));
406+
}
407+
return mlirValues;
408+
}
409+
401410
void mlirPDLPatternModuleRegisterRewriteFunction(
402411
MlirPDLPatternModule pdlModule, MlirStringRef name,
403412
MlirPDLRewriteFunction rewriteFn, void *userData) {
404413
unwrap(pdlModule)->registerRewriteFunction(
405414
unwrap(name),
406415
[userData, rewriteFn](PatternRewriter &rewriter, PDLResultList &results,
407416
ArrayRef<PDLValue> values) -> LogicalResult {
408-
std::vector<MlirPDLValue> mlirValues;
409-
mlirValues.reserve(values.size());
410-
for (auto &value : values) {
411-
mlirValues.push_back(wrap(&value));
412-
}
417+
std::vector<MlirPDLValue> mlirValues = wrap(values);
413418
return unwrap(rewriteFn(wrap(&rewriter), wrap(&results),
414419
mlirValues.size(), mlirValues.data(),
415420
userData));
416421
});
417422
}
423+
424+
void mlirPDLPatternModuleRegisterConstraintFunction(
425+
MlirPDLPatternModule pdlModule, MlirStringRef name,
426+
MlirPDLConstraintFunction constraintFn, void *userData) {
427+
unwrap(pdlModule)->registerConstraintFunction(
428+
unwrap(name),
429+
[userData, constraintFn](PatternRewriter &rewriter,
430+
PDLResultList &results,
431+
ArrayRef<PDLValue> values) -> LogicalResult {
432+
std::vector<MlirPDLValue> mlirValues = wrap(values);
433+
return unwrap(constraintFn(wrap(&rewriter), wrap(&results),
434+
mlirValues.size(), mlirValues.data(),
435+
userData));
436+
});
437+
}
418438
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH

mlir/test/python/integration/dialects/pdl.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,12 +153,43 @@ def rew():
153153
)
154154
pdl.ReplaceOp(op0, with_op=newOp)
155155

156+
@pdl.pattern(benefit=1, sym_name="myint_add_zero_fold")
157+
def pat():
158+
t = pdl.TypeOp(i32)
159+
v0 = pdl.OperandOp()
160+
v1 = pdl.OperandOp()
161+
v = pdl.apply_native_constraint([pdl.ValueType.get()], "has_zero", [v0, v1])
162+
op0 = pdl.OperationOp(name="myint.add", args=[v0, v1], types=[t])
163+
164+
@pdl.rewrite()
165+
def rew():
166+
pdl.ReplaceOp(op0, with_values=[v])
167+
156168
def add_fold(rewriter, results, values):
157169
a0, a1 = values
158170
results.append(IntegerAttr.get(i32, a0.value + a1.value))
159171

172+
def is_zero(value):
173+
op = value.owner
174+
if isinstance(op, Operation):
175+
return op.name == "myint.constant" and op.attributes["value"].value == 0
176+
return False
177+
178+
# Check if either operand is a constant zero,
179+
# and append the other operand to the results if so.
180+
def has_zero(rewriter, results, values):
181+
v0, v1 = values
182+
if is_zero(v0):
183+
results.append(v1)
184+
return False
185+
if is_zero(v1):
186+
results.append(v0)
187+
return False
188+
return True
189+
160190
pdl_module = PDLModule(m)
161191
pdl_module.register_rewrite_function("add_fold", add_fold)
192+
pdl_module.register_constraint_function("has_zero", has_zero)
162193
return pdl_module.freeze()
163194

164195

@@ -181,3 +212,28 @@ def test_pdl_register_function(module_):
181212
apply_patterns_and_fold_greedily(module_, frozen)
182213

183214
return module_
215+
216+
217+
# CHECK-LABEL: TEST: test_pdl_register_function_constraint
218+
# CHECK: return %arg0 : i32
219+
@construct_and_print_in_module
220+
def test_pdl_register_function_constraint(module_):
221+
load_myint_dialect()
222+
223+
module_ = Module.parse(
224+
"""
225+
func.func @f(%x : i32) -> i32 {
226+
%c0 = "myint.constant"() { value = 1 }: () -> (i32)
227+
%c1 = "myint.constant"() { value = -1 }: () -> (i32)
228+
%a = "myint.add"(%c0, %c1): (i32, i32) -> (i32)
229+
%b = "myint.add"(%a, %x): (i32, i32) -> (i32)
230+
%c = "myint.add"(%b, %a): (i32, i32) -> (i32)
231+
func.return %c : i32
232+
}
233+
"""
234+
)
235+
236+
frozen = get_pdl_pattern_fold()
237+
apply_patterns_and_fold_greedily(module_, frozen)
238+
239+
return module_

0 commit comments

Comments
 (0)