@@ -27,35 +27,38 @@ def test_load_model():
2727 )
2828
2929
30- def test_load_data ():
31- from tempfile import TemporaryDirectory
30+ # def test_load_data():
31+ # from tempfile import TemporaryDirectory
3232
33- import torch as th
34- from torchvision import transforms
35-
36- dataset_names = [
37- "usps_0-6" ,
38- "mnist_0-3" ,
39- "usps_7-9" ,
40- "svhn" ,
41- # 'mnist_4-9' #Uncomment when implemented
42- ]
33+ # import torch as th
34+ # from torchvision import transforms
4335
44- trans = transforms .Compose (
45- [
46- transforms .Resize ((16 , 16 )),
47- transforms .ToTensor (),
48- ]
49- )
50-
51- with TemporaryDirectory () as tmppath :
52- for name in dataset_names :
53- dataset = load_data (
54- name , train = False , data_path = tmppath , download = True , transform = trans
55- )
36+ # dataset_names = [
37+ # "usps_0-6",
38+ # "mnist_0-3",
39+ # "usps_7-9",
40+ # "svhn",
41+ # # 'mnist_4-9' #Uncomment when implemented
42+ # ]
43+
44+ # trans = transforms.Compose(
45+ # [
46+ # transforms.Resize((16, 16)),
47+ # transforms.ToTensor(),
48+ # ]
49+ # )
50+
51+ # with TemporaryDirectory() as tmppath:
52+ # for name in dataset_names:
53+ # dataset = load_data(
54+ # name, train=False, data_path=tmppath, download=True, transform=trans
55+ # )
56+
57+ # im, lab = dataset.__getitem__(0)
5658
57- im , lab = dataset .__getitem__ (0 )
59+ # assert dataset.__len__() != 0
60+ # assert type(im) == th.Tensor and len(im.size()) == 3
61+ # assert lab - lab == 0.0
5862
59- assert dataset .__len__ () != 0
60- assert type (im ) == th .Tensor and len (im .size ()) == 3
61- assert lab - lab == 0.0
63+ # def test_load_metric():
64+ # pass
0 commit comments