Skip to content

Commit 797dbb4

Browse files
authored
Passes in capture path can take options (#2154)
**Context:** Passes in capture path used to swallow the options. This is likely an oversight. **Description of the Change:** Construct a `Pass` object with the option arguments in capture path, during `handle_transform` of `WorkflowInterpreter`.
1 parent 63d40bf commit 797dbb4

File tree

3 files changed

+40
-14
lines changed

3 files changed

+40
-14
lines changed

doc/releases/changelog-dev.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@
88
from Python into the compiler IR, which can make it easier to read when debugging programs.
99
[(#2054)](https://github.com/PennyLaneAI/catalyst/pull/2054)
1010

11+
* Passes registered under `qml.transform` can now take in options when used with
12+
:func:`~.qjit` with program capture enabled.
13+
[(#2154)](https://github.com/PennyLaneAI/catalyst/pull/2154)
14+
1115
<h3>Breaking changes 💔</h3>
1216

1317
<h3>Deprecations 👋</h3>

frontend/catalyst/from_plxpr/from_plxpr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ def wrapper(*args):
336336
return self.eval(final_jaxpr.jaxpr, final_jaxpr.consts, *non_const_args)
337337

338338
# Apply the corresponding Catalyst pass counterpart
339-
self._pass_pipeline.insert(0, Pass(catalyst_pass_name))
339+
self._pass_pipeline.insert(0, Pass(catalyst_pass_name, *targs, **tkwargs))
340340
return self.eval(inner_jaxpr, consts, *non_const_args)
341341

342342

frontend/test/pytest/from_plxpr/test_capture_integration.py

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import catalyst
2525
from catalyst import qjit
26+
from catalyst.from_plxpr import register_transform
2627

2728
pytestmark = pytest.mark.usefixtures("disable_capture")
2829

@@ -1048,6 +1049,30 @@ def circuit(x: float):
10481049

10491050
assert jnp.allclose(circuit(0.1), capture_result)
10501051

1052+
@pytest.mark.usefixtures("use_capture")
1053+
def test_pass_with_options(self, backend):
1054+
"""Test the integration for a circuit with a pass that takes in options."""
1055+
1056+
@qml.transform
1057+
def my_pass(_tape, my_option=None, my_other_option=None): # pylint: disable=unused-argument
1058+
"""A dummy qml.transform."""
1059+
return
1060+
1061+
register_transform(my_pass, "my-pass", False)
1062+
1063+
@qjit(target="mlir")
1064+
@partial(my_pass, my_option="my_option_value", my_other_option=False)
1065+
@qml.qnode(qml.device(backend, wires=1))
1066+
def captured_circuit():
1067+
return qml.expval(qml.PauliZ(0))
1068+
1069+
capture_mlir = captured_circuit.mlir
1070+
assert 'transform.apply_registered_pass "my-pass"' in capture_mlir
1071+
assert (
1072+
'with options = {"my-option" = "my_option_value", "my-other-option" = false}'
1073+
in capture_mlir
1074+
)
1075+
10511076
def test_transform_cancel_inverses_workflow(self, backend):
10521077
"""Test the integration for a circuit with a 'cancel_inverses' transform."""
10531078

@@ -1234,11 +1259,9 @@ def captured_circuit(U: ShapedArray([2, 2], float)):
12341259

12351260
# Catalyst 'cancel_inverses' should have been scheduled as a pass
12361261
# whereas PL 'unitary_to_rot' should have been expanded
1237-
assert (
1238-
'transform.apply_registered_pass "remove-chained-self-inverse"'
1239-
in captured_inverses_unitary.mlir
1240-
)
1241-
assert is_unitary_rotated(captured_inverses_unitary.mlir)
1262+
capture_mlir = captured_inverses_unitary.mlir
1263+
assert 'transform.apply_registered_pass "remove-chained-self-inverse"' in capture_mlir
1264+
assert is_unitary_rotated(capture_mlir)
12421265

12431266
# Case 2: During plxpr interpretation, first comes the PL transform
12441267
# without Catalyst counterpart, second comes the PL transform with it
@@ -1251,12 +1274,10 @@ def captured_circuit(U: ShapedArray([2, 2], float)):
12511274

12521275
# Both PL transforms should have been expaned and no Catalyst pass should have been
12531276
# scheduled
1254-
assert (
1255-
'transform.apply_registered_pass "remove-chained-self-inverse"'
1256-
not in captured_unitary_inverses.mlir
1257-
)
1258-
assert 'quantum.custom "Hadamard"' not in captured_unitary_inverses.mlir
1259-
assert is_unitary_rotated(captured_unitary_inverses.mlir)
1277+
capture_mlir = captured_unitary_inverses.mlir
1278+
assert 'transform.apply_registered_pass "remove-chained-self-inverse"' not in capture_mlir
1279+
assert 'quantum.custom "Hadamard"' not in capture_mlir
1280+
assert is_unitary_rotated(capture_mlir)
12601281

12611282
qml.capture.disable()
12621283

@@ -1446,11 +1467,12 @@ def captured_circuit():
14461467

14471468
capture_result = captured_circuit()
14481469

1470+
capture_mlir = captured_circuit.mlir
14491471
assert is_controlled_pushed_back(
1450-
captured_circuit.mlir, 'quantum.custom "RX"', 'quantum.custom "CNOT"'
1472+
capture_mlir, 'quantum.custom "RX"', 'quantum.custom "CNOT"'
14511473
)
14521474
assert is_controlled_pushed_back(
1453-
captured_circuit.mlir, 'quantum.custom "PauliX"', 'quantum.custom "CRX"'
1475+
capture_mlir, 'quantum.custom "PauliX"', 'quantum.custom "CRX"'
14541476
)
14551477

14561478
qml.capture.disable()

0 commit comments

Comments
 (0)