|
3 | 3 | import pytest |
4 | 4 | import torch |
5 | 5 | from PIL import Image |
| 6 | +from torchvision import transforms |
6 | 7 |
|
7 | | -# from torchvision.transforms.functional import pil_to_tensor |
8 | 8 | from torch.utils.data import DataLoader |
9 | 9 | from src.dataset import ForestDataset, OversampledDataset, UndersampledDataset, ForestDataModule |
10 | 10 |
|
@@ -92,9 +92,10 @@ def test_transforms_applied(self): |
92 | 92 | image, _ = self.dataset[0] |
93 | 93 |
|
94 | 94 | assert isinstance(image, torch.Tensor), self.error_msg["not-tensor"] |
95 | | - # Normalize check |
96 | | - assert torch.min(image) >= -1.0, self.error_msg["invalid-min"] |
97 | | - assert torch.max(image) <= 1.0, self.error_msg["invalid-max"] |
| 95 | + # Check that image is normalized (values should be centered around 0 after ImageNet normalization) |
| 96 | + # ImageNet normalization doesn't guarantee [-1, 1] range, so we check for reasonable normalized values |
| 97 | + assert torch.min(image) >= -3.0, self.error_msg["invalid-min"] |
| 98 | + assert torch.max(image) <= 3.0, self.error_msg["invalid-max"] |
98 | 99 |
|
99 | 100 | def test_missing_file_handling(self): |
100 | 101 | with pytest.raises(FileNotFoundError): |
@@ -176,15 +177,20 @@ def test_multiple_samples(forest_dataset, sample_data): |
176 | 177 |
|
177 | 178 |
|
178 | 179 | def mock_transform(x): |
179 | | - """Create simple tranformation for testing.""" |
| 180 | + """Create simple transformation for testing that converts PIL Image to tensor.""" |
| 181 | + if isinstance(x, Image.Image): |
| 182 | + return transforms.ToTensor()(x) |
180 | 183 | return x * 2 |
181 | 184 |
|
182 | 185 |
|
183 | 186 | def mock_minority_transform(x): |
184 | 187 | """ |
185 | | - Create simple minority transform that increments |
186 | | - the input for test verification. |
| 188 | + Create simple minority transform that converts PIL Image to tensor |
| 189 | + and increments for test verification. |
187 | 190 | """ |
| 191 | + if isinstance(x, Image.Image): |
| 192 | + tensor = transforms.ToTensor()(x) |
| 193 | + return tensor + 0.1 # Small increment for verification |
188 | 194 | return x + 1 |
189 | 195 |
|
190 | 196 |
|
|
0 commit comments