Skip to content

Commit 0fa5419

Browse files
Merge pull request #25456 from jakevdp:xla-abstractify
PiperOrigin-RevId: 707175097
2 parents 7fe2579 + 2c722d9 commit 0fa5419

File tree

15 files changed

+79
-59
lines changed

15 files changed

+79
-59
lines changed

benchmarks/api_benchmark.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
from jax import lax
2424
from jax._src.api_util import shaped_abstractify # technically not an api fn
2525
from jax._src.ad_checkpoint import checkpoint # new jax.remat implementation
26+
from jax._src import core
2627
from jax._src.lib import xla_client as xc
27-
from jax.interpreters import xla
2828
from jax._src import array
2929
from jax._src import op_shardings
3030
from jax._src.pjit import pjit_check_aval_sharding
@@ -427,7 +427,7 @@ def bench_shaped_abstractify(state):
427427

428428
def _run_benchmark_for_xla_abstractify(arg, state):
429429
while state:
430-
xla.abstractify(arg)
430+
core.abstractify(arg)
431431

432432
def bench_xla_abstractify():
433433
_abstractify_args = [

jax/_src/abstract_arrays.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,23 +50,49 @@ def canonical_concrete_aval(val, weak_type=None):
5050
sharding = core._get_abstract_sharding(val)
5151
return ShapedArray(np.shape(val), dtype, weak_type=weak_type, sharding=sharding)
5252

53+
5354
def masked_array_error(*args, **kwargs):
5455
raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. "
5556
"Use arr.filled() to convert the value to a standard numpy array.")
5657

5758
core.pytype_aval_mappings[np.ma.MaskedArray] = masked_array_error
5859

59-
for t in array_types:
60+
61+
def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray:
62+
dtype = x.dtype
63+
dtypes.check_valid_dtype(dtype)
64+
return ShapedArray(x.shape, dtypes.canonicalize_dtype(dtype))
65+
66+
core.pytype_aval_mappings[np.ndarray] = canonical_concrete_aval
67+
core.xla_pytype_aval_mappings[np.ndarray] = _make_shaped_array_for_numpy_array
68+
69+
70+
def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray:
71+
dtype = np.dtype(x)
72+
dtypes.check_valid_dtype(dtype)
73+
return ShapedArray(np.shape(x), dtypes.canonicalize_dtype(dtype))
74+
75+
for t in numpy_scalar_types:
6076
core.pytype_aval_mappings[t] = canonical_concrete_aval
77+
core.xla_pytype_aval_mappings[t] = _make_shaped_array_for_numpy_scalar
6178

6279
core.literalable_types.update(array_types)
6380

81+
6482
def _make_concrete_python_scalar(t, x):
6583
dtype = dtypes._scalar_type_to_dtype(t, x)
6684
weak_type = dtypes.is_weakly_typed(x)
6785
return canonical_concrete_aval(np.array(x, dtype=dtype), weak_type=weak_type)
6886

87+
88+
def _make_abstract_python_scalar(typ, val):
89+
# Note: all python scalar types are weak except bool, because bool only
90+
# comes in a single width.
91+
return ShapedArray((), dtypes._scalar_type_to_dtype(typ, val),
92+
weak_type=typ is not bool)
93+
6994
for t in dtypes.python_scalar_dtypes:
7095
core.pytype_aval_mappings[t] = partial(_make_concrete_python_scalar, t)
96+
core.xla_pytype_aval_mappings[t] = partial(_make_abstract_python_scalar, t)
7197

7298
core.literalable_types.update(dtypes.python_scalar_dtypes.keys())

jax/_src/api_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -600,7 +600,7 @@ def _shaped_abstractify_slow(x):
600600
"does not have a dtype attribute")
601601
return core.ShapedArray(np.shape(x), dtype, weak_type=weak_type)
602602

