Skip to content

Commit a5a6c01

Browse files
authored
Merge pull request #29 from SFI-Visual-Intelligence/johan/devbranch
Johan/devbranch
2 parents c455b8a + b1a3627 commit a5a6c01

File tree

3 files changed

+169
-2
lines changed

3 files changed

+169
-2
lines changed

utils/load_metric.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
import torch.nn as nn
55

6-
from .metrics import EntropyPrediction, F1Score
6+
from .metrics import EntropyPrediction, F1Score, precision
77

88

99
class MetricWrapper(nn.Module):
@@ -39,7 +39,7 @@ def _get_metric(self, key):
3939
case "recall":
4040
raise NotImplementedError("Recall score not implemented yet")
4141
case "precision":
42-
raise NotImplementedError("Precision score not implemented yet")
42+
return precision()
4343
case "accuracy":
4444
raise NotImplementedError("Accuracy score not implemented yet")
4545
case _:

utils/metrics/precision.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
USE_MEAN = True
5+
6+
# Precision = TP / (TP + FP)
7+
8+
9+
class Precision(nn.Module):
10+
"""Metric module for precision. Can calculate precision both as a mean of precisions or as brute function of true positives and false positives. This is for now controller with the USE_MEAN macro.
11+
12+
Parameters
13+
----------
14+
num_classes : int
15+
Number of classes in the dataset.
16+
"""
17+
18+
def __init__(self, num_classes):
19+
super().__init__()
20+
21+
self.num_classes = num_classes
22+
23+
def forward(self, y_true, y_pred):
24+
"""Calculates the precision score given number of classes and the true and predicted labels.
25+
26+
Parameters
27+
----------
28+
y_true : torch.tensor
29+
true labels
30+
y_pred : torch.tensor
31+
predicted labels
32+
33+
Returns
34+
-------
35+
torch.tensor
36+
precision score
37+
"""
38+
# One-hot encode the target tensor
39+
true_oh = torch.zeros(y_true.size(0), self.num_classes).scatter_(
40+
1, y_true.unsqueeze(1), 1
41+
)
42+
pred_oh = torch.zeros(y_pred.size(0), self.num_classes).scatter_(
43+
1, y_pred.unsqueeze(1), 1
44+
)
45+
46+
if USE_MEAN:
47+
tp = torch.sum(true_oh * pred_oh, 0)
48+
fp = torch.sum(~true_oh.bool() * pred_oh, 0)
49+
50+
else:
51+
tp = torch.sum(true_oh * pred_oh)
52+
fp = torch.sum(~true_oh[pred_oh.bool()].bool())
53+
54+
return torch.nanmean(tp / (tp + fp))
55+
56+
57+
def test_precision_case1():
58+
true_precision = 25.0 / 36 if USE_MEAN else 7.0 / 10
59+
60+
true1 = torch.tensor([0, 1, 2, 1, 0, 2, 1, 0, 2, 1])
61+
pred1 = torch.tensor([0, 2, 1, 1, 0, 2, 0, 0, 2, 1])
62+
P = Precision(3)
63+
precision1 = P(true1, pred1)
64+
assert precision1.allclose(torch.tensor(true_precision), atol=1e-5), (
65+
f"Precision Score: {precision1.item()}"
66+
)
67+
68+
69+
def test_precision_case2():
70+
true_precision = 8.0 / 15 if USE_MEAN else 6.0 / 15
71+
72+
true2 = torch.tensor([0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4])
73+
pred2 = torch.tensor([0, 0, 4, 3, 4, 0, 4, 4, 2, 3, 4, 1, 2, 4, 0])
74+
P = Precision(5)
75+
precision2 = P(true2, pred2)
76+
assert precision2.allclose(torch.tensor(true_precision), atol=1e-5), (
77+
f"Precision Score: {precision2.item()}"
78+
)
79+
80+
81+
def test_precision_case3():
82+
true_precision = 3.0 / 4 if USE_MEAN else 4.0 / 5
83+
84+
true3 = torch.tensor([0, 0, 0, 1, 0])
85+
pred3 = torch.tensor([1, 0, 0, 1, 0])
86+
P = Precision(2)
87+
precision3 = P(true3, pred3)
88+
assert precision3.allclose(torch.tensor(true_precision), atol=1e-5), (
89+
f"Precision Score: {precision3.item()}"
90+
)
91+
92+
93+
def test_for_zero_denominator():
94+
true_precision = 0.0
95+
true4 = torch.tensor([1, 1, 1, 1, 1])
96+
pred4 = torch.tensor([0, 0, 0, 0, 0])
97+
P = Precision(2)
98+
precision4 = P(true4, pred4)
99+
assert precision4.allclose(torch.tensor(true_precision), atol=1e-5), (
100+
f"Precision Score: {precision4.item()}"
101+
)
102+
103+
104+
if __name__ == "__main__":
105+
pass

utils/models/johan_model.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import torch.nn as nn
2+
3+
"""
4+
Multi-layer perceptron model for image classification.
5+
"""
6+
7+
# class NeuronLayer(nn.Module):
8+
# def __init__(self, in_features, out_features):
9+
# super().__init__()
10+
11+
# self.fc = nn.Linear(in_features, out_features)
12+
# self.relu = nn.ReLU()
13+
14+
# def forward(self, x):
15+
# x = self.fc(x)
16+
# x = self.relu(x)
17+
# return x
18+
19+
20+
class JohanModel(nn.Module):
21+
"""Small MLP model for image classification.
22+
23+
Parameters
24+
----------
25+
in_features : int
26+
Numer of input features.
27+
num_classes : int
28+
Number of classes in the dataset.
29+
30+
"""
31+
32+
def __init__(self, image_shape, num_classes):
33+
super().__init__()
34+
35+
# Extract features from image shape
36+
self.in_channels = image_shape[0]
37+
self.height = image_shape[1]
38+
self.width = image_shape[2]
39+
self.num_classes = num_classes
40+
self.in_features = self.in_channels * self.height * self.width
41+
42+
self.fc1 = nn.Linear(self.in_features, 77)
43+
self.fc2 = nn.Linear(77, 77)
44+
self.fc3 = nn.Linear(77, 77)
45+
self.fc4 = nn.Linear(77, num_classes)
46+
self.softmax = nn.Softmax(dim=1)
47+
self.relu = nn.ReLU()
48+
49+
def forward(self, x):
50+
for layer in [self.fc1, self.fc2, self.fc3, self.fc4]:
51+
x = layer(x)
52+
x = self.relu(x)
53+
x = self.softmax(x)
54+
return x
55+
56+
57+
# TODO
58+
# Add your tests here
59+
60+
61+
if __name__ == "__main__":
62+
pass # Add your tests here

0 commit comments

Comments
 (0)