Skip to content

Commit 7680532

Browse files
Merge pull request #25595 from jakevdp:mv-shaped-abstractify
PiperOrigin-RevId: 707888615
2 parents ad00ec1 + 676070f commit 7680532

File tree

8 files changed

+57
-58
lines changed

8 files changed

+57
-58
lines changed

jax/_src/abstract_arrays.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,16 +56,30 @@ def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray:
5656
dtypes.check_valid_dtype(dtype)
5757
return ShapedArray(x.shape, dtypes.canonicalize_dtype(dtype))
5858

59+
def _numpy_array_abstractify(x: np.ndarray) -> ShapedArray:
60+
dtype = x.dtype
61+
dtypes.check_valid_dtype(dtype)
62+
return ShapedArray(x.shape,
63+
dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True))
64+
5965
core.pytype_aval_mappings[np.ndarray] = _make_shaped_array_for_numpy_array
66+
core.shaped_abstractify_handlers[np.ndarray] = _numpy_array_abstractify
6067

6168

6269
def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray:
6370
dtype = np.dtype(x)
6471
dtypes.check_valid_dtype(dtype)
6572
return ShapedArray(np.shape(x), dtypes.canonicalize_dtype(dtype))
6673

74+
def _np_scalar_abstractify(x: np.generic) -> ShapedArray:
75+
dtype = np.dtype(x)
76+
dtypes.check_valid_dtype(dtype)
77+
return ShapedArray(np.shape(x),
78+
dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True))
79+
6780
for t in numpy_scalar_types:
6881
core.pytype_aval_mappings[t] = _make_shaped_array_for_numpy_scalar
82+
core.shaped_abstractify_handlers[t] = _np_scalar_abstractify
6983

7084
core.literalable_types.update(array_types)
7185

@@ -76,7 +90,13 @@ def _make_abstract_python_scalar(typ, val):
7690
return ShapedArray((), dtypes._scalar_type_to_dtype(typ, val),
7791
weak_type=typ is not bool)
7892

93+
def _python_scalar_abstractify(x: int | float | complex | bool) -> ShapedArray:
94+
typ = type(x)
95+
dtype = dtypes._scalar_type_to_dtype(typ, x)
96+
return ShapedArray((), dtype, weak_type=typ in dtypes._weak_types)
97+
7998
for t in dtypes.python_scalar_dtypes:
8099
core.pytype_aval_mappings[t] = partial(_make_abstract_python_scalar, t)
100+
core.shaped_abstractify_handlers[t] = _python_scalar_abstractify
81101

82102
core.literalable_types.update(dtypes.python_scalar_dtypes.keys())

jax/_src/api_util.py

Lines changed: 5 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,10 @@
2020
from functools import partial, lru_cache
2121
from typing import Any
2222

