Skip to content

Commit 25524ab

Browse files
Reverts b56dc63
PiperOrigin-RevId: 707501925
1 parent 96d4a75 commit 25524ab

File tree

10 files changed

+56
-30
lines changed

10 files changed

+56
-30
lines changed

CHANGELOG.md

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,6 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
1212

1313
## Unreleased
1414

15-
* Deprecations
16-
* From {mod}`jax.interpreters.xla`, `abstractify` and `pytype_aval_mappings`
17-
are now deprecated, having been replaced by symbols of the same name
18-
in {mod}`jax.core`.
19-
2015
## jax 0.4.38 (Dec 17, 2024)
2116

2217
* Changes:

jax/_src/abstract_arrays.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def masked_array_error(*args, **kwargs):
4949
"Use arr.filled() to convert the value to a standard numpy array.")
5050

5151
core.pytype_aval_mappings[np.ma.MaskedArray] = masked_array_error
52+
core.xla_pytype_aval_mappings[np.ma.MaskedArray] = masked_array_error
5253

5354

5455
def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray:
@@ -57,6 +58,7 @@ def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray:
5758
return ShapedArray(x.shape, dtypes.canonicalize_dtype(dtype))
5859

5960
core.pytype_aval_mappings[np.ndarray] = _make_shaped_array_for_numpy_array
61+
core.xla_pytype_aval_mappings[np.ndarray] = _make_shaped_array_for_numpy_array
6062

6163

6264
def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray:
@@ -66,6 +68,7 @@ def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray:
6668

6769
for t in numpy_scalar_types:
6870
core.pytype_aval_mappings[t] = _make_shaped_array_for_numpy_scalar
71+
core.xla_pytype_aval_mappings[t] = _make_shaped_array_for_numpy_scalar
6972

7073
core.literalable_types.update(array_types)
7174

@@ -78,5 +81,6 @@ def _make_abstract_python_scalar(typ, val):
7881

7982
for t in dtypes.python_scalar_dtypes:
8083
core.pytype_aval_mappings[t] = partial(_make_abstract_python_scalar, t)
84+
core.xla_pytype_aval_mappings[t] = partial(_make_abstract_python_scalar, t)
8185

8286
core.literalable_types.update(dtypes.python_scalar_dtypes.keys())

jax/_src/array.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1038,6 +1038,7 @@ def _get_aval_array(self):
10381038

10391039
api_util._shaped_abstractify_handlers[ArrayImpl] = _get_aval_array
10401040
core.pytype_aval_mappings[ArrayImpl] = _get_aval_array
1041+
core.xla_pytype_aval_mappings[ArrayImpl] = _get_aval_array
10411042

10421043
# TODO(jakevdp) replace this with true inheritance at the C++ level.
10431044
basearray.Array.register(ArrayImpl)

jax/_src/core.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1388,7 +1388,7 @@ def lattice_join(x, y):
13881388

13891389
def valid_jaxtype(x) -> bool:
13901390
try:
1391-
abstractify(x)
1391+
concrete_aval(x)
13921392
except TypeError:
13931393
return False
13941394
else:
@@ -1400,9 +1400,35 @@ def check_valid_jaxtype(x):
14001400
f"Value {x!r} of type {type(x)} is not a valid JAX type")
14011401

14021402

1403-
def abstractify(x):
1403+
# TODO(jakevdp): merge concrete_aval and abstractify to the extent possible.
1404+
# This is tricky because concrete_aval includes sharding information, and
1405+
# abstractify does not; further, because abstractify is in the dispatch path,
1406+
# performance is important and simply adding sharding there is not an option.
1407+
def concrete_aval(x):
1408+
# This differs from abstractify below in that the abstract values
1409+
# include sharding where applicable. Historically (before stackless)
1410+
# the returned avals were concrete, but after the stackless change
1411+
# this returns ShapedArray like abstractify.
1412+
# Rules are registered in pytype_aval_mappings.
14041413
for typ in type(x).__mro__:
1405-
aval_fn = pytype_aval_mappings.get(typ)
1414+
handler = pytype_aval_mappings.get(typ)
1415+
if handler: return handler(x)
1416+
if hasattr(x, '__jax_array__'):
1417+
return concrete_aval(x.__jax_array__())
1418+
raise TypeError(f"Value {x!r} with type {type(x)} is not a valid JAX "
1419+
"type")
1420+
1421+
1422+
def abstractify(x):
1423+
# Historically, this was called xla.abstractify. It differs from
1424+
# concrete_aval in that it excludes sharding information, and
1425+
# uses a more performant path for accessing avals. Rules are
1426+
# registered in xla_pytype_aval_mappings.
1427+
typ = type(x)
1428+
aval_fn = xla_pytype_aval_mappings.get(typ)
1429+
if aval_fn: return aval_fn(x)
1430+
for typ in typ.__mro__:
1431+
aval_fn = xla_pytype_aval_mappings.get(typ)
14061432
if aval_fn: return aval_fn(x)
14071433
if hasattr(x, '__jax_array__'):
14081434
return abstractify(x.__jax_array__())
@@ -1413,7 +1439,7 @@ def get_aval(x):
14131439
if isinstance(x, Tracer):
14141440
return x.aval
14151441
else:
1416-
return abstractify(x)
1442+
return concrete_aval(x)
14171443

