11from pathlib import Path
2- from tempfile import TemporaryDirectory
32
43import pytest
5- import torch
6- from torchvision import transforms
4+ import torch as th
75
86from CollaborativeCoding import MetricWrapper , load_data , load_model
97
108
119def test_load_model ():
12- import torch as th
13-
1410 image_shape = (1 , 16 , 16 )
1511 num_classes = 4
1612
@@ -36,17 +32,14 @@ def test_load_model():
3632
3733
3834def test_load_data ():
39- from tempfile import TemporaryDirectory
40-
41- import torch as th
4235 from torchvision import transforms
4336
4437 dataset_names = [
4538 "usps_0-6" ,
4639 "mnist_0-3" ,
4740 "usps_7-9" ,
4841 "svhn" ,
49- "mnist_4-9" , # Uncomment when implemented
42+ "mnist_4-9" ,
5043 ]
5144
5245 trans = transforms .Compose (
@@ -56,21 +49,16 @@ def test_load_data():
5649 ]
5750 )
5851
59- with TemporaryDirectory () as tmppath :
60- for name in dataset_names :
61- dataset , _ , _ = load_data (
62- name , train = False , data_dir = Path (tmppath ), transform = trans
63- )
52+ for name in dataset_names :
53+ dataset = load_data (name , train = False , data_dir = Path .cwd () / "Data" , transform = trans )
6454
65- im , _ = dataset .__getitem__ (0 )
55+ im , _ = dataset .__getitem__ (0 )
6656
67- assert dataset .__len__ () != 0
68- assert type (im ) == th .Tensor and len (im .size ()) == 3
57+ assert dataset .__len__ () != 0
58+ assert type (im ) is th .Tensor and len (im .size ()) == 3
6959
7060
7161def test_load_metric ():
72- import torch as th
73-
7462 metrics = ("entropy" , "f1" , "recall" , "precision" , "accuracy" )
7563
7664 class_sizes = [3 , 6 , 10 ]
0 commit comments