Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions viscy/data/typing.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from typing import Callable, NamedTuple, Sequence, TypedDict, TypeVar

from torch import ShortTensor, Tensor
Expand All @@ -17,11 +17,15 @@
std: Tensor
median: Tensor
iqr: Tensor
percentile_lower_b: NotRequired[Tensor]
percentile_upper_b: NotRequired[Tensor]
percentile_range: NotRequired[Tensor]


class ChannelNormStats(TypedDict):
dataset_statistics: LevelNormStats
fov_statistics: LevelNormStats
per_timepoint: NotRequired[dict[int, LevelNormStats]]


NormMeta = dict[str, ChannelNormStats]
Expand Down
8 changes: 8 additions & 0 deletions viscy/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ def preprocess(
channel_names: list[str] | Literal[-1] = -1,
num_workers: int = 1,
block_size: int = 32,
per_timepoint: bool = False,
percentiles: tuple[float, float] = (50.0, 99.0),
model: LightningModule | None = None,
):
"""
Expand All @@ -36,6 +38,10 @@ def preprocess(
Number of CPU workers, by default 1
block_size : int, optional
Block size to subsample images, by default 32
per_timepoint : bool, optional
If True, compute normalization statistics per timepoint, by default False
percentiles : tuple[float, float], optional
Lower and upper percentiles to compute, by default (50.0, 99.0)
model: LightningModule, optional
Ignored placeholder, by default None
"""
Expand All @@ -52,6 +58,8 @@ def preprocess(
num_workers=num_workers,
channel_ids=channel_indices,
grid_spacing=block_size,
per_timepoint=per_timepoint,
percentiles=percentiles,
)

