Skip to content
Open
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
53df1bc
use pass name from transform
albi3ro Oct 24, 2025
cf01501
some udpates
albi3ro Nov 3, 2025
a08905b
make backwards compatible
albi3ro Nov 5, 2025
0cb9476
messed from plxpr up somehow
albi3ro Nov 13, 2025
dd3a655
Merge branch 'main' into pass-pipeline-transform-program
albi3ro Nov 13, 2025
3b2c5b8
more polishing
albi3ro Nov 13, 2025
42b5052
Merge branch 'main' into pass-pipeline-transform-program
albi3ro Nov 14, 2025
e885b91
some test fixes
albi3ro Nov 14, 2025
5c99bba
fix failing test
albi3ro Nov 17, 2025
ffedfe4
see if that fixes the failure
albi3ro Nov 17, 2025
77a211d
oops
albi3ro Nov 17, 2025
c738102
Merge branch 'main' into pass-pipeline-transform-program
albi3ro Nov 17, 2025
a1abd24
[skip ci] starting to test
albi3ro Nov 17, 2025
b6441bf
try using transform instead of passes
albi3ro Nov 18, 2025
838e45f
switch passes to being tranfsorms
albi3ro Nov 18, 2025
b2ca54d
update apply_pass and apply_pass_plugin
albi3ro Nov 18, 2025
f76ac1b
update apply_pass and apply_pass_plugin
albi3ro Nov 18, 2025
ded40ef
Merge branch 'main' into pass-pipeline-transform-program
albi3ro Nov 19, 2025
8b045d0
adding in some tests
albi3ro Nov 19, 2025
ed528ba
black and isort
albi3ro Nov 19, 2025
80c4a33
Merge branch 'main' into pass-pipeline-transform-program
albi3ro Nov 20, 2025
ec38c77
minor fixes
albi3ro Nov 20, 2025
d6477f5
Merge branch 'pass-pipeline-transform-program' into use-transform-pas…
albi3ro Nov 20, 2025
4771069
remove tests
albi3ro Nov 20, 2025
b5595af
Apply suggestion from @albi3ro
albi3ro Nov 20, 2025
579c27b
leave pipeline test in
albi3ro Nov 20, 2025
068e30b
Merge branch 'use-transform-pass-name' of https://github.com/PennyLan…
albi3ro Nov 20, 2025
e01158e
update name of cancel_inverses in test
albi3ro Nov 20, 2025
d37d89b
Try to fix the lit test again
albi3ro Nov 20, 2025
3586dd3
delete test files
albi3ro Nov 20, 2025
9d167c6
Merge branch 'main' into use-transform-pass-name
albi3ro Dec 3, 2025
80e149e
unpin pl branch
albi3ro Dec 3, 2025
3c8734e
Merge branch 'use-transform-pass-name' of https://github.com/PennyLan…
albi3ro Dec 3, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .github/workflows/check-catalyst.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,10 @@ jobs:
python3 -m pip install oqc-qcaas-client
make frontend
- name: Install PennyLane branch
run: |
pip install --no-deps --force git+https://github.com/PennyLaneAI/pennylane@add-pass-name
- name: Get Cached LLVM Build
id: cache-llvm-build
uses: actions/cache@v4
Expand Down Expand Up @@ -558,6 +562,10 @@ jobs:
python3 -m pip install -r requirements.txt
make frontend
- name: Install PennyLane branch
run: |
pip install --no-deps --force git+https://github.com/PennyLaneAI/pennylane@add-pass-name
- name: Get Cached LLVM Build
id: cache-llvm-build
uses: actions/cache@v4
Expand Down Expand Up @@ -620,6 +628,10 @@ jobs:
python3 -m pip install -r requirements.txt
make frontend
- name: Install PennyLane branch
run: |
pip install --no-deps --force git+https://github.com/PennyLaneAI/pennylane@add-pass-name
- name: Get Cached LLVM Build
id: cache-llvm-build
uses: actions/cache@v4
Expand Down
95 changes: 51 additions & 44 deletions frontend/catalyst/from_plxpr/from_plxpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,9 @@
from pennylane.capture.primitives import jacobian_prim as pl_jac_prim
from pennylane.capture.primitives import transform_prim
from pennylane.ops.functions.map_wires import _map_wires_transform as pl_map_wires
from pennylane.transforms import cancel_inverses as pl_cancel_inverses
from pennylane.transforms import commute_controlled as pl_commute_controlled
from pennylane.transforms import decompose as pl_decompose
from pennylane.transforms import merge_amplitude_embedding as pl_merge_amplitude_embedding
from pennylane.transforms import merge_rotations as pl_merge_rotations
from pennylane.transforms import single_qubit_fusion as pl_single_qubit_fusion
from pennylane.transforms import unitary_to_rot as pl_unitary_to_rot

