Skip to content

Commit f858a71

Browse files
committed
Finalize some deprecations in jax.core, jax.lib.xla_bridge, and jax.lib.xla_client.
1 parent 01206f8 commit f858a71

File tree

5 files changed

+25
-31
lines changed

5 files changed

+25
-31
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
1919
APIs of the same name in {mod}`jax.extend.core`; see the documentation for
2020
{mod}`jax.extend` for information on the compatibility guarantees of these
2121
semi-public extensions.
22+
* Several previously-deprecated APIs have been removed, including:
23+
* from {mod}`jax.core`: `check_eqn`, `check_type`, `check_valid_jaxtype`, and
24+
`non_negative_dim`.
25+
* from {mod}`jax.lib.xla_bridge`: `xla_client` and `default_backend`.
26+
* from {mod}`jax.lib.xla_client`: `_xla` and `bfloat16`.
2227

2328
## jax 0.4.37 (Dec 9, 2024)
2429

docs/jax.lib.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ jax.lib.xla_bridge
1111
.. autosummary::
1212
:toctree: _autosummary
1313

14-
default_backend
1514
get_backend
1615
get_compile_options
1716

jax/core.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -160,13 +160,16 @@
160160
"Var": ("jax.core.Var is deprecated. Use jax.extend.core.Var instead, "
161161
"and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.",
162162
_src_core.Var),
163-
# Added 2024-08-14
164-
"check_eqn": ("jax.core.check_eqn is deprecated.", _src_core.check_eqn),
165-
"check_type": ("jax.core.check_type is deprecated.", _src_core.check_type),
163+
# Finalized 2024-12-11; remove after 2025-3-11
164+
"check_eqn": ("jax.core.check_eqn was removed in JAX v0.4.38.", None),
165+
"check_type": ("jax.core.check_type was removed in JAX v0.4.38.", None),
166166
"check_valid_jaxtype": (
167-
("jax.core.check_valid_jaxtype is deprecated. Instead, you can manually"
167+
("jax.core.check_valid_jaxtype was removed in JAX v0.4.38. Instead, you can manually"
168168
" raise an error if core.valid_jaxtype() returns False."),
169-
_src_core.check_valid_jaxtype),
169+
None),
170+
"non_negative_dim": (
171+
"jax.core.non_negative_dim was removed in JAX v0.4.38. Use max_dim(..., 0).", None,
172+
),
170173
# Finalized 2024-09-25; remove after 2024-12-25
171174
"pp_aval": ("jax.core.pp_aval was removed in JAX v0.4.34.", None),
172175
"pp_eqn": ("jax.core.pp_eqn was removed in JAX v0.4.34.", None),
@@ -180,10 +183,6 @@
180183
"pp_kv_pairs": ("jax.core.pp_kv_pairs was removed in JAX v0.4.34.", None),
181184
"pp_var": ("jax.core.pp_var was removed in JAX v0.4.34.", None),
182185
"pp_vars": ("jax.core.pp_vars was removed in JAX v0.4.34.", None),
183-
# Added Jan 8, 2024
184-
"non_negative_dim": (
185-
"jax.core.non_negative_dim is deprecated. Use max_dim(..., 0).", _src_core.non_negative_dim,
186-
),
187186
}
188187

189188
import typing
@@ -207,9 +206,6 @@
207206
Var = _src_core.Var
208207
axis_frame = _src_core.axis_frame
209208
call_p = _src_core.call_p
210-
check_eqn = _src_core.check_eqn
211-
check_type = _src_core.check_type
212-
check_valid_jaxtype = _src_core.check_valid_jaxtype
213209
closed_call_p = _src_core.closed_call_p
214210
concrete_aval = _src_core.concrete_aval
215211
dedup_referents = _src_core.dedup_referents
@@ -223,7 +219,6 @@
223219
lattice_join = _src_core.lattice_join
224220
leaked_tracer_error = _src_core.leaked_tracer_error
225221
maybe_find_leaked_tracers = _src_core.maybe_find_leaked_tracers
226-
non_negative_dim = _src_core.non_negative_dim
227222
raise_to_shaped = _src_core.raise_to_shaped
228223
raise_to_shaped_mappings = _src_core.raise_to_shaped_mappings
229224
reset_trace_state = _src_core.reset_trace_state

jax/lib/xla_bridge.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,7 @@
1414

1515
# ruff: noqa: F401
1616
from jax._src.xla_bridge import (
17-
default_backend as _deprecated_default_backend,
1817
get_backend as _deprecated_get_backend,
19-
xla_client as _deprecated_xla_client,
2018
)
2119

2220
from jax._src.compiler import (
@@ -25,25 +23,24 @@
2523

2624
_deprecations = {
2725
# Added July 31, 2024
28-
"xla_client": (
29-
"jax.lib.xla_bridge.xla_client is deprecated; use jax.lib.xla_client directly.",
30-
_deprecated_xla_client
31-
),
3226
"get_backend": (
3327
"jax.lib.xla_bridge.get_backend is deprecated; use jax.extend.backend.get_backend.",
3428
_deprecated_get_backend
3529
),
30+
# Finalized 2024-12-11; remove after 2025-3-11
31+
"xla_client": (
32+
"jax.lib.xla_bridge.xla_client was removed in JAX v0.4.38; use jax.lib.xla_client directly.",
33+
None
34+
),
3635
"default_backend": (
37-
"jax.lib.xla_bridge.default_backend is deprecated; use jax.default_backend.",
38-
_deprecated_default_backend
36+
"jax.lib.xla_bridge.default_backend was removed in JAX v0.4.38; use jax.default_backend.",
37+
None
3938
),
4039
}
4140

4241
import typing as _typing
4342
if _typing.TYPE_CHECKING:
44-
from jax._src.xla_bridge import default_backend as default_backend
4543
from jax._src.xla_bridge import get_backend as get_backend
46-
from jax._src.xla_bridge import xla_client as xla_client
4744
else:
4845
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
4946
__getattr__ = _deprecation_getattr(__name__, _deprecations)

jax/lib/xla_client.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,14 @@
2727
Traceback = _xc.Traceback
2828

2929
_deprecations = {
30-
# Added Aug 5 2024
30+
# Finalized 2024-12-11; remove after 2025-3-11
3131
"_xla": (
32-
"jax.lib.xla_client._xla is deprecated; use jax.lib.xla_extension.",
33-
_xc._xla,
32+
"jax.lib.xla_client._xla was removed in JAX v0.4.38; use jax.lib.xla_extension.",
33+
None,
3434
),
3535
"bfloat16": (
36-
"jax.lib.xla_client.bfloat16 is deprecated; use ml_dtypes.bfloat16.",
37-
_xc.bfloat16,
36+
"jax.lib.xla_client.bfloat16 was removed in JAX v0.4.38; use ml_dtypes.bfloat16.",
37+
None,
3838
),
3939
# Added Sep 26 2024
4040
"Device": (
@@ -104,8 +104,6 @@
104104
import typing as _typing
105105

106106
if _typing.TYPE_CHECKING:
107-
_xla = _xc._xla
108-
bfloat16 = _xc.bfloat16
109107
dtype_to_etype = _xc.dtype_to_etype
110108
ops = _xc.ops
111109
register_custom_call_target = _xc.register_custom_call_target

0 commit comments

Comments
 (0)