diff --git a/pyproject.toml b/pyproject.toml index 0930bbfc2..859017812 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,7 +75,34 @@ packages = ["viscy"] write_to = "viscy/_version.py" [tool.ruff] -src = ["viscy", "tests"] line-length = 88 -lint.extend-select = ["I001"] -lint.isort.known-first-party = ["viscy"] +src = ["viscy", "tests"] +extend-include = ["*.ipynb"] +target-version = "py311" +# Exclude the following for now. Later on we should check every Python file. +extend-exclude = ["viscy/scripts/*", "applications/*", "examples/*"] + +[tool.ruff.format] +quote-style = "double" +indent-style = "space" +docstring-code-format = true +docstring-code-line-length = "dynamic" + +[tool.ruff.lint] +select = [ + "D", # pydocstyle + "I", # isort +] +ignore = [ + "D100", # Missing docstring in public module + "D104", # Missing docstring in public package + "D105", # __magic__ methods are often self-explanatory, allow missing docstrings + "D107", # Missing docstring in __init__ + # Disable one in each pair of mutually incompatible rules + "D203", # We don’t want a blank line before a class docstring + "D213", # <> We want docstrings to start immediately after the opening triple quote + "D400", # first line should end with a period [Bug: doesn’t work with single-line docstrings] + "D401", # First line should be in imperative mood; try rephrasing +] +per-file-ignores."tests/*" = ["D"] +pydocstyle.convention = "numpy" \ No newline at end of file diff --git a/tests/data/test_hcs.py b/tests/data/test_hcs.py index c71488c4c..8d040e827 100644 --- a/tests/data/test_hcs.py +++ b/tests/data/test_hcs.py @@ -3,7 +3,6 @@ from iohub import open_ome_zarr from monai.transforms import RandSpatialCropSamplesd from pytest import mark - from viscy.data.hcs import HCSDataModule from viscy.trainer import VisCyTrainer diff --git a/tests/data/test_select.py b/tests/data/test_select.py index 6ffc55b83..98b38f312 100644 --- a/tests/data/test_select.py +++ b/tests/data/test_select.py @@ -1,6 +1,5 @@ import pytest from iohub.ngff import open_ome_zarr - from viscy.data.select import SelectWell diff --git a/tests/data/test_triplet.py b/tests/data/test_triplet.py index 04f4c5019..b39e871c8 100644 --- a/tests/data/test_triplet.py +++ b/tests/data/test_triplet.py @@ -1,7 +1,6 @@ import pandas as pd from iohub import open_ome_zarr from pytest import mark - from viscy.data.triplet import TripletDataModule diff --git a/tests/evaluation/test_cell_feature_metrics.py b/tests/evaluation/test_cell_feature_metrics.py index d118fd381..3feb6596c 100644 --- a/tests/evaluation/test_cell_feature_metrics.py +++ b/tests/evaluation/test_cell_feature_metrics.py @@ -2,7 +2,6 @@ import pandas as pd import pytest from skimage import measure - from viscy.representation.evaluation.feature import CellFeatures, DynamicFeatures diff --git a/tests/evaluation/test_evaluation_metrics.py b/tests/evaluation/test_evaluation_metrics.py index af1c8411b..1d5afc5ee 100644 --- a/tests/evaluation/test_evaluation_metrics.py +++ b/tests/evaluation/test_evaluation_metrics.py @@ -3,7 +3,6 @@ import torch from skimage import data, measure from skimage.util import img_as_float - from viscy.translation.evaluation_metrics import ( POD_metric, VOI_metric, diff --git a/tests/preprocessing/generate_masks_tests.py b/tests/preprocessing/generate_masks_tests.py index 45f2e4f4a..6916de830 100644 --- a/tests/preprocessing/generate_masks_tests.py +++ b/tests/preprocessing/generate_masks_tests.py @@ -8,7 +8,6 @@ import pandas as pd import skimage.io as sk_im_io from testfixtures import TempDirectory - from viscy.preprocessing.generate_masks import MaskProcessor from viscy.utils import aux_utils as aux_utils @@ -129,7 +128,7 @@ def test_generate_masks_uni(self): nose.tools.assert_equal(len(frames_meta), exp_len) for idx in range(exp_len): nose.tools.assert_equal( - "im_c003_z00{}_t000_p001.npy".format(idx), + f"im_c003_z00{idx}_t000_p001.npy", frames_meta.iloc[idx]["file_name"], ) diff --git a/tests/preprocessing/resize_images_tests.py b/tests/preprocessing/resize_images_tests.py index 835f3de4b..5b41d0390 100644 --- a/tests/preprocessing/resize_images_tests.py +++ b/tests/preprocessing/resize_images_tests.py @@ -4,10 +4,9 @@ import cv2 import numpy as np import pandas as pd -from testfixtures import TempDirectory - import viscy.preprocessing.resize_images as resize_images import viscy.utils.aux_utils as aux_utils +from testfixtures import TempDirectory class TestResizeImages(unittest.TestCase): @@ -133,7 +132,7 @@ def test_resize_volumes(self): ), ignore_index=True, ) - op_fname = "im_c00{}_z000_t005_p007_3.3-0.8-1.0.npy".format(c) + op_fname = "im_c00{c}_z000_t005_p007_3.3-0.8-1.0.npy".format(c) exp_meta_dict.append( { "time_idx": self.time_idx, @@ -169,7 +168,7 @@ def test_resize_volumes(self): exp_meta_dict = [] for c in channel_ids: for s in [0, 2]: - op_fname = "im_c00{}_z00{}_t005_p007_3.3-0.8-1.0.npy".format(c, s) + op_fname = "im_c00{c}_z00{s}_t005_p007_3.3-0.8-1.0.npy".format(c, s) exp_meta_dict.append( { "time_idx": self.time_idx, diff --git a/tests/preprocessing/test_pixel_ratio.py b/tests/preprocessing/test_pixel_ratio.py index 0251fefc1..85e565e8b 100644 --- a/tests/preprocessing/test_pixel_ratio.py +++ b/tests/preprocessing/test_pixel_ratio.py @@ -1,5 +1,4 @@ from numpy.testing import assert_allclose - from viscy.preprocessing.pixel_ratio import sematic_class_weights diff --git a/tests/representation/test_feature.py b/tests/representation/test_feature.py index 5bd83ebdc..90a7c09fc 100644 --- a/tests/representation/test_feature.py +++ b/tests/representation/test_feature.py @@ -4,7 +4,6 @@ import pandas as pd import pytest from iohub import open_ome_zarr - from viscy.representation.evaluation.feature import ( CellFeatures, DynamicFeatures, diff --git a/tests/representation/test_lca.py b/tests/representation/test_lca.py index 1794804d1..f87ad3d33 100644 --- a/tests/representation/test_lca.py +++ b/tests/representation/test_lca.py @@ -1,7 +1,6 @@ import numpy as np import torch from sklearn.linear_model import LogisticRegression - from viscy.representation.evaluation.lca import linear_from_binary_logistic_regression diff --git a/tests/transforms/test_adjust_contrast.py b/tests/transforms/test_adjust_contrast.py index d2cd6e9cc..a40538162 100644 --- a/tests/transforms/test_adjust_contrast.py +++ b/tests/transforms/test_adjust_contrast.py @@ -1,7 +1,6 @@ import pytest import torch from monai.transforms import AdjustContrast, Compose - from viscy.transforms import BatchedRandAdjustContrast, BatchedRandAdjustContrastd diff --git a/tests/transforms/test_crop.py b/tests/transforms/test_crop.py index e5f80ae04..ca26b9efe 100644 --- a/tests/transforms/test_crop.py +++ b/tests/transforms/test_crop.py @@ -1,7 +1,6 @@ import pytest import torch from monai.transforms import Compose - from viscy.transforms._crop import ( BatchedCenterSpatialCrop, BatchedCenterSpatialCropd, diff --git a/tests/transforms/test_flip.py b/tests/transforms/test_flip.py index 0fbd1bf5a..eb4596054 100644 --- a/tests/transforms/test_flip.py +++ b/tests/transforms/test_flip.py @@ -1,6 +1,5 @@ import pytest import torch - from viscy.transforms import BatchedRandFlip, BatchedRandFlipd diff --git a/tests/transforms/test_gaussian_smooth.py b/tests/transforms/test_gaussian_smooth.py index 64bd979aa..136c1c55d 100644 --- a/tests/transforms/test_gaussian_smooth.py +++ b/tests/transforms/test_gaussian_smooth.py @@ -7,7 +7,6 @@ get_gaussian_kernel3d, ) from monai.transforms.intensity.array import GaussianSmooth - from viscy.transforms import BatchedRandGaussianSmooth, BatchedRandGaussianSmoothd from viscy.transforms._gaussian_smooth import filter3d_separable diff --git a/tests/transforms/test_noise.py b/tests/transforms/test_noise.py index da9a1e9fd..5e58e5c9b 100644 --- a/tests/transforms/test_noise.py +++ b/tests/transforms/test_noise.py @@ -1,7 +1,6 @@ import pytest import torch from monai.transforms import Compose - from viscy.transforms import BatchedRandGaussianNoise, BatchedRandGaussianNoised diff --git a/tests/transforms/test_scale_intensity.py b/tests/transforms/test_scale_intensity.py index 2cfdaa954..03d323074 100644 --- a/tests/transforms/test_scale_intensity.py +++ b/tests/transforms/test_scale_intensity.py @@ -1,7 +1,6 @@ import pytest import torch from monai.transforms import RandScaleIntensity - from viscy.transforms import BatchedRandScaleIntensity, BatchedRandScaleIntensityd diff --git a/tests/transforms/test_transforms.py b/tests/transforms/test_transforms.py index 88b955e3b..8e3efb481 100644 --- a/tests/transforms/test_transforms.py +++ b/tests/transforms/test_transforms.py @@ -1,6 +1,5 @@ import pytest import torch - from viscy.transforms._decollate import Decollate from viscy.transforms._transforms import ( BatchedScaleIntensityRangePercentiles, diff --git a/tests/translation/test_evaluation.py b/tests/translation/test_evaluation.py index ebfbaff8c..d883ff62a 100644 --- a/tests/translation/test_evaluation.py +++ b/tests/translation/test_evaluation.py @@ -6,7 +6,6 @@ import pytest from lightning.pytorch.loggers import CSVLogger from numpy.testing import assert_array_equal - from viscy.data.segmentation import SegmentationDataModule from viscy.trainer import Trainer from viscy.translation.evaluation import SegmentationMetrics2D diff --git a/tests/unet/networks/Unet25D_tests.py b/tests/unet/networks/Unet25D_tests.py index f954d8873..135d6cdc6 100644 --- a/tests/unet/networks/Unet25D_tests.py +++ b/tests/unet/networks/Unet25D_tests.py @@ -4,7 +4,6 @@ import numpy as np import torch - import viscy.utils.cli_utils as io_utils from viscy.unet.networks.Unet25D import Unet25d diff --git a/tests/unet/networks/Unet2D_tests.py b/tests/unet/networks/Unet2D_tests.py index 3f69f2145..7ea8c4f3f 100644 --- a/tests/unet/networks/Unet2D_tests.py +++ b/tests/unet/networks/Unet2D_tests.py @@ -4,7 +4,6 @@ import numpy as np import torch - import viscy.utils.cli_utils as io_utils from viscy.unet.networks.Unet2D import Unet2d diff --git a/tests/unet/networks/layers/ConvBlock2D_tests.py b/tests/unet/networks/layers/ConvBlock2D_tests.py index f708e8008..876421ab7 100644 --- a/tests/unet/networks/layers/ConvBlock2D_tests.py +++ b/tests/unet/networks/layers/ConvBlock2D_tests.py @@ -4,7 +4,6 @@ import numpy as np import torch - import viscy.utils.cli_utils as io_utils from viscy.unet.networks.layers.ConvBlock2D import ConvBlock2D diff --git a/tests/unet/networks/layers/ConvBlock3D_tests.py b/tests/unet/networks/layers/ConvBlock3D_tests.py index 60fbd3ef2..4760fcf0e 100644 --- a/tests/unet/networks/layers/ConvBlock3D_tests.py +++ b/tests/unet/networks/layers/ConvBlock3D_tests.py @@ -4,7 +4,6 @@ import numpy as np import torch - import viscy.utils.cli_utils as io_utils from viscy.unet.networks.layers.ConvBlock3D import ConvBlock3D diff --git a/tests/unet/test_fcmae.py b/tests/unet/test_fcmae.py index f22efa4c8..b044bfc1b 100644 --- a/tests/unet/test_fcmae.py +++ b/tests/unet/test_fcmae.py @@ -1,5 +1,4 @@ import torch - from viscy.unet.networks.fcmae import ( FullyConvolutionalMAE, MaskedAdaptiveProjection, diff --git a/tests/utils/image_utils_tests.py b/tests/utils/image_utils_tests.py index f90ab0e34..761569ebb 100644 --- a/tests/utils/image_utils_tests.py +++ b/tests/utils/image_utils_tests.py @@ -1,5 +1,4 @@ import numpy as np - from viscy.utils.image_utils import grid_sample_pixel_values, preprocess_image diff --git a/tests/utils/masks_utils_tests.py b/tests/utils/masks_utils_tests.py index 71d437db5..42213ad4f 100644 --- a/tests/utils/masks_utils_tests.py +++ b/tests/utils/masks_utils_tests.py @@ -2,7 +2,6 @@ import numpy as np from skimage import draw from skimage.filters import gaussian - from viscy.utils.masks import ( create_unimodal_mask, get_unet_border_weight_map, diff --git a/tests/utils/mp_utils_tests.py b/tests/utils/mp_utils_tests.py index 89b452550..1504cfaa8 100644 --- a/tests/utils/mp_utils_tests.py +++ b/tests/utils/mp_utils_tests.py @@ -5,11 +5,10 @@ import numpy as np import numpy.testing import skimage.io as sk_im_io -from testfixtures import TempDirectory - import viscy.utils.aux_utils as aux_utils import viscy.utils.image_utils as image_utils import viscy.utils.mp_utils as mp_utils +from testfixtures import TempDirectory from viscy.utils.masks import create_otsu_mask diff --git a/viscy/__init__.py b/viscy/__init__.py index 31573ed3c..e69de29bb 100644 --- a/viscy/__init__.py +++ b/viscy/__init__.py @@ -1 +0,0 @@ -"""Learning vision for cells""" diff --git a/viscy/cli.py b/viscy/cli.py index 0c07787ad..f85d30786 100644 --- a/viscy/cli.py +++ b/viscy/cli.py @@ -6,17 +6,24 @@ import torch from jsonargparse import lazy_instance from lightning.pytorch import LightningDataModule, LightningModule -from lightning.pytorch.cli import LightningCLI +from lightning.pytorch.cli import LightningArgumentParser, LightningCLI from lightning.pytorch.loggers import TensorBoardLogger from viscy.trainer import VisCyTrainer class VisCyCLI(LightningCLI): - """Extending lightning CLI arguments and defualts.""" + """Extending Lightning CLI arguments and defaults for VisCy.""" @staticmethod def subcommands() -> dict[str, set[str]]: + """Define subcommands and their required arguments. + + Returns + ------- + dict[str, set[str]] + Dictionary mapping subcommand names to sets of required argument names. + """ subcommands = LightningCLI.subcommands() subcommand_base_args = {"model"} subcommands["preprocess"] = subcommand_base_args @@ -24,7 +31,14 @@ def subcommands() -> dict[str, set[str]]: subcommands["precompute"] = subcommand_base_args return subcommands - def add_arguments_to_parser(self, parser) -> None: + def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: + """Add default arguments to the Lightning CLI parser. + + Parameters + ---------- + parser : LightningArgumentParser + Lightning CLI parser instance to configure. + """ parser.set_defaults( { "trainer.logger": lazy_instance( @@ -45,8 +59,8 @@ def _setup_environment() -> None: def main() -> None: - """ - Main Lightning CLI entry point. + """Run the Lightning CLI entry point. + Parse log level and set TF32 precision. Set default random seed to 42. """ diff --git a/viscy/data/cell_classification.py b/viscy/data/cell_classification.py index 57888f188..a30b55aff 100644 --- a/viscy/data/cell_classification.py +++ b/viscy/data/cell_classification.py @@ -1,5 +1,5 @@ +from collections.abc import Callable from pathlib import Path -from typing import Callable import pandas as pd import torch @@ -14,6 +14,29 @@ class ClassificationDataset(Dataset): + """Dataset for cell classification tasks. + + A PyTorch Dataset that provides cell patches and labels for classification. + Loads image patches from HCS OME-Zarr data based on cell annotations. + + Parameters + ---------- + plate : Plate + HCS OME-Zarr plate containing image data. + annotation : pd.DataFrame + DataFrame with cell annotations and labels. + channel_name : str + Name of the image channel to load. + z_range : tuple[int, int] + Range of Z slices to include (start, end). + transform : Callable | None, optional + Transform to apply to image patches. + initial_yx_patch_size : tuple[int, int] + Initial patch size in Y and X dimensions. + return_indices : bool + Whether to return cell indices with patches, by default False. + """ + def __init__( self, plate: Plate, @@ -46,11 +69,25 @@ def __init__( self.label_column = label_column def __len__(self): + """Return the number of samples in the dataset.""" return len(self.annotation) def __getitem__( self, idx ) -> tuple[Tensor, Tensor] | tuple[Tensor, Tensor, dict[str, int | str]]: + """ + Get a sample from the dataset. + + Parameters + ---------- + idx : int + Index of the sample to retrieve. + + Returns + ------- + tuple[Tensor, Tensor] or tuple[Tensor, Tensor, dict[str, int | str]] + Image tensor, label tensor, and optionally cell indices. + """ row = self.annotation.iloc[idx] fov_name, t, y, x = row["fov_name"], row["t"], row["y"], row["x"] fov = self.plate[fov_name] @@ -76,6 +113,37 @@ def __getitem__( class ClassificationDataModule(LightningDataModule): + """Lightning DataModule for cell classification tasks. + + Manages data loading and preprocessing for cell classification workflows. + Handles train/validation splits and applies appropriate transforms. + + Parameters + ---------- + image_path : Path + Path to HCS OME-Zarr image data. + annotation_path : Path + Path to cell annotation CSV file. + val_fovs : list[str], optional + List of FOV names to use for validation. + channel_name : str + Name of the image channel to load. + z_range : tuple[int, int] + Range of Z slices to include (start, end). + train_exlude_timepoints : list[int] + Timepoints to exclude from training data. + train_transforms : list[Callable], optional + List of transforms to apply to training data. + val_transforms : list[Callable], optional + List of transforms to apply to validation data. + initial_yx_patch_size : tuple[int, int] + Initial patch size in Y and X dimensions. + batch_size : int + Batch size for data loading. + num_workers : int + Number of workers for data loading. + """ + def __init__( self, image_path: Path, @@ -128,7 +196,22 @@ def _subset( label_column=self.label_column, ) - def setup(self, stage=None): + def setup(self, stage=None) -> None: + """ + Set up datasets for the specified stage. + + Parameters + ---------- + stage : str, optional + Stage to set up for ('fit', 'validate', 'predict', 'test'). + + Raises + ------ + NotImplementedError + If stage is 'test'. + ValueError + If stage is unknown. + """ plate = open_ome_zarr(self.image_path) annotation = pd.read_csv(self.annotation_path) all_fovs = [name for (name, _) in plate.positions()] @@ -171,9 +254,17 @@ def setup(self, stage=None): elif stage == "test": raise NotImplementedError("Test stage not implemented.") else: - raise (f"Unknown stage: {stage}") + raise ValueError(f"Unknown stage: {stage}") - def train_dataloader(self): + def train_dataloader(self) -> DataLoader: + """ + Create training data loader. + + Returns + ------- + DataLoader + Training data loader with shuffling enabled. + """ return DataLoader( self.train_dataset, batch_size=self.batch_size, @@ -181,7 +272,15 @@ def train_dataloader(self): shuffle=True, ) - def val_dataloader(self): + def val_dataloader(self) -> DataLoader: + """ + Create validation data loader. + + Returns + ------- + DataLoader + Validation data loader without shuffling. + """ return DataLoader( self.val_dataset, batch_size=self.batch_size, @@ -189,7 +288,15 @@ def val_dataloader(self): shuffle=False, ) - def predict_dataloader(self): + def predict_dataloader(self) -> DataLoader: + """ + Create prediction data loader. + + Returns + ------- + DataLoader + Prediction data loader without shuffling. + """ return DataLoader( self.predict_dataset, batch_size=self.batch_size, diff --git a/viscy/data/combined.py b/viscy/data/combined.py index 512dff188..2d2f0c9c8 100644 --- a/viscy/data/combined.py +++ b/viscy/data/combined.py @@ -1,8 +1,9 @@ import bisect import logging from collections import defaultdict +from collections.abc import Sequence from enum import Enum -from typing import Literal, Sequence +from typing import Literal import torch from lightning.pytorch import LightningDataModule @@ -17,6 +18,12 @@ class CombineMode(Enum): + """Enumeration of data combination modes for CombinedDataModule. + + Defines how multiple data modules should be combined during training, + validation, and testing phases. + """ + MIN_SIZE = "min_size" MAX_SIZE_CYCLE = "max_size_cycle" MAX_SIZE = "max_size" @@ -25,6 +32,7 @@ class CombineMode(Enum): class CombinedDataModule(LightningDataModule): """Wrapper for combining multiple data modules. + For supported modes, see ``lightning.pytorch.utilities.combined_loader``. Parameters @@ -57,31 +65,71 @@ def __init__( self.predict_mode = CombineMode(predict_mode).value self.prepare_data_per_node = True - def prepare_data(self): + def prepare_data(self) -> None: + """Prepare data for all constituent data modules. + + Propagates trainer reference and calls prepare_data on each + data module for dataset downloading and preprocessing. + """ for dm in self.data_modules: dm.trainer = self.trainer dm.prepare_data() - def setup(self, stage: Literal["fit", "validate", "test", "predict"]): + def setup(self, stage: Literal["fit", "validate", "test", "predict"]) -> None: + """Set up data modules for specified training stage. + + Parameters + ---------- + stage : Literal["fit", "validate", "test", "predict"] + Current training stage for Lightning setup. + """ for dm in self.data_modules: dm.setup(stage) - def train_dataloader(self): + def train_dataloader(self) -> CombinedLoader: + """Create combined training dataloader. + + Returns + ------- + CombinedLoader + Combined dataloader using specified train_mode strategy. + """ return CombinedLoader( [dm.train_dataloader() for dm in self.data_modules], mode=self.train_mode ) - def val_dataloader(self): + def val_dataloader(self) -> CombinedLoader: + """Create combined validation dataloader. + + Returns + ------- + CombinedLoader + Combined dataloader using specified val_mode strategy. + """ return CombinedLoader( [dm.val_dataloader() for dm in self.data_modules], mode=self.val_mode ) - def test_dataloader(self): + def test_dataloader(self) -> CombinedLoader: + """Create combined test dataloader. + + Returns + ------- + CombinedLoader + Combined dataloader using specified test_mode strategy. + """ return CombinedLoader( [dm.test_dataloader() for dm in self.data_modules], mode=self.test_mode ) - def predict_dataloader(self): + def predict_dataloader(self) -> CombinedLoader: + """Create combined prediction dataloader. + + Returns + ------- + CombinedLoader + Combined dataloader using specified predict_mode strategy. + """ return CombinedLoader( [dm.predict_dataloader() for dm in self.data_modules], mode=self.predict_mode, @@ -89,10 +137,45 @@ def predict_dataloader(self): class BatchedConcatDataset(ConcatDataset): - def __getitem__(self, idx): + """Batched concatenated dataset for efficient multi-dataset sampling. + + Extends PyTorch's ConcatDataset to support batched item retrieval + from multiple datasets with optimized index grouping for ML training. + """ + + def __getitem__(self, idx: int): + """Retrieve single item by index. + + Parameters + ---------- + idx : int + Sample index across concatenated datasets. + + Raises + ------ + NotImplementedError + Single item access not implemented; use __getitems__ instead. + """ raise NotImplementedError def _get_sample_indices(self, idx: int) -> tuple[int, int]: + """Map global index to dataset and sample indices. + + Parameters + ---------- + idx : int + Global index across all concatenated datasets. + + Returns + ------- + tuple[int, int] + Dataset index and local sample index within that dataset. + + Raises + ------ + ValueError + If absolute index value exceeds dataset length. + """ if idx < 0: if -idx > len(self): raise ValueError( @@ -107,6 +190,21 @@ def _get_sample_indices(self, idx: int) -> tuple[int, int]: return dataset_idx, sample_idx def __getitems__(self, indices: list[int]) -> list[dict[str, torch.Tensor]]: + """Retrieve multiple items by indices with batched dataset access. + + Groups indices by source dataset and performs batched retrieval + for improved data loading performance during ML training. + + Parameters + ---------- + indices : list[int] + List of global indices across concatenated datasets. + + Returns + ------- + list[dict[str, torch.Tensor]] + Samples from all requested indices, maintaining order. + """ grouped_indices = defaultdict(list) for idx in indices: dataset_idx, sample_indices = self._get_sample_indices(idx) @@ -154,11 +252,33 @@ def __init__(self, data_modules: Sequence[LightningDataModule]): self.prepare_data_per_node = True def prepare_data(self): + """Prepare data for all constituent data modules. + + Propagates trainer reference and calls prepare_data on each + data module for dataset preparation and preprocessing. + """ for dm in self.data_modules: dm.trainer = self.trainer dm.prepare_data() def setup(self, stage: Literal["fit", "validate", "test", "predict"]): + """Set up concatenated datasets for training stage. + + Validates patch configuration consistency across data modules + and creates concatenated train/validation datasets. + + Parameters + ---------- + stage : Literal["fit", "validate", "test", "predict"] + Training stage - only "fit" currently supported. + + Raises + ------ + ValueError + If patches per stack are inconsistent across data modules. + NotImplementedError + If stage other than "fit" is requested. + """ self.train_patches_per_stack = 0 for dm in self.data_modules: dm.setup(stage) @@ -177,6 +297,14 @@ def setup(self, stage: Literal["fit", "validate", "test", "predict"]): ) def _dataloader_kwargs(self) -> dict: + """Get common dataloader configuration parameters. + + Returns + ------- + dict + Common PyTorch DataLoader configuration parameters including + worker settings, memory pinning, and prefetch configuration. + """ return { "num_workers": self.num_workers, "persistent_workers": self.persistent_workers, @@ -184,7 +312,15 @@ def _dataloader_kwargs(self) -> dict: "pin_memory": self.pin_memory, } - def train_dataloader(self): + def train_dataloader(self) -> DataLoader: + """Create training dataloader for concatenated datasets. + + Returns + ------- + DataLoader + PyTorch DataLoader with shuffling enabled, batch size adjusted + for patch stacking, and sample collation for ML training. + """ return DataLoader( self.train_dataset, shuffle=True, @@ -194,7 +330,15 @@ def train_dataloader(self): **self._dataloader_kwargs(), ) - def val_dataloader(self): + def val_dataloader(self) -> DataLoader: + """Create validation dataloader for concatenated datasets. + + Returns + ------- + DataLoader + PyTorch DataLoader without shuffling for deterministic + validation evaluation. + """ return DataLoader( self.val_dataset, shuffle=False, @@ -205,9 +349,23 @@ def val_dataloader(self): class BatchedConcatDataModule(ConcatDataModule): + """Concatenated data module with batched dataset access. + + Extends ConcatDataModule to use BatchedConcatDataset and + ThreadDataLoader for optimized multi-dataset training performance. + """ + _ConcatDataset = BatchedConcatDataset - def train_dataloader(self): + def train_dataloader(self) -> ThreadDataLoader: + """Create threaded training dataloader for batched access. + + Returns + ------- + ThreadDataLoader + MONAI ThreadDataLoader with thread-based workers for + optimized batched dataset access during training. + """ return ThreadDataLoader( self.train_dataset, use_thread_workers=True, @@ -218,7 +376,15 @@ def train_dataloader(self): **self._dataloader_kwargs(), ) - def val_dataloader(self): + def val_dataloader(self) -> ThreadDataLoader: + """Create threaded validation dataloader for batched access. + + Returns + ------- + ThreadDataLoader + MONAI ThreadDataLoader with thread-based workers for + optimized validation data loading. + """ return ThreadDataLoader( self.val_dataset, use_thread_workers=True, @@ -262,6 +428,26 @@ def on_after_batch_transfer(self, batch, dataloader_idx: int): class CachedConcatDataModule(LightningDataModule): + """Cached concatenated data module for distributed training. + + Concatenates multiple data modules with support for distributed + sampling and caching optimizations for large-scale ML training. + + Parameters + ---------- + data_modules : Sequence[LightningDataModule] + Data modules to concatenate. + + Raises + ------ + ValueError + If inconsistent number of workers or batch size across data modules. + NotImplementedError + If stage other than "fit" is requested. + ValueError + If patches per stack are inconsistent across data modules. + """ + def __init__(self, data_modules: Sequence[LightningDataModule]): super().__init__() self.data_modules = data_modules @@ -275,11 +461,33 @@ def __init__(self, data_modules: Sequence[LightningDataModule]): self.prepare_data_per_node = True def prepare_data(self): + """Prepare data for all constituent data modules. + + Propagates trainer reference and calls prepare_data on each + data module for dataset preparation and caching setup. + """ for dm in self.data_modules: dm.trainer = self.trainer dm.prepare_data() def setup(self, stage: Literal["fit", "validate", "test", "predict"]): + """Set up cached concatenated datasets for distributed training. + + Validates patch configuration and creates concatenated datasets + with caching optimizations for efficient distributed access. + + Parameters + ---------- + stage : Literal["fit", "validate", "test", "predict"] + Training stage - only "fit" currently supported. + + Raises + ------ + ValueError + If patches per stack are inconsistent across data modules. + NotImplementedError + If stage other than "fit" is requested. + """ self.train_patches_per_stack = 0 for dm in self.data_modules: dm.setup(stage) @@ -298,6 +506,21 @@ def setup(self, stage: Literal["fit", "validate", "test", "predict"]): def _maybe_sampler( self, dataset: Dataset, shuffle: bool ) -> ShardedDistributedSampler | None: + """Create distributed sampler if in distributed training mode. + + Parameters + ---------- + dataset : Dataset + PyTorch dataset to create sampler for. + shuffle : bool + Whether to shuffle samples across distributed processes. + + Returns + ------- + ShardedDistributedSampler | None + Distributed sampler if PyTorch distributed is initialized, + None otherwise for single-process training. + """ return ( ShardedDistributedSampler(dataset, shuffle=shuffle) if torch.distributed.is_initialized() @@ -305,6 +528,14 @@ def _maybe_sampler( ) def train_dataloader(self) -> DataLoader: + """Create training dataloader with distributed sampling support. + + Returns + ------- + DataLoader + PyTorch DataLoader with distributed sampler if available, + configured for cached dataset access during training. + """ sampler = self._maybe_sampler(self.train_dataset, shuffle=True) return DataLoader( self.train_dataset, @@ -318,6 +549,14 @@ def train_dataloader(self) -> DataLoader: ) def val_dataloader(self) -> DataLoader: + """Create validation dataloader with distributed sampling support. + + Returns + ------- + DataLoader + PyTorch DataLoader with distributed sampler if available, + configured for deterministic validation evaluation. + """ sampler = self._maybe_sampler(self.val_dataset, shuffle=False) return DataLoader( self.val_dataset, diff --git a/viscy/data/ctmc_v1.py b/viscy/data/ctmc_v1.py index 3c888175c..468b0d676 100644 --- a/viscy/data/ctmc_v1.py +++ b/viscy/data/ctmc_v1.py @@ -10,6 +10,7 @@ class CTMCv1DataModule(GPUTransformDataModule): """ Autoregression data module for the CTMCv1 dataset. + Training and validation datasets are stored in separate HCS OME-Zarr stores. Parameters @@ -18,13 +19,13 @@ class CTMCv1DataModule(GPUTransformDataModule): Path to the training dataset. val_data_path : str or Path Path to the validation dataset. - train_cpu_transforms : list of MapTransform + train_cpu_transforms : list[MapTransform] List of CPU transforms for training. - val_cpu_transforms : list of MapTransform + val_cpu_transforms : list[MapTransform] List of CPU transforms for validation. - train_gpu_transforms : list of MapTransform + train_gpu_transforms : list[MapTransform] List of GPU transforms for training. - val_gpu_transforms : list of MapTransform + val_gpu_transforms : list[MapTransform] List of GPU transforms for validation. batch_size : int, optional Batch size, by default 16. @@ -68,21 +69,38 @@ def __init__( @property def train_cpu_transforms(self) -> Compose: + """Get composed training CPU transforms.""" return self._train_cpu_transforms @property def val_cpu_transforms(self) -> Compose: + """Get composed validation CPU transforms.""" return self._val_cpu_transforms @property def train_gpu_transforms(self) -> Compose: + """Get composed training GPU transforms.""" return self._train_gpu_transforms @property def val_gpu_transforms(self) -> Compose: + """Get composed validation GPU transforms.""" return self._val_gpu_transforms def setup(self, stage: str) -> None: + """ + Set up datasets for the specified stage. + + Parameters + ---------- + stage : str + The stage to set up for. Only "fit" is currently supported. + + Raises + ------ + NotImplementedError + If stage is not "fit". + """ if stage != "fit": raise NotImplementedError("Only fit stage is supported") self._setup_fit() diff --git a/viscy/data/distributed.py b/viscy/data/distributed.py index 68e6d39e5..beab41dc2 100644 --- a/viscy/data/distributed.py +++ b/viscy/data/distributed.py @@ -1,5 +1,3 @@ -"""Utilities for DDP training.""" - from __future__ import annotations import math @@ -14,9 +12,17 @@ class ShardedDistributedSampler(DistributedSampler): + """Distributed sampler that creates sharded random permutations. + + A specialized DistributedSampler that generates sharded random permutations + to ensure proper data distribution across multiple processes in DDP training. + """ + def _sharded_randperm(self, max_size: int, generator: Generator) -> list[int]: """Generate a sharded random permutation of indices. - Overlap may occur in between the last two shards to maintain divisibility.""" + + Overlap may occur in between the last two shards to maintain divisibility. + """ sharded_randperm = [ torch.randperm(self.num_samples, generator=generator) + min(i * self.num_samples, max_size - self.num_samples) @@ -26,7 +32,7 @@ def _sharded_randperm(self, max_size: int, generator: Generator) -> list[int]: return indices.tolist() def __iter__(self): - """Modified __iter__ method to shard data across distributed ranks.""" + """Iterate through sharded data across distributed ranks.""" max_size = len(self.dataset) # type: ignore[arg-type] if self.shuffle: # deterministically shuffle based on epoch and seed diff --git a/viscy/data/gpu_aug.py b/viscy/data/gpu_aug.py index 5eb200f33..abca552e5 100644 --- a/viscy/data/gpu_aug.py +++ b/viscy/data/gpu_aug.py @@ -49,6 +49,16 @@ def _maybe_sampler( ) def train_dataloader(self) -> DataLoader: + """Create GPU-optimized training data loader. + + Configures distributed sampling, persistent workers, and memory pinning + for efficient GPU-accelerated batch processing during training. + + Returns + ------- + DataLoader + Training data loader with GPU optimization settings. + """ sampler = self._maybe_sampler(self.train_dataset, shuffle=True) _logger.debug(f"Using training sampler {sampler}") return DataLoader( @@ -65,6 +75,16 @@ def train_dataloader(self) -> DataLoader: ) def val_dataloader(self) -> DataLoader: + """Create GPU-optimized validation data loader. + + Configures distributed sampling and memory pinning for efficient + GPU-accelerated batch processing during validation phase. + + Returns + ------- + DataLoader + Validation data loader with GPU optimization settings. + """ sampler = self._maybe_sampler(self.val_dataset, shuffle=False) _logger.debug(f"Using validation sampler {sampler}") return DataLoader( @@ -82,19 +102,63 @@ def val_dataloader(self) -> DataLoader: @property @abstractmethod - def train_cpu_transforms(self) -> Compose: ... + def train_cpu_transforms(self) -> Compose: + """CPU-based transform pipeline for training data. + + Returns pre-GPU augmentation transforms executed on CPU before + GPU transfer to optimize memory bandwidth and device utilization. + + Returns + ------- + Compose + Composed CPU transforms for training preprocessing. + """ + ... @property @abstractmethod - def train_gpu_transforms(self) -> Compose: ... + def train_gpu_transforms(self) -> Compose: + """GPU-accelerated transform pipeline for training data. + + Returns GPU-resident transforms for high-performance augmentation + with device memory optimization during training workflows. + + Returns + ------- + Compose + Composed GPU transforms for training augmentation. + """ + ... @property @abstractmethod - def val_cpu_transforms(self) -> Compose: ... + def val_cpu_transforms(self) -> Compose: + """CPU-based transform pipeline for validation data. + + Returns pre-GPU validation transforms executed on CPU for + deterministic preprocessing before GPU transfer. + + Returns + ------- + Compose + Composed CPU transforms for validation preprocessing. + """ + ... @property @abstractmethod - def val_gpu_transforms(self) -> Compose: ... + def val_gpu_transforms(self) -> Compose: + """GPU-accelerated transform pipeline for validation data. + + Returns GPU-resident transforms for consistent device-optimized + preprocessing during validation phase. + + Returns + ------- + Compose + Composed GPU transforms for validation processing. + """ + ... class CachedOmeZarrDataset(Dataset): @@ -147,7 +211,7 @@ def __init__( def __len__(self) -> int: return len(self._metadata_map) - def __getitem__(self, idx: int) -> dict[str, Tensor]: + def __getitem__(self, idx: int) -> dict[str, Tensor] | list[dict[str, Tensor]]: position, time_idx, norm_meta = self._metadata_map[idx] cache = self._cache_map[idx] if cache is None: @@ -240,18 +304,46 @@ def __init__( @property def train_cpu_transforms(self) -> Compose: + """CPU-based transform pipeline for training data. + + Returns + ------- + Compose + Composed CPU transforms applied before GPU transfer. + """ return self._train_cpu_transforms @property def train_gpu_transforms(self) -> Compose: + """GPU-accelerated transform pipeline for training data. + + Returns + ------- + Compose + Composed GPU transforms for device-optimized augmentation. + """ return self._train_gpu_transforms @property def val_cpu_transforms(self) -> Compose: + """CPU-based transform pipeline for validation data. + + Returns + ------- + Compose + Composed CPU transforms applied before GPU transfer. + """ return self._val_cpu_transforms @property def val_gpu_transforms(self) -> Compose: + """GPU-accelerated transform pipeline for validation data. + + Returns + ------- + Compose + Composed GPU transforms for device-optimized processing. + """ return self._val_gpu_transforms def _set_fit_global_state(self, num_positions: int) -> list[int]: @@ -279,6 +371,24 @@ def _filter_fit_fovs(self, plate: Plate) -> list[Position]: return positions def setup(self, stage: Literal["fit", "validate"]) -> None: + """Set up datasets with GPU-optimized caching and memory management. + + Configures train/validation split with shared memory caching for + efficient GPU batch loading. Initializes MONAI metadata tracking + and distributed data sampling. + + Parameters + ---------- + stage : Literal["fit", "validate"] + PyTorch Lightning stage for dataset configuration. + + Raises + ------ + NotImplementedError + If stage is not "fit" or "validate". + ValueError + If fewer than 2 FOVs available for train/validation split. + """ if stage not in ("fit", "validate"): raise NotImplementedError("Only fit and validate stages are supported.") cache_map = Manager().dict() diff --git a/viscy/data/hcs.py b/viscy/data/hcs.py index 609b293e9..078234db1 100644 --- a/viscy/data/hcs.py +++ b/viscy/data/hcs.py @@ -33,8 +33,15 @@ def _ensure_channel_list(str_or_seq: str | Sequence[str]) -> list[str]: """ Ensure channel argument is a list of strings. - :param Union[str, Sequence[str]] str_or_seq: channel name or list of channel names - :return list[str]: list of channel names + Parameters + ---------- + str_or_seq : str | Sequence[str] + Channel name or list of channel names + + Returns + ------- + list[str] + List of channel names """ if isinstance(str_or_seq, str): return [str_or_seq] @@ -49,7 +56,26 @@ def _ensure_channel_list(str_or_seq: str | Sequence[str]) -> list[str]: def _search_int_in_str(pattern: str, file_name: str) -> str: """Search image indices in a file name with regex patterns and strip leading zeros. - E.g. ``'001'`` -> ``1``""" + + E.g. ``'001'`` -> ``1``. + + Parameters + ---------- + pattern : str + Regex pattern to search for in filename + file_name : str + Filename to search within + + Returns + ------- + str + Extracted string with leading zeros stripped + + Raises + ------ + ValueError + If pattern is not found in filename + """ match = re.search(pattern, file_name) if match: return match.group() @@ -60,10 +86,17 @@ def _search_int_in_str(pattern: str, file_name: str) -> str: def _collate_samples(batch: Sequence[Sample]) -> Sample: """Collate samples into a batch sample. - :param Sequence[Sample] batch: a sequence of dictionaries, + Parameters + ---------- + batch : Sequence[Sample] + A sequence of dictionaries, where each key may point to a value of a single tensor or a list of tensors, as is the case with ``train_patches_per_stack > 1``. - :return Sample: Batch sample (dictionary of tensors) + + Returns + ------- + Sample + Batch sample (dictionary of tensors) """ collated: Sample = {} for key in batch[0].keys(): @@ -78,9 +111,19 @@ def _collate_samples(batch: Sequence[Sample]) -> Sample: def _read_norm_meta(fov: Position) -> NormMeta | None: - """ - Read normalization metadata from the FOV. + """Read normalization metadata from the FOV. + Convert to float32 tensors to avoid automatic casting to float64. + + Parameters + ---------- + fov : Position + OME-Zarr Position object containing metadata + + Returns + ------- + NormMeta | None + Normalization metadata dictionary or None if not available """ norm_meta = fov.zattrs.get("normalization", None) if norm_meta is None: @@ -97,15 +140,21 @@ def _read_norm_meta(fov: Position) -> NormMeta | None: class SlidingWindowDataset(Dataset): - """Torch dataset where each element is a window of - (C, Z, Y, X) where C=2 (source and target) and Z is ``z_window_size``. + """Torch dataset where each element is a window of (C, Z, Y, X). - :param list[Position] positions: FOVs to include in dataset - :param ChannelMap channels: source and target channel names, + Where C=2 (source and target) and Z is ``z_window_size``. + + Parameters + ---------- + positions : list[Position] + FOVs to include in dataset + channels : ChannelMap + Source and target channel names, e.g. ``{'source': 'Phase', 'target': ['Nuclei', 'Membrane']}`` - :param int z_window_size: Z window size of the 2.5D U-Net, 1 for 2D - :param DictTransform | None transform: - a callable that transforms data, defaults to None + z_window_size : int + Z window size of the 2.5D U-Net, 1 for 2D + transform : DictTransform | None, optional + A callable that transforms data, by default None """ def __init__( @@ -131,8 +180,10 @@ def __init__( self._get_windows() def _get_windows(self) -> None: - """Count the sliding windows along T and Z, - and build an index-to-window LUT.""" + """Count the sliding windows along T and Z. + + And build an index-to-window LUT. + """ w = 0 self.window_keys = [] self.window_arrays = [] @@ -163,16 +214,28 @@ def _find_window(self, index: int) -> tuple[ImageArray, int, NormMeta | None]: def _read_img_window( self, img: ImageArray, ch_idx: list[int], tz: int - ) -> tuple[list[Tensor], HCSStackIndex]: + ) -> tuple[tuple[Tensor, ...], tuple[str, int, int]]: """Read image window as tensor. - :param ImageArray img: NGFF image array - :param list[int] ch_idx: list of channel indices to read, - output channel ordering will reflect the sequence - :param int tz: window index within the FOV, counted Z-first - :return list[Tensor], HCSStackIndex: + Parameters + ---------- + img : ImageArray + NGFF image array + ch_idx : list[int] + list of channel indices to read, output channel ordering will reflect the sequence + tz : int + window index within the FOV, counted Z-first + + Returns + ------- + tuple[tuple[Tensor], tuple[str, int, int]] list of (C=1, Z, Y, X) image tensors, tuple of image name, time index, and Z index + + Raises + ------ + IndexError + If the window index is out of bounds """ zs = img.shape[-3] - self.z_window_size + 1 t = (tz + zs) // zs - 1 @@ -185,6 +248,7 @@ def _read_img_window( return torch.from_numpy(data).unbind(dim=1), (img.name, t, z) def __len__(self) -> int: + """Return total number of sliding windows across all FOVs.""" return self._max_window # TODO: refactor to a top level function @@ -203,6 +267,7 @@ def _stack_channels( ] def __getitem__(self, index: int) -> Sample: + """Get sliding window sample by index.""" img, tz, norm_meta = self._find_window(index) ch_names = self.channels["source"].copy() ch_idx = self.source_ch_idx.copy() @@ -233,19 +298,27 @@ def __getitem__(self, index: int) -> Sample: class MaskTestDataset(SlidingWindowDataset): - """Torch dataset where each element is a window of - (C, Z, Y, X) where C=2 (source and target) and Z is ``z_window_size``. - This a testing stage version of :py:class:`viscy.data.hcs.SlidingWindowDataset`, - and can only be used with batch size 1 for efficiency (no padding for collation), - since the mask is not available for each stack. - - :param list[Position] positions: FOVs to include in dataset - :param ChannelMap channels: source and target channel names, + """Torch dataset with ground truth masks for testing. + + Each element is a window of (C, Z, Y, X) where C=2 (source and target) + and Z is ``z_window_size``. This is a testing stage version of + :py:class:`viscy.data.hcs.SlidingWindowDataset`, and can only be used + with batch size 1 for efficiency (no padding for collation), since the + mask is not available for each stack. + + Parameters + ---------- + positions : list[Position] + FOVs to include in dataset + channels : ChannelMap + Source and target channel names, e.g. ``{'source': 'Phase', 'target': ['Nuclei', 'Membrane']}`` - :param int z_window_size: Z window size of the 2.5D U-Net, 1 for 2D - :param DictTransform transform: - a callable that transforms data, defaults to None - :param str | None ground_truth_masks: path to the ground truth masks + z_window_size : int + Z window size of the 2.5D U-Net, 1 for 2D + transform : DictTransform | None, optional + A callable that transforms data, by default None + ground_truth_masks : str | None, optional + Path to the ground truth masks, by default None """ def __init__( @@ -270,6 +343,7 @@ def __init__( _logger.info(str(self.masks)) def __getitem__(self, index: int) -> Sample: + """Get sample with ground truth mask if available.""" sample = super().__getitem__(index) img_name, t_idx, z_idx = sample["index"] position_name = int(img_name.split("/")[-2]) @@ -367,7 +441,14 @@ def __init__( self.pin_memory = pin_memory @property - def cache_path(self): + def cache_path(self) -> Path: + """Get the temporary cache path for HCS data. + + Returns + ------- + Path + Cache directory path in system temp with SLURM job ID if available + """ return Path( tempfile.gettempdir(), os.getenv("SLURM_JOB_ID", "viscy_cache"), @@ -375,7 +456,14 @@ def cache_path(self): ) @property - def maybe_cached_data_path(self): + def maybe_cached_data_path(self) -> Path: + """Get data path, using cache if caching is enabled. + + Returns + ------- + Path + Cache path if caching enabled, otherwise original data path + """ return self.cache_path if self.caching else self.data_path def _data_log_path(self) -> Path: @@ -387,7 +475,12 @@ def _data_log_path(self) -> Path: log_dir.mkdir(parents=True, exist_ok=True) return log_dir / "data.log" - def prepare_data(self): + def prepare_data(self) -> None: + """Prepare HCS data by caching if enabled. + + Copies OME-Zarr data to temporary cache directory for improved + I/O performance during training. + """ if not self.caching: return # setup logger @@ -424,6 +517,18 @@ def _base_dataset_settings(self) -> dict[str, dict[str, list[str]] | int]: } def setup(self, stage: Literal["fit", "validate", "test", "predict"]): + """Set up datasets for the specified Lightning stage. + + Parameters + ---------- + stage : Literal["fit", "validate", "test", "predict"] + Current training stage for Lightning setup + + Raises + ------ + NotImplementedError + If stage is not supported + """ dataset_settings = self._base_dataset_settings if stage in ("fit", "validate"): self._setup_fit(dataset_settings) @@ -532,7 +637,7 @@ def _setup_predict( ) def on_before_batch_transfer(self, batch: Sample, dataloader_idx: int) -> Sample: - """Removes redundant Z slices if the target is 2D to save VRAM.""" + """Remove redundant Z slices if the target is 2D to save VRAM.""" predicting = False if self.trainer: if self.trainer.predicting: @@ -546,7 +651,15 @@ def on_before_batch_transfer(self, batch: Sample, dataloader_idx: int) -> Sample batch["target"] = batch["target"][:, :, slice(z_index, z_index + 1)] return batch - def train_dataloader(self): + def train_dataloader(self) -> DataLoader: + """Create training DataLoader for HCS data. + + Returns + ------- + DataLoader + Training DataLoader with shuffling, batch collation, and + multi-worker support for HCS sliding window sampling + """ return DataLoader( self.train_dataset, batch_size=self.batch_size // self.train_patches_per_stack, @@ -559,7 +672,15 @@ def train_dataloader(self): pin_memory=self.pin_memory, ) - def val_dataloader(self): + def val_dataloader(self) -> DataLoader: + """Create validation DataLoader for HCS data. + + Returns + ------- + DataLoader + Validation DataLoader without shuffling for deterministic + validation evaluation on HCS datasets + """ return DataLoader( self.val_dataset, batch_size=self.batch_size, @@ -570,7 +691,15 @@ def val_dataloader(self): pin_memory=self.pin_memory, ) - def test_dataloader(self): + def test_dataloader(self) -> DataLoader: + """Create test DataLoader for HCS data with optional ground truth masks. + + Returns + ------- + DataLoader + Test DataLoader with batch size 1 for mask compatibility + and optional ground truth mask loading for segmentation metrics + """ return DataLoader( self.test_dataset, batch_size=1, @@ -578,7 +707,15 @@ def test_dataloader(self): shuffle=False, ) - def predict_dataloader(self): + def predict_dataloader(self) -> DataLoader: + """Create prediction DataLoader for HCS data. + + Returns + ------- + DataLoader + Prediction DataLoader for inference on HCS datasets + with metadata tracking enabled for transform inversion + """ return DataLoader( self.predict_dataset, batch_size=self.batch_size, @@ -587,8 +724,16 @@ def predict_dataloader(self): ) def _fit_transform(self) -> tuple[Compose, Compose]: - """(normalization -> maybe augmentation -> center crop) - Deterministic center crop as the last step of training and validation.""" + """Create training and validation transform pipelines. + + (normalization -> maybe augmentation -> center crop) + Deterministic center crop as the last step of training and validation. + + Returns + ------- + tuple[Compose, Compose] + Training and validation transform compositions + """ # TODO: These have a fixed order for now... () final_crop = [self._final_crop()] train_transform = Compose( @@ -609,8 +754,16 @@ def _final_crop(self) -> CenterSpatialCropd: ) def _train_transform(self) -> list[Callable]: - """Setup training augmentations: check input values, - and parse the number of Z slices and patches to sample per stack.""" + """Set up training augmentations. + + Check input values and parse the number of Z slices and patches to + sample per stack. + + Returns + ------- + list[Callable] + List of training augmentation transforms + """ self.train_patches_per_stack = 1 z_scale_range = None if self.augmentations: diff --git a/viscy/data/livecell.py b/viscy/data/livecell.py index e8da1eb45..d7134fe87 100644 --- a/viscy/data/livecell.py +++ b/viscy/data/livecell.py @@ -124,6 +124,42 @@ def __getitem__(self, idx: int) -> Sample: class LiveCellDataModule(GPUTransformDataModule): + """Data module for LiveCell microscopy dataset. + + Provides train, validation, and test dataloaders for the LiveCell + dataset containing single-cell segmentation annotations for multiple + cell types in live-cell imaging. + + Parameters + ---------- + train_val_images : Path | None, optional + Path to the training and validation images. + test_images : Path | None, optional + Path to the test images. + train_annotations : Path | None, optional + Path to the training annotations. + val_annotations : Path | None, optional + Path to the validation annotations. + test_annotations : Path | None, optional + Path to the test annotations. + train_cpu_transforms : list[MapTransform], optional + List of CPU transforms for training. + val_cpu_transforms : list[MapTransform], optional + List of CPU transforms for validation. + train_gpu_transforms : list[MapTransform], optional + List of GPU transforms for training. + val_gpu_transforms : list[MapTransform], optional + List of GPU transforms for validation. + test_transforms : list[MapTransform], optional + List of transforms for testing. + batch_size : int, optional + Batch size, by default 16. + num_workers : int, optional + Number of dataloading workers, by default 8. + pin_memory : bool, optional + Pin memory for dataloaders, by default True. + """ + def __init__( self, train_val_images: Path | None = None, @@ -172,21 +208,56 @@ def __init__( @property def train_cpu_transforms(self) -> Compose: + """Get CPU transforms for training data augmentation. + + Returns + ------- + Compose + Composed transforms applied on CPU during training. + """ return self._train_cpu_transforms @property def val_cpu_transforms(self) -> Compose: + """Get CPU transforms for validation data processing. + + Returns + ------- + Compose + Composed transforms applied on CPU during validation. + """ return self._val_cpu_transforms @property def train_gpu_transforms(self) -> Compose: + """Get GPU transforms for training data augmentation. + + Returns + ------- + Compose + Composed transforms applied on GPU during training. + """ return self._train_gpu_transforms @property def val_gpu_transforms(self) -> Compose: + """Get GPU transforms for validation data processing. + + Returns + ------- + Compose + Composed transforms applied on GPU during validation. + """ return self._val_gpu_transforms def setup(self, stage: str) -> None: + """Set up datasets based on the specified stage. + + Parameters + ---------- + stage : str + Either "fit" for training/validation or "test" for testing. + """ if stage == "fit": self._setup_fit() elif stage == "test": @@ -221,6 +292,13 @@ def _setup_test(self) -> None: ) def test_dataloader(self) -> DataLoader: + """Create test data loader. + + Returns + ------- + DataLoader + Test data loader with LiveCell test dataset. + """ return DataLoader( self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers ) diff --git a/viscy/data/mmap_cache.py b/viscy/data/mmap_cache.py index 735159903..b3cf427f4 100644 --- a/viscy/data/mmap_cache.py +++ b/viscy/data/mmap_cache.py @@ -30,6 +30,31 @@ class MmappedDataset(Dataset): + """Dataset for memory-mapped OME-Zarr arrays with caching. + + Provides efficient access to time-series microscopy data through + memory-mapped tensors with lazy loading and caching capabilities. + + Parameters + ---------- + positions : list[Position] + List of FOVs to load images from. + channel_names : list[str] + List of channel names to load. + cache_map : DictProxy + Shared dictionary for caching loaded volumes. + buffer : MemoryMappedTensor + Memory-mapped tensor for caching loaded volumes. + preprocess_transforms : Compose | None, optional + Composed transforms to be applied on the CPU, by default None + cpu_transform : Compose | None, optional + Composed transforms to be applied on the CPU, by default None + array_key : str, optional + The image array key name (multi-scale level), by default "0" + load_normalization_metadata : bool, optional + Load normalization metadata in the sample dictionary, by default True + """ + def __init__( self, positions: list[Position], @@ -178,26 +203,71 @@ def __init__( @property def preprocessing_transforms(self) -> Compose: + """Get preprocessing transforms for data normalization. + + Returns + ------- + Compose + Composed transforms for preprocessing image data. + """ return self._preprocessing_transforms @property def train_cpu_transforms(self) -> Compose: + """Get CPU transforms for training data augmentation. + + Returns + ------- + Compose + Composed transforms applied on CPU during training. + """ return self._train_cpu_transforms @property def train_gpu_transforms(self) -> Compose: + """Get GPU transforms for training data augmentation. + + Returns + ------- + Compose + Composed transforms applied on GPU during training. + """ return self._train_gpu_transforms @property def val_cpu_transforms(self) -> Compose: + """Get CPU transforms for validation data processing. + + Returns + ------- + Compose + Composed transforms applied on CPU during validation. + """ return self._val_cpu_transforms @property def val_gpu_transforms(self) -> Compose: + """Get GPU transforms for validation data processing. + + Returns + ------- + Compose + Composed transforms applied on GPU during validation. + """ return self._val_gpu_transforms @property def cache_dir(self) -> Path: + """Get cache directory for memory-mapped files. + + Creates a unique cache directory based on SLURM job ID or + distributed rank for parallel training. + + Returns + ------- + Path + Cache directory path for storing memory-mapped tensor files. + """ scratch_dir = self.scratch_dir or Path(tempfile.gettempdir()) cache_dir = Path( scratch_dir, @@ -222,6 +292,22 @@ def _buffer_shape(self, arr_shape, fovs) -> tuple[int, ...]: return (len(fovs) * arr_shape[0], len(self.channels), *arr_shape[2:]) def setup(self, stage: Literal["fit", "validate"]) -> None: + """Set up datasets for training or validation. + + Creates memory-mapped datasets with train/val split based on the + specified stage. Initializes buffers and cache maps for efficient + data loading. + + Parameters + ---------- + stage : Literal["fit", "validate"] + Stage for which to set up the datasets. + + Raises + ------ + NotImplementedError + If stage is not "fit" or "validate". + """ if stage not in ("fit", "validate"): raise NotImplementedError("Only fit and validate stages are supported.") plate: Plate = open_ome_zarr(self.data_path, mode="r", layout="hcs") diff --git a/viscy/data/segmentation.py b/viscy/data/segmentation.py index 553d9241c..4a6f13716 100644 --- a/viscy/data/segmentation.py +++ b/viscy/data/segmentation.py @@ -15,6 +15,29 @@ class SegmentationDataset(Dataset): + """ + Dataset for segmentation evaluation tasks. + + Loads predicted and target segmentation masks for comparison and evaluation. + + Parameters + ---------- + pred_dataset : Plate + HCS OME-Zarr plate containing predicted segmentation masks. + target_dataset : Plate + HCS OME-Zarr plate containing ground truth segmentation masks. + pred_channel : str + Name of the prediction channel to load. + target_channel : str + Name of the target channel to load. + pred_z_slice : int or slice + Z slice selection for prediction data. + target_z_slice : int or slice + Z slice selection for target data. + img_name : str, optional + Name of the image array within positions, by default "0". + """ + def __init__( self, pred_dataset: Plate, @@ -50,9 +73,23 @@ def _build_indices(self) -> None: _logger.info(f"Number of test samples: {len(self)}") def __len__(self) -> int: + """Return the number of segmentation samples in the dataset.""" return len(self._indices) def __getitem__(self, idx: int) -> SegmentationSample: + """ + Get a segmentation sample pair. + + Parameters + ---------- + idx : int + Index of the sample to retrieve. + + Returns + ------- + SegmentationSample + Dictionary containing prediction, target, position index, and time index. + """ pred_img, target_img, p, t = self._indices[idx] _logger.debug(f"Target image: {target_img.name}") pred = torch.from_numpy( @@ -65,6 +102,32 @@ def __getitem__(self, idx: int) -> SegmentationSample: class SegmentationDataModule(LightningDataModule): + """ + Lightning DataModule for segmentation evaluation. + + Manages data loading for comparing predicted and target segmentation masks. + Only supports test stage for evaluation purposes. + + Parameters + ---------- + pred_dataset : Path + Path to HCS OME-Zarr containing predicted segmentation masks. + target_dataset : Path + Path to HCS OME-Zarr containing ground truth segmentation masks. + pred_channel : str + Name of the prediction channel to load. + target_channel : str + Name of the target channel to load. + pred_z_slice : int + Z slice index for prediction data. + target_z_slice : int + Z slice index for target data. + batch_size : int + Batch size for data loading. + num_workers : int + Number of workers for data loading. + """ + def __init__( self, pred_dataset: Path, @@ -87,6 +150,19 @@ def __init__( self.num_workers = num_workers def setup(self, stage: str) -> None: + """ + Set up the segmentation dataset. + + Parameters + ---------- + stage : str + Stage to set up for. Only "test" is supported. + + Raises + ------ + NotImplementedError + If stage is not "test". + """ if stage != "test": raise NotImplementedError("Only test stage is supported!") self.test_dataset = SegmentationDataset( @@ -99,6 +175,14 @@ def setup(self, stage: str) -> None: ) def test_dataloader(self) -> DataLoader: + """ + Create test data loader for segmentation evaluation. + + Returns + ------- + DataLoader + Test data loader containing prediction-target pairs. + """ return DataLoader( self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers ) diff --git a/viscy/data/select.py b/viscy/data/select.py index 6e00c10e8..4a0e4539b 100644 --- a/viscy/data/select.py +++ b/viscy/data/select.py @@ -1,4 +1,4 @@ -from typing import Generator +from collections.abc import Generator from iohub.ngff.nodes import Plate, Position, Well @@ -21,6 +21,12 @@ def _filter_fovs( class SelectWell: + """Filter wells and fields-of-view for dataset selection. + + This class provides functionality to filter wells by inclusion criteria + and exclude specific fields-of-view from the dataset. + """ + _include_wells: list[str] | None _exclude_fovs: list[str] | None diff --git a/viscy/data/triplet.py b/viscy/data/triplet.py index 4a22fc269..8110e6096 100644 --- a/viscy/data/triplet.py +++ b/viscy/data/triplet.py @@ -1,8 +1,9 @@ import logging import os import warnings +from collections.abc import Sequence from pathlib import Path -from typing import Literal, Sequence +from typing import Literal import pandas as pd import tensorstore as ts @@ -65,6 +66,50 @@ def _transform_channel_wise( class TripletDataset(Dataset): + """Dataset for triplet sampling of tracked cells. + + Generates anchor, positive, and negative triplets from tracked cell + patches for contrastive learning. Supports temporal sampling with + configurable time intervals. + + Parameters + ---------- + positions : list[Position] + OME-Zarr images with consistent channel order + tracks_tables : list[pd.DataFrame] + Data frames containing ultrack results + channel_names : list[str] + Input channel names + initial_yx_patch_size : tuple[int, int] + YX size of the initially sampled image patch before augmentation + z_range : slice + Range of Z-slices + anchor_transform : DictTransform | None, optional + Transforms applied to the anchor sample, by default None + positive_transform : DictTransform | None, optional + Transforms applied to the positve sample, by default None + negative_transform : DictTransform | None, optional + Transforms applied to the negative sample, by default None + fit : bool, optional + Fitting mode in which the full triplet will be sampled, + only sample anchor if ``False``, by default True + predict_cells : bool, optional + Only predict on selected cells, by default False + include_fov_names : list[str] | None, optional + Only predict on selected FOVs, by default None + include_track_ids : list[int] | None, optional + Only predict on selected track IDs, by default None + time_interval : Literal["any"] | int, optional + Future time interval to sample positive and anchor from, + by default "any" + (sample negative from another track any time point + and use the augmented anchor patch as positive) + return_negative : bool, optional + Whether to return the negative sample during the fit stage + (can be set to False when using a loss function like NT-Xent), + by default True + """ + def __init__( self, positions: list[Position], @@ -162,8 +207,7 @@ def _get_tensorstore(self, position: Position) -> ts.TensorStore: return self._tensorstores[fov_name] def _filter_tracks(self, tracks_tables: list[pd.DataFrame]) -> pd.DataFrame: - """Exclude tracks that are too close to the border - or do not have the next time point. + """Exclude tracks that are too close to the border or do not have the next time point. Parameters ---------- @@ -223,6 +267,7 @@ def _specific_cells(self, tracks: pd.DataFrame) -> pd.DataFrame: return specific_tracks.reset_index(drop=True) def __len__(self) -> int: + """Return number of valid anchor samples.""" return len(self.valid_anchors) def _sample_positives(self, anchor_rows: pd.DataFrame) -> pd.DataFrame: @@ -232,8 +277,21 @@ def _sample_positives(self, anchor_rows: pd.DataFrame) -> pd.DataFrame: return query.merge(self.tracks, on=["global_track_id", "t"], how="inner") def _sample_negative(self, anchor_row: pd.Series) -> pd.Series: - """Select a negative sample from a different track in the next time point - if an interval is specified, otherwise from any random time point.""" + """Select a negative sample from a different track. + + Selects from the next time point if an interval is specified, + otherwise from any random time point. + + Parameters + ---------- + anchor_row : pd.Series + Row containing anchor cell information. + + Returns + ------- + pd.Series + Row containing negative sample information. + """ if self.time_interval == "any": tracks = self.tracks else: @@ -288,6 +346,7 @@ def _slice_patches(self, track_rows: pd.DataFrame): return torch.from_numpy(results), norms def __getitems__(self, indices: list[int]) -> dict[str, torch.Tensor]: + """Get batched triplet samples for efficient data loading.""" anchor_rows = self.valid_anchors.iloc[indices] anchor_patches, anchor_norms = self._slice_patches(anchor_rows) sample = {"anchor": anchor_patches, "anchor_norm_meta": anchor_norms} @@ -322,10 +381,70 @@ def __getitems__(self, indices: list[int]) -> dict[str, torch.Tensor]: class TripletDataModule(HCSDataModule): + """Lightning data module for triplet sampling from tracked cells. + + Provides train, validation, and prediction dataloaders for contrastive + learning on cell tracking data. Supports configurable time intervals + and spatial patch sampling. + + Parameters + ---------- + data_path : str | Path + Image dataset path + tracks_path : str | Path + Tracks labels dataset path + source_channel : str | Sequence[str] + List of input channel names + z_range : tuple[int, int] + Range of valid z-slices + initial_yx_patch_size : tuple[int, int], optional + XY size of the initially sampled image patch, by default (512, 512) + final_yx_patch_size : tuple[int, int], optional + Output patch size, by default (224, 224) + split_ratio : float, optional + Ratio of training samples, by default 0.8 + batch_size : int, optional + Batch size, by default 16 + num_workers : int, optional + Number of data-loading workers, by default 8 + normalizations : list[MapTransform], optional + Normalization transforms, by default [] + augmentations : list[MapTransform], optional + Augmentation transforms, by default [] + caching : bool, optional + Whether to cache the dataset, by default False + fit_include_wells : list[str], optional + Only include these wells for fitting, by default None + fit_exclude_fovs : list[str], optional + Exclude these FOVs for fitting, by default None + predict_cells : bool, optional + Only predict for selected cells, by default False + include_fov_names : list[str] | None, optional + Only predict for selected FOVs, by default None + include_track_ids : list[int] | None, optional + Only predict for selected tracks, by default None + time_interval : Literal["any"] | int, optional + Future time interval to sample positive and anchor from, + "any" means sampling negative from another track any time point + and using the augmented anchor patch as positive), by default "any" + return_negative : bool, optional + Whether to return the negative sample during the fit stage + (can be set to False when using a loss function like NT-Xent), + by default True + persistent_workers : bool, optional + Whether to keep worker processes alive between iterations, by default False + prefetch_factor : int | None, optional + Number of batches loaded in advance by each worker, by default None + pin_memory : bool, optional + Whether to pin memory in CPU for faster GPU transfer, by default False + z_window_size : int, optional + Size of the final Z window, by default None (inferred from z_range) + """ + def __init__( self, - data_path: str, - tracks_path: str, + data_path: str | Path, + tracks_path: str | Path, source_channel: str | Sequence[str], z_range: tuple[int, int], initial_yx_patch_size: tuple[int, int] = (512, 512), @@ -450,13 +569,15 @@ def __init__( def _align_tracks_tables_with_positions( self, ) -> tuple[list[Position], list[pd.DataFrame]]: - """Parse positions in ome-zarr store containing tracking information - and assemble tracks tables for each position. + """Parse positions in ome-zarr store containing tracking information. + + Assembles tracks tables for each position by matching position names + with corresponding CSV files in the tracks directory. Returns ------- tuple[list[Position], list[pd.DataFrame]] - List of positions and list of tracks tables for each position + List of positions and list of tracks tables for each position. """ positions = [] tracks_tables = [] @@ -527,7 +648,14 @@ def _setup_predict(self, dataset_settings: dict): def _setup_test(self, *args, **kwargs): raise NotImplementedError("Self-supervised model does not support testing") - def train_dataloader(self): + def train_dataloader(self) -> ThreadDataLoader: + """Create training data loader for triplet sampling. + + Returns + ------- + ThreadDataLoader + Training data loader with shuffling and thread workers. + """ return ThreadDataLoader( self.train_dataset, use_thread_workers=True, @@ -541,7 +669,14 @@ def train_dataloader(self): collate_fn=lambda x: x, ) - def val_dataloader(self): + def val_dataloader(self) -> ThreadDataLoader: + """Create validation data loader for triplet sampling. + + Returns + ------- + ThreadDataLoader + Validation data loader without shuffling. + """ return ThreadDataLoader( self.val_dataset, use_thread_workers=True, @@ -555,7 +690,14 @@ def val_dataloader(self): collate_fn=lambda x: x, ) - def predict_dataloader(self): + def predict_dataloader(self) -> ThreadDataLoader: + """Create prediction data loader for cell embedding extraction. + + Returns + ------- + ThreadDataLoader + Prediction data loader for anchor-only sampling. + """ return ThreadDataLoader( self.predict_dataset, use_thread_workers=True, diff --git a/viscy/data/typing.py b/viscy/data/typing.py index d6a70488c..eafb04d09 100644 --- a/viscy/data/typing.py +++ b/viscy/data/typing.py @@ -1,4 +1,5 @@ -from typing import Callable, NamedTuple, Sequence, TypedDict, TypeVar +from collections.abc import Callable, Sequence +from typing import NamedTuple, TypedDict, TypeVar from torch import ShortTensor, Tensor @@ -13,6 +14,8 @@ class LevelNormStats(TypedDict): + """Statistics for normalization at a specific level (dataset or FOV).""" + mean: Tensor std: Tensor median: Tensor @@ -20,6 +23,8 @@ class LevelNormStats(TypedDict): class ChannelNormStats(TypedDict): + """Normalization statistics for a channel at different levels.""" + dataset_statistics: LevelNormStats fov_statistics: LevelNormStats @@ -39,6 +44,7 @@ class HCSStackIndex(NamedTuple): class Sample(TypedDict, total=False): """ Image sample type for mini-batches. + All fields are optional. """ @@ -54,9 +60,7 @@ class Sample(TypedDict, total=False): class SegmentationSample(TypedDict): - """ - Segmentation sample type for mini-batches. - """ + """Segmentation sample type for mini-batches.""" pred: ShortTensor target: ShortTensor @@ -72,17 +76,18 @@ class ChannelMap(TypedDict): class TrackingIndex(TypedDict): - """Tracking index extracted from ultrack result - Potentially collated by the dataloader""" + """ + Tracking index extracted from ultrack result. + + Potentially collated by the dataloader. + """ fov_name: OneOrSeq[str] id: OneOrSeq[int] class TripletSample(TypedDict): - """ - Triplet sample type for mini-batches. - """ + """Triplet sample type for mini-batches.""" anchor: Tensor positive: NotRequired[Tensor] diff --git a/viscy/preprocessing/generate_masks.py b/viscy/preprocessing/generate_masks.py index 491bc4069..9cc4e9ae9 100644 --- a/viscy/preprocessing/generate_masks.py +++ b/viscy/preprocessing/generate_masks.py @@ -1,41 +1,47 @@ -"""Generate masks from sum of flurophore channels""" +"""Generate masks from sum of flurophore channels.""" -import iohub.ngff as ngff +from pathlib import Path +from typing import Literal +import iohub.ngff as ngff import viscy.utils.aux_utils as aux_utils from viscy.utils.mp_utils import mp_create_and_write_mask class MaskProcessor: - """ - Appends Masks to zarr directories + """Appends Masks to zarr directories. + + Parameters + ---------- + zarr_dir : Path + Directory of HCS zarr store to pull data from. Note: data in store is assumed to be stored in TCZYX format. + channel_ids : list[int] | int + Channel indices to be masked (typically just one) + time_ids : list[int] | int + Timepoints to consider + pos_ids : list[int] | int + Position (FOV) indices to use + num_workers : int, optional + Number of workers for multiprocessing, by default 4 + mask_type : Literal["otsu", "unimodal", "mem_detection", "borders_weight_loss_map"], optional + Method to use for generating mask. Needed for mapping to the masking function. + One of: {'otsu', 'unimodal', 'mem_detection', 'borders_weight_loss_map'}, by default "otsu". + overwrite_ok : bool, optional + Overwrite existing masks, by default False. """ def __init__( self, - zarr_dir, - channel_ids, - time_ids=-1, - pos_ids=-1, - num_workers=4, - mask_type="otsu", - overwrite_ok=False, + zarr_dir: Path, + channel_ids: list[int] | int, + time_ids: list[int] | int, + pos_ids: list[int] | int, + num_workers: int = 4, + mask_type: Literal[ + "otsu", "unimodal", "mem_detection", "borders_weight_loss_map" + ] = "otsu", + overwrite_ok: bool = False, ): - """ - :param str zarr_dir: directory of HCS zarr store to pull data from. - Note: data in store is assumed to be stored in - (time, channel, z, y, x) format. - :param list[int] channel_ids: Channel indices to be masked (typically - just one) - :param int/list channel_ids: generate mask from the sum of these - (flurophore) channel indices - :param list/int time_ids: timepoints to consider - :param int pos_ids: Position (FOV) indices to use - :param int num_workers: number of workers for multiprocessing - :param str mask_type: method to use for generating mask. Needed for - mapping to the masking function. One of: - {'otsu', 'unimodal', 'borders_weight_loss_map'} - """ self.zarr_dir = zarr_dir self.num_workers = num_workers @@ -72,8 +78,9 @@ def __init__( print(f"Mask found in channel {mask_name}. Overwriting with this mask.") plate.close() - def generate_masks(self, structure_elem_radius=5): - """ + def generate_masks(self, structure_elem_radius: int = 5): + """Generate foreground masks from fluorophore channels. + The sum of flurophore channels is thresholded to generate a foreground mask. @@ -84,10 +91,11 @@ def generate_masks(self, structure_elem_radius=5): Masks are also saved as an additional untracked array named "mask" and tracked in the "mask" metadata field. - :param int structure_elem_radius: Radius of structuring element for - morphological operations + Parameters + ---------- + structure_elem_radius : int + Radius of structuring element for morphological operations """ - # Gather function arguments for each index pair at each position plate = ngff.open_ome_zarr(store_path=self.zarr_dir, mode="r+") diff --git a/viscy/preprocessing/pixel_ratio.py b/viscy/preprocessing/pixel_ratio.py index 29c2ed419..285a15c09 100644 --- a/viscy/preprocessing/pixel_ratio.py +++ b/viscy/preprocessing/pixel_ratio.py @@ -6,13 +6,23 @@ def sematic_class_weights( dataset_path: str, target_channel: str, num_classes: int = 3 ) -> NDArray: - """Computes class balancing weights for semantic segmentation. + """Compute class balancing weights for semantic segmentation. + The weights can be used for cross-entropy loss. - :param str dataset_path: HCS OME-Zarr dataset path - :param str target_channel: target channel name - :param int num_classes: number of classes - :return NDArray: inverted ratio of background, uninfected and infected pixels + Parameters + ---------- + dataset_path : str + HCS OME-Zarr dataset path + target_channel : str + Target channel name + num_classes : int + Number of classes. Default is 3. + + Returns + ------- + NDArray + Inverted ratio of background, uninfected and infected pixels """ dataset = open_ome_zarr(dataset_path) arrays = [da.from_zarr(pos["0"]) for _, pos in dataset.positions()] diff --git a/viscy/preprocessing/precompute.py b/viscy/preprocessing/precompute.py index 1c68ad300..cc94b69f6 100644 --- a/viscy/preprocessing/precompute.py +++ b/viscy/preprocessing/precompute.py @@ -1,5 +1,3 @@ -"""Precompute normalization and store a plain C array""" - from __future__ import annotations from pathlib import Path @@ -8,7 +6,6 @@ import dask.array as da from dask.diagnostics import ProgressBar from iohub.ngff import open_ome_zarr - from viscy.data.select import _filter_fovs, _filter_wells @@ -40,6 +37,27 @@ def precompute_array( include_wells: list[str] | None = None, exclude_fovs: list[str] | None = None, ) -> None: + """Precompute normalized image arrays for efficient data loading. + + Parameters + ---------- + data_path : Path + Path to HCS OME-Zarr dataset. + output_path : Path + Output path for precomputed arrays. + channel_names : list[str] + List of channel names to process. + subtrahends : list[Literal["mean"] | float] + Subtraction values for normalization per channel. + divisors : list[Literal["std"] | tuple[float, float]] + Division values for normalization per channel. + image_array_key : str, optional + Array key in zarr store, by default "0". + include_wells : list[str] | None, optional + Wells to include, by default None (all wells). + exclude_fovs : list[str] | None, optional + FOVs to exclude, by default None (no exclusions). + """ normalized_images: list[da.Array] = [] with open_ome_zarr(data_path, layout="hcs", mode="r") as dataset: channel_indices = [dataset.channel_names.index(c) for c in channel_names] diff --git a/viscy/representation/classification.py b/viscy/representation/classification.py index d12f1be44..718866124 100644 --- a/viscy/representation/classification.py +++ b/viscy/representation/classification.py @@ -1,24 +1,55 @@ from pathlib import Path +from typing import Any +import numpy as np import pandas as pd import torch -from lightning.pytorch import LightningModule +from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.callbacks import BasePredictionWriter from torch import nn from torchmetrics.functional.classification import binary_accuracy, binary_f1_score - from viscy.representation.contrastive import ContrastiveEncoder from viscy.utils.log_images import render_images class ClassificationPredictionWriter(BasePredictionWriter): - def __init__(self, output_path: Path): + """Prediction writer callback for saving classification outputs to CSV. + + Collects predictions from all batches and writes them to a CSV file at the + end of each epoch. Converts tensor outputs to numpy arrays for storage. + + Parameters + ---------- + output_path : Path + Path to the output CSV file. + """ + + def __init__(self, output_path: Path) -> None: super().__init__("epoch") if Path(output_path).exists(): raise FileExistsError(f"Output path {output_path} already exists.") self.output_path = output_path - def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices): + def write_on_epoch_end( + self, + trainer: Trainer, + pl_module: LightningModule, + predictions: list[dict[str, Any]], + batch_indices: list[int], + ) -> None: + """Write all predictions to CSV file at epoch end. + + Parameters + ---------- + trainer : Trainer + PyTorch Lightning trainer instance. + pl_module : LightningModule + Lightning module being trained. + predictions : list[dict[str, Any]] + List of prediction dictionaries from all batches. + batch_indices : list[int] + Indices of batches processed during prediction. + """ all_predictions = [] for prediction in predictions: for key, value in prediction.items(): @@ -29,13 +60,31 @@ def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices): class ClassificationModule(LightningModule): + """Binary classification module using pre-trained contrastive encoder. + + Adapts a contrastive encoder for binary classification by replacing the + final linear layer and adding classification-specific training logic. + Computes binary cross-entropy loss and tracks accuracy and F1-score metrics. + + Parameters + ---------- + encoder : ContrastiveEncoder + Contrastive encoder model. + lr : float | None + Learning rate. + loss : nn.Module | None + Loss function. By default, BCEWithLogitsLoss with positive weight of 1.0. + example_input_array_shape : tuple[int, ...] + Shape of the example input array. + """ + def __init__( self, encoder: ContrastiveEncoder, lr: float | None, loss: nn.Module | None = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(1.0)), example_input_array_shape: tuple[int, ...] = (2, 1, 15, 160, 160), - ): + ) -> None: super().__init__() self.stem = encoder.stem self.backbone = encoder.encoder @@ -44,15 +93,34 @@ def __init__( self.lr = lr self.example_input_array = torch.rand(example_input_array_shape) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass through stem and backbone for classification. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (batch_size, channels, depth, height, width). + + Returns + ------- + torch.Tensor + Logits tensor of shape (batch_size, 1) for binary classification. + """ x = self.stem(x) return self.backbone(x) - def on_fit_start(self): + def on_fit_start(self) -> None: + """Initialize example storage lists at start of training. + + Creates empty lists to store training and validation examples for + visualization logging during the training process. + """ self.train_examples = [] self.val_examples = [] - def _fit_step(self, batch, stage: str, loss_on_step: bool): + def _fit_step( + self, batch: tuple[torch.Tensor, torch.Tensor], stage: str, loss_on_step: bool + ) -> tuple[torch.Tensor, np.ndarray]: x, y = batch y_hat = self(x) loss = self.loss(y_hat, y) @@ -66,26 +134,79 @@ def _fit_step(self, batch, stage: str, loss_on_step: bool): ) return loss, x[0, 0, x.shape[2] // 2].detach().cpu().numpy() - def training_step(self, batch, batch_idx: int): + def training_step( + self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int + ) -> torch.Tensor: + """Execute single training step with loss computation and logging. + + Parameters + ---------- + batch : tuple + Training batch containing (inputs, targets). + batch_idx : int + Index of current batch within epoch. + + Returns + ------- + torch.Tensor + Training loss for backpropagation. + """ loss, example = self._fit_step(batch, "train", loss_on_step=True) if batch_idx < 4: self.train_examples.append([example]) return loss - def validation_step(self, batch, batch_idx: int): + def validation_step( + self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int + ) -> torch.Tensor: + """Execute single validation step with metrics computation. + + Parameters + ---------- + batch : tuple + Validation batch containing (inputs, targets). + batch_idx : int + Index of current batch within epoch. + + Returns + ------- + torch.Tensor + Validation loss for monitoring. + """ loss, example = self._fit_step(batch, "val", loss_on_step=False) if batch_idx < 4: self.val_examples.append([example]) return loss - def predict_step(self, batch, batch_idx: int, dataloader_idx: int | None = None): + def predict_step( + self, + batch: tuple[torch.Tensor, torch.Tensor, dict[str, Any]], + batch_idx: int, + dataloader_idx: int | None = None, + ) -> dict[str, torch.Tensor]: + """Execute prediction step with sigmoid activation for probabilities. + + Parameters + ---------- + batch : tuple + Prediction batch containing (inputs, targets, indices). + batch_idx : int + Index of current batch. + dataloader_idx : int or None, optional + Index of dataloader when multiple dataloaders used. + + Returns + ------- + dict[str, torch.Tensor] + Dictionary containing indices, labels, and sigmoid probabilities. + """ x, y, indices = batch y_hat = nn.functional.sigmoid(self(x)) indices["label"] = y indices["prediction"] = y_hat return indices - def _log_images(self, examples, stage): + def _log_images(self, examples: list[list[np.ndarray]], stage: str) -> None: image = render_images(examples) self.logger.experiment.add_image( f"{stage}/examples", @@ -94,13 +215,30 @@ def _log_images(self, examples, stage): dataformats="HWC", ) - def on_train_epoch_end(self): + def on_train_epoch_end(self) -> None: + """Log training examples and clear storage at epoch end. + + Renders and logs training examples to tensorboard, then clears the + examples list for the next epoch. + """ self._log_images(self.train_examples, "train") self.train_examples.clear() - def on_validation_epoch_end(self): + def on_validation_epoch_end(self) -> None: + """Log validation examples and clear storage at epoch end. + + Renders and logs validation examples to tensorboard, then clears the + examples list for the next epoch. + """ self._log_images(self.val_examples, "val") self.val_examples.clear() - def configure_optimizers(self): + def configure_optimizers(self) -> torch.optim.AdamW: + """Configure AdamW optimizer for training. + + Returns + ------- + torch.optim.AdamW + AdamW optimizer with specified learning rate. + """ return torch.optim.AdamW(self.parameters(), lr=self.lr) diff --git a/viscy/representation/contrastive.py b/viscy/representation/contrastive.py index 8edeb8623..df6094ee9 100644 --- a/viscy/representation/contrastive.py +++ b/viscy/representation/contrastive.py @@ -3,7 +3,6 @@ import timm import torch.nn as nn from torch import Tensor - from viscy.unet.networks.unext2 import StemDepthtoChannels diff --git a/viscy/representation/embedding_writer.py b/viscy/representation/embedding_writer.py index 7b6296356..e7a89c079 100644 --- a/viscy/representation/embedding_writer.py +++ b/viscy/representation/embedding_writer.py @@ -1,6 +1,7 @@ import logging +from collections.abc import Sequence from pathlib import Path -from typing import Any, Dict, Literal, Optional, Sequence +from typing import Any, Literal import numpy as np import pandas as pd @@ -8,8 +9,6 @@ from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.callbacks import BasePredictionWriter from numpy.typing import NDArray -from xarray import Dataset, open_zarr - from viscy.data.triplet import INDEX_COLUMNS from viscy.representation.engine import ContrastivePrediction from viscy.representation.evaluation.dimensionality_reduction import ( @@ -17,14 +16,15 @@ compute_pca, compute_phate, ) +from xarray import Dataset, open_zarr __all__ = ["read_embedding_dataset", "EmbeddingWriter", "write_embedding_dataset"] _logger = logging.getLogger("lightning.pytorch") def read_embedding_dataset(path: Path) -> Dataset: - """ - Read the embedding dataset written by the EmbeddingWriter callback. + """Read the embedding dataset written by the EmbeddingWriter callback. + Supports both legacy datasets (without x/y coordinates) and new datasets. Parameters @@ -60,13 +60,13 @@ def _move_and_stack_embeddings( def write_embedding_dataset( - output_path: Path, - features: np.ndarray, + output_path: str | Path, + features: NDArray, index_df: pd.DataFrame, - projections: Optional[np.ndarray] = None, - umap_kwargs: Optional[Dict[str, Any]] = None, - phate_kwargs: Optional[Dict[str, Any]] = None, - pca_kwargs: Optional[Dict[str, Any]] = None, + projections: np.ndarray | None = None, + umap_kwargs: dict[str, Any] | None = None, + phate_kwargs: dict[str, Any] | None = None, + pca_kwargs: dict[str, Any] | None = None, overwrite: bool = False, ) -> None: """ @@ -74,9 +74,9 @@ def write_embedding_dataset( Parameters ---------- - output_path : Path + output_path : str | Path Path to the zarr store. - features : np.ndarray + features : NDArray Array of shape (n_samples, n_features) containing the embeddings. index_df : pd.DataFrame DataFrame containing the index information for each embedding. @@ -191,11 +191,12 @@ class EmbeddingWriter(BasePredictionWriter): Path to the zarr store. write_interval : Literal["batch", "epoch", "batch_and_epoch"], optional When to write the embeddings, by default 'epoch'. - umap_kwargs : dict, optional + umap_kwargs : dict[str, Any], optional Keyword arguments passed to UMAP, by default None (i.e. UMAP is not computed). - phate_kwargs : dict, optional + phate_kwargs : dict[str, Any], optional Keyword arguments passed to PHATE, by default PHATE is computed with default parameters. - pca_kwargs : dict, optional + Default configuration passed is: {"knn": 5, "decay": 40, "n_jobs": -1, "random_state": 42}. + pca_kwargs : dict[str, Any], optional Keyword arguments passed to PCA, by default PCA is computed with default parameters. """ @@ -221,6 +222,7 @@ def __init__( self.overwrite = overwrite def on_predict_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + """Initialize prediction writing and validate output path.""" if self.output_path.exists(): raise FileExistsError(f"Output path {self.output_path} already exists.") _logger.debug(f"Writing embeddings to {self.output_path}") diff --git a/viscy/representation/engine.py b/viscy/representation/engine.py index 7a35d93f1..a6d3f3db9 100644 --- a/viscy/representation/engine.py +++ b/viscy/representation/engine.py @@ -1,14 +1,14 @@ import logging -from typing import Literal, Sequence, TypedDict +from collections.abc import Sequence +from typing import Literal, TypedDict -import numpy as np import torch import torch.nn.functional as F from lightning.pytorch import LightningModule +from numpy.typing import NDArray from pytorch_metric_learning.losses import NTXentLoss from torch import Tensor, nn from umap import UMAP - from viscy.data.typing import TrackingIndex, TripletSample from viscy.representation.contrastive import ContrastiveEncoder from viscy.utils.log_images import detach_sample, render_images @@ -17,13 +17,39 @@ class ContrastivePrediction(TypedDict): + """Typed dictionary for contrastive model predictions. + + Contains features, projections, and metadata for contrastive learning + inference outputs. + """ + features: Tensor projections: Tensor index: TrackingIndex class ContrastiveModule(LightningModule): - """Contrastive Learning Model for self-supervised learning.""" + """Contrastive Learning Model for self-supervised learning. + + Parameters + ---------- + encoder : nn.Module | ContrastiveEncoder + Encoder model. + loss_function : nn.Module | nn.CosineEmbeddingLoss | nn.TripletMarginLoss | NTXentLoss + Loss function. By default, nn.TripletMarginLoss with margin 0.5. + lr : float + Learning rate. By default, 1e-3. + schedule : Literal["WarmupCosine", "Constant"] + Schedule for learning rate. By default, "Constant". + log_batches_per_epoch : int + Number of batches to log. By default, 8. + log_samples_per_batch : int + Number of samples to log. By default, 1. + log_embeddings : bool + Whether to log embeddings. By default, False. + example_input_array_shape : Sequence[int] + Shape of example input array. + """ def __init__( self, @@ -66,12 +92,36 @@ def forward(self, x: Tensor) -> tuple[Tensor, Tensor]: return self.model(x) def log_feature_statistics(self, embeddings: Tensor, prefix: str): + """Log embedding statistics for monitoring training dynamics. + + Parameters + ---------- + embeddings : Tensor + Embedding vectors to analyze. + prefix : str + Prefix for logging keys. + """ mean = torch.mean(embeddings, dim=0).detach().cpu().numpy() std = torch.std(embeddings, dim=0).detach().cpu().numpy() _logger.debug(f"{prefix}_mean: {mean}") _logger.debug(f"{prefix}_std: {std}") - def print_embedding_norms(self, anchor, positive, negative, phase): + def print_embedding_norms( + self, anchor: Tensor, positive: Tensor, negative: Tensor, phase: str + ): + """Log L2 norms of embeddings for triplet components. + + Parameters + ---------- + anchor : Tensor + Anchor embeddings. + positive : Tensor + Positive embeddings. + negative : Tensor + Negative embeddings. + phase : str + Training phase identifier for logging. + """ anchor_norm = torch.norm(anchor, dim=1).mean().item() positive_norm = torch.norm(positive, dim=1).mean().item() negative_norm = torch.norm(negative, dim=1).mean().item() @@ -80,7 +130,12 @@ def print_embedding_norms(self, anchor, positive, negative, phase): _logger.debug(f"{phase}/negative_norm: {negative_norm}") def _log_metrics( - self, loss, anchor, positive, stage: Literal["train", "val"], negative=None + self, + loss: Tensor, + anchor: Tensor, + positive: Tensor, + stage: Literal["train", "val"], + negative: Tensor | None = None, ): self.log( f"loss/{stage}", @@ -116,7 +171,7 @@ def _log_metrics( sync_dist=True, ) - def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): + def _log_samples(self, key: str, imgs: Sequence[Sequence[NDArray]]): grid = render_images(imgs, cmaps=["gray"] * 3) self.logger.experiment.add_image( key, grid, self.current_epoch, dataformats="HWC" @@ -133,6 +188,15 @@ def _log_step_samples(self, batch_idx, samples, stage: Literal["train", "val"]): output_list.extend(detach_sample(samples, self.log_samples_per_batch)) def log_embedding_umap(self, embeddings: Tensor, tag: str): + """Log UMAP visualization of embedding space to TensorBoard. + + Parameters + ---------- + embeddings : Tensor + High-dimensional embeddings to visualize. + tag : str + Tag for TensorBoard logging. + """ _logger.debug(f"Computing UMAP for {tag} embeddings.") umap = UMAP(n_components=2) embeddings_np = embeddings.detach().cpu().numpy() @@ -146,6 +210,23 @@ def log_embedding_umap(self, embeddings: Tensor, tag: str): ) def training_step(self, batch: TripletSample, batch_idx: int) -> Tensor: + """Execute training step for contrastive learning. + + Computes triplet or NT-Xent loss based on configured loss function + and logs training metrics. + + Parameters + ---------- + batch : TripletSample + Batch containing anchor, positive, and negative samples. + batch_idx : int + Index of current batch. + + Returns + ------- + Tensor + Computed contrastive loss. + """ anchor_img = batch["anchor"] pos_img = batch["positive"] _, anchor_projection = self(anchor_img) @@ -177,6 +258,11 @@ def training_step(self, batch: TripletSample, batch_idx: int) -> Tensor: return loss def on_train_epoch_end(self) -> None: + """Log training samples and embeddings at epoch end. + + Logs sample images and optionally computes UMAP visualization + of embedding space for monitoring training progress. + """ super().on_train_epoch_end() self._log_samples("train_samples", self.training_step_outputs) # Log UMAP embeddings for validation @@ -220,6 +306,11 @@ def validation_step(self, batch: TripletSample, batch_idx: int) -> Tensor: return loss def on_validation_epoch_end(self) -> None: + """Log validation samples and embeddings at epoch end. + + Logs sample images and optionally computes UMAP visualization + of embedding space for monitoring validation performance. + """ super().on_validation_epoch_end() self._log_samples("val_samples", self.validation_step_outputs) # Log UMAP embeddings for training @@ -231,12 +322,19 @@ def on_validation_epoch_end(self) -> None: self.validation_step_outputs = [] - def configure_optimizers(self): + def configure_optimizers(self) -> torch.optim.AdamW: + """Configure optimizer for contrastive learning. + + Returns + ------- + torch.optim.AdamW + AdamW optimizer with configured learning rate. + """ optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr) return optimizer def predict_step( - self, batch: TripletSample, batch_idx, dataloader_idx=0 + self, batch: TripletSample, batch_idx: int, dataloader_idx: int = 0 ) -> ContrastivePrediction: """Prediction step for extracting embeddings.""" features, projections = self.model(batch["anchor"]) diff --git a/viscy/representation/evaluation/__init__.py b/viscy/representation/evaluation/__init__.py index c474aec82..1c43f6295 100644 --- a/viscy/representation/evaluation/__init__.py +++ b/viscy/representation/evaluation/__init__.py @@ -1,5 +1,6 @@ -""" -This module enables evaluation of learned representations using annotations, such as +"""Evaluation tools for learned representations using various annotation types. + +Enables evaluation of learned representations using annotations, such as: * cell division labels, * infection state labels, * labels predicted using supervised classifiers, @@ -14,18 +15,22 @@ https://github.com/mehta-lab/dynacontrast/blob/master/analysis/gmm.py """ -import pandas as pd +from pathlib import Path +import pandas as pd from viscy.data.triplet import TripletDataModule +from xarray import DataArray -def load_annotation(da, path, name, categories: dict | None = None): +def load_annotation( + da: DataArray, path: str, name: str, categories: dict | None = None +) -> pd.Series: """ Load annotations from a CSV file and map them to the dataset. Parameters ---------- - da : xarray.DataArray + da : DataArray The dataset array containing 'fov_name' and 'id' coordinates. path : str Path to the CSV file containing annotations. @@ -64,15 +69,41 @@ def load_annotation(da, path, name, categories: dict | None = None): def dataset_of_tracks( - data_path, - tracks_path, - fov_list, - track_id_list, - source_channel=["Phase3D", "RFP"], - z_range=(28, 43), - initial_yx_patch_size=(128, 128), - final_yx_patch_size=(128, 128), + data_path: str | Path, + tracks_path: str | Path, + fov_list: list[str], + track_id_list: list[int], + source_channel: list[str] = ["Phase3D", "RFP"], + z_range: tuple[int, int] = (28, 43), + initial_yx_patch_size: tuple[int, int] = (128, 128), + final_yx_patch_size: tuple[int, int] = (128, 128), ): + """Create a prediction dataset from tracks for evaluation. + + Parameters + ---------- + data_path : str + Path to the data directory containing image files. + tracks_path : str + Path to the tracks data file. + fov_list : list + List of field of view names to include. + track_id_list : list + List of track IDs to include. + source_channel : list, optional + List of source channel names, by default ["Phase3D", "RFP"]. + z_range : tuple, optional + Z-stack range as (start, end), by default (28, 43). + initial_yx_patch_size : tuple, optional + Initial patch size in YX dimensions, by default (128, 128). + final_yx_patch_size : tuple, optional + Final patch size in YX dimensions, by default (128, 128). + + Returns + ------- + Dataset + Configured prediction dataset for evaluation. + """ data_module = TripletDataModule( data_path=data_path, tracks_path=tracks_path, @@ -84,7 +115,7 @@ def dataset_of_tracks( final_yx_patch_size=final_yx_patch_size, batch_size=1, num_workers=16, - normalizations=None, + normalizations=[], predict_cells=True, ) # for train and val diff --git a/viscy/representation/evaluation/clustering.py b/viscy/representation/evaluation/clustering.py index ebf49455f..c5599229b 100644 --- a/viscy/representation/evaluation/clustering.py +++ b/viscy/representation/evaluation/clustering.py @@ -12,12 +12,16 @@ from sklearn.neighbors import KNeighborsClassifier -def knn_accuracy(embeddings, annotations, k=5): +def knn_accuracy(embeddings: NDArray, annotations: NDArray, k: int = 5) -> float: """ Evaluate the k-NN classification accuracy. Parameters ---------- + embeddings : NDArray + Embeddings to cluster. + annotations : NDArray + Ground truth labels. k : int, optional Number of neighbors to use for k-NN. Default is 5. @@ -85,8 +89,9 @@ def select_block(distances: NDArray, index: NDArray) -> NDArray: def compare_time_offset( single_track_distances: NDArray, time_offset: int = 1 ) -> NDArray: - """Extract the nearest neighbor distances/rankings - of the next sample compared to each sample. + """Extract the nearest neighbor distances/rankings of the next sample. + + Compared to each sample. Parameters ---------- @@ -105,12 +110,14 @@ def compare_time_offset( return single_track_distances.diagonal(offset=-time_offset) -def dbscan_clustering(embeddings, eps=0.5, min_samples=5): +def dbscan_clustering(embeddings: NDArray, eps=0.5, min_samples=5): """ Apply DBSCAN clustering to the embeddings. Parameters ---------- + embeddings : NDArray + Embeddings to cluster. eps : float, optional The maximum distance between two samples for them to be considered as in the same neighborhood. Default is 0.5. min_samples : int, optional @@ -118,7 +125,7 @@ def dbscan_clustering(embeddings, eps=0.5, min_samples=5): Returns ------- - np.ndarray + NDArray Clustering labels assigned by DBSCAN. """ dbscan = DBSCAN(eps=eps, min_samples=min_samples) @@ -126,12 +133,16 @@ def dbscan_clustering(embeddings, eps=0.5, min_samples=5): return clusters -def clustering_evaluation(embeddings, annotations, method="nmi"): +def clustering_evaluation(embeddings: NDArray, annotations: NDArray, method="nmi"): """ Evaluate the clustering of the embeddings compared to the ground truth labels. Parameters ---------- + embeddings : NDArray + Embeddings to cluster. + annotations : NDArray + Ground truth labels. method : str, optional Metric to use for evaluation ('nmi' or 'ari'). Default is 'nmi'. diff --git a/viscy/representation/evaluation/dimensionality_reduction.py b/viscy/representation/evaluation/dimensionality_reduction.py index eb5d43f91..cc409c30b 100644 --- a/viscy/representation/evaluation/dimensionality_reduction.py +++ b/viscy/representation/evaluation/dimensionality_reduction.py @@ -1,4 +1,4 @@ -"""PCA and UMAP dimensionality reduction.""" +from typing import TYPE_CHECKING import pandas as pd import umap @@ -7,21 +7,24 @@ from sklearn.preprocessing import StandardScaler from xarray import Dataset +if TYPE_CHECKING: + from phate import PHATE + def compute_phate( - embedding_dataset, + embedding_dataset: NDArray | Dataset, n_components: int = 2, knn: int = 5, decay: int = 40, update_dataset: bool = False, **phate_kwargs, -) -> tuple[object, NDArray]: +) -> tuple[PHATE, NDArray]: """ Compute PHATE embeddings for features and optionally update dataset. Parameters ---------- - embedding_dataset : xarray.Dataset or NDArray + embedding_dataset : NDArray | Dataset The dataset containing embeddings, timepoints, fov_name, and track_id, or a numpy array of embeddings. n_components : int, optional @@ -37,7 +40,7 @@ def compute_phate( Returns ------- - tuple[object, NDArray] + tuple[phate.PHATE, NDArray] PHATE model and PHATE embeddings Raises @@ -75,12 +78,14 @@ def compute_phate( return phate_model, phate_embedding -def compute_pca(embedding_dataset, n_components=None, normalize_features=True): +def compute_pca( + embedding_dataset: NDArray | Dataset, n_components=None, normalize_features=True +): """Compute PCA embeddings for features and optionally update dataset. Parameters ---------- - embedding_dataset : xarray.Dataset or NDArray + embedding_dataset : Dataset | NDArray The dataset containing embeddings, timepoints, fov_name, and track_id, or a numpy array of embeddings. n_components : int, optional @@ -93,7 +98,6 @@ def compute_pca(embedding_dataset, n_components=None, normalize_features=True): tuple[NDArray, pd.DataFrame] PCA embeddings and PCA DataFrame """ - embeddings = ( embedding_dataset["features"].values if isinstance(embedding_dataset, Dataset) diff --git a/viscy/representation/evaluation/distance.py b/viscy/representation/evaluation/distance.py index a920eb072..fd8df30af 100644 --- a/viscy/representation/evaluation/distance.py +++ b/viscy/representation/evaluation/distance.py @@ -2,11 +2,32 @@ from typing import Literal import numpy as np +from numpy.typing import NDArray from sklearn.metrics.pairwise import cosine_similarity +from xarray import Dataset -def calculate_cosine_similarity_cell(embedding_dataset, fov_name, track_id): - """Extract embeddings and calculate cosine similarities for a specific cell""" +def calculate_cosine_similarity_cell( + embedding_dataset: Dataset, fov_name: str, track_id: int +) -> tuple[NDArray, NDArray]: + """ + + Extract embeddings and calculate cosine similarities for a specific cell + + Parameters + ---------- + embedding_dataset : Dataset + Dataset containing embeddings and metadata + fov_name : str + Field of view identifier + track_id : int + Track identifier for the specific cell + + Returns + ------- + tuple[NDArray, NDArray] + Time points and cosine similarities for the specific cell + """ filtered_data = embedding_dataset.where( (embedding_dataset["fov_name"] == fov_name) & (embedding_dataset["track_id"] == track_id), @@ -22,7 +43,7 @@ def calculate_cosine_similarity_cell(embedding_dataset, fov_name, track_id): def compute_displacement( - embedding_dataset, + embedding_dataset: Dataset, distance_metric: Literal["euclidean_squared", "cosine"] = "euclidean_squared", ) -> dict[int, list[float]]: """Compute the displacement or mean square displacement (MSD) of embeddings. @@ -34,15 +55,13 @@ def compute_displacement( Parameters ---------- - embedding_dataset : xarray.Dataset + embedding_dataset : Dataset Dataset containing embeddings and metadata - distance_metric : str + distance_metric : Literal["euclidean_squared", "cosine"] The metric to use for computing distances between embeddings. Valid options are: - - "euclidean": Euclidean distance (L2 norm) - "euclidean_squared": Squared Euclidean distance (for MSD, default) - "cosine": Cosine similarity - - "cosine_dissimilarity": 1 - cosine similarity Returns ------- @@ -130,31 +149,38 @@ def compute_displacement_statistics( return mean_displacement_per_tau, std_displacement_per_tau -def compute_dynamic_range(mean_displacement_per_tau): - """ - Compute the dynamic range as the difference between the maximum - and minimum mean displacement per τ. +def compute_dynamic_range(mean_displacement_per_tau: dict[int, float]): + """Compute the dynamic range as the difference between the maximum and minimum mean displacement. - Parameters: - mean_displacement_per_tau: dict with τ as key and mean displacement as value + Per τ. + + Parameters + ---------- + mean_displacement_per_tau : dict[int, float] + Dictionary with τ as key and mean displacement as value - Returns: - float: dynamic range (max displacement - min displacement) + Returns + ------- + float + dynamic range (max displacement - min displacement) """ displacements = list(mean_displacement_per_tau.values()) return max(displacements) - min(displacements) -def compute_rms_per_track(embedding_dataset): +def compute_rms_per_track(embedding_dataset: Dataset): """ Compute RMS of the time derivative of embeddings per track. - Parameters: - embedding_dataset : xarray.Dataset + Parameters + ---------- + embedding_dataset : Dataset The dataset containing embeddings, timepoints, fov_name, and track_id. - Returns: - list: A list of RMS values, one for each track. + Returns + ------- + list + A list of RMS values, one for each track. """ fov_names = embedding_dataset["fov_name"].values track_ids = embedding_dataset["track_id"].values @@ -193,7 +219,25 @@ def compute_rms_per_track(embedding_dataset): return rms_values -def calculate_normalized_euclidean_distance_cell(embedding_dataset, fov_name, track_id): +def calculate_normalized_euclidean_distance_cell( + embedding_dataset: Dataset, fov_name: str, track_id: int +): + """Calculate normalized euclidean distance for a specific cell track. + + Parameters + ---------- + embedding_dataset : Dataset + Dataset containing embedding data with fov_name and track_id coordinates + fov_name : str + Field of view identifier + track_id : int + Track identifier for the specific cell + + Returns + ------- + NDArray + Normalized euclidean distances for the cell track + """ filtered_data = embedding_dataset.where( (embedding_dataset["fov_name"] == fov_name) & (embedding_dataset["track_id"] == track_id), diff --git a/viscy/representation/evaluation/feature.py b/viscy/representation/evaluation/feature.py index 4b0896c84..c7ec70647 100644 --- a/viscy/representation/evaluation/feature.py +++ b/viscy/representation/evaluation/feature.py @@ -5,7 +5,7 @@ import pandas as pd import scipy.stats from numpy import fft -from numpy.typing import ArrayLike +from numpy.typing import ArrayLike, NDArray from scipy.ndimage import distance_transform_edt from scipy.stats import linregress from skimage.exposure import rescale_intensity @@ -127,7 +127,7 @@ def __init__(self, image: ArrayLike, segmentation_mask: ArrayLike | None = None) self._eps = 1e-10 - def _compute_kurtosis(self): + def _compute_kurtosis(self) -> float: """Compute the kurtosis of the image. Returns @@ -140,7 +140,7 @@ def _compute_kurtosis(self): return np.nan return scipy.stats.kurtosis(self.image, fisher=True, axis=None) - def _compute_skewness(self): + def _compute_skewness(self) -> float: """Compute the skewness of the image. Returns @@ -153,7 +153,7 @@ def _compute_skewness(self): return np.nan return scipy.stats.skew(self.image, axis=None) - def _compute_glcm_features(self): + def _compute_glcm_features(self) -> tuple[float, float, float]: """Compute GLCM-based texture features from the image. Converts normalized image to uint8 for GLCM computation. @@ -169,7 +169,7 @@ def _compute_glcm_features(self): return contrast, dissimilarity, homogeneity - def _compute_iqr(self): + def _compute_iqr(self) -> float: """Compute the interquartile range of pixel intensities. The IQR is observed to increase when a cell is infected, @@ -184,7 +184,7 @@ def _compute_iqr(self): return iqr - def _compute_weighted_intensity_gradient(self): + def _compute_weighted_intensity_gradient(self) -> float: """Compute the weighted radial intensity gradient profile. Calculates the slope of the azimuthally averaged radial gradient @@ -241,7 +241,7 @@ def _compute_weighted_intensity_gradient(self): return slope - def _compute_spectral_entropy(self): + def _compute_spectral_entropy(self) -> float: """Compute the spectral entropy of the image. Spectral entropy measures the complexity of the image's frequency @@ -268,17 +268,22 @@ def _compute_spectral_entropy(self): return entropy - def _compute_texture_features(self): + def _compute_texture_features(self) -> NDArray: """Compute Haralick texture features from the image. Converts normalized image to uint8 for Haralick computation. + + Returns + ------- + texture_features: NDArray + Haralick texture features of the image. """ # Convert 0-1 normalized image to uint8 (0-255) image_uint8 = (self.image_normalized * 255).astype(np.uint8) texture_features = mh.features.haralick(image_uint8) return np.mean(np.ptp(texture_features, axis=0)) - def _compute_perimeter_area_ratio(self): + def _compute_perimeter_area_ratio(self) -> tuple[float, float, float]: """Compute the perimeter of the nuclear segmentations found inside the patch. This function calculates the average perimeter, average area, and their ratio @@ -286,14 +291,8 @@ def _compute_perimeter_area_ratio(self): Returns ------- - average_perimeter, average_area, ratio: tuple - Tuple containing: - - average_perimeter : float - Average perimeter of all regions in the patch - - average_area : float - Average area of all regions - - ratio : float - Ratio of total perimeter to total area + tuple[float, float, float] + Tuple containing average perimeter, average area, and ratio of total perimeter to total area """ total_perimeter = 0 total_area = 0 @@ -314,7 +313,7 @@ def _compute_perimeter_area_ratio(self): return average_perimeter, average_area, total_perimeter / total_area - def _compute_nucleus_eccentricity(self): + def _compute_nucleus_eccentricity(self) -> float: """Compute the eccentricity of the nucleus. Eccentricity measures how much the nucleus deviates from @@ -336,7 +335,7 @@ def _compute_nucleus_eccentricity(self): eccentricities = [region.eccentricity for region in regions] return float(np.mean(eccentricities)) - def _compute_Eucledian_distance_transform(self): + def _compute_Eucledian_distance_transform(self) -> NDArray: """Compute the Euclidean distance transform of the segmentation mask. This transform computes the distance from each pixel to the @@ -345,7 +344,7 @@ def _compute_Eucledian_distance_transform(self): Returns ------- - dist_transform: ndarray + dist_transform: NDArray Distance transform of the segmentation mask. """ # Ensure the image is binary @@ -376,7 +375,7 @@ def _compute_intensity_localization(self): intensity_weighted_center = np.sum(self.image * edt) / (np.sum(edt) + self._eps) return intensity_weighted_center - def _compute_area(self, sigma=0.6): + def _compute_area(self, sigma: float = 0.6) -> tuple[float, float]: """Create a binary mask using morphological operations. This function creates a binary mask from the input image using Gaussian blur @@ -391,12 +390,8 @@ def _compute_area(self, sigma=0.6): Returns ------- - masked_intensity, masked_area: tuple - Tuple containing: - - masked_intensity : float - Mean intensity inside the sensor area - - masked_area : float - Area of the sensor mask in pixels + tuple[float, float] + Tuple containing masked intensity and masked area """ input_image_blur = gaussian(self.image, sigma=sigma) @@ -411,7 +406,7 @@ def _compute_area(self, sigma=0.6): return masked_intensity, np.sum(mask) - def _compute_zernike_moments(self): + def _compute_zernike_moments(self) -> NDArray: """Compute the Zernike moments of the image. Zernike moments are a set of orthogonal moments that capture @@ -420,16 +415,21 @@ def _compute_zernike_moments(self): Returns ------- - zernike_moments: np.ndarray + zernike_moments: NDArray Zernike moments of the image. """ zernike_moments = mh.features.zernike_moments(self.image, 32) return zernike_moments - def _compute_radial_intensity_gradient(self): + def _compute_radial_intensity_gradient(self) -> float: """Compute the radial intensity gradient of the image. Uses 0-1 normalized image directly for gradient calculation. + + Returns + ------- + radial_intensity_gradient: float + Radial intensity gradient of the image. """ # Use 0-1 normalized image directly y, x = np.indices(self.image_normalized.shape) @@ -447,7 +447,7 @@ def _compute_radial_intensity_gradient(self): return radial_intensity_gradient[0] - def compute_intensity_features(self): + def compute_intensity_features(self) -> IntensityFeatures: """Compute intensity features. This function computes various intensity-based features from the input image. @@ -471,7 +471,7 @@ def compute_intensity_features(self): weighted_intensity_gradient=self._compute_weighted_intensity_gradient(), ) - def compute_texture_features(self): + def compute_texture_features(self) -> TextureFeatures: """Compute texture features. This function computes texture features from the input image. @@ -493,7 +493,7 @@ def compute_texture_features(self): texture=self._compute_texture_features(), ) - def compute_morphology_features(self): + def compute_morphology_features(self) -> MorphologyFeatures: """Compute morphology features. This function computes morphology features from the input image. @@ -528,7 +528,7 @@ def compute_morphology_features(self): masked_area=masked_area, ) - def compute_symmetry_descriptor(self): + def compute_symmetry_descriptor(self) -> SymmetryDescriptor: """Compute the symmetry descriptor of the image. This function computes the symmetry descriptor of the image. @@ -615,20 +615,20 @@ class DynamicFeatures: Parameters ---------- - tracking_df : pandas.DataFrame + tracking_df : pd.DataFrame DataFrame containing cell tracking data with track_id, t, x, y columns Attributes ---------- - tracking_df : pandas.DataFrame + tracking_df : pd.DataFrame The input tracking dataframe containing cell position data over time - track_features : TrackFeatures or None + track_features : TrackFeatures | None Computed velocity-based features including mean, max, min velocities and their standard deviation - displacement_features : DisplacementFeatures or None + displacement_features : DisplacementFeatures | None Computed displacement features including total distance traveled, net displacement, and directional persistence - angular_features : AngularFeatures or None + angular_features : AngularFeatures | None Computed angular features including mean, max, and standard deviation of angular velocities @@ -657,7 +657,7 @@ def __init__(self, tracking_df: pd.DataFrame): if not np.issubdtype(tracking_df[col].dtype, np.number): raise ValueError(f"Column {col} must be numeric") - def _compute_instantaneous_velocity(self, track_id: str) -> np.ndarray: + def _compute_instantaneous_velocity(self, track_id: str) -> NDArray: """Compute the instantaneous velocity for all timepoints in a track. Parameters @@ -667,7 +667,7 @@ def _compute_instantaneous_velocity(self, track_id: str) -> np.ndarray: Returns ------- - velocities : np.ndarray + velocities : NDArray Array of instantaneous velocities for each timepoint """ # Get track data sorted by time @@ -708,15 +708,12 @@ def _compute_displacement(self, track_id: str) -> tuple[float, float, float]: Returns ------- - total_distance, net_displacement, directional_persistence: tuple - Tuple containing: - - total_distance : float - Total distance traveled by the cell along its path - - net_displacement : float - Straight-line distance between start and end positions - - directional_persistence : float - Ratio of net displacement to total distance (0 to 1), - where 1 indicates perfectly straight movement + tuple[float, float, float] + Tuple containing total distance, net displacement, and directional persistence + - total_distance: Total distance traveled by the cell along its path. + - net_displacement: Straight-line distance between start and end positions. + - directional_persistence: Ratio of net displacement to total distance (0 to 1), + where 1 indicates perfectly straight movement. """ track_data = self.tracking_df[ self.tracking_df["track_id"] == track_id @@ -758,11 +755,11 @@ def _compute_angular_velocity(self, track_id: str) -> tuple[float, float, float] Returns ------- - mean_angular_velocity, max_angular_velocity, std_angular_velocity: tuple - Tuple containing: - - mean_angular_velocity - - max_angular_velocity - - std_angular_velocity + tuple[float, float, float] + Tuple containing mean, maximum, and standard deviation of angular velocities + - mean_angular_velocity: Average angular velocity over the track. + - max_angular_velocity: Maximum angular velocity observed in the track. + - std_angular_velocity: Standard deviation of angular velocities in the track. """ track_data = self.tracking_df[ self.tracking_df["track_id"] == track_id diff --git a/viscy/representation/evaluation/lca.py b/viscy/representation/evaluation/lca.py index 7c5216193..9090e069f 100644 --- a/viscy/representation/evaluation/lca.py +++ b/viscy/representation/evaluation/lca.py @@ -1,6 +1,6 @@ """Linear probing of trained encoder based on cell state labels.""" -from typing import Mapping +from collections.abc import Mapping import pandas as pd import torch @@ -11,9 +11,8 @@ from sklearn.metrics import classification_report from sklearn.preprocessing import StandardScaler from torch import Tensor -from xarray import DataArray - from viscy.representation.contrastive import ContrastiveEncoder +from xarray import DataArray def fit_logistic_regression( @@ -139,11 +138,37 @@ def __init__(self, backbone: ContrastiveEncoder, classifier: nn.Linear) -> None: @staticmethod def scale_features(x: Tensor) -> Tensor: + """Scale features using standardization. + + Parameters + ---------- + x : Tensor + Input tensor to scale + + Returns + ------- + Tensor + Scaled tensor with zero mean and unit variance + """ m = x.mean(-2, keepdim=True) s = x.std(-2, unbiased=False, keepdim=True) return (x - m) / s def forward(self, x: Tensor, scale_features: bool = False) -> Tensor: + """Forward pass through the LCA backbone. + + Parameters + ---------- + x : Tensor + Input tensor + scale_features : bool, optional + Whether to apply feature scaling, by default False + + Returns + ------- + Tensor + Encoded feature representations + """ x = self.backbone.stem(x) x = self.backbone.encoder(x) if scale_features: diff --git a/viscy/representation/evaluation/visualization.py b/viscy/representation/evaluation/visualization.py index 9d787fe05..df08d0e56 100644 --- a/viscy/representation/evaluation/visualization.py +++ b/viscy/representation/evaluation/visualization.py @@ -4,6 +4,7 @@ import logging from io import BytesIO from pathlib import Path +from typing import Any import dash import dash.dependencies as dd @@ -12,10 +13,10 @@ import pandas as pd import plotly.graph_objects as go from dash import dcc, html +from numpy.typing import NDArray from PIL import Image from sklearn.decomposition import PCA from sklearn.preprocessing import StandardScaler - from viscy.data.triplet import TripletDataModule from viscy.representation.embedding_writer import read_embedding_dataset @@ -24,11 +25,18 @@ class EmbeddingVisualizationApp: + """Interactive visualization app for embedding analysis. + + Provides a Dash-based web application for exploring embeddings with PCA + visualization, track selection, and image display capabilities for + representation learning analysis. + """ + def __init__( self, - data_path: str, - tracks_path: str, - features_path: str, + data_path: str | Path, + tracks_path: str | Path, + features_path: str | Path, channels_to_display: list[str] | str, fov_tracks: dict[str, list[int] | str], z_range: tuple[int, int] = (0, 1), @@ -46,11 +54,11 @@ def __init__( Parameters ---------- - data_path: str + data_path: str | Path Path to the data directory. - tracks_path: str + tracks_path: str | Path Path to the tracks directory. - features_path: str + features_path: str | Path Path to the features directory. channels_to_display: list[str] | str List of channels to display. @@ -68,6 +76,7 @@ def __init__( Number of workers to use for loading data. output_dir: str | None, optional Directory to save CSV files and other outputs. If None, uses current working directory. + Returns ------- None @@ -101,7 +110,7 @@ def __init__( self._init_app() atexit.register(self._cleanup_cache) - def _prepare_data(self): + def _prepare_data(self) -> None: """Prepare the feature data and PCA transformation""" embedding_dataset = read_embedding_dataset(self.features_path) features = embedding_dataset["features"] @@ -182,11 +191,11 @@ def _prepare_data(self): # Combine all filtered features self.filtered_features_df = pd.concat(all_filtered_features, axis=0) - def _create_figure(self): + def _create_figure(self) -> None: """Create the initial scatter plot figure""" self.fig = self._create_track_colored_figure() - def _init_app(self): + def _init_app(self) -> None: """Initialize the Dash application""" self.app = dash.Dash(__name__) @@ -509,14 +518,29 @@ def _init_app(self): prevent_initial_call=True, ) def update_figure( - color_mode, - show_arrows, - x_axis, - y_axis, - relayout_data, - selected_data, - current_figure, - ): + color_mode: str, + show_arrows: list[str] | None, + x_axis: str, + y_axis: str, + relayout_data: dict[str, Any] | None, + selected_data: dict[str, Any] | None, + current_figure: dict[str, Any], + ) -> tuple[dict[str, Any], dict[str, Any] | None]: + """Update the figure based on the selected data. + + Parameters + ---------- + color_mode: str + The color mode. + show_arrows: list[str] | None + The show arrows. + x_axis: str + The x axis. + y_axis: str + The y axis. + """ + if show_arrows is None: + show_arrows = [] show_arrows = len(show_arrows or []) > 0 ctx = dash.callback_context @@ -554,8 +578,19 @@ def update_figure( [dd.Input("scatter-plot", "clickData")], prevent_initial_call=True, ) - def update_track_timeline(clickData): - """Update the track timeline based on the clicked point""" + def update_track_timeline(clickData: dict[str, Any] | None) -> html.Div: + """Update the track timeline based on the clicked point + + Parameters + ---------- + clickData: dict[str, Any] | None + The click data from the scatter plot. + + Returns + ------- + html.Div: The track timeline. + + """ if clickData is None: return html.Div("Click on a point to see the track timeline") @@ -727,19 +762,61 @@ def update_track_timeline(clickData): prevent_initial_call=True, ) def update_clusters_tab( - assign_clicks, - clear_clicks, - save_name_clicks, - cancel_name_clicks, - edit_name_clicks, - selected_data, - current_figure, - color_mode, - show_arrows, - x_axis, - y_axis, - cluster_name, - ): + assign_clicks: int | None, + clear_clicks: int | None, + save_name_clicks: int | None, + cancel_name_clicks: int | None, + edit_name_clicks: list[int], + selected_data: dict[str, Any] | None, + current_figure: dict[str, Any], + color_mode: str, + show_arrows: list[str] | None, + x_axis: str, + y_axis: str, + cluster_name: str | None, + ) -> tuple[ + dict[str, str], + html.Div | None, + str, + dict[str, Any] | Any, + dict[str, str], + str, + dict[str, Any] | None, + ]: + """Update the clusters tab and handle modal. + + Parameters + ---------- + assign_clicks: int | None + The number of clicks on the assign cluster button. + clear_clicks: int | None + The number of clicks on the clear clusters button. + save_name_clicks: int | None + The number of clicks on the save cluster name button. + cancel_name_clicks: int | None + The number of clicks on the cancel cluster name button. + edit_name_clicks: list[int] + The indices of the edit cluster name buttons. + selected_data: dict[str, Any] | None + The selected data from the scatter plot. + current_figure: dict[str, Any] + The current figure. + color_mode: str + The color mode. + show_arrows: list[str] | None + The show arrows. + x_axis: str + The x axis. + y_axis: str + The y axis. + cluster_name: str | None + The cluster name. + + Returns + ------- + tuple[dict[str, str], html.Div | None, str, dict[str, Any] | Any, dict[str, str], str, dict[str, Any] | None]: + The updated clusters tab and handle modal. + """ ctx = dash.callback_context if not ctx.triggered: return ( @@ -962,8 +1039,18 @@ def update_clusters_tab( [dd.Input("save-clusters-csv", "n_clicks")], prevent_initial_call=True, ) - def save_clusters_csv(n_clicks): - """Callback to save clusters to CSV file""" + def save_clusters_csv(n_clicks: int | None) -> html.Div: + """Callback to save clusters to CSV file + + Parameters + ---------- + n_clicks: int | None + The number of clicks on the save clusters CSV button. + + Returns + ------- + html.Div: The cluster container. + """ if n_clicks and self.clusters: try: output_path = self.save_clusters_to_csv() @@ -1035,8 +1122,33 @@ def save_clusters_csv(n_clicks): ], prevent_initial_call=True, ) - def clear_selection(n_clicks, color_mode, show_arrows, x_axis, y_axis): - """Callback to clear the selection and restore original opacity""" + def clear_selection( + n_clicks: int | None, + color_mode: str, + show_arrows: list[str] | None, + x_axis: str, + y_axis: str, + ) -> tuple[dict[str, Any] | Any, dict[str, Any] | None]: + """Callback to clear the selection and restore original opacity + + Parameters + ---------- + n_clicks: int | None + The number of clicks on the clear selection button. + color_mode: str + The color mode. + show_arrows: list[str] | None + The show arrows. + x_axis: str + The x axis. + y_axis: str + The y axis. + + Returns + ------- + tuple[dict[str, Any] | Any, dict[str, Any] | None]: + The new figure and clear selectedData. + """ if n_clicks: # Create a new figure with no selections if color_mode == "track": @@ -1063,7 +1175,9 @@ def clear_selection(n_clicks, color_mode, show_arrows, x_axis, y_axis): return fig, None # Return new figure and clear selectedData return dash.no_update, dash.no_update - def _calculate_equal_aspect_ranges(self, x_data, y_data): + def _calculate_equal_aspect_ranges( + self, x_data: NDArray, y_data: NDArray + ) -> tuple[tuple[float, float], tuple[float, float]]: """Calculate ranges for x and y axes to ensure equal aspect ratio. Parameters @@ -1110,11 +1224,26 @@ def _calculate_equal_aspect_ranges(self, x_data, y_data): def _create_track_colored_figure( self, - show_arrows=False, - x_axis=None, - y_axis=None, - ): - """Create scatter plot with track-based coloring""" + show_arrows: bool = False, + x_axis: str | None = None, + y_axis: str | None = None, + ) -> go.Figure: + """Create scatter plot with track-based coloring + + Parameters + ---------- + show_arrows: bool + The show arrows. + x_axis: str | None + The x axis. + y_axis: str | None + The y axis. + + Returns + ------- + go.Figure + The scatter plot. + """ x_axis = x_axis or self.default_x y_axis = y_axis or self.default_y @@ -1329,10 +1458,10 @@ def _create_track_colored_figure( def _create_time_colored_figure( self, - show_arrows=False, - x_axis=None, - y_axis=None, - ): + show_arrows: bool = False, + x_axis: str | None = None, + y_axis: str | None = None, + ) -> go.Figure: """Create scatter plot with time-based coloring""" x_axis = x_axis or self.default_x y_axis = y_axis or self.default_y @@ -1481,7 +1610,7 @@ def _create_time_colored_figure( return fig @staticmethod - def _normalize_image(img_array): + def _normalize_image(img_array: NDArray) -> NDArray: """Normalize a single image array to [0, 255] more efficiently""" min_val = img_array.min() max_val = img_array.max() @@ -1491,7 +1620,7 @@ def _normalize_image(img_array): return ((img_array - min_val) * 255 / (max_val - min_val)).astype(np.uint8) @staticmethod - def _numpy_to_base64(img_array): + def _numpy_to_base64(img_array: NDArray) -> str: """Convert numpy array to base64 string with compression""" if not isinstance(img_array, np.uint8): img_array = img_array.astype(np.uint8) @@ -1503,12 +1632,12 @@ def _numpy_to_base64(img_array): "utf-8" ) - def save_cache(self, cache_path: str | None = None): + def save_cache(self, cache_path: str | Path | None = None) -> None: """Save the image cache to disk using pickle. Parameters ---------- - cache_path : str | None, optional + cache_path : str | Path | None, optional Path to save the cache. If None, uses self.cache_path, by default None """ import pickle @@ -1543,12 +1672,12 @@ def save_cache(self, cache_path: str | None = None): except Exception as e: logger.error(f"Error saving cache: {e}") - def load_cache(self, cache_path: str | None = None) -> bool: + def load_cache(self, cache_path: str | Path | None = None) -> bool: """Load the image cache from disk using pickle. Parameters ---------- - cache_path : str | None, optional + cache_path : str | Path | None, optional Path to load the cache from. If None, uses self.cache_path, by default None Returns @@ -1596,7 +1725,7 @@ def load_cache(self, cache_path: str | None = None) -> bool: logger.error(f"Error loading cache: {e}") return False - def preload_images(self): + def preload_images(self) -> None: """Preload all images into memory""" # Try to load from cache first if self.cache_path and self.load_cache(): @@ -1625,7 +1754,7 @@ def preload_images(self): final_yx_patch_size=self.yx_patch_size, batch_size=1, num_workers=self.num_loading_workers, - normalizations=None, + normalizations=[], predict_cells=True, ) data_module.setup("predict") @@ -1696,12 +1825,14 @@ def preload_images(self): if self.cache_path: self.save_cache() - def _cleanup_cache(self): + def _cleanup_cache(self) -> None: """Clear the image cache when the program exits""" logging.info("Cleaning up image cache...") self.image_cache.clear() - def _get_trajectory_images_lasso(self, x_axis, y_axis, selected_data): + def _get_trajectory_images_lasso( + self, x_axis: str, y_axis: str, selected_data: dict[str, Any] | None + ) -> html.Div: """Get images of points selected by lasso""" if not selected_data or not selected_data.get("points"): return html.Div("Use the lasso tool to select points") @@ -1908,7 +2039,7 @@ def _get_output_info_display(self) -> html.Div: }, ) - def _get_cluster_images(self): + def _get_cluster_images(self) -> html.Div: """Display images for all clusters in a grid layout""" if not self.clusters: return html.Div( @@ -2117,7 +2248,7 @@ def get_output_dir(self) -> Path: """ return self.output_dir - def save_clusters_to_csv(self, output_path: str | None = None) -> str: + def save_clusters_to_csv(self, output_path: str | Path | None = None) -> str: """ Save cluster information to CSV file. @@ -2126,7 +2257,7 @@ def save_clusters_to_csv(self, output_path: str | None = None) -> str: Parameters ---------- - output_path : str | None, optional + output_path : str | Path | None, optional Path to save the CSV file. If None, generates a timestamped filename in the output directory, by default None @@ -2195,7 +2326,7 @@ def save_clusters_to_csv(self, output_path: str | None = None) -> str: logger.error(f"Error saving clusters to CSV: {e}") raise - def run(self, debug=False, port=None): + def run(self, debug: bool = False, port: int | None = None) -> None: """Run the Dash server Parameters @@ -2207,12 +2338,12 @@ def run(self, debug=False, port=None): """ import socket - def is_port_in_use(port): + def is_port_in_use(port: int) -> bool: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: try: s.bind(("127.0.0.1", port)) return False - except socket.error: + except OSError: return True if port is None: diff --git a/viscy/representation/multi_modal.py b/viscy/representation/multi_modal.py index 55481d434..51e429a45 100644 --- a/viscy/representation/multi_modal.py +++ b/viscy/representation/multi_modal.py @@ -1,10 +1,10 @@ +from collections.abc import Sequence from logging import getLogger -from typing import Literal, Sequence +from typing import Literal import torch from pytorch_metric_learning.losses import NTXentLoss from torch import Tensor, nn - from viscy.data.typing import TripletSample from viscy.representation.contrastive import ContrastiveEncoder from viscy.representation.engine import ContrastiveModule @@ -13,6 +13,20 @@ class JointEncoders(nn.Module): + """Joint multi-modal encoders for cross-modal representation learning. + + Pairs source and target encoders for CLIP-style contrastive learning + across different modalities or channels. Enables cross-modal alignment + and similarity computation through joint feature extraction. + + Parameters + ---------- + source_encoder : nn.Module | ContrastiveEncoder + Encoder for source modality/channel data. + target_encoder : nn.Module | ContrastiveEncoder + Encoder for target modality/channel data. + """ + def __init__( self, source_encoder: nn.Module | ContrastiveEncoder, @@ -25,19 +39,86 @@ def __init__( def forward( self, source: Tensor, target: Tensor ) -> tuple[tuple[Tensor, Tensor], tuple[Tensor, Tensor]]: + """Forward pass through both encoders for multi-modal features. + + Parameters + ---------- + source : Tensor + Source modality input tensor. + target : Tensor + Target modality input tensor. + + Returns + ------- + tuple[tuple[Tensor, Tensor], tuple[Tensor, Tensor]] + Tuple of (source_features, source_projections) and + (target_features, target_projections) for cross-modal learning. + """ return self.source_encoder(source), self.target_encoder(target) def forward_features(self, source: Tensor, target: Tensor) -> tuple[Tensor, Tensor]: + """Extract feature representations from both modalities. + + Parameters + ---------- + source : Tensor + Source modality input tensor. + target : Tensor + Target modality input tensor. + + Returns + ------- + tuple[Tensor, Tensor] + Feature representations from source and target encoders for + multi-modal representation learning. + """ return self.source_encoder(source)[0], self.target_encoder(target)[0] def forward_projections( self, source: Tensor, target: Tensor ) -> tuple[Tensor, Tensor]: + """Extract projection representations for contrastive learning. + + Parameters + ---------- + source : Tensor + Source modality input tensor. + target : Tensor + Target modality input tensor. + + Returns + ------- + tuple[Tensor, Tensor] + Projection representations from source and target encoders for + cross-modal contrastive alignment and similarity computation. + """ return self.source_encoder(source)[1], self.target_encoder(target)[1] class JointContrastiveModule(ContrastiveModule): - """CLIP-style model pair for self-supervised cross-modality representation learning.""" + """CLIP-style model pair for self-supervised cross-modality representation learning. + + Parameters + ---------- + encoder : nn.Module | JointEncoders + Encoder model. + loss_function : nn.Module | nn.CosineEmbeddingLoss | nn.TripletMarginLoss | NTXentLoss + Loss function. By default, nn.TripletMarginLoss with margin 0.5. + lr : float + Learning rate. By default, 1e-3. + schedule : Literal["WarmupCosine", "Constant"] + Schedule for learning rate. By default, "Constant". + log_batches_per_epoch : int + Number of batches to log. By default, 8. + log_samples_per_batch : int + Number of samples to log. By default, 1. + log_embeddings : bool + Whether to log embeddings. By default, False. + example_input_array_shape : Sequence[int] + Shape of example input array. + prediction_arm : Literal["source", "target"] + Arm to use for prediction. By default, "source". + """ def __init__( self, @@ -67,6 +148,21 @@ def __init__( self._prediction_arm = prediction_arm def forward(self, source: Tensor, target: Tensor) -> tuple[Tensor, Tensor]: + """Forward pass for cross-modal contrastive projections. + + Parameters + ---------- + source : Tensor + Source modality input tensor. + target : Tensor + Target modality input tensor. + + Returns + ------- + tuple[Tensor, Tensor] + Projection tensors from source and target encoders for + cross-modal contrastive learning and alignment. + """ return self.model.forward_projections(source, target) def _info_nce_style_loss(self, z1: Tensor, z2: Tensor) -> Tensor: @@ -110,12 +206,53 @@ def _fit_forward_step( return loss def training_step(self, batch: TripletSample, batch_idx: int) -> Tensor: + """Training step for cross-modal contrastive learning. + + Parameters + ---------- + batch : TripletSample + Batch containing anchor and positive samples for multi-modal + contrastive learning. + batch_idx : int + Batch index in current epoch. + + Returns + ------- + Tensor + Cross-modal contrastive loss for training optimization. + """ return self._fit_forward_step(batch=batch, batch_idx=batch_idx, stage="train") def validation_step(self, batch: TripletSample, batch_idx: int) -> Tensor: + """Validation step for cross-modal contrastive learning. + + Parameters + ---------- + batch : TripletSample + Batch containing anchor and positive samples for multi-modal + validation. + batch_idx : int + Batch index in current validation epoch. + + Returns + ------- + Tensor + Cross-modal contrastive loss for validation monitoring. + """ return self._fit_forward_step(batch=batch, batch_idx=batch_idx, stage="val") def on_predict_start(self) -> None: + """Configure prediction encoder arm for multi-modal inference. + + Sets up the appropriate encoder (source or target) and channel slice + based on the prediction_arm configuration for single-modality + inference from the trained cross-modal model. + + Raises + ------ + ValueError + If prediction_arm is not 'source' or 'target'. + """ _logger.info(f"Using {self._prediction_arm} encoder for predictions.") if self._prediction_arm == "source": self._prediction_encoder = self.model.source_encoder @@ -129,6 +266,27 @@ def on_predict_start(self) -> None: def predict_step( self, batch: TripletSample, batch_idx: int, dataloader_idx: int = 0 ): + """Prediction step using selected encoder arm. + + Extracts features and projections using the configured prediction + encoder (source or target) for single-modality inference from the + trained cross-modal model. + + Parameters + ---------- + batch : TripletSample + Batch containing anchor samples for prediction. + batch_idx : int + Batch index in current prediction run. + dataloader_idx : int, default=0 + Index of dataloader when using multiple prediction dataloaders. + + Returns + ------- + dict + Dictionary containing 'features', 'projections', and 'index' + for the predicted samples from the selected modality encoder. + """ features, projections = self._prediction_encoder( batch["anchor"][:, self._prediction_channel_slice] ) diff --git a/viscy/trainer.py b/viscy/trainer.py index 03395a371..5f12db396 100644 --- a/viscy/trainer.py +++ b/viscy/trainer.py @@ -15,6 +15,12 @@ class VisCyTrainer(Trainer): + """Extended Lightning Trainer for VisCy with preprocessing and export capabilities. + + Provides additional functionality for dataset preprocessing, model export, + and normalization metadata computation for computer vision training workflows. + """ + def preprocess( self, data_path: Path, @@ -118,6 +124,29 @@ def precompute( exclude_fovs: list[str] | None = None, model: LightningModule | None = None, ): + """Precompute and normalize image arrays for efficient training. + + Parameters + ---------- + data_path : Path + Path to input HCS OME-Zarr dataset + output_path : Path + Path to save precomputed arrays + channel_names : list[str] + List of channel names to process + subtrahends : list[Literal["mean"] | float] + Subtraction values for normalization (per channel) + divisors : list[Literal["std"] | tuple[float, float]] + Division values for normalization (per channel) + image_array_key : str, optional + Array key in OME-Zarr structure, by default "0" + include_wells : list[str] | None, optional + Wells to include, by default None + exclude_fovs : list[str] | None, optional + Fields of view to exclude, by default None + model : LightningModule | None, optional + Ignored placeholder parameter, by default None + """ precompute_array( data_path=data_path, output_path=output_path, diff --git a/viscy/transforms/__init__.py b/viscy/transforms/__init__.py index e3c1d6012..cccb749de 100644 --- a/viscy/transforms/__init__.py +++ b/viscy/transforms/__init__.py @@ -1,3 +1,5 @@ +"""VisCy transform package for data preprocessing and augmentation.""" + from viscy.transforms._adjust_contrast import ( BatchedRandAdjustContrast, BatchedRandAdjustContrastd, diff --git a/viscy/transforms/_gaussian_blur.py b/viscy/transforms/_gaussian_blur.py new file mode 100644 index 000000000..1522d7de5 --- /dev/null +++ b/viscy/transforms/_gaussian_blur.py @@ -0,0 +1,114 @@ +"""3D version of `kornia.augmentation._2d.intensity.gaussian_blur`.""" + +from collections.abc import Iterable +from typing import Any + +from kornia.augmentation import random_generator as rg +from kornia.augmentation._3d.intensity.base import IntensityAugmentationBase3D +from kornia.constants import BorderType +from kornia.filters import filter3d, get_gaussian_kernel3d +from monai.transforms import MapTransform, RandomizableTransform +from torch import Tensor + + +class RandomGaussianBlur(IntensityAugmentationBase3D): + """ + Random Gaussian Blur. + + Parameters + ---------- + kernel_size : tuple[int, int, int] | int + Kernel size. + sigma : tuple[float, float, float] | Tensor + Sigma. + border_type : str, optional + Border type. By default, "reflect". + same_on_batch : bool, optional + Whether to apply the same transformation to all batches. By default, False. + p : float, optional + Probability of applying the transformation. By default, 0.5. + keepdim : bool, optional + Whether to keep the dimensions of the input tensor. By default, False. + """ + + def __init__( + self, + kernel_size: tuple[int, int, int] | int, + sigma: tuple[float, float, float] | Tensor, + border_type: str = "reflect", + same_on_batch: bool = False, + p: float = 0.5, + keepdim: bool = False, + ) -> None: + super().__init__(p=p, same_on_batch=same_on_batch, p_batch=1.0, keepdim=keepdim) + + self.flags = { + "kernel_size": kernel_size, + "border_type": BorderType.get(border_type), + } + self._param_generator = rg.RandomGaussianBlurGenerator(sigma) + + def apply_transform( + self, + input: Tensor, + params: dict[str, Tensor], + flags: dict[str, Any], + transform: Tensor | None = None, + ) -> Tensor: + sigma = params["sigma"].unsqueeze(-1).expand(-1, 2) + kernel = get_gaussian_kernel3d( + kernel_size=self.flags["kernel_size"], sigma=sigma + ) + return filter3d(input, kernel, border_type=self.flags["border_type"]) + + +class BatchedRandGaussianBlurd(MapTransform, RandomizableTransform): + """ + Batched Random Gaussian Blur. + + Parameters + ---------- + keys : str | Iterable[str] + Keys to apply the transformation to. + kernel_size : tuple[int, int] | int + Kernel size. + sigma : tuple[float, float] + Sigma. + border_type : str, optional + Border type. By default, "reflect". + same_on_batch : bool, optional + Whether to apply the same transformation to all batches. By default, False. + prob : float, optional + Probability of applying the transformation. By default, 0.1. + allow_missing_keys : bool, optional + Whether to allow missing keys. By default, False. + """ + + def __init__( + self, + keys: str | Iterable[str], + kernel_size: tuple[int, int] | int, + sigma: tuple[float, float], + border_type: str = "reflect", + same_on_batch: bool = False, + prob: float = 0.1, + allow_missing_keys: bool = False, + ) -> None: + MapTransform.__init__(self, keys, allow_missing_keys=allow_missing_keys) + RandomizableTransform.__init__(self, prob) + self.filter = RandomGaussianBlur( + kernel_size=kernel_size, + sigma=sigma, + border_type=border_type, + same_on_batch=same_on_batch, + p=prob, + ) + + def __call__(self, sample: dict[str, Tensor]) -> dict[str, Tensor]: + self.randomize(None) + if not self._do_transform: + return sample + for key in self.keys: + if key in sample: + sample[key] = -sample[key] + return sample diff --git a/viscy/transforms/_redef.py b/viscy/transforms/_redef.py index 696c81abc..a4031ea8f 100644 --- a/viscy/transforms/_redef.py +++ b/viscy/transforms/_redef.py @@ -1,6 +1,7 @@ """Redefine transforms from MONAI for jsonargparse.""" -from typing import Sequence +from collections.abc import Sequence +from typing import Any from monai.transforms import ( CenterSpatialCropd, @@ -20,14 +21,19 @@ class Decollated(Decollated): + """Decollate data wrapper for jsonargparse compatibility. + + See parent class documentation for details. + """ + def __init__( self, keys: Sequence[str] | str, detach: bool = True, pad_batch: bool = True, fill_value: float | None = None, - **kwargs, - ): + **kwargs: Any, + ) -> None: super().__init__( keys=keys, detach=detach, @@ -38,19 +44,29 @@ def __init__( class ToDeviced(ToDeviced): - def __init__(self, keys: Sequence[str] | str, **kwargs): + """Transfer data to device wrapper for jsonargparse compatibility. + + See parent class documentation for details. + """ + + def __init__(self, keys: Sequence[str] | str, **kwargs: Any) -> None: super().__init__(keys=keys, **kwargs) class RandWeightedCropd(RandWeightedCropd): + """Random weighted crop wrapper for jsonargparse compatibility. + + See parent class documentation for details. + """ + def __init__( self, keys: Sequence[str] | str, w_key: str, spatial_size: Sequence[int], num_samples: int = 1, - **kwargs, - ): + **kwargs: Any, + ) -> None: super().__init__( keys=keys, w_key=w_key, @@ -61,6 +77,11 @@ def __init__( class RandAffined(RandAffined): + """Random affine transform wrapper for jsonargparse compatibility. + + See parent class documentation for details. + """ + def __init__( self, keys: Sequence[str] | str, @@ -68,8 +89,8 @@ def __init__( rotate_range: Sequence[float | Sequence[float]] | float, shear_range: Sequence[float | Sequence[float]] | float, scale_range: Sequence[float | Sequence[float]] | float, - **kwargs, - ): + **kwargs: Any, + ) -> None: super().__init__( keys=keys, prob=prob, @@ -81,40 +102,60 @@ def __init__( class RandAdjustContrastd(RandAdjustContrastd): + """Random contrast adjustment wrapper for jsonargparse compatibility. + + See parent class documentation for details. + """ + def __init__( self, keys: Sequence[str] | str, prob: float, gamma: tuple[float, float] | float, - **kwargs, - ): + **kwargs: Any, + ) -> None: super().__init__(keys=keys, prob=prob, gamma=gamma, **kwargs) class RandScaleIntensityd(RandScaleIntensityd): + """Random intensity scaling wrapper for jsonargparse compatibility. + + See parent class documentation for details. + """ + def __init__( self, keys: Sequence[str] | str, factors: tuple[float, float] | float, prob: float, - **kwargs, - ): + **kwargs: Any, + ) -> None: super().__init__(keys=keys, factors=factors, prob=prob, **kwargs) class RandGaussianNoised(RandGaussianNoised): + """Random Gaussian noise wrapper for jsonargparse compatibility. + + See parent class documentation for details. + """ + def __init__( self, keys: Sequence[str] | str, prob: float, mean: float, std: float, - **kwargs, - ): + **kwargs: Any, + ) -> None: super().__init__(keys=keys, prob=prob, mean=mean, std=std, **kwargs) class RandGaussianSmoothd(RandGaussianSmoothd): + """Random Gaussian smoothing wrapper for jsonargparse compatibility. + + See parent class documentation for details. + """ + def __init__( self, keys: Sequence[str] | str, @@ -122,8 +163,8 @@ def __init__( sigma_x: tuple[float, float] | float, sigma_y: tuple[float, float] | float, sigma_z: tuple[float, float] | float, - **kwargs, - ): + **kwargs: Any, + ) -> None: super().__init__( keys=keys, prob=prob, @@ -135,6 +176,11 @@ def __init__( class ScaleIntensityRangePercentilesd(ScaleIntensityRangePercentilesd): + """Scale intensity by percentile range wrapper for jsonargparse compatibility. + + See parent class documentation for details. + """ + def __init__( self, keys: Sequence[str] | str, @@ -147,7 +193,7 @@ def __init__( channel_wise: bool = False, dtype: DTypeLike | None = None, allow_missing_keys: bool = False, - ): + ) -> None: super().__init__( keys=keys, lower=lower, @@ -163,13 +209,18 @@ def __init__( class RandSpatialCropd(RandSpatialCropd): + """Random spatial crop wrapper for jsonargparse compatibility. + + See parent class documentation for details. + """ + def __init__( self, keys: Sequence[str] | str, roi_size: Sequence[int] | int, random_center: bool = True, - **kwargs, - ): + **kwargs: Any, + ) -> None: super().__init__( keys=keys, roi_size=roi_size, @@ -179,21 +230,31 @@ def __init__( class CenterSpatialCropd(CenterSpatialCropd): + """Center spatial crop wrapper for jsonargparse compatibility. + + See parent class documentation for details. + """ + def __init__( self, keys: Sequence[str] | str, roi_size: Sequence[int] | int, - **kwargs, - ): + **kwargs: Any, + ) -> None: super().__init__(keys=keys, roi_size=roi_size, **kwargs) class RandFlipd(RandFlipd): + """Random flip wrapper for jsonargparse compatibility. + + See parent class documentation for details. + """ + def __init__( self, keys: Sequence[str] | str, prob: float, spatial_axis: Sequence[int] | int, - **kwargs, - ): + **kwargs: Any, + ) -> None: super().__init__(keys=keys, prob=prob, spatial_axis=spatial_axis, **kwargs) diff --git a/viscy/transforms/_transforms.py b/viscy/transforms/_transforms.py index 1c174ec27..0a32e1edd 100644 --- a/viscy/transforms/_transforms.py +++ b/viscy/transforms/_transforms.py @@ -1,3 +1,5 @@ +from collections.abc import Iterable, Sequence +from typing import Literal from warnings import warn import numpy as np @@ -12,7 +14,6 @@ ) from numpy.typing import DTypeLike from torch import Tensor -from typing_extensions import Iterable, Literal, Sequence from viscy.data.typing import ChannelMap, Sample @@ -75,6 +76,15 @@ def _normalize(): class RandInvertIntensityd(MapTransform, RandomizableTransform): """ Randomly invert the intensity of the image. + + Parameters + ---------- + keys : str | Iterable[str] + Keys to invert the intensity of. + prob : float, optional + Probability of inverting the intensity. By default, 0.1. + allow_missing_keys : bool, optional + Whether to allow missing keys. By default, False. """ def __init__( @@ -97,9 +107,18 @@ def __call__(self, sample: Sample) -> Sample: class TiledSpatialCropSamplesd(MapTransform, MultiSampleTrait): - """ - Crop multiple tiled ROIs from an image. + """Crop multiple tiled ROIs from an image. + Used for deterministic cropping in validation. + + Parameters + ---------- + keys : str | Iterable[str] + Keys to crop. + roi_size : tuple[int, int, int] + ROI size. + num_samples : int + Number of samples. """ def __init__( @@ -147,7 +166,13 @@ def __call__(self, sample: Sample) -> Sample: class StackChannelsd(MapTransform): - """Stack source and target channels.""" + """Stack source and target channels. + + Parameters + ---------- + channel_map : ChannelMap + Channel map. + """ def __init__(self, channel_map: ChannelMap) -> None: channel_names = [] @@ -164,7 +189,21 @@ def __call__(self, sample: Sample) -> Sample: class BatchedZoom(Transform): - "Batched zoom transform using ``torch.nn.functional.interpolate``." + """Batched zoom transform using ``torch.nn.functional.interpolate``. + + Parameters + ---------- + scale_factor : float | tuple[float, float, float] + Scale factor. + mode : Literal["nearest", "nearest-exact", "linear", "bilinear", "bicubic", "trilinear", "area"] + Mode. + align_corners : bool | None + Align corners. + recompute_scale_factor : bool | None + Recompute scale factor. + antialias : bool + Whether to use antialiasing. + """ def __init__( self, @@ -200,6 +239,8 @@ def __call__(self, sample: Tensor) -> Tensor: class BatchedScaleIntensityRangePercentiles(ScaleIntensityRangePercentiles): + """Batched scale intensity range percentiles.""" + def _normalize(self, img: Tensor) -> Tensor: q_low = self.lower / 100.0 q_high = self.upper / 100.0 @@ -244,6 +285,32 @@ def __call__(self, img: Tensor) -> Tensor: class BatchedScaleIntensityRangePercentilesd(MapTransform): + """Batched scale intensity range percentiles. + + Parameters + ---------- + keys : str | Iterable[str] + Keys to scale. + lower : float + Lower percentile. + upper : float + Upper percentile. + b_min : float | None + Minimum value. + b_max : float | None + Maximum value. + clip : bool + Whether to clip the values. + relative : bool + Whether to use relative scaling. + channel_wise : bool + Whether to use channel-wise scaling. + dtype : DTypeLike + Data type. + allow_missing_keys : bool, optional + Whether to allow missing keys. By default, False. + """ + def __init__( self, keys: str | Iterable[str], @@ -270,6 +337,28 @@ def __call__(self, data: dict[str, Tensor]) -> dict[str, Tensor]: class BatchedRandAffined(MapTransform): + """Batched random affine. + + Parameters + ---------- + keys : str | Iterable[str] + Keys to affine. + prob : float, optional + Probability of affine. By default, 0.1. + rotate_range : Sequence[tuple[float, float] | float] | float | None + Rotate range. + shear_range : Sequence[tuple[float, float] | float] | float | None + Shear range. + translate_range : Sequence[tuple[float, float] | float] | float | None + Translate range. + scale_range : Sequence[tuple[float, float] | float] | float | None + Scale range. + mode : str, optional + Mode. By default, "bilinear". + allow_missing_keys : bool, optional + Whether to allow missing keys. By default, False. + """ + def __init__( self, keys: str | Iterable[str], diff --git a/viscy/transforms/batched_rand_3d_elasticd.py b/viscy/transforms/batched_rand_3d_elasticd.py index 7e38f01d1..c28ebcbea 100644 --- a/viscy/transforms/batched_rand_3d_elasticd.py +++ b/viscy/transforms/batched_rand_3d_elasticd.py @@ -5,7 +5,29 @@ class BatchedRand3DElasticd(MapTransform, RandomizableTransform): - """Batched 3D elastic deformation for biological structures.""" + """Apply random 3D elastic deformation image data. + + Uses Gaussian-smoothed displacement fields to simulate natural tissue deformation. + + Parameters + ---------- + keys : str or Iterable[str] + Keys of the corresponding items to be transformed. + sigma_range : tuple[float, float] + Range for random sigma values used in Gaussian smoothing. + magnitude_range : tuple[float, float] + Range for random displacement magnitude values. + spatial_size : tuple[int, int, int] or int or None, optional + Expected spatial size of input data. + prob : float, optional + Probability of applying the transform, by default 0.1. + mode : str, optional + Interpolation mode for grid sampling, by default "bilinear". + padding_mode : str, optional + Padding mode for grid sampling, by default "reflection". + allow_missing_keys : bool, optional + Whether to ignore missing keys, by default False. + """ def __init__( self, @@ -76,6 +98,18 @@ def _generate_elastic_field( return torch.stack(displacement_fields) def __call__(self, sample: dict[str, Tensor]) -> dict[str, Tensor]: + """Apply elastic deformation to sample data. + + Parameters + ---------- + sample : dict[str, Tensor] + Dictionary containing image tensors to transform. + + Returns + ------- + dict[str, Tensor] + Dictionary with transformed tensors. + """ self.randomize(None) d = dict(sample) diff --git a/viscy/transforms/batched_rand_histogram_shiftd.py b/viscy/transforms/batched_rand_histogram_shiftd.py index e7fe2b39d..c3ee2d332 100644 --- a/viscy/transforms/batched_rand_histogram_shiftd.py +++ b/viscy/transforms/batched_rand_histogram_shiftd.py @@ -5,7 +5,21 @@ class BatchedRandHistogramShiftd(MapTransform, RandomizableTransform): - """Batched random histogram shifting for intensity distribution changes.""" + """ + + Apply random histogram shifts to modify intensity distributions. + + Parameters + ---------- + keys : str or Iterable[str] + Keys of the corresponding items to be transformed. + shift_range : tuple[float, float], optional + Range for random intensity shift values, by default (-0.1, 0.1). + prob : float, optional + Probability of applying the transform, by default 0.1. + allow_missing_keys : bool, optional + Whether to ignore missing keys, by default False. + """ def __init__( self, @@ -19,6 +33,18 @@ def __init__( self.shift_range = shift_range def __call__(self, sample: dict[str, Tensor]) -> dict[str, Tensor]: + """Apply histogram shift to sample data. + + Parameters + ---------- + sample : dict[str, Tensor] + Dictionary containing image tensors to transform. + + Returns + ------- + dict[str, Tensor] + Dictionary with intensity-shifted tensors. + """ self.randomize(None) d = dict(sample) diff --git a/viscy/transforms/batched_rand_local_pixel_shufflingd.py b/viscy/transforms/batched_rand_local_pixel_shufflingd.py index 73cd9caf4..8cebc41e1 100644 --- a/viscy/transforms/batched_rand_local_pixel_shufflingd.py +++ b/viscy/transforms/batched_rand_local_pixel_shufflingd.py @@ -5,7 +5,23 @@ class BatchedRandLocalPixelShufflingd(MapTransform, RandomizableTransform): - """Batched random local pixel shuffling for texture augmentation.""" + """Apply random local pixel shuffling to simulate texture variations. + + Shuffles pixels within small local patches to add texture noise. + + Parameters + ---------- + keys : str or Iterable[str] + Keys of the corresponding items to be transformed. + patch_size : int, optional + Size of local patches for pixel shuffling, by default 3. + shuffle_prob : float, optional + Probability of shuffling within patches, by default 0.1. + prob : float, optional + Probability of applying the transform, by default 0.1. + allow_missing_keys : bool, optional + Whether to ignore missing keys, by default False. + """ def __init__( self, @@ -72,6 +88,18 @@ def _shuffle_patches(self, data: Tensor) -> Tensor: return result def __call__(self, sample: dict[str, Tensor]) -> dict[str, Tensor]: + """Apply pixel shuffling to sample data. + + Parameters + ---------- + sample : dict[str, Tensor] + Dictionary containing image tensors to transform. + + Returns + ------- + dict[str, Tensor] + Dictionary with pixel-shuffled tensors. + """ self.randomize(None) d = dict(sample) diff --git a/viscy/transforms/batched_rand_sharpend.py b/viscy/transforms/batched_rand_sharpend.py index fe2a54c07..92b992b5f 100644 --- a/viscy/transforms/batched_rand_sharpend.py +++ b/viscy/transforms/batched_rand_sharpend.py @@ -6,7 +6,21 @@ class BatchedRandSharpend(MapTransform, RandomizableTransform): - """Batched random sharpening for microscopy images.""" + """Apply random sharpening to enhance image edges and details. + + Uses 3D convolution with sharpening kernel to enhance fine structures. + + Parameters + ---------- + keys : str or Iterable[str] + Keys of the corresponding items to be transformed. + alpha_range : tuple[float, float], optional + Range for random alpha blending values, by default (0.1, 0.5). + prob : float, optional + Probability of applying the transform, by default 0.1. + allow_missing_keys : bool, optional + Whether to ignore missing keys, by default False. + """ def __init__( self, @@ -40,6 +54,18 @@ def _get_sharpen_kernel(self, device: torch.device, channels: int) -> Tensor: return self._cached_kernel def __call__(self, sample: dict[str, Tensor]) -> dict[str, Tensor]: + """Apply sharpening to sample data. + + Parameters + ---------- + sample : dict[str, Tensor] + Dictionary containing image tensors to transform. + + Returns + ------- + dict[str, Tensor] + Dictionary with sharpened tensors. + """ self.randomize(None) d = dict(sample) diff --git a/viscy/transforms/batched_rand_zstack_shiftd.py b/viscy/transforms/batched_rand_zstack_shiftd.py index e94e5b714..1fbad6638 100644 --- a/viscy/transforms/batched_rand_zstack_shiftd.py +++ b/viscy/transforms/batched_rand_zstack_shiftd.py @@ -5,7 +5,25 @@ class BatchedRandZStackShiftd(MapTransform, RandomizableTransform): - """Batched random Z-axis shifts for 3D microscopy data.""" + """Apply random shifts along Z-axis to simulate focal plane variations. + + Shifts image data in the depth dimension to augment focal plane diversity. + + Parameters + ---------- + keys : str or Iterable[str] + Keys of the corresponding items to be transformed. + max_shift : int, optional + Maximum shift distance in Z direction, by default 3. + prob : float, optional + Probability of applying the transform, by default 0.1. + mode : str, optional + Padding mode for shifted regions, by default "constant". + cval : float, optional + Fill value for constant padding, by default 0.0. + allow_missing_keys : bool, optional + Whether to ignore missing keys, by default False. + """ def __init__( self, @@ -23,6 +41,18 @@ def __init__( self.cval = cval def __call__(self, sample: dict[str, Tensor]) -> dict[str, Tensor]: + """Apply Z-axis shift to sample data. + + Parameters + ---------- + sample : dict[str, Tensor] + Dictionary containing image tensors to transform. + + Returns + ------- + dict[str, Tensor] + Dictionary with Z-shifted tensors. + """ self.randomize(None) d = dict(sample) diff --git a/viscy/translation/engine.py b/viscy/translation/engine.py index 56af9b985..1cc4e23ee 100644 --- a/viscy/translation/engine.py +++ b/viscy/translation/engine.py @@ -1,7 +1,10 @@ +"""Training engine for virtual staining and image translation models.""" + import logging import os import random -from typing import Callable, Literal, Sequence, Union +from collections.abc import Callable, Sequence +from typing import Any, Literal, Union import numpy as np import torch @@ -12,7 +15,7 @@ from monai.optimizers import WarmupCosineSchedule from monai.transforms import DivisiblePad, Rotate90 from torch import Tensor, nn -from torch.optim.lr_scheduler import ConstantLR +from torch.optim.lr_scheduler import ConstantLR, LRScheduler from torchmetrics.functional import ( accuracy, cosine_similarity, @@ -24,7 +27,6 @@ structural_similarity_index_measure, ) from torchmetrics.functional.segmentation import dice_score - from viscy.data.combined import CombinedDataModule from viscy.data.gpu_aug import GPUTransformDataModule from viscy.data.typing import Sample @@ -48,17 +50,23 @@ class MixedLoss(nn.Module): """Mixed reconstruction loss. + Adapted from Zhao et al, https://arxiv.org/pdf/1511.08861.pdf Reduces to simple distances if only one weight is non-zero. - :param float l1_alpha: L1 loss weight, defaults to 0.5 - :param float l2_alpha: L2 loss weight, defaults to 0.0 - :param float ms_dssim_alpha: MS-DSSIM weight, defaults to 0.5 + Parameters + ---------- + l1_alpha : float, optional + L1 loss weight, by default 0.5 + l2_alpha : float, optional + L2 loss weight, by default 0.0 + ms_dssim_alpha : float, optional + MS-DSSIM weight, by default 0.5 """ def __init__( self, l1_alpha: float = 0.5, l2_alpha: float = 0.0, ms_dssim_alpha: float = 0.5 - ): + ) -> None: super().__init__() if not any([l1_alpha, l2_alpha, ms_dssim_alpha]): raise ValueError("Loss term weights cannot be all zero!") @@ -67,7 +75,21 @@ def __init__( self.ms_dssim_alpha = ms_dssim_alpha @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32) - def forward(self, preds, target): + def forward(self, preds: Tensor, target: Tensor) -> Tensor: + """Compute mixed reconstruction loss. + + Parameters + ---------- + preds : Tensor + Predicted tensor + target : Tensor + Target tensor + + Returns + ------- + Tensor + Combined loss value + """ loss = 0 if self.l1_alpha: # the gaussian in the reference is not used @@ -84,7 +106,28 @@ def forward(self, preds, target): class MaskedMSELoss(nn.Module): - def forward(self, preds, original, mask): + """Masked mean squared error loss. + + Computes MSE loss only for masked regions. + """ + + def forward(self, preds: Tensor, original: Tensor, mask: Tensor) -> Tensor: + """Compute masked MSE loss. + + Parameters + ---------- + preds : Tensor + Predicted tensor. + original : Tensor + Original tensor. + mask : Tensor + Binary mask tensor. + + Returns + ------- + Tensor + Masked MSE loss value. + """ loss = F.mse_loss(preds, original, reduction="none") loss = (loss.mean(2) * mask).sum() / mask.sum() return loss @@ -93,44 +136,52 @@ def forward(self, preds, original, mask): class VSUNet(LightningModule): """Regression U-Net module for virtual staining. - :param dict model_config: model config, - defaults to :py:class:`viscy.unet.utils.model.ModelDefaults25D` - :param Union[nn.Module, MixedLoss] loss_function: - loss function in training/validation, - if a dictionary, should specify weights of each term - ('l1_alpha', 'l2_alpha', 'ssim_alpha') - defaults to L2 (mean squared error) - :param float lr: learning rate in training, defaults to 1e-3 - :param Literal['WarmupCosine', 'Constant'] schedule: - learning rate scheduler, defaults to "Constant" - :param str ckpt_path: path to the checkpoint to load weights, defaults to None - :param int log_batches_per_epoch: - number of batches to log each training/validation epoch, - has to be smaller than steps per epoch, defaults to 8 - :param int log_samples_per_batch: - number of samples to log each training/validation batch, - has to be smaller than batch size, defaults to 1 - :param Sequence[int] example_input_yx_shape: - XY shape of the example input for network graph tracing, defaults to (256, 256) - :param str test_cellpose_model_path: - path to the CellPose model for testing segmentation, defaults to None - :param float test_cellpose_diameter: - diameter parameter of the CellPose model for testing segmentation, - defaults to None - :param bool test_evaluate_cellpose: - evaluate the performance of the CellPose model instead of the trained model - in test stage, defaults to False - :param bool test_time_augmentations: - apply test time augmentations in test stage, defaults to False - :param Literal['mean', 'median', 'product'] tta_type: - type of test time augmentations aggregation, defaults to "mean" + Parameters + ---------- + architecture : Literal["2D", "UNeXt2", "2.5D", "3D", "fcmae", "UNeXt2_2D"] + Model architecture type. + model_config : dict, optional + Model config, defaults to :py:class:`viscy.unet.utils.model.ModelDefaults25D`, + by default {}. + loss_function : Union[nn.Module, MixedLoss], optional + Loss function in training/validation. If a dictionary, should specify weights + of each term ('l1_alpha', 'l2_alpha', 'ssim_alpha'), defaults to L2 + (mean squared error), by default None. + lr : float, optional + Learning rate in training, by default 1e-3. + schedule : Literal['WarmupCosine', 'Constant'], optional + Learning rate scheduler, by default "Constant". + freeze_encoder : bool, optional + Whether to freeze encoder weights, by default False. + ckpt_path : str, optional + Path to the checkpoint to load weights, by default None. + log_batches_per_epoch : int, optional + Number of batches to log each training/validation epoch, + has to be smaller than steps per epoch, by default 8. + log_samples_per_batch : int, optional + Number of samples to log each training/validation batch, + has to be smaller than batch size, by default 1. + example_input_yx_shape : Sequence[int], optional + XY shape of the example input for network graph tracing, by default (256, 256). + test_cellpose_model_path : str, optional + Path to the CellPose model for testing segmentation, by default None. + test_cellpose_diameter : float, optional + Diameter parameter of the CellPose model for testing segmentation, + by default None. + test_evaluate_cellpose : bool, optional + Evaluate the performance of the CellPose model instead of the trained model + in test stage, by default False. + test_time_augmentations : bool, optional + Apply test time augmentations in test stage, by default False. + tta_type : Literal['mean', 'median', 'product'], optional + Type of test time augmentations aggregation, by default "mean". """ def __init__( self, architecture: Literal["2D", "UNeXt2", "2.5D", "3D", "fcmae", "UNeXt2_2D"], model_config: dict = {}, - loss_function: Union[nn.Module, MixedLoss] | None = None, + loss_function: nn.Module | MixedLoss | None = None, lr: float = 1e-3, schedule: Literal["WarmupCosine", "Constant"] = "Constant", freeze_encoder: bool = False, @@ -185,9 +236,35 @@ def __init__( ) # loading only weights def forward(self, x: Tensor) -> Tensor: + """Forward pass through the model. + + Parameters + ---------- + x : Tensor + Input tensor. + + Returns + ------- + Tensor + Model output. + """ return self.model(x) - def training_step(self, batch: Sample | Sequence[Sample], batch_idx: int): + def training_step(self, batch: Sample | Sequence[Sample], batch_idx: int) -> Tensor: + """Execute single training step. + + Parameters + ---------- + batch : Sample | Sequence[Sample] + Training batch data. + batch_idx : int + Batch index. + + Returns + ------- + Tensor + Training loss. + """ losses = [] batch_size = 0 if not isinstance(batch, Sequence): @@ -216,7 +293,20 @@ def training_step(self, batch: Sample | Sequence[Sample], batch_idx: int): ) return loss_step - def validation_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): + def validation_step( + self, batch: Sample, batch_idx: int, dataloader_idx: int = 0 + ) -> None: + """Execute single validation step. + + Parameters + ---------- + batch : Sample + Validation batch data. + batch_idx : int + Batch index. + dataloader_idx : int + Dataloader index for multi-dataloader validation. By default, 0. + """ source: Tensor = batch["source"] target: Tensor = batch["target"] pred = self.forward(source) @@ -235,7 +325,16 @@ def validation_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0 detach_sample((source, target, pred), self.log_samples_per_batch) ) - def test_step(self, batch: Sample, batch_idx: int): + def test_step(self, batch: Sample, batch_idx: int) -> None: + """Execute single test step. + + Parameters + ---------- + batch : Sample + Test batch data. + batch_idx : int + Batch index. + """ source = batch["source"] target = batch["target"] center_index = target.shape[-3] // 2 @@ -266,7 +365,7 @@ def test_step(self, batch: Sample, batch_idx: int): else: self._log_segmentation_metrics(None, None) - def _log_regression_metrics(self, pred: Tensor, target: Tensor): + def _log_regression_metrics(self, pred: Tensor, target: Tensor) -> None: # paired image translation metrics self.log_dict( { @@ -289,7 +388,7 @@ def _log_regression_metrics(self, pred: Tensor, target: Tensor): on_epoch=True, ) - def _cellpose_predict(self, pred: Tensor, name: str) -> torch.ShortTensor: + def _cellpose_predict(self, pred: Tensor, name: str) -> Tensor: pred_labels_np = self.cellpose_model.eval( pred.cpu().numpy(), channels=[0, 0], diameter=self.test_cellpose_diameter )[0].astype(np.int16) @@ -297,8 +396,8 @@ def _cellpose_predict(self, pred: Tensor, name: str) -> torch.ShortTensor: return torch.from_numpy(pred_labels_np).to(self.device) def _log_segmentation_metrics( - self, pred_labels: torch.ShortTensor, target_labels: torch.ShortTensor - ): + self, pred_labels: Tensor, target_labels: Tensor + ) -> None: compute = pred_labels is not None if compute: pred_binary = pred_labels > 0 @@ -337,7 +436,25 @@ def _log_segmentation_metrics( on_epoch=False, ) - def predict_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): + def predict_step( + self, batch: Sample, batch_idx: int, dataloader_idx: int = 0 + ) -> Tensor: + """Execute single prediction step. + + Parameters + ---------- + batch : Sample + Prediction batch data. + batch_idx : int + Batch index. + dataloader_idx : int, default=0 + Dataloader index. + + Returns + ------- + Tensor + Model prediction. + """ source = batch["source"] if self.test_time_augmentations: prediction = self.perform_test_time_augmentations(source) @@ -349,13 +466,21 @@ def predict_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): return prediction def perform_test_time_augmentations(self, source: Tensor) -> Tensor: - """Perform test time augmentations on the input source - and aggregate the predictions using the specified method. + """Perform test time augmentations and aggregate predictions. - :param source: input tensor - :return: aggregated prediction - """ + Apply rotational augmentations to input source and aggregate the + predictions using the specified method. + Parameters + ---------- + source : Tensor + Input tensor. + + Returns + ------- + Tensor + Aggregated prediction. + """ # Save the yx coords to crop post rotations self._original_shape_yx = source.shape[-2:] predictions = [] @@ -384,11 +509,13 @@ def perform_test_time_augmentations(self, source: Tensor) -> Tensor: prediction = torch.exp(log_prediction_sum) return prediction - def on_train_epoch_end(self): + def on_train_epoch_end(self) -> None: + """Log training samples at end of epoch.""" self._log_samples("train_samples", self.training_step_outputs) self.training_step_outputs = [] - def on_validation_epoch_end(self): + def on_validation_epoch_end(self) -> None: + """Log validation samples and compute average loss at end of epoch.""" super().on_validation_epoch_end() self._log_samples("val_samples", self.validation_step_outputs) # average within each dataloader @@ -401,8 +528,14 @@ def on_validation_epoch_end(self): self.validation_step_outputs.clear() self.validation_losses.clear() - def on_test_start(self): - """Load CellPose model for segmentation.""" + def on_test_start(self) -> None: + """Load CellPose model for segmentation. + + Raises + ------ + ImportError + If CellPose is not installed. + """ if self.test_cellpose_model_path is not None: try: from cellpose.models import CellposeModel @@ -417,14 +550,25 @@ def on_test_start(self): '`pip install viscy"[metrics]"`' ) - def on_predict_start(self): - """Pad the input shape to be divisible by the downsampling factor. + def on_predict_start(self) -> None: + """Set up prediction padding transform. + + Pad the input shape to be divisible by the downsampling factor. The inverse of this transform crops the prediction to original shape. """ down_factor = 2**self.model.num_blocks self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) - def configure_optimizers(self): + def configure_optimizers( + self, + ) -> tuple[list[torch.optim.Optimizer], list[LRScheduler]]: + """Configure optimizer and learning rate scheduler. + + Returns + ------- + tuple[list[torch.optim.Optimizer], list[LRScheduler]] + Tuple containing a list of optimizers and schedulers. + """ if self.freeze_encoder: self.model: FullyConvolutionalMAE self.model.encoder.requires_grad_(False) @@ -442,7 +586,7 @@ def configure_optimizers(self): ) return [optimizer], [scheduler] - def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): + def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]) -> None: grid = render_images(imgs) self.logger.experiment.add_image( key, grid, self.current_epoch, dataformats="HWC" @@ -477,8 +621,7 @@ def _crop_to_original(self, tensor: Tensor) -> Tensor: class AugmentedPredictionVSUNet(LightningModule): - """Apply arbitrary collection of test-time augmentations - for image translation prediction. + """Apply arbitrary collection of test-time augmentations for image translation prediction. Parameters ---------- @@ -528,15 +671,51 @@ def __init__( self._reduction = reduction def forward(self, x: Tensor) -> Tensor: + """Forward pass through the model. + + Parameters + ---------- + x : Tensor + Input tensor. + + Returns + ------- + Tensor + Model output. + """ return self.model(x) def setup(self, stage: str) -> None: + """Set up the Lightning module for the specified stage. + + Parameters + ---------- + stage : str + Stage name (only 'predict' is supported). + + Raises + ------ + NotImplementedError + If stage is not 'predict'. + """ if stage != "predict": raise NotImplementedError( f"Only the 'predict' stage is supported by {type(self)}" ) def _reduce_predictions(self, preds: list[Tensor]) -> Tensor: + """Reduce multiple predictions using specified method. + + Parameters + ---------- + preds : list[Tensor] + List of prediction tensors. + + Returns + ------- + Tensor + Reduced prediction tensor. + """ prediction = torch.stack(preds, dim=0) if self._reduction == "mean": prediction = prediction.mean(dim=0) @@ -547,6 +726,22 @@ def _reduce_predictions(self, preds: list[Tensor]) -> Tensor: def predict_step( self, batch: Sample, batch_idx: int, dataloader_idx: int = 0 ) -> Tensor: + """Execute single prediction step with augmentations. + + Parameters + ---------- + batch : Sample + Prediction batch data. + batch_idx : int + Batch index. + dataloader_idx : int, default=0 + Dataloader index. + + Returns + ------- + Tensor + Aggregated prediction from augmented inputs. + """ source = batch["source"] preds = [] for forward_t, inverse_t in zip( @@ -566,16 +761,36 @@ def predict_step( class FcmaeUNet(VSUNet): + """Fully Convolutional Masked Autoencoder U-Net. + + Extends VSUNet to support masked autoencoder pre-training and supervised + fine-tuning for virtual staining tasks. + + Parameters + ---------- + fit_mask_ratio : float, default=0.0 + Masking ratio for FCMAE pre-training. + **kwargs + Additional arguments passed to VSUNet. + """ + def __init__( self, fit_mask_ratio: float = 0.0, **kwargs, - ): + ) -> None: super().__init__(architecture="fcmae", **kwargs) self.fit_mask_ratio = fit_mask_ratio self.save_hyperparameters(ignore=["loss_function"]) - def on_fit_start(self): + def on_fit_start(self) -> None: + """Set up data modules and validate configuration for training. + + Raises + ------ + ValueError + If data module configuration is incompatible with FCMAE training. + """ dm = self.trainer.datamodule if not isinstance(dm, CombinedDataModule): raise ValueError( @@ -595,12 +810,42 @@ def on_fit_start(self): f"got {type(self.loss_function)}" ) - def forward(self, x: Tensor, mask_ratio: float = 0.0): + def forward( + self, x: Tensor, mask_ratio: float = 0.0 + ) -> tuple[Tensor, Tensor] | Tensor: + """Forward pass with optional masking. + + Parameters + ---------- + x : Tensor + Input tensor. + mask_ratio : float, default=0.0 + Masking ratio for FCMAE mode. + + Returns + ------- + Tensor or tuple + Model output, optionally with mask if mask_ratio > 0. + """ return self.model(x, mask_ratio) def forward_fit_fcmae( self, batch: Sample, return_target: bool = False ) -> tuple[Tensor, Tensor | None, Tensor]: + """Forward pass for FCMAE pre-training. + + Parameters + ---------- + batch : Sample + Input batch. + return_target : bool, default=False + Whether to return masked target for logging. + + Returns + ------- + tuple[Tensor, Tensor or None, Tensor] + Prediction, target (if requested), and loss. + """ x = batch["source"] pred, mask = self.forward(x, mask_ratio=self.fit_mask_ratio) loss = self.loss_function(pred, x, mask) @@ -611,6 +856,18 @@ def forward_fit_fcmae( return pred, target, loss def forward_fit_supervised(self, batch: Sample) -> tuple[Tensor, Tensor, Tensor]: + """Forward pass for supervised training. + + Parameters + ---------- + batch : Sample + Input batch containing source and target. + + Returns + ------- + tuple[Tensor, Tensor, Tensor] + Prediction, target, and loss. + """ x = batch["source"] target = batch["target"] pred = self.forward(x) @@ -620,6 +877,23 @@ def forward_fit_supervised(self, batch: Sample) -> tuple[Tensor, Tensor, Tensor] def forward_fit_task( self, batch: Sample, batch_idx: int ) -> tuple[Tensor, Tensor | None, Tensor]: + """Forward pass for current training task. + + Automatically selects FCMAE pre-training or supervised training + based on model configuration. + + Parameters + ---------- + batch : Sample + Input batch. + batch_idx : int + Batch index. + + Returns + ------- + tuple[Tensor, Tensor | None, Tensor] + Prediction, target, and loss. + """ if self.model.pretraining: if batch_idx < self.log_batches_per_epoch: return_target = True @@ -630,6 +904,18 @@ def forward_fit_task( @torch.no_grad() def train_transform_and_collate(self, batch: list[dict[str, Tensor]]) -> Sample: + """Apply training transforms and collate batch data. + + Parameters + ---------- + batch : list[dict[str, Tensor]] + List of batch dictionaries from multiple data modules. + + Returns + ------- + Sample + Collated and transformed sample. + """ transformed = [] for dataset_batch, dm in zip(batch, self.datamodules): dataset_batch = dm.train_gpu_transforms(dataset_batch) @@ -642,10 +928,38 @@ def train_transform_and_collate(self, batch: list[dict[str, Tensor]]) -> Sample: def val_transform_and_collate( self, batch: list[Sample], dataloader_idx: int ) -> Tensor: + """Apply validation transforms and collate batch data. + + Parameters + ---------- + batch : list[Sample] + List of samples. + dataloader_idx : int + Index of the validation dataloader. + + Returns + ------- + Tensor + Collated and transformed batch. + """ batch = self.datamodules[dataloader_idx].val_gpu_transforms(batch) return collate_meta_tensor(batch) def training_step(self, batch: list[list[Sample]], batch_idx: int) -> Tensor: + """Execute single training step for FCMAE. + + Parameters + ---------- + batch : list[list[Sample]] + Nested list of samples from multiple data modules. + batch_idx : int + Batch index. + + Returns + ------- + Tensor + Training loss. + """ batch = self.train_transform_and_collate(batch) pred, target, loss = self.forward_fit_task(batch, batch_idx) if batch_idx < self.log_batches_per_epoch: @@ -669,6 +983,17 @@ def training_step(self, batch: list[list[Sample]], batch_idx: int) -> Tensor: def validation_step( self, batch: list[Sample], batch_idx: int, dataloader_idx: int = 0 ) -> None: + """Execute single validation step for FCMAE. + + Parameters + ---------- + batch : list[Sample] + List of validation samples. + batch_idx : int + Batch index. + dataloader_idx : int, default=0 + Dataloader index. + """ batch = self.val_transform_and_collate(batch, dataloader_idx) pred, target, loss = self.forward_fit_task(batch, batch_idx) if dataloader_idx + 1 > len(self.validation_losses): diff --git a/viscy/translation/evaluation.py b/viscy/translation/evaluation.py index 11812f4fe..42ef08c48 100644 --- a/viscy/translation/evaluation.py +++ b/viscy/translation/evaluation.py @@ -5,7 +5,6 @@ from lightning.pytorch import LightningModule from torchmetrics.functional import accuracy, jaccard_index from torchmetrics.functional.segmentation import dice_score - from viscy.data.typing import SegmentationSample from viscy.translation.evaluation_metrics import mean_average_precision @@ -13,13 +12,28 @@ class SegmentationMetrics2D(LightningModule): - """Test runner for 2D segmentation.""" + """Test runner for 2D segmentation. + + Parameters + ---------- + aggregate_epoch : bool, optional + Whether to aggregate the metrics over the epoch. Defaults to False. + """ def __init__(self, aggregate_epoch: bool = False) -> None: super().__init__() self.aggregate_epoch = aggregate_epoch def test_step(self, batch: SegmentationSample, batch_idx: int) -> None: + """Compute segmentation metrics for a test batch. + + Parameters + ---------- + batch : SegmentationSample + Batch containing prediction and target segmentation masks + batch_idx : int + Batch index + """ pred = batch["pred"] target = batch["target"] if not pred.shape[0] == 1 and target.shape[0] == 1: diff --git a/viscy/translation/evaluation_metrics.py b/viscy/translation/evaluation_metrics.py index bb89858f2..2011b4def 100644 --- a/viscy/translation/evaluation_metrics.py +++ b/viscy/translation/evaluation_metrics.py @@ -7,18 +7,28 @@ import torch import torch.nn.functional as F from monai.metrics.regression import compute_ssim_and_cs +from numpy.typing import NDArray from scipy.optimize import linear_sum_assignment from skimage.measure import label, regionprops from torchmetrics.detection.mean_ap import MeanAveragePrecision from torchvision.ops import masks_to_boxes -def VOI_metric(target, prediction): - """variation of information metric +def VOI_metric(target: np.array, prediction: np.array) -> float: + """Variation of information metric + Reports overlap between predicted and ground truth mask - : param np.array target: ground truth mask - : param np.array prediction: model infered FL image cellpose mask - : return float VI: VI for image masks + + Parameters + ---------- + target : np.array + Ground truth mask + prediction : np.array + Model inferred FL image cellpose mask + + Returns + ------- + float VI: VI for image masks """ # cellpose segmentation of predicted image: outputs labl mask pred_bin = prediction > 0 @@ -55,7 +65,27 @@ def VOI_metric(target, prediction): return [VI] -def POD_metric(target_bin, pred_bin): +def POD_metric(target_bin: NDArray, pred_bin: NDArray): + """ + Probability of detection metric for object matching. + + Parameters + ---------- + target_bin : NDArray + Binary ground truth mask. + pred_bin : NDArray + Binary predicted mask. + + Returns + ------- + tuple[float, float, float, int, int] + POD and various detection statistics. + - POD: Probability of detection + - FAR: False alarm rate + - PCD: Probability of correct detection + - n_targObj: Number of target objects + - n_predObj: Number of predicted objects + """ # pred_bin = cpmask_array(prediction) # relabel mask for ordered labelling across images for efficient LAP mapping @@ -120,9 +150,16 @@ def POD_metric(target_bin, pred_bin): def labels_to_masks(labels: torch.ShortTensor) -> torch.BoolTensor: """Convert integer labels to a stack of boolean masks. - :param torch.ShortTensor labels: 2D labels where each value is an object - (0 is background) - :return torch.BoolTensor: Boolean masks of shape (objects, H, W) + Parameters + ---------- + labels : torch.ShortTensor + 2D labels where each value is an object (0 is background) + + Returns + ------- + torch.BoolTensor + Boolean masks of shape (objects, H, W) + """ if labels.ndim != 2: raise ValueError(f"Labels must be 2D, got shape {labels.shape}.") @@ -141,9 +178,15 @@ def labels_to_masks(labels: torch.ShortTensor) -> torch.BoolTensor: def labels_to_detection(labels: torch.ShortTensor) -> dict[str, torch.Tensor]: """Convert integer labels to a torchvision/torchmetrics detection dictionary. - :param torch.ShortTensor labels: 2D labels where each value is an object - (0 is background) - :return dict[str, torch.Tensor]: detection boxes, scores, labels, and masks + Parameters + ---------- + labels : torch.ShortTensor + 2D labels where each value is an object (0 is background) + + Returns + ------- + dict[str, torch.Tensor] + detection boxes, scores, labels, and masks """ masks = labels_to_masks(labels) boxes = masks_to_boxes(masks) @@ -166,11 +209,20 @@ def mean_average_precision( ) -> dict[str, torch.Tensor]: """Compute the mAP metric for instance segmentation. - :param torch.ShortTensor pred_labels: 2D integer prediction labels - :param torch.ShortTensor target_labels: 2D integer prediction labels - :param dict **kwargs: keyword arguments passed to + Parameters + ---------- + pred_labels : torch.ShortTensor + 2D integer prediction labels + target_labels : torch.ShortTensor + 2D integer prediction labels + **kwargs : dict + Keyword arguments passed to :py:class:`torchmetrics.detection.MeanAveragePrecision` - :return dict[str, torch.Tensor]: COCO-style metrics + + Returns + ------- + dict[str, torch.Tensor] + COCO-style metrics """ defaults = dict( iou_type="segm", box_format="xyxy", max_detection_thresholds=[1, 100, 10000] @@ -191,15 +243,24 @@ def ssim_25d( return_contrast_sensitivity: bool = False, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: """Multi-scale SSIM loss function for 2.5D volumes (3D with small depth). + Uses uniform kernel (windows), depth-dimension window size equals to depth size. - :param torch.Tensor preds: predicted batch (B, C, D, W, H) - :param torch.Tensor target: target batch - :param tuple[int, int] in_plane_window_size: kernel width and height, - by default (11, 11) - :param bool return_contrast_sensitivity: whether to return contrast sensitivity - :return torch.Tensor: SSIM for the batch - :return Optional[torch.Tensor]: contrast sensitivity + Parameters + ---------- + preds : torch.Tensor + predicted batch (B, C, D, W, H) + target : torch.Tensor + target batch + in_plane_window_size : tuple[int, int], optional + kernel width and height, by default (11, 11) + return_contrast_sensitivity : bool, optional + whether to return contrast sensitivity + + Returns + ------- + torch.Tensor: SSIM for the batch + Optional[torch.Tensor]: contrast sensitivity """ if preds.ndim != 5: raise ValueError( @@ -233,6 +294,7 @@ def ms_ssim_25d( betas: Sequence[float] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333), ) -> torch.Tensor: """Multi-scale SSIM for 2.5D volumes (3D with small depth). + Uses uniform kernel (windows), depth-dimension window size equals to depth size. Depth dimension is not downsampled. @@ -240,15 +302,23 @@ def ms_ssim_25d( Original license: Copyright The Lightning team, http://www.apache.org/licenses/LICENSE-2.0 - :param torch.Tensor preds: predicted images - :param torch.Tensor target: target images - :param tuple[int, int] in_plane_window_size: kernel width and height, - defaults to (11, 11) - :param bool clamp: clamp to [1e-6, 1] for training stability when used in loss, - defaults to False - :param Sequence[float] betas: exponents of each resolution, - defaults to (0.0448, 0.2856, 0.3001, 0.2363, 0.1333) - :return torch.Tensor: multi-scale SSIM + Parameters + ---------- + preds : torch.Tensor + predicted images + target : torch.Tensor + target images + in_plane_window_size : tuple[int, int], optional + kernel width and height, by default (11, 11) + clamp : bool, optional + clamp to [1e-6, 1] for training stability when used in loss, + by default False + betas : Sequence[float], optional + exponents of each resolution, by default (0.0448, 0.2856, 0.3001, 0.2363, 0.1333) + + Returns + ------- + torch.Tensor: multi-scale SSIM """ base_min = 1e-4 mcs_list = [] diff --git a/viscy/translation/predict_writer.py b/viscy/translation/predict_writer.py index 75d6d4152..4aaa3394a 100644 --- a/viscy/translation/predict_writer.py +++ b/viscy/translation/predict_writer.py @@ -1,7 +1,8 @@ import logging import os +from collections.abc import Sequence from pathlib import Path -from typing import Literal, Optional, Sequence +from typing import Literal, Optional import numpy as np import torch @@ -9,7 +10,6 @@ from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.callbacks import BasePredictionWriter from numpy.typing import DTypeLike, NDArray - from viscy.data.hcs import HCSDataModule, Sample __all__ = ["HCSPredictionWriter"] @@ -19,6 +19,7 @@ def _pad_shape(shape: tuple[int, ...], target: int = 5) -> tuple[int, ...]: """ Pad shape tuple to a target length. + Vendored from ``iohub.ngff.nodes._pad_shape()``. """ pad = target - len(shape) @@ -49,7 +50,7 @@ def _blend_in(old_stack: NDArray, new_stack: NDArray, z_slice: slice) -> NDArray weights are determined by the position within the range of slices. If the start of `z_slice` is 0, the function returns the `new_stack` unchanged. - Parameters: + Parameters ---------- old_stack : NDArray The original stack of images to be blended. @@ -59,12 +60,11 @@ def _blend_in(old_stack: NDArray, new_stack: NDArray, z_slice: slice) -> NDArray A slice object indicating the range of slices over which to perform the blending. The start and stop attributes of the slice determine the range. - Returns: + Returns ------- NDArray The blended stack of images. If `z_slice.start` is 0, returns `new_stack` unchanged. """ - if z_slice.start == 0: return new_stack depth = z_slice.stop - z_slice.start @@ -81,12 +81,15 @@ def _blend_in(old_stack: NDArray, new_stack: NDArray, z_slice: slice) -> NDArray class HCSPredictionWriter(BasePredictionWriter): """Callback to store virtual staining predictions as HCS OME-Zarr. - :param str output_store: Path to the zarr store to store output - :param bool write_input: Write the source and target channels too - (must be writing to a new store), - defaults to False - :param Literal['batch', 'epoch', 'batch_and_epoch'] write_interval: - When to write, defaults to "batch" + Parameters + ---------- + output_store : str + Path to the zarr store to store output. + write_input : bool, optional + Write the source and target channels too (must be writing to a new store), + by default False. + write_interval : Literal['batch', 'epoch', 'batch_and_epoch'], optional + When to write, by default "batch". """ def __init__( @@ -117,6 +120,16 @@ def _get_scale_metadata(self, metadata_store: Path) -> None: _logger.debug(f"Dataset scale {self._dataset_scale}.") def on_predict_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + """ + Initialize output store and set up prediction writing at start of prediction. + + Parameters + ---------- + trainer : Trainer + PyTorch Lightning trainer instance. + pl_module : LightningModule + PyTorch Lightning module being used for predictions. + """ dm: HCSDataModule = trainer.datamodule self._get_scale_metadata(dm.data_path) self.z_padding = dm.z_window_size // 2 if dm.target_2d else 0 @@ -156,21 +169,63 @@ def write_on_batch_end( trainer: Trainer, pl_module: LightningModule, prediction: torch.Tensor, - batch_indices: Optional[Sequence[int]], + batch_indices: Sequence[int] | None, batch: Sample, batch_idx: int, dataloader_idx: int, ) -> None: + """ + Write predictions to output store at the end of each batch. + + Parameters + ---------- + trainer : Trainer + PyTorch Lightning trainer instance. + pl_module : LightningModule + PyTorch Lightning module being used for predictions. + prediction : torch.Tensor + Batch of predictions from the model. + batch_indices : Sequence[int] | None + Indices of the batch samples. + batch : Sample + Input batch data. + batch_idx : int + Index of the current batch. + dataloader_idx : int + Index of the current dataloader. + """ _logger.debug(f"Writing batch {batch_idx}.") for sample_index, _ in enumerate(batch["index"][0]): self.write_sample(batch, prediction[sample_index], sample_index) def on_predict_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + """ + Close output store at the end of prediction. + + Parameters + ---------- + trainer : Trainer + PyTorch Lightning trainer instance. + pl_module : LightningModule + PyTorch Lightning module being used for predictions. + """ self.plate.close() def write_sample( self, batch: Sample, sample_prediction: torch.Tensor, sample_index: int ) -> None: + """ + Write a single sample prediction to the output store. + + Parameters + ---------- + batch : Sample + Input batch data containing metadata for the sample. + sample_prediction : torch.Tensor + Prediction tensor for the sample. + sample_index : int + Index of the sample within the batch. + """ _logger.debug(f"Writing sample {sample_index}.") sample_prediction = sample_prediction.cpu().numpy() img_name, t_index, z_index = [batch["index"][i][sample_index] for i in range(3)] diff --git a/viscy/unet/networks/Unet25D.py b/viscy/unet/networks/Unet25D.py index 8a34042d7..9cef5e93b 100644 --- a/viscy/unet/networks/Unet25D.py +++ b/viscy/unet/networks/Unet25D.py @@ -1,3 +1,5 @@ +from typing import Literal + import torch import torch.nn as nn @@ -5,53 +7,67 @@ class Unet25d(nn.Module): - def __name__(self): + """2.5D U-Net neural network for volumetric image translation. + + A hybrid approach that processes 3D input stacks but outputs 2D predictions. + Combines 3D spatial information with 2D computational efficiency. + + Architecture takes in stack of 2D inputs given as a 3D tensor + and returns a 2D interpretation. Learns 3D information based upon input stack, + but speeds up training by compressing 3D information before the decoding path. + Uses interruption conv layers in the U-Net skip paths to + compress information with z-channel convolution. + + References + ---------- + https://elifesciences.org/articles/55502 + + Parameters + ---------- + in_channels : int, optional + Number of feature channels in (1 or more), by default 1. + out_channels : int, optional + Number of feature channels out (1 or more), by default 1. + in_stack_depth : int, optional + Depth of input stack in z, by default 5. + out_stack_depth : int, optional + Depth of output stack, by default 1. + xy_kernel_size : int or tuple of int, optional + Size of x and y dimensions of conv kernels in blocks, by default (3, 3). + residual : bool, optional + Whether to use residual connections, by default False. + dropout : float, optional + Probability of dropout, between 0 and 0.5, by default 0.2. + num_blocks : int, optional + Number of convolutional blocks on encoder and decoder paths, by default 4. + num_block_layers : int, optional + Number of layer sequences repeated per block, by default 2. + num_filters : list of int, optional + List of filters/feature levels at each conv block depth, by default []. + task : str, optional + Network task (for virtual staining this is regression), + one of 'seg','reg', by default "seg". + """ + + def __name__(self) -> str: + """Return the name of the network architecture.""" return "Unet25d" def __init__( self, - in_channels=1, - out_channels=1, - in_stack_depth=5, - out_stack_depth=1, - xy_kernel_size=(3, 3), - residual=False, - dropout=0.2, - num_blocks=4, - num_block_layers=2, - num_filters=[], - task="seg", - ): - """ - Instance of 2.5D Unet. - 1.) https://elifesciences.org/articles/55502 - - Architecture takes in stack of 2d inputs given as a 3d tensor - and returns a 2d interpretation. - Learns 3d information based upon input stack, - but speeds up training by compressing 3d information before the decoding path. - Uses interruption conv layers in the Unet skip paths to - compress information with z-channel convolution. - - :param int in_channels: number of feature channels in (1 or more) - :param int out_channels: number of feature channels out (1 or more) - :param int input_stack_depth: depth of input stack in z - :param int output_stack_depth: depth of output stack - :param int/tuple(int, int) xy_kernel_size: size of x and y dimensions - of conv kernels in blocks - :param bool residual: see name - :param float dropout: probability of dropout, between 0 and 0.5 - :param int num_blocks: number of convolutional blocks - on encoder and decoder paths - :param int num_block_layers: number of layer sequences repeated per block - :param list[int] num_filters: list of filters/feature levels - at each conv block depth - :param str task: network task (for virtual staining this is regression), - one of 'seg','reg' - :param str debug_mode: if true logs features at each step of architecture, - must be manually set - """ - super(Unet25d, self).__init__() + in_channels: int = 1, + out_channels: int = 1, + in_stack_depth: int = 5, + out_stack_depth: int = 1, + xy_kernel_size: tuple[int, int] = (3, 3), + residual: bool = False, + dropout: float = 0.2, + num_blocks: int = 4, + num_block_layers: int = 2, + num_filters: list[int] = [], + task: Literal["seg", "reg"] = "seg", + ) -> None: + super().__init__() self.in_channels = in_channels self.num_blocks = num_blocks self.kernel_size = xy_kernel_size @@ -202,7 +218,7 @@ def __init__( # ----- Feature Logging ----- # self.log_save_folder = None - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward call of network. @@ -213,9 +229,16 @@ def forward(self, x): between them (decoder) => terminal block collapses to output dimensions - :param torch.tensor x: input image - """ + Parameters + ---------- + x : torch.Tensor + Input image. + Returns + ------- + torch.Tensor + Output image. + """ # encoder skip_tensors = [] for i in range(self.num_blocks): @@ -240,16 +263,20 @@ def forward(self, x): x = self.terminal_block(x) return x - def register_modules(self, module_list, name): - """ - Helper function that registers modules stored in a list to the model object - so that the can be seen by PyTorch optimizer. + def register_modules(self, module_list: list[nn.Module], name: str) -> None: + """Register modules stored in a list to the model object. + + So that they can be seen by PyTorch optimizer. Used to enable model graph creation with - non-sequential model types and dynamic layer numbers + non-sequential model types and dynamic layer numbers. - :param list(torch.nn.module) module_list: list of modules to register - :param str name: name of module type + Parameters + ---------- + module_list : list[torch.nn.module] + List of modules to register + name : str + Name of module type """ for i, module in enumerate(module_list): self.add_module(f"{name}_{str(i)}", module) diff --git a/viscy/unet/networks/Unet2D.py b/viscy/unet/networks/Unet2D.py index 0edd95362..2d3b5ac22 100644 --- a/viscy/unet/networks/Unet2D.py +++ b/viscy/unet/networks/Unet2D.py @@ -1,3 +1,5 @@ +"""2D U-Net implementation for image-to-image translation tasks.""" + import torch import torch.nn as nn @@ -5,7 +7,43 @@ class Unet2d(nn.Module): + """2D U-Net neural network for image-to-image translation. + + A convolutional neural network following the U-Net architecture for 2D images. + Supports both segmentation and regression tasks with configurable depth and filters. + + Follows 2D UNet Architecture: + + References + ---------- + 1) U-Net: https://arxiv.org/pdf/1505.04597.pdf + 2) Residual U-Net: https://arxiv.org/pdf/1711.10684.pdf + + Parameters + ---------- + in_channels : int, optional + Number of feature channels in, by default 1. + out_channels : int, optional + Number of feature channels out, by default 1. + kernel_size : int or tuple of int, optional + Size of x and y dimensions of conv kernels in blocks, by default (3, 3). + residual : bool, optional + Whether to use residual connections, by default False. + dropout : float, optional + Probability of dropout, between 0 and 0.5, by default 0.2. + num_blocks : int, optional + Number of convolutional blocks on encoder and decoder, by default 4. + num_block_layers : int, optional + Number of layers per block, by default 2. + num_filters : list of int, optional + List of filters/feature levels at each conv block depth, by default []. + task : str, optional + Network task (for virtual staining this is regression), + one of 'seg','reg', by default "seg". + """ + def __name__(self): + """Return the name of the network architecture.""" return "Unet2d" def __init__( @@ -20,27 +58,7 @@ def __init__( num_filters=[], task="seg", ): - """ - 2D Unet with variable input/output channels and depth (block numbers). - Follows 2D UNet Architecture: - 1) Unet: https://arxiv.org/pdf/1505.04597.pdf - 2) residual Unet: https://arxiv.org/pdf/1711.10684.pdf - - :param int in_channels: number of feature channels in - :param int out_channels: number of feature channels out - :param int/tuple(int,int) kernel_size: size of x and y dimensions - of conv kernels in blocks - :param bool residual: see name - :param float dropout: probability of dropout, between 0 and 0.5 - :param int num_blocks: number of convolutional blocks on encoder and decoder - :param int num_block_layers: number of layers per block - :param list[int] num_filters: list of filters/feature levels - at each conv block depth - :param str task: network task (for virtual staining this is regression), - one of 'seg','reg' - """ - - super(Unet2d, self).__init__() + super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size @@ -167,20 +185,27 @@ def __init__( kernel_size=self.kernel_size, ) - def forward(self, x, validate_input=False): - """ - Forward call of network - - x -> Torch.tensor: input image stack + def forward(self, x: torch.Tensor, validate_input: bool = False) -> torch.Tensor: + """Forward pass through the 2D U-Net. Call order: - => num_block 2D convolutional blocks, with downsampling in between (encoder) - => num_block 2D convolutional blocks, with upsampling between them (decoder) - => skip connections between corresponding blocks on encoder and decoder - => terminal block collapses to output dimensions - - :param torch.tensor x: input image - :param bool validate_input: Deactivates assertions which are redundant - if forward pass is being traced by tensorboard writer. + => num_block 2D convolutional blocks, with downsampling in between (encoder) + => num_block 2D convolutional blocks, with upsampling between them (decoder) + => skip connections between corresponding blocks on encoder and decoder + => terminal block collapses to output dimensions + + Parameters + ---------- + x : torch.tensor + Input image stack. + validate_input : bool, optional + Deactivates assertions which are redundant if forward pass is being + traced by tensorboard writer, by default False. + + Returns + ------- + torch.tensor + Network output with same spatial dimensions as input. """ # handle input exceptions if validate_input: @@ -210,16 +235,20 @@ def forward(self, x, validate_input=False): return x.unsqueeze(2) - def register_modules(self, module_list, name): - """ - Helper function that registers modules stored in a list to the model object - so that they can be seen by PyTorch optimizer. + def register_modules(self, module_list: list[nn.Module], name: str) -> None: + """Register modules stored in a list to the model object. - Used to enable model graph creation with - non-sequential model types and dynamic layer numbers + So that they can be seen by PyTorch optimizer. - :param list(torch.nn.module) module_list: list of modules to register - :param str name: name of module type + Used to enable model graph creation with + non-sequential model types and dynamic layer numbers. + + Parameters + ---------- + module_list : list[torch.nn.module] + List of modules to register + name : str + Name of module type """ for i, module in enumerate(module_list): self.add_module(f"{name}_{str(i)}", module) diff --git a/viscy/unet/networks/fcmae.py b/viscy/unet/networks/fcmae.py index d63b65a7d..51abeed86 100644 --- a/viscy/unet/networks/fcmae.py +++ b/viscy/unet/networks/fcmae.py @@ -1,12 +1,12 @@ -""" -Fully Convolutional Masked Autoencoder as described in ConvNeXt V2 -based on the official JAX example in +"""Fully Convolutional Masked Autoencoder as described in ConvNeXt V2. + +Based on the official JAX example in https://github.com/facebookresearch/ConvNeXt-V2/blob/main/TRAINING.md#implementing-fcmae-with-masked-convolution-in-jax and timm's dense implementation of the encoder in ``timm.models.convnext`` """ import math -from typing import Sequence +from collections.abc import Sequence import torch from monai.networks.blocks import UpSample @@ -40,11 +40,21 @@ def _init_weights(module: nn.Module) -> None: def generate_mask( target: Size, stride: int, mask_ratio: float, device: str ) -> BoolTensor: - """ - :param Size target: target shape - :param int stride: total stride - :param float mask_ratio: ratio of the pixels to mask - :return BoolTensor: boolean mask (B1HW) + """Generate random boolean mask for masked autoencoder training. + + Parameters + ---------- + target : Size + Target tensor shape. + stride : int + Total downsampling stride. + mask_ratio : float + Ratio of pixels to mask for training. + + Returns + ------- + BoolTensor + Boolean mask tensor of shape (B1HW). """ m_height = target[-2] // stride m_width = target[-1] // stride @@ -55,10 +65,19 @@ def generate_mask( def upsample_mask(mask: BoolTensor, target: Size) -> BoolTensor: - """ - :param BoolTensor mask: low-resolution boolean mask (B1HW) - :param Size target: target size (BCHW) - :return BoolTensor: upsampled boolean mask (B1HW) + """Upsample boolean mask to match target spatial dimensions. + + Parameters + ---------- + mask : BoolTensor + Low-resolution boolean mask of shape (B1HW). + target : Size + Target tensor size (BCHW). + + Returns + ------- + BoolTensor + Upsampled boolean mask of shape (B1HW). """ if target[-2:] != mask.shape[-2:]: if not all(i % j == 0 for i, j in zip(target, mask.shape)): @@ -73,10 +92,19 @@ def upsample_mask(mask: BoolTensor, target: Size) -> BoolTensor: def masked_patchify(features: Tensor, unmasked: BoolTensor | None = None) -> Tensor: - """ - :param Tensor features: input image features (BCHW) - :param BoolTensor unmasked: boolean foreground mask (B1HW) - :return Tensor: masked channel-last features (BLC, L = H * W * mask_ratio) + """Convert spatial features to channel-last patches, optionally masked. + + Parameters + ---------- + features : Tensor + Input image features of shape (BCHW). + unmasked : BoolTensor | None, optional + Boolean foreground mask of shape (B1HW), by default None. + + Returns + ------- + Tensor + Masked channel-last features of shape (BLC, L = H * W * mask_ratio). """ if unmasked is None: return features.flatten(2).permute(0, 2, 1) @@ -91,11 +119,21 @@ def masked_patchify(features: Tensor, unmasked: BoolTensor | None = None) -> Ten def masked_unpatchify( features: Tensor, out_shape: Size, unmasked: BoolTensor | None = None ) -> Tensor: - """ - :param Tensor features: dense channel-last features (BLC) - :param Size out_shape: output shape (BCHW) - :param BoolTensor | None unmasked: boolean foreground mask, defaults to None - :return Tensor: masked features (BCHW) + """Convert channel-last patches back to spatial features. + + Parameters + ---------- + features : Tensor + Dense channel-last features of shape (BLC). + out_shape : Size + Output tensor shape (BCHW). + unmasked : BoolTensor | None, optional + Boolean foreground mask, by default None. + + Returns + ------- + Tensor + Masked spatial features of shape (BCHW). """ if unmasked is None: return features.permute(0, 2, 1).reshape(out_shape) @@ -111,12 +149,20 @@ def masked_unpatchify( class MaskedConvNeXtV2Block(nn.Module): """Masked ConvNeXt V2 Block. - :param int in_channels: input channels - :param int | None out_channels: output channels, defaults to None - :param int kernel_size: depth-wise convolution kernel size, defaults to 7 - :param int stride: downsample stride, defaults to 1 - :param int mlp_ratio: MLP expansion ratio, defaults to 4 - :param float drop_path: drop path rate, defaults to 0.0 + Parameters + ---------- + in_channels : int + Input channels. + out_channels : int | None, optional + Output channels, by default None. + kernel_size : int, optional + Depth-wise convolution kernel size, by default 7. + stride : int, optional + Downsample stride, by default 1. + mlp_ratio : int, optional + MLP expansion ratio, by default 4. + drop_path : float, optional + Drop path rate, by default 0.0. """ def __init__( @@ -151,10 +197,19 @@ def __init__( self.shortcut = nn.Identity() def forward(self, x: Tensor, unmasked: BoolTensor | None = None) -> Tensor: - """ - :param Tensor x: input tensor (BCHW) - :param BoolTensor | None unmasked: boolean foreground mask, defaults to None - :return Tensor: output tensor (BCHW) + """Forward pass through masked ConvNeXt V2 block. + + Parameters + ---------- + x : Tensor + Input tensor of shape (BCHW). + unmasked : BoolTensor | None, optional + Boolean foreground mask, by default None. + + Returns + ------- + Tensor + Output tensor of shape (BCHW). """ shortcut = self.shortcut(x) if unmasked is not None: @@ -172,15 +227,22 @@ def forward(self, x: Tensor, unmasked: BoolTensor | None = None) -> Tensor: class MaskedConvNeXtV2Stage(nn.Module): - """Masked ConvNeXt V2 Stage. - - :param int in_channels: input channels - :param int out_channels: output channels - :param int kernel_size: depth-wise convolution kernel size, defaults to 7 - :param int stride: downsampling factor of this stage, defaults to 2 - :param int num_blocks: number of residual blocks, defaults to 2 - :param Sequence[float] | None drop_path_rates: drop path rates of each block, - defaults to None + """Masked ConvNeXt V2 Stage for hierarchical feature extraction. + + Parameters + ---------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + kernel_size : int, optional + Depth-wise convolution kernel size, by default 7. + stride : int, optional + Downsampling factor of this stage, by default 2. + num_blocks : int, optional + Number of residual blocks, by default 2. + drop_path_rates : Sequence[float] | None, optional + Drop path rates of each block, by default None. """ def __init__( @@ -229,10 +291,19 @@ def __init__( in_channels = out_channels def forward(self, x: Tensor, unmasked: BoolTensor | None = None) -> Tensor: - """ - :param Tensor x: input tensor (BCHW) - :param BoolTensor | None unmasked: boolean foreground mask, defaults to None - :return Tensor: output tensor (BCHW) + """Forward pass through masked ConvNeXt V2 stage. + + Parameters + ---------- + x : Tensor + Input tensor of shape (BCHW). + unmasked : BoolTensor | None, optional + Boolean foreground mask, by default None. + + Returns + ------- + Tensor + Output tensor of shape (BCHW). """ x = self.downsample(x) if unmasked is not None: @@ -243,14 +314,20 @@ def forward(self, x: Tensor, unmasked: BoolTensor | None = None) -> Tensor: class MaskedAdaptiveProjection(nn.Module): - """ - Masked patchifying layer for projecting 2D or 3D input into 2D feature maps. - - :param int in_channels: input channels - :param int out_channels: output channels - :param Sequence[int, int] | int kernel_size_2d: kernel width and height - :param int kernel_depth: kernel depth for 3D input - :param int in_stack_depth: input stack depth for 3D input + """Masked patchifying layer for projecting 2D or 3D input into 2D feature maps. + + Parameters + ---------- + in_channels : int + Input channels. + out_channels : int + Output channels. + kernel_size_2d : tuple[int, int] | int, optional + Kernel width and height, by default 4. + kernel_depth : int, optional + Kernel depth for 3D input, by default 5. + in_stack_depth : int, optional + Input stack depth for 3D input, by default 5. """ def __init__( @@ -281,10 +358,19 @@ def __init__( self.norm = nn.LayerNorm(out_channels) def forward(self, x: Tensor, unmasked: BoolTensor = None) -> Tensor: - """ - :param Tensor x: input tensor (BCDHW) - :param BoolTensor unmasked: boolean foreground mask (B1HW), defaults to None - :return Tensor: output tensor (BCHW) + """Forward pass through masked adaptive projection layer. + + Parameters + ---------- + x : Tensor + Input tensor of shape (BCDHW). + unmasked : BoolTensor, optional + Boolean foreground mask of shape (B1HW), by default None. + + Returns + ------- + Tensor + Output tensor of shape (BCHW). """ # no need to mask before convolutions since patches do not spill over if x.shape[2] > 1: @@ -305,6 +391,27 @@ def forward(self, x: Tensor, unmasked: BoolTensor = None) -> Tensor: class MaskedMultiscaleEncoder(nn.Module): + """Multi-scale encoder with masking support for FC-MAE architecture. + + Implements hierarchical feature extraction through multiple ConvNeXt V2 stages + with optional random masking for self-supervised pretraining. + + Parameters + ---------- + in_channels : int + Input channels. + stage_blocks : Sequence[int], optional + Number of blocks per encoder stage, by default (3, 3, 9, 3). + dims : Sequence[int], optional + Feature dimensions at each stage, by default (96, 192, 384, 768). + drop_path_rate : float, optional + Stochastic depth rate, by default 0.0. + stem_kernel_size : Sequence[int], optional + Kernel sizes for adaptive projection, by default (5, 4, 4). + in_stack_depth : int, optional + Input stack depth for 3D input, by default 5. + """ + def __init__( self, in_channels: int, @@ -342,12 +449,20 @@ def __init__( def forward( self, x: Tensor, mask_ratio: float = 0.0 ) -> tuple[list[Tensor], BoolTensor | None]: - """ - :param Tensor x: input tensor (BCDHW) - :param float mask_ratio: ratio of the feature maps to mask, - defaults to 0.0 (no masking) - :return list[Tensor]: output tensors (list of BCHW) - :return BoolTensor | None: boolean foreground mask, None if no masking + """Extract multi-scale features with optional masking. + + Parameters + ---------- + x : Tensor + Input tensor of shape (BCDHW). + mask_ratio : float, optional + Ratio of the feature maps to mask, by default 0.0 (no masking). + + Returns + ------- + tuple[list[Tensor], BoolTensor | None] + Output tensors as list of BCHW tensors and boolean foreground mask + (None if no masking). """ if mask_ratio > 0.0: mask = generate_mask( @@ -367,6 +482,25 @@ def forward( class PixelToVoxelShuffleHead(nn.Module): + """Pixel-to-voxel reconstruction head using pixel shuffle upsampling. + + Converts 2D feature maps to 3D output volumes through pixel shuffle + upsampling and channel-to-depth reshaping. + + Parameters + ---------- + in_channels : int + Input feature channels. + out_channels : int + Output channels per voxel. + out_stack_depth : int, optional + Output stack depth (Z dimension), by default 5. + xy_scaling : int, optional + Spatial upsampling factor, by default 4. + pool : bool, optional + Whether to apply pooling in upsampling, by default False. + """ + def __init__( self, in_channels: int, @@ -389,6 +523,18 @@ def __init__( ) def forward(self, x: Tensor) -> Tensor: + """Reconstruct 3D volume from 2D features. + + Parameters + ---------- + x : Tensor + Input 2D features of shape (BCHW). + + Returns + ------- + Tensor + Reconstructed 3D volume of shape (BCDHW). + """ x = self.upsample(x) b, _, h, w = x.shape x = x.reshape(b, self.out_channels, self.out_stack_depth, h, w) @@ -396,6 +542,40 @@ def forward(self, x: Tensor) -> Tensor: class FullyConvolutionalMAE(nn.Module): + """Fully Convolutional Masked Autoencoder for self-supervised learning. + + Implements FC-MAE architecture combining a masked multi-scale encoder + with a UNet-style decoder for reconstruction tasks. Supports both + pretraining with masking and fine-tuning for downstream tasks. + + Parameters + ---------- + in_channels : int + Input channels. + out_channels : int + Output channels. + encoder_blocks : Sequence[int], optional + Blocks per encoder stage, by default [3, 3, 9, 3]. + dims : Sequence[int], optional + Feature dimensions per stage, by default [96, 192, 384, 768]. + encoder_drop_path_rate : float, optional + Encoder stochastic depth rate, by default 0.0. + stem_kernel_size : Sequence[int], optional + Adaptive projection kernel sizes, by default (5, 4, 4). + in_stack_depth : int, optional + Input stack depth for 3D data, by default 5. + decoder_conv_blocks : int, optional + Decoder convolution blocks per stage, by default 1. + pretraining : bool, optional + Whether in pretraining mode (returns mask), by default True. + head_conv : bool, optional + Whether to use convolutional reconstruction head, by default False. + head_conv_expansion_ratio : int, optional + Expansion ratio for conv head, by default 4. + head_conv_pool : bool, optional + Whether to use pooling in conv head, by default True. + """ + def __init__( self, in_channels: int, @@ -459,7 +639,26 @@ def __init__( self.num_blocks = len(dims) * int(math.log2(stem_kernel_size[-1])) self.pretraining = pretraining - def forward(self, x: Tensor, mask_ratio: float = 0.0) -> Tensor: + def forward( + self, x: Tensor, mask_ratio: float = 0.0 + ) -> Tensor | tuple[Tensor, BoolTensor]: + """Forward pass through FC-MAE architecture. + + Encodes input with optional masking, decodes through UNet decoder, + and reconstructs output through pixel-to-voxel head. + + Parameters + ---------- + x : Tensor + Input tensor of shape (BCDHW). + mask_ratio : float, optional + Masking ratio for pretraining, by default 0.0 (no mask). + + Returns + ------- + Tensor | tuple[Tensor, BoolTensor] + Reconstructed output of shape (BCDHW) or tuple with mask. + """ x, mask = self.encoder(x, mask_ratio=mask_ratio) x.reverse() x = self.decoder(x) diff --git a/viscy/unet/networks/layers/ConvBlock2D.py b/viscy/unet/networks/layers/ConvBlock2D.py index 114777a79..7f9861fe5 100644 --- a/viscy/unet/networks/layers/ConvBlock2D.py +++ b/viscy/unet/networks/layers/ConvBlock2D.py @@ -1,3 +1,5 @@ +from typing import Literal + import numpy as np import torch import torch.nn as nn @@ -5,49 +7,57 @@ class ConvBlock2D(nn.Module): + """2D convolutional block for U-Net lateral layers with configurable architecture. + + Supports dynamic layer configuration, normalization, activation functions, + residual connections, and various filter progression strategies. + """ + def __init__( self, - in_filters, - out_filters, - dropout=False, - norm="batch", - residual=True, - activation="relu", - transpose=False, - kernel_size=3, - num_repeats=3, - filter_steps="first", - layer_order="can", - ): + in_filters: int, + out_filters: int, + dropout: float | bool = False, + norm: Literal["batch", "instance"] = "batch", + residual: bool = True, + activation: Literal["relu", "leakyrelu", "elu", "selu", "linear"] = "relu", + transpose: bool = False, + kernel_size: int | tuple[int, int] = 3, + num_repeats: int = 3, + filter_steps: Literal["linear", "first", "last"] = "first", + layer_order: str = "can", + ) -> None: + """Initialize convolutional block for lateral layers in U-Net. + + Format for layer initialization allows dynamic layer number specification + in the conv blocks, enabling parameter number flexibility across the network. + + Parameters + ---------- + in_filters : int + Number of input feature channels. + out_filters : int + Number of output feature channels. + dropout : float or bool, default=False + Dropout probability. If False, no dropout is applied. + norm : {"batch", "instance"}, default="batch" + Normalization type to apply. + residual : bool, default=True + Whether to include residual connections. + activation : {"relu", "leakyrelu", "elu", "selu", "linear"}, default="relu" + Activation function type. + transpose : bool, default=False + Whether to use transpose convolution layers. + kernel_size : int or tuple[int, int], default=3 + 2D convolutional kernel size. + num_repeats : int, default=3 + Number of times the layer_order sequence is repeated in the block. + filter_steps : {"linear", "first", "last"}, default="first" + Strategy for channel dimension changes across layers. + layer_order : str, default="can" + Order of conv (c), activation (a), normalization (n) layers. """ - Convolutional block for lateral layers in Unet - - Format for layer initialization is as follows: - if layer type specified - => for number of layers - => add layer to list of that layer type - => register elements of list - This is done to allow for dynamic layer number specification in the conv blocks, - which allows us to change the parameter numbers of the network. - - :param int in_filters: number of images in in stack - :param int out_filters: number of images in out stack - :param float dropout: dropout probability (False => 0) - :param str norm: normalization type: 'batch', 'instance' - :param bool residual: as name - :param str activation: activation function: 'relu', 'leakyrelu', 'elu', 'selu' - :param bool transpose: as name - :param int/tuple kernel_size: convolutional kernel size - :param int num_repeats: number of times the layer_order layer sequence - is repeated in the block - :param str filter_steps: determines where in the block - the filters inflate channels (learn abstraction information): - 'linear','first','last' - :param str layer_order: order of conv, norm, and act layers in block: - 'can', 'cna', 'nca', etc - """ - - super(ConvBlock2D, self).__init__() + super().__init__() self.in_filters = in_filters self.out_filters = out_filters self.dropout = dropout @@ -262,21 +272,18 @@ def __init__( ) self.register_modules(self.act_list, f"{self.activation}_act") - def forward(self, x, validate_input=False): - """ - Forward call of convolutional block + def forward(self, x: torch.Tensor, validate_input: bool = False) -> torch.Tensor: + """Forward pass through the convolutional block. Order of layers within the block is defined by the 'layer_order' parameter, - which is a string of 'c's, 'a's and 'n's - in reference to convolution, activation, and normalization layers. - This sequence is repeated num_repeats times. + which is a string of 'c's, 'a's and 'n's in reference to convolution, + activation, and normalization layers. This sequence is repeated num_repeats times. - Recommended layer order: convolution -> activation -> normalization + Recommended layer order: convolution -> activation -> normalization - Regardless of layer order, - the final layer sequence in the block always ends in activation. - This allows for usage of passthrough layers - or a final output activation function determined separately. + Regardless of layer order, the final layer sequence in the block always ends + in activation. This allows for usage of passthrough layers or a final output + activation function determined separately. Residual blocks: if input channels are greater than output channels, @@ -284,9 +291,18 @@ def forward(self, x, validate_input=False): if input channels are less than output channels, we zero-pad input channels to output channel size. - :param torch.tensor x: input tensor - :param bool validate_input: Deactivates assertions - which are redundant if forward pass is being traced by tensorboard writer. + Parameters + ---------- + x : torch.Tensor + Input tensor for convolutional processing. + validate_input : bool, default=False + Deactivates assertions which are redundant if forward pass is being + traced by tensorboard writer. + + Returns + ------- + torch.Tensor + Output tensor after convolutional block processing. """ if validate_input: if isinstance(self.kernel_size, int): @@ -335,19 +351,24 @@ def forward(self, x, validate_input=False): return x - def model(self): - """ + def model(self) -> nn.Sequential: + """Create a sequential model from the convolutional block layers. + Allows calling of parameters inside ConvBlock object: - 'ConvBlock.model().parameters()'' + 'ConvBlock.model().parameters()' - Layer order: convolution -> normalization -> activation + Layer order: convolution -> normalization -> activation We can make a list of layer modules and unpack them into nn.Sequential. - Note: this is distinct from the forward call - because we want to use the forward call with addition, - since this is a residual block. - The forward call performs the residial calculation, - and all the parameters can be seen by the optimizer when given this model. + Note: this is distinct from the forward call because we want to use + the forward call with addition, since this is a residual block. + The forward call performs the residual calculation, and all the + parameters can be seen by the optimizer when given this model. + + Returns + ------- + nn.Sequential + Sequential model containing all layers in the block. """ layers = [] @@ -362,16 +383,21 @@ def model(self): return nn.Sequential(*layers) - def register_modules(self, module_list, name): - """ + def register_modules(self, module_list: list[nn.Module], name: str) -> None: + """Register modules from a list to enable PyTorch optimizer access. + Helper function that registers modules stored in a list to the model object so that they can be seen by PyTorch optimizer. - Used to enable model graph creation - with non-sequential model types and dynamic layer numbers + Used to enable model graph creation with non-sequential model types + and dynamic layer numbers. - :param list(torch.nn.module) module_list: list of modules to register - :param str name: name of module type + Parameters + ---------- + module_list : list of torch.nn.Module + List of PyTorch modules to register. + name : str + Name prefix for the module type. """ for i, module in enumerate(module_list): self.add_module(f"{name}_{str(i)}", module) diff --git a/viscy/unet/networks/layers/ConvBlock3D.py b/viscy/unet/networks/layers/ConvBlock3D.py index 893c612ef..9cf96e33c 100644 --- a/viscy/unet/networks/layers/ConvBlock3D.py +++ b/viscy/unet/networks/layers/ConvBlock3D.py @@ -1,3 +1,5 @@ +from typing import Literal + import numpy as np import torch import torch.nn as nn @@ -5,64 +7,62 @@ class ConvBlock3D(nn.Module): + """3D convolutional building block for volumetric neural networks. + + A flexible 3D convolutional block designed for processing volumetric data + such as medical imaging, microscopy, and video sequences. Supports residual + connections, various normalization schemes, activation functions, and + configurable layer ordering for deep 3D U-Net architectures. + + The block processes tensors in [..., z, x, y] or [..., z, y, x] format + and provides dynamic layer configuration with support for transpose + convolutions, dropout, and multiple padding strategies optimized for + volumetric convolution operations. + + Parameters + ---------- + in_filters : int + Number of input feature channels. + out_filters : int + Number of output feature channels. + dropout : float or bool, default=False + Dropout probability. If False, no dropout is applied. + norm : {"batch", "instance"}, default="batch" + Normalization type to apply. + residual : bool, default=True + Whether to include residual connections. + activation : {"relu", "leakyrelu", "elu", "selu", "linear"}, default="relu" + Activation function type. + transpose : bool, default=False + Whether to use transpose convolution layers. + kernel_size : int or tuple of int, default=(3, 3, 3) + 3D convolutional kernel size. + num_repeats : int, default=3 + Number of convolutional layers in the block. + filter_steps : {"linear", "first", "last"}, default="first" + Strategy for channel dimension changes across layers. + layer_order : str, default="can" + Order of conv (c), activation (a), normalization (n) layers. + padding : str, int, tuple or None, default=None + Padding strategy for convolutions. + """ + def __init__( self, - in_filters, - out_filters, - dropout=False, - norm="batch", - residual=True, - activation="relu", - transpose=False, - kernel_size=(3, 3, 3), - num_repeats=3, - filter_steps="first", - layer_order="can", - padding=None, - ): - """ - Convolutional block for lateral layers in Unet. - This block only accepts tensors of dimensions in - order [...,z,x,y] or [...,z,y,x] - - Format for layer initialization is as follows: - if layer type specified - => for number of layers - => add layer to list of that layer type - This is done to allow for dynamic layer number specification in the conv blocks, - which allows us to change the parameter numbers of the network. - - Only 'same' convolutional padding is recommended, - as the conv blocks are intended for deep Unets. - However padding can be specified as follows: - padding -> token{'same', 'valid', 'valid_stack'} or tuple(int) or int: - -> 'same': pads with same convolution - -> 'valid': pads for valid convolution on all dimensions - -> 'valid_stack': pads for valid convolution on xy dims (-1, -2), - same on z dim (-3). - -> tuple (int): pads above and below corresponding dimensions - -> int: pads above and below all dimensions - - :param int in_filters: number of images in in stack - :param int out_filters: number of images in out stack - :param float dropout: dropout probability (False => 0) - :param str norm: normalization type: 'batch', 'instance' - :param bool residual: as name - :param str activation: activation function: 'relu', 'leakyrelu', 'elu', 'selu' - :param bool transpose: as name - :param int/tuple kernel_size: convolutional kernel size - :param int num_repeats: as name - :param str filter_steps: determines where in the block - the filters inflate channels - (learn abstraction information): 'linear','first','last' - :param str layer_order: order of conv, norm, and act layers in block: - 'can', 'cna', etc. - NOTE: for now conv must always come first as required by norm feature counts - :paramn str/tuple(int)/tuple/None padding: convolutional padding, - see docstring for details - """ - - super(ConvBlock3D, self).__init__() + in_filters: int, + out_filters: int, + dropout: float | bool = False, + norm: Literal["batch", "instance"] = "batch", + residual: bool = True, + activation: Literal["relu", "leakyrelu", "elu", "selu", "linear"] = "relu", + transpose: bool = False, + kernel_size: int | tuple[int, int, int] = (3, 3, 3), + num_repeats: int = 3, + filter_steps: Literal["linear", "first", "last"] = "first", + layer_order: str = "can", + padding: str | int | tuple[int, ...] | None = None, + ) -> None: + super().__init__() self.in_filters = in_filters self.out_filters = out_filters self.dropout = dropout @@ -244,9 +244,9 @@ def __init__( ) self.register_modules(self.act_list, f"{self.activation}_act") - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: """ - Forward call of convolutional block + Forward call of convolutional block. Order of layers within the block is defined by the 'layer_order' parameter, which is a string of 'c's, 'a's and 'n's in reference to @@ -266,7 +266,15 @@ def forward(self, x): if input channels are less than output channels, we zero-pad input channels to output channel size - :param torch.tensor x: input tensor + Parameters + ---------- + x : torch.Tensor + Input tensor. + + Returns + ------- + torch.Tensor + Output tensor. """ x_0 = x for i in range(self.num_repeats): @@ -310,12 +318,11 @@ def forward(self, x): return x - def model(self): + def model(self) -> nn.Sequential: """ - Allows calling of parameters inside ConvBlock object: - 'ConvBlock.model().parameters()'' + Create sequential model from ConvBlock parameters. - Layer order: convolution -> normalization -> activation + Layer order: convolution -> normalization -> activation We can make a list of layer modules and unpack them into nn.Sequential. Note: this is distinct from the forward call @@ -323,6 +330,11 @@ def model(self): since this is a residual block. The forward call performs the residual calculation, and all the parameters can be seen by the optimizer when given this model. + + Returns + ------- + nn.Sequential + Sequential model containing all layers in the block. """ layers = [] @@ -337,16 +349,19 @@ def model(self): return nn.Sequential(*layers) - def register_modules(self, module_list, name): + def register_modules(self, module_list: list[nn.Module], name: str) -> None: """ - Helper function that registers modules stored in a list to the model object - so that the can be seen by PyTorch optimizer. + Register modules for PyTorch optimizer visibility. Used to enable model graph creation with non-sequential model types and dynamic layer numbers - :param list(torch.nn.module) module_list: list of modules to register - :param str name: name of module type + Parameters + ---------- + module_list : list[torch.nn.Module] + List of modules to register. + name : str + Name of module type. """ for i, module in enumerate(module_list): self.add_module(f"{name}_{str(i)}", module) diff --git a/viscy/unet/networks/unext2.py b/viscy/unet/networks/unext2.py index c2403fc9b..ed683a602 100644 --- a/viscy/unet/networks/unext2.py +++ b/viscy/unet/networks/unext2.py @@ -1,4 +1,5 @@ -from typing import Callable, Literal, Sequence +from collections.abc import Callable, Sequence +from typing import Literal import timm import torch @@ -14,17 +15,22 @@ def icnr_init( upsample_dims: int, init: Callable = nn.init.kaiming_normal_, ): - """ - ICNR initialization for 2D/3D kernels adapted from Aitken et al.,2017 , - "Checkerboard artifact free sub-pixel convolution". + """ICNR initialization for 2D/3D kernels. + Adapted from Aitken et al., 2017, "Checkerboard artifact free sub-pixel convolution". Adapted from MONAI v1.2.0, added support for upsampling dimensions that are not the same as the kernel dimension. - :param conv: convolution layer - :param upsample_factor: upsample factor - :param upsample_dims: upsample dimensions, 2 or 3 - :param init: initialization function + Parameters + ---------- + conv : nn.Module + Convolution layer to initialize. + upsample_factor : int + Upsample factor. + upsample_dims : int + Upsample dimensions, 2 or 3. + init : Callable, optional + Initialization function, by default nn.init.kaiming_normal_. """ out_channels, in_channels, *dims = conv.weight.shape scale_factor = upsample_factor**upsample_dims @@ -65,7 +71,19 @@ def _get_convnext_stage( class UNeXt2Stem(nn.Module): - """Stem for UNeXt2 and ContrastiveEncoder networks.""" + """Stem for UNeXt2 and ContrastiveEncoder networks. + + Parameters + ---------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + kernel_size : tuple[int, int, int] + Kernel size. + in_stack_depth : int + Number of input stack depth. + """ def __init__( self, @@ -84,6 +102,19 @@ def __init__( ) def forward(self, x: Tensor): + """Forward pass through UNeXt2 stem with depth-to-channel projection. + + Parameters + ---------- + x : Tensor + Input tensor of shape (B, C, D, H, W) where D is the stack depth. + + Returns + ------- + Tensor + Output tensor with depth projected to channels, shape (B, C*D', H', W') + where D' = D // kernel_size[0] after 3D convolution. + """ x = self.conv(x) b, c, d, h, w = x.shape # project Z/depth into channels @@ -92,7 +123,21 @@ def forward(self, x: Tensor): class StemDepthtoChannels(nn.Module): - """Stem with 3D convolution that maps depth to channels.""" + """Stem with 3D convolution that maps depth to channels. + + Parameters + ---------- + in_channels : int + Number of input channels. + in_stack_depth : int + Number of input stack depth. + in_channels_encoder : int + Number of input channels for the encoder. + stem_kernel_size : tuple[int, int, int] + Kernel size. + stem_stride : tuple[int, int, int] + Stride. + """ def __init__( self, @@ -115,8 +160,35 @@ def __init__( ) def compute_stem_channels( - self, in_stack_depth, stem_kernel_size, stem_stride_depth, in_channels_encoder + self, + in_stack_depth: int, + stem_kernel_size: tuple[int, int, int], + stem_stride_depth: int, + in_channels_encoder: int, ): + """Compute required 3D stem output channels for encoder compatibility. + + Parameters + ---------- + in_stack_depth : int + Input stack depth dimension. + stem_kernel_size : tuple[int, int, int] + 3D convolution kernel size. + stem_stride_depth : int + Stride in the depth dimension. + in_channels_encoder : int + Required input channels for the encoder after depth projection. + + Returns + ------- + int + Required output channels for the 3D stem convolution. + + Raises + ------ + ValueError + If channel dimensions cannot be matched with current configuration. + """ stem3d_out_depth = ( in_stack_depth - stem_kernel_size[0] ) // stem_stride_depth + 1 @@ -129,6 +201,19 @@ def compute_stem_channels( return stem3d_out_channels def forward(self, x: Tensor): + """Forward pass with 3D convolution and depth-to-channel mapping. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (B, C, D, H, W) where D is the input stack depth. + + Returns + ------- + torch.Tensor + Output tensor with depth projected to channels, maintaining spatial + dimensions after strided 3D convolution. + """ x = self.conv(x) b, c, d, h, w = x.shape # project Z/depth into channels @@ -137,6 +222,33 @@ def forward(self, x: Tensor): class UNeXt2UpStage(nn.Module): + """UNeXt2 decoder upsampling stage with skip connection fusion. + + Implements hierarchical feature upsampling using either deconvolution or + pixel shuffle, followed by ConvNeXt blocks for feature refinement. Combines + low-resolution features with high-resolution skip connections for multi-scale + feature fusion. + + Parameters + ---------- + in_channels : int + Number of input channels. + skip_channels : int + Number of skip channels. + out_channels : int + Number of output channels. + scale_factor : int + Scale factor. + mode : Literal["deconv", "pixelshuffle"] + Mode. "deconv" for deconvolution, "pixelshuffle" for pixel shuffle. + conv_blocks : int + Number of ConvNeXt blocks. + norm_name : str + Name of the normalization layer. + upsample_pre_conv : Literal["default"] | Callable | None + Upsample pre-convolution. + """ + def __init__( self, in_channels: int, @@ -191,10 +303,20 @@ def __init__( ) def forward(self, inp: Tensor, skip: Tensor) -> Tensor: - """ - :param Tensor inp: Low resolution features - :param Tensor skip: High resolution skip connection features - :return Tensor: High resolution features + """Forward pass with upsampling and skip connection fusion. + + Parameters + ---------- + inp : torch.Tensor + Low resolution input features from deeper decoder stage. + skip : torch.Tensor + High resolution skip connection features from encoder. + + Returns + ------- + torch.Tensor + Upsampled and refined features combining both inputs through + ConvNeXt blocks or residual units. """ inp = self.upsample(inp) inp = torch.cat([inp, skip], dim=1) @@ -202,6 +324,26 @@ def forward(self, inp: Tensor, skip: Tensor) -> Tensor: class PixelToVoxelHead(nn.Module): + """Head module for converting 2D features to 3D voxel output. + + Performs 2D-to-3D reconstruction using pixel shuffle upsampling and 3D + convolutions. Applies depth channel expansion and spatial upsampling to + generate volumetric outputs from 2D feature representations. + + Parameters + ---------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + out_stack_depth : int + Number of output stack depth. + expansion_ratio : int + Expansion ratio. + pool : bool + Whether to apply pooling in upsampling. + """ + def __init__( self, in_channels: int, @@ -238,6 +380,19 @@ def __init__( self.out_stack_depth = out_stack_depth def forward(self, x: Tensor) -> Tensor: + """Forward pass for 2D to 3D voxel reconstruction. + + Parameters + ---------- + x : torch.Tensor + Input 2D feature tensor of shape (B, C, H, W). + + Returns + ------- + torch.Tensor + Output 3D voxel tensor with upsampled spatial dimensions and + reconstructed depth, shape (B, out_channels, out_stack_depth, H', W'). + """ x = self.upsample(x) d = self.out_stack_depth + 2 b, c, h, w = x.shape @@ -249,17 +404,51 @@ def forward(self, x: Tensor) -> Tensor: class UnsqueezeHead(nn.Module): - """Unsqueeze 2D (B, C, H, W) feature map to 3D (B, C, 1, H, W) output""" + """Unsqueeze 2D (B, C, H, W) feature map to 3D (B, C, 1, H, W) output.""" def __init__(self) -> None: super().__init__() def forward(self, x: Tensor) -> Tensor: + """Forward pass adding singleton depth dimension. + + Parameters + ---------- + x : torch.Tensor + Input 2D tensor of shape (B, C, H, W). + + Returns + ------- + torch.Tensor + Output 3D tensor with singleton depth dimension, shape (B, C, 1, H, W). + """ x = x.unsqueeze(2) return x class UNeXt2Decoder(nn.Module): + """UNeXt2 hierarchical decoder with multi-stage upsampling. + + Implements progressive upsampling through multiple UNeXt2UpStage modules, + combining features from different encoder scales through skip connections. + Each stage performs feature upsampling and refinement using ConvNeXt blocks. + + Parameters + ---------- + num_channels : list[int] + Number of channels for each stage. + norm_name : str + Name of the normalization layer. + mode : Literal["deconv", "pixelshuffle"] + Mode. "deconv" for deconvolution, "pixelshuffle" for pixel shuffle. + conv_blocks : int + Number of ConvNeXt blocks. + strides : list[int] + Strides for each stage. + upsample_pre_conv : Literal["default"] | Callable | None + Upsample pre-convolution. + """ + def __init__( self, num_channels: list[int], @@ -286,6 +475,20 @@ def __init__( self.decoder_stages.append(stage) def forward(self, features: Sequence[Tensor]) -> Tensor: + """Forward pass through hierarchical decoder stages. + + Parameters + ---------- + features : Sequence[torch.Tensor] + List of multi-scale encoder features, ordered from lowest to highest + resolution. First element is the bottleneck feature. + + Returns + ------- + torch.Tensor + Decoded high-resolution features after progressive upsampling and + skip connection fusion through all decoder stages. + """ feat = features[0] # padding features.append(None) @@ -295,12 +498,51 @@ def forward(self, features: Sequence[Tensor]) -> Tensor: class UNeXt2(nn.Module): + """UNeXt2: ConvNeXt-based U-Net for 3D-to-2D-to-3D processing. + + Advanced transformer-inspired U-Net architecture using ConvNeXt backbones + for hierarchical feature extraction. Performs 3D-to-2D projection via stem, + 2D multi-scale processing through ConvNeXt encoder-decoder, and 2D-to-3D + reconstruction via specialized head modules. + + Parameters + ---------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + in_stack_depth : int + Number of input stack depth. + out_stack_depth : int, optional + Number of output stack depth. By default, None, it is the same as the input stack depth. + backbone : str + Backbone model. + pretrained : bool + Whether to use pretrained weights. + stem_kernel_size : tuple[int, int, int] + Kernel size. + decoder_mode : Literal["deconv", "pixelshuffle"] + Mode. "deconv" for deconvolution, "pixelshuffle" for pixel shuffle. + decoder_conv_blocks : int + Number of ConvNeXt blocks. By default, 2. + decoder_norm_layer : str, optional + Name of the normalization layer. By default, "instance". + decoder_upsample_pre_conv : bool, optional + Whether to use upsample pre-convolution. By default, False. + head_pool : bool, optional + Whether to apply pooling in upsampling. By default, False. + head_expansion_ratio : int, optional + Expansion ratio. By default, 4. + drop_path_rate : float, optional + Drop path rate. By default, 0.0. + """ + def __init__( self, in_channels: int = 1, out_channels: int = 1, in_stack_depth: int = 5, - out_stack_depth: int = None, + out_stack_depth: int | None = None, backbone: str = "convnextv2_tiny", pretrained: bool = False, stem_kernel_size: tuple[int, int, int] = (5, 4, 4), @@ -357,10 +599,23 @@ def __init__( @property def num_blocks(self) -> int: - """2-times downscaling factor of the smallest feature map""" + """2-times downscaling factor of the smallest feature map.""" return 6 def forward(self, x: Tensor) -> Tensor: + """Forward pass through complete UNeXt2 architecture. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (B, C, D, H, W) where D is the input stack depth. + + Returns + ------- + torch.Tensor + Output tensor of shape (B, out_channels, out_stack_depth, H', W') + after 3D-to-2D-to-3D processing through ConvNeXt backbone. + """ x = self.stem(x) x: list = self.encoder_stages(x) x.reverse() diff --git a/viscy/utils/__init__.py b/viscy/utils/__init__.py index 5e2c7e1ed..e69de29bb 100644 --- a/viscy/utils/__init__.py +++ b/viscy/utils/__init__.py @@ -1 +0,0 @@ -"""Module for utility functions""" diff --git a/viscy/utils/aux_utils.py b/viscy/utils/aux_utils.py index f49137beb..ea597f4f6 100644 --- a/viscy/utils/aux_utils.py +++ b/viscy/utils/aux_utils.py @@ -1,17 +1,36 @@ -"""Auxiliary utility functions""" +"""Auxiliary utility functions.""" + +from pathlib import Path import iohub.ngff as ngff import yaml def _assert_unique_subset(subset, superset, name): - """ - Helper function to allow for clean code: - Throws error if unique elements of subset are not a subset of - unique elements of superset. - - Returns unique elements of subset if given a list. If subset is -1, - returns all unique elements of superset + """Check that unique elements of subset are a subset of superset. + + Helper function to allow for clean code: Throws error if unique elements + of subset are not a subset of unique elements of superset. + + Parameters + ---------- + subset : list or int + Subset to validate. If -1, returns all unique elements of superset. + superset : list + Superset to validate against. + name : str + Name of the parameter being validated (for error messages). + + Returns + ------- + set + Unique elements of subset if given a list. If subset is -1, + returns all unique elements of superset. + + Raises + ------ + AssertionError + If subset is not a subset of superset. """ if subset == -1: subset = superset @@ -27,34 +46,44 @@ def _assert_unique_subset(subset, superset, name): def validate_metadata_indices( - zarr_dir, + zarr_dir: str | Path, time_ids=[], channel_ids=[], slice_ids=[], pos_ids=[], ): - """ - Check the availability of indices provided timepoints, channels, positions - and slices for all data, and returns only the available of the specified - indices. + """Check availability of indices for timepoints, channels, positions and slices. + Returns only the available indices from the specified indices. If input ids are None, the indices for that parameter will not be evaluated. If input ids are -1, all indices for that parameter will be returned. - Assumes uniform structure, as such structure is required for HCS compatibility - - :param str zarr_dir: HCS-compatible zarr directory to validate indices against - :param list time_ids: check availability of these timepoints in image - metadata - :param list channel_ids: check availability of these channels in image - metadata - :param list pos_ids: Check availability of positions in zarr_dir - :param list slice_ids: Check availability of z slices in image metadata - - :return dict indices_metadata: All indices found given input - :raise AssertionError: If not all channels, timepoints, positions - or slices are present + Assumes uniform structure, as such structure is required for HCS compatibility. + + Parameters + ---------- + zarr_dir : str | Path + HCS-compatible zarr directory to validate indices against. + time_ids : list, optional + Check availability of these timepoints in image metadata, by default []. + channel_ids : list, optional + Check availability of these channels in image metadata, by default []. + slice_ids : list, optional + Check availability of z slices in image metadata, by default []. + pos_ids : list, optional + Check availability of positions in zarr_dir, by default []. + + Returns + ------- + dict + Dictionary with keys 'time_ids', 'channel_ids', 'slice_ids', 'pos_ids' + containing all indices found given input. + + Raises + ------ + AssertionError + If not all channels, timepoints, positions or slices are present. """ plate = ngff.open_ome_zarr(zarr_dir, layout="hcs", mode="r") position_path, position = next(plate.positions()) @@ -86,14 +115,20 @@ def validate_metadata_indices( return indices_metadata -def read_config(config_fname): - """Read the config file in yml format +def read_config(config_fname: str | Path): + """Read the config file in yml format. - :param str config_fname: fname of config yaml with its full path - :return: dict config: Configuration parameters - """ + Parameters + ---------- + config_fname : str | Path + Filename of config yaml with its full path. - with open(config_fname, "r") as f: + Returns + ------- + dict + Configuration parameters. + """ + with open(config_fname) as f: config = yaml.safe_load(f) return config diff --git a/viscy/utils/blend.py b/viscy/utils/blend.py index 18167f852..006ee59a3 100644 --- a/viscy/utils/blend.py +++ b/viscy/utils/blend.py @@ -6,6 +6,22 @@ def blend_channels( image: np.ndarray, cmaps: list[Colormap], rescale: bool ) -> np.ndarray: + """Blend multi-channel images using specified colormaps. + + Parameters + ---------- + image : np.ndarray + Multi-channel image array to blend. + cmaps : list[Colormap] + List of colormaps for each channel. + rescale : bool + Whether to rescale intensity values to [0, 1] range. + + Returns + ------- + np.ndarray + Blended RGB image clipped to [0, 1] range. + """ rendered_channels = [] for channel, cmap in zip(image, cmaps): colormap = Colormap(cmap) diff --git a/viscy/utils/cli_utils.py b/viscy/utils/cli_utils.py index 4223e6784..9d81b3335 100644 --- a/viscy/utils/cli_utils.py +++ b/viscy/utils/cli_utils.py @@ -1,18 +1,31 @@ import collections import os import re +from pathlib import Path import numpy as np import torch +from numpy.typing import NDArray from PIL import Image +from torch.utils.data import DataLoader -def unique_tags(directory): - """ - Returns list of unique nume tags from data directory +def unique_tags(directory: str | Path) -> dict[str, int]: + """Return list of unique nume tags from data directory. + + Parameters + ---------- + directory : str | Path + Directory containing '.tif' files. + + Returns + ------- + dict[str, int] + Dictionary of unique tags and their counts. - :param str directory: directory containing '.tif' files - TODO: Remove, unused and poorly written + Notes + ----- + TODO: Remove, unused and poorly written. """ files = [ f for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f)) @@ -29,31 +42,51 @@ def unique_tags(directory): return tags -class MultiProcessProgressBar(object): - """ +class MultiProcessProgressBar: + """Progress bar for multi-processed tasks. + Provides the ability to create & update a single progress bar for multi-depth - multi-processed tasks by calling updates on a single object + multi-processed tasks by calling updates on a single object. + + Parameters + ---------- + total_updates : int + Total number of updates expected for this progress bar. """ - def __init__(self, total_updates): + def __init__(self, total_updates: int) -> None: self.dataloader = list(range(total_updates)) self.current = 0 - def tick(self, process): + def tick(self, process: str) -> None: + """Update progress bar with current process status. + + Parameters + ---------- + process : str + Description of the current process being executed. + """ self.current += 1 show_progress_bar(self.dataloader, self.current, process) -def show_progress_bar(dataloader, current, process="training", interval=1): - """ - Utility function to print tensorflow-like progress bar. +def show_progress_bar( + dataloader: DataLoader, current: int, process: str = "training", interval: int = 1 +) -> None: + """Print TensorFlow-like progress bar for batch processing. Written instead of using tqdm to allow for custom progress bar readouts. - :param iterable dataloader: dataloader currently being processed - :param int current: current index in dataloader - :param str proces: current process being performed - :param int interval: interval at which to update progress bar + Parameters + ---------- + dataloader : DataLoader + Dataloader currently being processed. + current : int + Current index in dataloader. + process : str, optional + Current process being performed, by default "training". + interval : int, optional + Interval at which to update progress bar, by default 1. """ current += 1 bar_length = 50 @@ -80,23 +113,44 @@ def show_progress_bar(dataloader, current, process="training", interval=1): print(output_string) -def save_figure(data, save_folder, name, title=None, vmax=0, ext=".png"): - """ +def save_figure( + data: NDArray | torch.Tensor, + save_folder: str | Path, + name: str, + title: str | None = None, + vmax: float = 0, + ext: str = ".png", +) -> None: + """Save image data as PNG or JPEG figure. + Saves .png or .jpeg figure of data to folder save_folder under 'name'. - 'data' must be a 3d tensor or numpy array, in channels_first format - - :param numpy.ndarray/torch.tensor data: input image/stack data to save - :param str save_folder: global path to folder where data is saved. - :param str name: name of data, no extension specified - :param str/None title: image title, if none specified, defaults used - :param float vmax: value to normalize figure to, by default uses data max - :param str ext: image save file extension + 'data' must be a 3d tensor or numpy array, in channels_first format. + + Parameters + ---------- + data : NDArray | torch.Tensor + Input image/stack data to save in channels_first format. + save_folder : str | Path + Global path to folder where data is saved. + name : str + Name of data, no extension specified. + title : str, optional + Image title, if none specified, defaults used, by default None. + vmax : float, optional + Value to normalize figure to, by default 0 (uses data max). + ext : str, optional + Image save file extension, by default ".png". + + Raises + ------ + AttributeError + If data is not a torch tensor or numpy array. """ assert len(data.shape) == 3, f"'{len(data.shape)}d' data must be 3-dimensional" if isinstance(data, torch.Tensor): data = data.detach().cpu().numpy() - elif not isinstance(data, np.ndarray): + elif not isinstance(data, NDArray): raise AttributeError( f"'data' of type {type(data)} must be torch tensor or numpy array." ) diff --git a/viscy/utils/image_utils.py b/viscy/utils/image_utils.py index a95691162..9f24631bd 100644 --- a/viscy/utils/image_utils.py +++ b/viscy/utils/image_utils.py @@ -1,14 +1,39 @@ -"""Utility functions for processing images""" +"""Utility functions for processing images.""" import itertools import sys +from typing import Any import numpy as np +from numpy.typing import ArrayLike, NDArray import viscy.utils.normalize as normalize -def im_bit_convert(im, bit=16, norm=False, limit=[]): +def im_bit_convert( + im: ArrayLike, bit: int = 16, norm: bool = False, limit: list[float] = [] +) -> NDArray[Any]: + """Convert image to specified bit depth with optional normalization. + + FIXME: Verify parameter types and exact behavior for edge cases. + + Parameters + ---------- + im : ArrayLike + Input image to convert. + bit : int, optional + Target bit depth (8 or 16), by default 16. + norm : bool, optional + Whether to normalize image to [0, 2^bit-1] range, by default False. + limit : list, optional + Min/max values for normalization. If empty, uses image min/max, + by default []. + + Returns + ------- + NDArray + Image converted to specified bit depth. + """ im = im.astype( np.float32, copy=False ) # convert to float32 without making a copy to save memory @@ -29,27 +54,53 @@ def im_bit_convert(im, bit=16, norm=False, limit=[]): return im -def im_adjust(img, tol=1, bit=8): - """ - Stretches contrast of the image and converts to 'bit'-bit. - Useful for weight-maps in masking +def im_adjust(img: ArrayLike, tol: int | float = 1, bit: int = 8) -> NDArray[Any]: + """Stretch contrast of the image and convert to specified bit depth. + + Useful for weight-maps in masking. + + Parameters + ---------- + img : ArrayLike + Input image to adjust. + tol : int or float, optional + Tolerance percentile for contrast stretching, by default 1. + bit : int, optional + Target bit depth, by default 8. + + Returns + ------- + NDArray + Contrast-adjusted image in specified bit depth. """ limit = np.percentile(img, [tol, 100 - tol]) im_adjusted = im_bit_convert(img, bit=bit, norm=True, limit=limit.tolist()) return im_adjusted -def grid_sample_pixel_values(im, grid_spacing): - """Sample pixel values in the input image at the grid. Any incomplete - grids (remainders of modulus operation) will be ignored. +def grid_sample_pixel_values( + im: NDArray[Any], grid_spacing: int +) -> tuple[NDArray[Any], NDArray[Any], NDArray[Any]]: + """Sample pixel values in the input image at grid points. - :param np.array im: 2D image - :param int grid_spacing: spacing of the grid - :return int row_ids: row indices of the grids - :return int col_ids: column indices of the grids - :return np.array sample_values: sampled pixel values - """ + Any incomplete grids (remainders of modulus operation) will be ignored. + + Parameters + ---------- + im : NDArray + 2D image to sample from. + grid_spacing : int + Spacing of the grid points. + Returns + ------- + row_ids : NDArray + Row indices of the grid points. + col_ids : NDArray + Column indices of the grid points. + sample_values : NDArray + Sampled pixel values at grid points. + """ im_shape = im.shape assert grid_spacing < im_shape[0], "grid spacing larger than image height" assert grid_spacing < im_shape[1], "grid spacing larger than image width" @@ -69,22 +120,38 @@ def grid_sample_pixel_values(im, grid_spacing): def preprocess_image( - im, - hist_clip_limits=None, - is_mask=False, - normalize_im=None, - zscore_mean=None, - zscore_std=None, -): - """ - Do histogram clipping, z score normalization, and potentially binarization. - - :param np.array im: Image (stack) - :param tuple hist_clip_limits: Percentile histogram clipping limits - :param bool is_mask: True if mask - :param str/None normalize_im: Normalization, if any - :param float/None zscore_mean: Data mean - :param float/None zscore_std: Data std + im: ArrayLike, + hist_clip_limits: tuple[float, float] | None = None, + is_mask: bool = False, + normalize_im: str | None = None, + zscore_mean: float | None = None, + zscore_std: float | None = None, +) -> NDArray[Any]: + """Preprocess image with histogram clipping, z-score normalization, and binarization. + + Performs histogram clipping, z-score normalization, and potentially binarization + depending on the input parameters. + + Parameters + ---------- + im : ArrayLike + Input image or image stack. + hist_clip_limits : tuple[float, float], optional + Percentile histogram clipping limits (min_percentile, max_percentile), + by default None. + is_mask : bool, optional + True if input is a mask (will be binarized), by default False. + normalize_im : str, optional + Normalization method to apply, by default None. + zscore_mean : float, optional + Precomputed mean for z-score normalization, by default None. + zscore_std : float, optional + Precomputed standard deviation for z-score normalization, by default None. + + Returns + ------- + NDArray + Preprocessed image. """ # remove singular dimension for 3D images if len(im.shape) > 3: diff --git a/viscy/utils/log_images.py b/viscy/utils/log_images.py index 3949f93fb..99dfa98f3 100644 --- a/viscy/utils/log_images.py +++ b/viscy/utils/log_images.py @@ -1,16 +1,15 @@ -"""Logging example images during training.""" - -from typing import Sequence +from collections.abc import Sequence import numpy as np from matplotlib.pyplot import get_cmap +from numpy.typing import NDArray from skimage.exposure import rescale_intensity from torch import Tensor def detach_sample( imgs: Sequence[Tensor], log_samples_per_batch: int -) -> list[list[np.ndarray]]: +) -> list[list[NDArray]]: """Detach example images from the batch and convert them to numpy arrays. Parameters @@ -22,7 +21,7 @@ def detach_sample( Returns ------- - list[list[np.ndarray]] + list[list[NDArray]] Grid of example images. Rows are samples, columns are channels. """ @@ -38,21 +37,19 @@ def detach_sample( return samples -def render_images( - imgs: Sequence[Sequence[np.ndarray]], cmaps: list[str] = [] -) -> np.ndarray: +def render_images(imgs: Sequence[Sequence[NDArray]], cmaps: list[str] = []) -> NDArray: """Render images in a grid. Parameters ---------- - imgs : Sequence[Sequence[np.ndarray]] + imgs : Sequence[Sequence[NDArray]] Grid of images to render, output of `detach_sample`. cmaps : list[str], optional Colormaps for each column, by default [] Returns ------- - np.ndarray + NDArray Rendered RGB images grid. """ images_grid = [] diff --git a/viscy/utils/logging.py b/viscy/utils/logging.py index 5bdeac90b..7320fc5b2 100644 --- a/viscy/utils/logging.py +++ b/viscy/utils/logging.py @@ -1,6 +1,7 @@ import datetime import os import time +from typing import Any import torch @@ -8,19 +9,28 @@ from viscy.utils.normalize import hist_clipping -def log_feature(feature_map, name, log_save_folder, debug_mode): - """ - If self.debug_mode, creates a visual of the given feature map, and saves it at - 'log_save_folder' - If no log_save_folder specified, saves relative to working directory with timestamp. +def log_feature( + feature_map: torch.Tensor, name: str, log_save_folder: str, debug_mode: bool +) -> None: + """Create visual feature map logs for debugging deep learning models. - Currently only saving in working directory is supported. - This is meant to be an analysis tool, - and results should not be saved permanently. + If debug_mode is enabled, creates a visual of the given feature map and saves it at + 'log_save_folder'. If no log_save_folder specified, saves relative to working + directory with timestamp. - :param torch.tensor feature_map: feature map to create visualization log of - :param str name: string - :param str log_save_folder + Currently only saving in working directory is supported. + This is meant to be an analysis tool, and results should not be saved permanently. + + Parameters + ---------- + feature_map : torch.Tensor + Feature map to create visualization log of. + name : str + Name identifier for the feature map visualization. + log_save_folder : str + Directory path for saving the visualization output. + debug_mode : bool + Whether to enable debug mode visualization logging. """ try: if debug_mode: @@ -46,35 +56,73 @@ def log_feature(feature_map, name, log_save_folder, debug_mode): class FeatureLogger: + """Logger for visualizing neural network feature maps during training and debugging. + + This utility class provides comprehensive feature map visualization capabilities + for monitoring convolutional neural network activations. It supports both + individual channel visualization and grid-based multi-channel displays, + with flexible normalization and spatial dimension handling. + + The logger is designed for debugging deep learning models by capturing + intermediate layer activations and saving them as organized image files. + It handles multi-dimensional tensors commonly found in computer vision + tasks, including 2D/3D spatial dimensions with batch and channel axes. + + Parameters + ---------- + save_folder : str + Output directory for saving visualization files. + spatial_dims : int, optional + Number of spatial dimensions in feature tensors, by default 3. + full_batch : bool, optional + If true, log all samples in batch (warning: slow!), by default False. + save_as_grid : bool, optional + If true, feature maps are saved as a grid containing all channels, + else saved individually, by default True. + grid_width : int, optional + Desired width of grid if save_as_grid. If 0, defaults to 1/4 the + number of channels, by default 0. + normalize_by_grid : bool, optional + If true, images saved in grid are normalized to brightest pixel in + entire grid, by default False. + + Attributes + ---------- + save_folder : str + Directory path for saving visualization outputs. + spatial_dims : int + Number of spatial dimensions in feature tensors (2D or 3D). + full_batch : bool + Whether to log all samples in batch or just the first. + save_as_grid : bool + Whether to arrange channels in a grid layout. + grid_width : int + Number of columns in grid visualization. + normalize_by_grid : bool + Whether to normalize intensities across entire grid. + + Examples + -------- + >>> logger = FeatureLogger( + ... save_folder="./feature_logs", + ... spatial_dims=3, + ... save_as_grid=True, + ... grid_width=8, + ... ) + >>> logger.log_feature_map( + ... conv_features, "conv1_activations", dim_names=["batch", "channels"] + ... ) + """ + def __init__( self, - save_folder, - spatial_dims=3, - full_batch=False, - save_as_grid=True, - grid_width=0, - normalize_by_grid=False, - ): - """ - Logger object for handling logging feature maps inside network architectures. - - Saves each 2d slice of a feature map in either a single grid per feature map - stack or a directory tree of labeled slices. - - By default saves images into grid. - - :param str save_folder: output directory - :param bool full_batch: if true, log all sample in batch (warning slow!), - defaults to False - :param bool save_as_grid: if true feature maps are to be saved as a grid - containing all channels, else saved individually, - defaults to True - :param int grid_width: desired width of grid if save_as_grid, by default - 1/4 the number of channels, defaults to 0 - :param bool normalize_by_grid: if true, images saved in grid are normalized - to brightest pixel in entire grid, defaults to False - - """ + save_folder: str, + spatial_dims: int = 3, + full_batch: bool = False, + save_as_grid: bool = True, + grid_width: int = 0, + normalize_by_grid: bool = False, + ) -> None: self.save_folder = save_folder self.spatial_dims = spatial_dims self.full_batch = full_batch @@ -86,34 +134,38 @@ def __init__( def log_feature_map( self, - feature_map, - feature_name, - dim_names=[], - vmax=0, - ): - """ - Creates a log of figures the given feature map tensor at 'save_folder'. - Log is saved as images of feature maps in nested directory tree. - - By default _assumes that batch dimension is the first dimension_, and - only logs the first sample in the batch, for performance reasons. - - Feature map logs cannot overwrite. - - :param torch.Tensor feature_map: feature map to log (typically 5d tensor) - :parapm str feature_name: name of feature (will be used as dir name) - :param list dim_names: names of each dimension, by default just numbers - :param int spatial_dims: number of spatial dims, defaults to 3 - :param float vmax: maximum intensity to normalize figures by, by default - (if given 0) does relative normalization + feature_map: torch.Tensor, + feature_name: str, + dim_names: list[str] | None = None, + vmax: float = 0, + ) -> None: + """Create a log of figures for the given feature map tensor. + + Log is saved as images of feature maps in nested directory tree at save_folder. + + By default assumes that batch dimension is the first dimension, and only logs + the first sample in the batch for performance reasons. Feature map logs cannot + overwrite existing files. + + Parameters + ---------- + feature_map : torch.Tensor + Feature map to log, typically 5D tensor (BCDHW or BCTHW). + feature_name : str + Name of feature, used as directory name for organizing outputs. + dim_names : list[str] | None, optional + Names of each non-spatial dimension, by default just numbers. + vmax : float, optional + Maximum intensity to normalize figures by. If 0, uses relative + normalization, by default 0. """ # take tensor off of gpu and detach gradient feature_map = feature_map.detach().cpu() # handle dim names num_dims = len(feature_map.shape) - if len(dim_names) == 0: - dim_names = ["dim_" + str(i) for i in range(len(num_dims))] + if dim_names is None: + dim_names = ["dim_" + str(i) for i in range(num_dims)] else: assert len(dim_names) + self.spatial_dims == num_dims, ( "dim_names must be same length as nonspatial tensor dim length" @@ -132,24 +184,32 @@ def log_feature_map( def map_feature_dims( self, - feature_map, - save_as_grid, - vmax=0, - depth=0, - ): - """ - Recursive directory creation for organizing feature map logs - - If save_as_grid, will compile 'channels' (assumed to be last - non-spatial dimension) into a single large image grid before saving. - - :param numpy.ndarray feature_map: see name - :param str save_dir: see name - :param bool save_as_grid: if true, saves images as channel grid - :param float vmax: maximum intensity to normalize figures by - :param int depth: recursion counter. depth in dimensions + feature_map: torch.Tensor, + save_as_grid: bool, + vmax: float = 0, + depth: int = 0, + ) -> None: + """Recursively create directory structure for organizing feature map logs. + + If save_as_grid is True, compiles 'channels' (assumed to be last non-spatial + dimension) into a single large image grid before saving. + + Parameters + ---------- + feature_map : torch.Tensor + Feature tensor to process and save. + save_as_grid : bool + If true, saves images as channel grid layout. + vmax : float, optional + Maximum intensity to normalize figures by, by default 0. + depth : int, optional + Recursion counter tracking depth in tensor dimensions, by default 0. + + Raises + ------ + AttributeError + If the feature map has an invalid number of dimensions. """ - for i in range(feature_map.shape[0]): if len(feature_map.shape) == 3: # individual saving @@ -257,16 +317,33 @@ def map_feature_dims( break return - def interleave_bars(self, arrays, axis, pixel_width=3, value=0): - """ - Takes list of 2d torch tensors and interleaves bars to improve - grid visualization quality. - Assumes arrays are all of the same shape. - - :param list grid_arrays: list of tensors to place bars between - :param int axis: axis on which to interleave bars (0 or 1) - :param int pixel_width: width of bar, defaults to 3 - :param int value: value of bar pixels, defaults to 0 + def interleave_bars( + self, + arrays: list[torch.Tensor], + axis: int, + pixel_width: int = 3, + value: float = 0, + ) -> list[torch.Tensor]: + """Interleave separator bars between tensors to improve grid visualization. + + Takes list of 2D torch tensors and interleaves bars to improve grid + visualization quality. Assumes arrays are all of the same shape. + + Parameters + ---------- + arrays : list[torch.Tensor] + List of tensors to place separator bars between. + axis : int + Axis on which to interleave bars (0 or 1). + pixel_width : int, optional + Width of separator bar in pixels, by default 3. + value : float, optional + Pixel value for separator bars, by default 0. + + Returns + ------- + list[torch.Tensor] + List of tensors with separator bars interleaved for grid visualization. """ shape_match_axis = abs(axis - 1) length = arrays[0].shape[shape_match_axis] diff --git a/viscy/utils/masks.py b/viscy/utils/masks.py index a0881fa02..f000fce94 100644 --- a/viscy/utils/masks.py +++ b/viscy/utils/masks.py @@ -1,5 +1,8 @@ +from typing import Any + import numpy as np import scipy.ndimage as ndimage +from numpy.typing import NDArray from scipy.ndimage import binary_fill_holes from skimage.filters import gaussian, laplace, threshold_otsu from skimage.morphology import ( @@ -11,14 +14,24 @@ ) -def create_otsu_mask(input_image, sigma=0.6): - """Create a binary mask using morphological operations - :param np.array input_image: generate masks from this 3D image - :param float sigma: Gaussian blur standard deviation, - increase in value increases blur - :return: volume mask of input_image, 3D np.array - """ +def create_otsu_mask( + input_image: NDArray[Any], sigma: float = 0.6 +) -> NDArray[np.bool_]: + """Create a binary mask using Otsu thresholding and morphological operations. + Parameters + ---------- + input_image : NDArray + Generate masks from this 3D image. + sigma : float, optional + Gaussian blur standard deviation, increase in value increases blur, + by default 0.6. + + Returns + ------- + NDArray + Volume mask of input_image, 3D binary array. + """ input_sz = input_image.shape mid_slice_id = input_sz[0] // 2 @@ -28,20 +41,36 @@ def create_otsu_mask(input_image, sigma=0.6): return mask -def create_membrane_mask(input_image, str_elem_size=23, sigma=0.4, k_size=3, msize=120): - """Create a binary mask using Laplacian of Gaussian (LOG) feature detection - - :param np.array input_image: generate masks from this image - :param int str_elem_size: size of the laplacian filter - used for contarst enhancement, odd number. - Increase in value increases sensitivity of contrast enhancement - :param float sigma: Gaussian blur standard deviation - :param int k_size: disk/ball size for mask dilation, - ball for 3D and disk for 2D data - :param int msize: size of small objects removed to clean segmentation - :return: mask of input_image, np.array +def create_membrane_mask( + input_image: NDArray[Any], + str_elem_size: int = 23, + sigma: float = 0.4, + k_size: int = 3, + msize: int = 120, +) -> NDArray[np.bool_]: + """Create a binary mask using Laplacian of Gaussian (LOG) feature detection. + + Parameters + ---------- + input_image : NDArray + Generate masks from this image. + str_elem_size : int, optional + Size of the laplacian filter used for contrast enhancement, odd number. + Increase in value increases sensitivity of contrast enhancement, + by default 23. + sigma : float, optional + Gaussian blur standard deviation, by default 0.4. + k_size : int, optional + Disk/ball size for mask dilation, ball for 3D and disk for 2D data, + by default 3. + msize : int, optional + Size of small objects removed to clean segmentation, by default 120. + + Returns + ------- + NDArray + Binary mask of input_image. """ - input_image_blur = gaussian(input_image, sigma=sigma) input_Lapl = laplace(input_image_blur, ksize=str_elem_size) @@ -61,17 +90,24 @@ def create_membrane_mask(input_image, str_elem_size=23, sigma=0.4, k_size=3, msi return mask -def get_unimodal_threshold(input_image): - """Determines optimal unimodal threshold +def get_unimodal_threshold(input_image: NDArray[Any]) -> float: + """Determine optimal unimodal threshold using Rosin's method. + References + ---------- https://users.cs.cf.ac.uk/Paul.Rosin/resources/papers/unimodal2.pdf https://www.mathworks.com/matlabcentral/fileexchange/45443-rosin-thresholding - :param np.array input_image: generate mask for this image - :return float best_threshold: optimal lower threshold for the foreground - hist - """ + Parameters + ---------- + input_image : NDArray + Generate mask for this image. + Returns + ------- + float + Optimal lower threshold for the foreground histogram. + """ hist_counts, bin_edges = np.histogram( input_image, bins=256, @@ -105,17 +141,27 @@ def get_unimodal_threshold(input_image): return best_threshold -def create_unimodal_mask(input_image, str_elem_size=3, sigma=0.6): - """ - Create a mask with unimodal thresholding and morphological operations. - Unimodal thresholding seems to oversegment, erode it by a fraction +def create_unimodal_mask( + input_image: NDArray[Any], str_elem_size: int = 3, sigma: float = 0.6 +) -> NDArray[np.bool_]: + """Create a mask with unimodal thresholding and morphological operations. - :param np.array input_image: generate masks from this image - :param int str_elem_size: size of the structuring element. typically 3, 5 - :param float sigma: gaussian blur standard deviation - :return mask of input_image, np.array - """ + Unimodal thresholding seems to oversegment, erode it by a fraction. + Parameters + ---------- + input_image : NDArray + Generate masks from this image. + str_elem_size : int, optional + Size of the structuring element, typically 3 or 5, by default 3. + sigma : float, optional + Gaussian blur standard deviation, by default 0.6. + + Returns + ------- + NDArray + Binary mask of input_image. + """ input_image = gaussian(input_image, sigma=sigma) if np.min(input_image) == np.max(input_image): @@ -133,21 +179,31 @@ def create_unimodal_mask(input_image, str_elem_size=3, sigma=0.6): return mask -def get_unet_border_weight_map(annotation, w0=10, sigma=5): - """ - Return weight map for borders as specified in UNet paper - :param annotation A 2D array of shape (image_height, image_width) - contains annotation with each class labeled as an integer. - :param w0 multiplier to the exponential distance loss - default 10 as mentioned in UNet paper - :param sigma standard deviation in the exponential distance term - e^(-d1 + d2) ** 2 / 2 (sigma ^ 2) - default 5 as mentioned in UNet paper - :return weight mapt for borders as specified in UNet - - TODO: Calculate boundaries directly and calculate distance - from boundary of cells to another - Note: The below method only works for UNet Segmentation only +def get_unet_border_weight_map( + annotation: NDArray[Any], w0: int = 10, sigma: int = 5 +) -> NDArray[np.float64]: + """Return weight map for borders as specified in U-Net paper. + + TODO: Calculate boundaries directly and calculate distance from boundary + of cells to another. Note: The below method only works for UNet Segmentation only. + + Parameters + ---------- + annotation : NDArray + A 2D array of shape (image_height, image_width) containing annotation + with each class labeled as an integer. + w0 : int, optional + Multiplier to the exponential distance loss, default 10 as mentioned + in UNet paper, by default 10. + sigma : int, optional + Standard deviation in the exponential distance term + e^(-d1 + d2) ** 2 / 2 (sigma ^ 2), default 5 as mentioned in UNet paper, + by default 5. + + Returns + ------- + NDArray + Weight map for borders as specified in U-Net paper. """ # if there is only one label, zero return the array as is if np.sum(annotation) == 0: @@ -160,7 +216,7 @@ def get_unet_border_weight_map(annotation, w0=10, sigma=5): assert annotation.dtype in [ np.uint8, np.uint16, - ], "Expected data type uint, it is {}".format(annotation.dtype) + ], f"Expected data type uint, it is {annotation.dtype}" # cells instances for distance computation # 4 connected i.e default (cross-shaped) diff --git a/viscy/utils/meta_utils.py b/viscy/utils/meta_utils.py index f948fb849..3547c39a7 100644 --- a/viscy/utils/meta_utils.py +++ b/viscy/utils/meta_utils.py @@ -10,24 +10,31 @@ from viscy.utils.mp_utils import get_val_stats -def write_meta_field(position: ngff.Position, metadata, field_name, subfield_name): - """ - Writes 'metadata' to position's plate-level or FOV level .zattrs metadata by either - creating a new field (field_name) according to 'metadata', or updating the metadata - to an existing field if found, - or concatenating the metadata from different channels. +def write_meta_field( + position: ngff.Position, metadata: dict, field_name: str, subfield_name: str +): + """Write metadata to position's plate-level or FOV level .zattrs metadata. - Assumes that the zarr store group given follows the OMG-NGFF HCS - format as specified here: - https://ngff.openmicroscopy.org/latest/#hcs-layout + Write metadata to position's plate-level or FOV level .zattrs metadata by either + creating a new field (field_name) according to metadata, or updating the metadata + to an existing field if found, or concatenating the metadata from different channels. - Warning: Dangerous. Writing metadata fields above the image-level of - an HCS hierarchy can break HCS compatibility + Assumes that the zarr store group given follows the OME-NGFF HCS + format as specified here: https://ngff.openmicroscopy.org/latest/#hcs-layout - :param Position zarr_dir: NGFF position node object - :param dict metadata: metadata dictionary to write to JSON .zattrs - :param str subfield_name: name of subfield inside the the main field - (values for different channels) + Warning: Dangerous. Writing metadata fields above the image-level of + an HCS hierarchy can break HCS compatibility. + + Parameters + ---------- + position : ngff.Position + NGFF position node object. + metadata : dict + Metadata dictionary to write to JSON .zattrs. + field_name : str + Name of the main metadata field. + subfield_name : str + Name of subfield inside the main field (values for different channels). """ if field_name in position.zattrs: if subfield_name in position.zattrs[field_name]: @@ -62,31 +69,32 @@ def _grid_sample( def generate_normalization_metadata( - zarr_dir, num_workers=4, channel_ids=-1, grid_spacing=32 + zarr_dir: str, num_workers: int = 4, channel_ids: int = -1, grid_spacing: int = 32 ): - """ + """Generate pixel intensity metadata for on-the-fly normalization. + Generate pixel intensity metadata to be later used in on-the-fly normalization during training and inference. Sampling is used for efficient estimation of median and interquartile range for intensity values on both a dataset and field-of-view - level. - - Normalization values are recorded in the image-level metadata in the corresponding - position of each zarr_dir store. Format of metadata is as follows: - { channel_idx : { dataset_statistics: dataset level normalization values (positive float), fov_statistics: field-of-view level normalization values (positive float) - }, - . - . - . + } } - :param str zarr_dir: path to zarr store directory containing dataset. - :param int num_workers: number of cpu workers for multiprocessing, defaults to 4 - :param list/int channel_ids: indices of channels to process in dataset arrays, - by default calculates all - :param int grid_spacing: distance between points in sampling grid + Warning: Dangerous. Writing metadata fields above the image-level of + an HCS hierarchy can break HCS compatibility. + + Parameters + ---------- + zarr_dir : str + Path to zarr store directory containing dataset. + num_workers : int, optional + Number of CPU workers for multiprocessing, by default 4. + channel_ids : list[int] | int, optional + Indices of channels to process in dataset arrays, by default -1 (all channels). + grid_spacing : int, optional + Distance between points in sampling grid, by default 32. """ plate = ngff.open_ome_zarr(zarr_dir, mode="r+") position_map = list(plate.positions()) @@ -139,22 +147,31 @@ def generate_normalization_metadata( def compute_zscore_params( frames_meta, ints_meta, input_dir, normalize_im, min_fraction=0.99 ): - """ - Get zscore median and interquartile range - - :param pd.DataFrame frames_meta: Dataframe containing all metadata - :param pd.DataFrame ints_meta: Metadata containing intensity statistics - each z-slice and foreground fraction for masks - :param str input_dir: Directory containing images - :param None or str normalize_im: normalization scheme for input images - :param float min_fraction: Minimum foreground fraction (in case of masks) + """Compute normalization statistics from image data using grid sampling. + + Compute zscore median and interquartile range. + + Parameters + ---------- + frames_meta : pd.DataFrame + Dataframe containing all metadata. + ints_meta : pd.DataFrame + Metadata containing intensity statistics each z-slice and foreground fraction for masks. + input_dir : str + Directory containing images. + normalize_im : None or str + Normalization scheme for input images. + min_fraction : float + Minimum foreground fraction (in case of masks) for computing intensity statistics. for computing intensity statistics. - :return pd.DataFrame frames_meta: Dataframe containing all metadata - :return pd.DataFrame ints_meta: Metadata containing intensity statistics - each z-slice + Returns + ------- + tuple[pd.DataFrame, pd.DataFrame] + Tuple containing: + - pd.DataFrame frames_meta: Dataframe containing all metadata + - pd.DataFrame ints_meta: Metadata containing intensity statistics of each z-slice """ - assert normalize_im in [ None, "slice", diff --git a/viscy/utils/mp_utils.py b/viscy/utils/mp_utils.py index 4db77e4de..65fc78071 100644 --- a/viscy/utils/mp_utils.py +++ b/viscy/utils/mp_utils.py @@ -1,19 +1,35 @@ +from collections.abc import Callable from concurrent.futures import ProcessPoolExecutor +from typing import Any import iohub.ngff as ngff import numpy as np import scipy.stats +import zarr +from numpy.typing import NDArray import viscy.utils.image_utils as image_utils import viscy.utils.masks as mask_utils -def mp_wrapper(fn, fn_args, workers): - """Create and save masks with multiprocessing - - :param list of tuple fn_args: list with tuples of function arguments - :param int workers: max number of workers - :return: list of returned dicts from create_save_mask +def mp_wrapper( + fn: Callable[..., Any], fn_args: list[tuple[Any, ...]], workers: int +) -> list[Any]: + """Create and save masks with multiprocessing. + + Parameters + ---------- + fn : callable + Function to be applied with multiprocessing. + fn_args : list of tuple + List with tuples of function arguments. + workers : int + Max number of workers. + + Returns + ------- + list + List of returned dicts from create_save_mask. """ with ProcessPoolExecutor(workers) as ex: # can't use map directly as it works only with single arg functions @@ -21,13 +37,22 @@ def mp_wrapper(fn, fn_args, workers): return list(res) -def mp_create_and_write_mask(fn_args, workers): - """Create and save masks with multiprocessing. For argument parameters - see mp_utils.create_and_write_mask. +def mp_create_and_write_mask(fn_args: list[tuple[Any, ...]], workers: int) -> list[Any]: + """Create and save masks with multiprocessing. + + For argument parameters see mp_utils.create_and_write_mask. - :param list of tuple fn_args: list with tuples of function arguments - :param int workers: max number of workers - :return: list of returned dicts from create_save_mask + Parameters + ---------- + fn_args : list of tuple + List with tuples of function arguments. + workers : int + Max number of workers. + + Returns + ------- + list + List of returned dicts from create_save_mask. """ with ProcessPoolExecutor(workers) as ex: # can't use map directly as it works only with single arg functions @@ -37,29 +62,35 @@ def mp_create_and_write_mask(fn_args, workers): def add_channel( position: ngff.Position, - new_channel_array, - new_channel_name, - overwrite_ok=False, -): - """ - Adds a channels to the data array at position "position". Note that there is - only one 'tracked' data array in current HCS spec at each position. Also - updates the 'omero' channel-tracking metadata to track the new channel. + new_channel_array: NDArray, + new_channel_name: str, + overwrite_ok: bool = False, +) -> None: + """Add a channel to the data array at specified position. + + Note that there is only one 'tracked' data array in current HCS spec at each position. + Also updates the 'omero' channel-tracking metadata to track the new channel. The 'new_channel_array' must match the dimensions of the current array in - all positions but the channel position (1) and have the same datatype + all positions but the channel position (1) and have the same datatype. Note: to maintain HCS compatibility of the zarr store, all positions (wells) must maintain arrays with congruent channels. That is, if you add a channel to one position of an HCS compatible zarr store, an additional channel must be added to every position in that store to maintain HCS compatibility. - :param Position zarr_dir: NGFF position node object - :param np.ndarray new_channel_array: array to add as new channel with matching - dimensions (except channel dim) and dtype - :param str new_channel_name: name of new channel - :param bool overwrite_ok: if true, if a channel with the same name as - 'new_channel_name' is found, will overwrite + Parameters + ---------- + position : ngff.Position + NGFF position node object. + new_channel_array : NDArray + Array to add as new channel with matching dimensions (except channel dim) + and dtype. + new_channel_name : str + Name of new channel. + overwrite_ok : bool, optional + If true, if a channel with the same name as 'new_channel_name' is found, + will overwrite, by default False. """ assert len(new_channel_array.shape) == len(position.data.shape) - 1, ( "New channel array must match all dimensions of the position array, " @@ -82,20 +113,18 @@ def add_channel( def create_and_write_mask( position: ngff.Position, - time_indices, - channel_indices, - structure_elem_radius, - mask_type, - mask_name, - verbose=False, -): - # TODO: rewrite docstring - """ - Create mask *for all depth slices* at each time and channel index specified - in this position, and save them both as an additional channel in the data array - of the given zarr store and a separate 'untracked' array with specified name. - If output_channel_index is specified as an existing channel index, will overwrite - this channel instead. + time_indices: list[int], + channel_indices: list[int], + structure_elem_radius: int, + mask_type: str, + mask_name: str, + verbose: bool = False, +) -> None: + """Create mask for all depth slices at specified time and channel indices. + + Creates masks at each time and channel index specified in this position, + and saves them both as an additional channel in the data array of the given + zarr store and a separate 'untracked' array with specified name. Saves custom metadata related to the mask creation in the well-level .zattrs in the 'mask' field. @@ -105,24 +134,25 @@ def create_and_write_mask( a timepoint-position basis. That is, it will be recorded as an average foreground fraction over all slices in any given timepoint. - - :param str zarr_dir: directory to HCS compatible zarr store for usage - :param str position_path: path within store to position to generate masks for - :param list time_indices: list of time indices for mask generation, - if an index is skipped over, will populate with - zeros - :param list channel_indices: list of channel indices for mask generation, - if more than 1 channel specified, masks from all - channels are aggregated - :param int structure_elem_radius: size of structuring element used for binary - opening. str_elem: disk or ball - :param str mask_type: thresholding type used for masking or str to map to - masking function - :param str mask_name: name under which to save untracked copy of mask in - position - :param bool verbose: whether this process should send updates to stdout + Parameters + ---------- + position : ngff.Position + NGFF position node object. + time_indices : list + List of time indices for mask generation. If an index is skipped over, + will populate with zeros. + channel_indices : list + List of channel indices for mask generation. If more than 1 channel + specified, masks from all channels are aggregated. + structure_elem_radius : int + Size of structuring element used for binary opening. str_elem: disk or ball. + mask_type : str + Thresholding type used for masking or str to map to masking function. + mask_name : str + Name under which to save untracked copy of mask in position. + verbose : bool, optional + Whether this process should send updates to stdout, by default False. """ - shape = position.data.shape position_masks_shape = tuple([shape[0], len(channel_indices), *shape[2:]]) @@ -195,25 +225,35 @@ def create_and_write_mask( def get_mask_slice( - position_zarr, - time_index, - channel_index, - mask_type, - structure_elem_radius, -): - """ + position_zarr: zarr.Array, + time_index: int, + channel_index: int, + mask_type: str, + structure_elem_radius: int, +) -> NDArray: + """Compute mask for a single image slice. + Given a set of indices, mask type, and structuring element, pulls an image slice from the given zarr array, computes the requested mask and returns. - :param zarr.Array position_zarr: zarr array of the desired position - :param time_index: see name - :param channel_index: see name - :param mask_type: see name, - options are {otsu, unimodal, mem_detection, borders_weight_loss_map} - :param int structure_elem_radius: creation radius for the structuring - element - :return np.ndarray mask: 2d mask for this slice + Parameters + ---------- + position_zarr : zarr.Array + Zarr array of the desired position. + time_index : int + Time index for the slice. + channel_index : int + Channel index for the slice. + mask_type : str + Mask type, options are {otsu, unimodal, mem_detection, borders_weight_loss_map}. + structure_elem_radius : int + Creation radius for the structuring element. + + Returns + ------- + NDArray + 2D mask for this slice. """ # read and correct/preprocess slice im = position_zarr[time_index, channel_index] @@ -237,13 +277,20 @@ def get_mask_slice( return mask -def mp_get_val_stats(fn_args, workers): - """ - Computes statistics of numpy arrays with multiprocessing +def mp_get_val_stats(fn_args: list[Any], workers: int) -> list[dict[str, float]]: + """Compute statistics of numpy arrays with multiprocessing. + + Parameters + ---------- + fn_args : list of tuple + List with tuples of function arguments. + workers : int + Max number of workers. - :param list of tuple fn_args: list with tuples of function arguments - :param int workers: max number of workers - :return: list of returned df from get_im_stats + Returns + ------- + list[dict[str, float]] + List of returned df from get_im_stats. """ with ProcessPoolExecutor(workers) as ex: # can't use map directly as it works only with single arg functions @@ -251,16 +298,22 @@ def mp_get_val_stats(fn_args, workers): return list(res) -def get_val_stats(sample_values): - """ +def get_val_stats(sample_values: list[float]) -> dict[str, float]: + """Compute statistics of a numpy array. + Computes the statistics of a numpy array and returns a dictionary of metadata corresponding to input sample values. - :param list(float) sample_values: List of sample values at respective - indices - :return dict meta_row: Dict with intensity data for image - """ + Parameters + ---------- + sample_values : list of float + List of sample values at respective indices. + Returns + ------- + dict[str, float] + Dictionary with intensity data for image. + """ meta_row = { "mean": float(np.nanmean(sample_values)), "std": float(np.nanstd(sample_values)), diff --git a/viscy/utils/normalize.py b/viscy/utils/normalize.py index 73753acb7..e86eae818 100644 --- a/viscy/utils/normalize.py +++ b/viscy/utils/normalize.py @@ -1,19 +1,33 @@ -"""Image normalization related functions""" +"""Image normalization related functions.""" import sys +from typing import Any import numpy as np +from numpy.typing import ArrayLike, NDArray from skimage.exposure import equalize_adapthist -def zscore(input_image, im_mean=None, im_std=None): - """ - Performs z-score normalization. Adds epsilon in denominator for robustness +def zscore( + input_image: NDArray, im_mean: float | None = None, im_std: float | None = None +) -> NDArray[Any]: + """Perform z-score normalization. + + Adds epsilon in denominator for robustness. + + Parameters + ---------- + input_image : NDArray + Input image for intensity normalization. + im_mean : float, optional + Image mean, by default None. + im_std : float, optional + Image std, by default None. - :param np.array input_image: input image for intensity normalization - :param float/None im_mean: Image mean - :param float/None im_std: Image std - :return np.array norm_img: z score normalized image + Returns + ------- + NDArray + Z-score normalized image. """ if not im_mean: im_mean = np.nanmean(input_image) @@ -23,50 +37,85 @@ def zscore(input_image, im_mean=None, im_std=None): return norm_img -def unzscore(im_norm, zscore_median, zscore_iqr): - """ - Revert z-score normalization applied during preprocessing. Necessary - before computing SSIM +def unzscore(im_norm: NDArray, zscore_median: float, zscore_iqr: float) -> NDArray[Any]: + """Revert z-score normalization applied during preprocessing. + + Necessary before computing SSIM. - :param im_norm: Normalized image for un-zscore - :param zscore_median: Image median - :param zscore_iqr: Image interquartile range - :return im: image at its original scale + Parameters + ---------- + im_norm : NDArray + Normalized image for un-zscore. + zscore_median : float + Image median. + zscore_iqr : float + Image interquartile range. + + Returns + ------- + NDArray + Image at its original scale. """ im = im_norm * (zscore_iqr + sys.float_info.epsilon) + zscore_median return im -def hist_clipping(input_image, min_percentile=2, max_percentile=98): - """Clips and rescales histogram from min to max intensity percentiles - - rescale_intensity with input check - - :param np.array input_image: input image for intensity normalization - :param int/float min_percentile: min intensity percentile - :param int/flaot max_percentile: max intensity percentile - :return: np.float, intensity clipped and rescaled image +def hist_clipping( + input_image: NDArray, + min_percentile: int | float = 2, + max_percentile: int | float = 98, +) -> NDArray[Any]: + """Clip and rescale histogram from min to max intensity percentiles. + + rescale_intensity with input check. + + Parameters + ---------- + input_image : NDArray + Input image for intensity normalization. + min_percentile : int or float, optional + Min intensity percentile, by default 2. + max_percentile : int or float, optional + Max intensity percentile, by default 98. + + Returns + ------- + NDArray + Intensity clipped and rescaled image. """ - assert (min_percentile < max_percentile) and max_percentile <= 100 pmin, pmax = np.percentile(input_image, (min_percentile, max_percentile)) hist_clipped_image = np.clip(input_image, pmin, pmax) return hist_clipped_image -def hist_adapteq_2D(input_image, kernel_size=None, clip_limit=None): - """CLAHE on 2D images +def hist_adapteq_2D( + input_image: NDArray, + kernel_size: int | list[int] | tuple[int, ...] | None = None, + clip_limit: float | None = None, +) -> NDArray: + """Apply CLAHE (Contrast Limited Adaptive Histogram Equalization) on 2D images. skimage.exposure.equalize_adapthist works only for 2D. Extend to 3D or use - openCV? Not ideal, as it enhances noise in homogeneous areas - - :param np.array input_image: input image for intensity normalization - :param int/list kernel_size: Neighbourhood to be used for histogram - equalization. If none, use default of 1/8th image size. - :param float clip_limit: Clipping limit, normalized between 0 and 1 - (higher values give more contrast, ~ max percent of voxels in any - histogram bin, if > this limit, the voxel intensities are redistributed). - if None, default=0.01 + openCV? Not ideal, as it enhances noise in homogeneous areas. + + Parameters + ---------- + input_image : NDArray + Input image for intensity normalization. + kernel_size : int or list, optional + Neighbourhood to be used for histogram equalization. If None, use default + of 1/8th image size, by default None. + clip_limit : float, optional + Clipping limit, normalized between 0 and 1 (higher values give more + contrast, ~ max percent of voxels in any histogram bin, if > this limit, + the voxel intensities are redistributed). If None, default=0.01, + by default None. + + Returns + ------- + NDArray + Adaptive histogram equalized image. """ nrows, ncols = input_image.shape if kernel_size is not None: @@ -78,9 +127,7 @@ def hist_adapteq_2D(input_image, kernel_size=None, clip_limit=None): raise ValueError("kernel size invalid: not an int / list / tuple") if clip_limit is not None: - assert 0 <= clip_limit <= 1, "Clip limit {} is out of range [0, 1]".format( - clip_limit - ) + assert 0 <= clip_limit <= 1, f"Clip limit {clip_limit} is out of range [0, 1]" adapt_eq_image = equalize_adapthist( input_image, kernel_size=kernel_size, clip_limit=clip_limit diff --git a/viscy/utils/slurm_utils.py b/viscy/utils/slurm_utils.py index 9cfafb84b..c943b8b11 100644 --- a/viscy/utils/slurm_utils.py +++ b/viscy/utils/slurm_utils.py @@ -36,7 +36,8 @@ def calculate_dataloader_settings( Returns ------- - dict: Recommended settings for DataLoader + dict: + Dictionary with recommended settings for DataLoader """ # Get system resources if not provided if available_ram_gb is None: