Skip to content

Commit 06e2c78

Browse files
[MLIR][Python] Pass OpView subclasses instead of Operation in rewrite patterns (#163080)
This is a follow-up PR for #162699. Currently, in the function where we define rewrite patterns, the `op` we receive is of type `ir.Operation` rather than a specific `OpView` type (such as `arith.AddIOp`). This means we can’t conveniently access certain parts of the operation — for example, we need to use `op.operands[0]` instead of `op.lhs`. The following example code illustrates this situation. ```python def to_muli(op, rewriter): # op is typed ir.Operation instead of arith.AddIOp pass patterns.add(arith.AddIOp, to_muli) ``` In this PR, we convert the operation to its corresponding `OpView` subclass before invoking the rewrite pattern callback, making it much easier to write patterns. --------- Co-authored-by: Maksim Levental <[email protected]>
1 parent 6785c4f commit 06e2c78

File tree

3 files changed

+11
-5
lines changed

3 files changed

+11
-5
lines changed

mlir/lib/Bindings/Python/Rewrite.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,12 @@ class PyRewritePatternSet {
197197
MlirPatternRewriter rewriter,
198198
void *userData) -> MlirLogicalResult {
199199
nb::handle f(static_cast<PyObject *>(userData));
200-
nb::object res = f(op, PyPatternRewriter(rewriter));
200+
201+
PyMlirContextRef ctx =
202+
PyMlirContext::forContext(mlirOperationGetContext(op));
203+
nb::object opView = PyOperation::forOperation(ctx, op)->createOpView();
204+
205+
nb::object res = f(opView, PyPatternRewriter(rewriter));
201206
return logicalResultFromObject(res);
202207
};
203208
MlirRewritePattern pattern = mlirOpRewritePattenCreate(

mlir/python/mlir/dialects/arith.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def type(self):
9292

9393
@property
9494
def value(self):
95-
return Attribute(self.operation.attributes["value"])
95+
return self.operation.attributes["value"]
9696

9797
@property
9898
def literal_value(self) -> Union[int, float]:

mlir/test/python/rewrite.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,16 @@ def run(f):
1717
def testRewritePattern():
1818
def to_muli(op, rewriter):
1919
with rewriter.ip:
20-
new_op = arith.muli(op.operands[0], op.operands[1], loc=op.location)
20+
assert isinstance(op, arith.AddIOp)
21+
new_op = arith.muli(op.lhs, op.rhs, loc=op.location)
2122
rewriter.replace_op(op, new_op.owner)
2223

2324
def constant_1_to_2(op, rewriter):
24-
c = op.attributes["value"].value
25+
c = op.value.value
2526
if c != 1:
2627
return True # failed to match
2728
with rewriter.ip:
28-
new_op = arith.constant(op.result.type, 2, loc=op.location)
29+
new_op = arith.constant(op.type, 2, loc=op.location)
2930
rewriter.replace_op(op, [new_op])
3031

3132
with Context():

0 commit comments

Comments
 (0)