23-
import numpy as np
24-
2523
from jax._src import core
2624
from jax._src import config
2725
from jax._src import dtypes
2826
from jax._src.state.types import AbstractRef
29-
from jax._src.abstract_arrays import numpy_scalar_types
30-
from jax._src.core import ShapedArray
3127
from jax._src.tree_util import (
3228
PyTreeDef, tree_flatten, tree_unflatten, tree_map,
3329
treedef_children, generate_key_paths, keystr, broadcast_prefix,
@@ -587,54 +583,13 @@ def _dtype(x):
587583
except ValueError:
588584
return dtypes.result_type(getattr(x, 'dtype'))
589585

590-
def _shaped_abstractify_slow(x):
591-
try:
592-
return x if isinstance(x, core.AbstractValue) else core.get_aval(x)
593-
except TypeError:
594-
pass
595-
596-
weak_type = getattr(x, 'weak_type', False)
597-
if hasattr(x, 'dtype'):
598-
dtype = dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True)
599-
else:
600-
raise TypeError(
601-
f"Cannot interpret value of type {type(x)} as an abstract array; it "
602-
"does not have a dtype attribute")
603-
return core.ShapedArray(np.shape(x), dtype, weak_type=weak_type)
604-
605586
# TODO(mattjj,yashkatariya): replace core.abstractify with this, same behavior
587+
# TODO(jakevdp): fix downstream consumers and remove this.
606588
def shaped_abstractify(x):
607-
handler = _shaped_abstractify_handlers.get(type(x), None)
608-
return handler(x) if handler is not None else _shaped_abstractify_slow(x)
609-
610-
_shaped_abstractify_handlers: dict[Any, Callable[[Any], core.ShapedArray]] = {}
611-
612-
613-
def _str_abstractify(x):
614-
raise TypeError(f"Argument '{x}' of type {type(x)} is not a valid JAX type")
615-
_shaped_abstractify_handlers[str] = _str_abstractify
616-
617-
def _numpy_array_abstractify(x: np.ndarray) -> ShapedArray:
618-
dtype = x.dtype
619-
dtypes.check_valid_dtype(dtype)
620-
return ShapedArray(x.shape,
621-
dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True))
622-
_shaped_abstractify_handlers[np.ndarray] = _numpy_array_abstractify
623-
624-
def _np_scalar_abstractify(x: np.generic) -> ShapedArray:
625-
dtype = np.dtype(x)
626-
dtypes.check_valid_dtype(dtype)
627-
return ShapedArray(np.shape(x),
628-
dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True))
629-
_shaped_abstractify_handlers.update((t, _np_scalar_abstractify)
630-
for t in numpy_scalar_types)
631-
632-
def _python_scalar_abstractify(x: int | float | complex | bool) -> ShapedArray:
633-
typ = type(x)
634-
dtype = dtypes._scalar_type_to_dtype(typ, x)
635-
return ShapedArray((), dtype, weak_type=typ in dtypes._weak_types)
636-
_shaped_abstractify_handlers.update((t, _python_scalar_abstractify)
637-
for t in dtypes.python_scalar_dtypes)
589+
return core.shaped_abstractify(x)
590+
591+
# TODO(jakevdp): fix downstream consumers and remove this.
592+
_shaped_abstractify_handlers = core.shaped_abstractify_handlers
638593

639594
# This decorator exists to make it easier to monkey-patch APIs in JAX.
640595
# By default it does nothing, but it can be monkey-patched to do other things.

jax/_src/array.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1036,7 +1036,7 @@ def _get_aval_array(self):
10361036
else:
10371037
return self.aval
10381038

1039-
api_util._shaped_abstractify_handlers[ArrayImpl] = _get_aval_array
1039+
core.shaped_abstractify_handlers[ArrayImpl] = _get_aval_array
10401040
core.pytype_aval_mappings[ArrayImpl] = _get_aval_array
10411041

10421042
# TODO(jakevdp) replace this with true inheritance at the C++ level.

jax/_src/core.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1400,6 +1400,29 @@ def check_valid_jaxtype(x):
14001400
f"Value {x!r} of type {type(x)} is not a valid JAX type")
14011401

14021402

1403+
def _shaped_abstractify_slow(x):
1404+
try:
1405+
return x if isinstance(x, AbstractValue) else get_aval(x)
1406+
except TypeError:
1407+
pass
1408+
1409+
weak_type = getattr(x, 'weak_type', False)
1410+
if hasattr(x, 'dtype'):
1411+
dtype = dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True)
1412+
else:
1413+
raise TypeError(
1414+
f"Cannot interpret value of type {type(x)} as an abstract array; it "
1415+
"does not have a dtype attribute")
1416+
return ShapedArray(np.shape(x), dtype, weak_type=weak_type)
1417+
1418+
# TODO(jakevdp): deduplicate this with abstractify
1419+
def shaped_abstractify(x):
1420+
# This was originally api_util.shaped_abstractify; temporarily moved
1421+
# here in order to facilitate combining it with abstractify.
1422+
handler = shaped_abstractify_handlers.get(type(x), None)
1423+
return handler(x) if handler is not None else _shaped_abstractify_slow(x)
1424+
1425+
14031426
def abstractify(x):
14041427
for typ in type(x).__mro__:
14051428
aval_fn = pytype_aval_mappings.get(typ)
@@ -1809,7 +1832,11 @@ def to_tangent_aval(self):
18091832
self.weak_type)
18101833

