Skip to content

Commit 3411983

Browse files
authored
Merge pull request #91 from bioio-devs/zarr-writer-single-time
allow zarr writer to do incremental writes of single timepoints
2 parents e7bd2af + becaa13 commit 3411983

File tree

3 files changed

+110
-6
lines changed

3 files changed

+110
-6
lines changed

bioio/tests/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ def da_random_from_shape(
5757

5858

5959
array_constructor = pytest.mark.parametrize(
60-
"array_constructor", [np_random_from_shape, da_random_from_shape]
60+
"array_constructor",
61+
[np_random_from_shape, da_random_from_shape],
6162
)
6263

6364
DUMMY_PLUGIN_NAME = "dummy-plugin"

bioio/tests/writers/test_ome_zarr_writer_2.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,3 +193,93 @@ def test_write_ome_zarr(
193193
axes = node.metadata["axes"]
194194
dims = "".join([a["name"] for a in axes]).upper()
195195
assert dims == "TCZYX"
196+
197+
198+
@array_constructor
199+
@pytest.mark.parametrize(
200+
"shape, num_levels, scaling, expected_shapes",
201+
[
202+
(
203+
(4, 2, 2, 64, 32), # easy, powers of two
204+
3,
205+
(1, 1, 1, 2, 2), # downscale xy by two
206+
[(4, 2, 2, 64, 32), (4, 2, 2, 32, 16), (4, 2, 2, 16, 8)],
207+
),
208+
(
209+
(4, 2, 2, 8, 6),
210+
1, # no downscaling
211+
(1, 1, 1, 1, 1),
212+
[(4, 2, 2, 8, 6)],
213+
),
214+
],
215+
)
216+
@pytest.mark.parametrize("filename", ["e.zarr"])
217+
def test_write_ome_zarr_iterative(
218+
array_constructor: Callable,
219+
filename: str,
220+
shape: DimTuple,
221+
num_levels: int,
222+
scaling: Tuple[float, float, float, float, float],
223+
expected_shapes: List[DimTuple],
224+
tmp_path: pathlib.Path,
225+
) -> None:
226+
# TCZYX order, downsampling x and y only
227+
im = array_constructor(shape, dtype=np.uint8)
228+
C = shape[1]
229+
230+
shapes = compute_level_shapes(shape, scaling, num_levels)
231+
chunk_sizes = compute_level_chunk_sizes_zslice(shapes)
232+
233+
# Create an OmeZarrWriter object
234+
writer = OmeZarrWriter()
235+
236+
# Initialize the store. Use s3 url or local directory path!
237+
save_uri = tmp_path / filename
238+
writer.init_store(str(save_uri), shapes, chunk_sizes, im.dtype)
239+
240+
# Write the image iteratively as if we only have one timepoint at a time
241+
for t in range(shape[0]):
242+
t4d = im[t]
243+
t5d = np.expand_dims(t4d, axis=0)
244+
writer.write_t_batches_array(t5d, channels=[], tbatch=1, toffset=t)
245+
246+
# TODO: get this from source image
247+
physical_scale = {
248+
"c": 1.0, # default value for channel
249+
"t": 1.0,
250+
"z": 1.0,
251+
"y": 1.0,
252+
"x": 1.0,
253+
}
254+
physical_units = {
255+
"x": "micrometer",
256+
"y": "micrometer",
257+
"z": "micrometer",
258+
"t": "minute",
259+
}
260+
meta = writer.generate_metadata(
261+
image_name="TEST",
262+
channel_names=[f"c{i}" for i in range(C)],
263+
physical_dims=physical_scale,
264+
physical_units=physical_units,
265+
channel_colors=[0xFFFFFF for i in range(C)],
266+
)
267+
writer.write_metadata(meta)
268+
269+
# Read written result and check basics
270+
reader = Reader(parse_url(save_uri))
271+
node = list(reader())[0]
272+
num_levels_read = len(node.data)
273+
assert num_levels_read == num_levels
274+
for level, shape in zip(range(num_levels), expected_shapes):
275+
read_shape = node.data[level].shape
276+
assert read_shape == shape
277+
axes = node.metadata["axes"]
278+
dims = "".join([a["name"] for a in axes]).upper()
279+
assert dims == "TCZYX"
280+
281+
# check lvl 0 data values got written in order
282+
for t in range(shape[0]):
283+
t4d = im[t]
284+
read_t4d = node.data[0][t]
285+
np.testing.assert_array_equal(t4d, read_t4d)

bioio/writers/ome_zarr_writer_2.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def build_ome(
220220
class ZarrLevel:
221221
shape: DimTuple
222222
chunk_size: DimTuple
223-
dtype: str
223+
dtype: np.dtype
224224
zarray: zarr.core.Array
225225

226226

@@ -440,7 +440,7 @@ def _create_levels(
440440
self.levels.append(level)
441441

442442
def _downsample_and_write_batch_t(
443-
self, data_tczyx: da.Array, start_t: int, end_t: int
443+
self, data_tczyx: da.Array, start_t: int, end_t: int, toffset: int = 0
444444
) -> None:
445445
dtype = data_tczyx.dtype
446446
if len(data_tczyx.shape) != 5:
@@ -451,7 +451,11 @@ def _downsample_and_write_batch_t(
451451
# write level 0 first
452452
for k in range(start_t, end_t):
453453
subset = data_tczyx[[k - start_t]]
454-
da.to_zarr(subset, self.levels[0].zarray, region=(slice(k, k + 1),))
454+
da.to_zarr(
455+
subset,
456+
self.levels[0].zarray,
457+
region=(slice(k + toffset, k + toffset + 1),),
458+
)
455459

456460
# downsample to next level then write
457461
for j in range(1, len(self.levels)):
@@ -463,7 +467,11 @@ def _downsample_and_write_batch_t(
463467
# write ti to zarr
464468
for k in range(start_t, end_t):
465469
subset = data_tczyx[[k - start_t]]
466-
da.to_zarr(subset, self.levels[j].zarray, region=(slice(k, k + 1),))
470+
da.to_zarr(
471+
subset,
472+
self.levels[j].zarray,
473+
region=(slice(k + toffset, k + toffset + 1),),
474+
)
467475

468476
log.info(f"Completed {start_t} to {end_t}")
469477

@@ -540,6 +548,7 @@ def write_t_batches_array(
540548
im: Union[da.Array, np.ndarray],
541549
channels: List[int] = [],
542550
tbatch: int = 4,
551+
toffset: int = 0,
543552
debug: bool = False,
544553
) -> None:
545554
"""
@@ -551,6 +560,8 @@ def write_t_batches_array(
551560
An ArrayLike object. Should be 5D TCZYX.
552561
tbatch:
553562
The number of T to write at a time.
563+
toffset:
564+
The offset to start writing T from. All T in the input array will be written
554565
"""
555566
# if isinstance(im, (np.ndarray)):
556567
# im_da = da.from_array(im)
@@ -571,7 +582,9 @@ def write_t_batches_array(
571582
if channels:
572583
for t in range(len(ti)):
573584
ti[t] = [ti[t][c] for c in channels]
574-
self._downsample_and_write_batch_t(da.asarray(ti), start_t, end_t)
585+
self._downsample_and_write_batch_t(
586+
da.asarray(ti), start_t, end_t, toffset
587+
)
575588
log.info("Finished loop over T")
576589

577590
def _get_scale_ratio(self, level: int) -> Tuple[float, float, float, float, float]:

0 commit comments

Comments
 (0)