Skip to content

Commit 98dbc5b

Browse files
authored
[frontend] outline register_transform (#1705)
**Context:** I need to be able to register equivalence from plxpr transforms to MLIR transforms. There is this piece of code that does that, but it is not a function. **Description of the Change:** Outline that piece of code into its own function. **Benefits:** I can register the equivalence from plxpr transforms to MLIR transforms. **Possible Drawbacks:** None **Related GitHub Issues:** Fixes #1703
1 parent e558aa7 commit 98dbc5b

File tree

2 files changed

+15
-5
lines changed

2 files changed

+15
-5
lines changed

doc/releases/changelog-dev.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@
5959

6060
<h3>Internal changes ⚙️</h3>
6161

62+
* Creates a function that allows developers to register an equivalent MLIR transform for a given PLxPR transform.
63+
[(#1705)](https://github.com/PennyLaneAI/catalyst/pull/1705)
64+
6265
* Stop overriding the `num_wires` property when the operator can exist on `AnyWires`. This allows the deprecation
6366
of `WiresEnum` in pennylane.
6467
[(#1667)](https://github.com/PennyLaneAI/catalyst/pull/1667)

frontend/catalyst/from_plxpr.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -201,11 +201,10 @@ def handle_qnode(
201201
}
202202

203203

204-
# This is our registration factory for PL transforms. The loop below iterates
205-
# across the map above and generates a custom handler for each transform.
206-
# In order to ensure early binding, we pass the PL plxpr transform and the
207-
# Catalyst pass as arguments whose default values are set by the loop.
208-
for pl_transform, (pass_name, decomposition) in transforms_to_passes.items():
204+
# pylint: disable-next=redefined-outer-name
205+
def register_transform(pl_transform, pass_name, decomposition):
206+
"""Register pennylane transforms and their conversion to Catalyst transforms"""
207+
209208
# pylint: disable=unused-argument, too-many-arguments, cell-var-from-loop
210209
@WorkflowInterpreter.register_primitive(pl_transform._primitive)
211210
def handle_transform(
@@ -251,6 +250,14 @@ def wrapper(*args):
251250
return self.eval(inner_jaxpr, consts, *non_const_args)
252251

253252

253+
# This is our registration factory for PL transforms. The loop below iterates
254+
# across the map above and generates a custom handler for each transform.
255+
# In order to ensure early binding, we pass the PL plxpr transform and the
256+
# Catalyst pass as arguments whose default values are set by the loop.
257+
for pl_transform, (pass_name, decomposition) in transforms_to_passes.items():
258+
register_transform(pl_transform, pass_name, decomposition)
259+
260+
254261
class QFuncPlxprInterpreter(PlxprInterpreter):
255262
"""An interpreter that converts plxpr into catalyst-variant jaxpr.
256263

0 commit comments

Comments
 (0)