Skip to content

Commit 4350664

Browse files
authored
Merge pull request #24 from SFI-Visual-Intelligence/christian/model
2 parents 970fe05 + 68b5616 commit 4350664

File tree

7 files changed

+177
-12
lines changed

7 files changed

+177
-12
lines changed

main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def main():
6666
"--modelname",
6767
type=str,
6868
default="MagnusModel",
69-
choices=["MagnusModel"],
69+
choices=["MagnusModel", "ChristianModel"],
7070
help="Model which to be trained on",
7171
)
7272
parser.add_argument(
@@ -196,7 +196,7 @@ def main():
196196
model.eval()
197197
with th.no_grad():
198198
for x, y in valiloader:
199-
x = x.to(device)
199+
x, y = x.to(device), y.to(device)
200200
pred = model.forward(x)
201201
loss = criterion(y, pred)
202202
evalloss.append(loss.item())

utils/dataloaders/usps_0_6.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,12 @@ def __getitem__(self, idx):
106106

107107
data = data.reshape(16, 16)
108108

109+
# one hot encode the target
110+
target = np.eye(self.num_classes, dtype=np.float32)[target]
111+
112+
# Add channel dimension
113+
data = np.expand_dims(data, axis=0)
114+
109115
if self.transform:
110116
data = self.transform(data)
111117

utils/load_model.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
import torch.nn as nn
22

3-
from .models import MagnusModel
3+
from .models import ChristianModel, MagnusModel
44

55

6-
def load_model(modelname: str) -> nn.Module:
7-
if modelname == "MagnusModel":
8-
return MagnusModel()
9-
else:
10-
raise ValueError(
11-
f"Model: {modelname} has not been implemented. \nCheck the documentation for implemented metrics, or check your spelling"
12-
)
6+
def load_model(modelname: str, *args, **kwargs) -> nn.Module:
7+
match modelname.lower():
8+
case "magnusmodel":
9+
return MagnusModel(*args, **kwargs)
10+
case "christianmodel":
11+
return ChristianModel(*args, **kwargs)
12+
case _:
13+
raise ValueError(
14+
f"Model: {modelname} has not been implemented. \nCheck the documentation for implemented metrics, or check your spelling"
15+
)

utils/metrics/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1-
__all__ = ["EntropyPrediction"]
1+
__all__ = ["EntropyPrediction", "Recall"]
22

33
from .EntropyPred import EntropyPrediction
4+
from .recall import Recall

utils/metrics/recall.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
5+
def one_hot_encode(y_true, num_classes):
6+
"""One-hot encode the target tensor.
7+
8+
Args
9+
----
10+
y_true : torch.Tensor
11+
Target tensor.
12+
num_classes : int
13+
Number of classes in the dataset.
14+
15+
Returns
16+
-------
17+
torch.Tensor
18+
One-hot encoded tensor.
19+
"""
20+
21+
y_onehot = torch.zeros(y_true.size(0), num_classes)
22+
y_onehot.scatter_(1, y_true.unsqueeze(1), 1)
23+
return y_onehot
24+
25+
26+
class Recall(nn.Module):
27+
def __init__(self, num_classes):
28+
super().__init__()
29+
30+
self.num_classes = num_classes
31+
32+
def forward(self, y_true, y_pred):
33+
true_onehot = one_hot_encode(y_true, self.num_classes)
34+
pred_onehot = one_hot_encode(y_pred, self.num_classes)
35+
36+
true_positives = (true_onehot * pred_onehot).sum()
37+
38+
false_negatives = torch.sum(~pred_onehot[true_onehot.bool()].bool())
39+
40+
recall = true_positives / (true_positives + false_negatives)
41+
42+
return recall
43+
44+
45+
def test_recall():
46+
recall = Recall(7)
47+
48+
y_true = torch.tensor([0, 1, 2, 3, 4, 5, 6])
49+
y_pred = torch.tensor([2, 1, 2, 1, 4, 5, 6])
50+
51+
recall_score = recall(y_true, y_pred)
52+
53+
assert recall_score.allclose(torch.tensor(0.7143), atol=1e-5), f"Recall Score: {recall_score.item()}"
54+
55+
56+
def test_one_hot_encode():
57+
num_classes = 7
58+
59+
y_true = torch.tensor([0, 1, 2, 3, 4, 5, 6])
60+
y_onehot = one_hot_encode(y_true, num_classes)
61+
62+
assert y_onehot.shape == (7, 7), f"Shape: {y_onehot.shape}"

utils/models/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1-
__all__ = ["MagnusModel"]
1+
__all__ = ["MagnusModel", "ChristianModel"]
22

3+
from .christian_model import ChristianModel
34
from .magnus_model import MagnusModel

utils/models/christian_model.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import pytest
2+
import torch
3+
import torch.nn as nn
4+
5+
6+
class CNNBlock(nn.Module):
7+
def __init__(self, in_channels, out_channels):
8+
super().__init__()
9+
10+
self.conv = nn.Conv2d(
11+
in_channels,
12+
out_channels,
13+
kernel_size=3,
14+
padding=1,
15+
)
16+
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
17+
self.relu = nn.ReLU()
18+
19+
def forward(self, x):
20+
x = self.conv(x)
21+
x = self.maxpool(x)
22+
x = self.relu(x)
23+
return x
24+
25+
26+
class ChristianModel(nn.Module):
27+
"""Simple CNN model for image classification.
28+
29+
Args
30+
----
31+
in_channels : int
32+
Number of input channels.
33+
num_classes : int
34+
Number of classes in the dataset.
35+
36+
Processing Images
37+
-----------------
38+
Input: (N, C, H, W)
39+
N: Batch size
40+
C: Number of input channels
41+
H: Height of the input image
42+
W: Width of the input image
43+
44+
Example:
45+
For grayscale images, C = 1.
46+
47+
Input Image Shape: (5, 1, 16, 16)
48+
CNN1 Output Shape: (5, 50, 8, 8)
49+
CNN2 Output Shape: (5, 100, 4, 4)
50+
FC Output Shape: (5, num_classes)
51+
"""
52+
def __init__(self, in_channels, num_classes):
53+
super().__init__()
54+
55+
self.cnn1 = CNNBlock(in_channels, 50)
56+
self.cnn2 = CNNBlock(50, 100)
57+
58+
self.fc1 = nn.Linear(100 * 4 * 4, num_classes)
59+
self.softmax = nn.Softmax(dim=1)
60+
61+
def forward(self, x):
62+
x = self.cnn1(x)
63+
x = self.cnn2(x)
64+
65+
x = x.view(x.size(0), -1)
66+
x = self.fc1(x)
67+
x = self.softmax(x)
68+
69+
return x
70+
71+
72+
@pytest.mark.parametrize("in_channels, num_classes", [(1, 6), (3, 6)])
73+
def test_christian_model(in_channels, num_classes):
74+
n, c, h, w = 5, in_channels, 16, 16
75+
76+
model = ChristianModel(c, num_classes)
77+
78+
x = torch.randn(n, c, h, w)
79+
y = model(x)
80+
81+
assert y.shape == (n, num_classes), f"Shape: {y.shape}"
82+
assert y.sum(dim=1).allclose(torch.ones(n), atol=1e-5), f"Softmax output should sum to 1, but got: {y.sum()}"
83+
84+
85+
if __name__ == "__main__":
86+
87+
model = ChristianModel(3, 7)
88+
89+
x = torch.randn(3, 3, 16, 16)
90+
y = model(x)
91+
92+
print(y)

0 commit comments

Comments
 (0)