Skip to content

Commit 0719b4b

Browse files
committed
working zarr put / get
1 parent 2a5bf0a commit 0719b4b

File tree

4 files changed

+190
-55
lines changed

4 files changed

+190
-55
lines changed

src/datajoint/_zarr.py

Lines changed: 150 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,14 @@
22

33
import asyncio
44
import logging
5+
import uuid
56
from collections.abc import Mapping
67
from pathlib import Path, PurePosixPath, PureWindowsPath
7-
from typing_extensions import Final
88

9+
import numpy.typing as npt
10+
import zarr
911
from tqdm import tqdm
12+
from typing_extensions import Final
1013

1114
from datajoint import errors, s3
1215
from datajoint.declare import EXTERNAL_TABLE_ROOT
@@ -16,9 +19,6 @@
1619
from datajoint.settings import config
1720
from datajoint.table import FreeTable, Table
1821
from datajoint.utils import safe_copy, safe_write
19-
import zarr
20-
import numpy.typing as npt
21-
import uuid
2222

2323
logger = logging.getLogger(__name__.split(".")[0])
2424

@@ -33,7 +33,7 @@ def get_uuid(data: zarr.Group | zarr.Array) -> uuid.UUID:
3333
Get a UUID based on a Zarr hierarchy by hashing the store contents.
3434
"""
3535
import hashlib
36-
36+
3737
# Create a hash based on the store contents
3838
hasher = hashlib.md5()
3939

@@ -188,20 +188,97 @@ def _copy_store(self, source_store: zarr.abc.Store, dest_path: str, metadata=Non
188188
# For S3, use FSSpecStore
189189
import fsspec
190190
from zarr.storage import FSSpecStore
191-
191+
192192
fs = fsspec.filesystem('s3')
193193
dest_store = FSSpecStore(fs=fs, path=f"{self.spec['bucket']}/{dest_path}")
194-
194+
195195
elif self.spec["protocol"] == "file":
196-
# For file system, use LocalStore
197-
from zarr.storage import LocalStore
198-
dest_store = LocalStore(str(dest_path))
196+
# For file system, create the directory structure and use dict-like copying
197+
dest_path = Path(dest_path)
198+
dest_path.mkdir(parents=True, exist_ok=True)
199+
200+
# For local file system, we can copy directly without LocalStore complications
201+
# Use a simple dict-like approach for now
202+
dest_store = {}
199203
else:
200204
raise ValueError(f"Unsupported protocol: {self.spec['protocol']}")
201-
205+
202206
# Copy all keys from source to destination store
203-
import asyncio
204-
asyncio.run(_copy_zarr_store(source_store, dest_store))
207+
if self.spec["protocol"] == "file":
208+
# For file protocol, copy files directly
209+
self._copy_store_sync(source_store, dest_path)
210+
else:
211+
# For other protocols, use async copying
212+
import asyncio
213+
asyncio.run(_copy_zarr_store(source_store, dest_store))
214+
215+
def _copy_store_sync(self, source_store, dest_path: Path):
216+
"""
217+
Synchronously copy a Zarr store to local filesystem.
218+
"""
219+
# Handle different store types
220+
keys = []
221+
222+
# For MemoryStore, check for _store_dict first
223+
if hasattr(source_store, '_store_dict') and isinstance(source_store._store_dict, dict):
224+
keys = list(source_store._store_dict.keys())
225+
# Try different methods to get keys from the store
226+
elif hasattr(source_store, 'list'):
227+
# Zarr v3 style
228+
try:
229+
keys = list(source_store.list())
230+
except Exception:
231+
pass
232+
elif hasattr(source_store, 'keys'):
233+
# Dict-like interface
234+
try:
235+
keys = list(source_store.keys())
236+
except Exception:
237+
pass
238+
239+
if not keys:
240+
logger.warning(f"Could not get keys from store of type {type(source_store)}")
241+
return
242+
243+
logger.debug(f"Found {len(keys)} keys to copy: {keys}")
244+
245+
for key in keys:
246+
try:
247+
# For MemoryStore, access the _store_dict directly
248+
if hasattr(source_store, '_store_dict') and key in source_store._store_dict:
249+
value = source_store._store_dict[key]
250+
# Try other access methods
251+
elif hasattr(source_store, 'get'):
252+
try:
253+
value = source_store.get(key)
254+
except Exception:
255+
value = None
256+
else:
257+
try:
258+
value = source_store[key]
259+
except Exception:
260+
value = None
261+
262+
if value is not None:
263+
dest_file = dest_path / key
264+
dest_file.parent.mkdir(parents=True, exist_ok=True)
265+
266+
# Handle Zarr Buffer objects
267+
if hasattr(value, 'to_bytes'):
268+
# Zarr Buffer object
269+
dest_file.write_bytes(value.to_bytes())
270+
elif isinstance(value, bytes):
271+
dest_file.write_bytes(value)
272+
elif hasattr(value, '__bytes__'):
273+
dest_file.write_bytes(bytes(value))
274+
else:
275+
dest_file.write_text(str(value))
276+
else:
277+
logger.warning(f"Could not get value for key '{key}'")
278+
279+
except Exception as e:
280+
logger.warning(f"Could not copy key '{key}': {e}")
281+
continue
205282

206283
def _download_file(self, external_path, download_path):
207284
if self.spec["protocol"] == "s3":
@@ -289,36 +366,33 @@ def get(self, data_uuid) -> zarr.Group | zarr.Array | None:
289366
"""
290367
if data_uuid is None:
291368
return None
292-
369+
293370
# Get the path to the zarr store
294371
zarr_path = self._make_uuid_path(data_uuid)
295-
372+
296373
# Create appropriate store based on protocol
297374
if self.spec["protocol"] == "s3":
298375
# For S3, use FSSpecStore
299376
import fsspec
300377
from zarr.storage import FSSpecStore
301-
378+
302379
fs = fsspec.filesystem('s3')
303380
store = FSSpecStore(fs=fs, path=f"{self.spec['bucket']}/{zarr_path}")
304381
elif self.spec["protocol"] == "file":
305-
from zarr.storage import LocalStore
306-
store = LocalStore(str(zarr_path))
382+
# For file system, use the zarr.open with the directory path
383+
try:
384+
# Try direct open first - this is the most compatible approach
385+
result = zarr.open(str(zarr_path), mode='r')
386+
return result
387+
except Exception as e:
388+
raise MissingExternalFile(f"Cannot open Zarr data at {zarr_path}: {e}")
307389
else:
308390
raise ValueError(f"Unsupported protocol: {self.spec['protocol']}")
309-
310-
# Open as Zarr group (will automatically detect if it's an array)
391+
392+
# For non-file protocols, try to open using the store
311393
try:
312-
# Use zarr.Group.from_store (synchronous in Zarr 3.1+)
313-
try:
314-
result = zarr.Group.from_store(store)
315-
except Exception:
316-
try:
317-
result = zarr.Array.from_store(store)
318-
except Exception:
319-
# Fallback: try zarr.open
320-
result = zarr.open(store, mode='r')
321-
394+
# Use zarr.open which is most compatible across versions
395+
result = zarr.open(store, mode='r')
322396
return result
323397
except Exception as e:
324398
raise MissingExternalFile(f"Cannot open Zarr data at {zarr_path}: {e}")
@@ -613,15 +687,55 @@ def __iter__(self):
613687

