Skip to content

Commit 018b669

Browse files
authored
Merge pull request #49 from SFI-Visual-Intelligence/christian/test-model-metric-data
Create tests for load_model/metric/data, closes #46 LGTM
2 parents 365d389 + 6be010a commit 018b669

File tree

3 files changed

+38
-2
lines changed

3 files changed

+38
-2
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ _build/
77
bin/*
88
wandb/*
99
wandb_api.py
10+
doc/autoapi
1011

1112
#Magnus specific
1213
job*

tests/test_dataloaders.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,41 @@
1-
from utils.dataloaders.usps_0_6 import USPSDataset0_6
1+
from pathlib import Path
2+
3+
import numpy as np
4+
import pytest
5+
import torch
6+
from PIL import Image
7+
from torchvision import transforms
8+
9+
from utils.dataloaders import MNISTDataset0_3, USPSDataset0_6, USPSH5_Digit_7_9_Dataset
10+
from utils.load_data import load_data
11+
12+
13+
@pytest.mark.parametrize(
14+
"data_name, expected",
15+
[
16+
("usps_0-6", USPSDataset0_6),
17+
("usps_7-9", USPSH5_Digit_7_9_Dataset),
18+
("mnist_0-3", MNISTDataset0_3),
19+
# TODO: Add more datasets here
20+
],
21+
)
22+
def test_load_data(data_name, expected):
23+
dataset = load_data(
24+
data_name,
25+
data_path=Path("data"),
26+
download=True,
27+
transform=transforms.ToTensor(),
28+
)
29+
assert isinstance(dataset, expected)
30+
assert len(dataset) > 0
31+
assert isinstance(dataset[0], tuple)
32+
assert isinstance(dataset[0][0], torch.Tensor)
33+
assert isinstance(
34+
dataset[0][1], (int, torch.Tensor, np.ndarray)
35+
) # Should probably restrict this to only int or one-hot encoded tensor or array for consistency.
236

337

438
def test_uspsdataset0_6():
5-
from pathlib import Path
639
from tempfile import TemporaryDirectory
740

841
import h5py

utils/dataloaders/uspsh5_7_9.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ def __init__(self, data_path, train=False, transform=None):
5050
self.filepath = path / self.filename
5151
self.transform = transform
5252
self.mode = "train" if train else "test"
53+
self.h5_path = data_path / self.filename
54+
5355
# Load the dataset from the HDF5 file
5456
with h5py.File(self.filepath, "r") as hf:
5557
images = hf[self.mode]["data"][:]

0 commit comments

Comments
 (0)