Skip to content

Commit 8056f88

Browse files
committed
[MLIR][Transform][Python] Expose applying named_sequences as a method
Makes it so that a NamedSequenceOp can be directly applied to a Module, via a method `apply(...)`.
1 parent a4e7d15 commit 8056f88

File tree

2 files changed

+30
-0
lines changed

2 files changed

+30
-0
lines changed

mlir/python/mlir/dialects/transform/__init__.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from .._transform_ops_gen import _Dialect
88
from ..._mlir_libs._mlirDialectsTransform import *
99
from ..._mlir_libs._mlirDialectsTransform import AnyOpType, OperationType
10+
from . import interpreter
1011

1112
try:
1213
from ...ir import *
@@ -324,6 +325,25 @@ def bodyTarget(self) -> Value:
324325
def bodyExtraArgs(self) -> BlockArgumentList:
325326
return self.body.arguments[1:]
326327

328+
def apply(
329+
self,
330+
payload: Module,
331+
transform_options: Optional[interpreter.TransformOptions] = None,
332+
) -> Module:
333+
assert self.parent
334+
assert "transform.with_named_sequence" in self.parent.attributes
335+
assert isinstance(
336+
self.parent.attributes["transform.with_named_sequence"], UnitAttr
337+
)
338+
339+
interpreter.apply_named_sequence(
340+
payload_root=payload,
341+
transform_root=self,
342+
transform_module=self.parent,
343+
transform_options=transform_options,
344+
)
345+
return payload # NB: was modified in-place (if any transformation happened)
346+
327347

328348
def named_sequence(
329349
sym_name: Union[str, SymbolRefAttr],

mlir/test/python/dialects/transform_interpreter.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,16 @@ def print_self():
3030
# CHECK: transform.print
3131
# CHECK: transform.yield
3232

33+
@test_in_context
34+
def print_self_via_apply_method():
35+
m = ir.Module.parse(print_root_module.replace("from interpreter", "print_self_via_apply_method"))
36+
m.body.operations[0].apply(m)
37+
38+
# CHECK-LABEL: print_self_via_apply_method
39+
# CHECK: transform.named_sequence @__transform_main
40+
# CHECK: transform.print
41+
# CHECK: transform.yield
42+
3343

3444
@test_in_context
3545
def print_other():

0 commit comments

Comments
 (0)