603-
# TODO(mattjj,yashkatariya): replace xla.abstractify with this, same behavior
603+
# TODO(mattjj,yashkatariya): replace core.abstractify with this, same behavior
604604
def shaped_abstractify(x):
605605
handler = _shaped_abstractify_handlers.get(type(x), None)
606606
return handler(x) if handler is not None else _shaped_abstractify_slow(x)

jax/_src/array.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1029,7 +1029,7 @@ def make_array_from_single_device_arrays(
10291029

10301030

10311031
core.pytype_aval_mappings[ArrayImpl] = abstract_arrays.canonical_concrete_aval
1032-
xla.pytype_aval_mappings[ArrayImpl] = op.attrgetter('aval')
1032+
core.xla_pytype_aval_mappings[ArrayImpl] = op.attrgetter('aval')
10331033
xla.canonicalize_dtype_handlers[ArrayImpl] = pxla.identity
10341034
def _get_aval_array(self):
10351035
if config.sharding_in_types.value and isinstance(self.sharding, NamedSharding):

jax/_src/core.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1400,7 +1400,16 @@ def check_valid_jaxtype(x):
14001400
f"Value {x!r} of type {type(x)} is not a valid JAX type")
14011401

14021402

1403+
# TODO(jakevdp): merge concrete_aval and abstractify to the extent possible.
1404+
# This is tricky because concrete_aval includes sharding information, and
1405+
# abstractify does not; further, because abstractify is in the dispatch path,
1406+
# performance is important and simply adding sharding there is not an option.
14031407
def concrete_aval(x):
1408+
# This differs from abstractify below in that the abstract values
1409+
# include sharding where applicable. Historically (before stackless)
1410+
# the returned avals were concrete, but after the stackless change
1411+
# this returns ShapedArray like abstractify.
1412+
# Rules are registered in pytype_aval_mappings.
14041413
for typ in type(x).__mro__:
14051414
handler = pytype_aval_mappings.get(typ)
14061415
if handler: return handler(x)
@@ -1410,6 +1419,22 @@ def concrete_aval(x):
14101419
"type")
14111420

14121421

1422+
def abstractify(x):
1423+
# Historically, this was called xla.abstractify. It differs from
1424+
# concrete_aval in that it excludes sharding information, and
1425+
# uses a more performant path for accessing avals. Rules are
1426+
# registered in xla_pytype_aval_mappings.
1427+
typ = type(x)
1428+
aval_fn = xla_pytype_aval_mappings.get(typ)
1429+
if aval_fn: return aval_fn(x)
1430+
for typ in typ.__mro__:
1431+
aval_fn = xla_pytype_aval_mappings.get(typ)
1432+
if aval_fn: return aval_fn(x)
1433+
if hasattr(x, '__jax_array__'):
1434+
return abstractify(x.__jax_array__())
1435+
raise TypeError(f"Argument '{x}' of type '{type(x)}' is not a valid JAX type")
1436+
1437+
14131438
def get_aval(x):
14141439
if isinstance(x, Tracer):
14151440
return x.aval
@@ -1810,6 +1835,7 @@ def to_tangent_aval(self):
18101835
self.weak_type)
18111836

18121837
pytype_aval_mappings: dict[type, Callable[[Any], AbstractValue]] = {}
1838+
xla_pytype_aval_mappings: dict[type, Callable[[Any], AbstractValue]] = {}
18131839

18141840

18151841
class DArray:
@@ -1866,6 +1892,7 @@ def data(self):
18661892

18671893
pytype_aval_mappings[DArray] = \
18681894
lambda x: DShapedArray(x._aval.shape, x._aval.dtype, x._aval.weak_type)
1895+
xla_pytype_aval_mappings[DArray] = lambda x: x._aval
18691896

