|
1 | | -from utils.metrics import Recall, F1Score |
| 1 | +from utils.metrics import F1Score, Precision, Recall |
2 | 2 |
|
3 | 3 |
|
4 | 4 | def test_recall(): |
@@ -30,3 +30,55 @@ def test_f1score(): |
30 | 30 | assert f1_metric.tp.sum().item() > 0, "Expected some true positives." |
31 | 31 | assert f1_metric.fp.sum().item() > 0, "Expected some false positives." |
32 | 32 | assert f1_metric.fn.sum().item() > 0, "Expected some false negatives." |
| 33 | + |
| 34 | + |
| 35 | +def test_precision_case1(): |
| 36 | + import torch |
| 37 | + |
| 38 | + for boolean, true_precision in zip([True, False], [25.0 / 36, 7.0 / 10]): |
| 39 | + true1 = torch.tensor([0, 1, 2, 1, 0, 2, 1, 0, 2, 1]) |
| 40 | + pred1 = torch.tensor([0, 2, 1, 1, 0, 2, 0, 0, 2, 1]) |
| 41 | + P = Precision(3, use_mean=boolean) |
| 42 | + precision1 = P(true1, pred1) |
| 43 | + assert precision1.allclose(torch.tensor(true_precision), atol=1e-5), ( |
| 44 | + f"Precision Score: {precision1.item()}" |
| 45 | + ) |
| 46 | + |
| 47 | + |
| 48 | +def test_precision_case2(): |
| 49 | + import torch |
| 50 | + |
| 51 | + for boolean, true_precision in zip([True, False], [8.0 / 15, 6.0 / 15]): |
| 52 | + true2 = torch.tensor([0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4]) |
| 53 | + pred2 = torch.tensor([0, 0, 4, 3, 4, 0, 4, 4, 2, 3, 4, 1, 2, 4, 0]) |
| 54 | + P = Precision(5, use_mean=boolean) |
| 55 | + precision2 = P(true2, pred2) |
| 56 | + assert precision2.allclose(torch.tensor(true_precision), atol=1e-5), ( |
| 57 | + f"Precision Score: {precision2.item()}" |
| 58 | + ) |
| 59 | + |
| 60 | + |
| 61 | +def test_precision_case3(): |
| 62 | + import torch |
| 63 | + |
| 64 | + for boolean, true_precision in zip([True, False], [3.0 / 4, 4.0 / 5]): |
| 65 | + true3 = torch.tensor([0, 0, 0, 1, 0]) |
| 66 | + pred3 = torch.tensor([1, 0, 0, 1, 0]) |
| 67 | + P = Precision(2, use_mean=boolean) |
| 68 | + precision3 = P(true3, pred3) |
| 69 | + assert precision3.allclose(torch.tensor(true_precision), atol=1e-5), ( |
| 70 | + f"Precision Score: {precision3.item()}" |
| 71 | + ) |
| 72 | + |
| 73 | + |
| 74 | +def test_for_zero_denominator(): |
| 75 | + import torch |
| 76 | + |
| 77 | + for boolean in [True, False]: |
| 78 | + true4 = torch.tensor([1, 1, 1, 1, 1]) |
| 79 | + pred4 = torch.tensor([0, 0, 0, 0, 0]) |
| 80 | + P = Precision(2, use_mean=boolean) |
| 81 | + precision4 = P(true4, pred4) |
| 82 | + assert precision4.allclose(torch.tensor(0.0), atol=1e-5), ( |
| 83 | + f"Precision Score: {precision4.item()}" |
| 84 | + ) |
0 commit comments