|
4 | 4 | import iohub.ngff as ngff
|
5 | 5 | import numpy as np
|
6 | 6 | import pandas as pd
|
| 7 | +import tensorstore |
| 8 | +from tqdm import tqdm |
7 | 9 |
|
8 |
| -import viscy.utils.mp_utils as mp_utils |
9 |
| -from viscy.utils.cli_utils import show_progress_bar |
| 10 | +from viscy.utils.mp_utils import get_val_stats |
10 | 11 |
|
11 | 12 |
|
12 | 13 | def write_meta_field(position: ngff.Position, metadata, field_name, subfield_name):
|
@@ -45,11 +46,23 @@ def write_meta_field(position: ngff.Position, metadata, field_name, subfield_nam
|
45 | 46 | position.zattrs[field_name] = field_metadata
|
46 | 47 |
|
47 | 48 |
|
| 49 | +def _grid_sample( |
| 50 | + position: ngff.Position, grid_spacing: int, channel_index: int, num_workers: int |
| 51 | +): |
| 52 | + return ( |
| 53 | + position["0"] |
| 54 | + .tensorstore( |
| 55 | + context=tensorstore.Context( |
| 56 | + {"data_copy_concurrency": {"limit": num_workers}} |
| 57 | + ) |
| 58 | + )[:, channel_index, :, ::grid_spacing, ::grid_spacing] |
| 59 | + .read() |
| 60 | + .result() |
| 61 | + ) |
| 62 | + |
| 63 | + |
48 | 64 | def generate_normalization_metadata(
|
49 |
| - zarr_dir, |
50 |
| - num_workers=4, |
51 |
| - channel_ids=-1, |
52 |
| - grid_spacing=32, |
| 65 | + zarr_dir, num_workers=4, channel_ids=-1, grid_spacing=32 |
53 | 66 | ):
|
54 | 67 | """
|
55 | 68 | Generate pixel intensity metadata to be later used in on-the-fly normalization
|
@@ -89,54 +102,37 @@ def generate_normalization_metadata(
|
89 | 102 | mp_grid_sampler_args.append([position, grid_spacing])
|
90 | 103 |
|
91 | 104 | # sample values and use them to get normalization statistics
|
92 |
| - for i, channel in enumerate(channel_ids): |
93 |
| - show_progress_bar( |
94 |
| - dataloader=channel_ids, |
95 |
| - current=i, |
96 |
| - process="sampling channel values", |
97 |
| - ) |
| 105 | + for i, channel_index in enumerate(channel_ids): |
| 106 | + print(f"Sampling channel index {channel_index} ({i + 1}/{len(channel_ids)})") |
98 | 107 |
|
99 |
| - channel_name = plate.channel_names[channel] |
100 |
| - this_channels_args = tuple([args + [channel] for args in mp_grid_sampler_args]) |
| 108 | + channel_name = plate.channel_names[channel_index] |
| 109 | + dataset_sample_values = [] |
| 110 | + position_and_statistics = [] |
101 | 111 |
|
102 |
| - # NOTE: Doing sequential mp with pool execution creates synchronization |
103 |
| - # points between each step. This could be detrimental to performance |
104 |
| - positions, fov_sample_values = mp_utils.mp_sample_im_pixels( |
105 |
| - this_channels_args, num_workers |
106 |
| - ) |
107 |
| - dataset_sample_values = np.concatenate( |
108 |
| - [arr.flatten() for arr in fov_sample_values] |
109 |
| - ) |
110 |
| - fov_level_statistics = mp_utils.mp_get_val_stats(fov_sample_values, num_workers) |
111 |
| - dataset_level_statistics = mp_utils.get_val_stats(dataset_sample_values) |
| 112 | + for _, pos in tqdm(position_map, desc="Positions"): |
| 113 | + samples = _grid_sample(pos, grid_spacing, channel_index, num_workers) |
| 114 | + dataset_sample_values.append(samples) |
| 115 | + fov_level_statistics = {"fov_statistics": get_val_stats(samples)} |
| 116 | + position_and_statistics.append((pos, fov_level_statistics)) |
112 | 117 |
|
113 | 118 | dataset_statistics = {
|
114 |
| - "dataset_statistics": dataset_level_statistics, |
| 119 | + "dataset_statistics": get_val_stats(np.stack(dataset_sample_values)), |
115 | 120 | }
|
116 |
| - |
117 | 121 | write_meta_field(
|
118 | 122 | position=plate,
|
119 | 123 | metadata=dataset_statistics,
|
120 | 124 | field_name="normalization",
|
121 | 125 | subfield_name=channel_name,
|
122 | 126 | )
|
123 | 127 |
|
124 |
| - for j, pos in enumerate(positions): |
125 |
| - show_progress_bar( |
126 |
| - dataloader=position_map, |
127 |
| - current=j, |
128 |
| - process=f"calculating channel statistics {channel}/{list(channel_ids)}", |
129 |
| - ) |
130 |
| - position_statistics = dataset_statistics | { |
131 |
| - "fov_statistics": fov_level_statistics[j], |
132 |
| - } |
133 |
| - |
| 128 | + for pos, position_statistics in position_and_statistics: |
134 | 129 | write_meta_field(
|
135 | 130 | position=pos,
|
136 |
| - metadata=position_statistics, |
| 131 | + metadata=dataset_statistics | position_statistics, |
137 | 132 | field_name="normalization",
|
138 | 133 | subfield_name=channel_name,
|
139 | 134 | )
|
| 135 | + |
140 | 136 | plate.close()
|
141 | 137 |
|
142 | 138 |
|
|
0 commit comments