Skip to content

Commit 1add669

Browse files
committed
Create new test that verifies basic functionality of all datasets
1 parent a4214d2 commit 1add669

File tree

1 file changed

+35
-2
lines changed

1 file changed

+35
-2
lines changed

tests/test_dataloaders.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,41 @@
1-
from utils.dataloaders.usps_0_6 import USPSDataset0_6
1+
from pathlib import Path
2+
3+
import numpy as np
4+
import pytest
5+
import torch
6+
from PIL import Image
7+
from torchvision import transforms
8+
9+
from utils.dataloaders import MNISTDataset0_3, USPSDataset0_6, USPSH5_Digit_7_9_Dataset
10+
from utils.load_data import load_data
11+
12+
13+
@pytest.mark.parametrize(
14+
"data_name, expected",
15+
[
16+
("usps_0-6", USPSDataset0_6),
17+
("usps_7-9", USPSH5_Digit_7_9_Dataset),
18+
("mnist_0-3", MNISTDataset0_3),
19+
# TODO: Add more datasets here
20+
],
21+
)
22+
def test_load_data(data_name, expected):
23+
dataset = load_data(
24+
data_name,
25+
data_path=Path("data"),
26+
download=True,
27+
transform=transforms.ToTensor(),
28+
)
29+
assert isinstance(dataset, expected)
30+
assert len(dataset) > 0
31+
assert isinstance(dataset[0], tuple)
32+
assert isinstance(dataset[0][0], torch.Tensor)
33+
assert isinstance(
34+
dataset[0][1], (int, torch.Tensor, np.ndarray)
35+
) # Should probably restrict this to only int or one-hot encoded tensor or array for consistency.
236

337

438
def test_uspsdataset0_6():
5-
from pathlib import Path
639
from tempfile import TemporaryDirectory
740

841
import h5py

0 commit comments

Comments
 (0)