Skip to content

Commit da91527

Browse files
committed
Add SolveigModel implementation
1 parent 7c5a9b2 commit da91527

File tree

1 file changed

+62
-2
lines changed

1 file changed

+62
-2
lines changed

utils/models/solveig_model.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,68 @@
44

55

66
class SolveigModel(nn.Module):
7-
def __init__(self):
7+
"""
8+
A Convolutional Neural Network model for classification.
9+
10+
Args:
11+
----
12+
in_channels : int
13+
Number of input channels (e.g., 3 for RGB images, 1 for grayscale).
14+
num_classes : int
15+
The number of output classes (e.g., 2 for binary classification).
16+
17+
Attributes:
18+
-----------
19+
conv_block1 : nn.Sequential
20+
First convolutional block containing a convolutional layer, ReLU activation, and max-pooling.
21+
conv_block2 : nn.Sequential
22+
Second convolutional block containing a convolutional layer and ReLU activation.
23+
conv_block3 : nn.Sequential
24+
Third convolutional block containing a convolutional layer and ReLU activation.
25+
fc1 : nn.Linear
26+
Fully connected layer that outputs the final classification scores.
27+
"""
28+
29+
def __init__(self, in_channels, num_classes):
830
super().__init__()
931

32+
# Define the first convolutional block (conv + relu + maxpool)
33+
self.conv_block1 = nn.Sequential(
34+
nn.Conv2d(in_channels=in_channels, out_channels=25, kernel_size=3, padding=1),
35+
nn.ReLU(),
36+
nn.MaxPool2d(kernel_size=2, stride=2)
37+
)
38+
39+
# Define the second convolutional block (conv + relu)
40+
self.conv_block2 = nn.Sequential(
41+
nn.Conv2d(in_channels=25, out_channels=50, kernel_size=3, padding=1),
42+
nn.ReLU()
43+
)
44+
45+
# Define the third convolutional block (conv + relu)
46+
self.conv_block3 = nn.Sequential(
47+
nn.Conv2d(in_channels=50, out_channels=100, kernel_size=3, padding=1),
48+
nn.ReLU()
49+
)
50+
51+
self.fc1 = nn.Linear(100 * 8 * 8, num_classes)
52+
1053
def forward(self, x):
11-
return
54+
x = self.conv_block1(x)
55+
x = self.conv_block2(x)
56+
x = self.conv_block3(x)
57+
x = torch.flatten(x, 1)
58+
59+
x = self.fc1(x)
60+
x = nn.Softmax(x)
61+
62+
return x
63+
64+
65+
if __name__ == "__main__":
66+
model = SolveigModel(3, 3)
67+
68+
x = torch.randn(1, 3, 16, 16)
69+
y = model(x)
70+
71+
print(y)

0 commit comments

Comments
 (0)