Skip to content

Commit 3cecbf3

Browse files
committed
Remove core.concrete_aval and replace with abstractify
1 parent 1e22149 commit 3cecbf3

File tree

10 files changed

+30
-56
lines changed

10 files changed

+30
-56
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
1212

1313
## Unreleased
1414

15+
* Deprecations
16+
* From {mod}`jax.interpreters.xla`, `abstractify` and `pytype_aval_mappings`
17+
are now deprecated, having been replaced by symbols of the same name
18+
in {mod}`jax.core`.
19+
1520
## jax 0.4.38 (Dec 17, 2024)
1621

1722
* Changes:

jax/_src/abstract_arrays.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ def masked_array_error(*args, **kwargs):
4949
"Use arr.filled() to convert the value to a standard numpy array.")
5050

5151
core.pytype_aval_mappings[np.ma.MaskedArray] = masked_array_error
52-
core.xla_pytype_aval_mappings[np.ma.MaskedArray] = masked_array_error
5352

5453

5554
def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray:
@@ -58,7 +57,6 @@ def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray:
5857
return ShapedArray(x.shape, dtypes.canonicalize_dtype(dtype))
5958

6059
core.pytype_aval_mappings[np.ndarray] = _make_shaped_array_for_numpy_array
61-
core.xla_pytype_aval_mappings[np.ndarray] = _make_shaped_array_for_numpy_array
6260

6361

6462
def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray:
@@ -68,7 +66,6 @@ def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray:
6866

6967
for t in numpy_scalar_types:
7068
core.pytype_aval_mappings[t] = _make_shaped_array_for_numpy_scalar
71-
core.xla_pytype_aval_mappings[t] = _make_shaped_array_for_numpy_scalar
7269

7370
core.literalable_types.update(array_types)
7471

@@ -81,6 +78,5 @@ def _make_abstract_python_scalar(typ, val):
8178

8279
for t in dtypes.python_scalar_dtypes:
8380
core.pytype_aval_mappings[t] = partial(_make_abstract_python_scalar, t)
84-
core.xla_pytype_aval_mappings[t] = partial(_make_abstract_python_scalar, t)
8581

8682
core.literalable_types.update(dtypes.python_scalar_dtypes.keys())

jax/_src/array.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1038,7 +1038,6 @@ def _get_aval_array(self):
10381038

10391039
api_util._shaped_abstractify_handlers[ArrayImpl] = _get_aval_array
10401040
core.pytype_aval_mappings[ArrayImpl] = _get_aval_array
1041-
core.xla_pytype_aval_mappings[ArrayImpl] = _get_aval_array
10421041

10431042
# TODO(jakevdp) replace this with true inheritance at the C++ level.
10441043
basearray.Array.register(ArrayImpl)

jax/_src/core.py

Lines changed: 4 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1388,7 +1388,7 @@ def lattice_join(x, y):
13881388

13891389
def valid_jaxtype(x) -> bool:
13901390
try:
1391-
concrete_aval(x)
1391+
abstractify(x)
13921392
except TypeError:
13931393
return False
13941394
else:
@@ -1400,35 +1400,9 @@ def check_valid_jaxtype(x):
14001400
f"Value {x!r} of type {type(x)} is not a valid JAX type")
14011401

14021402

1403-
# TODO(jakevdp): merge concrete_aval and abstractify to the extent possible.
1404-
# This is tricky because concrete_aval includes sharding information, and
1405-
# abstractify does not; further, because abstractify is in the dispatch path,
1406-
# performance is important and simply adding sharding there is not an option.
1407-
def concrete_aval(x):
1408-
# This differs from abstractify below in that the abstract values
1409-
# include sharding where applicable. Historically (before stackless)
1410-
# the returned avals were concrete, but after the stackless change
1411-
# this returns ShapedArray like abstractify.
1412-
# Rules are registered in pytype_aval_mappings.
1413-
for typ in type(x).__mro__:
1414-
handler = pytype_aval_mappings.get(typ)
1415-
if handler: return handler(x)
1416-
if hasattr(x, '__jax_array__'):
1417-
return concrete_aval(x.__jax_array__())
1418-
raise TypeError(f"Value {x!r} with type {type(x)} is not a valid JAX "
1419-
"type")
1420-
1421-
14221403
def abstractify(x):
1423-
# Historically, this was called xla.abstractify. It differs from
1424-
# concrete_aval in that it excludes sharding information, and
1425-
# uses a more performant path for accessing avals. Rules are
1426-
# registered in xla_pytype_aval_mappings.
1427-
typ = type(x)
1428-
aval_fn = xla_pytype_aval_mappings.get(typ)
1429-
if aval_fn: return aval_fn(x)
1430-
for typ in typ.__mro__:
1431-
aval_fn = xla_pytype_aval_mappings.get(typ)
1404+
for typ in type(x).__mro__:
1405+
aval_fn = pytype_aval_mappings.get(typ)
14321406
if aval_fn: return aval_fn(x)
14331407
if hasattr(x, '__jax_array__'):
14341408
return abstractify(x.__jax_array__())
@@ -1439,7 +1413,7 @@ def get_aval(x):
14391413
if isinstance(x, Tracer):
14401414
return x.aval
14411415
else:
1442-
return concrete_aval(x)
1416+
return abstractify(x)
14431417

