Skip to content

Commit 18dfea2

Browse files
committed
adjusted my dataloader to the new format, added test for my model
1 parent 26ff680 commit 18dfea2

File tree

2 files changed

+25
-8
lines changed

2 files changed

+25
-8
lines changed

tests/test_models.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pytest
22
import torch
33

4-
from utils.models import ChristianModel, JanModel
4+
from utils.models import ChristianModel, JanModel, SolveigModel
55

66

77
@pytest.mark.parametrize(
@@ -33,3 +33,18 @@ def test_jan_model(image_shape, num_classes):
3333

3434
assert y.shape == (n, num_classes), f"Shape: {y.shape}"
3535

36+
37+
@pytest.mark.parametrize(
38+
"image_shape, num_classes",
39+
[((3, 16, 16), 3), ((3, 16, 16), 7)],
40+
)
41+
def test_solveig_model(image_shape, num_classes):
42+
n, c, h, w = 5, *image_shape
43+
44+
model = SolveigModel(image_shape, num_classes)
45+
46+
x = torch.randn(n, c, h, w)
47+
y = model(x)
48+
49+
assert y.shape == (n, num_classes), f"Shape: {y.shape}"
50+

utils/dataloaders/uspsh5_7_9.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from PIL import Image
55
from torch.utils.data import Dataset
66
from torchvision import transforms
7+
from pathlib import Path
78

89

910
class USPSH5_Digit_7_9_Dataset(Dataset):
@@ -30,7 +31,7 @@ class USPSH5_Digit_7_9_Dataset(Dataset):
3031
A transform function to apply to the images.
3132
"""
3233

33-
def __init__(self, h5_path, mode, transform=None):
34+
def __init__(self, data_path, train = False, transform=None):
3435
super().__init__()
3536
"""
3637
Initializes the USPS dataset by loading images and labels from the given `.h5` file.
@@ -43,12 +44,13 @@ def __init__(self, h5_path, mode, transform=None):
4344
transform : callable, optional, default=None
4445
A transform function to apply on images.
4546
"""
46-
47+
self.filename = "usps.h5"
48+
path = data_path if isinstance(data_path, Path) else Path(data_path)
49+
self.filepath = path / self.filename
4750
self.transform = transform
48-
self.mode = mode
49-
self.h5_path = h5_path
51+
self.mode = "train" if train else "test"
5052
# Load the dataset from the HDF5 file
51-
with h5py.File(self.h5_path, "r") as hf:
53+
with h5py.File(self.filepath, "r") as hf:
5254
images = hf[self.mode]["data"][:]
5355
labels = hf[self.mode]["target"][:]
5456

@@ -105,8 +107,8 @@ def main():
105107

106108
# Load the dataset
107109
dataset = USPSH5_Digit_7_9_Dataset(
108-
h5_path="C:/Users/Solveig/OneDrive/Dokumente/UiT PhD/Courses/Git/usps.h5",
109-
mode="train",
110+
data_path="C:/Users/Solveig/OneDrive/Dokumente/UiT PhD/Courses/Git",
111+
train = False,
110112
transform=transform,
111113
)
112114
data_loader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True)

0 commit comments

Comments
 (0)