Skip to content

Commit c1de9cc

Browse files
committed
formatted utils folder
1 parent 562800d commit c1de9cc

File tree

6 files changed

+34
-79
lines changed

6 files changed

+34
-79
lines changed

utils/dataloaders/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
__all__ = ["USPSDataset0_6", "USPSH5_Digit_7_9_Dataset"]
22

33
from .usps_0_6 import USPSDataset0_6
4-
from .uspsh5_7_9 import USPSH5_Digit_7_9_Dataset
4+
from .uspsh5_7_9 import USPSH5_Digit_7_9_Dataset

utils/load_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,6 @@ def load_data(dataset: str, *args, **kwargs) -> Dataset:
88
case "usps_0-6":
99
return USPSDataset0_6(*args, **kwargs)
1010
case "usps_7-9":
11-
return USPSH5_Digit_7_9_Dataset(*args, **kwargs)
11+
return USPSH5_Digit_7_9_Dataset(*args, **kwargs)
1212
case _:
1313
raise ValueError(f"Dataset: {dataset} not implemented.")

utils/metrics/F1.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,4 +84,3 @@ def compute(self):
8484
)
8585

8686
return f1_score
87-

utils/metrics/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
from .EntropyPred import EntropyPrediction
44
from .F1 import F1Score
55
from .recall import Recall
6+
from .precision import Precision

utils/metrics/precision.py

Lines changed: 7 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,23 @@
77

88

99
class Precision(nn.Module):
10-
"""Metric module for precision. Can calculate precision both as a mean of precisions or as brute function of true positives and false positives. This is for now controller with the USE_MEAN macro.
10+
"""Metric module for precision. Can calculate precision both as a mean of precisions or as brute function of true positives and false positives.
1111
1212
Parameters
1313
----------
1414
num_classes : int
1515
Number of classes in the dataset.
16+
use_mean : bool
17+
Whether to calculate precision as a mean of precisions or as a brute function of true positives and false positives.
1618
"""
1719

18-
def __init__(self, num_classes):
20+
def __init__(self, num_classes: int, use_mean: bool = True):
1921
super().__init__()
2022

2123
self.num_classes = num_classes
24+
self.use_mean = use_mean
2225

23-
def forward(self, y_true, y_pred):
26+
def forward(self, y_true: torch.tensor, y_pred: torch.tensor) -> torch.tensor:
2427
"""Calculates the precision score given number of classes and the true and predicted labels.
2528
2629
Parameters
@@ -43,7 +46,7 @@ def forward(self, y_true, y_pred):
4346
1, y_pred.unsqueeze(1), 1
4447
)
4548

46-
if USE_MEAN:
49+
if self.use_mean:
4750
tp = torch.sum(true_oh * pred_oh, 0)
4851
fp = torch.sum(~true_oh.bool() * pred_oh, 0)
4952

@@ -54,52 +57,5 @@ def forward(self, y_true, y_pred):
5457
return torch.nanmean(tp / (tp + fp))
5558

5659

57-
def test_precision_case1():
58-
true_precision = 25.0 / 36 if USE_MEAN else 7.0 / 10
59-
60-
true1 = torch.tensor([0, 1, 2, 1, 0, 2, 1, 0, 2, 1])
61-
pred1 = torch.tensor([0, 2, 1, 1, 0, 2, 0, 0, 2, 1])
62-
P = Precision(3)
63-
precision1 = P(true1, pred1)
64-
assert precision1.allclose(torch.tensor(true_precision), atol=1e-5), (
65-
f"Precision Score: {precision1.item()}"
66-
)
67-
68-
69-
def test_precision_case2():
70-
true_precision = 8.0 / 15 if USE_MEAN else 6.0 / 15
71-
72-
true2 = torch.tensor([0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4])
73-
pred2 = torch.tensor([0, 0, 4, 3, 4, 0, 4, 4, 2, 3, 4, 1, 2, 4, 0])
74-
P = Precision(5)
75-
precision2 = P(true2, pred2)
76-
assert precision2.allclose(torch.tensor(true_precision), atol=1e-5), (
77-
f"Precision Score: {precision2.item()}"
78-
)
79-
80-
81-
def test_precision_case3():
82-
true_precision = 3.0 / 4 if USE_MEAN else 4.0 / 5
83-
84-
true3 = torch.tensor([0, 0, 0, 1, 0])
85-
pred3 = torch.tensor([1, 0, 0, 1, 0])
86-
P = Precision(2)
87-
precision3 = P(true3, pred3)
88-
assert precision3.allclose(torch.tensor(true_precision), atol=1e-5), (
89-
f"Precision Score: {precision3.item()}"
90-
)
91-
92-
93-
def test_for_zero_denominator():
94-
true_precision = 0.0
95-
true4 = torch.tensor([1, 1, 1, 1, 1])
96-
pred4 = torch.tensor([0, 0, 0, 0, 0])
97-
P = Precision(2)
98-
precision4 = P(true4, pred4)
99-
assert precision4.allclose(torch.tensor(true_precision), atol=1e-5), (
100-
f"Precision Score: {precision4.item()}"
101-
)
102-
103-
10460
if __name__ == "__main__":
10561
pass

