Skip to content

Commit 85e2969

Browse files
committed
Deprecate several private APIs in jax.lib
1 parent a582df0 commit 85e2969

File tree

3 files changed

+35
-3
lines changed

3 files changed

+35
-3
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,12 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
6565
result in an indexing overflow for batch sizes close to int32 max. See
6666
{jax-issue}`#24843` for more details.
6767

68+
* Deprecations
69+
* `jax.lib.xla_extension.ArrayImpl` and `jax.lib.xla_client.ArrayImpl` are deprecated;
70+
use `jax.Array` instead.
71+
* `jax.lib.xla_extension.XlaRuntimeError` is deprecated; use `jax.errors.JaxRuntimeError`
72+
instead.
73+
6874
## jax 0.4.35 (Oct 22, 2024)
6975

7076
* Breaking Changes

jax/lib/xla_client.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
get_topology_for_devices = _xc.get_topology_for_devices
1919
heap_profile = _xc.heap_profile
2020
mlir_api_version = _xc.mlir_api_version
21-
ArrayImpl = _xc.ArrayImpl
2221
Client = _xc.Client
2322
CompileOptions = _xc.CompileOptions
2423
DeviceAssignment = _xc.DeviceAssignment
@@ -95,6 +94,11 @@
9594
"XlaComputation is deprecated; use StableHLO instead.",
9695
_xc.XlaComputation,
9796
),
97+
# Added Nov 20 2024
98+
"ArrayImpl": (
99+
"jax.lib.xla_client.ArrayImpl is deprecated; use jax.Array instead.",
100+
_xc.ArrayImpl,
101+
),
98102
}
99103

100104
import typing as _typing
@@ -106,6 +110,7 @@
106110
ops = _xc.ops
107111
register_custom_call_target = _xc.register_custom_call_target
108112
shape_from_pyval = _xc.shape_from_pyval
113+
ArrayImpl = _xc.ArrayImpl
109114
Device = _xc.Device
110115
FftType = _FftType
111116
PaddingType = _xc.PaddingType

jax/lib/xla_extension.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
pmap_lib = _xe.pmap_lib
2525
profiler = _xe.profiler
2626
pytree = _xe.pytree
27-
ArrayImpl = _xe.ArrayImpl
2827
Device = _xe.Device
2928
DistributedRuntimeClient = _xe.DistributedRuntimeClient
3029
HloModule = _xe.HloModule
@@ -33,6 +32,28 @@
3332
PjitFunctionCache = _xe.PjitFunctionCache
3433
PjitFunction = _xe.PjitFunction
3534
PmapFunction = _xe.PmapFunction
36-
XlaRuntimeError = _xe.XlaRuntimeError
3735

36+
_deprecations = {
37+
# Added Nov 20 2024
38+
"ArrayImpl": (
39+
"jax.lib.xla_extension.ArrayImpl is deprecated; use jax.Array instead.",
40+
_xe.ArrayImpl,
41+
),
42+
"XlaRuntimeError": (
43+
"jax.lib.xla_extension.XlaRuntimeError is deprecated; use jax.errors.JaxRuntimeError instead.",
44+
_xe.XlaRuntimeError,
45+
),
46+
}
47+
48+
import typing as _typing
49+
50+
if _typing.TYPE_CHECKING:
51+
ArrayImpl = _xe.ArrayImpl
52+
XlaRuntimeError = _xe.XlaRuntimeError
53+
else:
54+
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
55+
56+
__getattr__ = _deprecation_getattr(__name__, _deprecations)
57+
del _deprecation_getattr
58+
del _typing
3859
del _xe

0 commit comments

Comments
 (0)