|
| 1 | +import pytest |
| 2 | +import torch |
| 3 | +from monai.transforms import AdjustContrast, Compose |
| 4 | + |
| 5 | +from viscy.transforms import BatchedRandAdjustContrast, BatchedRandAdjustContrastd |
| 6 | + |
| 7 | + |
| 8 | +@pytest.mark.parametrize("ndim", [4, 5]) |
| 9 | +@pytest.mark.parametrize("prob", [0.0, 0.5, 1.0]) |
| 10 | +@pytest.mark.parametrize( |
| 11 | + "device", ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"] |
| 12 | +) |
| 13 | +@pytest.mark.parametrize("compose", [True, False]) |
| 14 | +def test_batched_adjust_contrast(device, ndim, prob, compose): |
| 15 | + img = torch.rand([16] + [2] * (ndim - 1), device=device) + 0.1 |
| 16 | + transform = BatchedRandAdjustContrast(prob=prob, gamma=(0.5, 2.0)) |
| 17 | + if compose: |
| 18 | + transform = Compose([transform]) |
| 19 | + result = transform(img) |
| 20 | + assert result.shape == img.shape |
| 21 | + changed = ~torch.isclose(result, img, atol=1e-6).all( |
| 22 | + dim=list(range(1, result.ndim)) |
| 23 | + ) |
| 24 | + if prob == 1.0: |
| 25 | + assert changed.all() |
| 26 | + elif prob == 0.5: |
| 27 | + assert changed.any() |
| 28 | + assert not changed.all() |
| 29 | + elif prob == 0.0: |
| 30 | + assert not changed.any() |
| 31 | + assert result.device == img.device |
| 32 | + if not compose: |
| 33 | + repeat = transform(img, randomize=False) |
| 34 | + assert torch.equal(result, repeat) |
| 35 | + |
| 36 | + |
| 37 | +@pytest.mark.parametrize("gamma", [0.8, 1.5, (0.5, 2.0)]) |
| 38 | +@pytest.mark.parametrize("retain_stats", [True, False]) |
| 39 | +@pytest.mark.parametrize("invert_image", [True, False]) |
| 40 | +def test_batched_adjust_contrast_options(gamma, retain_stats, invert_image): |
| 41 | + img = torch.rand(8, 1, 8, 8, 8) + 0.1 |
| 42 | + original_mean = img.mean() |
| 43 | + original_std = img.std() |
| 44 | + |
| 45 | + transform = BatchedRandAdjustContrast( |
| 46 | + prob=1.0, gamma=gamma, retain_stats=retain_stats, invert_image=invert_image |
| 47 | + ) |
| 48 | + result = transform(img) |
| 49 | + |
| 50 | + assert result.shape == img.shape |
| 51 | + |
| 52 | + if retain_stats: |
| 53 | + assert torch.isclose(result.mean(), original_mean, atol=1e-5) |
| 54 | + assert torch.isclose(result.std(), original_std, atol=1e-5) |
| 55 | + |
| 56 | + if not (isinstance(gamma, (int, float)) and gamma == 1.0): |
| 57 | + assert not torch.equal(result, img) |
| 58 | + |
| 59 | + |
| 60 | +def test_batched_adjust_contrast_dict(): |
| 61 | + img = torch.rand([16, 1, 4, 8, 8]) + 0.1 |
| 62 | + data = {"a": img, "b": img.clone()} |
| 63 | + transform = BatchedRandAdjustContrastd(keys=["a", "b"], prob=1.0, gamma=(0.5, 2.0)) |
| 64 | + result = transform(data) |
| 65 | + assert not torch.equal(result["a"], img) |
| 66 | + assert torch.equal(result["a"], result["b"]) |
| 67 | + |
| 68 | + |
| 69 | +def test_batched_adjust_contrast_gamma_validation(): |
| 70 | + with pytest.raises(ValueError): |
| 71 | + BatchedRandAdjustContrast(gamma=0.0) |
| 72 | + |
| 73 | + with pytest.raises(ValueError): |
| 74 | + BatchedRandAdjustContrast(gamma=-0.5) |
| 75 | + |
| 76 | + with pytest.raises(ValueError): |
| 77 | + BatchedRandAdjustContrast(gamma=(0.5, 2.0, 1.0)) |
| 78 | + |
| 79 | + with pytest.raises(ValueError): |
| 80 | + BatchedRandAdjustContrast(gamma=(-0.1, 2.0)) |
| 81 | + |
| 82 | + BatchedRandAdjustContrast(gamma=0.1) |
| 83 | + BatchedRandAdjustContrast(gamma=0.3) |
| 84 | + BatchedRandAdjustContrast(gamma=1.5) |
| 85 | + BatchedRandAdjustContrast(gamma=(0.2, 0.8)) |
| 86 | + BatchedRandAdjustContrast(gamma=(0.5, 2.0)) |
| 87 | + |
| 88 | + |
| 89 | +@pytest.mark.parametrize("gamma_value", [0.2, 0.5, 0.8, 1.2, 2.0]) |
| 90 | +@pytest.mark.parametrize("retain_stats", [True, False]) |
| 91 | +@pytest.mark.parametrize("invert_image", [True, False]) |
| 92 | +def test_batched_adjust_contrast_vs_monai(gamma_value, retain_stats, invert_image): |
| 93 | + torch.manual_seed(42) |
| 94 | + |
| 95 | + batch_size = 4 |
| 96 | + img_batch = torch.rand(batch_size, 1, 8, 8, 8) + 0.1 |
| 97 | + |
| 98 | + batched_transform = BatchedRandAdjustContrast( |
| 99 | + prob=1.0, |
| 100 | + gamma=(gamma_value, gamma_value), |
| 101 | + retain_stats=retain_stats, |
| 102 | + invert_image=invert_image, |
| 103 | + ) |
| 104 | + |
| 105 | + torch.manual_seed(42) |
| 106 | + batched_result = batched_transform(img_batch) |
| 107 | + |
| 108 | + monai_transform = AdjustContrast( |
| 109 | + gamma=gamma_value, retain_stats=retain_stats, invert_image=invert_image |
| 110 | + ) |
| 111 | + |
| 112 | + monai_results = [] |
| 113 | + for i in range(batch_size): |
| 114 | + individual_result = monai_transform(img_batch[i]) |
| 115 | + monai_results.append(individual_result) |
| 116 | + |
| 117 | + monai_batch_result = torch.stack(monai_results) |
| 118 | + |
| 119 | + assert torch.allclose(batched_result, monai_batch_result, atol=1e-6, rtol=1e-5) |
0 commit comments