|
31 | 31 | from jax._src.lib.mlir import ir |
32 | 32 | from jax._src.lib.mlir.dialects import hlo |
33 | 33 | from jax.core import AbstractValue |
| 34 | +from jax.extend.core import Primitive |
34 | 35 | from jax.interpreters import mlir |
35 | 36 | from jax.tree_util import PyTreeDef, tree_unflatten |
36 | 37 | from jaxlib.hlo_helpers import shape_dtype_to_ir_type |
@@ -246,67 +247,67 @@ class MeasurementPlane(Enum): |
246 | 247 | # Primitives # |
247 | 248 | ############## |
248 | 249 |
|
249 | | -zne_p = core.Primitive("zne") |
250 | | -device_init_p = core.Primitive("device_init") |
| 250 | +zne_p = Primitive("zne") |
| 251 | +device_init_p = Primitive("device_init") |
251 | 252 | device_init_p.multiple_results = True |
252 | | -device_release_p = core.Primitive("device_release") |
| 253 | +device_release_p = Primitive("device_release") |
253 | 254 | device_release_p.multiple_results = True |
254 | | -qalloc_p = core.Primitive("qalloc") |
255 | | -qdealloc_p = core.Primitive("qdealloc") |
| 255 | +qalloc_p = Primitive("qalloc") |
| 256 | +qdealloc_p = Primitive("qdealloc") |
256 | 257 | qdealloc_p.multiple_results = True |
257 | | -qextract_p = core.Primitive("qextract") |
258 | | -qinsert_p = core.Primitive("qinsert") |
259 | | -gphase_p = core.Primitive("gphase") |
| 258 | +qextract_p = Primitive("qextract") |
| 259 | +qinsert_p = Primitive("qinsert") |
| 260 | +gphase_p = Primitive("gphase") |
260 | 261 | gphase_p.multiple_results = True |
261 | | -qinst_p = core.Primitive("qinst") |
| 262 | +qinst_p = Primitive("qinst") |
262 | 263 | qinst_p.multiple_results = True |
263 | | -unitary_p = core.Primitive("unitary") |
| 264 | +unitary_p = Primitive("unitary") |
264 | 265 | unitary_p.multiple_results = True |
265 | | -measure_p = core.Primitive("measure") |
| 266 | +measure_p = Primitive("measure") |
266 | 267 | measure_p.multiple_results = True |
267 | | -compbasis_p = core.Primitive("compbasis") |
268 | | -namedobs_p = core.Primitive("namedobs") |
269 | | -hermitian_p = core.Primitive("hermitian") |
270 | | -tensorobs_p = core.Primitive("tensorobs") |
271 | | -hamiltonian_p = core.Primitive("hamiltonian") |
272 | | -sample_p = core.Primitive("sample") |
273 | | -counts_p = core.Primitive("counts") |
| 268 | +compbasis_p = Primitive("compbasis") |
| 269 | +namedobs_p = Primitive("namedobs") |
| 270 | +hermitian_p = Primitive("hermitian") |
| 271 | +tensorobs_p = Primitive("tensorobs") |
| 272 | +hamiltonian_p = Primitive("hamiltonian") |
| 273 | +sample_p = Primitive("sample") |
| 274 | +counts_p = Primitive("counts") |
274 | 275 | counts_p.multiple_results = True |
275 | | -expval_p = core.Primitive("expval") |
276 | | -var_p = core.Primitive("var") |
277 | | -probs_p = core.Primitive("probs") |
278 | | -state_p = core.Primitive("state") |
| 276 | +expval_p = Primitive("expval") |
| 277 | +var_p = Primitive("var") |
| 278 | +probs_p = Primitive("probs") |
| 279 | +state_p = Primitive("state") |
279 | 280 | cond_p = DynshapePrimitive("cond") |
280 | 281 | cond_p.multiple_results = True |
281 | 282 | while_p = DynshapePrimitive("while_loop") |
282 | 283 | while_p.multiple_results = True |
283 | 284 | for_p = DynshapePrimitive("for_loop") |
284 | 285 | for_p.multiple_results = True |
285 | | -grad_p = core.Primitive("grad") |
| 286 | +grad_p = Primitive("grad") |
286 | 287 | grad_p.multiple_results = True |
287 | 288 | func_p = core.CallPrimitive("func") |
288 | 289 | func_p.multiple_results = True |
289 | | -jvp_p = core.Primitive("jvp") |
| 290 | +jvp_p = Primitive("jvp") |
290 | 291 | jvp_p.multiple_results = True |
291 | | -vjp_p = core.Primitive("vjp") |
| 292 | +vjp_p = Primitive("vjp") |
292 | 293 | vjp_p.multiple_results = True |
293 | | -adjoint_p = jax.core.Primitive("adjoint") |
| 294 | +adjoint_p = Primitive("adjoint") |
294 | 295 | adjoint_p.multiple_results = True |
295 | | -print_p = jax.core.Primitive("debug_print") |
| 296 | +print_p = Primitive("debug_print") |
296 | 297 | print_p.multiple_results = True |
297 | | -python_callback_p = core.Primitive("python_callback") |
| 298 | +python_callback_p = Primitive("python_callback") |
298 | 299 | python_callback_p.multiple_results = True |
299 | | -value_and_grad_p = core.Primitive("value_and_grad") |
| 300 | +value_and_grad_p = Primitive("value_and_grad") |
300 | 301 | value_and_grad_p.multiple_results = True |
301 | | -assert_p = core.Primitive("assert") |
| 302 | +assert_p = Primitive("assert") |
302 | 303 | assert_p.multiple_results = True |
303 | | -set_state_p = jax.core.Primitive("state_prep") |
| 304 | +set_state_p = Primitive("state_prep") |
304 | 305 | set_state_p.multiple_results = True |
305 | | -set_basis_state_p = jax.core.Primitive("set_basis_state") |
| 306 | +set_basis_state_p = Primitive("set_basis_state") |
306 | 307 | set_basis_state_p.multiple_results = True |
307 | 308 | quantum_kernel_p = core.CallPrimitive("quantum_kernel") |
308 | 309 | quantum_kernel_p.multiple_results = True |
309 | | -measure_in_basis_p = core.Primitive("measure_in_basis") |
| 310 | +measure_in_basis_p = Primitive("measure_in_basis") |
310 | 311 | measure_in_basis_p.multiple_results = True |
311 | 312 |
|
312 | 313 |
|
|
0 commit comments