def export(
Expand Down
11 changes: 8 additions & 3 deletions viscy/transforms/_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class NormalizeSampled(MapTransform):
----------
keys : str | Iterable[str]
Keys to normalize.
level : {'fov_statistics', 'dataset_statistics'}
level : {'fov_statistics', 'dataset_statistics', 'per_timepoint'}
Level of normalization.
subtrahend : str, optional
Subtrahend for normalization, defaults to "mean".
Expand All @@ -33,7 +33,7 @@ class NormalizeSampled(MapTransform):
def __init__(
self,
keys: str | Iterable[str],
level: Literal["fov_statistics", "dataset_statistics"],
level: Literal["fov_statistics", "dataset_statistics", "per_timepoint"],
subtrahend="mean",
divisor="std",
remove_meta: bool = False,
Expand All @@ -47,7 +47,12 @@ def __init__(
# TODO: need to implement the case where the preprocessing already exists
def __call__(self, sample: Sample) -> Sample:
for key in self.keys:
level_meta = sample["norm_meta"][key][self.level]
if self.level == "per_timepoint":
time_idx = sample["index"].time
level_meta = sample["norm_meta"][key]["per_timepoint"][time_idx]
else:
level_meta = sample["norm_meta"][key][self.level]

subtrahend_val = level_meta[self.subtrahend]
divisor_val = level_meta[self.divisor] + 1e-8 # avoid div by zero
sample[key] = (sample[key] - subtrahend_val) / divisor_val
Expand Down
61 changes: 50 additions & 11 deletions viscy/utils/meta_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
import os
import sys

Expand Down Expand Up @@ -50,6 +50,8 @@
num_workers=4,
channel_ids=-1,
grid_spacing=32,
per_timepoint=False,
percentiles=(50.0, 99.0),
):
"""
Generate pixel intensity metadata to be later used in on-the-fly normalization
Expand All @@ -62,7 +64,8 @@
{
channel_idx : {
dataset_statistics: dataset level normalization values (positive float),
fov_statistics: field-of-view level normalization values (positive float)
fov_statistics: field-of-view level normalization values (positive float),
per_timepoint: per-timepoint level normalization values (when enabled)
},
.
.
Expand All @@ -74,6 +77,8 @@
:param list/int channel_ids: indices of channels to process in dataset arrays,
by default calculates all
:param int grid_spacing: distance between points in sampling grid
:param bool per_timepoint: if True, compute statistics per timepoint, defaults to False
:param tuple percentiles: lower and upper percentiles to compute, defaults to (50.0, 99.0)
"""
plate = ngff.open_ome_zarr(zarr_dir, mode="r+")
position_map = list(plate.positions())
Expand All @@ -99,16 +104,47 @@
channel_name = plate.channel_names[channel]
this_channels_args = tuple([args + [channel] for args in mp_grid_sampler_args])

# NOTE: Doing sequential mp with pool execution creates synchronization
# points between each step. This could be detrimental to performance
positions, fov_sample_values = mp_utils.mp_sample_im_pixels(
this_channels_args, num_workers
)
dataset_sample_values = np.concatenate(
[arr.flatten() for arr in fov_sample_values]
)
fov_level_statistics = mp_utils.mp_get_val_stats(fov_sample_values, num_workers)
dataset_level_statistics = mp_utils.get_val_stats(dataset_sample_values)
if per_timepoint:
# Sample per timepoint
positions, fov_timepoint_samples = mp_utils.mp_sample_im_pixels_per_timepoint(
this_channels_args, num_workers
)

# Compute dataset-level statistics across all timepoints
all_dataset_samples = []
for fov_samples in fov_timepoint_samples:
for timepoint_samples in fov_samples.values():
all_dataset_samples.append(timepoint_samples)
dataset_sample_values = np.concatenate(all_dataset_samples)

# Compute per-timepoint statistics for each FOV
fov_level_statistics = []
per_timepoint_statistics = []

for fov_samples in fov_timepoint_samples:
# FOV-level statistics (across all timepoints in this FOV)
fov_all_samples = np.concatenate(list(fov_samples.values()))
fov_stats = mp_utils.get_val_stats(fov_all_samples, percentiles)
fov_level_statistics.append(fov_stats)

# Per-timepoint statistics for this FOV
timepoint_stats = {}
for time_idx, samples in fov_samples.items():
timepoint_stats[time_idx] = mp_utils.get_val_stats(samples, percentiles)
per_timepoint_statistics.append(timepoint_stats)

else:
# Original behavior - sample across all timepoints
positions, fov_sample_values = mp_utils.mp_sample_im_pixels(
this_channels_args, num_workers
)
dataset_sample_values = np.concatenate(
[arr.flatten() for arr in fov_sample_values]
)
fov_args = [(samples, percentiles) for samples in fov_sample_values]
fov_level_statistics = mp_utils.mp_get_val_stats(fov_args, num_workers)

dataset_level_statistics = mp_utils.get_val_stats(dataset_sample_values, percentiles)

dataset_statistics = {
"dataset_statistics": dataset_level_statistics,
Expand All @@ -130,6 +166,9 @@
position_statistics = dataset_statistics | {
"fov_statistics": fov_level_statistics[j],
}

if per_timepoint:
position_statistics["per_timepoint"] = per_timepoint_statistics[j]

write_meta_field(
position=pos,
Expand Down
126 changes: 109 additions & 17 deletions viscy/utils/mp_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from concurrent.futures import ProcessPoolExecutor

import iohub.ngff as ngff
Expand Down Expand Up @@ -247,43 +247,85 @@
"""
with ProcessPoolExecutor(workers) as ex:
# can't use map directly as it works only with single arg functions
res = ex.map(get_val_stats, fn_args)
res = ex.map(get_val_stats, *zip(*fn_args))
return list(res)


def get_val_stats(sample_values):
def get_val_stats(sample_values, percentiles=(25.0, 75.0)):
"""
Computes the statistics of a numpy array and returns a dictionary
of metadata corresponding to input sample values.

:param list(float) sample_values: List of sample values at respective
indices
:return dict meta_row: Dict with intensity data for image
Parameters
----------
sample_values : array_like
List of sample values at respective indices
percentiles : tuple of float, optional
Lower and upper percentiles to compute, by default (25.0, 75.0)

Returns
-------
dict
Dict with intensity data for image including mean, std, median, iqr,
percentile bounds and range
"""

percentile_vals = np.nanpercentile(sample_values, percentiles)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This call can also compute median (50%) and IQR (25%, 75%) in one sorting operation.


meta_row = {
"mean": float(np.nanmean(sample_values)),
"std": float(np.nanstd(sample_values)),
"median": float(np.nanmedian(sample_values)),
"iqr": float(scipy.stats.iqr(sample_values)),
"percentile_lower_b": float(percentile_vals[0]),
"percentile_upper_b": float(percentile_vals[1]),
"percentile_range": float(percentile_vals[1] - percentile_vals[0]),
}
return meta_row


def mp_sample_im_pixels(fn_args, workers):
"""Read and computes statistics of images with multiprocessing

:param list of tuple fn_args: list with tuples of function arguments
:param int workers: max number of workers
:return: list of paths and corresponding returned df from get_im_stats
"""

Read and computes statistics of images with multiprocessing.

Parameters
----------
fn_args : list of tuple
List with tuples of function arguments
workers : int
Max number of workers

Returns
-------
list
List of paths and corresponding returned df from get_im_stats
"""
with ProcessPoolExecutor(workers) as ex:
# can't use map directly as it works only with single arg functions
res = ex.map(sample_im_pixels, *zip(*fn_args))
return list(map(list, zip(*list(res))))


def mp_sample_im_pixels_per_timepoint(fn_args, workers):
"""
Sample pixel values per timepoint with multiprocessing.

Parameters
----------
fn_args : list of tuple
List with tuples of function arguments
workers : int
Max number of workers

Returns
-------
list
List of (position, timepoint_samples_dict) tuples
"""
with ProcessPoolExecutor(workers) as ex:
res = ex.map(sample_im_pixels_per_timepoint, *zip(*fn_args))
return list(map(list, zip(*list(res))))


def sample_im_pixels(
position: ngff.Position,
grid_spacing,
Expand All @@ -298,11 +340,19 @@
assumes that the data in the zarr store is stored in [T,C,Z,Y,X] format,
for time, channel, z, y, x.

:param Position zarr_dir: NGFF position node object
:param int grid_spacing: spacing of sampling grid in x and y
:param int channel: channel to sample from

:return list meta_rows: Dicts with intensity data for each grid point
Parameters
----------
position : ngff.Position
NGFF position node object
grid_spacing : int
Spacing of sampling grid in x and y
channel : int
Channel to sample from

Returns
-------
tuple
(position, sample_values)
"""
image_zarr = position.data

Expand All @@ -320,3 +370,45 @@
sample_values = np.stack(all_sample_values, 0).flatten()

return position, sample_values


def sample_im_pixels_per_timepoint(
position: ngff.Position,
grid_spacing,
channel,
):
"""
Sample pixel values per timepoint for per-timepoint normalization statistics.

Parameters
----------
position : ngff.Position
NGFF position node object
grid_spacing : int
Spacing of sampling grid in x and y
channel : int
Channel to sample from

Returns
-------
tuple
(position, dict mapping time_index to sample_values)
"""
image_zarr = position.data

timepoint_samples = {}
all_time_indices = list(range(image_zarr.shape[0]))
all_z_indices = list(range(image_zarr.shape[2]))

for time_index in all_time_indices:
time_sample_values = []
for z_index in all_z_indices:
image_slice = image_zarr[time_index, channel, z_index, :, :]
_, _, sample_values = image_utils.grid_sample_pixel_values(
image_slice, grid_spacing
)
time_sample_values.append(sample_values)

timepoint_samples[time_index] = np.stack(time_sample_values, 0).flatten()

return position, timepoint_samples
Loading