Skip to content

Commit fa08a02

Browse files
committed
test: Fit transform test to new code
1 parent e5ba79a commit fa08a02

File tree

2 files changed

+21
-14
lines changed

2 files changed

+21
-14
lines changed

src/dataset.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -78,18 +78,19 @@ def __len__(self):
7878
return len(self.image_paths)
7979

8080
def __getitem__(self, idx):
81-
# TODO: Load an image from path here
8281
image_path = self.image_paths[idx]
8382
label = self.labels[idx]
8483

8584
with Image.open(image_path) as img:
86-
# Convert to numpy array
87-
image = np.array(img)
88-
image = image[:, :, 1:] if image.shape[-1] == 4 else image # Removing "near-inferred" channel
89-
# We found out that PIL conversion to RGB
90-
# keeps the "near-inferred" channel which was not desired
85+
# Remove "near-infrared" channel if present (4-channel RGBA)
86+
if img.mode == "RGBA":
87+
# Convert RGBA to RGB (drops alpha/near-infrared channel)
88+
image = img.convert("RGB")
89+
else:
90+
# Keep as PIL Image for transforms
91+
image = img.copy()
9192

92-
# Apply transformations
93+
# Apply transformations (expects PIL Image)
9394
if self.transform:
9495
image = self.transform(image)
9596

tests/test_dataset.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import pytest
44
import torch
55
from PIL import Image
6+
from torchvision import transforms
67

7-
# from torchvision.transforms.functional import pil_to_tensor
88
from torch.utils.data import DataLoader
99
from src.dataset import ForestDataset, OversampledDataset, UndersampledDataset, ForestDataModule
1010

@@ -92,9 +92,10 @@ def test_transforms_applied(self):
9292
image, _ = self.dataset[0]
9393

9494
assert isinstance(image, torch.Tensor), self.error_msg["not-tensor"]
95-
# Normalize check
96-
assert torch.min(image) >= -1.0, self.error_msg["invalid-min"]
97-
assert torch.max(image) <= 1.0, self.error_msg["invalid-max"]
95+
# Check that image is normalized (values should be centered around 0 after ImageNet normalization)
96+
# ImageNet normalization doesn't guarantee [-1, 1] range, so we check for reasonable normalized values
97+
assert torch.min(image) >= -3.0, self.error_msg["invalid-min"]
98+
assert torch.max(image) <= 3.0, self.error_msg["invalid-max"]
9899

99100
def test_missing_file_handling(self):
100101
with pytest.raises(FileNotFoundError):
@@ -176,15 +177,20 @@ def test_multiple_samples(forest_dataset, sample_data):
176177

177178

178179
def mock_transform(x):
179-
"""Create simple tranformation for testing."""
180+
"""Create simple transformation for testing that converts PIL Image to tensor."""
181+
if isinstance(x, Image.Image):
182+
return transforms.ToTensor()(x)
180183
return x * 2
181184

182185

183186
def mock_minority_transform(x):
184187
"""
185-
Create simple minority transform that increments
186-
the input for test verification.
188+
Create simple minority transform that converts PIL Image to tensor
189+
and increments for test verification.
187190
"""
191+
if isinstance(x, Image.Image):
192+
tensor = transforms.ToTensor()(x)
193+
return tensor + 0.1 # Small increment for verification
188194
return x + 1
189195

190196

0 commit comments

Comments
 (0)