Skip to content

Commit 167a5b3

Browse files
authored
Support region(s) in to_zarr and store (#799)
* Support region(s) in `to_zarr` and `store` * Generalise test for checking for a backend storage array * Fix picking error on beam and spark, see #800 * Fix some tests running under tensorstore
1 parent de56acf commit 167a5b3

File tree

5 files changed

+271
-52
lines changed

5 files changed

+271
-52
lines changed

cubed/core/ops.py

Lines changed: 89 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from cubed.primitive.memory import get_buffer_copies
2424
from cubed.primitive.rechunk import rechunk as primitive_rechunk
2525
from cubed.spec import spec_from_config
26-
from cubed.storage.backend import open_backend_array
26+
from cubed.storage.backend import is_backend_storage_array, open_backend_array
2727
from cubed.storage.zarr import lazy_zarr_array
2828
from cubed.types import T_RegularChunks, T_Shape
2929
from cubed.utils import (
@@ -125,7 +125,13 @@ def from_zarr(store, path=None, spec=None) -> "Array":
125125
return Array(name, target, spec, plan)
126126

127127

128-
def store(sources: Union["Array", Sequence["Array"]], targets, executor=None, **kwargs):
128+
def store(
129+
sources: Union["Array", Sequence["Array"]],
130+
targets,
131+
regions: tuple[slice, ...] | list[tuple[slice, ...]] | None = None,
132+
executor=None,
133+
**kwargs,
134+
):
129135
"""Save source arrays to array-like objects.
130136
131137
In the current implementation ``targets`` must be Zarr arrays.
@@ -135,10 +141,12 @@ def store(sources: Union["Array", Sequence["Array"]], targets, executor=None, **
135141
136142
Parameters
137143
----------
138-
x : cubed.Array or collection of cubed.Array
144+
sources : cubed.Array or collection of cubed.Array
139145
Arrays to save
140-
store : zarr.Array or collection of zarr.Array
146+
targets : string or Zarr store or collection of strings or Zarr stores
141147
Zarr arrays to write to
148+
regions : tuple of slices or list of tuple of slices, optional
149+
The regions of data that should be written to in targets.
142150
executor : cubed.runtime.types.Executor, optional
143151
The executor to use to run the computation.
144152
Defaults to using the in-process Python executor.
@@ -155,19 +163,38 @@ def store(sources: Union["Array", Sequence["Array"]], targets, executor=None, **
155163
f"Different number of sources ({len(sources)}) and targets ({len(targets)})"
156164
)
157165

166+
if isinstance(regions, tuple) or regions is None:
167+
regions_list = [regions] * len(sources)
168+
else:
169+
regions_list = list(regions)
170+
if len(sources) != len(regions_list):
171+
raise ValueError(
172+
f"Different number of sources [{len(sources)}] and "
173+
f"targets [{len(targets)}] than regions [{len(regions_list)}]"
174+
)
175+
158176
arrays = []
159-
for source, target in zip(sources, targets):
160-
identity = lambda a: a
161-
ind = tuple(range(source.ndim))
177+
for source, target, region in zip(sources, targets, regions_list):
178+
array = _store_array(source, target, region=region)
179+
arrays.append(array)
180+
compute(*arrays, executor=executor, _return_in_memory_array=False, **kwargs)
162181

163-
if target is not None and not isinstance(target, zarr.Array):
164-
target = lazy_zarr_array(
165-
target,
166-
shape=source.shape,
167-
dtype=source.dtype,
168-
chunks=source.chunksize,
169-
)
170-
array = blockwise(
182+
183+
def _store_array(source: "Array", target, path=None, region=None):
184+
if target is not None and not is_backend_storage_array(target):
185+
target = lazy_zarr_array(
186+
target,
187+
shape=source.shape,
188+
dtype=source.dtype,
189+
chunks=source.chunksize,
190+
path=path,
191+
)
192+
if target is None and region is not None:
193+
raise ValueError("Target store must be specified when setting a region")
194+
identity = lambda a: a
195+
if region is None or all(r == slice(None) for r in region):
196+
ind = tuple(range(source.ndim))
197+
return blockwise(
171198
identity,
172199
ind,
173200
source,
@@ -176,11 +203,50 @@ def store(sources: Union["Array", Sequence["Array"]], targets, executor=None, **
176203
align_arrays=False,
177204
target_store=target,
178205
)
179-
arrays.append(array)
180-
compute(*arrays, executor=executor, _return_in_memory_array=False, **kwargs)
206+
else:
207+
# treat a region as an offset within the target store
208+
shape = target.shape
209+
chunks = target.chunks
210+
for i, (sl, cs) in enumerate(zip(region, chunks)):
211+
if sl.start % cs != 0 or (sl.stop % cs != 0 and sl.stop != shape[i]):
212+
raise ValueError(
213+
f"Region {region} does not align with target chunks {chunks}"
214+
)
215+
block_offsets = [sl.start // cs for sl, cs in zip(region, chunks)]
216+
217+
def key_function(out_key):
218+
out_coords = out_key[1:]
219+
in_coords = tuple(bi - off for bi, off in zip(out_coords, block_offsets))
220+
return ((source.name, *in_coords),)
221+
222+
# calculate output block ids from region selection
223+
indexer = _create_zarr_indexer(region, shape, chunks)
224+
if source.shape != indexer.shape:
225+
raise ValueError(
226+
f"Source array shape {source.shape} does not match region shape {indexer.shape}"
227+
)
228+
# TODO(#800): make Zarr indexer pickle-able so we don't have to materialize all the block IDs
229+
output_blocks = map(
230+
lambda chunk_projection: list(chunk_projection[0]), list(indexer)
231+
)
232+
233+
out = general_blockwise(
234+
identity,
235+
key_function,
236+
source,
237+
shapes=[shape],
238+
dtypes=[source.dtype],
239+
chunkss=[chunks],
240+
target_stores=[target],
241+
output_blocks=output_blocks,
242+
)
243+
from cubed import Array
181244

245+
assert isinstance(out, Array) # single output
246+
return out
182247

183-
def to_zarr(x: "Array", store, path=None, executor=None, **kwargs):
248+
249+
def to_zarr(x: "Array", store, path=None, region=None, executor=None, **kwargs):
184250
"""Save an array to Zarr storage.
185251
186252
Note that this operation is eager, and will run the computation
@@ -190,35 +256,17 @@ def to_zarr(x: "Array", store, path=None, executor=None, **kwargs):
190256
----------
191257
x : cubed.Array
192258
Array to save
193-
store : string or Zarr Store
259+
store : string or Zarr store
194260
Output Zarr store
195261
path : string, optional
196262
Group path
263+
region : tuple of slices, optional
264+
The region of data that should be written to in target.
197265
executor : cubed.runtime.types.Executor, optional
198266
The executor to use to run the computation.
199267
Defaults to using the in-process Python executor.
200268
"""
201-
# Note that the intermediate write to x's store will be optimized away
202-
# by map fusion (if it was produced with a blockwise operation).
203-
identity = lambda a: a
204-
ind = tuple(range(x.ndim))
205-
if store is not None and not isinstance(store, zarr.Array):
206-
store = lazy_zarr_array(
207-
store,
208-
shape=x.shape,
209-
dtype=x.dtype,
210-
chunks=x.chunksize,
211-
path=path,
212-
)
213-
out = blockwise(
214-
identity,
215-
ind,
216-
x,
217-
ind,
218-
dtype=x.dtype,
219-
align_arrays=False,
220-
target_store=store,
221-
)
269+
out = _store_array(x, store, path=path, region=region)
222270
out.compute(executor=executor, _return_in_memory_array=False, **kwargs)
223271

224272

@@ -466,7 +514,7 @@ def _general_blockwise(
466514
spec = check_array_specs(arrays)
467515
buffer_copies = get_buffer_copies(spec)
468516

469-
if isinstance(target_stores, list): # multiple outputs
517+
if isinstance(target_stores, list) and len(target_stores) > 1: # multiple outputs
470518
name = [gensym() for _ in range(len(target_stores))]
471519
target_stores = [
472520
ts if ts is not None else context_dir_path(spec=spec)

cubed/core/plan.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@
99
from typing import Any, Callable, Dict, Optional, Tuple
1010

1111
import networkx as nx
12-
import zarr
1312

1413
from cubed.core.optimization import multiple_inputs_optimize_dag
1514
from cubed.primitive.blockwise import BlockwiseSpec
1615
from cubed.primitive.types import PrimitiveOperation
1716
from cubed.runtime.pipeline import visit_nodes
1817
from cubed.runtime.types import ComputeEndEvent, ComputeStartEvent, CubedPipeline
18+
from cubed.storage.backend import is_backend_storage_array
1919
from cubed.storage.zarr import LazyZarrArray, open_if_lazy_zarr_array
2020
from cubed.utils import (
2121
chunk_memory,
@@ -456,7 +456,9 @@ def visualize(
456456
chunkmem = memory_repr(chunk_memory(target))
457457

458458
# materialized arrays are light orange, virtual arrays are white
459-
if isinstance(target, (LazyZarrArray, zarr.Array)):
459+
if isinstance(target, LazyZarrArray) or is_backend_storage_array(
460+
target
461+
):
460462
d["style"] = "filled"
461463
d["fillcolor"] = "#ffd8b1"
462464
if n in array_display_names:

cubed/primitive/blockwise.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from cubed.primitive.memory import BufferCopies, MemoryModeller, calculate_projected_mem
1919
from cubed.primitive.types import CubedArrayProxy, PrimitiveOperation
2020
from cubed.runtime.types import CubedPipeline
21+
from cubed.storage.backend import is_backend_storage_array
2122
from cubed.storage.zarr import LazyZarrArray, T_ZarrArray, lazy_zarr_array
2223
from cubed.types import T_Chunks, T_DType, T_RegularChunks, T_Shape, T_Store
2324
from cubed.utils import (
@@ -294,6 +295,7 @@ def general_blockwise(
294295
iterable_input_blocks: Optional[Tuple[bool, ...]] = None,
295296
target_chunks_: Optional[T_RegularChunks] = None,
296297
return_writes_stores: bool = False,
298+
output_blocks: Optional[Iterator[List[int]]] = None,
297299
**kwargs,
298300
) -> PrimitiveOperation:
299301
"""A more general form of ``blockwise`` that uses a function to specify the block
@@ -365,7 +367,9 @@ def general_blockwise(
365367
f"All outputs must have matching number of blocks in each dimension. Chunks specified: {chunkss}"
366368
)
367369
ta: Union[zarr.Array, LazyZarrArray]
368-
if isinstance(target_store, (zarr.Array, LazyZarrArray)):
370+
if is_backend_storage_array(target_store) or isinstance(
371+
target_store, LazyZarrArray
372+
):
369373
ta = target_store
370374
else:
371375
ta = lazy_zarr_array(
@@ -417,9 +421,10 @@ def general_blockwise(
417421
)
418422

419423
# this must be an iterator of lists, not of tuples, otherwise lithops breaks
420-
output_blocks = map(
421-
list, itertools.product(*[range(len(c)) for c in chunks_normal])
422-
)
424+
if output_blocks is None:
425+
output_blocks = map(
426+
list, itertools.product(*[range(len(c)) for c in chunks_normal])
427+
)
423428
num_tasks = math.prod(len(c) for c in chunks_normal)
424429

425430
pipeline = CubedPipeline(

cubed/storage/backend.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,32 @@ def backend_storage_name():
2020
return storage_name
2121

2222

23+
def is_backend_storage_array(obj):
24+
storage_name = backend_storage_name()
25+
26+
if storage_name == "zarr-python":
27+
import zarr
28+
29+
from cubed.storage.backends.zarr_python import ZarrArrayGroup
30+
31+
return isinstance(obj, (zarr.Array, ZarrArrayGroup))
32+
elif storage_name in ("zarr-python-v3", "zarrs-python"):
33+
import zarr
34+
35+
from cubed.storage.backends.zarr_python_v3 import ZarrV3ArrayGroup
36+
37+
return isinstance(obj, (zarr.Array, ZarrV3ArrayGroup))
38+
elif storage_name == "tensorstore":
39+
from cubed.storage.backends.tensorstore import (
40+
TensorStoreArray,
41+
TensorStoreGroup,
42+
)
43+
44+
return isinstance(obj, (TensorStoreArray, TensorStoreGroup))
45+
else:
46+
raise ValueError(f"Unrecognized storage name: {storage_name}")
47+
48+
2349
def open_backend_array(
2450
store: T_Store,
2551
mode: str,

0 commit comments

Comments
 (0)