Skip to content

Commit 07369d6

Browse files
authored
Use jax.extend.core.Primitive (#1731)
**Context:** we are currently accessing `Primitive` through a private path. **Description of the Change:** Changes how `Primitive` is accessed to use the preferred path. **Benefits:** Less private paths. **Possible Drawbacks:** None. **Related GitHub Issues:**
1 parent 834b73a commit 07369d6

File tree

1 file changed

+34
-33
lines changed

1 file changed

+34
-33
lines changed

frontend/catalyst/jax_primitives.py

Lines changed: 34 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from jax._src.lib.mlir import ir
3232
from jax._src.lib.mlir.dialects import hlo
3333
from jax.core import AbstractValue
34+
from jax.extend.core import Primitive
3435
from jax.interpreters import mlir
3536
from jax.tree_util import PyTreeDef, tree_unflatten
3637
from jaxlib.hlo_helpers import shape_dtype_to_ir_type
@@ -246,67 +247,67 @@ class MeasurementPlane(Enum):
246247
# Primitives #
247248
##############
248249

249-
zne_p = core.Primitive("zne")
250-
device_init_p = core.Primitive("device_init")
250+
zne_p = Primitive("zne")
251+
device_init_p = Primitive("device_init")
251252
device_init_p.multiple_results = True
252-
device_release_p = core.Primitive("device_release")
253+
device_release_p = Primitive("device_release")
253254
device_release_p.multiple_results = True
254-
qalloc_p = core.Primitive("qalloc")
255-
qdealloc_p = core.Primitive("qdealloc")
255+
qalloc_p = Primitive("qalloc")
256+
qdealloc_p = Primitive("qdealloc")
256257
qdealloc_p.multiple_results = True
257-
qextract_p = core.Primitive("qextract")
258-
qinsert_p = core.Primitive("qinsert")
259-
gphase_p = core.Primitive("gphase")
258+
qextract_p = Primitive("qextract")
259+
qinsert_p = Primitive("qinsert")
260+
gphase_p = Primitive("gphase")
260261
gphase_p.multiple_results = True
261-
qinst_p = core.Primitive("qinst")
262+
qinst_p = Primitive("qinst")
262263
qinst_p.multiple_results = True
263-
unitary_p = core.Primitive("unitary")
264+
unitary_p = Primitive("unitary")
264265
unitary_p.multiple_results = True
265-
measure_p = core.Primitive("measure")
266+
measure_p = Primitive("measure")
266267
measure_p.multiple_results = True
267-
compbasis_p = core.Primitive("compbasis")
268-
namedobs_p = core.Primitive("namedobs")
269-
hermitian_p = core.Primitive("hermitian")
270-
tensorobs_p = core.Primitive("tensorobs")
271-
hamiltonian_p = core.Primitive("hamiltonian")
272-
sample_p = core.Primitive("sample")
273-
counts_p = core.Primitive("counts")
268+
compbasis_p = Primitive("compbasis")
269+
namedobs_p = Primitive("namedobs")
270+
hermitian_p = Primitive("hermitian")
271+
tensorobs_p = Primitive("tensorobs")
272+
hamiltonian_p = Primitive("hamiltonian")
273+
sample_p = Primitive("sample")
274+
counts_p = Primitive("counts")
274275
counts_p.multiple_results = True
275-
expval_p = core.Primitive("expval")
276-
var_p = core.Primitive("var")
277-
probs_p = core.Primitive("probs")
278-
state_p = core.Primitive("state")
276+
expval_p = Primitive("expval")
277+
var_p = Primitive("var")
278+
probs_p = Primitive("probs")
279+
state_p = Primitive("state")
279280
cond_p = DynshapePrimitive("cond")
280281
cond_p.multiple_results = True
281282
while_p = DynshapePrimitive("while_loop")
282283
while_p.multiple_results = True
283284
for_p = DynshapePrimitive("for_loop")
284285
for_p.multiple_results = True
285-
grad_p = core.Primitive("grad")
286+
grad_p = Primitive("grad")
286287
grad_p.multiple_results = True
287288
func_p = core.CallPrimitive("func")
288289
func_p.multiple_results = True
289-
jvp_p = core.Primitive("jvp")
290+
jvp_p = Primitive("jvp")
290291
jvp_p.multiple_results = True
291-
vjp_p = core.Primitive("vjp")
292+
vjp_p = Primitive("vjp")
292293
vjp_p.multiple_results = True
293-
adjoint_p = jax.core.Primitive("adjoint")
294+
adjoint_p = Primitive("adjoint")
294295
adjoint_p.multiple_results = True
295-
print_p = jax.core.Primitive("debug_print")
296+
print_p = Primitive("debug_print")
296297
print_p.multiple_results = True
297-
python_callback_p = core.Primitive("python_callback")
298+
python_callback_p = Primitive("python_callback")
298299
python_callback_p.multiple_results = True
299-
value_and_grad_p = core.Primitive("value_and_grad")
300+
value_and_grad_p = Primitive("value_and_grad")
300301
value_and_grad_p.multiple_results = True
301-
assert_p = core.Primitive("assert")
302+
assert_p = Primitive("assert")
302303
assert_p.multiple_results = True
303-
set_state_p = jax.core.Primitive("state_prep")
304+
set_state_p = Primitive("state_prep")
304305
set_state_p.multiple_results = True
305-
set_basis_state_p = jax.core.Primitive("set_basis_state")
306+
set_basis_state_p = Primitive("set_basis_state")
306307
set_basis_state_p.multiple_results = True
307308
quantum_kernel_p = core.CallPrimitive("quantum_kernel")
308309
quantum_kernel_p.multiple_results = True
309-
measure_in_basis_p = core.Primitive("measure_in_basis")
310+
measure_in_basis_p = Primitive("measure_in_basis")
310311
measure_in_basis_p.multiple_results = True
311312

312313

0 commit comments

Comments
 (0)