Skip to content

Commit f0df162

Browse files
committed
updated meta_utils.py
1 parent ecc3296 commit f0df162

File tree

2 files changed

+123
-142
lines changed

2 files changed

+123
-142
lines changed

viscy/data/hcs.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
import os
44
import re
55
import tempfile
6-
from collections.abc import Callable, Sequence
6+
7+
# from collections.abc import Callable, Sequence
78
from pathlib import Path
8-
from typing import Literal
9+
from typing import Callable, Literal, Sequence
910

1011
import numpy as np
1112
import torch

viscy/utils/meta_utils.py

Lines changed: 120 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import os
22
import sys
3-
from pathlib import Path
43

54
import iohub.ngff as ngff
65
import numpy as np
@@ -11,7 +10,9 @@
1110
from viscy.utils.mp_utils import get_val_stats
1211

1312

14-
def write_meta_field(position: ngff.Position, metadata, field_name, subfield_name):
13+
def write_meta_field(
14+
position: ngff.Position, metadata: dict, field_name: str, subfield_name: str
15+
):
1516
"""Write metadata to position's plate-level or FOV level .zattrs metadata.
1617
1718
Write metadata to position's plate-level or FOV level .zattrs metadata by either
@@ -68,21 +69,13 @@ def _grid_sample(
6869

6970

7071
def generate_normalization_metadata(
71-
zarr_dir: str,
72-
num_workers: int = 4,
73-
channel_ids: list[int] | int = -1,
74-
grid_spacing: int = 32,
72+
zarr_dir: str, num_workers: int = 4, channel_ids: int = -1, grid_spacing: int = 32
7573
):
7674
"""Generate pixel intensity metadata for on-the-fly normalization.
7775
7876
Generate pixel intensity metadata to be later used in on-the-fly normalization
7977
during training and inference. Sampling is used for efficient estimation of median
8078
and interquartile range for intensity values on both a dataset and field-of-view
81-
level.
82-
83-
Normalization values are recorded in the image-level metadata in the corresponding
84-
position of each zarr_dir store. Format of metadata is as follows:
85-
{
8679
channel_idx : {
8780
dataset_statistics: dataset level normalization values (positive float),
8881
fov_statistics: field-of-view level normalization values (positive float)
@@ -106,152 +99,139 @@ def generate_normalization_metadata(
10699
plate = ngff.open_ome_zarr(zarr_dir, mode="r+")
107100
position_map = list(plate.positions())
108101

109-
# Prepare parameters for multiprocessing
110-
zarr_dir_path = os.path.dirname(os.path.dirname(zarr_dir))
111-
112-
# Get channels to process
113102
if channel_ids == -1:
114-
# Get channel IDs from first position
115-
first_position = position_map[0][1]
116-
first_images = list(first_position.images())
117-
first_image = first_images[0][1]
118-
# shape is (t, c, z, y, x)
119-
channel_ids = list(range(first_image.data.shape[1]))
120-
121-
if isinstance(channel_ids, int):
103+
channel_ids = range(len(plate.channel_names))
104+
elif isinstance(channel_ids, int):
122105
channel_ids = [channel_ids]
123106

124-
# Prepare parameters for each position and channel
125-
params_list = []
126-
for position_idx, (position_key, position) in enumerate(position_map):
127-
for channel_id in channel_ids:
128-
params = {
129-
"zarr_dir": zarr_dir,
130-
"position_key": position_key,
131-
"channel_id": channel_id,
132-
"grid_spacing": grid_spacing,
133-
}
134-
params_list.append(params)
135-
136-
# Use multiprocessing to compute normalization statistics
137-
progress_bar = show_progress_bar()
138-
if num_workers > 1:
139-
with mp_utils.get_context("spawn").Pool(num_workers) as pool:
140-
results = pool.map(mp_utils.normalize_meta_worker, params_list)
141-
progress_bar.update(len(params_list))
142-
else:
143-
results = []
144-
for params in params_list:
145-
result = mp_utils.normalize_meta_worker(params)
146-
results.append(result)
147-
progress_bar.update(1)
148-
149-
progress_bar.close()
150-
151-
# Aggregate results and write to metadata
152-
all_dataset_stats = {}
153-
for result in results:
154-
if result is not None:
155-
position_key, channel_id, dataset_stats, fov_stats = result
156-
157-
if channel_id not in all_dataset_stats:
158-
all_dataset_stats[channel_id] = []
159-
all_dataset_stats[channel_id].append(dataset_stats)
160-
161-
# Calculate dataset-level statistics
162-
final_dataset_stats = {}
163-
for channel_id, stats_list in all_dataset_stats.items():
164-
if stats_list:
165-
# Aggregate median and IQR across all positions
166-
medians = [stats["median"] for stats in stats_list if "median" in stats]
167-
iqrs = [stats["iqr"] for stats in stats_list if "iqr" in stats]
168-
169-
if medians and iqrs:
170-
final_dataset_stats[channel_id] = {
171-
"median": np.median(medians),
172-
"iqr": np.median(iqrs),
173-
}
174-
175-
# Write metadata to each position
176-
for result in results:
177-
if result is not None:
178-
position_key, channel_id, dataset_stats, fov_stats = result
179-
180-
# Get position object
181-
position = dict(plate.positions())[position_key]
182-
183-
# Prepare metadata
184-
metadata = {
185-
"dataset_statistics": final_dataset_stats.get(channel_id, {}),
186-
"fov_statistics": fov_stats,
187-
}
107+
# get arguments for multiprocessed grid sampling
108+
mp_grid_sampler_args = []
109+
for _, position in position_map:
110+
mp_grid_sampler_args.append([position, grid_spacing])
111+
112+
# sample values and use them to get normalization statistics
113+
for i, channel_index in enumerate(channel_ids):
114+
print(f"Sampling channel index {channel_index} ({i + 1}/{len(channel_ids)})")
188115

189-
# Write metadata
116+
channel_name = plate.channel_names[channel_index]
117+
dataset_sample_values = []
118+
position_and_statistics = []
119+
120+
for _, pos in tqdm(position_map, desc="Positions"):
121+
samples = _grid_sample(pos, grid_spacing, channel_index, num_workers)
122+
dataset_sample_values.append(samples)
123+
fov_level_statistics = {"fov_statistics": get_val_stats(samples)}
124+
position_and_statistics.append((pos, fov_level_statistics))
125+
126+
dataset_statistics = {
127+
"dataset_statistics": get_val_stats(np.stack(dataset_sample_values)),
128+
}
129+
write_meta_field(
130+
position=plate,
131+
metadata=dataset_statistics,
132+
field_name="normalization",
133+
subfield_name=channel_name,
134+
)
135+
136+
for pos, position_statistics in position_and_statistics:
190137
write_meta_field(
191-
position=position,
192-
metadata=metadata,
138+
position=pos,
139+
metadata=dataset_statistics | position_statistics,
193140
field_name="normalization",
194-
subfield_name=str(channel_id),
141+
subfield_name=channel_name,
195142
)
196143

197144
plate.close()
198145

199146

200-
def compute_normalization_stats(
201-
image_data: np.ndarray, grid_spacing: int = 32
202-
) -> dict[str, float]:
147+
def compute_zscore_params(
148+
frames_meta, ints_meta, input_dir, normalize_im, min_fraction=0.99
149+
):
203150
"""Compute normalization statistics from image data using grid sampling.
204151
152+
Compute zscore median and interquartile range.
153+
205154
Parameters
206155
----------
207-
image_data : np.ndarray
208-
3D or 4D image array of shape (z, y, x) or (t, z, y, x).
209-
grid_spacing : int, optional
210-
Spacing betweend grid points for sampling, by default 32.
156+
frames_meta : pd.DataFrame
157+
Dataframe containing all metadata.
158+
ints_meta : pd.DataFrame
159+
Metadata containing intensity statistics each z-slice and foreground fraction for masks.
160+
input_dir : str
161+
Directory containing images.
162+
normalize_im : None or str
163+
Normalization scheme for input images.
164+
min_fraction : float
165+
Minimum foreground fraction (in case of masks) for computing intensity statistics.
166+
for computing intensity statistics.
211167
212168
Returns
213169
-------
214-
dict[str, float]
215-
Dictionary with median and IQR statistics for normalization.
170+
tuple[pd.DataFrame, pd.DataFrame]
171+
Tuple containing:
172+
- pd.DataFrame frames_meta: Dataframe containing all metadata
173+
- pd.DataFrame ints_meta: Metadata containing intensity statistics of each z-slice
216174
"""
217-
# Handle different input shapes
218-
if image_data.ndim == 4:
219-
# Assume (t, z, y, x) and take first timepoint
220-
image_data = image_data[0]
221-
222-
if image_data.ndim == 3:
223-
# Assume (z, y, x) and use middle z-slice if available
224-
if image_data.shape[0] > 1:
225-
z_mid = image_data.shape[0] // 2
226-
image_data = image_data[z_mid]
227-
else:
228-
image_data = image_data[0]
229-
230-
# Now image_data should be 2D (y, x)
231-
if image_data.ndim != 2:
232-
raise ValueError(f"Expected 2D image after processing, got {image_data.ndim}D")
233-
234-
# Create sampling grid
235-
y_indices = np.arange(0, image_data.shape[0], grid_spacing)
236-
x_indices = np.arange(0, image_data.shape[1], grid_spacing)
237-
238-
# Sample values at grid points
239-
sampled_values = image_data[np.ix_(y_indices, x_indices)].flatten()
240-
241-
# Remove any NaN or infinite values
242-
sampled_values = sampled_values[np.isfinite(sampled_values)]
243-
244-
if len(sampled_values) == 0:
245-
return {"median": 0.0, "iqr": 1.0}
246-
247-
# Compute statistics
248-
median = np.median(sampled_values)
249-
q25 = np.percentile(sampled_values, 25)
250-
q75 = np.percentile(sampled_values, 75)
251-
iqr = q75 - q25
252-
253-
# Avoid zero IQR
254-
if iqr == 0:
255-
iqr = 1.0
175+
assert normalize_im in [
176+
None,
177+
"slice",
178+
"volume",
179+
"dataset",
180+
], 'normalize_im must be None or "slice" or "volume" or "dataset"'
181+
182+
if normalize_im is None:
183+
# No normalization
184+
frames_meta["zscore_median"] = 0
185+
frames_meta["zscore_iqr"] = 1
186+
return frames_meta
187+
elif normalize_im == "dataset":
188+
agg_cols = ["time_idx", "channel_idx", "dir_name"]
189+
elif normalize_im == "volume":
190+
agg_cols = ["time_idx", "channel_idx", "dir_name", "pos_idx"]
191+
else:
192+
agg_cols = ["time_idx", "channel_idx", "dir_name", "pos_idx", "slice_idx"]
193+
# median and inter-quartile range are more robust than mean and std
194+
ints_meta_sub = ints_meta[ints_meta["fg_frac"] >= min_fraction]
195+
ints_agg_median = ints_meta_sub[agg_cols + ["intensity"]].groupby(agg_cols).median()
196+
ints_agg_hq = (
197+
ints_meta_sub[agg_cols + ["intensity"]].groupby(agg_cols).quantile(0.75)
198+
)
199+
ints_agg_lq = (
200+
ints_meta_sub[agg_cols + ["intensity"]].groupby(agg_cols).quantile(0.25)
201+
)
202+
ints_agg = ints_agg_median
203+
ints_agg.columns = ["zscore_median"]
204+
ints_agg["zscore_iqr"] = ints_agg_hq["intensity"] - ints_agg_lq["intensity"]
205+
ints_agg.reset_index(inplace=True)
206+
207+
cols_to_merge = frames_meta.columns[
208+
[col not in ["zscore_median", "zscore_iqr"] for col in frames_meta.columns]
209+
]
210+
frames_meta = pd.merge(
211+
frames_meta[cols_to_merge],
212+
ints_agg,
213+
how="left",
214+
on=agg_cols,
215+
)
216+
if frames_meta["zscore_median"].isnull().values.any():
217+
raise ValueError(
218+
"Found NaN in normalization parameters. \
219+
min_fraction might be too low or images might be corrupted."
220+
)
221+
frames_meta_filename = os.path.join(input_dir, "frames_meta.csv")
222+
frames_meta.to_csv(frames_meta_filename, sep=",")
223+
224+
cols_to_merge = ints_meta.columns[
225+
[col not in ["zscore_median", "zscore_iqr"] for col in ints_meta.columns]
226+
]
227+
ints_meta = pd.merge(
228+
ints_meta[cols_to_merge],
229+
ints_agg,
230+
how="left",
231+
on=agg_cols,
232+
)
233+
ints_meta["intensity_norm"] = (
234+
ints_meta["intensity"] - ints_meta["zscore_median"]
235+
) / (ints_meta["zscore_iqr"] + sys.float_info.epsilon)
256236

257-
return {"median": float(median), "iqr": float(iqr)}
237+
return frames_meta, ints_meta

0 commit comments

Comments
 (0)