Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions tests/transforms/test_zoom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import torch

from viscy.transforms._zoom import BatchedZoom, BatchedZoomd


def test_batched_zoom():
"""Test BatchedZoom transform."""
batch_size = 2
channels = 3
depth, height, width = 8, 16, 16
data = torch.rand(batch_size, channels, depth, height, width)

transform = BatchedZoom(scale_factor=0.5, mode="area")
result = transform(data)

expected_shape = (batch_size, channels, depth // 2, height // 2, width // 2)
assert result.shape == expected_shape
assert torch.allclose(result.mean(), data.mean(), rtol=1e-3)


def test_batched_zoomd():
"""Test BatchedZoomd dictionary transform."""
batch_size = 2
channels = 1
depth, height, width = 4, 8, 8
data = {
"image": torch.rand(batch_size, channels, depth, height, width),
"label": torch.rand(batch_size, channels, depth, height, width),
}

transform = BatchedZoomd(keys=["image", "label"], scale_factor=2.0, mode="nearest")
result = transform(data)

expected_shape = (batch_size, channels, depth * 2, height * 2, width * 2)
assert result["image"].shape == expected_shape
assert result["label"].shape == expected_shape


def test_batched_zoom_roundtrip():
"""Test roundtrip zoom (2x then 0.5x) returns close to original."""
batch_size = 4
channels = 3
depth, height, width = 4, 8, 8
data = torch.rand(batch_size, channels, depth, height, width)

zoom_in = BatchedZoom(scale_factor=2.0, mode="nearest")
zoom_out = BatchedZoom(scale_factor=0.5, mode="nearest")

zoomed_in = zoom_in(data)
zoomed_out = zoom_out(zoomed_in)

assert torch.all(data == zoomed_out)
3 changes: 2 additions & 1 deletion viscy/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,12 @@
BatchedRandAffined,
BatchedScaleIntensityRangePercentiles,
BatchedScaleIntensityRangePercentilesd,
BatchedZoom,
NormalizeSampled,
RandInvertIntensityd,
StackChannelsd,
TiledSpatialCropSamplesd,
)
from viscy.transforms._zoom import BatchedZoom, BatchedZoomd
from viscy.transforms.batched_rand_3d_elasticd import BatchedRand3DElasticd
from viscy.transforms.batched_rand_histogram_shiftd import BatchedRandHistogramShiftd
from viscy.transforms.batched_rand_local_pixel_shufflingd import (
Expand Down Expand Up @@ -80,6 +80,7 @@
"BatchedScaleIntensityRangePercentiles",
"BatchedScaleIntensityRangePercentilesd",
"BatchedZoom",
"BatchedZoomd",
"CenterSpatialCropd",
"Decollate",
"Decollated",
Expand Down
37 changes: 0 additions & 37 deletions viscy/transforms/_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
MultiSampleTrait,
RandomizableTransform,
ScaleIntensityRangePercentiles,
Transform,
)
from numpy.typing import DTypeLike
from torch import Tensor
Expand Down Expand Up @@ -163,42 +162,6 @@ def __call__(self, sample: Sample) -> Sample:
return results


class BatchedZoom(Transform):
"Batched zoom transform using ``torch.nn.functional.interpolate``."

def __init__(
self,
scale_factor: float | tuple[float, float, float],
mode: Literal[
"nearest",
"nearest-exact",
"linear",
"bilinear",
"bicubic",
"trilinear",
"area",
],
align_corners: bool | None = None,
recompute_scale_factor: bool | None = None,
antialias: bool = False,
) -> None:
self.scale_factor = scale_factor
self.mode = mode
self.align_corners = align_corners
self.recompute_scale_factor = recompute_scale_factor
self.antialias = antialias

def __call__(self, sample: Tensor) -> Tensor:
return torch.nn.functional.interpolate(
sample,
scale_factor=self.scale_factor,
mode=self.mode,
align_corners=self.align_corners,
recompute_scale_factor=self.recompute_scale_factor,
antialias=self.antialias,
)


class BatchedScaleIntensityRangePercentiles(ScaleIntensityRangePercentiles):
def _normalize(self, img: Tensor) -> Tensor:
q_low = self.lower / 100.0
Expand Down
78 changes: 78 additions & 0 deletions viscy/transforms/_zoom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from typing import Sequence

import torch
from monai.transforms import MapTransform, Transform
from torch import Tensor
from typing_extensions import Literal


class BatchedZoom(Transform):
"Batched zoom transform using ``torch.nn.functional.interpolate``."

def __init__(
self,
scale_factor: float | tuple[float, float, float],
mode: Literal[
"nearest",
"nearest-exact",
"linear",
"bilinear",
"bicubic",
"trilinear",
"area",
],
align_corners: bool | None = None,
recompute_scale_factor: bool | None = None,
antialias: bool = False,
) -> None:
self.scale_factor = scale_factor
self.mode = mode
self.align_corners = align_corners
self.recompute_scale_factor = recompute_scale_factor
self.antialias = antialias

def __call__(self, sample: Tensor) -> Tensor:
return torch.nn.functional.interpolate(
sample,
scale_factor=self.scale_factor,
mode=self.mode,
align_corners=self.align_corners,
recompute_scale_factor=self.recompute_scale_factor,
antialias=self.antialias,
)


class BatchedZoomd(MapTransform):
"Dictionary wrapper of :py:class:`BatchedZoom`."

def __init__(
self,
keys: Sequence[str],
scale_factor: float | tuple[float, float, float],
mode: Literal[
"nearest",
"nearest-exact",
"linear",
"bilinear",
"bicubic",
"trilinear",
"area",
],
align_corners: bool | None = None,
recompute_scale_factor: bool | None = None,
antialias: bool = False,
) -> None:
super().__init__(keys)
self.transform = BatchedZoom(
scale_factor=scale_factor,
mode=mode,
align_corners=align_corners,
recompute_scale_factor=recompute_scale_factor,
antialias=antialias,
)

def __call__(self, data):
d = dict(data)
for key in self.keys:
d[key] = self.transform(d[key])
return d