Skip to content

Commit b9dc34e

Browse files
committed
Update tests for Recall metric
1 parent bf8a09f commit b9dc34e

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

tests/test_metrics.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,19 @@
44
def test_recall():
55
import torch
66

7-
recall = Recall(7)
8-
97
y_true = torch.tensor([0, 1, 2, 3, 4, 5, 6])
10-
y_pred = torch.tensor([2, 1, 2, 1, 4, 5, 6])
8+
logits = torch.randn(7, 7)
119

12-
recall_score = recall(y_true, y_pred)
10+
recall_micro = Recall(7)
11+
recall_macro = Recall(7, macro_averaging=True)
1312

14-
assert recall_score.allclose(torch.tensor(0.7143), atol=1e-5), (
15-
f"Recall Score: {recall_score.item()}"
16-
)
13+
recall_micro_score = recall_micro(y_true, logits)
14+
recall_macro_score = recall_macro(y_true, logits)
15+
16+
assert isinstance(recall_micro_score, torch.Tensor), "Expected a tensor output."
17+
assert isinstance(recall_macro_score, torch.Tensor), "Expected a tensor output."
18+
assert recall_micro_score.item() >= 0, "Expected a non-negative value."
19+
assert recall_macro_score.item() >= 0, "Expected a non-negative value."
1720

1821

1922
def test_f1score():

0 commit comments

Comments
 (0)