Skip to content

Commit eb9d56c

Browse files
authored
[MLIR][Transform][Python] Expose applying named_sequences as a method (#168223)
Makes it so that a NamedSequenceOp can be directly applied to a Module, via a method `apply(...)`.
1 parent 82214ff commit eb9d56c

File tree

2 files changed

+34
-0
lines changed

2 files changed

+34
-0
lines changed

mlir/python/mlir/dialects/transform/__init__.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from .._transform_ops_gen import _Dialect
88
from ..._mlir_libs._mlirDialectsTransform import *
99
from ..._mlir_libs._mlirDialectsTransform import AnyOpType, OperationType
10+
from . import interpreter
1011

1112
try:
1213
from ...ir import *
@@ -324,6 +325,25 @@ def bodyTarget(self) -> Value:
324325
def bodyExtraArgs(self) -> BlockArgumentList:
325326
return self.body.arguments[1:]
326327

328+
def apply(
329+
self,
330+
payload: Module,
331+
transform_options: Optional[interpreter.TransformOptions] = None,
332+
) -> Module:
333+
assert self.parent
334+
assert "transform.with_named_sequence" in self.parent.attributes
335+
assert isinstance(
336+
self.parent.attributes["transform.with_named_sequence"], UnitAttr
337+
)
338+
339+
interpreter.apply_named_sequence(
340+
payload_root=payload,
341+
transform_root=self,
342+
transform_module=self.parent,
343+
transform_options=transform_options,
344+
)
345+
return payload # NB: was modified in-place (if any transformation happened)
346+
327347

328348
def named_sequence(
329349
sym_name: Union[str, SymbolRefAttr],

mlir/test/python/dialects/transform_interpreter.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,20 @@ def print_self():
3131
# CHECK: transform.yield
3232

3333

34+
@test_in_context
35+
def print_self_via_apply_method():
36+
m = ir.Module.parse(
37+
print_root_module.replace("from interpreter", "print_self_via_apply_method")
38+
)
39+
m.body.operations[0].apply(m)
40+
41+
42+
# CHECK-LABEL: print_self_via_apply_method
43+
# CHECK: transform.named_sequence @__transform_main
44+
# CHECK: transform.print
45+
# CHECK: transform.yield
46+
47+
3448
@test_in_context
3549
def print_other():
3650
transform = ir.Module.parse(

0 commit comments

Comments
 (0)