2323 ("accuracy" , randint (2 , 10 ), True ),
2424 ("precision" , randint (2 , 10 ), False ),
2525 ("precision" , randint (2 , 10 ), True ),
26- ("EntropyPrediction " , randint (2 , 10 ), False ),
26+ ("entropy " , randint (2 , 10 ), False ),
2727 ],
2828)
2929def test_metric_wrapper (metric , num_classes , macro_averaging ):
@@ -40,9 +40,9 @@ def test_metric_wrapper(metric, num_classes, macro_averaging):
4040 )
4141
4242 metrics (y_true , logits )
43- score = metrics .accumulate ()
44- metrics .reset ()
45- empty_score = metrics .accumulate ()
43+ score = metrics .getmetrics ()
44+ metrics .resetmetric ()
45+ empty_score = metrics .getmetrics ()
4646
4747 assert isinstance (score , dict ), "Expected a dictionary output."
4848 assert metric in score , f"Expected { metric } metric in the output."
@@ -151,16 +151,22 @@ def test_accuracy():
151151def test_entropypred ():
152152 import torch as th
153153
154- pred_logits = th .rand (6 , 5 )
155154 true_lab = th .rand (6 , 5 )
156155
157- metric = EntropyPrediction (averages = "mean" )
158- metric2 = EntropyPrediction (averages = "sum" )
156+ metric = EntropyPrediction (num_classes = 5 )
159157
160- # Test for averaging metric consistency
158+ # Test if the metric stores multiple values
159+ pred_logits = th .rand (6 , 5 )
161160 metric (true_lab , pred_logits )
162- metric2 (true_lab , pred_logits )
163- assert (
164- th .abs (th .sum (6 * metric .__returnmetric__ () - metric2 .__returnmetric__ ()))
165- < 1e-5
166- )
161+
162+ pred_logits = th .rand (6 , 5 )
163+ metric (true_lab , pred_logits )
164+
165+ pred_logits = th .rand (6 , 5 )
166+ metric (true_lab , pred_logits )
167+
168+ assert type (metric .__returnmetric__ ()) == th .Tensor
169+
170+ # Test than an error is raised with num_class != class dimension length
171+ with pytest .raises (AssertionError ):
172+ metric (true_lab , th .rand (6 , 6 ))
0 commit comments