From eae156c41a8a7fd438d0872dab92e28529959653 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Wed, 6 Aug 2025 17:45:38 -0400 Subject: [PATCH 01/10] decompose to device gateset in from_plxpr --- frontend/catalyst/from_plxpr/from_plxpr.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/frontend/catalyst/from_plxpr/from_plxpr.py b/frontend/catalyst/from_plxpr/from_plxpr.py index c044d0f832..1bd4f4baff 100644 --- a/frontend/catalyst/from_plxpr/from_plxpr.py +++ b/frontend/catalyst/from_plxpr/from_plxpr.py @@ -182,6 +182,12 @@ def __init__(self): super().__init__() +def _decompose_jaxpr_to_gateset(qfunc_jaxpr, consts, device): + gate_set = set(device.capabilities.operations) + targs = () + tkwargs = {"gate_set": gate_set} + return qml.transforms.decompose.plxpr_transform(qfunc_jaxpr, consts, targs, tkwargs) + # pylint: disable=unused-argument, too-many-arguments @WorkflowInterpreter.register_primitive(qnode_prim) def handle_qnode( @@ -191,7 +197,8 @@ def handle_qnode( consts = args[:n_consts] non_const_args = args[n_consts:] - closed_jaxpr = ClosedJaxpr(qfunc_jaxpr, consts) + # hopefully this patch stays patchy and doesn't become permanent + closed_jaxpr = _decompose_jaxpr_to_gateset(qfunc_jaxpr, consts, device) def extract_shots_value(shots: qml.measurements.Shots | int): """Extract the shots value according to the type""" From e5edece6d8634abf155ebc6244cdace62466a1b2 Mon Sep 17 00:00:00 2001 From: Joseph Lee Date: Thu, 7 Aug 2025 18:43:01 +0000 Subject: [PATCH 02/10] add singleexcitation qis --- runtime/include/RuntimeCAPI.h | 1 + runtime/lib/capi/RuntimeCAPI.cpp | 8 ++++++++ 2 files changed, 9 insertions(+) diff --git a/runtime/include/RuntimeCAPI.h b/runtime/include/RuntimeCAPI.h index bceffb7b7d..83366a7df4 100644 --- a/runtime/include/RuntimeCAPI.h +++ b/runtime/include/RuntimeCAPI.h @@ -82,6 +82,7 @@ void __catalyst__qis__MultiRZ(double, const Modifiers *, int64_t, /*qubits*/...) void __catalyst__qis__GlobalPhase(double, const Modifiers *); void __catalyst__qis__ISWAP(QUBIT *, QUBIT *, const Modifiers *); void __catalyst__qis__PSWAP(double, QUBIT *, QUBIT *, const Modifiers *); +void __catalyst__qis__SingleExcitation(double, QUBIT *, QUBIT *, const Modifiers *); // Struct pointer arguments for these instructions represent real arguments, // as passing structs by value is too unreliable / compiler dependant. diff --git a/runtime/lib/capi/RuntimeCAPI.cpp b/runtime/lib/capi/RuntimeCAPI.cpp index 26f8245ab5..da3ab166a5 100644 --- a/runtime/lib/capi/RuntimeCAPI.cpp +++ b/runtime/lib/capi/RuntimeCAPI.cpp @@ -777,6 +777,14 @@ void __catalyst__qis__PSWAP(double phi, QUBIT *wire0, QUBIT *wire1, const Modifi MODIFIERS_ARGS(modifiers)); } +void __catalyst__qis__SingleExcitation(double phi, QUBIT *wire0, QUBIT *wire1, const Modifiers *modifiers) +{ + getQuantumDevicePtr()->NamedOperation( + "SingleExcitation", {phi}, + {reinterpret_cast(wire0), reinterpret_cast(wire1)}, + MODIFIERS_ARGS(modifiers)); +} + static void _qubitUnitary_impl(MemRefT_CplxT_double_2d *matrix, int64_t numQubits, std::vector> &coeffs, std::vector &wires, va_list *args) From 0a815987c60cf5b0502cb51328d591fb04322605 Mon Sep 17 00:00:00 2001 From: Joseph Lee Date: Thu, 7 Aug 2025 18:43:11 +0000 Subject: [PATCH 03/10] comment out eliminate-empty-tensor --- frontend/catalyst/pipelines.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/catalyst/pipelines.py b/frontend/catalyst/pipelines.py index 257b0df0fa..edc77a3891 100644 --- a/frontend/catalyst/pipelines.py +++ b/frontend/catalyst/pipelines.py @@ -274,7 +274,7 @@ def get_bufferization_stage(options: CompileOptions) -> List[str]: "convert-tensor-to-linalg", # tensor.pad "convert-elementwise-to-linalg", # Must be run before --one-shot-bufferize "gradient-preprocess", - "eliminate-empty-tensors", + #"eliminate-empty-tensors", #################### "one-shot-bufferize{" + bufferization_options + "}", #################### From cf00833a90d2fb69f59708f0f5deeb0a1a26b56d Mon Sep 17 00:00:00 2001 From: Jake Zaia Date: Mon, 11 Aug 2025 21:30:01 +0000 Subject: [PATCH 04/10] Change how capabilities are gathered to make compatible with null.qubit --- frontend/catalyst/from_plxpr/from_plxpr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/catalyst/from_plxpr/from_plxpr.py b/frontend/catalyst/from_plxpr/from_plxpr.py index 1bd4f4baff..21d06992f8 100644 --- a/frontend/catalyst/from_plxpr/from_plxpr.py +++ b/frontend/catalyst/from_plxpr/from_plxpr.py @@ -183,7 +183,7 @@ def __init__(self): def _decompose_jaxpr_to_gateset(qfunc_jaxpr, consts, device): - gate_set = set(device.capabilities.operations) + gate_set = set(get_device_capabilities(device).operations) targs = () tkwargs = {"gate_set": gate_set} return qml.transforms.decompose.plxpr_transform(qfunc_jaxpr, consts, targs, tkwargs) From 1bcf6aa6fcaa0876d2827f25181f77134cda9edc Mon Sep 17 00:00:00 2001 From: Joseph Lee Date: Tue, 12 Aug 2025 15:50:22 +0000 Subject: [PATCH 05/10] add StatePrep to supported gate set --- frontend/catalyst/from_plxpr/from_plxpr.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/frontend/catalyst/from_plxpr/from_plxpr.py b/frontend/catalyst/from_plxpr/from_plxpr.py index 21d06992f8..223870e44b 100644 --- a/frontend/catalyst/from_plxpr/from_plxpr.py +++ b/frontend/catalyst/from_plxpr/from_plxpr.py @@ -184,6 +184,8 @@ def __init__(self): def _decompose_jaxpr_to_gateset(qfunc_jaxpr, consts, device): gate_set = set(get_device_capabilities(device).operations) + if get_device_capabilities(device).initial_state_prep: + gate_set.add("StatePrep") targs = () tkwargs = {"gate_set": gate_set} return qml.transforms.decompose.plxpr_transform(qfunc_jaxpr, consts, targs, tkwargs) From a5ef3ecfacd86f139b7e0ead540c16d8ca596b00 Mon Sep 17 00:00:00 2001 From: Joseph Lee Date: Tue, 12 Aug 2025 15:53:35 +0000 Subject: [PATCH 06/10] format --- frontend/catalyst/from_plxpr/from_plxpr.py | 1 + runtime/lib/capi/RuntimeCAPI.cpp | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/frontend/catalyst/from_plxpr/from_plxpr.py b/frontend/catalyst/from_plxpr/from_plxpr.py index 223870e44b..b5f3d68e7e 100644 --- a/frontend/catalyst/from_plxpr/from_plxpr.py +++ b/frontend/catalyst/from_plxpr/from_plxpr.py @@ -190,6 +190,7 @@ def _decompose_jaxpr_to_gateset(qfunc_jaxpr, consts, device): tkwargs = {"gate_set": gate_set} return qml.transforms.decompose.plxpr_transform(qfunc_jaxpr, consts, targs, tkwargs) + # pylint: disable=unused-argument, too-many-arguments @WorkflowInterpreter.register_primitive(qnode_prim) def handle_qnode( diff --git a/runtime/lib/capi/RuntimeCAPI.cpp b/runtime/lib/capi/RuntimeCAPI.cpp index da3ab166a5..9ffe5c7fac 100644 --- a/runtime/lib/capi/RuntimeCAPI.cpp +++ b/runtime/lib/capi/RuntimeCAPI.cpp @@ -777,7 +777,8 @@ void __catalyst__qis__PSWAP(double phi, QUBIT *wire0, QUBIT *wire1, const Modifi MODIFIERS_ARGS(modifiers)); } -void __catalyst__qis__SingleExcitation(double phi, QUBIT *wire0, QUBIT *wire1, const Modifiers *modifiers) +void __catalyst__qis__SingleExcitation(double phi, QUBIT *wire0, QUBIT *wire1, + const Modifiers *modifiers) { getQuantumDevicePtr()->NamedOperation( "SingleExcitation", {phi}, From 47b4d72f553175972e7f2de063da8229ab7b36bd Mon Sep 17 00:00:00 2001 From: Joseph Lee Date: Tue, 26 Aug 2025 21:16:36 +0000 Subject: [PATCH 07/10] remove singleexcitation --- runtime/include/RuntimeCAPI.h | 1 - runtime/lib/capi/RuntimeCAPI.cpp | 9 --------- 2 files changed, 10 deletions(-) diff --git a/runtime/include/RuntimeCAPI.h b/runtime/include/RuntimeCAPI.h index ac47e02ce2..1333527c9e 100644 --- a/runtime/include/RuntimeCAPI.h +++ b/runtime/include/RuntimeCAPI.h @@ -86,7 +86,6 @@ void __catalyst__qis__GlobalPhase(double, const Modifiers *); void __catalyst__qis__PCPhase(double, double, const Modifiers *, int64_t, /*qubits*/...); void __catalyst__qis__ISWAP(QUBIT *, QUBIT *, const Modifiers *); void __catalyst__qis__PSWAP(double, QUBIT *, QUBIT *, const Modifiers *); -void __catalyst__qis__SingleExcitation(double, QUBIT *, QUBIT *, const Modifiers *); // Struct pointer arguments for these instructions represent real arguments, // as passing structs by value is too unreliable / compiler dependant. diff --git a/runtime/lib/capi/RuntimeCAPI.cpp b/runtime/lib/capi/RuntimeCAPI.cpp index 10426e6593..e03a292939 100644 --- a/runtime/lib/capi/RuntimeCAPI.cpp +++ b/runtime/lib/capi/RuntimeCAPI.cpp @@ -814,15 +814,6 @@ void __catalyst__qis__PSWAP(double phi, QUBIT *wire0, QUBIT *wire1, const Modifi MODIFIERS_ARGS(modifiers)); } -void __catalyst__qis__SingleExcitation(double phi, QUBIT *wire0, QUBIT *wire1, - const Modifiers *modifiers) -{ - getQuantumDevicePtr()->NamedOperation( - "SingleExcitation", {phi}, - {reinterpret_cast(wire0), reinterpret_cast(wire1)}, - MODIFIERS_ARGS(modifiers)); -} - static void _qubitUnitary_impl(MemRefT_CplxT_double_2d *matrix, int64_t numQubits, std::vector> &coeffs, std::vector &wires, va_list *args) From 01135a1a1aeeae69f0a200f016a09a04dadb2161 Mon Sep 17 00:00:00 2001 From: Jake Zaia Date: Wed, 3 Sep 2025 16:23:45 +0000 Subject: [PATCH 08/10] Fix get_device_capabilities --- frontend/catalyst/from_plxpr/from_plxpr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/catalyst/from_plxpr/from_plxpr.py b/frontend/catalyst/from_plxpr/from_plxpr.py index de88149fb7..b5f3d68e7e 100644 --- a/frontend/catalyst/from_plxpr/from_plxpr.py +++ b/frontend/catalyst/from_plxpr/from_plxpr.py @@ -43,7 +43,7 @@ from pennylane.transforms import single_qubit_fusion as pl_single_qubit_fusion from pennylane.transforms import unitary_to_rot as pl_unitary_to_rot -from catalyst.device import extract_backend_info +from catalyst.device import extract_backend_info, get_device_capabilities from catalyst.from_plxpr.qreg_manager import QregManager from catalyst.jax_extras import jaxpr_pad_consts, make_jaxpr2, transient_jax_config from catalyst.jax_primitives import ( From bbecb78e1a35a6ed505e03f58afca2f505b692f8 Mon Sep 17 00:00:00 2001 From: Jake Zaia Date: Wed, 1 Oct 2025 16:34:28 +0000 Subject: [PATCH 09/10] Remove unused imports --- frontend/catalyst/from_plxpr/from_plxpr.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/frontend/catalyst/from_plxpr/from_plxpr.py b/frontend/catalyst/from_plxpr/from_plxpr.py index c6d158a602..bfec77ed57 100644 --- a/frontend/catalyst/from_plxpr/from_plxpr.py +++ b/frontend/catalyst/from_plxpr/from_plxpr.py @@ -46,8 +46,6 @@ from pennylane.transforms import unitary_to_rot as pl_unitary_to_rot from catalyst.device import extract_backend_info, get_device_capabilities -from catalyst.from_plxpr.qreg_manager import QregManager -from catalyst.device import extract_backend_info from catalyst.from_plxpr.decompose import COMPILER_OPS_FOR_DECOMPOSITION, DecompRuleInterpreter from catalyst.from_plxpr.qubit_handler import QubitHandler, QubitIndexRecorder, get_in_qubit_values from catalyst.jax_extras import jaxpr_pad_consts, make_jaxpr2, transient_jax_config From f72b35ac58fe798b000b3d081970b7426d712a4e Mon Sep 17 00:00:00 2001 From: Jake Zaia Date: Wed, 1 Oct 2025 19:35:43 +0000 Subject: [PATCH 10/10] move patch --- frontend/catalyst/from_plxpr/from_plxpr.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/frontend/catalyst/from_plxpr/from_plxpr.py b/frontend/catalyst/from_plxpr/from_plxpr.py index bfec77ed57..b514b7f3fe 100644 --- a/frontend/catalyst/from_plxpr/from_plxpr.py +++ b/frontend/catalyst/from_plxpr/from_plxpr.py @@ -198,6 +198,7 @@ def _decompose_jaxpr_to_gateset(qfunc_jaxpr, consts, device): gate_set.add("StatePrep") targs = () tkwargs = {"gate_set": gate_set} + breakpoint() return qml.transforms.decompose.plxpr_transform(qfunc_jaxpr, consts, targs, tkwargs) @@ -208,9 +209,6 @@ def handle_qnode( ): """Handle the conversion from plxpr to Catalyst jaxpr for the qnode primitive""" - # hopefully this patch stays patchy and doesn't become permanent - closed_jaxpr = _decompose_jaxpr_to_gateset(qfunc_jaxpr, consts, device) - self.qubit_index_recorder = QubitIndexRecorder() if shots_len > 1: @@ -220,6 +218,10 @@ def handle_qnode( consts = args[shots_len : n_consts + shots_len] non_const_args = args[shots_len + n_consts :] + # hopefully this patch stays patchy and doesn't become permanent + # TODO: Too much has changed within this function, need to rework the patch + closed_jaxpr = _decompose_jaxpr_to_gateset(qfunc_jaxpr, consts, device) + closed_jaxpr = ( ClosedJaxpr(qfunc_jaxpr, consts) if not self.requires_decompose_lowering