Skip to content

Commit 1dab5d8

Browse files
albi3ropaul0403
andauthored
fix translation of transforms with multiple qnodes (#2167)
**Context:** Transforms were global for the entire workflow. We needed to make them specific to the code they are actually applied to. **Description of the Change:** Create new interpreters for transforms and only append transforms to the new interpreter. **Benefits:** Proper application of transforms. **Possible Drawbacks:** **Related GitHub Issues:** [sc-102818] --------- Co-authored-by: Paul <[email protected]>
1 parent 39144c8 commit 1dab5d8

File tree

3 files changed

+53
-6
lines changed

3 files changed

+53
-6
lines changed

doc/releases/changelog-dev.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@
3535
[(#2128)](https://github.com/PennyLaneAI/catalyst/pull/2128)
3636
[(#2133)](https://github.com/PennyLaneAI/catalyst/pull/2133)
3737

38+
* Fixes the translation of a workflow with different transforms applied to different qnodes.
39+
[(#2167)](https://github.com/PennyLaneAI/catalyst/pull/2167)
40+
3841
<h3>Internal changes ⚙️</h3>
3942

4043
* Refactor Catalyst pass registering so that it's no longer necessary to manually add new

frontend/catalyst/from_plxpr/from_plxpr.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919

2020
import warnings
21+
from copy import copy
2122
from functools import partial
2223
from typing import Callable
2324

@@ -137,6 +138,14 @@ def f(x):
137138
class WorkflowInterpreter(PlxprInterpreter):
138139
"""An interpreter that converts a qnode primitive from a plxpr variant to a catalyst jaxpr variant."""
139140

141+
def __copy__(self):
142+
new_version = WorkflowInterpreter()
143+
new_version._pass_pipeline = copy(self._pass_pipeline)
144+
new_version.init_qreg = self.init_qreg
145+
new_version.requires_decompose_lowering = self.requires_decompose_lowering
146+
new_version.decompose_tkwargs = copy(self.decompose_tkwargs)
147+
return new_version
148+
140149
def __init__(self):
141150
self._pass_pipeline = []
142151
self.init_qreg = None
@@ -284,12 +293,13 @@ def handle_transform(
284293
"Multiple decomposition transforms are not yet supported."
285294
)
286295

296+
next_eval = copy(self)
287297
# Update the decompose_gateset to be used by the quantum kernel primitive
288298
# TODO: we originally wanted to treat decompose_gateset as a queue of
289299
# gatesets to be used by the decompose-lowering pass at MLIR
290300
# but this requires a C++ implementation of the graph-based decomposition
291301
# which doesn't exist yet.
292-
self.decompose_tkwargs = tkwargs
302+
next_eval.decompose_tkwargs = tkwargs
293303

294304
# Note. We don't perform the compiler-specific decomposition here
295305
# to be able to support multiple decomposition transforms
@@ -300,7 +310,7 @@ def handle_transform(
300310
# in the qnode handler.
301311

302312
# Add the decompose-lowering pass to the start of the pipeline
303-
self._pass_pipeline.insert(0, Pass("decompose-lowering"))
313+
next_eval._pass_pipeline.insert(0, Pass("decompose-lowering"))
304314

305315
# We still need to construct and solve the graph based on
306316
# the current jaxpr based on the current gateset
@@ -313,7 +323,7 @@ def handle_transform(
313323

314324
# final_jaxpr = jax.make_jaxpr(gds_wrapper)(*args)
315325
# return self.eval(final_jaxpr.jaxpr, consts, *non_const_args)
316-
return self.eval(inner_jaxpr, consts, *non_const_args)
326+
return next_eval.eval(inner_jaxpr, consts, *non_const_args)
317327

318328
if catalyst_pass_name is None:
319329
# Use PL's ExpandTransformsInterpreter to expand this and any embedded
@@ -333,11 +343,12 @@ def wrapper(*args):
333343
final_jaxpr.jaxpr, final_jaxpr.consts, targs, tkwargs, *non_const_args
334344
)
335345

336-
return self.eval(final_jaxpr.jaxpr, final_jaxpr.consts, *non_const_args)
346+
return copy(self).eval(final_jaxpr.jaxpr, final_jaxpr.consts, *non_const_args)
337347

338348
# Apply the corresponding Catalyst pass counterpart
339-
self._pass_pipeline.insert(0, Pass(catalyst_pass_name, *targs, **tkwargs))
340-
return self.eval(inner_jaxpr, consts, *non_const_args)
349+
next_eval = copy(self)
350+
next_eval._pass_pipeline.insert(0, Pass(catalyst_pass_name, *targs, **tkwargs))
351+
return next_eval.eval(inner_jaxpr, consts, *non_const_args)
341352

342353

343354
# This is our registration factory for PL transforms. The loop below iterates

frontend/test/lit/test_from_plxpr.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,3 +421,36 @@ def circuit3():
421421

422422

423423
test_pass_decomposition()
424+
425+
426+
def test_two_qnodes_with_different_passes_in_one_workflow():
427+
"""Two qnodes with different passes in one workflow."""
428+
429+
dev = qml.device("null.qubit", wires=1)
430+
431+
qml.capture.enable()
432+
433+
@qml.qjit(target="mlir")
434+
def workflow():
435+
@qml.transforms.merge_rotations
436+
@qml.qnode(dev)
437+
def circuit1():
438+
return qml.probs()
439+
440+
@qml.transforms.cancel_inverses
441+
@qml.qnode(dev)
442+
def circuit2():
443+
return qml.probs()
444+
445+
return circuit1() + circuit2()
446+
447+
# CHECK: module @module_circuit1 {
448+
# CHECK: transform.apply_registered_pass "merge-rotations"
449+
# CHECK: module @module_circuit2 {
450+
# CHECK: transform.apply_registered_pass "remove-chained-self-inverse"
451+
452+
print(workflow.mlir)
453+
qml.capture.disable()
454+
455+
456+
test_two_qnodes_with_different_passes_in_one_workflow()

0 commit comments

Comments
 (0)