Expand All @@ -48,7 +46,6 @@
qdealloc_p,
quantum_kernel_p,
)
from catalyst.passes.pass_api import Pass

from .qfunc_interpreter import PLxPRToQuantumJaxprInterpreter
from .qubit_handler import (
Expand Down Expand Up @@ -215,7 +212,9 @@ def handle_qnode(
# Fallback to the legacy decomposition if the graph-based decomposition failed
if not graph_succeeded:
# Remove the decompose-lowering pass from the pipeline
self._pass_pipeline = [p for p in self._pass_pipeline if p.name != "decompose-lowering"]
self._pass_pipeline = [
p for p in self._pass_pipeline if p.pass_name != "decompose-lowering"
]
closed_jaxpr = _apply_compiler_decompose_to_plxpr(
inner_jaxpr=closed_jaxpr.jaxpr,
consts=closed_jaxpr.consts,
Expand Down Expand Up @@ -263,12 +262,10 @@ def calling_convention(*args):
# otherwise their value will be None. The second value indicates if the transform
# requires decomposition to be supported by Catalyst.
transforms_to_passes = {
pl_cancel_inverses: ("cancel-inverses", False),
pl_commute_controlled: (None, False),
pl_decompose: (None, False),
pl_map_wires: (None, False),
pl_merge_amplitude_embedding: (None, True),
pl_merge_rotations: ("merge-rotations", False),
pl_single_qubit_fusion: (None, False),
pl_unitary_to_rot: (None, False),
}
Expand All @@ -279,6 +276,47 @@ def register_transform(pl_transform, pass_name, decomposition):
transforms_to_passes[pl_transform] = (pass_name, decomposition)


def _handle_decompose_transform(self, inner_jaxpr, consts, non_const_args, tkwargs):
if not self.requires_decompose_lowering:
self.requires_decompose_lowering = True
else:
raise NotImplementedError("Multiple decomposition transforms are not yet supported.")

next_eval = copy(self)
# Update the decompose_gateset to be used by the quantum kernel primitive
# TODO: we originally wanted to treat decompose_gateset as a queue of
# gatesets to be used by the decompose-lowering pass at MLIR
# but this requires a C++ implementation of the graph-based decomposition
# which doesn't exist yet.
next_eval.decompose_tkwargs = tkwargs

# Note. We don't perform the compiler-specific decomposition here
# to be able to support multiple decomposition transforms
# and collect all the required gatesets
# as well as being able to support other transforms in between.

# The compiler specific transformation will be performed
# in the qnode handler.

# Add the decompose-lowering pass to the start of the pipeline
t = qml.transform(pass_name="decompose-lowering")
pass_container = qml.transforms.core.TransformContainer(t)
next_eval._pass_pipeline.insert(0, pass_container)

# We still need to construct and solve the graph based on
# the current jaxpr based on the current gateset
# but we don't rewrite the jaxpr at this stage.

# gds_interpreter = DecompRuleInterpreter(*targs, **tkwargs)

# def gds_wrapper(*args):
# return gds_interpreter.eval(inner_jaxpr, consts, *args)

# final_jaxpr = jax.make_jaxpr(gds_wrapper)(*args)
# return self.eval(final_jaxpr.jaxpr, consts, *non_const_args)
return next_eval.eval(inner_jaxpr, consts, *non_const_args)


# pylint: disable=too-many-arguments
@WorkflowInterpreter.register_primitive(transform_prim)
def handle_transform(
Expand All @@ -304,44 +342,11 @@ def handle_transform(
and transform._plxpr_transform.__name__ == "decompose_plxpr_to_plxpr"
and qml.decomposition.enabled_graph()
):
if not self.requires_decompose_lowering:
self.requires_decompose_lowering = True
else:
raise NotImplementedError("Multiple decomposition transforms are not yet supported.")

next_eval = copy(self)
# Update the decompose_gateset to be used by the quantum kernel primitive
# TODO: we originally wanted to treat decompose_gateset as a queue of
# gatesets to be used by the decompose-lowering pass at MLIR
# but this requires a C++ implementation of the graph-based decomposition
# which doesn't exist yet.
next_eval.decompose_tkwargs = tkwargs

# Note. We don't perform the compiler-specific decomposition here
# to be able to support multiple decomposition transforms
# and collect all the required gatesets
# as well as being able to support other transforms in between.

# The compiler specific transformation will be performed
# in the qnode handler.

# Add the decompose-lowering pass to the start of the pipeline
next_eval._pass_pipeline.insert(0, Pass("decompose-lowering"))
return _handle_decompose_transform(self, inner_jaxpr, consts, non_const_args, tkwargs)

# We still need to construct and solve the graph based on
# the current jaxpr based on the current gateset
# but we don't rewrite the jaxpr at this stage.

# gds_interpreter = DecompRuleInterpreter(*targs, **tkwargs)

# def gds_wrapper(*args):
# return gds_interpreter.eval(inner_jaxpr, consts, *args)

# final_jaxpr = jax.make_jaxpr(gds_wrapper)(*args)
# return self.eval(final_jaxpr.jaxpr, consts, *non_const_args)
return next_eval.eval(inner_jaxpr, consts, *non_const_args)

catalyst_pass_name = transforms_to_passes.get(transform, (None,))[0]
catalyst_pass_name = transform.pass_name
if catalyst_pass_name is None:
catalyst_pass_name = transforms_to_passes.get(transform, (None,))[0]
if catalyst_pass_name is None:
# Use PL's ExpandTransformsInterpreter to expand this and any embedded
# transform according to PL rules. It works by overriding the primitive
Expand All @@ -363,7 +368,9 @@ def wrapper(*args):

# Apply the corresponding Catalyst pass counterpart
next_eval = copy(self)
next_eval._pass_pipeline.insert(0, Pass(catalyst_pass_name, *targs, **tkwargs))
t = qml.transform(pass_name=catalyst_pass_name)
bound_pass = qml.transforms.core.TransformContainer(t, args=targs, kwargs=tkwargs)
next_eval._pass_pipeline.insert(0, bound_pass)
return next_eval.eval(inner_jaxpr, consts, *non_const_args)


Expand Down
21 changes: 18 additions & 3 deletions frontend/catalyst/jax_primitives_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,16 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self.ctx.module_context = self.old_module_context


def _lowered_options(args, kwargs):
lowered_options = {}
for arg in args:
lowered_options[str(arg)] = get_mlir_attribute_from_pyval(True)
for option, value in kwargs.items():
mlir_option = str(option).replace("_", "-")
lowered_options[mlir_option] = get_mlir_attribute_from_pyval(value)
return lowered_options


def transform_named_sequence_lowering(jax_ctx: mlir.LoweringRuleContext, pipeline):
"""Generate a transform module embedded in the current module and schedule
the transformations in pipeline"""
Expand Down Expand Up @@ -366,11 +376,16 @@ def transform_named_sequence_lowering(jax_ctx: mlir.LoweringRuleContext, pipelin
with ir.InsertionPoint(bb_named_sequence):
target = bb_named_sequence.arguments[0]
for _pass in pipeline:
options = _pass.get_options()
if isinstance(_pass, qml.transforms.core.TransformContainer):
options = _lowered_options(_pass.args, _pass.kwargs)
name = _pass.pass_name
else:
options = _pass.get_options()
name = _pass.name
apply_registered_pass_op = ApplyRegisteredPassOp(
result=transform_mod_type,
target=target,
pass_name=_pass.name,
pass_name=name,
options=options,
dynamic_options={},
)
Expand All @@ -382,7 +397,7 @@ def transform_named_sequence_lowering(jax_ctx: mlir.LoweringRuleContext, pipelin
is_xdsl_pass,
)

if is_xdsl_pass(_pass.name):
if is_xdsl_pass(name):
uses_xdsl_passes = True
apply_registered_pass_op.operation.attributes["catalyst.xdsl_pass"] = (
ir.UnitAttr.get()
Expand Down
Loading
Loading