Skip to content

Commit 64fb6ac

Browse files
erick-xanadumudit2812mehrdad2m
authored
Redefine ApplyRegisteredPass in the Transform dialect (#7956)
**Context:** Catalyst is updating LLVM, which includes a redefinition of this operation in the transform dialect. Note: Only merge after PennyLaneAI/catalyst#1916 **Description of the Change:** Redefine the ApplyRegisteredPass in the transform dialect to support Catalyst's generic format. **Benefits:** **Possible Drawbacks:** **Related GitHub Issues:** [sc-96972] --------- Co-authored-by: Mudit Pandey <[email protected]> Co-authored-by: Mehrdad Malekmohammadi <[email protected]> Co-authored-by: Mehrdad Malek <[email protected]>
1 parent c077269 commit 64fb6ac

File tree

6 files changed

+290
-14
lines changed

6 files changed

+290
-14
lines changed

pennylane/compiler/python_compiler/dialects/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@
1414

1515
"""This submodule contains xDSL dialects for the Python compiler."""
1616

17+
from .catalyst import Catalyst
1718
from .mbqc import MBQC
1819
from .quantum import Quantum
19-
from .catalyst import Catalyst
2020
from .qec import QEC
21+
from .transform import Transform
22+
2123

22-
__all__ = ["Catalyst", "MBQC", "Quantum", "QEC"]
24+
__all__ = ["Catalyst", "MBQC", "Quantum", "QEC", "Transform"]
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# Copyright 2025 Xanadu Quantum Technologies Inc.
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
This file contains an updated version of the transform dialect.
16+
As of the time of writing, xDSL uses the MLIR released with LLVM's
17+
version 20.1.7. However, https://github.com/PennyLaneAI/catalyst/pull/1916
18+
will be updating MLIR where the transform dialect has the
19+
`apply_registered_pass` operation re-defined.
20+
21+
See the following changelog on the above PR
22+
23+
Things related to transform.apply_registered_pass op:
24+
25+
It now takes in a dynamic_options
26+
27+
[MLIR][Transform] Allow ApplyRegisteredPassOp to take options as
28+
a param llvm/llvm-project#142683. We don't need to use this as all our pass options are static.
29+
https://github.com/llvm/llvm-project/pull/142683
30+
31+
The options it takes in are now dictionaries instead of strings
32+
[MLIR][Transform] apply_registered_pass op's options as a dict llvm/llvm-project#143159
33+
https://github.com/llvm/llvm-project/pull/143159
34+
35+
This file will re-define the apply_registered_pass operation in xDSL
36+
and the transform dialect.
37+
38+
Once xDSL moves to a newer version of MLIR, these changes should
39+
be contributed upstream.
40+
"""
41+
42+
from xdsl.dialects.builtin import Dialect
43+
44+
# pylint: disable=unused-wildcard-import,wildcard-import,undefined-variable,too-few-public-methods
45+
from xdsl.dialects.transform import ApplyRegisteredPassOp as xApplyRegisteredPassOp
46+
from xdsl.dialects.transform import (
47+
DictionaryAttr,
48+
StringAttr,
49+
)
50+
from xdsl.dialects.transform import Transform as xTransform
51+
from xdsl.dialects.transform import (
52+
TransformHandleType,
53+
irdl_op_definition,
54+
operand_def,
55+
prop_def,
56+
result_def,
57+
)
58+
from xdsl.ir import Attribute, SSAValue
59+
from xdsl.irdl import IRDLOperation, ParsePropInAttrDict
60+
61+
62+
@irdl_op_definition
63+
# pylint: disable=function-redefined
64+
class ApplyRegisteredPassOp(IRDLOperation):
65+
"""
66+
See external [documentation](https://mlir.llvm.org/docs/Dialects/Transform/#transformapply_registered_pass-transformapplyregisteredpassop).
67+
"""
68+
69+
name = "transform.apply_registered_pass"
70+
71+
options = prop_def(DictionaryAttr, default_value=DictionaryAttr({}))
72+
pass_name = prop_def(StringAttr)
73+
target = operand_def(TransformHandleType)
74+
result = result_def(TransformHandleType)
75+
# While this assembly format doesn't match
76+
# the one in upstream MLIR,
77+
# this is because xDSL currently lacks CustomDirectives
78+
# https://mlir.llvm.org/docs/DefiningDialects/Operations/#custom-directives
79+
# https://github.com/xdslproject/xdsl/pull/4829
80+
# However, storing the property in the attribute should still work
81+
# specially when parsing and printing in generic format.
82+
# Which is how Catalyst and XDSL currently communicate at the moment.
83+
# TODO: Add test.
84+
assembly_format = "$pass_name `to` $target attr-dict `:` functional-type(operands, results)"
85+
irdl_options = [ParsePropInAttrDict()]
86+
87+
def __init__(
88+
self,
89+
pass_name: str | StringAttr,
90+
target: SSAValue,
91+
options: dict[str | StringAttr, Attribute | str | bool | int] | None = None,
92+
):
93+
if isinstance(pass_name, str):
94+
pass_name = StringAttr(pass_name)
95+
96+
if isinstance(options, dict):
97+
options = DictionaryAttr(options)
98+
99+
super().__init__(
100+
properties={
101+
"pass_name": pass_name,
102+
"options": options,
103+
},
104+
operands=[target],
105+
result_types=[target.type],
106+
)
107+
108+
109+
# Copied over from xDSL's sources
110+
# the main difference will be the use
111+
# of a different ApplyRegisteredPassOp
112+
operations = list(xTransform.operations)
113+
del operations[operations.index(xApplyRegisteredPassOp)]
114+
operations.append(ApplyRegisteredPassOp)
115+
116+
Transform = Dialect(
117+
"transform",
118+
[
119+
*operations,
120+
],
121+
[
122+
*xTransform.attributes,
123+
],
124+
)

pennylane/compiler/python_compiler/jax_utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,11 @@
3030
from xdsl.dialects import scf as xscf
3131
from xdsl.dialects import stablehlo as xstablehlo
3232
from xdsl.dialects import tensor as xtensor
33-
from xdsl.dialects import transform as xtransform
3433
from xdsl.ir import Dialect as xDialect
3534
from xdsl.parser import Parser as xParser
3635
from xdsl.traits import SymbolTable as xSymbolTable
3736

38-
from .dialects import MBQC, QEC, Catalyst, Quantum
37+
from .dialects import MBQC, QEC, Catalyst, Quantum, Transform
3938

4039
JaxJittedFunction: TypeAlias = _jax.PjitFunction # pylint: disable=c-extension-no-member
4140

@@ -59,7 +58,7 @@ class QuantumParser(xParser): # pylint: disable=abstract-method,too-few-public-
5958
xscf.Scf,
6059
xstablehlo.StableHLO,
6160
xtensor.Tensor,
62-
xtransform.Transform,
61+
Transform,
6362
Quantum,
6463
MBQC,
6564
Catalyst,

pennylane/compiler/python_compiler/transforms/api/transform_interpreter.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323

2424
from catalyst.compiler import _quantum_opt # pylint: disable=protected-access
2525
from xdsl.context import Context
26-
from xdsl.dialects import builtin, transform
26+
from xdsl.dialects import builtin
27+
from xdsl.dialects.transform import NamedSequenceOp
2728
from xdsl.interpreter import Interpreter, PythonValues, impl, register_impls
2829
from xdsl.interpreters.transform import TransformFunctions
2930
from xdsl.parser import Parser
@@ -32,6 +33,8 @@
3233
from xdsl.rewriter import Rewriter
3334
from xdsl.utils.exceptions import PassFailedException
3435

36+
from ...dialects.transform import ApplyRegisteredPassOp
37+
3538

3639
# pylint: disable=too-few-public-methods
3740
@register_impls
@@ -43,11 +46,11 @@ class TransformFunctionsExt(TransformFunctions):
4346
then it will try to run this pass in Catalyst.
4447
"""
4548

46-
@impl(transform.ApplyRegisteredPassOp)
49+
@impl(ApplyRegisteredPassOp)
4750
def run_apply_registered_pass_op( # pragma: no cover
4851
self,
4952
_interpreter: Interpreter,
50-
op: transform.ApplyRegisteredPassOp,
53+
op: ApplyRegisteredPassOp,
5154
args: PythonValues,
5255
) -> PythonValues:
5356
"""Try to run the pass in xDSL, if it can't run on catalyst"""
@@ -56,7 +59,7 @@ def run_apply_registered_pass_op( # pragma: no cover
5659
if pass_name in self.passes:
5760
# pragma: no cover
5861
pass_class = self.passes[pass_name]()
59-
pipeline = PassPipeline((pass_class(),))
62+
pipeline = PassPipeline((pass_class(**op.options.data),))
6063
pipeline.apply(self.ctx, args[0])
6164
return (args[0],)
6265

@@ -86,12 +89,10 @@ def __init__(self, passes):
8689
self.passes = passes
8790

8891
@staticmethod
89-
def find_transform_entry_point(
90-
root: builtin.ModuleOp, entry_point: str
91-
) -> transform.NamedSequenceOp:
92+
def find_transform_entry_point(root: builtin.ModuleOp, entry_point: str) -> NamedSequenceOp:
9293
"""Find the entry point of the program"""
9394
for op in root.walk():
94-
if isinstance(op, transform.NamedSequenceOp) and op.sym_name.data == entry_point:
95+
if isinstance(op, NamedSequenceOp) and op.sym_name.data == entry_point:
9596
return op
9697
raise PassFailedException( # pragma: no cover
9798
f"{root} could not find a nested named sequence with name: {entry_point}"
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
# Copyright 2025 Xanadu Quantum Technologies Inc.
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Unit test module for pennylane/compiler/python_compiler/transform.py."""
16+
17+
from dataclasses import dataclass
18+
19+
import pytest
20+
21+
# pylint: disable=wrong-import-position
22+
23+
xdsl = pytest.importorskip("xdsl")
24+
filecheck = pytest.importorskip("filecheck")
25+
26+
pytestmark = pytest.mark.external
27+
28+
from xdsl import passes
29+
from xdsl.context import Context
30+
from xdsl.dialects import builtin
31+
from xdsl.dialects.builtin import DictionaryAttr, IntegerAttr, i64
32+
from xdsl.dialects.transform import AnyOpType
33+
from xdsl.utils.exceptions import VerifyException
34+
from xdsl.utils.test_value import create_ssa_value
35+
36+
from pennylane.compiler.python_compiler.dialects import transform
37+
from pennylane.compiler.python_compiler.dialects.transform import ApplyRegisteredPassOp
38+
from pennylane.compiler.python_compiler.jax_utils import xdsl_from_docstring
39+
from pennylane.compiler.python_compiler.transforms.api import (
40+
ApplyTransformSequence,
41+
compiler_transform,
42+
)
43+
44+
45+
def test_dict_options():
46+
"""Test ApplyRegisteredPassOp constructor with dict options."""
47+
target = create_ssa_value(AnyOpType())
48+
options = {"option1": 1, "option2": True}
49+
50+
op = ApplyRegisteredPassOp("canonicalize", target, options)
51+
52+
assert op.pass_name.data == "canonicalize"
53+
assert isinstance(op.options, DictionaryAttr)
54+
assert op.options == DictionaryAttr({"option1": 1, "option2": True})
55+
assert op.verify_() is None
56+
57+
58+
def test_attr_options():
59+
"""Test ApplyRegisteredPassOp constructor with DictionaryAttr options."""
60+
target = create_ssa_value(AnyOpType())
61+
options = DictionaryAttr({"test-option": IntegerAttr(42, i64)})
62+
63+
# This should trigger the __init__ method
64+
op = ApplyRegisteredPassOp("canonicalize", target, options)
65+
66+
assert op.pass_name.data == "canonicalize"
67+
assert isinstance(op.options, DictionaryAttr)
68+
assert op.options == DictionaryAttr({"test-option": IntegerAttr(42, i64)})
69+
assert op.verify_() is None
70+
71+
72+
def test_none_options():
73+
"""Test ApplyRegisteredPassOp constructor with None options."""
74+
target = create_ssa_value(AnyOpType())
75+
76+
# This should trigger the __init__ method
77+
op = ApplyRegisteredPassOp("canonicalize", target, None)
78+
79+
assert op.pass_name.data == "canonicalize"
80+
assert isinstance(op.options, DictionaryAttr)
81+
assert op.options == DictionaryAttr({})
82+
assert op.verify_() is None
83+
84+
85+
def test_invalid_options():
86+
"""Test ApplyRegisteredPassOp constructor with invalid options type."""
87+
target = create_ssa_value(AnyOpType())
88+
89+
with pytest.raises(
90+
VerifyException, match="invalid_options should be of base attribute dictionary"
91+
):
92+
ApplyRegisteredPassOp("canonicalize", target, "invalid_options").verify_()
93+
94+
95+
def test_transform_dialect_filecheck(run_filecheck):
96+
"""Test that the transform dialect operations are parsed correctly."""
97+
program = """
98+
"builtin.module"() ({
99+
"transform.named_sequence"() <{function_type = (!transform.any_op) -> (), sym_name = "__transform_main"}> ({
100+
^bb0(%arg0: !transform.any_op):
101+
%0 = "transform.structured.match"(%arg0) <{ops = ["func.func"]}> : (!transform.any_op) -> !transform.any_op
102+
// CHECK: options = {"invalid-option" = 1 : i64}
103+
%1 = "transform.apply_registered_pass"(%0) <{options = {"invalid-option" = 1 : i64}, pass_name = "canonicalize"}> : (!transform.any_op) -> !transform.any_op
104+
"transform.yield"() : () -> ()
105+
}) : () -> ()
106+
}) {transform.with_named_sequence} : () -> ()
107+
"""
108+
109+
run_filecheck(program)
110+
111+
112+
def test_integration_for_transform_interpreter(capsys):
113+
"""Test that a pass with options is run via the transform interpreter"""
114+
115+
@compiler_transform
116+
@dataclass(frozen=True)
117+
class _HelloWorld(passes.ModulePass):
118+
name = "hello-world"
119+
120+
custom_print: str | None = None
121+
122+
def apply(self, _ctx: Context, _module: builtin.ModuleOp) -> None:
123+
if self.custom_print:
124+
print(self.custom_print)
125+
else:
126+
print("hello world")
127+
128+
@xdsl_from_docstring
129+
def program():
130+
"""
131+
builtin.module {
132+
builtin.module {
133+
transform.named_sequence @__transform_main(%arg0 : !transform.op<"builtin.module">) {
134+
%0 = "transform.apply_registered_pass"(%arg0) <{options = {"custom_print" = "Hello from custom option!"}, pass_name = "hello-world"}> : (!transform.op<"builtin.module">) -> !transform.op<"builtin.module">
135+
transform.yield
136+
}
137+
}
138+
}
139+
"""
140+
141+
ctx = xdsl.context.Context()
142+
ctx.load_dialect(builtin.Builtin)
143+
ctx.load_dialect(transform.Transform)
144+
145+
mod = program()
146+
pipeline = xdsl.passes.PassPipeline((ApplyTransformSequence(),))
147+
pipeline.apply(ctx, mod)
148+
149+
assert "Hello from custom option!" in capsys.readouterr().out

tests/python_compiler/test_python_compiler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,11 @@
2929
from catalyst import CompileError
3030
from xdsl import passes
3131
from xdsl.context import Context
32-
from xdsl.dialects import builtin, transform
32+
from xdsl.dialects import builtin
3333
from xdsl.interpreters import Interpreter
3434

3535
from pennylane.compiler.python_compiler import Compiler
36+
from pennylane.compiler.python_compiler.dialects import transform
3637
from pennylane.compiler.python_compiler.jax_utils import (
3738
jax_from_docstring,
3839
module,

0 commit comments

Comments
 (0)