diff --git a/tests/transforms/test_zoom.py b/tests/transforms/test_zoom.py new file mode 100644 index 000000000..850e7d86a --- /dev/null +++ b/tests/transforms/test_zoom.py @@ -0,0 +1,52 @@ +import torch + +from viscy.transforms._zoom import BatchedZoom, BatchedZoomd + + +def test_batched_zoom(): + """Test BatchedZoom transform.""" + batch_size = 2 + channels = 3 + depth, height, width = 8, 16, 16 + data = torch.rand(batch_size, channels, depth, height, width) + + transform = BatchedZoom(scale_factor=0.5, mode="area") + result = transform(data) + + expected_shape = (batch_size, channels, depth // 2, height // 2, width // 2) + assert result.shape == expected_shape + assert torch.allclose(result.mean(), data.mean(), rtol=1e-3) + + +def test_batched_zoomd(): + """Test BatchedZoomd dictionary transform.""" + batch_size = 2 + channels = 1 + depth, height, width = 4, 8, 8 + data = { + "image": torch.rand(batch_size, channels, depth, height, width), + "label": torch.rand(batch_size, channels, depth, height, width), + } + + transform = BatchedZoomd(keys=["image", "label"], scale_factor=2.0, mode="nearest") + result = transform(data) + + expected_shape = (batch_size, channels, depth * 2, height * 2, width * 2) + assert result["image"].shape == expected_shape + assert result["label"].shape == expected_shape + + +def test_batched_zoom_roundtrip(): + """Test roundtrip zoom (2x then 0.5x) returns close to original.""" + batch_size = 4 + channels = 3 + depth, height, width = 4, 8, 8 + data = torch.rand(batch_size, channels, depth, height, width) + + zoom_in = BatchedZoom(scale_factor=2.0, mode="nearest") + zoom_out = BatchedZoom(scale_factor=0.5, mode="nearest") + + zoomed_in = zoom_in(data) + zoomed_out = zoom_out(zoomed_in) + + assert torch.all(data == zoomed_out) diff --git a/viscy/transforms/__init__.py b/viscy/transforms/__init__.py index e3c1d6012..7ce9f0f8b 100644 --- a/viscy/transforms/__init__.py +++ b/viscy/transforms/__init__.py @@ -42,12 +42,12 @@ BatchedRandAffined, BatchedScaleIntensityRangePercentiles, BatchedScaleIntensityRangePercentilesd, - BatchedZoom, NormalizeSampled, RandInvertIntensityd, StackChannelsd, TiledSpatialCropSamplesd, ) +from viscy.transforms._zoom import BatchedZoom, BatchedZoomd from viscy.transforms.batched_rand_3d_elasticd import BatchedRand3DElasticd from viscy.transforms.batched_rand_histogram_shiftd import BatchedRandHistogramShiftd from viscy.transforms.batched_rand_local_pixel_shufflingd import ( @@ -80,6 +80,7 @@ "BatchedScaleIntensityRangePercentiles", "BatchedScaleIntensityRangePercentilesd", "BatchedZoom", + "BatchedZoomd", "CenterSpatialCropd", "Decollate", "Decollated", diff --git a/viscy/transforms/_transforms.py b/viscy/transforms/_transforms.py index 1c174ec27..601d51168 100644 --- a/viscy/transforms/_transforms.py +++ b/viscy/transforms/_transforms.py @@ -8,7 +8,6 @@ MultiSampleTrait, RandomizableTransform, ScaleIntensityRangePercentiles, - Transform, ) from numpy.typing import DTypeLike from torch import Tensor @@ -163,42 +162,6 @@ def __call__(self, sample: Sample) -> Sample: return results -class BatchedZoom(Transform): - "Batched zoom transform using ``torch.nn.functional.interpolate``." - - def __init__( - self, - scale_factor: float | tuple[float, float, float], - mode: Literal[ - "nearest", - "nearest-exact", - "linear", - "bilinear", - "bicubic", - "trilinear", - "area", - ], - align_corners: bool | None = None, - recompute_scale_factor: bool | None = None, - antialias: bool = False, - ) -> None: - self.scale_factor = scale_factor - self.mode = mode - self.align_corners = align_corners - self.recompute_scale_factor = recompute_scale_factor - self.antialias = antialias - - def __call__(self, sample: Tensor) -> Tensor: - return torch.nn.functional.interpolate( - sample, - scale_factor=self.scale_factor, - mode=self.mode, - align_corners=self.align_corners, - recompute_scale_factor=self.recompute_scale_factor, - antialias=self.antialias, - ) - - class BatchedScaleIntensityRangePercentiles(ScaleIntensityRangePercentiles): def _normalize(self, img: Tensor) -> Tensor: q_low = self.lower / 100.0 diff --git a/viscy/transforms/_zoom.py b/viscy/transforms/_zoom.py new file mode 100644 index 000000000..c7964dd3a --- /dev/null +++ b/viscy/transforms/_zoom.py @@ -0,0 +1,78 @@ +from typing import Sequence + +import torch +from monai.transforms import MapTransform, Transform +from torch import Tensor +from typing_extensions import Literal + + +class BatchedZoom(Transform): + "Batched zoom transform using ``torch.nn.functional.interpolate``." + + def __init__( + self, + scale_factor: float | tuple[float, float, float], + mode: Literal[ + "nearest", + "nearest-exact", + "linear", + "bilinear", + "bicubic", + "trilinear", + "area", + ], + align_corners: bool | None = None, + recompute_scale_factor: bool | None = None, + antialias: bool = False, + ) -> None: + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + self.recompute_scale_factor = recompute_scale_factor + self.antialias = antialias + + def __call__(self, sample: Tensor) -> Tensor: + return torch.nn.functional.interpolate( + sample, + scale_factor=self.scale_factor, + mode=self.mode, + align_corners=self.align_corners, + recompute_scale_factor=self.recompute_scale_factor, + antialias=self.antialias, + ) + + +class BatchedZoomd(MapTransform): + "Dictionary wrapper of :py:class:`BatchedZoom`." + + def __init__( + self, + keys: Sequence[str], + scale_factor: float | tuple[float, float, float], + mode: Literal[ + "nearest", + "nearest-exact", + "linear", + "bilinear", + "bicubic", + "trilinear", + "area", + ], + align_corners: bool | None = None, + recompute_scale_factor: bool | None = None, + antialias: bool = False, + ) -> None: + super().__init__(keys) + self.transform = BatchedZoom( + scale_factor=scale_factor, + mode=mode, + align_corners=align_corners, + recompute_scale_factor=recompute_scale_factor, + antialias=antialias, + ) + + def __call__(self, data): + d = dict(data) + for key in self.keys: + d[key] = self.transform(d[key]) + return d