14441418
get_type = get_aval
14451419

@@ -1835,7 +1809,6 @@ def to_tangent_aval(self):
18351809
self.weak_type)
18361810

18371811
pytype_aval_mappings: dict[type, Callable[[Any], AbstractValue]] = {}
1838-
xla_pytype_aval_mappings: dict[type, Callable[[Any], AbstractValue]] = {}
18391812

18401813

18411814
class DArray:
@@ -1892,7 +1865,6 @@ def data(self):
18921865

18931866
pytype_aval_mappings[DArray] = \
18941867
lambda x: DShapedArray(x._aval.shape, x._aval.dtype, x._aval.weak_type)
1895-
xla_pytype_aval_mappings[DArray] = lambda x: x._aval
18961868

18971869
@dataclass(frozen=True)
18981870
class bint(dtypes.ExtendedDType):
@@ -1925,7 +1897,6 @@ def __getitem__(self, idx): return get_aval(self)._getitem(self, idx)
19251897
def __setitem__(self, idx, x): return get_aval(self)._setitem(self, idx, x)
19261898
def __repr__(self) -> str: return 'Mutable' + repr(self[...])
19271899
pytype_aval_mappings[MutableArray] = lambda x: x._aval
1928-
xla_pytype_aval_mappings[MutableArray] = lambda x: x._aval
19291900

19301901
def mutable_array(init_val):
19311902
return mutable_array_p.bind(init_val)
@@ -1979,7 +1950,6 @@ def __init__(self, buf):
19791950
def block_until_ready(self):
19801951
self._buf.block_until_ready()
19811952
pytype_aval_mappings[Token] = lambda _: abstract_token
1982-
xla_pytype_aval_mappings[Token] = lambda _: abstract_token
19831953

19841954

19851955
# TODO(dougalm): Deprecate these. They're just here for backwards compat.

jax/_src/export/shape_poly.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1205,7 +1205,6 @@ def _geq_decision(e1: DimSize, e2: DimSize, cmp_str: Callable[[], str]) -> bool:
12051205
f"Symbolic dimension comparison {cmp_str()} is inconclusive.{describe_scope}")
12061206

12071207
core.pytype_aval_mappings[_DimExpr] = _DimExpr._get_aval
1208-
core.xla_pytype_aval_mappings[_DimExpr] = _DimExpr._get_aval
12091208
dtypes._weak_types.append(_DimExpr)
12101209

12111210
def _convertible_to_int(p: DimSize) -> bool:

jax/_src/interpreters/xla.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -146,13 +146,6 @@ def _canonicalize_python_scalar_dtype(typ, x):
146146
canonicalize_dtype_handlers[core.DArray] = identity
147147
canonicalize_dtype_handlers[core.MutableArray] = identity
148148

149-
# TODO(jakevdp): deprecate and remove this.
150-
def abstractify(x) -> Any:
151-
return core.abstractify(x)
152-
153-
# TODO(jakevdp): deprecate and remove this.
154-
pytype_aval_mappings: dict[Any, Callable[[Any], core.AbstractValue]] = core.xla_pytype_aval_mappings
155-
156149
initial_style_primitives: set[core.Primitive] = set()
157150

158151
def register_initial_style_primitive(prim: core.Primitive):

jax/_src/prng.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,6 @@ def __hash__(self) -> int:
463463

464464