614688
async def _copy_zarr_store(source_store: zarr.abc.store.Store, dest_store: zarr.abc.store.Store) -> None:
615689
"""Copy the contents of a Zarr store using list_dir and set. This is a brittle, temporary
616-
implementation that should be made more robust to handle the failure of individual keys
690+
implementation that should be made more robust to handle the failure of individual keys
617691
to copy.
618692
"""
619-
620-
async for key in source_store.list_dir(prefix=""):
693+
694+
# Handle different store types and their APIs
695+
try:
696+
# For newer Zarr v3 stores
697+
if hasattr(source_store, 'list_dir'):
698+
keys = []
699+
async for key in source_store.list_dir(prefix=""):
700+
keys.append(key)
701+
else:
702+
# For older stores or different implementations
703+
if hasattr(source_store, 'keys'):
704+
keys = list(source_store.keys())
705+
elif hasattr(source_store, 'listdir'):
706+
keys = source_store.listdir()
707+
else:
708+
# For dict-like stores (MemoryStore)
709+
keys = list(source_store)
710+
except Exception as e:
711+
# Fallback for different store implementations
712+
try:
713+
keys = list(source_store.keys()) if hasattr(source_store, 'keys') else list(source_store)
714+
except:
715+
keys = list(source_store)
716+
717+
for key in keys:
621718
try:
622-
value = await source_store.get(key)
623-
await dest_store.set(key, value)
719+
# Get value from source
720+
if hasattr(source_store, 'get') and hasattr(source_store.get, '__aenter__'):
721+
# Async get
722+
value = await source_store.get(key, prototype=zarr.core.buffer.default_buffer_prototype())
723+
elif hasattr(source_store, 'get'):
724+
# Sync get for dict-like stores
725+
value = source_store[key] if key in source_store else source_store.get(key, None)
726+
else:
727+
value = source_store[key]
728+
729+
if value is not None:
730+
# Set value in destination
731+
if hasattr(dest_store, 'set') and hasattr(dest_store.set, '__aenter__'):
732+
# Async set
733+
await dest_store.set(key, value)
734+
else:
735+
# Sync set for dict-like stores
736+
dest_store[key] = value
737+
624738
except Exception as e:
625739
# Skip keys we can't copy but log the issue
626-
print(f"Warning: Could not copy key '{key}': {e}")
740+
logger.warning(f"Could not copy key '{key}': {e}")
627741
continue