18111834
pytype_aval_mappings: dict[type, Callable[[Any], AbstractValue]] = {}
1835+
shaped_abstractify_handlers: dict[Any, Callable[[Any], ShapedArray]] = {}
18121836

1837+
def _str_abstractify(x):
1838+
raise TypeError(f"Argument '{x}' of type {type(x)} is not a valid JAX type")
1839+
shaped_abstractify_handlers[str] = _str_abstractify
18131840

18141841
class DArray:
18151842
_aval: DShapedArray

jax/_src/earray.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import math
1818

19-
from jax._src import api_util
2019
from jax._src import basearray
2120
from jax._src import core
2221
from jax._src import tree_util
@@ -116,7 +115,7 @@ def _earray_shard_arg_handler(xs, shardings, layouts, copy_semantics):
116115
return pxla.shard_args(phys_shardings, layouts, copy_semantics, arrs)
117116
pxla.shard_arg_handlers[EArray] = _earray_shard_arg_handler
118117

119-
api_util._shaped_abstractify_handlers[EArray] = lambda self: self.aval
118+
core.shaped_abstractify_handlers[EArray] = lambda self: self.aval
120119
core.pytype_aval_mappings[EArray] = lambda x: x.aval
121120
xla.canonicalize_dtype_handlers[EArray] = lambda x: x
122121
tree_util.dispatch_registry.register_node(

jax/_src/interpreters/partial_eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1572,7 +1572,7 @@ def get_referent(self):
15721572

15731573
def _dynamic_jaxpr_tracer_shaped_abstractify(x):
15741574
return x.aval
1575-
api_util._shaped_abstractify_handlers[DynamicJaxprTracer] = _dynamic_jaxpr_tracer_shaped_abstractify
1575+
core.shaped_abstractify_handlers[DynamicJaxprTracer] = _dynamic_jaxpr_tracer_shaped_abstractify
15761576

15771577
def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects:
15781578
sentinel = object()

jax/_src/numpy/lax_numpy.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
from jax import errors
4444
from jax import jit
4545
from jax import lax
46-
from jax._src import api_util
4746
from jax._src import config
4847
from jax._src import core
4948
from jax._src import deprecations
@@ -192,7 +191,7 @@ def __instancecheck__(self, instance: Any) -> bool:
192191

193192
def _abstractify_scalar_meta(x):
194193
raise TypeError(f"JAX scalar type {x} cannot be interpreted as a JAX array.")
195-
api_util._shaped_abstractify_handlers[_ScalarMeta] = _abstractify_scalar_meta
194+
core.shaped_abstractify_handlers[_ScalarMeta] = _abstractify_scalar_meta
196195

197196
def _make_scalar_type(np_scalar_type: type) -> _ScalarMeta:
198197
meta = _ScalarMeta(np_scalar_type.__name__, (object,),

jax/_src/prng.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from jax import numpy as jnp
2727
from jax import tree_util
2828

29-
from jax._src import api_util
3029
from jax._src import api
3130
from jax._src import config as config
3231
from jax._src import core
@@ -303,7 +302,6 @@ def transpose(self, *_, **__) -> PRNGKeyArray: assert False
303302
'at', 'flatten', 'ravel', 'reshape',
304303
'squeeze', 'swapaxes', 'take', 'transpose', 'T'])
305304

306-
api_util._shaped_abstractify_handlers[PRNGKeyArray] = op.attrgetter('aval')
307305

308306
def prngkeyarray_flatten(x):
309307
return (x._base_array,), x._impl
@@ -463,6 +461,7 @@ def __hash__(self) -> int:
463461

464462

465463
core.pytype_aval_mappings[PRNGKeyArray] = lambda x: x.aval
464+
core.shaped_abstractify_handlers[PRNGKeyArray] = op.attrgetter('aval')
466465

467466
xla.canonicalize_dtype_handlers[PRNGKeyArray] = lambda x: x
468467

0 commit comments

Comments
 (0)