diff --git a/viscy/data/typing.py b/viscy/data/typing.py index c824b9416..4c131dc4a 100644 --- a/viscy/data/typing.py +++ b/viscy/data/typing.py @@ -17,11 +17,15 @@ class LevelNormStats(TypedDict): std: Tensor median: Tensor iqr: Tensor + percentile_lower_b: NotRequired[Tensor] + percentile_upper_b: NotRequired[Tensor] + percentile_range: NotRequired[Tensor] class ChannelNormStats(TypedDict): dataset_statistics: LevelNormStats fov_statistics: LevelNormStats + per_timepoint: NotRequired[dict[int, LevelNormStats]] NormMeta = dict[str, ChannelNormStats] diff --git a/viscy/trainer.py b/viscy/trainer.py index 03395a371..a6a1c7dc4 100644 --- a/viscy/trainer.py +++ b/viscy/trainer.py @@ -21,6 +21,8 @@ def preprocess( channel_names: list[str] | Literal[-1] = -1, num_workers: int = 1, block_size: int = 32, + per_timepoint: bool = False, + percentiles: tuple[float, float] = (50.0, 99.0), model: LightningModule | None = None, ): """ @@ -36,6 +38,10 @@ def preprocess( Number of CPU workers, by default 1 block_size : int, optional Block size to subsample images, by default 32 + per_timepoint : bool, optional + If True, compute normalization statistics per timepoint, by default False + percentiles : tuple[float, float], optional + Lower and upper percentiles to compute, by default (50.0, 99.0) model: LightningModule, optional Ignored placeholder, by default None """ @@ -52,6 +58,8 @@ def preprocess( num_workers=num_workers, channel_ids=channel_indices, grid_spacing=block_size, + per_timepoint=per_timepoint, + percentiles=percentiles, ) def export( diff --git a/viscy/transforms/_transforms.py b/viscy/transforms/_transforms.py index 0b99a2ac4..05c096077 100644 --- a/viscy/transforms/_transforms.py +++ b/viscy/transforms/_transforms.py @@ -20,7 +20,7 @@ class NormalizeSampled(MapTransform): ---------- keys : str | Iterable[str] Keys to normalize. - level : {'fov_statistics', 'dataset_statistics'} + level : {'fov_statistics', 'dataset_statistics', 'per_timepoint'} Level of normalization. subtrahend : str, optional Subtrahend for normalization, defaults to "mean". @@ -33,7 +33,7 @@ class NormalizeSampled(MapTransform): def __init__( self, keys: str | Iterable[str], - level: Literal["fov_statistics", "dataset_statistics"], + level: Literal["fov_statistics", "dataset_statistics", "per_timepoint"], subtrahend="mean", divisor="std", remove_meta: bool = False, @@ -47,7 +47,12 @@ def __init__( # TODO: need to implement the case where the preprocessing already exists def __call__(self, sample: Sample) -> Sample: for key in self.keys: - level_meta = sample["norm_meta"][key][self.level] + if self.level == "per_timepoint": + time_idx = sample["index"].time + level_meta = sample["norm_meta"][key]["per_timepoint"][time_idx] + else: + level_meta = sample["norm_meta"][key][self.level] + subtrahend_val = level_meta[self.subtrahend] divisor_val = level_meta[self.divisor] + 1e-8 # avoid div by zero sample[key] = (sample[key] - subtrahend_val) / divisor_val diff --git a/viscy/utils/meta_utils.py b/viscy/utils/meta_utils.py index 961b66967..701bce516 100644 --- a/viscy/utils/meta_utils.py +++ b/viscy/utils/meta_utils.py @@ -50,6 +50,8 @@ def generate_normalization_metadata( num_workers=4, channel_ids=-1, grid_spacing=32, + per_timepoint=False, + percentiles=(50.0, 99.0), ): """ Generate pixel intensity metadata to be later used in on-the-fly normalization @@ -62,7 +64,8 @@ def generate_normalization_metadata( { channel_idx : { dataset_statistics: dataset level normalization values (positive float), - fov_statistics: field-of-view level normalization values (positive float) + fov_statistics: field-of-view level normalization values (positive float), + per_timepoint: per-timepoint level normalization values (when enabled) }, . . @@ -74,6 +77,8 @@ def generate_normalization_metadata( :param list/int channel_ids: indices of channels to process in dataset arrays, by default calculates all :param int grid_spacing: distance between points in sampling grid + :param bool per_timepoint: if True, compute statistics per timepoint, defaults to False + :param tuple percentiles: lower and upper percentiles to compute, defaults to (50.0, 99.0) """ plate = ngff.open_ome_zarr(zarr_dir, mode="r+") position_map = list(plate.positions()) @@ -99,16 +104,47 @@ def generate_normalization_metadata( channel_name = plate.channel_names[channel] this_channels_args = tuple([args + [channel] for args in mp_grid_sampler_args]) - # NOTE: Doing sequential mp with pool execution creates synchronization - # points between each step. This could be detrimental to performance - positions, fov_sample_values = mp_utils.mp_sample_im_pixels( - this_channels_args, num_workers - ) - dataset_sample_values = np.concatenate( - [arr.flatten() for arr in fov_sample_values] - ) - fov_level_statistics = mp_utils.mp_get_val_stats(fov_sample_values, num_workers) - dataset_level_statistics = mp_utils.get_val_stats(dataset_sample_values) + if per_timepoint: + # Sample per timepoint + positions, fov_timepoint_samples = mp_utils.mp_sample_im_pixels_per_timepoint( + this_channels_args, num_workers + ) + + # Compute dataset-level statistics across all timepoints + all_dataset_samples = [] + for fov_samples in fov_timepoint_samples: + for timepoint_samples in fov_samples.values(): + all_dataset_samples.append(timepoint_samples) + dataset_sample_values = np.concatenate(all_dataset_samples) + + # Compute per-timepoint statistics for each FOV + fov_level_statistics = [] + per_timepoint_statistics = [] + + for fov_samples in fov_timepoint_samples: + # FOV-level statistics (across all timepoints in this FOV) + fov_all_samples = np.concatenate(list(fov_samples.values())) + fov_stats = mp_utils.get_val_stats(fov_all_samples, percentiles) + fov_level_statistics.append(fov_stats) + + # Per-timepoint statistics for this FOV + timepoint_stats = {} + for time_idx, samples in fov_samples.items(): + timepoint_stats[time_idx] = mp_utils.get_val_stats(samples, percentiles) + per_timepoint_statistics.append(timepoint_stats) + + else: + # Original behavior - sample across all timepoints + positions, fov_sample_values = mp_utils.mp_sample_im_pixels( + this_channels_args, num_workers + ) + dataset_sample_values = np.concatenate( + [arr.flatten() for arr in fov_sample_values] + ) + fov_args = [(samples, percentiles) for samples in fov_sample_values] + fov_level_statistics = mp_utils.mp_get_val_stats(fov_args, num_workers) + + dataset_level_statistics = mp_utils.get_val_stats(dataset_sample_values, percentiles) dataset_statistics = { "dataset_statistics": dataset_level_statistics, @@ -130,6 +166,9 @@ def generate_normalization_metadata( position_statistics = dataset_statistics | { "fov_statistics": fov_level_statistics[j], } + + if per_timepoint: + position_statistics["per_timepoint"] = per_timepoint_statistics[j] write_meta_field( position=pos, diff --git a/viscy/utils/mp_utils.py b/viscy/utils/mp_utils.py index ee46f7ea9..2607d2857 100644 --- a/viscy/utils/mp_utils.py +++ b/viscy/utils/mp_utils.py @@ -247,43 +247,85 @@ def mp_get_val_stats(fn_args, workers): """ with ProcessPoolExecutor(workers) as ex: # can't use map directly as it works only with single arg functions - res = ex.map(get_val_stats, fn_args) + res = ex.map(get_val_stats, *zip(*fn_args)) return list(res) -def get_val_stats(sample_values): +def get_val_stats(sample_values, percentiles=(25.0, 75.0)): """ Computes the statistics of a numpy array and returns a dictionary of metadata corresponding to input sample values. - :param list(float) sample_values: List of sample values at respective - indices - :return dict meta_row: Dict with intensity data for image + Parameters + ---------- + sample_values : array_like + List of sample values at respective indices + percentiles : tuple of float, optional + Lower and upper percentiles to compute, by default (25.0, 75.0) + + Returns + ------- + dict + Dict with intensity data for image including mean, std, median, iqr, + percentile bounds and range """ - + percentile_vals = np.nanpercentile(sample_values, percentiles) + meta_row = { "mean": float(np.nanmean(sample_values)), "std": float(np.nanstd(sample_values)), "median": float(np.nanmedian(sample_values)), "iqr": float(scipy.stats.iqr(sample_values)), + "percentile_lower_b": float(percentile_vals[0]), + "percentile_upper_b": float(percentile_vals[1]), + "percentile_range": float(percentile_vals[1] - percentile_vals[0]), } return meta_row def mp_sample_im_pixels(fn_args, workers): - """Read and computes statistics of images with multiprocessing - - :param list of tuple fn_args: list with tuples of function arguments - :param int workers: max number of workers - :return: list of paths and corresponding returned df from get_im_stats """ - + Read and computes statistics of images with multiprocessing. + + Parameters + ---------- + fn_args : list of tuple + List with tuples of function arguments + workers : int + Max number of workers + + Returns + ------- + list + List of paths and corresponding returned df from get_im_stats + """ with ProcessPoolExecutor(workers) as ex: # can't use map directly as it works only with single arg functions res = ex.map(sample_im_pixels, *zip(*fn_args)) return list(map(list, zip(*list(res)))) +def mp_sample_im_pixels_per_timepoint(fn_args, workers): + """ + Sample pixel values per timepoint with multiprocessing. + + Parameters + ---------- + fn_args : list of tuple + List with tuples of function arguments + workers : int + Max number of workers + + Returns + ------- + list + List of (position, timepoint_samples_dict) tuples + """ + with ProcessPoolExecutor(workers) as ex: + res = ex.map(sample_im_pixels_per_timepoint, *zip(*fn_args)) + return list(map(list, zip(*list(res)))) + + def sample_im_pixels( position: ngff.Position, grid_spacing, @@ -298,11 +340,19 @@ def sample_im_pixels( assumes that the data in the zarr store is stored in [T,C,Z,Y,X] format, for time, channel, z, y, x. - :param Position zarr_dir: NGFF position node object - :param int grid_spacing: spacing of sampling grid in x and y - :param int channel: channel to sample from - - :return list meta_rows: Dicts with intensity data for each grid point + Parameters + ---------- + position : ngff.Position + NGFF position node object + grid_spacing : int + Spacing of sampling grid in x and y + channel : int + Channel to sample from + + Returns + ------- + tuple + (position, sample_values) """ image_zarr = position.data @@ -320,3 +370,45 @@ def sample_im_pixels( sample_values = np.stack(all_sample_values, 0).flatten() return position, sample_values + + +def sample_im_pixels_per_timepoint( + position: ngff.Position, + grid_spacing, + channel, +): + """ + Sample pixel values per timepoint for per-timepoint normalization statistics. + + Parameters + ---------- + position : ngff.Position + NGFF position node object + grid_spacing : int + Spacing of sampling grid in x and y + channel : int + Channel to sample from + + Returns + ------- + tuple + (position, dict mapping time_index to sample_values) + """ + image_zarr = position.data + + timepoint_samples = {} + all_time_indices = list(range(image_zarr.shape[0])) + all_z_indices = list(range(image_zarr.shape[2])) + + for time_index in all_time_indices: + time_sample_values = [] + for z_index in all_z_indices: + image_slice = image_zarr[time_index, channel, z_index, :, :] + _, _, sample_values = image_utils.grid_sample_pixel_values( + image_slice, grid_spacing + ) + time_sample_values.append(sample_values) + + timepoint_samples[time_index] = np.stack(time_sample_values, 0).flatten() + + return position, timepoint_samples