1
1
import os
2
2
import sys
3
- from pathlib import Path
4
3
5
4
import iohub .ngff as ngff
6
5
import numpy as np
11
10
from viscy .utils .mp_utils import get_val_stats
12
11
13
12
14
- def write_meta_field (position : ngff .Position , metadata , field_name , subfield_name ):
13
+ def write_meta_field (
14
+ position : ngff .Position , metadata : dict , field_name : str , subfield_name : str
15
+ ):
15
16
"""Write metadata to position's plate-level or FOV level .zattrs metadata.
16
17
17
18
Write metadata to position's plate-level or FOV level .zattrs metadata by either
@@ -68,21 +69,13 @@ def _grid_sample(
68
69
69
70
70
71
def generate_normalization_metadata (
71
- zarr_dir : str ,
72
- num_workers : int = 4 ,
73
- channel_ids : list [int ] | int = - 1 ,
74
- grid_spacing : int = 32 ,
72
+ zarr_dir : str , num_workers : int = 4 , channel_ids : int = - 1 , grid_spacing : int = 32
75
73
):
76
74
"""Generate pixel intensity metadata for on-the-fly normalization.
77
75
78
76
Generate pixel intensity metadata to be later used in on-the-fly normalization
79
77
during training and inference. Sampling is used for efficient estimation of median
80
78
and interquartile range for intensity values on both a dataset and field-of-view
81
- level.
82
-
83
- Normalization values are recorded in the image-level metadata in the corresponding
84
- position of each zarr_dir store. Format of metadata is as follows:
85
- {
86
79
channel_idx : {
87
80
dataset_statistics: dataset level normalization values (positive float),
88
81
fov_statistics: field-of-view level normalization values (positive float)
@@ -106,152 +99,139 @@ def generate_normalization_metadata(
106
99
plate = ngff .open_ome_zarr (zarr_dir , mode = "r+" )
107
100
position_map = list (plate .positions ())
108
101
109
- # Prepare parameters for multiprocessing
110
- zarr_dir_path = os .path .dirname (os .path .dirname (zarr_dir ))
111
-
112
- # Get channels to process
113
102
if channel_ids == - 1 :
114
- # Get channel IDs from first position
115
- first_position = position_map [0 ][1 ]
116
- first_images = list (first_position .images ())
117
- first_image = first_images [0 ][1 ]
118
- # shape is (t, c, z, y, x)
119
- channel_ids = list (range (first_image .data .shape [1 ]))
120
-
121
- if isinstance (channel_ids , int ):
103
+ channel_ids = range (len (plate .channel_names ))
104
+ elif isinstance (channel_ids , int ):
122
105
channel_ids = [channel_ids ]
123
106
124
- # Prepare parameters for each position and channel
125
- params_list = []
126
- for position_idx , (position_key , position ) in enumerate (position_map ):
127
- for channel_id in channel_ids :
128
- params = {
129
- "zarr_dir" : zarr_dir ,
130
- "position_key" : position_key ,
131
- "channel_id" : channel_id ,
132
- "grid_spacing" : grid_spacing ,
133
- }
134
- params_list .append (params )
135
-
136
- # Use multiprocessing to compute normalization statistics
137
- progress_bar = show_progress_bar ()
138
- if num_workers > 1 :
139
- with mp_utils .get_context ("spawn" ).Pool (num_workers ) as pool :
140
- results = pool .map (mp_utils .normalize_meta_worker , params_list )
141
- progress_bar .update (len (params_list ))
142
- else :
143
- results = []
144
- for params in params_list :
145
- result = mp_utils .normalize_meta_worker (params )
146
- results .append (result )
147
- progress_bar .update (1 )
148
-
149
- progress_bar .close ()
150
-
151
- # Aggregate results and write to metadata
152
- all_dataset_stats = {}
153
- for result in results :
154
- if result is not None :
155
- position_key , channel_id , dataset_stats , fov_stats = result
156
-
157
- if channel_id not in all_dataset_stats :
158
- all_dataset_stats [channel_id ] = []
159
- all_dataset_stats [channel_id ].append (dataset_stats )
160
-
161
- # Calculate dataset-level statistics
162
- final_dataset_stats = {}
163
- for channel_id , stats_list in all_dataset_stats .items ():
164
- if stats_list :
165
- # Aggregate median and IQR across all positions
166
- medians = [stats ["median" ] for stats in stats_list if "median" in stats ]
167
- iqrs = [stats ["iqr" ] for stats in stats_list if "iqr" in stats ]
168
-
169
- if medians and iqrs :
170
- final_dataset_stats [channel_id ] = {
171
- "median" : np .median (medians ),
172
- "iqr" : np .median (iqrs ),
173
- }
174
-
175
- # Write metadata to each position
176
- for result in results :
177
- if result is not None :
178
- position_key , channel_id , dataset_stats , fov_stats = result
179
-
180
- # Get position object
181
- position = dict (plate .positions ())[position_key ]
182
-
183
- # Prepare metadata
184
- metadata = {
185
- "dataset_statistics" : final_dataset_stats .get (channel_id , {}),
186
- "fov_statistics" : fov_stats ,
187
- }
107
+ # get arguments for multiprocessed grid sampling
108
+ mp_grid_sampler_args = []
109
+ for _ , position in position_map :
110
+ mp_grid_sampler_args .append ([position , grid_spacing ])
111
+
112
+ # sample values and use them to get normalization statistics
113
+ for i , channel_index in enumerate (channel_ids ):
114
+ print (f"Sampling channel index { channel_index } ({ i + 1 } /{ len (channel_ids )} )" )
188
115
189
- # Write metadata
116
+ channel_name = plate .channel_names [channel_index ]
117
+ dataset_sample_values = []
118
+ position_and_statistics = []
119
+
120
+ for _ , pos in tqdm (position_map , desc = "Positions" ):
121
+ samples = _grid_sample (pos , grid_spacing , channel_index , num_workers )
122
+ dataset_sample_values .append (samples )
123
+ fov_level_statistics = {"fov_statistics" : get_val_stats (samples )}
124
+ position_and_statistics .append ((pos , fov_level_statistics ))
125
+
126
+ dataset_statistics = {
127
+ "dataset_statistics" : get_val_stats (np .stack (dataset_sample_values )),
128
+ }
129
+ write_meta_field (
130
+ position = plate ,
131
+ metadata = dataset_statistics ,
132
+ field_name = "normalization" ,
133
+ subfield_name = channel_name ,
134
+ )
135
+
136
+ for pos , position_statistics in position_and_statistics :
190
137
write_meta_field (
191
- position = position ,
192
- metadata = metadata ,
138
+ position = pos ,
139
+ metadata = dataset_statistics | position_statistics ,
193
140
field_name = "normalization" ,
194
- subfield_name = str ( channel_id ) ,
141
+ subfield_name = channel_name ,
195
142
)
196
143
197
144
plate .close ()
198
145
199
146
200
- def compute_normalization_stats (
201
- image_data : np . ndarray , grid_spacing : int = 32
202
- ) -> dict [ str , float ] :
147
+ def compute_zscore_params (
148
+ frames_meta , ints_meta , input_dir , normalize_im , min_fraction = 0.99
149
+ ):
203
150
"""Compute normalization statistics from image data using grid sampling.
204
151
152
+ Compute zscore median and interquartile range.
153
+
205
154
Parameters
206
155
----------
207
- image_data : np.ndarray
208
- 3D or 4D image array of shape (z, y, x) or (t, z, y, x).
209
- grid_spacing : int, optional
210
- Spacing betweend grid points for sampling, by default 32.
156
+ frames_meta : pd.DataFrame
157
+ Dataframe containing all metadata.
158
+ ints_meta : pd.DataFrame
159
+ Metadata containing intensity statistics each z-slice and foreground fraction for masks.
160
+ input_dir : str
161
+ Directory containing images.
162
+ normalize_im : None or str
163
+ Normalization scheme for input images.
164
+ min_fraction : float
165
+ Minimum foreground fraction (in case of masks) for computing intensity statistics.
166
+ for computing intensity statistics.
211
167
212
168
Returns
213
169
-------
214
- dict[str, float]
215
- Dictionary with median and IQR statistics for normalization.
170
+ tuple[pd.DataFrame, pd.DataFrame]
171
+ Tuple containing:
172
+ - pd.DataFrame frames_meta: Dataframe containing all metadata
173
+ - pd.DataFrame ints_meta: Metadata containing intensity statistics of each z-slice
216
174
"""
217
- # Handle different input shapes
218
- if image_data .ndim == 4 :
219
- # Assume (t, z, y, x) and take first timepoint
220
- image_data = image_data [0 ]
221
-
222
- if image_data .ndim == 3 :
223
- # Assume (z, y, x) and use middle z-slice if available
224
- if image_data .shape [0 ] > 1 :
225
- z_mid = image_data .shape [0 ] // 2
226
- image_data = image_data [z_mid ]
227
- else :
228
- image_data = image_data [0 ]
229
-
230
- # Now image_data should be 2D (y, x)
231
- if image_data .ndim != 2 :
232
- raise ValueError (f"Expected 2D image after processing, got { image_data .ndim } D" )
233
-
234
- # Create sampling grid
235
- y_indices = np .arange (0 , image_data .shape [0 ], grid_spacing )
236
- x_indices = np .arange (0 , image_data .shape [1 ], grid_spacing )
237
-
238
- # Sample values at grid points
239
- sampled_values = image_data [np .ix_ (y_indices , x_indices )].flatten ()
240
-
241
- # Remove any NaN or infinite values
242
- sampled_values = sampled_values [np .isfinite (sampled_values )]
243
-
244
- if len (sampled_values ) == 0 :
245
- return {"median" : 0.0 , "iqr" : 1.0 }
246
-
247
- # Compute statistics
248
- median = np .median (sampled_values )
249
- q25 = np .percentile (sampled_values , 25 )
250
- q75 = np .percentile (sampled_values , 75 )
251
- iqr = q75 - q25
252
-
253
- # Avoid zero IQR
254
- if iqr == 0 :
255
- iqr = 1.0
175
+ assert normalize_im in [
176
+ None ,
177
+ "slice" ,
178
+ "volume" ,
179
+ "dataset" ,
180
+ ], 'normalize_im must be None or "slice" or "volume" or "dataset"'
181
+
182
+ if normalize_im is None :
183
+ # No normalization
184
+ frames_meta ["zscore_median" ] = 0
185
+ frames_meta ["zscore_iqr" ] = 1
186
+ return frames_meta
187
+ elif normalize_im == "dataset" :
188
+ agg_cols = ["time_idx" , "channel_idx" , "dir_name" ]
189
+ elif normalize_im == "volume" :
190
+ agg_cols = ["time_idx" , "channel_idx" , "dir_name" , "pos_idx" ]
191
+ else :
192
+ agg_cols = ["time_idx" , "channel_idx" , "dir_name" , "pos_idx" , "slice_idx" ]
193
+ # median and inter-quartile range are more robust than mean and std
194
+ ints_meta_sub = ints_meta [ints_meta ["fg_frac" ] >= min_fraction ]
195
+ ints_agg_median = ints_meta_sub [agg_cols + ["intensity" ]].groupby (agg_cols ).median ()
196
+ ints_agg_hq = (
197
+ ints_meta_sub [agg_cols + ["intensity" ]].groupby (agg_cols ).quantile (0.75 )
198
+ )
199
+ ints_agg_lq = (
200
+ ints_meta_sub [agg_cols + ["intensity" ]].groupby (agg_cols ).quantile (0.25 )
201
+ )
202
+ ints_agg = ints_agg_median
203
+ ints_agg .columns = ["zscore_median" ]
204
+ ints_agg ["zscore_iqr" ] = ints_agg_hq ["intensity" ] - ints_agg_lq ["intensity" ]
205
+ ints_agg .reset_index (inplace = True )
206
+
207
+ cols_to_merge = frames_meta .columns [
208
+ [col not in ["zscore_median" , "zscore_iqr" ] for col in frames_meta .columns ]
209
+ ]
210
+ frames_meta = pd .merge (
211
+ frames_meta [cols_to_merge ],
212
+ ints_agg ,
213
+ how = "left" ,
214
+ on = agg_cols ,
215
+ )
216
+ if frames_meta ["zscore_median" ].isnull ().values .any ():
217
+ raise ValueError (
218
+ "Found NaN in normalization parameters. \
219
+ min_fraction might be too low or images might be corrupted."
220
+ )
221
+ frames_meta_filename = os .path .join (input_dir , "frames_meta.csv" )
222
+ frames_meta .to_csv (frames_meta_filename , sep = "," )
223
+
224
+ cols_to_merge = ints_meta .columns [
225
+ [col not in ["zscore_median" , "zscore_iqr" ] for col in ints_meta .columns ]
226
+ ]
227
+ ints_meta = pd .merge (
228
+ ints_meta [cols_to_merge ],
229
+ ints_agg ,
230
+ how = "left" ,
231
+ on = agg_cols ,
232
+ )
233
+ ints_meta ["intensity_norm" ] = (
234
+ ints_meta ["intensity" ] - ints_meta ["zscore_median" ]
235
+ ) / (ints_meta ["zscore_iqr" ] + sys .float_info .epsilon )
256
236
257
- return { "median" : float ( median ), "iqr" : float ( iqr )}
237
+ return frames_meta , ints_meta
0 commit comments