Skip to content

Commit 0f62bc4

Browse files
authored
Merge pull request #167 from HiPCTProject/downsample-with-ts
Use tensorestore for downsampling
2 parents 5a7267a + 3d42fc8 commit 0f62bc4

File tree

2 files changed

+32
-22
lines changed

2 files changed

+32
-22
lines changed

src/stack_to_chunk/_array_helpers.py

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import numpy as np
55
import skimage.measure
66
import tensorstore as ts
7-
import zarr
87
from joblib import delayed
98
from loguru import logger
109

@@ -32,23 +31,14 @@ def _copy_slab(arr_path: Path, slab: da.Array, zstart: int, zend: int) -> None:
3231

3332
logger.info(f"Writing z={zstart} -> {zend - 1}")
3433
# Write out data
35-
arr_zarr = ts.open(
36-
{
37-
"driver": "zarr3",
38-
"kvstore": {
39-
"driver": "file",
40-
"path": str(arr_path),
41-
},
42-
"open": True,
43-
}
44-
).result()
34+
arr_zarr = _open_with_tensorstore(arr_path)
4535
arr_zarr[:, :, zstart:zend].write(data).result()
4636
logger.info(f"Finished copying z={zstart} -> {zend - 1}")
4737

4838

4939
@delayed # type: ignore[misc]
5040
def _downsample_block(
51-
arr_in: zarr.Array, arr_out: zarr.Array, block_idx: tuple[int, int, int]
41+
arr_in_path: Path, arr_out_path: Path, block_idx: tuple[int, int, int]
5242
) -> None:
5343
"""
5444
Copy a single block from one array to the next, downsampling by a factor of two.
@@ -59,15 +49,17 @@ def _downsample_block(
5949
6050
Parameters
6151
----------
62-
arr_in :
63-
Input array.
64-
arr_out :
65-
Output array. Must have the same chunk shape as `arr_in`.
52+
arr_in_path :
53+
Path to input array.
54+
arr_out_path :
55+
Path to output array. Must have the same chunk shape as `arr_in`.
6656
block_idx :
6757
Index of block to copy. Must be a multiple of the shard shape in `arr_out`.
6858
6959
"""
70-
shard_shape: tuple[int, int, int] = arr_out.shards
60+
arr_in = _open_with_tensorstore(arr_in_path)
61+
arr_out = _open_with_tensorstore(arr_out_path)
62+
shard_shape: tuple[int, int, int] = arr_out.chunk_layout.write_chunk.shape
7163
np.testing.assert_equal(
7264
np.array(block_idx) % np.array(shard_shape),
7365
np.array([0, 0, 0]),
@@ -85,17 +77,32 @@ def _downsample_block(
8577
block_idx[2] * 2, min((block_idx[2] + shard_shape[2]) * 2, arr_in.shape[2])
8678
),
8779
)
88-
data = arr_in[in_slice]
80+
data = arr_in[in_slice].read().result()
8981

9082
# Pad to an even number
9183
pads = np.array(data.shape) % 2
9284
pad_width = [(0, p) for p in pads]
9385
data = np.pad(data, pad_width, mode="edge")
94-
data = skimage.measure.block_reduce(data, block_size=2, func=np.mean)
86+
data = skimage.measure.block_reduce(data, block_size=2, func=np.mean).astype(
87+
data.dtype
88+
)
9589

9690
out_slice = (
9791
slice(block_idx[0], min((block_idx[0] + shard_shape[0]), arr_out.shape[0])),
9892
slice(block_idx[1], min((block_idx[1] + shard_shape[1]), arr_out.shape[1])),
9993
slice(block_idx[2], min((block_idx[2] + shard_shape[2]), arr_out.shape[2])),
10094
)
101-
arr_out[out_slice] = data
95+
arr_out[out_slice].write(data).result()
96+
97+
98+
def _open_with_tensorstore(arr_path: Path) -> ts.TensorStore:
99+
return ts.open(
100+
{
101+
"driver": "zarr3",
102+
"kvstore": {
103+
"driver": "file",
104+
"path": str(arr_path),
105+
},
106+
"open": True,
107+
}
108+
).result()

src/stack_to_chunk/main.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -387,14 +387,17 @@ def add_downsample_level(self, level: int, *, n_processes: int = 1) -> None:
387387
assert sink_arr.shards is not None
388388

389389
# Get slice of every shard in the sink array
390-
block_indices = [
390+
block_indices: list[tuple[int, int, int]] = [
391391
(x, y, z)
392392
for x in range(0, sink_arr.shape[0], sink_arr.shards[0])
393393
for y in range(0, sink_arr.shape[1], sink_arr.shards[1])
394394
for z in range(0, sink_arr.shape[2], sink_arr.shards[2])
395395
]
396396

397-
all_args = [(source_arr, sink_arr, idxs) for idxs in block_indices]
397+
all_args: list[tuple[Path, Path, tuple[int, int, int]]] = [
398+
(self._path / str(level_minus_one), self._path / level_str, idxs)
399+
for idxs in block_indices
400+
]
398401

399402
logger.info(f"Starting downsampling from level {level_minus_one} > {level}...")
400403
blosc_use_threads = blosc.use_threads

0 commit comments

Comments
 (0)