Skip to content

Commit 13aac03

Browse files
authored
Feat(Next-Gen CAREamics): draft convenience functions for NG-compatible CAREamics (N2V only) (#662)
## Description <!-- This section provides the necessary background and information for reviewers to understand the code and have the correct mindset when examining changes. --> > [!NOTE] > **tldr**: Refactor and adapt configuration factories for NG Dataset compatiblity This PR creates a copy of the convenience function for use with the NG Dataset. The problem we are currently facing is that the `Configuration` is not compatible with the NG Dataset, and there is no replacement for using the NG dataset without creating the configs explicitly (e.g. algorithm config). This PR introduces a N2V-only set of convenience functions and configuration class to create a `NGConfiguration` that is NG dataset-compatible. In essence: - Two new submodules `ng_factories` and `ng_configs`. The former holds the convenience functions, the latter global configurations. - Makes a `NGConfiguration` that accepts `NGDataConfig`, the only other difference with the old `Configuration` is that there is no `set_3D` method anymore (it was never used) and the axes/model conv dims are not silently changed by the validation in case of mismatch, but explicitly raise an error. - New `N2VConfiguration`, child of `NGConfiguration`, that performs N2V-specific validation. - New implementation of `create_n2v_configuration` (shadowing the original name, but in sub module `ng_factories`), which also has recently introduced `in_memory` and `channels` parameters. `channels` has a complex interaction with `n_channels` and `axes` since it is a parameter that allows defining which channels to use. - All convenience functions have been broken up into different modules for clarity (as opposed to the old and massive `configuration_factories.py`...). Note: I have not put much thought about the `channels`, `n_channels` and `axes` interaction. I implemented what made sense to me naively, we should have a close look and change it in another PR. ## Breaking changes <!-- Describe any breaking changes introduced by this PR. --> Unless code was importing `create_ng_data_config` from `careamics.config.configuration_factory` instead of from `careamics.config`, then all other code base should be unaffected. --- **Please ensure your PR meets the following requirements:** - [x] Code builds and passes tests locally, including doctests - [x] New tests have been added (for bug fixes/features) - [x] Pre-commit passes - [ ] PR to the documentation exists (for bug fixes / features)
1 parent a3fe03c commit 13aac03

22 files changed

+1667
-154
lines changed

src/careamics/config/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,14 @@
5151
create_microsplit_configuration,
5252
create_n2n_configuration,
5353
create_n2v_configuration,
54-
create_ng_data_configuration,
5554
create_pn2v_configuration,
5655
)
5756
from .data import DataConfig, NGDataConfig
5857
from .data.inference_config import InferenceConfig
5958
from .lightning.callbacks import CheckpointConfig
6059
from .lightning.training_config import TrainingConfig
6160
from .losses.loss_config import LVAELossConfig
61+
from .ng_factories.data_factory import create_ng_data_configuration
6262
from .noise_model import (
6363
GaussianMixtureNMConfig,
6464
MultiChannelNMConfig,

src/careamics/config/architectures/unet_config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,9 @@ def is_3D(self) -> bool:
114114
"""
115115
Return whether the model is 3D or not.
116116
117+
This method is used in the NG configuration validation to check that the model
118+
dimensions match the data dimensions.
119+
117120
Returns
118121
-------
119122
bool

src/careamics/config/configuration_factories.py

Lines changed: 1 addition & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
PN2VAlgorithm,
1414
)
1515
from careamics.config.architectures import LVAEConfig, UNetConfig
16-
from careamics.config.data import DataConfig, NGDataConfig
16+
from careamics.config.data import DataConfig
1717
from careamics.config.lightning.training_config import TrainingConfig
1818
from careamics.config.losses.loss_config import LVAELossConfig
1919
from careamics.config.noise_model.likelihood_config import (
@@ -357,99 +357,6 @@ def _create_microsplit_data_configuration(
357357
return MicroSplitDataConfig(**data)
358358

359359

360-
def create_ng_data_configuration(
361-
data_type: Literal["array", "tiff", "zarr", "czi", "custom"],
362-
axes: str,
363-
patch_size: Sequence[int],
364-
batch_size: int,
365-
augmentations: list[SPATIAL_TRANSFORMS_UNION] | None = None,
366-
channels: Sequence[int] | None = None,
367-
in_memory: bool | None = None,
368-
train_dataloader_params: dict[str, Any] | None = None,
369-
val_dataloader_params: dict[str, Any] | None = None,
370-
pred_dataloader_params: dict[str, Any] | None = None,
371-
seed: int | None = None,
372-
) -> NGDataConfig:
373-
"""
374-
Create a training NGDatasetConfig.
375-
376-
Parameters
377-
----------
378-
data_type : {"array", "tiff", "zarr", "czi", "custom"}
379-
Type of the data.
380-
axes : str
381-
Axes of the data.
382-
patch_size : list of int
383-
Size of the patches along the spatial dimensions.
384-
batch_size : int
385-
Batch size.
386-
augmentations : list of transforms
387-
List of transforms to apply.
388-
channels : Sequence of int, default=None
389-
List of channels to use. If `None`, all channels are used.
390-
in_memory : bool, default=None
391-
Whether to load all data into memory. This is only supported for 'array',
392-
'tiff' and 'custom' data types. If `None`, defaults to `True` for 'array',
393-
'tiff' and `custom`, and `False` for 'zarr' and 'czi' data types. Must be `True`
394-
for `array`.
395-
augmentations : list of transforms or None, default=None
396-
List of transforms to apply. If `None`, default augmentations are applied
397-
(flip in X and Y, rotations by 90 degrees in the XY plane).
398-
train_dataloader_params : dict
399-
Parameters for the training dataloader, see PyTorch notes, by default None.
400-
val_dataloader_params : dict
401-
Parameters for the validation dataloader, see PyTorch notes, by default None.
402-
pred_dataloader_params : dict
403-
Parameters for the test dataloader, see PyTorch notes, by default None.
404-
seed : int, default=None
405-
Random seed for reproducibility. If `None`, no seed is set.
406-
407-
Returns
408-
-------
409-
NGDataConfig
410-
Next-Generation Data model with the specified parameters.
411-
"""
412-
if augmentations is None:
413-
augmentations = _list_spatial_augmentations()
414-
415-
# data model
416-
data: dict[str, Any] = {
417-
"mode": "training",
418-
"data_type": data_type,
419-
"axes": axes,
420-
"batch_size": batch_size,
421-
"channels": channels,
422-
"transforms": augmentations,
423-
"seed": seed,
424-
}
425-
426-
if in_memory is not None:
427-
data["in_memory"] = in_memory
428-
429-
# don't override defaults set in DataConfig class
430-
if train_dataloader_params is not None:
431-
# the presence of `shuffle` key in the dataloader parameters is enforced
432-
# by the NGDataConfig class
433-
if "shuffle" not in train_dataloader_params:
434-
train_dataloader_params["shuffle"] = True
435-
436-
data["train_dataloader_params"] = train_dataloader_params
437-
438-
if val_dataloader_params is not None:
439-
data["val_dataloader_params"] = val_dataloader_params
440-
441-
if pred_dataloader_params is not None:
442-
data["pred_dataloader_params"] = pred_dataloader_params
443-
444-
# add training patching
445-
data["patching"] = {
446-
"name": "random",
447-
"patch_size": patch_size,
448-
}
449-
450-
return NGDataConfig(**data)
451-
452-
453360
def _create_training_configuration(
454361
trainer_params: dict,
455362
logger: Literal["wandb", "tensorboard", "none"],

src/careamics/config/data/ng_data_config.py

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -554,24 +554,26 @@ def validate_dimensions(self: Self) -> Self:
554554
ValueError
555555
If the patch size dimension is not compatible with the axes.
556556
"""
557-
if "Z" in self.axes:
558-
if (
559-
hasattr(self.patching, "patch_size")
560-
and len(self.patching.patch_size) != 3
561-
):
562-
raise ValueError(
563-
f"`patch_size` in `patching` must have 3 dimensions if the data is"
564-
f" 3D, got axes {self.axes})."
565-
)
557+
# "whole" patching does not have dimensions to validate
558+
if not hasattr(self.patching, "patch_size"):
559+
return self
560+
561+
if self.data_type == "czi":
562+
# Z and T are both depth axes for CZI data
563+
expected_dims = 3 if ("Z" in self.axes or "T" in self.axes) else 2
564+
additional_message = " (`Z` and `T` are depth axes for CZI data)"
566565
else:
567-
if (
568-
hasattr(self.patching, "patch_size")
569-
and len(self.patching.patch_size) != 2
570-
):
571-
raise ValueError(
572-
f"`patch_size` in `patching` must have 2 dimensions if the data is"
573-
f" 3D, got axes {self.axes})."
574-
)
566+
expected_dims = 3 if "Z" in self.axes else 2
567+
additional_message = ""
568+
569+
# infer dimension from requested patch size
570+
actual_dims = len(self.patching.patch_size)
571+
if actual_dims != expected_dims:
572+
raise ValueError(
573+
f"`patch_size` in `patching` must have {expected_dims} dimensions, "
574+
f"got {self.patching.patch_size} with axes {self.axes}"
575+
f"{additional_message}."
576+
)
575577

576578
return self
577579

@@ -780,6 +782,27 @@ def set_means_and_stds(
780782
target_stds=target_stds,
781783
)
782784

785+
def is_3D(self) -> bool:
786+
"""
787+
Check if the data is 3D based on the axes.
788+
789+
Either "Z" is in the axes and patching `patch_size` has 3 dimensions, or for CZI
790+
data, "Z" is in the axes or "T" is in the axes and patching `patch_size` has
791+
3 dimensions.
792+
793+
This method is used during NGConfiguration validation to cross checks dimensions
794+
with the algorithm configuration.
795+
796+
Returns
797+
-------
798+
bool
799+
True if the data is 3D, False otherwise.
800+
"""
801+
if self.data_type == "czi":
802+
return "Z" in self.axes or "T" in self.axes
803+
else:
804+
return "Z" in self.axes
805+
783806
# TODO: if switching from a state in which in_memory=True to an incompatible state
784807
# an error will be raised. Should that automatically be set to False instead?
785808
# TODO `channels=None` is ambigouous: all channels or same channels as in training?

src/careamics/config/lightning/training_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class TrainingConfig(BaseModel):
2929
model_config = ConfigDict(
3030
validate_assignment=True,
3131
)
32+
3233
lightning_trainer_config: dict | None = None
3334
"""Configuration for the PyTorch Lightning Trainer, following PyTorch Lightning
3435
Trainer class"""
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""Definitions of configurations for CAREamics, compatible with the NG dataset."""
2+
3+
__all__ = ["N2VConfiguration"]
4+
5+
from .n2v_configuration import N2VConfiguration
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
"""Configuration for N2V."""
2+
3+
from typing import Self
4+
5+
import numpy as np
6+
from pydantic import model_validator
7+
8+
from careamics.config.algorithms import N2VAlgorithm
9+
from careamics.config.data.patching_strategies import RandomPatchingConfig
10+
11+
from .ng_configuration import NGConfiguration
12+
13+
14+
class N2VConfiguration(NGConfiguration):
15+
"""N2V-specific configuration."""
16+
17+
algorithm_config: N2VAlgorithm
18+
19+
@model_validator(mode="after")
20+
def validate_n2v_mask_pixel_perc(self: Self) -> Self:
21+
"""
22+
Validate that there will always be at least one blind-spot pixel in every patch.
23+
24+
The probability of creating a blind-spot pixel is a function of the chosen
25+
masked pixel percentage and patch size.
26+
27+
Returns
28+
-------
29+
Self
30+
Validated configuration.
31+
32+
Raises
33+
------
34+
ValueError
35+
If the probability of masking a pixel within a patch is less than 1 for the
36+
chosen masked pixel percentage and patch size.
37+
"""
38+
if self.data_config.mode == "training":
39+
assert isinstance(self.data_config.patching, RandomPatchingConfig)
40+
41+
mask_pixel_perc = self.algorithm_config.n2v_config.masked_pixel_percentage
42+
patch_size = self.data_config.patching.patch_size
43+
expected_area_per_pixel = 1 / (mask_pixel_perc / 100)
44+
45+
n_dims = 3 if self.algorithm_config.model.is_3D() else 2
46+
patch_size_lower_bound = int(
47+
np.ceil(expected_area_per_pixel ** (1 / n_dims))
48+
)
49+
required_patch_size = tuple(
50+
2 ** int(np.ceil(np.log2(patch_size_lower_bound)))
51+
for _ in range(n_dims)
52+
)
53+
required_mask_pixel_perc = (1 / np.prod(patch_size)) * 100
54+
55+
if expected_area_per_pixel > np.prod(patch_size):
56+
raise ValueError(
57+
"The probability of creating a blind-spot pixel within a patch is "
58+
f"below 1, for a patch size of {patch_size} with a masked pixel "
59+
f"percentage of {mask_pixel_perc}%. Either increase the patch size "
60+
f"to {required_patch_size} or increase the masked pixel percentage "
61+
f"to at least {required_mask_pixel_perc}%."
62+
)
63+
64+
return self

0 commit comments

Comments
 (0)