2323 ("accuracy" , randint (2 , 10 ), True ),
2424 ("precision" , randint (2 , 10 ), False ),
2525 ("precision" , randint (2 , 10 ), True ),
26- # TODO: Add test for EntropyPrediction
26+ ( "entropy" , randint ( 2 , 10 ), False ),
2727 ],
2828)
2929def test_metric_wrapper (metric , num_classes , macro_averaging ):
@@ -40,9 +40,9 @@ def test_metric_wrapper(metric, num_classes, macro_averaging):
4040 )
4141
4242 metrics (y_true , logits )
43- score = metrics .__getmetrics__ ()
44- metrics .__resetmetrics__ ()
45- empty_score = metrics .__getmetrics__ ()
43+ score = metrics .getmetrics ()
44+ metrics .resetmetric ()
45+ empty_score = metrics .getmetrics ()
4646
4747 assert isinstance (score , dict ), "Expected a dictionary output."
4848 assert metric in score , f"Expected { metric } metric in the output."
@@ -169,16 +169,22 @@ def test_accuracy():
169169def test_entropypred ():
170170 import torch as th
171171
172- pred_logits = th .rand (6 , 5 )
173172 true_lab = th .rand (6 , 5 )
174173
175- metric = EntropyPrediction (averages = "mean" )
176- metric2 = EntropyPrediction (averages = "sum" )
174+ metric = EntropyPrediction (num_classes = 5 )
177175
178- # Test for averaging metric consistency
176+ # Test if the metric stores multiple values
177+ pred_logits = th .rand (6 , 5 )
179178 metric (true_lab , pred_logits )
180- metric2 (true_lab , pred_logits )
181- assert (
182- th .abs (th .sum (6 * metric .__returnmetric__ () - metric2 .__returnmetric__ ()))
183- < 1e-5
184- )
179+
180+ pred_logits = th .rand (6 , 5 )
181+ metric (true_lab , pred_logits )
182+
183+ pred_logits = th .rand (6 , 5 )
184+ metric (true_lab , pred_logits )
185+
186+ assert type (metric .__returnmetric__ ()) == th .Tensor
187+
188+ # Test than an error is raised with num_class != class dimension length
189+ with pytest .raises (AssertionError ):
190+ metric (true_lab , th .rand (6 , 6 ))
0 commit comments