Skip to content

Commit e6e5bf1

Browse files
committed
Added test_load_data
1 parent 24e40ae commit e6e5bf1

File tree

1 file changed

+31
-1
lines changed

1 file changed

+31
-1
lines changed

tests/test_wrappers.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,34 @@ def test_load_model():
2828

2929

3030
def test_load_data():
31-
pass
31+
from tempfile import TemporaryDirectory
32+
33+
import torch as th
34+
from torchvision import transforms
35+
36+
dataset_names = [
37+
"usps_0-6",
38+
"mnist_0-3",
39+
"usps_7-9",
40+
"svhn",
41+
# 'mnist_4-9' #Uncomment when implemented
42+
]
43+
44+
trans = transforms.Compose(
45+
[
46+
transforms.Resize((16, 16)),
47+
transforms.ToTensor(),
48+
]
49+
)
50+
51+
with TemporaryDirectory() as tmppath:
52+
for name in dataset_names:
53+
dataset = load_data(
54+
name, train=False, data_path=tmppath, download=True, transform=trans
55+
)
56+
57+
im, lab = dataset.__getitem__(0)
58+
59+
assert dataset.__len__() != 0
60+
assert type(im) == th.Tensor and len(im.size()) == 3
61+
assert lab - lab == 0.0

0 commit comments

Comments
 (0)