Skip to content

Commit 1e5381d

Browse files
ziw-liuritvikvasanedyoshikun
authored
Batched transforms (#282)
* add test for scale intensity range percentiles transform * randomize test data * add decollate * detect leading slashes * configurable label column * fix typo * configurable example shape * type hint * fix example array * format * auto-align validation fov names * Adding some batched transforms for MAE training (#284) * Add batched transforms for MAE training - BatchedRandFlipd - BatchedRandSharpend - BatchedRandLocalPixelShufflingd - BatchedRandHistogramShiftd - BatchedRandZStackShiftd - BatchedRand3DElasticd * Revert existing files to original state Keep only new transform files and necessary imports * Fix import sorting in new transform files * Fix import formatting in __init__.py * Format code with ruff formatter * bugfix for device * ruff format fix * wip: random crop * format * use gather operation for cropping * register batched random crop * use unfold and bench * add dict version for random spatial crop * lint * benchmark crop * rename flip file * use tensorstore to stack * rename in register * add notebook to showcase transforms * batched center crop * batched gaussian noise * test and add to notebook * fix docstring * test random flip * add map version of batched gaussian noise * fix noise scaling and shifting * per-sample random flip * fix gaussian blur * rename blur to smooth to match monai * update smoothing tests * sigma-based truncation * print timer results * random adjust contrast * wrap monai * random scale intensity * update tests * allow final crop override * add center crop to example * remove redundant call method * check gamma value * use batched center crop * fix import * wip: fix gathering * limit worker to 1 * wip: execute augmentations in data hook * use the full loop for profiling * wip: update concat wrapper * fix negative sampling * update tests * revert to use threads * cache tensorstore arrays * update default num_workers * skip transforms for example array * relax flaky randomness test * explicitly configure cache pool * use new iohub API * only check cache when opening * fix transform initialization * remove unecessary tipletdatamodule inits * explain num_workers in the docstring * updating paths for profiling * fixed batch randintensityscale in the notebook --------- Co-authored-by: Ritvik <[email protected]> Co-authored-by: Eduardo Hirata-Miyasaki <[email protected]>
1 parent 4df580d commit 1e5381d

34 files changed

+3184
-330
lines changed

examples/transforms/batched_transforms.ipynb

Lines changed: 622 additions & 0 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ classifiers = [
1616
"Programming Language :: Python :: 3.13",
1717
]
1818
dependencies = [
19-
"iohub[tensorstore]>=0.2.2rc0",
19+
"iohub[tensorstore]>=0.3.0a2",
2020
"kornia",
2121
"torch>=2.4.1",
2222
"timm>=0.9.5",

tests/data/test_triplet.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def test_datamodule_setup_fit(
5454
assert fov_name not in exclude_fovs
5555
assert len(all_tracks) == len_total
5656
for batch in dm.train_dataloader():
57+
dm.on_after_batch_transfer(batch, 0)
5758
assert batch["anchor"].shape == (
5859
batch_size,
5960
len(channel_names),
@@ -95,6 +96,7 @@ def test_datamodule_z_window_size(
9596
else:
9697
expected_z_shape = z_window_size
9798
for batch in dm.train_dataloader():
99+
dm.on_after_batch_transfer(batch, 0)
98100
assert batch["anchor"].shape == (
99101
batch_size,
100102
len(channel_names),

tests/transforms/__init__.py

Whitespace-only changes.
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
import pytest
2+
import torch
3+
from monai.transforms import AdjustContrast, Compose
4+
5+
from viscy.transforms import BatchedRandAdjustContrast, BatchedRandAdjustContrastd
6+
7+
8+
@pytest.mark.parametrize("ndim", [4, 5])
9+
@pytest.mark.parametrize("prob", [0.0, 0.5, 1.0])
10+
@pytest.mark.parametrize(
11+
"device", ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
12+
)
13+
@pytest.mark.parametrize("compose", [True, False])
14+
def test_batched_adjust_contrast(device, ndim, prob, compose):
15+
img = torch.rand([16] + [2] * (ndim - 1), device=device) + 0.1
16+
transform = BatchedRandAdjustContrast(prob=prob, gamma=(0.5, 2.0))
17+
if compose:
18+
transform = Compose([transform])
19+
result = transform(img)
20+
assert result.shape == img.shape
21+
changed = ~torch.isclose(result, img, atol=1e-6).all(
22+
dim=list(range(1, result.ndim))
23+
)
24+
if prob == 1.0:
25+
assert changed.all()
26+
elif prob == 0.5:
27+
assert changed.any()
28+
assert not changed.all()
29+
elif prob == 0.0:
30+
assert not changed.any()
31+
assert result.device == img.device
32+
if not compose:
33+
repeat = transform(img, randomize=False)
34+
assert torch.equal(result, repeat)
35+
36+
37+
@pytest.mark.parametrize("gamma", [0.8, 1.5, (0.5, 2.0)])
38+
@pytest.mark.parametrize("retain_stats", [True, False])
39+
@pytest.mark.parametrize("invert_image", [True, False])
40+
def test_batched_adjust_contrast_options(gamma, retain_stats, invert_image):
41+
img = torch.rand(8, 1, 8, 8, 8) + 0.1
42+
original_mean = img.mean()
43+
original_std = img.std()
44+
45+
transform = BatchedRandAdjustContrast(
46+
prob=1.0, gamma=gamma, retain_stats=retain_stats, invert_image=invert_image
47+
)
48+
result = transform(img)
49+
50+
assert result.shape == img.shape
51+
52+
if retain_stats:
53+
assert torch.isclose(result.mean(), original_mean, atol=1e-5)
54+
assert torch.isclose(result.std(), original_std, atol=1e-5)
55+
56+
if not (isinstance(gamma, (int, float)) and gamma == 1.0):
57+
assert not torch.equal(result, img)
58+
59+
60+
def test_batched_adjust_contrast_dict():
61+
img = torch.rand([16, 1, 4, 8, 8]) + 0.1
62+
data = {"a": img, "b": img.clone()}
63+
transform = BatchedRandAdjustContrastd(keys=["a", "b"], prob=1.0, gamma=(0.5, 2.0))
64+
result = transform(data)
65+
assert not torch.equal(result["a"], img)
66+
assert torch.equal(result["a"], result["b"])
67+
68+
69+
def test_batched_adjust_contrast_gamma_validation():
70+
with pytest.raises(ValueError):
71+
BatchedRandAdjustContrast(gamma=0.0)
72+
73+
with pytest.raises(ValueError):
74+
BatchedRandAdjustContrast(gamma=-0.5)
75+
76+
with pytest.raises(ValueError):
77+
BatchedRandAdjustContrast(gamma=(0.5, 2.0, 1.0))
78+
79+
with pytest.raises(ValueError):
80+
BatchedRandAdjustContrast(gamma=(-0.1, 2.0))
81+
82+
BatchedRandAdjustContrast(gamma=0.1)
83+
BatchedRandAdjustContrast(gamma=0.3)
84+
BatchedRandAdjustContrast(gamma=1.5)
85+
BatchedRandAdjustContrast(gamma=(0.2, 0.8))
86+
BatchedRandAdjustContrast(gamma=(0.5, 2.0))
87+
88+
89+
@pytest.mark.parametrize("gamma_value", [0.2, 0.5, 0.8, 1.2, 2.0])
90+
@pytest.mark.parametrize("retain_stats", [True, False])
91+
@pytest.mark.parametrize("invert_image", [True, False])
92+
def test_batched_adjust_contrast_vs_monai(gamma_value, retain_stats, invert_image):
93+
torch.manual_seed(42)
94+
95+
batch_size = 4
96+
img_batch = torch.rand(batch_size, 1, 8, 8, 8) + 0.1
97+
98+
batched_transform = BatchedRandAdjustContrast(
99+
prob=1.0,
100+
gamma=(gamma_value, gamma_value),
101+
retain_stats=retain_stats,
102+
invert_image=invert_image,
103+
)
104+
105+
torch.manual_seed(42)
106+
batched_result = batched_transform(img_batch)
107+
108+
monai_transform = AdjustContrast(
109+
gamma=gamma_value, retain_stats=retain_stats, invert_image=invert_image
110+
)
111+
112+
monai_results = []
113+
for i in range(batch_size):
114+
individual_result = monai_transform(img_batch[i])
115+
monai_results.append(individual_result)
116+
117+
monai_batch_result = torch.stack(monai_results)
118+
119+
assert torch.allclose(batched_result, monai_batch_result, atol=1e-6, rtol=1e-5)

0 commit comments

Comments
 (0)