Skip to content

Commit dcdf896

Browse files
paul0403erick-xanadumehrdad2m
authored
Update jax to 0.5.3 (#1652)
**Context:** We update the jax version to 0.5.3. For the detailed changes, see the corresponding PRs: - LAPACK kernels are updated to adhere to the new JAX lowering rules for external functions. #1685 - The trace stack is removed and replaced with a tracing context manager. #1662 - A new `debug_info` argument/field is added to `Jaxpr`, the `make_jaxpr` functions, and `jax.extend.linear_util.wrap_init`. #1670 #1671 #1681 [sc-88694] --------- Co-authored-by: Erick Ochoa Lopez <[email protected]> Co-authored-by: erick-xanadu <[email protected]> Co-authored-by: Mehrdad Malek <[email protected]> Co-authored-by: Mehrdad Malekmohammadi <[email protected]>
1 parent 77b282a commit dcdf896

36 files changed

+794
-648
lines changed

.dep-versions

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Always update the version check in catalyst.__init__ when changing the JAX version.
2-
jax=0.4.28
2+
jax=0.5.3
33
mhlo=89a891c986650c33df76885f5620e0a92150d90f
44
llvm=3a8316216807d64a586b971f51695e23883331f7
55
enzyme=v0.0.149

doc/releases/changelog-dev.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,22 @@
7676
will handle direct decomposition of PPRs into PPMs.
7777
[(#1688)](https://github.com/PennyLaneAI/catalyst/pull/1688)
7878

79+
* The version of JAX used by Catalyst is updated to 0.5.3.
80+
[(#1652)](https://github.com/PennyLaneAI/catalyst/pull/1652)
81+
82+
Several internal changes were made for this update.
83+
- LAPACK kernels are updated to adhere to the new JAX lowering rules for external functions.
84+
[(#1685)](https://github.com/PennyLaneAI/catalyst/pull/1685)
85+
86+
- The trace stack is removed and replaced with a tracing context manager.
87+
[(#1662)](https://github.com/PennyLaneAI/catalyst/pull/1662)
88+
89+
- A new `debug_info` argument is added to `Jaxpr`, the `make_jaxpr`
90+
functions, and `jax.extend.linear_util.wrap_init`.
91+
[(#1670)](https://github.com/PennyLaneAI/catalyst/pull/1670)
92+
[(#1671)](https://github.com/PennyLaneAI/catalyst/pull/1671)
93+
[(#1681)](https://github.com/PennyLaneAI/catalyst/pull/1681)
94+
7995
<h3>Deprecations 👋</h3>
8096

8197
<h3>Bug fixes 🐛</h3>
@@ -150,4 +166,5 @@ David Ittah,
150166
Tzung-Han Juang,
151167
Christina Lee,
152168
Erick Ochoa Lopez,
169+
Mehrdad Malekmohammadi,
153170
Paul Haochen Wang.

frontend/catalyst/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
import jaxlib as _jaxlib
2525

26-
_jaxlib_version = "0.4.28"
26+
_jaxlib_version = "0.5.3"
2727
if _jaxlib.__version__ != _jaxlib_version:
2828
import warnings
2929

frontend/catalyst/api_extensions/callbacks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
import jax
2929
import jax.numpy as jnp
30-
from jax._src.api_util import shaped_abstractify
30+
from jax._src.core import shaped_abstractify
3131
from jax._src.tree_util import (
3232
Partial,
3333
tree_flatten,

frontend/catalyst/api_extensions/control_flow.py

Lines changed: 83 additions & 59 deletions
Large diffs are not rendered by default.

frontend/catalyst/api_extensions/differentiation.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@
2424
from typing import Callable, Iterable, List, Optional, Union
2525

2626
import jax
27-
import jax.numpy as jnp
2827
from jax._src.api import _dtype
2928
from jax._src.tree_util import PyTreeDef, tree_flatten, tree_unflatten
29+
from jax.api_util import debug_info
3030
from pennylane import QNode
3131

3232
import catalyst
@@ -46,6 +46,7 @@
4646
from catalyst.tracing.contexts import EvaluationContext, GradContext
4747
from catalyst.utils.callables import CatalystCallable
4848
from catalyst.utils.exceptions import DifferentiableCompileError
49+
from catalyst.utils.types import get_shape
4950

5051
Differentiable = Union[Function, QNode]
5152

@@ -474,10 +475,10 @@ def check_is_iterable(x, hint):
474475
f"{_dtype(p)}, but got tangent dtype {_dtype(t)} instead."
475476
)
476477

477-
if jnp.shape(p) != jnp.shape(t):
478+
if get_shape(p) != get_shape(t):
478479
raise ValueError(
479480
"catalyst.jvp called with different function params and tangent shapes; "
480-
f"got function params shape {jnp.shape(p)} and tangent shape {jnp.shape(t)}"
481+
f"got function params shape {get_shape(p)} and tangent shape {get_shape(t)}"
481482
)
482483

483484
jaxpr, out_tree = _make_jaxpr_check_differentiable(fn, grad_params, *params)
@@ -584,11 +585,11 @@ def check_is_iterable(x, hint):
584585
f"{_dtype(p)}, but got cotangent dtype {_dtype(t)} instead."
585586
)
586587

587-
if jnp.shape(p) != jnp.shape(t):
588+
if get_shape(p) != get_shape(t):
588589
raise ValueError(
589590
"catalyst.vjp called with different function output params and cotangent "
590-
f"shapes; got function output params shape {jnp.shape(p)} and cotangent shape "
591-
f"{jnp.shape(t)}"
591+
f"shapes; got function output params shape {get_shape(p)} and cotangent shape "
592+
f"{get_shape(t)}"
592593
)
593594

594595
cotangents, _ = tree_flatten(cotangents)
@@ -807,7 +808,9 @@ def _make_jaxpr_check_differentiable(
807808
return the output tree."""
808809
method = grad_params.method
809810
with mark_gradient_tracing(method):
810-
jaxpr, _, out_tree = make_jaxpr2(f)(*args, **kwargs)
811+
jaxpr, _, out_tree = make_jaxpr2(
812+
f, debug_info=debug_info("grad make jaxpr", f, args, kwargs)
813+
)(*args, **kwargs)
811814

812815
for pos, arg in enumerate(jaxpr.in_avals):
813816
if arg.dtype.kind != "f" and pos in grad_params.expanded_argnums:

frontend/catalyst/api_extensions/function_maps.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import jax.numpy as jnp
2828
import numpy as np
2929
from jax._src.tree_util import tree_flatten, tree_leaves, tree_structure, tree_unflatten
30+
from jax.api_util import debug_info
3031

3132
from catalyst.api_extensions.control_flow import for_loop
3233
from catalyst.jax_extras import make_jaxpr2
@@ -227,7 +228,9 @@ def __call__(self, *args, **kwargs):
227228
fn_args = tree_unflatten(args_tree, fn_args_flat)
228229

229230
# Run 'fn' one time to get output-shape
230-
_, shapes, init_result_tree = make_jaxpr2(self.fn)(*fn_args, **kwargs)
231+
_, shapes, init_result_tree = make_jaxpr2(
232+
self.fn, debug_info=debug_info("vmap", self.fn, args, kwargs)
233+
)(*fn_args, **kwargs)
231234

232235
init_result_flat = [jnp.zeros(shape=shape.shape, dtype=shape.dtype) for shape, _ in shapes]
233236
init_result = tree_unflatten(init_result_tree, init_result_flat)

frontend/catalyst/api_extensions/quantum_operators.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import jax
2727
import pennylane as qml
2828
from jax._src.tree_util import tree_flatten
29+
from jax.api_util import debug_info
2930
from jax.core import get_aval
3031
from pennylane import QueuingManager
3132
from pennylane.operation import Operator
@@ -146,7 +147,7 @@ def circuit():
146147
EvaluationContext.check_is_quantum_tracing(
147148
"catalyst.measure can only be used from within a qml.qnode."
148149
)
149-
ctx = EvaluationContext.get_main_tracing_context()
150+
cur_trace = EvaluationContext.get_current_trace()
150151
wires = list(wires) if isinstance(wires, (list, tuple)) else [wires]
151152
if len(wires) != 1:
152153
raise TypeError(f"Only one element is supported for the 'wires' parameter, got {wires}.")
@@ -161,7 +162,7 @@ def circuit():
161162
if postselect is not None and postselect not in [0, 1]:
162163
raise TypeError(f"postselect must be '0' or '1', got {postselect}")
163164

164-
m = new_inner_tracer(ctx.trace, get_aval(True))
165+
m = new_inner_tracer(cur_trace, get_aval(True))
165166
MidCircuitMeasure(
166167
in_classical_tracers=in_classical_tracers,
167168
out_classical_tracers=[m],
@@ -344,14 +345,12 @@ def trace_quantum(self, ctx, device, trace, qrp, postselect_mode=None) -> QRegPr
344345
qubit = qrp.extract(self.wires)[0]
345346
if postselect_mode == "hw-like":
346347
qubit2 = self.bind_overwrite_classical_tracers(
347-
ctx,
348348
trace,
349349
in_expanded_tracers=[qubit],
350350
out_expanded_tracers=self.out_classical_tracers,
351351
)
352352
else:
353353
qubit2 = self.bind_overwrite_classical_tracers(
354-
ctx,
355354
trace,
356355
in_expanded_tracers=[qubit],
357356
out_expanded_tracers=self.out_classical_tracers,
@@ -405,12 +404,14 @@ def __call__(self, *args, **kwargs):
405404
return create_adjoint_op(base_op, self.lazy)
406405

407406
tracing_artifacts = self.trace_body(args, kwargs)
408-
409-
return HybridAdjoint(*tracing_artifacts)
407+
dbg = tracing_artifacts[3]
408+
return HybridAdjoint(*tracing_artifacts[0:3], debug_info=dbg)
410409

411410
def trace_body(self, args, kwargs):
412411
"""Generate a HybridOpRegion for use by Catalyst."""
413412

413+
dbg = debug_info("AdjointCallable", self.target, args, kwargs)
414+
414415
# Allow the creation of HybridAdjoint instances outside of any contexts.
415416
# Don't create a JAX context here as otherwise we could be dealing with escaped tracers.
416417
if not EvaluationContext.is_tracing():
@@ -420,18 +421,17 @@ def trace_body(self, args, kwargs):
420421

421422
adjoint_region = HybridOpRegion(None, quantum_tape, [], [])
422423

423-
return [], [], [adjoint_region]
424+
return [], [], [adjoint_region], dbg
424425

425426
# Create a nested jaxpr scope for the body of the adjoint.
426-
ctx = EvaluationContext.get_main_tracing_context()
427-
with EvaluationContext.frame_tracing_context(ctx) as inner_trace:
427+
with EvaluationContext.frame_tracing_context(debug_info=dbg) as inner_trace:
428428
in_classical_tracers, _ = tree_flatten((args, kwargs))
429-
wffa, in_avals, _, _ = deduce_avals(self.target, args, kwargs)
429+
wffa, in_avals, _, _ = deduce_avals(self.target, args, kwargs, debug_info=dbg)
430430
arg_classical_tracers = _input_type_to_tracers(inner_trace.new_arg, in_avals)
431431
with QueuingManager.stop_recording(), QuantumTape() as quantum_tape:
432-
# FIXME: move all full_raise calls into a separate function
432+
# FIXME: move all to_jaxpr_tracer calls into a separate function
433433
res_classical_tracers = [
434-
inner_trace.full_raise(t)
434+
inner_trace.to_jaxpr_tracer(t)
435435
for t in wffa.call_wrapped(*arg_classical_tracers)
436436
if isinstance(t, DynamicJaxprTracer)
437437
]
@@ -442,7 +442,7 @@ def trace_body(self, args, kwargs):
442442
inner_trace, quantum_tape, arg_classical_tracers, res_classical_tracers
443443
)
444444

445-
return in_classical_tracers, [], [adjoint_region]
445+
return in_classical_tracers, [], [adjoint_region], wffa.debug_info
446446

447447

448448
class HybridAdjoint(HybridOp):
@@ -457,19 +457,20 @@ def trace_quantum(self, ctx, device, _trace, qrp) -> QRegPromise:
457457
body_trace = op.regions[0].trace
458458
body_tape = op.regions[0].quantum_tape
459459
res_classical_tracers = op.regions[0].res_classical_tracers
460+
dbg = op.debug_info
460461

461462
# Handle ops that were instantiated outside of a tracing context.
462463
if body_trace is None:
463-
frame_ctx = EvaluationContext.frame_tracing_context(ctx)
464+
frame_ctx = EvaluationContext.frame_tracing_context(debug_info=dbg)
464465
else:
465-
frame_ctx = EvaluationContext.frame_tracing_context(ctx, body_trace)
466+
frame_ctx = EvaluationContext.frame_tracing_context(body_trace)
466467

467468
with frame_ctx as body_trace:
468469
qreg_in = _input_type_to_tracers(body_trace.new_arg, [AbstractQreg()])[0]
469470
qrp_out = trace_quantum_operations(body_tape, device, qreg_in, ctx, body_trace)
470471
qreg_out = qrp_out.actualize()
471-
body_jaxpr, _, body_consts = ctx.frames[body_trace].to_jaxpr2(
472-
res_classical_tracers + [qreg_out]
472+
body_jaxpr, _, body_consts = body_trace.frame.to_jaxpr2(
473+
res_classical_tracers + [qreg_out], dbg
473474
)
474475

475476
qreg = qrp.actualize()
@@ -657,9 +658,9 @@ def ctrl_distribute(
657658

658659
# Allow decompositions outside of a Catalyst context.
659660
if EvaluationContext.is_tracing():
660-
ctx = EvaluationContext.get_main_tracing_context()
661+
cur_trace = EvaluationContext.get_current_trace()
661662
else:
662-
ctx = None
663+
cur_trace = None
663664

664665
new_ops = []
665666
for op in tape.operations:
@@ -675,8 +676,8 @@ def ctrl_distribute(
675676
else:
676677
for region in [region for region in op.regions if region.quantum_tape is not None]:
677678
# Re-enter a JAXPR frame but do not create a new one is none exists.
678-
if ctx and region.trace:
679-
trace_manager = EvaluationContext.frame_tracing_context(ctx, region.trace)
679+
if cur_trace and region.trace:
680+
trace_manager = EvaluationContext.frame_tracing_context(region.trace)
680681
else:
681682
trace_manager = nullcontext
682683

frontend/catalyst/device/decomposition.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def _decompose_nested_tapes(op, ctx, capabilities: DeviceCapabilities):
152152
if region.quantum_tape is None:
153153
new_tape = None
154154
else:
155-
with EvaluationContext.frame_tracing_context(ctx, region.trace):
155+
with EvaluationContext.frame_tracing_context(region.trace):
156156
tapes, _ = catalyst_decompose(
157157
region.quantum_tape, ctx=ctx, capabilities=capabilities
158158
)

frontend/catalyst/device/verification.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,12 @@ def _verify_nested(
6363
) -> Any:
6464
"""Traverse the nested quantum tape, carry a caller-defined state."""
6565

66-
ctx = EvaluationContext.get_main_tracing_context()
6766
for op in operations:
6867
inner_state = op_checker_fn(op, state)
6968
if has_nested_tapes(op):
7069
for region in nested_quantum_regions(op):
7170
if region.trace is not None:
72-
with EvaluationContext.frame_tracing_context(ctx, region.trace):
71+
with EvaluationContext.frame_tracing_context(region.trace):
7372
inner_state = _verify_nested(
7473
region.quantum_tape.operations, inner_state, op_checker_fn
7574
)

0 commit comments

Comments
 (0)