33import pytest
44
55from CollaborativeCoding .load_metric import MetricWrapper
6- from CollaborativeCoding .metrics import Accuracy , F1Score , Precision , Recall
6+ from CollaborativeCoding .metrics import (
7+ Accuracy ,
8+ EntropyPrediction ,
9+ F1Score ,
10+ Precision ,
11+ Recall ,
12+ )
713
814
915@pytest .mark .parametrize (
1723 ("accuracy" , randint (2 , 10 ), True ),
1824 ("precision" , randint (2 , 10 ), False ),
1925 ("precision" , randint (2 , 10 ), True ),
20- # TODO: Add test for EntropyPrediction
26+ ( "entropy" , randint ( 2 , 10 ), False ),
2127 ],
2228)
2329def test_metric_wrapper (metric , num_classes , macro_averaging ):
@@ -34,9 +40,9 @@ def test_metric_wrapper(metric, num_classes, macro_averaging):
3440 )
3541
3642 metrics (y_true , logits )
37- score = metrics .accumulate ()
38- metrics .reset ()
39- empty_score = metrics .accumulate ()
43+ score = metrics .getmetrics ()
44+ metrics .resetmetric ()
45+ empty_score = metrics .getmetrics ()
4046
4147 assert isinstance (score , dict ), "Expected a dictionary output."
4248 assert metric in score , f"Expected { metric } metric in the output."
@@ -143,16 +149,22 @@ def test_accuracy():
143149def test_entropypred ():
144150 import torch as th
145151
146- pred_logits = th .rand (6 , 5 )
147152 true_lab = th .rand (6 , 5 )
148153
149- metric = EntropyPrediction (averages = "mean" )
150- metric2 = EntropyPrediction (averages = "sum" )
154+ metric = EntropyPrediction (num_classes = 5 )
151155
152- # Test for averaging metric consistency
156+ # Test if the metric stores multiple values
157+ pred_logits = th .rand (6 , 5 )
153158 metric (true_lab , pred_logits )
154- metric2 (true_lab , pred_logits )
155- assert (
156- th .abs (th .sum (6 * metric .__returnmetric__ () - metric2 .__returnmetric__ ()))
157- < 1e-5
158- )
159+
160+ pred_logits = th .rand (6 , 5 )
161+ metric (true_lab , pred_logits )
162+
163+ pred_logits = th .rand (6 , 5 )
164+ metric (true_lab , pred_logits )
165+
166+ assert type (metric .__returnmetric__ ()) == th .Tensor
167+
168+ # Test than an error is raised with num_class != class dimension length
169+ with pytest .raises (AssertionError ):
170+ metric (true_lab , th .rand (6 , 6 ))
0 commit comments