465465
core.pytype_aval_mappings[PRNGKeyArray] = lambda x: x.aval
466-
core.xla_pytype_aval_mappings[PRNGKeyArray] = lambda x: x.aval
467466

468467
xla.canonicalize_dtype_handlers[PRNGKeyArray] = lambda x: x
469468

jax/core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@
122122
_src_core.call_p),
123123
"closed_call_p": ("jax.core.closed_call_p is deprecated. Use jax.extend.core.primitives.closed_call_p",
124124
_src_core.closed_call_p),
125-
"concrete_aval": ("jax.core.concrete_aval is deprecated.", _src_core.concrete_aval),
125+
"concrete_aval": ("jax.core.concrete_aval is deprecated.", _src_core.abstractify),
126126
"dedup_referents": ("jax.core.dedup_referents is deprecated.", _src_core.dedup_referents),
127127
"escaped_tracer_error": ("jax.core.escaped_tracer_error is deprecated.",
128128
_src_core.escaped_tracer_error),
@@ -207,7 +207,7 @@
207207
axis_frame = _src_core.axis_frame
208208
call_p = _src_core.call_p
209209
closed_call_p = _src_core.closed_call_p
210-
concrete_aval = _src_core.concrete_aval
210+
concrete_aval = _src_core.abstractify
211211
dedup_referents = _src_core.dedup_referents
212212
escaped_tracer_error = _src_core.escaped_tracer_error
213213
extend_axis_env_nd = _src_core.extend_axis_env_nd

jax/interpreters/xla.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,8 @@
1313
# limitations under the License.
1414

1515
from jax._src.interpreters.xla import (
16-
abstractify as abstractify,
1716
canonicalize_dtype as canonicalize_dtype,
1817
canonicalize_dtype_handlers as canonicalize_dtype_handlers,
19-
pytype_aval_mappings as pytype_aval_mappings,
2018
)
2119

2220
from jax._src.dispatch import (
@@ -27,8 +25,19 @@
2725
Backend = _xc._xla.Client
2826
del _xc
2927

28+
from jax._src import core as _src_core
29+
3030
# Deprecations
3131
_deprecations = {
32+
# Added 2024-12-17
33+
"abstractify": (
34+
"jax.interpreters.xla.abstractify is deprecated.",
35+
_src_core.abstractify
36+
),
37+
"pytype_aval_mappings": (
38+
"jax.interpreters.xla.pytype_aval_mappings is deprecated.",
39+
_src_core.pytype_aval_mappings
40+
),
3241
# Finalized 2024-10-24; remove after 2025-01-24
3342
"xb": (
3443
("jax.interpreters.xla.xb was removed in JAX v0.4.36. "
@@ -44,6 +53,13 @@
4453
),
4554
}
4655

56+
import typing as _typing
4757
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
48-
__getattr__ = _deprecation_getattr(__name__, _deprecations)
58+
if _typing.TYPE_CHECKING:
59+
abstractify = _src_core.abstractify
60+
pytype_aval_mappings = _src_core.pytype_aval_mappings
61+
else:
62+
__getattr__ = _deprecation_getattr(__name__, _deprecations)
4963
del _deprecation_getattr
64+
del _typing
65+
del _src_core

tests/lax_test.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3959,8 +3959,6 @@ def setUp(self):
39593959
core.pytype_aval_mappings[FooArray] = \
39603960
lambda x: core.ShapedArray(x.shape, FooTy())
39613961
xla.canonicalize_dtype_handlers[FooArray] = lambda x: x
3962-
core.xla_pytype_aval_mappings[FooArray] = \
3963-
lambda x: core.ShapedArray(x.shape, FooTy())
39643962
pxla.shard_arg_handlers[FooArray] = shard_foo_array_handler
39653963
mlir._constant_handlers[FooArray] = foo_array_constant_handler
39663964
mlir.register_lowering(make_p, mlir.lower_fun(make_lowering, False))
@@ -3973,7 +3971,6 @@ def setUp(self):
39733971
def tearDown(self):
39743972
del core.pytype_aval_mappings[FooArray]
39753973
del xla.canonicalize_dtype_handlers[FooArray]
3976-
del core.xla_pytype_aval_mappings[FooArray]
39773974
del mlir._constant_handlers[FooArray]
39783975
del mlir._lowerings[make_p]
39793976
del mlir._lowerings[bake_p]

0 commit comments

Comments
 (0)