11from pathlib import Path
22
3- from CollaborativeCoding import load_data , load_metric , load_model
3+ from CollaborativeCoding import MetricWrapper , load_data , load_model
44
55
66def test_load_model ():
@@ -36,13 +36,7 @@ def test_load_data():
3636 import torch as th
3737 from torchvision import transforms
3838
39- dataset_names = [
40- "usps_0-6" ,
41- "mnist_0-3" ,
42- "usps_7-9" ,
43- "svhn" ,
44- # 'mnist_4-9' #Uncomment when implemented
45- ]
39+ dataset_names = ["usps_0-6" , "mnist_0-3" , "usps_7-9" , "svhn" , "mnist_4-9" ]
4640
4741 trans = transforms .Compose (
4842 [
@@ -64,4 +58,25 @@ def test_load_data():
6458
6559
6660def test_load_metric ():
67- pass
61+ import torch as th
62+
63+ metrics = ("entropy" , "f1" , "recall" , "precision" , "accuracy" )
64+
65+ class_sizes = [3 , 6 , 10 ]
66+ for class_size in class_sizes :
67+ y_true = th .rand ((5 , class_size )).argmax (dim = 1 )
68+ y_pred = th .rand ((5 , class_size ))
69+
70+ metricwrapper = MetricWrapper (
71+ * metrics ,
72+ num_classes = class_size ,
73+ macro_averaging = True if class_size % 2 == 0 else False ,
74+ )
75+
76+ metricwrapper (y_true , y_pred )
77+ metric = metricwrapper .getmetrics ()
78+ assert metric is not None
79+
80+ metricwrapper .resetmetric ()
81+ metric2 = metricwrapper .getmetrics ()
82+ assert metric != metric2
0 commit comments