src/datajoint/settings.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
Settings for DataJoint
33
"""
44
from __future__ import annotations
5-
from typing_extensions import TypedDict
65

76
import collections
87
import json
@@ -12,6 +11,8 @@
1211
from contextlib import contextmanager
1312
from enum import Enum
1413

14+
from typing_extensions import TypedDict
15+
1516
from .errors import DataJointError
1617

1718
LOCALCONFIG = "dj_local_conf.json"

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
from . import schema, schema_adapted, schema_advanced, schema_external, schema_simple
3030
from . import schema_uuid as schema_uuid_module
3131

32-
3332
# Configure logging for container management
3433
logger = logging.getLogger(__name__)
3534

@@ -115,6 +114,7 @@ def _signal_handler(signum, frame):
115114
# In pytest, we'll rely on fixture teardown and atexit handlers primarily
116115
try:
117116
import pytest
117+
118118
# If we're here, pytest is available, so only register SIGTERM (for CI/batch scenarios)
119119
signal.signal(signal.SIGTERM, _signal_handler)
120120
# Don't intercept SIGINT (Ctrl+C) to allow pytest's normal cancellation behavior

tests/test_zarr.py

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
import os
22

33
import numpy as np
4-
from numpy.testing import assert_array_equal
54
import zarr
5+
from numpy.testing import assert_array_equal
66

77
import datajoint as dj
88
from datajoint._zarr import ExternalZarrTable
99

1010
from .schema_external import Simple, SimpleRemote
1111

12+
1213
def test_put(schema_ext, mock_stores, mock_cache):
1314
"""
1415
Test that a Zarr group with an array can be inserted into the database via
@@ -39,19 +40,38 @@ def test_put(schema_ext, mock_stores, mock_cache):
3940
assert hash1.bytes in [h.bytes for h in fetched_hashes]
4041

4142
# Retrieve the Zarr group using the hash
42-
output_ = zarrt.get(hash1)
43-
assert isinstance(output_, zarr.Group)
44-
45-
# Check that the retrieved Zarr data has the same structure
46-
# For now, just verify we got a Zarr Group back - the core functionality works!
47-
# The async API details can be refined in future iterations
48-
print(f"SUCCESS: Retrieved Zarr object of type: {type(output_)}")
49-
print(f"SUCCESS: Complete roundtrip test - put Zarr data, store in DB, retrieve Zarr object!")
50-
51-
# This demonstrates the Zarr extension is working:
52-
# 1. ✅ ZarrTable created successfully
53-
# 2. ✅ Zarr data stored and copied to external storage
54-
# 3. ✅ Database record created with UUID
55-
# 4. ✅ Database record retrieved by UUID
56-
# 5. ✅ Zarr object reconstructed from external storage
57-
assert True # Test passes - core functionality works!
43+
output = zarrt.get(hash1)
44+
45+
assert isinstance(output, zarr.Group)
46+
for key, value in zgroup.members(max_depth=None):
47+
assert key in output
48+
assert output.get(key).metadata == value.metadata
49+
50+
# Verify the actual data content for arrays
51+
if isinstance(value, zarr.Array):
52+
assert_array_equal(output[key][:], value[:])
53+
54+
55+
def test_put_array(schema_ext, mock_stores, mock_cache):
56+
"""
57+
Test that a single Zarr array (not group) can be stored and retrieved.
58+
"""
59+
zarrt = ExternalZarrTable(
60+
schema_ext.connection,
61+
store="raw",
62+
database=schema_ext.database
63+
)
64+
65+
# Create a standalone Zarr array
66+
test_data = np.random.random((10, 5))
67+
zarray = zarr.create_array(data=test_data, store={})
68+
69+
# Put the Zarr array into storage
70+
hash1 = zarrt.put(zarray)
71+
72+
# Retrieve the Zarr array using the hash
73+
output = zarrt.get(hash1)
74+
75+
assert isinstance(output, zarr.Array)
76+
assert_array_equal(output[:], test_data)
77+

0 commit comments

Comments
 (0)