|
| 1 | +import torch |
| 2 | +import torch.nn as nn |
| 3 | + |
| 4 | +USE_MEAN = True |
| 5 | + |
| 6 | +# Precision = TP / (TP + FP) |
| 7 | + |
| 8 | + |
| 9 | +class Precision(nn.Module): |
| 10 | + """Metric module for precision. Can calculate precision both as a mean of precisions or as brute function of true positives and false positives. This is for now controller with the USE_MEAN macro. |
| 11 | +
|
| 12 | + Parameters |
| 13 | + ---------- |
| 14 | + num_classes : int |
| 15 | + Number of classes in the dataset. |
| 16 | + """ |
| 17 | + |
| 18 | + def __init__(self, num_classes): |
| 19 | + super().__init__() |
| 20 | + |
| 21 | + self.num_classes = num_classes |
| 22 | + |
| 23 | + def forward(self, y_true, y_pred): |
| 24 | + """Calculates the precision score given number of classes and the true and predicted labels. |
| 25 | +
|
| 26 | + Parameters |
| 27 | + ---------- |
| 28 | + y_true : torch.tensor |
| 29 | + true labels |
| 30 | + y_pred : torch.tensor |
| 31 | + predicted labels |
| 32 | +
|
| 33 | + Returns |
| 34 | + ------- |
| 35 | + torch.tensor |
| 36 | + precision score |
| 37 | + """ |
| 38 | + # One-hot encode the target tensor |
| 39 | + true_oh = torch.zeros(y_true.size(0), self.num_classes).scatter_( |
| 40 | + 1, y_true.unsqueeze(1), 1 |
| 41 | + ) |
| 42 | + pred_oh = torch.zeros(y_pred.size(0), self.num_classes).scatter_( |
| 43 | + 1, y_pred.unsqueeze(1), 1 |
| 44 | + ) |
| 45 | + |
| 46 | + if USE_MEAN: |
| 47 | + tp = torch.sum(true_oh * pred_oh, 0) |
| 48 | + fp = torch.sum(~true_oh.bool() * pred_oh, 0) |
| 49 | + |
| 50 | + else: |
| 51 | + tp = torch.sum(true_oh * pred_oh) |
| 52 | + fp = torch.sum(~true_oh[pred_oh.bool()].bool()) |
| 53 | + |
| 54 | + return torch.nanmean(tp / (tp + fp)) |
| 55 | + |
| 56 | + |
| 57 | +def test_precision_case1(): |
| 58 | + true_precision = 25.0 / 36 if USE_MEAN else 7.0 / 10 |
| 59 | + |
| 60 | + true1 = torch.tensor([0, 1, 2, 1, 0, 2, 1, 0, 2, 1]) |
| 61 | + pred1 = torch.tensor([0, 2, 1, 1, 0, 2, 0, 0, 2, 1]) |
| 62 | + P = Precision(3) |
| 63 | + precision1 = P(true1, pred1) |
| 64 | + assert precision1.allclose(torch.tensor(true_precision), atol=1e-5), ( |
| 65 | + f"Precision Score: {precision1.item()}" |
| 66 | + ) |
| 67 | + |
| 68 | + |
| 69 | +def test_precision_case2(): |
| 70 | + true_precision = 8.0 / 15 if USE_MEAN else 6.0 / 15 |
| 71 | + |
| 72 | + true2 = torch.tensor([0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4]) |
| 73 | + pred2 = torch.tensor([0, 0, 4, 3, 4, 0, 4, 4, 2, 3, 4, 1, 2, 4, 0]) |
| 74 | + P = Precision(5) |
| 75 | + precision2 = P(true2, pred2) |
| 76 | + assert precision2.allclose(torch.tensor(true_precision), atol=1e-5), ( |
| 77 | + f"Precision Score: {precision2.item()}" |
| 78 | + ) |
| 79 | + |
| 80 | + |
| 81 | +def test_precision_case3(): |
| 82 | + true_precision = 3.0 / 4 if USE_MEAN else 4.0 / 5 |
| 83 | + |
| 84 | + true3 = torch.tensor([0, 0, 0, 1, 0]) |
| 85 | + pred3 = torch.tensor([1, 0, 0, 1, 0]) |
| 86 | + P = Precision(2) |
| 87 | + precision3 = P(true3, pred3) |
| 88 | + assert precision3.allclose(torch.tensor(true_precision), atol=1e-5), ( |
| 89 | + f"Precision Score: {precision3.item()}" |
| 90 | + ) |
| 91 | + |
| 92 | + |
| 93 | +def test_for_zero_denominator(): |
| 94 | + true_precision = 0.0 |
| 95 | + true4 = torch.tensor([1, 1, 1, 1, 1]) |
| 96 | + pred4 = torch.tensor([0, 0, 0, 0, 0]) |
| 97 | + P = Precision(2) |
| 98 | + precision4 = P(true4, pred4) |
| 99 | + assert precision4.allclose(torch.tensor(true_precision), atol=1e-5), ( |
| 100 | + f"Precision Score: {precision4.item()}" |
| 101 | + ) |
| 102 | + |
| 103 | + |
| 104 | +if __name__ == "__main__": |
| 105 | + pass |
0 commit comments