Skip to content

Commit c455b8a

Browse files
authored
Merge pull request #32 from SFI-Visual-Intelligence/solveig-branch
Solveig branch
2 parents 7ff097a + 9d6692c commit c455b8a

File tree

10 files changed

+107
-23
lines changed

10 files changed

+107
-23
lines changed

main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,14 +66,14 @@ def main():
6666
"--modelname",
6767
type=str,
6868
default="MagnusModel",
69-
choices=["MagnusModel", "ChristianModel"],
69+
choices=["MagnusModel", "ChristianModel", "SolveigModel"],
7070
help="Model which to be trained on",
7171
)
7272
parser.add_argument(
7373
"--dataset",
7474
type=str,
7575
default="svhn",
76-
choices=["svhn", "usps_0-6"],
76+
choices=["svhn", "usps_0-6", "uspsh5_7_9"],
7777
help="Which dataset to train the model on.",
7878
)
7979

tests/test_metrics.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from utils.metrics import Recall
1+
from utils.metrics import Recall, F1Score
22

33

44
def test_recall():
@@ -14,3 +14,19 @@ def test_recall():
1414
assert recall_score.allclose(torch.tensor(0.7143), atol=1e-5), (
1515
f"Recall Score: {recall_score.item()}"
1616
)
17+
18+
19+
def test_f1score():
20+
import torch
21+
22+
f1_metric = F1Score(num_classes=3)
23+
preds = torch.tensor(
24+
[[0.8, 0.1, 0.1], [0.2, 0.7, 0.1], [0.2, 0.3, 0.5], [0.1, 0.2, 0.7]]
25+
)
26+
27+
target = torch.tensor([0, 1, 0, 2])
28+
29+
f1_metric.update(preds, target)
30+
assert f1_metric.tp.sum().item() > 0, "Expected some true positives."
31+
assert f1_metric.fp.sum().item() > 0, "Expected some false positives."
32+
assert f1_metric.fn.sum().item() > 0, "Expected some false negatives."

utils/dataloaders/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1-
__all__ = ["USPSDataset0_6"]
1+
__all__ = ["USPSDataset0_6", "USPSH5_Digit_7_9_Dataset"]
22

33
from .usps_0_6 import USPSDataset0_6
4+
from .uspsh5_7_9 import USPSH5_Digit_7_9_Dataset

utils/load_data.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from torch.utils.data import Dataset
22

3-
from .dataloaders import USPSDataset0_6
3+
from .dataloaders import USPSDataset0_6, USPSH5_Digit_7_9_Dataset
44

55

66
def load_data(dataset: str, *args, **kwargs) -> Dataset:
77
match dataset.lower():
88
case "usps_0-6":
99
return USPSDataset0_6(*args, **kwargs)
10+
case "usps_7-9":
11+
return USPSH5_Digit_7_9_Dataset(*args, **kwargs)
1012
case _:
1113
raise ValueError(f"Dataset: {dataset} not implemented.")

utils/load_metric.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
import torch.nn as nn
55

6-
from .metrics import EntropyPrediction
6+
from .metrics import EntropyPrediction, F1Score
77

88

99
class MetricWrapper(nn.Module):
@@ -35,7 +35,7 @@ def _get_metric(self, key):
3535
case "entropy":
3636
return EntropyPrediction()
3737
case "f1":
38-
raise NotImplementedError("F1 score not implemented yet")
38+
raise F1Score()
3939
case "recall":
4040
raise NotImplementedError("Recall score not implemented yet")
4141
case "precision":

utils/load_model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch.nn as nn
22

3-
from .models import ChristianModel, MagnusModel
3+
from .models import ChristianModel, MagnusModel, SolveigModel
44

55