18701897
@dataclass(frozen=True)
18711898
class bint(dtypes.ExtendedDType):
@@ -1898,6 +1925,7 @@ def __getitem__(self, idx): return get_aval(self)._getitem(self, idx)
18981925
def __setitem__(self, idx, x): return get_aval(self)._setitem(self, idx, x)
18991926
def __repr__(self) -> str: return 'Mutable' + repr(self[...])
19001927
pytype_aval_mappings[MutableArray] = lambda x: x._aval
1928+
xla_pytype_aval_mappings[MutableArray] = lambda x: x._aval
19011929

19021930
def mutable_array(init_val):
19031931
return mutable_array_p.bind(init_val)
@@ -1951,6 +1979,7 @@ def __init__(self, buf):
19511979
def block_until_ready(self):
19521980
self._buf.block_until_ready()
19531981
pytype_aval_mappings[Token] = lambda _: abstract_token
1982+
xla_pytype_aval_mappings[Token] = lambda _: abstract_token
19541983

19551984

19561985
# TODO(dougalm): Deprecate these. They're just here for backwards compat.

jax/_src/dispatch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,7 @@ def _device_put_impl(
457457
" please provide a concrete Sharding with memory_kind.")
458458

459459
try:
460-
aval = xla.abstractify(x)
460+
aval = core.abstractify(x)
461461
except TypeError as err:
462462
raise TypeError(
463463
f"Argument '{x}' of type {type(x)} is not a valid JAX type") from err

jax/_src/export/shape_poly.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
import opt_einsum
3636

3737
import jax
38-
from jax.interpreters import xla
3938

4039
from jax._src import config
4140
from jax._src import core
@@ -1206,7 +1205,7 @@ def _geq_decision(e1: DimSize, e2: DimSize, cmp_str: Callable[[], str]) -> bool:
12061205
f"Symbolic dimension comparison {cmp_str()} is inconclusive.{describe_scope}")
12071206

12081207
core.pytype_aval_mappings[_DimExpr] = _DimExpr._get_aval
1209-
xla.pytype_aval_mappings[_DimExpr] = _DimExpr._get_aval
1208+
core.xla_pytype_aval_mappings[_DimExpr] = _DimExpr._get_aval
12101209
dtypes._weak_types.append(_DimExpr)
12111210

12121211
def _convertible_to_int(p: DimSize) -> bool:

jax/_src/interpreters/mlir.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1825,7 +1825,7 @@ def read(v: core.Atom) -> IrValues:
18251825

18261826
def aval(v: core.Atom) -> core.AbstractValue:
18271827
if type(v) is core.Literal:
1828-
return xla.abstractify(v.val)
1828+
return core.abstractify(v.val)
18291829
else:
18301830
return v.aval
18311831

