Skip to content

Commit 05a2c0e

Browse files
committed
Testing load_model
1 parent e6e5bf1 commit 05a2c0e

File tree

1 file changed

+31
-28
lines changed

1 file changed

+31
-28
lines changed

tests/test_wrappers.py

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -27,35 +27,38 @@ def test_load_model():
2727
)
2828

2929

30-
def test_load_data():
31-
from tempfile import TemporaryDirectory
30+
# def test_load_data():
31+
# from tempfile import TemporaryDirectory
3232

33-
import torch as th
34-
from torchvision import transforms
35-
36-
dataset_names = [
37-
"usps_0-6",
38-
"mnist_0-3",
39-
"usps_7-9",
40-
"svhn",
41-
# 'mnist_4-9' #Uncomment when implemented
42-
]
33+
# import torch as th
34+
# from torchvision import transforms
4335

44-
trans = transforms.Compose(
45-
[
46-
transforms.Resize((16, 16)),
47-
transforms.ToTensor(),
48-
]
49-
)
50-
51-
with TemporaryDirectory() as tmppath:
52-
for name in dataset_names:
53-
dataset = load_data(
54-
name, train=False, data_path=tmppath, download=True, transform=trans
55-
)
36+
# dataset_names = [
37+
# "usps_0-6",
38+
# "mnist_0-3",
39+
# "usps_7-9",
40+
# "svhn",
41+
# # 'mnist_4-9' #Uncomment when implemented
42+
# ]
43+
44+
# trans = transforms.Compose(
45+
# [
46+
# transforms.Resize((16, 16)),
47+
# transforms.ToTensor(),
48+
# ]
49+
# )
50+
51+
# with TemporaryDirectory() as tmppath:
52+
# for name in dataset_names:
53+
# dataset = load_data(
54+
# name, train=False, data_path=tmppath, download=True, transform=trans
55+
# )
56+
57+
# im, lab = dataset.__getitem__(0)
5658

57-
im, lab = dataset.__getitem__(0)
59+
# assert dataset.__len__() != 0
60+
# assert type(im) == th.Tensor and len(im.size()) == 3
61+
# assert lab - lab == 0.0
5862

59-
assert dataset.__len__() != 0
60-
assert type(im) == th.Tensor and len(im.size()) == 3
61-
assert lab - lab == 0.0
63+
# def test_load_metric():
64+
# pass

0 commit comments

Comments
 (0)