Skip to content

Commit 51c224c

Browse files
superbobryGoogle-ML-Automation
authored andcommitted
Removed deprecated jax.core.{full_lower,jaxpr_as_fun,lattice_join}
PiperOrigin-RevId: 744754730
1 parent ff00fa9 commit 51c224c

File tree

3 files changed

+6
-16
lines changed

3 files changed

+6
-16
lines changed

CHANGELOG.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
4949
* From `jax.core`: `AxisSize`, `ClosedJaxpr`, `EvalTrace`, `InDBIdx`, `InputType`,
5050
`Jaxpr`, `JaxprEqn`, `Literal`, `MapPrimitive`, `OpaqueTraceState`, `OutDBIdx`,
5151
`Primitive`, `Token`, `TRACER_LEAK_DEBUGGER_WARNING`, `Var`, `concrete_aval`,
52-
`dedup_referents`, `escaped_tracer_error`, `extend_axis_env_nd`, `get_referent`,
53-
`join_effects`, `leaked_tracer_error`, `maybe_find_leaked_tracers`, `raise_to_shaped`,
52+
`dedup_referents`, `escaped_tracer_error`, `extend_axis_env_nd`, `full_lower`, `get_referent`, `jaxpr_as_fun`, `join_effects`, `lattice_join`,
53+
`leaked_tracer_error`, `maybe_find_leaked_tracers`, `raise_to_shaped`,
5454
`raise_to_shaped_mappings`, `reset_trace_state`, `str_eqn_compact`,
5555
`substitute_vars_in_output_ty`, `typecompat`, and `used_axis_names_jaxpr`. Most
5656
have no public replacement, though a few are available at {mod}`jax.extend.core`.

jax/_src/core.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1496,11 +1496,6 @@ def _jaxpr_type_to_callable_annotation(jaxpr: Jaxpr) -> InputType:
14961496
for v in jaxpr.invars]
14971497
return tuple(out)
14981498

1499-
# TODO(dougalm): Deprecate. This is here for backwards compat.
1500-
def lattice_join(x, y):
1501-
assert typematch(x, y)
1502-
return x
1503-
15041499
# For use in typing annotations to denote either a Tracer or a `valid_jaxtype`.
15051500
Value = Any
15061501

jax/core.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -97,13 +97,11 @@
9797
"typecheck": ("jax.core.typecheck is deprecated.", _src_core.typecheck),
9898
"typematch": ("jax.core.typematch is deprecated.", _src_core.typematch),
9999
# Added 2024-12-10
100-
"full_lower": ("jax.core.full_lower is deprecated. It is a no-op as of JAX v0.4.36.",
101-
_src_core.full_lower),
102-
"jaxpr_as_fun": ("jax.core.jaxpr_as_fun is deprecated. Use jax.extend.core.jaxpr_as_fun instead, "
100+
"full_lower": ("jax.core.full_lower is deprecated. It is a no-op as of JAX v0.4.36.", None),
101+
"jaxpr_as_fun": ("jax.core.jaxpr_as_fun was removed in JAX v0.6.0. Use jax.extend.core.jaxpr_as_fun instead, "
103102
"and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.",
104-
_src_core.jaxpr_as_fun),
105-
"lattice_join": ("jax.core.lattice_join is deprecated. It is a no-op as of JAX v0.4.36.",
106-
_src_core.lattice_join),
103+
None),
104+
"lattice_join": ("jax.core.lattice_join is deprecated. It is a no-op as of JAX v0.4.36.", None),
107105
# Finalized 2025-03-25 for JAX v0.6.0; remove after 2025-06-25
108106
"AxisSize": ("jax.core.AxisSize was removed in JAX v0.6.0.", None),
109107
"ClosedJaxpr": ("jax.core.ClosedJaxpr was removed in JAX v0.6.0. Use jax.extend.core.ClosedJaxpr instead, "
@@ -152,10 +150,7 @@
152150
axis_frame = _src_core.axis_frame
153151
call_p = _src_core.call_p
154152
closed_call_p = _src_core.closed_call_p
155-
full_lower = _src_core.full_lower
156153
get_type = _src_core.get_aval
157-
jaxpr_as_fun = _src_core.jaxpr_as_fun
158-
lattice_join = _src_core.lattice_join
159154
trace_state_clean = _src_core.trace_state_clean
160155
typecheck = _src_core.typecheck
161156
typematch = _src_core.typematch

0 commit comments

Comments
 (0)