2929from pennylane .capture .expand_transforms import ExpandTransformsInterpreter
3030from pennylane .capture .primitives import jacobian_prim as pl_jac_prim
3131from pennylane .capture .primitives import transform_prim
32- from pennylane .transforms import cancel_inverses as pl_cancel_inverses
3332from pennylane .transforms import commute_controlled as pl_commute_controlled
3433from pennylane .transforms import decompose as pl_decompose
3534from pennylane .transforms import merge_amplitude_embedding as pl_merge_amplitude_embedding
36- from pennylane .transforms import merge_rotations as pl_merge_rotations
3735from pennylane .transforms import single_qubit_fusion as pl_single_qubit_fusion
3836from pennylane .transforms import unitary_to_rot as pl_unitary_to_rot
3937
4846 qdealloc_p ,
4947 quantum_kernel_p ,
5048)
51- from catalyst .passes .pass_api import Pass
5249from catalyst .utils .patching import Patcher
5350
5451from .qfunc_interpreter import PLxPRToQuantumJaxprInterpreter
@@ -286,7 +283,9 @@ def handle_qnode(
286283 # Fallback to the legacy decomposition if the graph-based decomposition failed
287284 if not graph_succeeded :
288285 # Remove the decompose-lowering pass from the pipeline
289- self ._pass_pipeline = [p for p in self ._pass_pipeline if p .name != "decompose-lowering" ]
286+ self ._pass_pipeline = [
287+ p for p in self ._pass_pipeline if p .pass_name != "decompose-lowering"
288+ ]
290289 closed_jaxpr = _apply_compiler_decompose_to_plxpr (
291290 inner_jaxpr = closed_jaxpr .jaxpr ,
292291 consts = closed_jaxpr .consts ,
@@ -334,11 +333,9 @@ def calling_convention(*args):
334333# otherwise their value will be None. The second value indicates if the transform
335334# requires decomposition to be supported by Catalyst.
336335transforms_to_passes = {
337- pl_cancel_inverses : ("cancel-inverses" , False ),
338336 pl_commute_controlled : (None , False ),
339337 pl_decompose : (None , False ),
340338 pl_merge_amplitude_embedding : (None , True ),
341- pl_merge_rotations : ("merge-rotations" , False ),
342339 pl_single_qubit_fusion : (None , False ),
343340 pl_unitary_to_rot : (None , False ),
344341}
@@ -349,6 +346,47 @@ def register_transform(pl_transform, pass_name, decomposition):
349346 transforms_to_passes [pl_transform ] = (pass_name , decomposition )
350347
351348
349+ def _handle_decompose_transform (self , inner_jaxpr , consts , non_const_args , tkwargs ):
350+ if not self .requires_decompose_lowering :
351+ self .requires_decompose_lowering = True
352+ else :
353+ raise NotImplementedError ("Multiple decomposition transforms are not yet supported." )
354+
355+ next_eval = copy (self )
356+ # Update the decompose_gateset to be used by the quantum kernel primitive
357+ # TODO: we originally wanted to treat decompose_gateset as a queue of
358+ # gatesets to be used by the decompose-lowering pass at MLIR
359+ # but this requires a C++ implementation of the graph-based decomposition
360+ # which doesn't exist yet.
361+ next_eval .decompose_tkwargs = tkwargs
362+
363+ # Note. We don't perform the compiler-specific decomposition here
364+ # to be able to support multiple decomposition transforms
365+ # and collect all the required gatesets
366+ # as well as being able to support other transforms in between.
367+
368+ # The compiler specific transformation will be performed
369+ # in the qnode handler.
370+
371+ # Add the decompose-lowering pass to the start of the pipeline
372+ t = qml .transform (pass_name = "decompose-lowering" )
373+ pass_container = qml .transforms .core .TransformContainer (t )
374+ next_eval ._pass_pipeline .insert (0 , pass_container )
375+
376+ # We still need to construct and solve the graph based on
377+ # the current jaxpr based on the current gateset
378+ # but we don't rewrite the jaxpr at this stage.
379+
380+ # gds_interpreter = DecompRuleInterpreter(*targs, **tkwargs)
381+
382+ # def gds_wrapper(*args):
383+ # return gds_interpreter.eval(inner_jaxpr, consts, *args)
384+
385+ # final_jaxpr = jax.make_jaxpr(gds_wrapper)(*args)
386+ # return self.eval(final_jaxpr.jaxpr, consts, *non_const_args)
387+ return next_eval .eval (inner_jaxpr , consts , * non_const_args )
388+
389+
352390# pylint: disable=too-many-arguments
353391@WorkflowInterpreter .register_primitive (transform_prim )
354392def handle_transform (
@@ -375,45 +413,11 @@ def handle_transform(
375413 and transform ._plxpr_transform .__name__ == "decompose_plxpr_to_plxpr"
376414 and qml .decomposition .enabled_graph ()
377415 ):
378- # Handle the conversion from plxpr to Catalyst jaxpr for a PL transform.
379- if not self .requires_decompose_lowering :
380- self .requires_decompose_lowering = True
381- else :
382- raise NotImplementedError ("Multiple decomposition transforms are not yet supported." )
383-
384- next_eval = copy (self )
385- # Update the decompose_gateset to be used by the quantum kernel primitive
386- # TODO: we originally wanted to treat decompose_gateset as a queue of
387- # gatesets to be used by the decompose-lowering pass at MLIR
388- # but this requires a C++ implementation of the graph-based decomposition
389- # which doesn't exist yet.
390- next_eval .decompose_tkwargs = tkwargs
391-
392- # Note. We don't perform the compiler-specific decomposition here
393- # to be able to support multiple decomposition transforms
394- # and collect all the required gatesets
395- # as well as being able to support other transforms in between.
396-
397- # The compiler specific transformation will be performed
398- # in the qnode handler.
399-
400- # Add the decompose-lowering pass to the start of the pipeline
401- next_eval ._pass_pipeline .insert (0 , Pass ("decompose-lowering" ))
416+ return _handle_decompose_transform (self , inner_jaxpr , consts , non_const_args , tkwargs )
402417
403- # We still need to construct and solve the graph based on
404- # the current jaxpr based on the current gateset
405- # but we don't rewrite the jaxpr at this stage.
406-
407- # gds_interpreter = DecompRuleInterpreter(*targs, **tkwargs)
408-
409- # def gds_wrapper(*args):
410- # return gds_interpreter.eval(inner_jaxpr, consts, *args)
411-
412- # final_jaxpr = jax.make_jaxpr(gds_wrapper)(*args)
413- # return self.eval(final_jaxpr.jaxpr, consts, *non_const_args)
414- return next_eval .eval (inner_jaxpr , consts , * non_const_args )
415-
416- catalyst_pass_name = transforms_to_passes .get (transform , (None ,))[0 ]
418+ catalyst_pass_name = transform .pass_name
419+ if catalyst_pass_name is None :
420+ catalyst_pass_name = transforms_to_passes .get (transform , (None ,))[0 ]
417421 if catalyst_pass_name is None :
418422 # Use PL's ExpandTransformsInterpreter to expand this and any embedded
419423 # transform according to PL rules. It works by overriding the primitive
@@ -435,7 +439,9 @@ def wrapper(*args):
435439
436440 # Apply the corresponding Catalyst pass counterpart
437441 next_eval = copy (self )
438- next_eval ._pass_pipeline .insert (0 , Pass (catalyst_pass_name , * targs , ** tkwargs ))
442+ t = qml .transform (pass_name = catalyst_pass_name )
443+ bound_pass = qml .transforms .core .TransformContainer (t , args = targs , kwargs = tkwargs )
444+ next_eval ._pass_pipeline .insert (0 , bound_pass )
439445 return next_eval .eval (inner_jaxpr , consts , * non_const_args )
440446
441447
0 commit comments