14181444
get_type = get_aval
14191445

@@ -1809,6 +1835,7 @@ def to_tangent_aval(self):
18091835
self.weak_type)
18101836

18111837
pytype_aval_mappings: dict[type, Callable[[Any], AbstractValue]] = {}
1838+
xla_pytype_aval_mappings: dict[type, Callable[[Any], AbstractValue]] = {}
18121839

18131840

18141841
class DArray:
@@ -1865,6 +1892,7 @@ def data(self):
18651892

18661893
pytype_aval_mappings[DArray] = \
18671894
lambda x: DShapedArray(x._aval.shape, x._aval.dtype, x._aval.weak_type)
1895+
xla_pytype_aval_mappings[DArray] = lambda x: x._aval
18681896

18691897
@dataclass(frozen=True)
18701898
class bint(dtypes.ExtendedDType):
@@ -1897,6 +1925,7 @@ def __getitem__(self, idx): return get_aval(self)._getitem(self, idx)
18971925
def __setitem__(self, idx, x): return get_aval(self)._setitem(self, idx, x)
18981926
def __repr__(self) -> str: return 'Mutable' + repr(self[...])
18991927
pytype_aval_mappings[MutableArray] = lambda x: x._aval
1928+
xla_pytype_aval_mappings[MutableArray] = lambda x: x._aval
19001929

19011930
def mutable_array(init_val):
19021931
return mutable_array_p.bind(init_val)
@@ -1952,6 +1981,7 @@ def __init__(self, buf):
19521981
def block_until_ready(self):
19531982
self._buf.block_until_ready()
19541983
pytype_aval_mappings[Token] = lambda _: abstract_token
1984+
xla_pytype_aval_mappings[Token] = lambda _: abstract_token
19551985

19561986

19571987
# TODO(dougalm): Deprecate these. They're just here for backwards compat.

jax/_src/export/shape_poly.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1205,6 +1205,7 @@ def _geq_decision(e1: DimSize, e2: DimSize, cmp_str: Callable[[], str]) -> bool:
12051205
f"Symbolic dimension comparison {cmp_str()} is inconclusive.{describe_scope}")
12061206

12071207
core.pytype_aval_mappings[_DimExpr] = _DimExpr._get_aval
1208+
core.xla_pytype_aval_mappings[_DimExpr] = _DimExpr._get_aval
12081209
dtypes._weak_types.append(_DimExpr)
12091210

12101211
def _convertible_to_int(p: DimSize) -> bool:

jax/_src/interpreters/xla.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,13 @@ def _canonicalize_python_scalar_dtype(typ, x):
146146
canonicalize_dtype_handlers[core.DArray] = identity
147147
canonicalize_dtype_handlers[core.MutableArray] = identity
148148

149+
# TODO(jakevdp): deprecate and remove this.
150+
def abstractify(x) -> Any:
151+
return core.abstractify(x)
152+
153+
# TODO(jakevdp): deprecate and remove this.
154+
pytype_aval_mappings: dict[Any, Callable[[Any], core.AbstractValue]] = core.xla_pytype_aval_mappings
155+
149156
initial_style_primitives: set[core.Primitive] = set()
150157

151158
def register_initial_style_primitive(prim: core.Primitive):

jax/_src/prng.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,7 @@ def __hash__(self) -> int:
463463

464464

