Skip to content

Commit 87db8ba

Browse files
authored
Set object codec for object arrays (#573)
1 parent d24f83b commit 87db8ba

File tree

3 files changed

+39
-11
lines changed

3 files changed

+39
-11
lines changed

.github/workflows/jax-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ jobs:
4444
- name: Run tests
4545
run: |
4646
# exclude tests that rely on structured types since JAX doesn't support these
47-
pytest -k "not argmax and not argmin and not mean and not apply_reduction and not broadcast_trick and not groupby"
47+
pytest -k "not argmax and not argmin and not mean and not apply_reduction and not broadcast_trick and not groupby and not object_dtype"
4848
env:
4949
CUBED_BACKEND_ARRAY_API_MODULE: jax.numpy
5050
JAX_ENABLE_X64: True

cubed/storage/backend.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,7 @@
44
from cubed.types import T_DType, T_RegularChunks, T_Shape, T_Store
55

66

7-
def open_backend_array(
8-
store: T_Store,
9-
mode: str,
10-
*,
11-
shape: Optional[T_Shape] = None,
12-
dtype: Optional[T_DType] = None,
13-
chunks: Optional[T_RegularChunks] = None,
14-
path: Optional[str] = None,
15-
**kwargs,
16-
):
7+
def backend_storage_name():
178
# get storage name from top-level config
189
# e.g. set globally with CUBED_STORAGE_NAME=tensorstore
1910
storage_name = config.get("storage_name", None)
@@ -26,10 +17,35 @@ def open_backend_array(
2617
else:
2718
storage_name = "zarr-python"
2819

20+
return storage_name
21+
22+
23+
def open_backend_array(
24+
store: T_Store,
25+
mode: str,
26+
*,
27+
shape: Optional[T_Shape] = None,
28+
dtype: Optional[T_DType] = None,
29+
chunks: Optional[T_RegularChunks] = None,
30+
path: Optional[str] = None,
31+
**kwargs,
32+
):
33+
storage_name = backend_storage_name()
34+
2935
if storage_name == "zarr-python":
3036
from cubed.storage.backends.zarr_python import open_zarr_array
3137

3238
open_func = open_zarr_array
39+
40+
# set object codec if needed
41+
import numpy as np
42+
43+
if np.dtype(dtype).hasobject and "object_codec" not in kwargs:
44+
import numcodecs
45+
46+
object_codec = numcodecs.Pickle()
47+
kwargs["object_codec"] = object_codec
48+
3349
elif storage_name == "zarr-python-v3":
3450
from cubed.storage.backends.zarr_python_v3 import open_zarr_v3_array
3551

cubed/tests/test_types.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,22 @@
1+
import pytest
12
from numpy.testing import assert_array_equal
23

4+
import cubed
35
import cubed.array_api as xp
6+
from cubed.storage.backend import backend_storage_name
47

58

69
# This is less strict than the spec, but is supported by implementations like NumPy
710
def test_prod_sum_bool():
811
a = xp.ones((2,), dtype=xp.bool)
912
assert_array_equal(xp.prod(a).compute(), xp.asarray([1], dtype=xp.int64))
1013
assert_array_equal(xp.sum(a).compute(), xp.asarray([2], dtype=xp.int64))
14+
15+
16+
@pytest.mark.skipif(
17+
backend_storage_name() != "zarr-python",
18+
reason="object dtype only works on zarr-python",
19+
)
20+
def test_object_dtype():
21+
a = xp.asarray(["a", "b"], dtype=object, chunks=2)
22+
cubed.to_zarr(a, store=None)

0 commit comments

Comments
 (0)