33import pytest
44
55from CollaborativeCoding .load_metric import MetricWrapper
6- from CollaborativeCoding .metrics import Accuracy , F1Score , Precision , Recall
6+ from CollaborativeCoding .metrics import (
7+ Accuracy ,
8+ EntropyPrediction ,
9+ F1Score ,
10+ Precision ,
11+ Recall ,
12+ )
713
814
915@pytest .mark .parametrize (
1723 ("accuracy" , randint (2 , 10 ), True ),
1824 ("precision" , randint (2 , 10 ), False ),
1925 ("precision" , randint (2 , 10 ), True ),
20- # TODO: Add test for EntropyPrediction
26+ ( "entropy" , randint ( 2 , 10 ), False ),
2127 ],
2228)
2329def test_metric_wrapper (metric , num_classes , macro_averaging ):
@@ -34,9 +40,9 @@ def test_metric_wrapper(metric, num_classes, macro_averaging):
3440 )
3541
3642 metrics (y_true , logits )
37- score = metrics .accumulate ()
38- metrics .reset ()
39- empty_score = metrics .accumulate ()
43+ score = metrics .getmetrics ()
44+ metrics .resetmetric ()
45+ empty_score = metrics .getmetrics ()
4046
4147 assert isinstance (score , dict ), "Expected a dictionary output."
4248 assert metric in score , f"Expected { metric } metric in the output."
@@ -145,16 +151,22 @@ def test_accuracy():
145151def test_entropypred ():
146152 import torch as th
147153
148- pred_logits = th .rand (6 , 5 )
149154 true_lab = th .rand (6 , 5 )
150155
151- metric = EntropyPrediction (averages = "mean" )
152- metric2 = EntropyPrediction (averages = "sum" )
156+ metric = EntropyPrediction (num_classes = 5 )
153157
154- # Test for averaging metric consistency
158+ # Test if the metric stores multiple values
159+ pred_logits = th .rand (6 , 5 )
155160 metric (true_lab , pred_logits )
156- metric2 (true_lab , pred_logits )
157- assert (
158- th .abs (th .sum (6 * metric .__returnmetric__ () - metric2 .__returnmetric__ ()))
159- < 1e-5
160- )
161+
162+ pred_logits = th .rand (6 , 5 )
163+ metric (true_lab , pred_logits )
164+
165+ pred_logits = th .rand (6 , 5 )
166+ metric (true_lab , pred_logits )
167+
168+ assert type (metric .__returnmetric__ ()) == th .Tensor
169+
170+ # Test than an error is raised with num_class != class dimension length
171+ with pytest .raises (AssertionError ):
172+ metric (true_lab , th .rand (6 , 6 ))
0 commit comments