Skip to content

Commit 22b2e59

Browse files
committed
Resolve merge conflicts in batched transforms integration
1 parent cf23055 commit 22b2e59

34 files changed

+3273
-386
lines changed

examples/transforms/batched_transforms.ipynb

Lines changed: 622 additions & 0 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ classifiers = [
1616
"Programming Language :: Python :: 3.13",
1717
]
1818
dependencies = [
19-
"iohub[tensorstore]>=0.2.2rc0",
19+
"iohub[tensorstore]>=0.3.0a2",
2020
"kornia",
2121
"torch>=2.4.1",
2222
"timm>=0.9.5",

tests/data/test_triplet.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def test_datamodule_setup_fit(
5353
assert fov_name not in exclude_fovs
5454
assert len(all_tracks) == len_total
5555
for batch in dm.train_dataloader():
56+
dm.on_after_batch_transfer(batch, 0)
5657
assert batch["anchor"].shape == (
5758
batch_size,
5859
len(channel_names),
@@ -94,6 +95,7 @@ def test_datamodule_z_window_size(
9495
else:
9596
expected_z_shape = z_window_size
9697
for batch in dm.train_dataloader():
98+
dm.on_after_batch_transfer(batch, 0)
9799
assert batch["anchor"].shape == (
98100
batch_size,
99101
len(channel_names),

tests/transforms/__init__.py

Whitespace-only changes.
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
import pytest
2+
import torch
3+
from monai.transforms import AdjustContrast, Compose
4+
5+
from viscy.transforms import BatchedRandAdjustContrast, BatchedRandAdjustContrastd
6+
7+
8+
@pytest.mark.parametrize("ndim", [4, 5])
9+
@pytest.mark.parametrize("prob", [0.0, 0.5, 1.0])
10+
@pytest.mark.parametrize(
11+
"device", ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
12+
)
13+
@pytest.mark.parametrize("compose", [True, False])
14+
def test_batched_adjust_contrast(device, ndim, prob, compose):
15+
img = torch.rand([16] + [2] * (ndim - 1), device=device) + 0.1
16+
transform = BatchedRandAdjustContrast(prob=prob, gamma=(0.5, 2.0))
17+
if compose:
18+
transform = Compose([transform])
19+
result = transform(img)
20+
assert result.shape == img.shape
21+
changed = ~torch.isclose(result, img, atol=1e-6).all(
22+
dim=list(range(1, result.ndim))
23+
)
24+
if prob == 1.0:
25+
assert changed.all()
26+
elif prob == 0.5:
27+
assert changed.any()
28+
assert not changed.all()
29+
elif prob == 0.0:
30+
assert not changed.any()
31+
assert result.device == img.device
32+
if not compose:
33+
repeat = transform(img, randomize=False)
34+
assert torch.equal(result, repeat)
35+
36+
37+
@pytest.mark.parametrize("gamma", [0.8, 1.5, (0.5, 2.0)])
38+
@pytest.mark.parametrize("retain_stats", [True, False])
39+
@pytest.mark.parametrize("invert_image", [True, False])
40+
def test_batched_adjust_contrast_options(gamma, retain_stats, invert_image):
41+
img = torch.rand(8, 1, 8, 8, 8) + 0.1
42+
original_mean = img.mean()
43+
original_std = img.std()
44+
45+
transform = BatchedRandAdjustContrast(
46+
prob=1.0, gamma=gamma, retain_stats=retain_stats, invert_image=invert_image
47+
)
48+
result = transform(img)
49+
50+
assert result.shape == img.shape
51+
52+
if retain_stats:
53+
assert torch.isclose(result.mean(), original_mean, atol=1e-5)
54+
assert torch.isclose(result.std(), original_std, atol=1e-5)
55+
56+
if not (isinstance(gamma, (int, float)) and gamma == 1.0):
57+
assert not torch.equal(result, img)
58+
59+
60+
def test_batched_adjust_contrast_dict():
61+
img = torch.rand([16, 1, 4, 8, 8]) + 0.1
62+
data = {"a": img, "b": img.clone()}
63+
transform = BatchedRandAdjustContrastd(keys=["a", "b"], prob=1.0, gamma=(0.5, 2.0))
64+
result = transform(data)
65+
assert not torch.equal(result["a"], img)
66+
assert torch.equal(result["a"], result["b"])
67+
68+
69+
def test_batched_adjust_contrast_gamma_validation():
70+
with pytest.raises(ValueError):
71+
BatchedRandAdjustContrast(gamma=0.0)
72+
73+
with pytest.raises(ValueError):
74+
BatchedRandAdjustContrast(gamma=-0.5)
75+
76+
with pytest.raises(ValueError):
77+
BatchedRandAdjustContrast(gamma=(0.5, 2.0, 1.0))
78+
79+
with pytest.raises(ValueError):
80+
BatchedRandAdjustContrast(gamma=(-0.1, 2.0))
81+
82+
BatchedRandAdjustContrast(gamma=0.1)
83+
BatchedRandAdjustContrast(gamma=0.3)
84+
BatchedRandAdjustContrast(gamma=1.5)
85+
BatchedRandAdjustContrast(gamma=(0.2, 0.8))
86+
BatchedRandAdjustContrast(gamma=(0.5, 2.0))
87+
88+
89+
@pytest.mark.parametrize("gamma_value", [0.2, 0.5, 0.8, 1.2, 2.0])
90+
@pytest.mark.parametrize("retain_stats", [True, False])
91+
@pytest.mark.parametrize("invert_image", [True, False])
92+
def test_batched_adjust_contrast_vs_monai(gamma_value, retain_stats, invert_image):
93+
torch.manual_seed(42)
94+
95+
batch_size = 4
96+
img_batch = torch.rand(batch_size, 1, 8, 8, 8) + 0.1
97+
98+
batched_transform = BatchedRandAdjustContrast(
99+
prob=1.0,
100+
gamma=(gamma_value, gamma_value),
101+
retain_stats=retain_stats,
102+
invert_image=invert_image,
103+
)
104+
105+
torch.manual_seed(42)
106+
batched_result = batched_transform(img_batch)
107+
108+
monai_transform = AdjustContrast(
109+
gamma=gamma_value, retain_stats=retain_stats, invert_image=invert_image
110+
)
111+
112+
monai_results = []
113+
for i in range(batch_size):
114+
individual_result = monai_transform(img_batch[i])
115+
monai_results.append(individual_result)
116+
117+
monai_batch_result = torch.stack(monai_results)
118+
119+
assert torch.allclose(batched_result, monai_batch_result, atol=1e-6, rtol=1e-5)

0 commit comments

Comments
 (0)