44
55
66class SolveigModel (nn .Module ):
7- def __init__ (self ):
7+ """
8+ A Convolutional Neural Network model for classification.
9+
10+ Args:
11+ ----
12+ in_channels : int
13+ Number of input channels (e.g., 3 for RGB images, 1 for grayscale).
14+ num_classes : int
15+ The number of output classes (e.g., 2 for binary classification).
16+
17+ Attributes:
18+ -----------
19+ conv_block1 : nn.Sequential
20+ First convolutional block containing a convolutional layer, ReLU activation, and max-pooling.
21+ conv_block2 : nn.Sequential
22+ Second convolutional block containing a convolutional layer and ReLU activation.
23+ conv_block3 : nn.Sequential
24+ Third convolutional block containing a convolutional layer and ReLU activation.
25+ fc1 : nn.Linear
26+ Fully connected layer that outputs the final classification scores.
27+ """
28+
29+ def __init__ (self , in_channels , num_classes ):
830 super ().__init__ ()
931
32+ # Define the first convolutional block (conv + relu + maxpool)
33+ self .conv_block1 = nn .Sequential (
34+ nn .Conv2d (in_channels = in_channels , out_channels = 25 , kernel_size = 3 , padding = 1 ),
35+ nn .ReLU (),
36+ nn .MaxPool2d (kernel_size = 2 , stride = 2 )
37+ )
38+
39+ # Define the second convolutional block (conv + relu)
40+ self .conv_block2 = nn .Sequential (
41+ nn .Conv2d (in_channels = 25 , out_channels = 50 , kernel_size = 3 , padding = 1 ),
42+ nn .ReLU ()
43+ )
44+
45+ # Define the third convolutional block (conv + relu)
46+ self .conv_block3 = nn .Sequential (
47+ nn .Conv2d (in_channels = 50 , out_channels = 100 , kernel_size = 3 , padding = 1 ),
48+ nn .ReLU ()
49+ )
50+
51+ self .fc1 = nn .Linear (100 * 8 * 8 , num_classes )
52+
1053 def forward (self , x ):
11- return
54+ x = self .conv_block1 (x )
55+ x = self .conv_block2 (x )
56+ x = self .conv_block3 (x )
57+ x = torch .flatten (x , 1 )
58+
59+ x = self .fc1 (x )
60+ x = nn .Softmax (x )
61+
62+ return x
63+
64+
65+ if __name__ == "__main__" :
66+ model = SolveigModel (3 , 3 )
67+
68+ x = torch .randn (1 , 3 , 16 , 16 )
69+ y = model (x )
70+
71+ print (y )
0 commit comments