utils/models/solveig_model.py

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,26 @@
44

55
class SolveigModel(nn.Module):
66
"""
7-
A Convolutional Neural Network model for classification.
8-
9-
Args
10-
----
11-
image_shape : tuple(int, int, int)
12-
Shape of the input image (C, H, W).
13-
num_classes : int
14-
Number of classes in the dataset.
15-
16-
Attributes:
17-
-----------
18-
conv_block1 : nn.Sequential
19-
First convolutional block containing a convolutional layer, ReLU activation, and max-pooling.
20-
conv_block2 : nn.Sequential
21-
Second convolutional block containing a convolutional layer and ReLU activation.
22-
conv_block3 : nn.Sequential
23-
Third convolutional block containing a convolutional layer and ReLU activation.
24-
fc1 : nn.Linear
25-
Fully connected layer that outputs the final classification scores.
26-
"""
7+
A Convolutional Neural Network model for classification.
8+
9+
Args
10+
----
11+
image_shape : tuple(int, int, int)
12+
Shape of the input image (C, H, W).
13+
num_classes : int
14+
Number of classes in the dataset.
15+
16+
Attributes:
17+
-----------
18+
conv_block1 : nn.Sequential
19+
First convolutional block containing a convolutional layer, ReLU activation, and max-pooling.
20+
conv_block2 : nn.Sequential
21+
Second convolutional block containing a convolutional layer and ReLU activation.
22+
conv_block3 : nn.Sequential
23+
Third convolutional block containing a convolutional layer and ReLU activation.
24+
fc1 : nn.Linear
25+
Fully connected layer that outputs the final classification scores.
26+
"""
2727

2828
def __init__(self, image_shape, num_classes):
2929
super().__init__()
@@ -34,19 +34,19 @@ def __init__(self, image_shape, num_classes):
3434
self.conv_block1 = nn.Sequential(
3535
nn.Conv2d(in_channels=C, out_channels=25, kernel_size=3, padding=1),
3636
nn.ReLU(),
37-
nn.MaxPool2d(kernel_size=2, stride=2)
37+
nn.MaxPool2d(kernel_size=2, stride=2),
3838
)
3939

4040
# Define the second convolutional block (conv + relu)
4141
self.conv_block2 = nn.Sequential(
4242
nn.Conv2d(in_channels=25, out_channels=50, kernel_size=3, padding=1),
43-
nn.ReLU()
43+
nn.ReLU(),
4444
)
4545

4646
# Define the third convolutional block (conv + relu)
4747
self.conv_block3 = nn.Sequential(
4848
nn.Conv2d(in_channels=50, out_channels=100, kernel_size=3, padding=1),
49-
nn.ReLU()
49+
nn.ReLU(),
5050
)
5151

5252
self.fc1 = nn.Linear(100 * 8 * 8, num_classes)
@@ -64,8 +64,7 @@ def forward(self, x):
6464

6565

6666
if __name__ == "__main__":
67-
68-
x = torch.randn(1,3, 16, 16)
67+
x = torch.randn(1, 3, 16, 16)
6968

7069
model = SolveigModel(x.shape[1:], 3)
7170

0 commit comments

Comments
 (0)