Skip to content

Commit 100a2be

Browse files
committed
Updated some batched transforms integration docs
1 parent 22b2e59 commit 100a2be

14 files changed

+168
-15
lines changed

tests/transforms/test_adjust_contrast.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import pytest
22
import torch
33
from monai.transforms import AdjustContrast, Compose
4-
54
from viscy.transforms import BatchedRandAdjustContrast, BatchedRandAdjustContrastd
65

76

tests/transforms/test_crop.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import pytest
22
import torch
33
from monai.transforms import Compose
4-
54
from viscy.transforms._crop import (
65
BatchedCenterSpatialCrop,
76
BatchedCenterSpatialCropd,

tests/transforms/test_flip.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import pytest
22
import torch
3-
43
from viscy.transforms import BatchedRandFlip, BatchedRandFlipd
54

65

tests/transforms/test_gaussian_smooth.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
get_gaussian_kernel3d,
88
)
99
from monai.transforms.intensity.array import GaussianSmooth
10-
1110
from viscy.transforms import BatchedRandGaussianSmooth, BatchedRandGaussianSmoothd
1211
from viscy.transforms._gaussian_smooth import filter3d_separable
1312

tests/transforms/test_noise.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import pytest
22
import torch
33
from monai.transforms import Compose
4-
54
from viscy.transforms import BatchedRandGaussianNoise, BatchedRandGaussianNoised
65

76

tests/transforms/test_scale_intensity.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import pytest
22
import torch
33
from monai.transforms import RandScaleIntensity
4-
54
from viscy.transforms import BatchedRandScaleIntensity, BatchedRandScaleIntensityd
65

76

tests/transforms/test_transforms.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import pytest
22
import torch
3-
43
from viscy.transforms._decollate import Decollate
54
from viscy.transforms._transforms import (
65
BatchedScaleIntensityRangePercentiles,

viscy/data/triplet.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,9 @@ def _get_tensorstore(self, position: Position) -> ts.TensorStore:
207207
return self._tensorstores[fov_name]
208208

209209
def _filter_tracks(self, tracks_tables: list[pd.DataFrame]) -> pd.DataFrame:
210-
"""Exclude tracks that are too close to the border
211-
or do not have the next time point.
210+
"""
211+
212+
Exclude tracks that are too close to the border or do not have the next time point.
212213
213214
Parameters
214215
----------

viscy/transforms/batched_rand_3d_elasticd.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,29 @@
55

66

77
class BatchedRand3DElasticd(MapTransform, RandomizableTransform):
8-
"""Batched 3D elastic deformation for biological structures."""
8+
"""Apply random 3D elastic deformation to input data.
9+
10+
Uses Gaussian-smoothed displacement fields to simulate deformation.
11+
12+
Parameters
13+
----------
14+
keys : str or Iterable[str]
15+
Keys of the corresponding items to be transformed.
16+
sigma_range : tuple[float, float]
17+
Range for random sigma values used in Gaussian smoothing.
18+
magnitude_range : tuple[float, float]
19+
Range for random displacement magnitude values.
20+
spatial_size : tuple[int, int, int] or int or None, optional
21+
Expected spatial size of input data.
22+
prob : float, optional
23+
Probability of applying the transform, by default 0.1.
24+
mode : str, optional
25+
Interpolation mode for grid sampling, by default "bilinear".
26+
padding_mode : str, optional
27+
Padding mode for grid sampling, by default "reflection".
28+
allow_missing_keys : bool, optional
29+
Whether to ignore missing keys, by default False.
30+
"""
931

1032
def __init__(
1133
self,
@@ -29,7 +51,6 @@ def __init__(
2951
def _generate_elastic_field(
3052
self, shape: torch.Size, device: torch.device
3153
) -> Tensor:
32-
"""Generate batched elastic deformation field."""
3354
batch_size = shape[0]
3455
spatial_dims = shape[2:] # Skip batch and channel
3556

@@ -76,6 +97,18 @@ def _generate_elastic_field(
7697
return torch.stack(displacement_fields)
7798

7899
def __call__(self, sample: dict[str, Tensor]) -> dict[str, Tensor]:
100+
"""Apply elastic deformation to sample data.
101+
102+
Parameters
103+
----------
104+
sample : dict[str, Tensor]
105+
Dictionary containing image tensors to transform.
106+
107+
Returns
108+
-------
109+
dict[str, Tensor]
110+
Dictionary with transformed tensors.
111+
"""
79112
self.randomize(None)
80113
d = dict(sample)
81114

viscy/transforms/batched_rand_histogram_shiftd.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,21 @@
55

66

77
class BatchedRandHistogramShiftd(MapTransform, RandomizableTransform):
8-
"""Batched random histogram shifting for intensity distribution changes."""
8+
"""Apply random histogram shifts to modify intensity distributions.
9+
10+
Adds random intensity offsets to simulate illumination variations.
11+
12+
Parameters
13+
----------
14+
keys : str or Iterable[str]
15+
Keys of the corresponding items to be transformed.
16+
shift_range : tuple[float, float], optional
17+
Range for random intensity shift values, by default (-0.1, 0.1).
18+
prob : float, optional
19+
Probability of applying the transform, by default 0.1.
20+
allow_missing_keys : bool, optional
21+
Whether to ignore missing keys, by default False.
22+
"""
923

1024
def __init__(
1125
self,
@@ -19,6 +33,18 @@ def __init__(
1933
self.shift_range = shift_range
2034

2135
def __call__(self, sample: dict[str, Tensor]) -> dict[str, Tensor]:
36+
"""Apply histogram shift to sample data.
37+
38+
Parameters
39+
----------
40+
sample : dict[str, Tensor]
41+
Dictionary containing image tensors to transform.
42+
43+
Returns
44+
-------
45+
dict[str, Tensor]
46+
Dictionary with intensity-shifted tensors.
47+
"""
2248
self.randomize(None)
2349
d = dict(sample)
2450

0 commit comments

Comments
 (0)