465465
core.pytype_aval_mappings[PRNGKeyArray] = lambda x: x.aval
466+
core.xla_pytype_aval_mappings[PRNGKeyArray] = lambda x: x.aval
466467

467468
xla.canonicalize_dtype_handlers[PRNGKeyArray] = lambda x: x
468469

jax/core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@
122122
_src_core.call_p),
123123
"closed_call_p": ("jax.core.closed_call_p is deprecated. Use jax.extend.core.primitives.closed_call_p",
124124
_src_core.closed_call_p),
125-
"concrete_aval": ("jax.core.concrete_aval is deprecated.", _src_core.abstractify),
125+
"concrete_aval": ("jax.core.concrete_aval is deprecated.", _src_core.concrete_aval),
126126
"dedup_referents": ("jax.core.dedup_referents is deprecated.", _src_core.dedup_referents),
127127
"escaped_tracer_error": ("jax.core.escaped_tracer_error is deprecated.",
128128
_src_core.escaped_tracer_error),
@@ -207,7 +207,7 @@
207207
axis_frame = _src_core.axis_frame
208208
call_p = _src_core.call_p
209209
closed_call_p = _src_core.closed_call_p
210-
concrete_aval = _src_core.abstractify
210+
concrete_aval = _src_core.concrete_aval
211211
dedup_referents = _src_core.dedup_referents
212212
escaped_tracer_error = _src_core.escaped_tracer_error
213213
extend_axis_env_nd = _src_core.extend_axis_env_nd

jax/interpreters/xla.py

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313
# limitations under the License.
1414

1515
from jax._src.interpreters.xla import (
16+
abstractify as abstractify,
1617
canonicalize_dtype as canonicalize_dtype,
1718
canonicalize_dtype_handlers as canonicalize_dtype_handlers,
19+
pytype_aval_mappings as pytype_aval_mappings,
1820
)
1921

2022
from jax._src.dispatch import (
@@ -25,19 +27,8 @@
2527
Backend = _xc._xla.Client
2628
del _xc
2729

28-
from jax._src import core as _src_core
29-
3030
# Deprecations
3131
_deprecations = {
32-
# Added 2024-12-17
33-
"abstractify": (
34-
"jax.interpreters.xla.abstractify is deprecated.",
35-
_src_core.abstractify
36-
),
37-
"pytype_aval_mappings": (
38-
"jax.interpreters.xla.pytype_aval_mappings is deprecated.",
39-
_src_core.pytype_aval_mappings
40-
),
4132
# Finalized 2024-10-24; remove after 2025-01-24
4233
"xb": (
4334
("jax.interpreters.xla.xb was removed in JAX v0.4.36. "
@@ -53,13 +44,6 @@
5344
),
5445
}
5546

56-
import typing as _typing
5747
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
58-
if _typing.TYPE_CHECKING:
59-
abstractify = _src_core.abstractify
60-
pytype_aval_mappings = _src_core.pytype_aval_mappings
61-
else:
62-
__getattr__ = _deprecation_getattr(__name__, _deprecations)
48+
__getattr__ = _deprecation_getattr(__name__, _deprecations)
6349
del _deprecation_getattr
64-
del _typing
65-
del _src_core

tests/lax_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3959,6 +3959,8 @@ def setUp(self):
39593959
core.pytype_aval_mappings[FooArray] = \
39603960
lambda x: core.ShapedArray(x.shape, FooTy())
39613961
xla.canonicalize_dtype_handlers[FooArray] = lambda x: x
3962+
core.xla_pytype_aval_mappings[FooArray] = \
3963+
lambda x: core.ShapedArray(x.shape, FooTy())
39623964
pxla.shard_arg_handlers[FooArray] = shard_foo_array_handler
39633965
mlir._constant_handlers[FooArray] = foo_array_constant_handler
39643966
mlir.register_lowering(make_p, mlir.lower_fun(make_lowering, False))
@@ -3971,6 +3973,7 @@ def setUp(self):
39713973
def tearDown(self):
39723974
del core.pytype_aval_mappings[FooArray]
39733975
del xla.canonicalize_dtype_handlers[FooArray]
3976+
del core.xla_pytype_aval_mappings[FooArray]
39743977
del mlir._constant_handlers[FooArray]
39753978
del mlir._lowerings[make_p]
39763979
del mlir._lowerings[bake_p]

0 commit comments

Comments
 (0)