diff --git a/frontend/catalyst/from_plxpr/from_plxpr.py b/frontend/catalyst/from_plxpr/from_plxpr.py index 6ae7e5a0d3..b514b7f3fe 100644 --- a/frontend/catalyst/from_plxpr/from_plxpr.py +++ b/frontend/catalyst/from_plxpr/from_plxpr.py @@ -45,7 +45,7 @@ from pennylane.transforms import single_qubit_fusion as pl_single_qubit_fusion from pennylane.transforms import unitary_to_rot as pl_unitary_to_rot -from catalyst.device import extract_backend_info +from catalyst.device import extract_backend_info, get_device_capabilities from catalyst.from_plxpr.decompose import COMPILER_OPS_FOR_DECOMPOSITION, DecompRuleInterpreter from catalyst.from_plxpr.qubit_handler import QubitHandler, QubitIndexRecorder, get_in_qubit_values from catalyst.jax_extras import jaxpr_pad_consts, make_jaxpr2, transient_jax_config @@ -192,6 +192,16 @@ def __init__(self): super().__init__() +def _decompose_jaxpr_to_gateset(qfunc_jaxpr, consts, device): + gate_set = set(get_device_capabilities(device).operations) + if get_device_capabilities(device).initial_state_prep: + gate_set.add("StatePrep") + targs = () + tkwargs = {"gate_set": gate_set} + breakpoint() + return qml.transforms.decompose.plxpr_transform(qfunc_jaxpr, consts, targs, tkwargs) + + # pylint: disable=unused-argument, too-many-arguments @WorkflowInterpreter.register_primitive(qnode_prim) def handle_qnode( @@ -208,6 +218,10 @@ def handle_qnode( consts = args[shots_len : n_consts + shots_len] non_const_args = args[shots_len + n_consts :] + # hopefully this patch stays patchy and doesn't become permanent + # TODO: Too much has changed within this function, need to rework the patch + closed_jaxpr = _decompose_jaxpr_to_gateset(qfunc_jaxpr, consts, device) + closed_jaxpr = ( ClosedJaxpr(qfunc_jaxpr, consts) if not self.requires_decompose_lowering