|
20 | 20 | from functools import partial, lru_cache |
21 | 21 | from typing import Any |
22 | 22 |
|
23 | | -import numpy as np |
24 | | - |
25 | 23 | from jax._src import core |
26 | 24 | from jax._src import config |
27 | 25 | from jax._src import dtypes |
28 | 26 | from jax._src.state.types import AbstractRef |
29 | | -from jax._src.abstract_arrays import numpy_scalar_types |
30 | | -from jax._src.core import ShapedArray |
31 | 27 | from jax._src.tree_util import ( |
32 | 28 | PyTreeDef, tree_flatten, tree_unflatten, tree_map, |
33 | 29 | treedef_children, generate_key_paths, keystr, broadcast_prefix, |
@@ -587,54 +583,13 @@ def _dtype(x): |
587 | 583 | except ValueError: |
588 | 584 | return dtypes.result_type(getattr(x, 'dtype')) |
589 | 585 |
|
590 | | -def _shaped_abstractify_slow(x): |
591 | | - try: |
592 | | - return x if isinstance(x, core.AbstractValue) else core.get_aval(x) |
593 | | - except TypeError: |
594 | | - pass |
595 | | - |
596 | | - weak_type = getattr(x, 'weak_type', False) |
597 | | - if hasattr(x, 'dtype'): |
598 | | - dtype = dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True) |
599 | | - else: |
600 | | - raise TypeError( |
601 | | - f"Cannot interpret value of type {type(x)} as an abstract array; it " |
602 | | - "does not have a dtype attribute") |
603 | | - return core.ShapedArray(np.shape(x), dtype, weak_type=weak_type) |
604 | | - |
605 | 586 | # TODO(mattjj,yashkatariya): replace core.abstractify with this, same behavior |
| 587 | +# TODO(jakevdp): fix downstream consumers and remove this. |
606 | 588 | def shaped_abstractify(x): |
607 | | - handler = _shaped_abstractify_handlers.get(type(x), None) |
608 | | - return handler(x) if handler is not None else _shaped_abstractify_slow(x) |
609 | | - |
610 | | -_shaped_abstractify_handlers: dict[Any, Callable[[Any], core.ShapedArray]] = {} |
611 | | - |
612 | | - |
613 | | -def _str_abstractify(x): |
614 | | - raise TypeError(f"Argument '{x}' of type {type(x)} is not a valid JAX type") |
615 | | -_shaped_abstractify_handlers[str] = _str_abstractify |
616 | | - |
617 | | -def _numpy_array_abstractify(x: np.ndarray) -> ShapedArray: |
618 | | - dtype = x.dtype |
619 | | - dtypes.check_valid_dtype(dtype) |
620 | | - return ShapedArray(x.shape, |
621 | | - dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True)) |
622 | | -_shaped_abstractify_handlers[np.ndarray] = _numpy_array_abstractify |
623 | | - |
624 | | -def _np_scalar_abstractify(x: np.generic) -> ShapedArray: |
625 | | - dtype = np.dtype(x) |
626 | | - dtypes.check_valid_dtype(dtype) |
627 | | - return ShapedArray(np.shape(x), |
628 | | - dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True)) |
629 | | -_shaped_abstractify_handlers.update((t, _np_scalar_abstractify) |
630 | | - for t in numpy_scalar_types) |
631 | | - |
632 | | -def _python_scalar_abstractify(x: int | float | complex | bool) -> ShapedArray: |
633 | | - typ = type(x) |
634 | | - dtype = dtypes._scalar_type_to_dtype(typ, x) |
635 | | - return ShapedArray((), dtype, weak_type=typ in dtypes._weak_types) |
636 | | -_shaped_abstractify_handlers.update((t, _python_scalar_abstractify) |
637 | | - for t in dtypes.python_scalar_dtypes) |
| 589 | + return core.shaped_abstractify(x) |
| 590 | + |
| 591 | +# TODO(jakevdp): fix downstream consumers and remove this. |
| 592 | +_shaped_abstractify_handlers = core.shaped_abstractify_handlers |
638 | 593 |
|
639 | 594 | # This decorator exists to make it easier to monkey-patch APIs in JAX. |
640 | 595 | # By default it does nothing, but it can be monkey-patched to do other things. |
|
0 commit comments