Skip to content

Commit 40bb5c0

Browse files
committed
Add ChristianModel: 2 layer CNN w/maxpooling
1 parent e5aafb0 commit 40bb5c0

File tree

4 files changed

+106
-10
lines changed

4 files changed

+106
-10
lines changed

main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def main():
6666
"--modelname",
6767
type=str,
6868
default="MagnusModel",
69-
choices=["MagnusModel"],
69+
choices=["MagnusModel", "ChristianModel"],
7070
help="Model which to be trained on",
7171
)
7272
parser.add_argument(

utils/load_model.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
import torch.nn as nn
22

3-
from .models import MagnusModel
3+
from .models import ChristianModel, MagnusModel
44

55

6-
def load_model(modelname: str) -> nn.Module:
7-
if modelname == "MagnusModel":
8-
return MagnusModel()
9-
else:
10-
raise ValueError(
11-
f"Model: {modelname} has not been implemented. \nCheck the documentation for implemented metrics, or check your spelling"
12-
)
6+
def load_model(modelname: str, *args, **kwargs) -> nn.Module:
7+
match modelname.lower():
8+
case "magnusmodel":
9+
return MagnusModel(*args, **kwargs)
10+
case "christianmodel":
11+
return ChristianModel(*args, **kwargs)
12+
case _:
13+
raise ValueError(
14+
f"Model: {modelname} has not been implemented. \nCheck the documentation for implemented metrics, or check your spelling"
15+
)

utils/models/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1-
__all__ = ["MagnusModel"]
1+
__all__ = ["MagnusModel", "ChristianModel"]
22

3+
from .christian_model import ChristianModel
34
from .magnus_model import MagnusModel

utils/models/christian_model.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import pytest
2+
import torch
3+
import torch.nn as nn
4+
5+
6+
class CNNBlock(nn.Module):
7+
def __init__(self, in_channels, out_channels):
8+
super().__init__()
9+
10+
self.conv = nn.Conv2d(
11+
in_channels,
12+
out_channels,
13+
kernel_size=3,
14+
padding=1,
15+
)
16+
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
17+
self.relu = nn.ReLU()
18+
19+
def forward(self, x):
20+
x = self.conv(x)
21+
x = self.maxpool(x)
22+
x = self.relu(x)
23+
return x
24+
25+
26+
class ChristianModel(nn.Module):
27+
"""Simple CNN model for image classification.
28+
29+
Args
30+
----
31+
in_channels : int
32+
Number of input channels.
33+
num_classes : int
34+
Number of classes in the dataset.
35+
36+
Processing Images
37+
-----------------
38+
Input: (N, C, H, W)
39+
N: Batch size
40+
C: Number of input channels
41+
H: Height of the input image
42+
W: Width of the input image
43+
44+
Example:
45+
For grayscale images, C = 1.
46+
47+
Input Image Shape: (5, 1, 16, 16)
48+
CNN1 Output Shape: (5, 50, 8, 8)
49+
CNN2 Output Shape: (5, 100, 4, 4)
50+
FC Output Shape: (5, num_classes)
51+
"""
52+
def __init__(self, in_channels, num_classes):
53+
super().__init__()
54+
55+
self.cnn1 = CNNBlock(in_channels, 50)
56+
self.cnn2 = CNNBlock(50, 100)
57+
58+
self.fc1 = nn.Linear(100 * 4 * 4, num_classes)
59+
self.softmax = nn.Softmax(dim=1)
60+
61+
def forward(self, x):
62+
x = self.cnn1(x)
63+
x = self.cnn2(x)
64+
65+
x = x.view(x.size(0), -1)
66+
x = self.fc1(x)
67+
x = self.softmax(x)
68+
69+
return x
70+
71+
72+
@pytest.mark.parametrize("in_channels, num_classes", [(1, 6), (3, 6)])
73+
def test_christian_model(in_channels, num_classes):
74+
n, c, h, w = 5, in_channels, 16, 16
75+
76+
model = ChristianModel(c, num_classes)
77+
78+
x = torch.randn(n, c, h, w)
79+
y = model(x)
80+
81+
assert y.shape == (n, num_classes), f"Shape: {y.shape}"
82+
assert y.sum(dim=1).allclose(torch.ones(n), atol=1e-5), f"Softmax output should sum to 1, but got: {y.sum()}"
83+
84+
85+
if __name__ == "__main__":
86+
87+
model = ChristianModel(3, 7)
88+
89+
x = torch.randn(3, 3, 16, 16)
90+
y = model(x)
91+
92+
print(y)

0 commit comments

Comments
 (0)