Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions etils/enp/array_types/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from etils.enp.array_types import dtypes
import numpy as np


_T = TypeVar('_T')

# Match both `np.dtype('int32')` and np.int32
Expand Down Expand Up @@ -98,6 +99,16 @@ def __instancecheck__(cls, instance: np.ndarray) -> bool:
"""`isinstance(array, f32['h w c'])`."""
raise NotImplementedError

try:
import dill # pylint: disable=g-import-not-at-top # pytype: disable=import-error

@dill.register(ArrayAliasMeta)
def _save_array_alias_meta(pickler, obj: ArrayAliasMeta) -> None:
args = (obj.shape, obj.dtype)
pickler.save_reduce(ArrayAliasMeta, args, obj=obj)

except ImportError:
pass

def _normalize_shape_item(item: _ShapeItem) -> ShapeSpec:
"""Returns the `str` representation associated with the shape element."""
Expand Down
19 changes: 19 additions & 0 deletions etils/enp/array_types/typing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@
import numpy as np
import pytest

try:
import dill

_DILL_AVAILABLE = True
except ImportError:
dill = None
_DILL_AVAILABLE = False

# TODO(epot): Add `bfloat16` to array_types. Not this might require some
# LazyDType to lazy-load jax.
bf16 = enp.typing.ArrayAliasMeta(shape=None, dtype=np.dtype(jnp.bfloat16))
Expand Down Expand Up @@ -89,3 +97,14 @@ def test_array_eq():
assert f32['h w'] != ui8['h w']

assert {f32['h w'], f32['h w'], f32['h', 'w']} == {f32['h w']}


@pytest.mark.skipif(not _DILL_AVAILABLE, reason='dill not available')
def test_f32_can_be_pickled_unpickled_with_dill():
assert dill is not None, (
'dill library is not available. We should have skipped this test, but for'
' some reason we did not.'
)
my_type = f32['N']
my_type_pickled_unpickled = dill.loads(dill.dumps(my_type))
assert my_type_pickled_unpickled == my_type