jax/_src/interpreters/pxla.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ def _emap_apply_fn(*args):
349349
donated_invars=donated_invars,
350350
is_explicit_global_axis_size=is_explicit_global_axis_size)
351351
return _emap_apply_fn
352-
abstract_args = unsafe_map(xla.abstractify, args)
352+
abstract_args = unsafe_map(core.abstractify, args)
353353
compiled_fun, fingerprint = parallel_callable(
354354
fun, backend, axis_name, axis_size, global_axis_size, devices, name,
355355
in_axes, out_axes_thunk, donated_invars,
@@ -360,7 +360,7 @@ def _emap_apply_fn(*args):
360360
distributed_debug_log(("Running pmapped function", name),
361361
("python function", fun.f),
362362
("devices", devices),
363-
("abstract args", map(xla.abstractify, args)),
363+
("abstract args", map(core.abstractify, args)),
364364
("fingerprint", fingerprint))
365365
return compiled_fun
366366

@@ -598,7 +598,7 @@ def __init__(self, trace: MapTrace, val, shard_axes: dict[core.AxisName, int]):
598598

599599
@property
600600
def aval(self):
601-
aval = xla.abstractify(self.val)
601+
aval = core.abstractify(self.val)
602602
shard_axes = dict(self.shard_axes)
603603
for axis_idx in sorted(shard_axes.values())[::-1]:
604604
aval = core.mapped_aval(aval.shape[axis_idx], axis_idx, aval)
@@ -1145,7 +1145,7 @@ def xla_extension_executable(self):
11451145
@profiler.annotate_function
11461146
def call(self, *args):
11471147
# TODO(frostig): do we need to check sharding and sharded avals?
1148-
arg_avals = map(xla.abstractify, args)
1148+
arg_avals = map(core.abstractify, args)
11491149
check_arg_avals_for_call(self.in_avals, arg_avals, self._jaxpr_debug_info)
11501150
return self.unsafe_call(*args) # pylint: disable=not-callable
11511151

@@ -3092,7 +3092,7 @@ def call(self, *args):
30923092
ref_avals = self._all_args_info.in_avals
30933093
debug_info = self._all_args_info.debug_info
30943094

3095-
all_arg_avals = map(xla.abstractify, kept_args)
3095+
all_arg_avals = map(core.abstractify, kept_args)
30963096
check_arg_avals_for_call(ref_avals, all_arg_avals, debug_info)
30973097
check_array_xla_sharding_layout_match(
30983098
args_after_dce, self._in_shardings, self._xla_in_layouts, debug_info,

jax/_src/interpreters/xla.py

Lines changed: 4 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -146,44 +146,12 @@ def _canonicalize_python_scalar_dtype(typ, x):
146146
canonicalize_dtype_handlers[core.DArray] = identity
147147
canonicalize_dtype_handlers[core.MutableArray] = identity
148148

149+
# TODO(jakevdp): deprecate and remove this.
149150
def abstractify(x) -> Any:
150-
typ = type(x)
151-
aval_fn = pytype_aval_mappings.get(typ)
152-
if aval_fn: return aval_fn(x)
153-
for typ in typ.__mro__:
154-
aval_fn = pytype_aval_mappings.get(typ)
155-
if aval_fn: return aval_fn(x)
156-
if hasattr(x, '__jax_array__'):
157-
return abstractify(x.__jax_array__())
158-
raise TypeError(f"Argument '{x}' of type '{type(x)}' is not a valid JAX type")
159-
160-
def _make_abstract_python_scalar(typ, val):
161-
# Note: all python scalar types are weak except bool, because bool only
162-
# comes in a single width.
163-
return ShapedArray((), dtypes._scalar_type_to_dtype(typ, val),
164-
weak_type=typ is not bool)
165-
166-
def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray:
167-
dtype = np.dtype(x)
168-
dtypes.check_valid_dtype(dtype)
169-
return ShapedArray(np.shape(x), dtypes.canonicalize_dtype(dtype))
170-
171-
def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray:
172-
dtype = x.dtype
173-
dtypes.check_valid_dtype(dtype)
174-
return ShapedArray(x.shape, dtypes.canonicalize_dtype(dtype))
175-
176-
177-
pytype_aval_mappings: dict[Any, Callable[[Any], core.AbstractValue]] = {}
178-
pytype_aval_mappings[core.DArray] = lambda x: x._aval
179-
pytype_aval_mappings[core.MutableArray] = lambda x: x._aval
180-
pytype_aval_mappings[core.Token] = lambda _: core.abstract_token
181-
pytype_aval_mappings.update((t, _make_shaped_array_for_numpy_scalar)
182-
for t in numpy_scalar_types)
183-
pytype_aval_mappings[np.ndarray] = _make_shaped_array_for_numpy_array
184-
pytype_aval_mappings.update(
185-
(t, partial(_make_abstract_python_scalar, t)) for t in _scalar_types)
151+
return core.abstractify(x)
186152

153+
# TODO(jakevdp): deprecate and remove this.
154+
pytype_aval_mappings: dict[Any, Callable[[Any], core.AbstractValue]] = core.xla_pytype_aval_mappings
187155

188156
initial_style_primitives: set[core.Primitive] = set()
189157

0 commit comments

Comments
 (0)