Skip to content

Commit f31e4e7

Browse files
committed
Move load_data_test to test_wrappers.py
1 parent 2accbe4 commit f31e4e7

File tree

2 files changed

+37
-60
lines changed

2 files changed

+37
-60
lines changed

tests/test_dataloaders.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,19 @@
11
from pathlib import Path
22

33
import numpy as np
4-
import pytest
5-
import torch
6-
from PIL import Image
74
from torchvision import transforms
85

96
from CollaborativeCoding.dataloaders import (
107
MNISTDataset0_3,
118
USPSDataset0_6,
129
USPSH5_Digit_7_9_Dataset,
1310
)
14-
from CollaborativeCoding.load_data import load_data
15-
16-
17-
@pytest.mark.parametrize(
18-
"data_name, expected",
19-
[
20-
("usps_0-6", USPSDataset0_6),
21-
("usps_7-9", USPSH5_Digit_7_9_Dataset),
22-
("mnist_0-3", MNISTDataset0_3),
23-
# TODO: Add more datasets here
24-
],
25-
)
26-
def test_load_data(data_name, expected):
27-
dataset = load_data(
28-
data_name,
29-
data_dir=Path("data"),
30-
transform=transforms.ToTensor(),
31-
)
32-
assert isinstance(dataset, expected)
33-
assert len(dataset) > 0
34-
assert isinstance(dataset[0], tuple)
35-
assert isinstance(dataset[0][0], torch.Tensor)
36-
assert isinstance(
37-
dataset[0][1], (int, torch.Tensor, np.ndarray)
38-
) # Should probably restrict this to only int or one-hot encoded tensor or array for consistency.
3911

4012

4113
def test_uspsdataset0_6():
4214
from tempfile import TemporaryDirectory
4315

4416
import h5py
45-
import numpy as np
46-
from torchvision import transforms
4717

4818
# Create a temporary directory (deleted after the test)
4919
with TemporaryDirectory() as tempdir:

tests/test_wrappers.py

Lines changed: 37 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,17 @@
11
from pathlib import Path
2+
from tempfile import TemporaryDirectory
3+
4+
import pytest
5+
import torch
6+
from torchvision import transforms
27

38
from CollaborativeCoding import load_data, load_metric, load_model
9+
from CollaborativeCoding.dataloaders import (
10+
MNISTDataset0_3,
11+
SVHNDataset,
12+
USPSDataset0_6,
13+
USPSH5_Digit_7_9_Dataset,
14+
)
415

516

617
def test_load_model():
@@ -30,38 +41,34 @@ def test_load_model():
3041
)
3142

3243

33-
def test_load_data():
34-
from tempfile import TemporaryDirectory
35-
36-
import torch as th
37-
from torchvision import transforms
38-
39-
dataset_names = [
40-
"usps_0-6",
41-
"mnist_0-3",
42-
"usps_7-9",
43-
"svhn",
44-
# 'mnist_4-9' #Uncomment when implemented
45-
]
46-
47-
trans = transforms.Compose(
48-
[
49-
transforms.Resize((16, 16)),
50-
transforms.ToTensor(),
51-
]
52-
)
53-
54-
with TemporaryDirectory() as tmppath:
55-
for name in dataset_names:
56-
dataset = load_data(
57-
name, train=False, data_dir=Path(tmppath), transform=trans
44+
@pytest.mark.parametrize(
45+
"data_name, expected",
46+
[
47+
("usps_0-6", USPSDataset0_6),
48+
("usps_7-9", USPSH5_Digit_7_9_Dataset),
49+
("mnist_0-3", MNISTDataset0_3),
50+
("svhn", SVHNDataset),
51+
],
52+
)
53+
def test_load_data(data_name, expected):
54+
with TemporaryDirectory() as tempdir:
55+
tempdir = Path(tempdir)
56+
57+
train, val, test = load_data(
58+
data_name,
59+
data_dir=tempdir,
60+
transform=transforms.ToTensor(),
61+
)
62+
63+
for dataset in [train, val, test]:
64+
assert isinstance(dataset, expected)
65+
assert len(dataset) > 0
66+
assert isinstance(dataset[0], tuple)
67+
assert isinstance(dataset[0][0], torch.Tensor)
68+
assert isinstance(
69+
dataset[0][1], int
5870
)
5971

60-
im, _ = dataset.__getitem__(0)
61-
62-
assert dataset.__len__() != 0
63-
assert type(im) == th.Tensor and len(im.size()) == 3
64-
6572

6673
def test_load_metric():
6774
pass

0 commit comments

Comments
 (0)