Skip to content

Commit 4df580d

Browse files
authored
Fix preprocessing with zarr-python 3 (#294)
* use tensorstore for grid sampling * remove old image sampling * fix channel enumeration
1 parent 2c5d2b4 commit 4df580d

File tree

2 files changed

+33
-89
lines changed

2 files changed

+33
-89
lines changed

viscy/utils/meta_utils.py

Lines changed: 33 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
import iohub.ngff as ngff
55
import numpy as np
66
import pandas as pd
7+
import tensorstore
8+
from tqdm import tqdm
79

8-
import viscy.utils.mp_utils as mp_utils
9-
from viscy.utils.cli_utils import show_progress_bar
10+
from viscy.utils.mp_utils import get_val_stats
1011

1112

1213
def write_meta_field(position: ngff.Position, metadata, field_name, subfield_name):
@@ -45,11 +46,23 @@ def write_meta_field(position: ngff.Position, metadata, field_name, subfield_nam
4546
position.zattrs[field_name] = field_metadata
4647

4748

49+
def _grid_sample(
50+
position: ngff.Position, grid_spacing: int, channel_index: int, num_workers: int
51+
):
52+
return (
53+
position["0"]
54+
.tensorstore(
55+
context=tensorstore.Context(
56+
{"data_copy_concurrency": {"limit": num_workers}}
57+
)
58+
)[:, channel_index, :, ::grid_spacing, ::grid_spacing]
59+
.read()
60+
.result()
61+
)
62+
63+
4864
def generate_normalization_metadata(
49-
zarr_dir,
50-
num_workers=4,
51-
channel_ids=-1,
52-
grid_spacing=32,
65+
zarr_dir, num_workers=4, channel_ids=-1, grid_spacing=32
5366
):
5467
"""
5568
Generate pixel intensity metadata to be later used in on-the-fly normalization
@@ -89,54 +102,37 @@ def generate_normalization_metadata(
89102
mp_grid_sampler_args.append([position, grid_spacing])
90103

91104
# sample values and use them to get normalization statistics
92-
for i, channel in enumerate(channel_ids):
93-
show_progress_bar(
94-
dataloader=channel_ids,
95-
current=i,
96-
process="sampling channel values",
97-
)
105+
for i, channel_index in enumerate(channel_ids):
106+
print(f"Sampling channel index {channel_index} ({i + 1}/{len(channel_ids)})")
98107

99-
channel_name = plate.channel_names[channel]
100-
this_channels_args = tuple([args + [channel] for args in mp_grid_sampler_args])
108+
channel_name = plate.channel_names[channel_index]
109+
dataset_sample_values = []
110+
position_and_statistics = []
101111

102-
# NOTE: Doing sequential mp with pool execution creates synchronization
103-
# points between each step. This could be detrimental to performance
104-
positions, fov_sample_values = mp_utils.mp_sample_im_pixels(
105-
this_channels_args, num_workers
106-
)
107-
dataset_sample_values = np.concatenate(
108-
[arr.flatten() for arr in fov_sample_values]
109-
)
110-
fov_level_statistics = mp_utils.mp_get_val_stats(fov_sample_values, num_workers)
111-
dataset_level_statistics = mp_utils.get_val_stats(dataset_sample_values)
112+
for _, pos in tqdm(position_map, desc="Positions"):
113+
samples = _grid_sample(pos, grid_spacing, channel_index, num_workers)
114+
dataset_sample_values.append(samples)
115+
fov_level_statistics = {"fov_statistics": get_val_stats(samples)}
116+
position_and_statistics.append((pos, fov_level_statistics))
112117

113118
dataset_statistics = {
114-
"dataset_statistics": dataset_level_statistics,
119+
"dataset_statistics": get_val_stats(np.stack(dataset_sample_values)),
115120
}
116-
117121
write_meta_field(
118122
position=plate,
119123
metadata=dataset_statistics,
120124
field_name="normalization",
121125
subfield_name=channel_name,
122126
)
123127

124-
for j, pos in enumerate(positions):
125-
show_progress_bar(
126-
dataloader=position_map,
127-
current=j,
128-
process=f"calculating channel statistics {channel}/{list(channel_ids)}",
129-
)
130-
position_statistics = dataset_statistics | {
131-
"fov_statistics": fov_level_statistics[j],
132-
}
133-
128+
for pos, position_statistics in position_and_statistics:
134129
write_meta_field(
135130
position=pos,
136-
metadata=position_statistics,
131+
metadata=dataset_statistics | position_statistics,
137132
field_name="normalization",
138133
subfield_name=channel_name,
139134
)
135+
140136
plate.close()
141137

142138

viscy/utils/mp_utils.py

Lines changed: 0 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -268,55 +268,3 @@ def get_val_stats(sample_values):
268268
"iqr": float(scipy.stats.iqr(sample_values)),
269269
}
270270
return meta_row
271-
272-
273-
def mp_sample_im_pixels(fn_args, workers):
274-
"""Read and computes statistics of images with multiprocessing
275-
276-
:param list of tuple fn_args: list with tuples of function arguments
277-
:param int workers: max number of workers
278-
:return: list of paths and corresponding returned df from get_im_stats
279-
"""
280-
281-
with ProcessPoolExecutor(workers) as ex:
282-
# can't use map directly as it works only with single arg functions
283-
res = ex.map(sample_im_pixels, *zip(*fn_args))
284-
return list(map(list, zip(*list(res))))
285-
286-
287-
def sample_im_pixels(
288-
position: ngff.Position,
289-
grid_spacing,
290-
channel,
291-
):
292-
# TODO move out of mp utils into normalization utils
293-
"""
294-
Read and computes statistics of images for each point in a grid.
295-
Grid spacing determines distance in pixels between grid points
296-
for rows and cols.
297-
By default, samples from every time position and every z-depth, and
298-
assumes that the data in the zarr store is stored in [T,C,Z,Y,X] format,
299-
for time, channel, z, y, x.
300-
301-
:param Position zarr_dir: NGFF position node object
302-
:param int grid_spacing: spacing of sampling grid in x and y
303-
:param int channel: channel to sample from
304-
305-
:return list meta_rows: Dicts with intensity data for each grid point
306-
"""
307-
image_zarr = position.data
308-
309-
all_sample_values = []
310-
all_time_indices = list(range(image_zarr.shape[0]))
311-
all_z_indices = list(range(image_zarr.shape[2]))
312-
313-
for time_index in all_time_indices:
314-
for z_index in all_z_indices:
315-
image_slice = image_zarr[time_index, channel, z_index, :, :]
316-
_, _, sample_values = image_utils.grid_sample_pixel_values(
317-
image_slice, grid_spacing
318-
)
319-
all_sample_values.append(sample_values)
320-
sample_values = np.stack(all_sample_values, 0).flatten()
321-
322-
return position, sample_values

0 commit comments

Comments
 (0)