22import torch .nn as nn
33
44
5+ def find_fc_input_shape (image_shape , model ):
6+ """
7+ Find the shape of the input to the fully connected layer after passing through the convolutional layers.
8+
9+ Code inspired by @Seilmast (https://github.com/SFI-Visual-Intelligence/Collaborative-Coding-Exam/issues/67#issuecomment-2651212254)
10+
11+ Args
12+ ----
13+ image_shape : tuple(int, int, int)
14+ Shape of the input image (C, H, W), where C is the number of channels,
15+ H is the height, and W is the width of the image.
16+ model : nn.Module
17+ The CNN model containing the convolutional layers, whose output size is used to
18+ determine the number of input features for the fully connected layer.
19+
20+ Returns
21+ -------
22+ int
23+ The number of elements in the input to the fully connected layer.
24+ """
25+
26+ dummy_img = torch .randn (1 , * image_shape )
27+ with torch .no_grad ():
28+ x = model .conv_block1 (dummy_img )
29+ x = model .conv_block2 (x )
30+ x = model .conv_block3 (x )
31+ x = torch .flatten (x , 1 )
32+
33+ return x .size (1 )
34+
35+
536class SolveigModel (nn .Module ):
637 """
738 A Convolutional Neural Network model for classification.
@@ -49,9 +80,19 @@ def __init__(self, image_shape, num_classes):
4980 nn .ReLU (),
5081 )
5182
52- self .fc1 = nn .Linear (100 * 8 * 8 , num_classes )
83+ fc_input_size = find_fc_input_shape (image_shape , self )
84+
85+ self .fc1 = nn .Linear (fc_input_size , num_classes )
5386
5487 def forward (self , x ):
88+ """
89+ Defines the forward pass.
90+ Args:
91+ x (torch.Tensor): A four-dimensional tensor with shape
92+ (Batch Size, Channels, Image Height, Image Width).
93+ Returns:
94+ torch.Tensor: The output tensor containing class logits for each input sample.
95+ """
5596 x = self .conv_block1 (x )
5697 x = self .conv_block2 (x )
5798 x = self .conv_block3 (x )
@@ -63,7 +104,7 @@ def forward(self, x):
63104
64105
65106if __name__ == "__main__" :
66- x = torch .randn (1 , 3 , 16 , 16 )
107+ x = torch .randn (1 , 3 , 28 , 28 )
67108
68109 model = SolveigModel (x .shape [1 :], 3 )
69110
0 commit comments