Skip to content

Commit f6ac5b6

Browse files
albi3ropaul0403
andauthored
Use pass_name from pennylane transform (#2149)
**Context:** Trying to improve our integration between pennylane transforms and catalyst passes. See #2206 for where we are able to update all built-in catalyst passes to be transforms instead. That PR helps validate we can make this change with no breaking behaviour. Going forward, we can setup a plan to get rid of the catalyst-specific pass handling infrastructure. Also note that this PR targets the PennyLane branch, so we can confirm we won't need anymore changes there to get catalyst working correctly before approving and merging. **Description of the Change:** Uses the `transform.pass_name` property as the higher priority source of the `pass_name`. Updates the `pass_pipeline` from the `quantum_kernel_p` primitive to be a tuple of `TransformContainer` objects instead of `Pass`, but leaves in the logic to handle `tuple[Pass,...]` for safer backwards compatibility. Also makes using transforms with pass names by splitting a `QNode`'s transform program into "things at the start of the program that are tape transforms" and "things at the end of the program that are MLIR passes". This will allow us to get rid of `PassPipelineWrapper` use for both systems and soften the change from old frontend to new frontend. **Benefits:** Unified handling of transforms across pennylane and catalyst. **Possible Drawbacks:** **Related GitHub Issues:** Depends on PennyLaneAI/pennylane#8539 [sc-103775] --------- Co-authored-by: Paul <[email protected]>
1 parent ad9dd18 commit f6ac5b6

File tree

11 files changed

+254
-61
lines changed

11 files changed

+254
-61
lines changed

.dep-versions

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ enzyme=v0.0.203
1010

1111
# For a custom PL version, update the package version here and at
1212
# 'doc/requirements.txt'
13-
pennylane=0.44.0-dev42
13+
pennylane=0.44.0-dev44
1414

1515
# For a custom LQ/LK version, update the package version here and at
1616
# 'doc/requirements.txt'

doc/releases/changelog-dev.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,10 @@
7070

7171
<h3>Improvements 🛠</h3>
7272

73+
* Catalyst can now use the new `pass_name` property of pennylane transform objects. Passes can now
74+
be created using `qml.transform(pass_name=pass_name)` instead of `PassPipelineWrapper`.
75+
[(#2149](https://github.com/PennyLaneAI/catalyst/pull/2149)
76+
7377
* An error is now raised if a transform is applied inside a QNode when program capture is enabled.
7478
[(#2256)](https://github.com/PennyLaneAI/catalyst/pull/2256)
7579

doc/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,4 @@ lxml_html_clean
3333
--extra-index-url https://test.pypi.org/simple/
3434
pennylane-lightning-kokkos==0.44.0-dev16
3535
pennylane-lightning==0.44.0-dev16
36-
pennylane==0.44.0-dev42
36+
pennylane==0.44.0-dev44

frontend/catalyst/from_plxpr/from_plxpr.py

Lines changed: 51 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,9 @@
2929
from pennylane.capture.expand_transforms import ExpandTransformsInterpreter
3030
from pennylane.capture.primitives import jacobian_prim as pl_jac_prim
3131
from pennylane.capture.primitives import transform_prim
32-
from pennylane.transforms import cancel_inverses as pl_cancel_inverses
3332
from pennylane.transforms import commute_controlled as pl_commute_controlled
3433
from pennylane.transforms import decompose as pl_decompose
3534
from pennylane.transforms import merge_amplitude_embedding as pl_merge_amplitude_embedding
36-
from pennylane.transforms import merge_rotations as pl_merge_rotations
3735
from pennylane.transforms import single_qubit_fusion as pl_single_qubit_fusion
3836
from pennylane.transforms import unitary_to_rot as pl_unitary_to_rot
3937

@@ -48,7 +46,6 @@
4846
qdealloc_p,
4947
quantum_kernel_p,
5048
)
51-
from catalyst.passes.pass_api import Pass
5249
from catalyst.utils.patching import Patcher
5350

5451
from .qfunc_interpreter import PLxPRToQuantumJaxprInterpreter
@@ -286,7 +283,9 @@ def handle_qnode(
286283
# Fallback to the legacy decomposition if the graph-based decomposition failed
287284
if not graph_succeeded:
288285
# Remove the decompose-lowering pass from the pipeline
289-
self._pass_pipeline = [p for p in self._pass_pipeline if p.name != "decompose-lowering"]
286+
self._pass_pipeline = [
287+
p for p in self._pass_pipeline if p.pass_name != "decompose-lowering"
288+
]
290289
closed_jaxpr = _apply_compiler_decompose_to_plxpr(
291290
inner_jaxpr=closed_jaxpr.jaxpr,
292291
consts=closed_jaxpr.consts,
@@ -334,11 +333,9 @@ def calling_convention(*args):
334333
# otherwise their value will be None. The second value indicates if the transform
335334
# requires decomposition to be supported by Catalyst.
336335
transforms_to_passes = {
337-
pl_cancel_inverses: ("cancel-inverses", False),
338336
pl_commute_controlled: (None, False),
339337
pl_decompose: (None, False),
340338
pl_merge_amplitude_embedding: (None, True),
341-
pl_merge_rotations: ("merge-rotations", False),
342339
pl_single_qubit_fusion: (None, False),
343340
pl_unitary_to_rot: (None, False),
344341
}
@@ -349,6 +346,47 @@ def register_transform(pl_transform, pass_name, decomposition):
349346
transforms_to_passes[pl_transform] = (pass_name, decomposition)
350347

351348

349+
def _handle_decompose_transform(self, inner_jaxpr, consts, non_const_args, tkwargs):
350+
if not self.requires_decompose_lowering:
351+
self.requires_decompose_lowering = True
352+
else:
353+
raise NotImplementedError("Multiple decomposition transforms are not yet supported.")
354+
355+
next_eval = copy(self)
356+
# Update the decompose_gateset to be used by the quantum kernel primitive
357+
# TODO: we originally wanted to treat decompose_gateset as a queue of
358+
# gatesets to be used by the decompose-lowering pass at MLIR
359+
# but this requires a C++ implementation of the graph-based decomposition
360+
# which doesn't exist yet.
361+
next_eval.decompose_tkwargs = tkwargs
362+
363+
# Note. We don't perform the compiler-specific decomposition here
364+
# to be able to support multiple decomposition transforms
365+
# and collect all the required gatesets
366+
# as well as being able to support other transforms in between.
367+
368+
# The compiler specific transformation will be performed
369+
# in the qnode handler.
370+
371+
# Add the decompose-lowering pass to the start of the pipeline
372+
t = qml.transform(pass_name="decompose-lowering")
373+
pass_container = qml.transforms.core.TransformContainer(t)
374+
next_eval._pass_pipeline.insert(0, pass_container)
375+
376+
# We still need to construct and solve the graph based on
377+
# the current jaxpr based on the current gateset
378+
# but we don't rewrite the jaxpr at this stage.
379+
380+
# gds_interpreter = DecompRuleInterpreter(*targs, **tkwargs)
381+
382+
# def gds_wrapper(*args):
383+
# return gds_interpreter.eval(inner_jaxpr, consts, *args)
384+
385+
# final_jaxpr = jax.make_jaxpr(gds_wrapper)(*args)
386+
# return self.eval(final_jaxpr.jaxpr, consts, *non_const_args)
387+
return next_eval.eval(inner_jaxpr, consts, *non_const_args)
388+
389+
352390
# pylint: disable=too-many-arguments
353391
@WorkflowInterpreter.register_primitive(transform_prim)
354392
def handle_transform(
@@ -375,45 +413,11 @@ def handle_transform(
375413
and transform._plxpr_transform.__name__ == "decompose_plxpr_to_plxpr"
376414
and qml.decomposition.enabled_graph()
377415
):
378-
# Handle the conversion from plxpr to Catalyst jaxpr for a PL transform.
379-
if not self.requires_decompose_lowering:
380-
self.requires_decompose_lowering = True
381-
else:
382-
raise NotImplementedError("Multiple decomposition transforms are not yet supported.")
383-
384-
next_eval = copy(self)
385-
# Update the decompose_gateset to be used by the quantum kernel primitive
386-
# TODO: we originally wanted to treat decompose_gateset as a queue of
387-
# gatesets to be used by the decompose-lowering pass at MLIR
388-
# but this requires a C++ implementation of the graph-based decomposition
389-
# which doesn't exist yet.
390-
next_eval.decompose_tkwargs = tkwargs
391-
392-
# Note. We don't perform the compiler-specific decomposition here
393-
# to be able to support multiple decomposition transforms
394-
# and collect all the required gatesets
395-
# as well as being able to support other transforms in between.
396-
397-
# The compiler specific transformation will be performed
398-
# in the qnode handler.
399-
400-
# Add the decompose-lowering pass to the start of the pipeline
401-
next_eval._pass_pipeline.insert(0, Pass("decompose-lowering"))
416+
return _handle_decompose_transform(self, inner_jaxpr, consts, non_const_args, tkwargs)
402417

403-
# We still need to construct and solve the graph based on
404-
# the current jaxpr based on the current gateset
405-
# but we don't rewrite the jaxpr at this stage.
406-
407-
# gds_interpreter = DecompRuleInterpreter(*targs, **tkwargs)
408-
409-
# def gds_wrapper(*args):
410-
# return gds_interpreter.eval(inner_jaxpr, consts, *args)
411-
412-
# final_jaxpr = jax.make_jaxpr(gds_wrapper)(*args)
413-
# return self.eval(final_jaxpr.jaxpr, consts, *non_const_args)
414-
return next_eval.eval(inner_jaxpr, consts, *non_const_args)
415-
416-
catalyst_pass_name = transforms_to_passes.get(transform, (None,))[0]
418+
catalyst_pass_name = transform.pass_name
419+
if catalyst_pass_name is None:
420+
catalyst_pass_name = transforms_to_passes.get(transform, (None,))[0]
417421
if catalyst_pass_name is None:
418422
# Use PL's ExpandTransformsInterpreter to expand this and any embedded
419423
# transform according to PL rules. It works by overriding the primitive
@@ -435,7 +439,9 @@ def wrapper(*args):
435439

436440
# Apply the corresponding Catalyst pass counterpart
437441
next_eval = copy(self)
438-
next_eval._pass_pipeline.insert(0, Pass(catalyst_pass_name, *targs, **tkwargs))
442+
t = qml.transform(pass_name=catalyst_pass_name)
443+
bound_pass = qml.transforms.core.TransformContainer(t, args=targs, kwargs=tkwargs)
444+
next_eval._pass_pipeline.insert(0, bound_pass)
439445
return next_eval.eval(inner_jaxpr, consts, *non_const_args)
440446

441447

frontend/catalyst/jax_primitives_utils.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,16 @@ def __exit__(self, exc_type, exc_val, exc_tb):
325325
self.ctx.module_context = self.old_module_context
326326

327327

328+
def _lowered_options(args, kwargs):
329+
lowered_options = {}
330+
for arg in args:
331+
lowered_options[str(arg)] = get_mlir_attribute_from_pyval(True)
332+
for option, value in kwargs.items():
333+
mlir_option = str(option).replace("_", "-")
334+
lowered_options[mlir_option] = get_mlir_attribute_from_pyval(value)
335+
return lowered_options
336+
337+
328338
def transform_named_sequence_lowering(jax_ctx: mlir.LoweringRuleContext, pipeline):
329339
"""Generate a transform module embedded in the current module and schedule
330340
the transformations in pipeline"""
@@ -371,11 +381,16 @@ def transform_named_sequence_lowering(jax_ctx: mlir.LoweringRuleContext, pipelin
371381
with ir.InsertionPoint(bb_named_sequence):
372382
target = bb_named_sequence.arguments[0]
373383
for _pass in pipeline:
374-
options = _pass.get_options()
384+
if isinstance(_pass, qml.transforms.core.TransformContainer):
385+
options = _lowered_options(_pass.args, _pass.kwargs)
386+
name = _pass.pass_name
387+
else:
388+
options = _pass.get_options()
389+
name = _pass.name
375390
apply_registered_pass_op = ApplyRegisteredPassOp(
376391
result=transform_mod_type,
377392
target=target,
378-
pass_name=_pass.name,
393+
pass_name=name,
379394
options=options,
380395
dynamic_options={},
381396
)
@@ -387,7 +402,7 @@ def transform_named_sequence_lowering(jax_ctx: mlir.LoweringRuleContext, pipelin
387402
is_xdsl_pass,
388403
)
389404

390-
if is_xdsl_pass(_pass.name):
405+
if is_xdsl_pass(name):
391406
uses_xdsl_passes = True
392407
apply_registered_pass_op.operation.attributes["catalyst.xdsl_pass"] = (
393408
ir.UnitAttr.get()

frontend/catalyst/qfunc.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -285,18 +285,23 @@ def __call__(self, *args, **kwargs):
285285

286286
assert isinstance(self, qml.QNode)
287287

288+
new_transform_program, new_pipeline = _extract_passes(self.transform_program)
288289
# Update the qnode with peephole pipeline
289-
pass_pipeline = kwargs.pop("pass_pipeline", [])
290-
pass_pipeline = dictionary_to_list_of_passes(pass_pipeline)
290+
old_pipeline = kwargs.pop("pass_pipeline", None)
291+
processed_old_pipeline = tuple(dictionary_to_list_of_passes(old_pipeline))
292+
pass_pipeline = processed_old_pipeline + new_pipeline
293+
new_qnode = copy(self)
294+
# pylint: disable=attribute-defined-outside-init, protected-access
295+
new_qnode._transform_program = new_transform_program
291296

292297
# Mid-circuit measurement configuration/execution
293-
fn_result = configure_mcm_and_try_one_shot(self, args, kwargs, pass_pipeline)
298+
fn_result = configure_mcm_and_try_one_shot(new_qnode, args, kwargs, pass_pipeline)
294299

295300
# If the qnode is failed to execute as one-shot, fn_result will be None
296301
if fn_result is not None:
297302
return fn_result
298303

299-
new_device = copy(self.device)
304+
new_device = copy(new_qnode.device)
300305
qjit_device = QJITDevice(new_device)
301306

302307
static_argnums = kwargs.pop("static_argnums", ())
@@ -307,11 +312,11 @@ def __call__(self, *args, **kwargs):
307312

308313
def _eval_quantum(*args, **kwargs):
309314
trace_result = trace_quantum_function(
310-
self.func,
315+
new_qnode.func,
311316
qjit_device,
312317
args,
313318
kwargs,
314-
self,
319+
new_qnode,
315320
static_argnums,
316321
debug_info,
317322
)
@@ -655,3 +660,22 @@ def wrap_single_shot_qnode(*_):
655660
return _finalize_output(out, ctx)
656661

657662
return one_shot_wrapper
663+
664+
665+
def _extract_passes(transform_program):
666+
"""Extract transforms with pass names from the end of the TransformProgram."""
667+
tape_transforms = []
668+
pass_pipeline = []
669+
i = len(transform_program)
670+
for t in reversed(transform_program):
671+
if t.pass_name is None:
672+
break
673+
i -= 1
674+
pass_pipeline = transform_program[i:]
675+
tape_transforms = transform_program[:i]
676+
for t in tape_transforms:
677+
if t.transform is None:
678+
raise ValueError(
679+
f"{t} without a tape definition occurs before tape transform {tape_transforms[-1]}."
680+
)
681+
return qml.transforms.core.TransformProgram(tape_transforms), tuple(pass_pipeline)

frontend/test/lit/test_decomposition.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def wrapper():
4646
error_msg = str(e)
4747
if (
4848
"Unsupported type annotation None for parameter pauli_word" in error_msg
49+
or "Unsupported type annotation <class 'str'> for parameter pauli_word" in error_msg
4950
or "index is out of bounds for axis" in error_msg
5051
):
5152
print(f"# SKIPPED {test_func.__name__}: PauliRot type annotation issue")

frontend/test/lit/test_peephole_optimizations.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,42 @@ def test_pipeline_lowering_workflow(x):
8787
test_pipeline_lowering()
8888

8989

90+
def test_transform_lowering():
91+
"""
92+
Basic pipeline lowering on one qnode.
93+
"""
94+
95+
@qjit(keep_intermediate=True)
96+
@qml.transforms.merge_rotations
97+
@qml.transforms.cancel_inverses
98+
@qml.qnode(qml.device("lightning.qubit", wires=2))
99+
def test_pipeline_lowering_workflow(x):
100+
qml.RX(x, wires=[0])
101+
qml.Hadamard(wires=[1])
102+
qml.Hadamard(wires=[1])
103+
return qml.expval(qml.PauliY(wires=0))
104+
105+
# CHECK: pipeline=(<cancel_inverses((), {})>, <merge_rotations((), {})>)
106+
print_jaxpr(test_pipeline_lowering_workflow, 1.2)
107+
108+
# CHECK: transform.named_sequence @__transform_main
109+
# CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "cancel-inverses" to {{%.+}}
110+
# CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotations" to {{%.+}}
111+
# CHECK-NEXT: transform.yield
112+
print_mlir(test_pipeline_lowering_workflow, 1.2)
113+
114+
# CHECK: {{%.+}} = call @test_pipeline_lowering_workflow_0(
115+
# CHECK: func.func public @test_pipeline_lowering_workflow_0(
116+
# CHECK: {{%.+}} = quantum.custom "RX"({{%.+}}) {{%.+}} : !quantum.bit
117+
# CHECK-NOT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit
118+
# CHECK-NOT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit
119+
test_pipeline_lowering_workflow(42.42)
120+
flush_peephole_opted_mlir_to_iostream(test_pipeline_lowering_workflow)
121+
122+
123+
test_transform_lowering()
124+
125+
90126
def test_pipeline_lowering_keep_original():
91127
"""
92128
Test when the pipelined qnode and the original qnode are both used,

frontend/test/pytest/from_plxpr/test_capture_integration.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1048,7 +1048,7 @@ def circuit(x: float):
10481048
assert jnp.allclose(circuit(0.1), capture_result)
10491049

10501050
@pytest.mark.usefixtures("use_capture")
1051-
def test_pass_with_options(self, backend):
1051+
def test_pass_with_options_patch(self, backend):
10521052
"""Test the integration for a circuit with a pass that takes in options."""
10531053

10541054
@qml.transform
@@ -1071,6 +1071,25 @@ def captured_circuit():
10711071
in capture_mlir
10721072
)
10731073

1074+
@pytest.mark.usefixtures("use_capture")
1075+
def test_pass_with_options(self, backend):
1076+
"""Test the integration for a circuit with a pass that takes in options."""
1077+
1078+
my_pass = qml.transform(pass_name="my-pass")
1079+
1080+
@qjit(target="mlir")
1081+
@partial(my_pass, my_option="my_option_value", my_other_option=False)
1082+
@qml.qnode(qml.device(backend, wires=1))
1083+
def captured_circuit():
1084+
return qml.expval(qml.PauliZ(0))
1085+
1086+
capture_mlir = captured_circuit.mlir
1087+
assert 'transform.apply_registered_pass "my-pass"' in capture_mlir
1088+
assert (
1089+
'with options = {"my-option" = "my_option_value", "my-other-option" = false}'
1090+
in capture_mlir
1091+
)
1092+
10741093
def test_transform_cancel_inverses_workflow(self, backend):
10751094
"""Test the integration for a circuit with a 'cancel_inverses' transform."""
10761095

0 commit comments

Comments
 (0)