66
def load_model(modelname: str, *args, **kwargs) -> nn.Module:
@@ -9,6 +9,8 @@ def load_model(modelname: str, *args, **kwargs) -> nn.Module:
99
return MagnusModel(*args, **kwargs)
1010
case "christianmodel":
1111
return ChristianModel(*args, **kwargs)
12+
case "solveigmodel":
13+
return SolveigModel(*args, **kwargs)
1214
case _:
1315
raise ValueError(
1416
f"Model: {modelname} has not been implemented. \nCheck the documentation for implemented metrics, or check your spelling"

utils/metrics/F1.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -85,16 +85,3 @@ def compute(self):
8585

8686
return f1_score
8787

88-
89-
def test_f1score():
90-
f1_metric = F1Score(num_classes=3)
91-
preds = torch.tensor(
92-
[[0.8, 0.1, 0.1], [0.2, 0.7, 0.1], [0.2, 0.3, 0.5], [0.1, 0.2, 0.7]]
93-
)
94-
95-
target = torch.tensor([0, 1, 0, 2])
96-
97-
f1_metric.update(preds, target)
98-
assert f1_metric.tp.sum().item() > 0, "Expected some true positives."
99-
assert f1_metric.fp.sum().item() > 0, "Expected some false positives."
100-
assert f1_metric.fn.sum().item() > 0, "Expected some false negatives."

utils/metrics/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
__all__ = ["EntropyPrediction", "Recall"]
1+
__all__ = ["EntropyPrediction", "Recall", "F1Score"]
22

33
from .EntropyPred import EntropyPrediction
4+
from .F1 import F1Score
45
from .recall import Recall

utils/models/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
__all__ = ["MagnusModel", "ChristianModel"]
1+
__all__ = ["MagnusModel", "ChristianModel", "SolveigModel"]
22

33
from .christian_model import ChristianModel
44
from .magnus_model import MagnusModel
5+
from .solveig_model import SolveigModel

utils/models/solveig_model.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
5+
class SolveigModel(nn.Module):
6+
"""
7+
A Convolutional Neural Network model for classification.
8+
9+
Args
10+
----
11+
image_shape : tuple(int, int, int)
12+
Shape of the input image (C, H, W).
13+
num_classes : int
14+
Number of classes in the dataset.
15+
16+
Attributes:
17+
-----------
18+
conv_block1 : nn.Sequential
19+
First convolutional block containing a convolutional layer, ReLU activation, and max-pooling.
20+
conv_block2 : nn.Sequential
21+
Second convolutional block containing a convolutional layer and ReLU activation.
22+
conv_block3 : nn.Sequential
23+
Third convolutional block containing a convolutional layer and ReLU activation.
24+
fc1 : nn.Linear
25+
Fully connected layer that outputs the final classification scores.
26+
"""
27+
28+
def __init__(self, image_shape, num_classes):
29+
super().__init__()
30+
31+
C, *_ = image_shape
32+
33+
# Define the first convolutional block (conv + relu + maxpool)
34+
self.conv_block1 = nn.Sequential(
35+
nn.Conv2d(in_channels=C, out_channels=25, kernel_size=3, padding=1),
36+
nn.ReLU(),
37+
nn.MaxPool2d(kernel_size=2, stride=2)
38+
)
39+
40+
# Define the second convolutional block (conv + relu)
41+
self.conv_block2 = nn.Sequential(
42+
nn.Conv2d(in_channels=25, out_channels=50, kernel_size=3, padding=1),
43+
nn.ReLU()
44+
)
45+
46+
# Define the third convolutional block (conv + relu)
47+
self.conv_block3 = nn.Sequential(
48+
nn.Conv2d(in_channels=50, out_channels=100, kernel_size=3, padding=1),
49+
nn.ReLU()
50+
)
51+
52+
self.fc1 = nn.Linear(100 * 8 * 8, num_classes)
53+
54+
def forward(self, x):
55+
x = self.conv_block1(x)
56+
x = self.conv_block2(x)
57+
x = self.conv_block3(x)
58+
x = torch.flatten(x, 1)
59+
60+
x = self.fc1(x)
61+
x = nn.Softmax(x)
62+
63+
return x
64+
65+
66+
if __name__ == "__main__":
67+
68+
x = torch.randn(1,3, 16, 16)
69+
70+
model = SolveigModel(x.shape[1:], 3)
71+
72+
y = model(x)
73+
74+
print(y)

0 commit comments

Comments
 (0)