33
44
55class CNNBlock (nn .Module ):
6+ """
7+ CNN block with Conv2d, MaxPool2d, and ReLU.
8+
9+ Args
10+ ----
11+
12+ in_channels : int
13+ Number of input channels.
14+ out_channels : int
15+ Number of output channels.
16+ """
17+
618 def __init__ (self , in_channels , out_channels ):
719 super ().__init__ ()
820
@@ -22,6 +34,37 @@ def forward(self, x):
2234 return x
2335
2436
37+ def find_fc_input_shape (image_shape , * cnn_layers ):
38+ """
39+ Find the shape of the input to the fully connected layer.
40+
41+ Code inspired by @Seilmast (https://github.com/SFI-Visual-Intelligence/Collaborative-Coding-Exam/issues/67#issuecomment-2651212254)
42+
43+ Args
44+ ----
45+ image_shape : tuple(int, int, int)
46+ Shape of the input image (C, H, W).
47+ cnn_layers : nn.Module
48+ List of CNN layers.
49+
50+ Returns
51+ -------
52+ int
53+ Number of elements in the input to the fully connected layer.
54+ """
55+
56+ dummy_img = torch .randn (1 , * image_shape )
57+ with torch .no_grad ():
58+ x = cnn_layers [0 ](dummy_img )
59+
60+ for layer in cnn_layers [1 :]:
61+ x = layer (x )
62+
63+ x = x .view (x .size (0 ), - 1 )
64+
65+ return x .size (1 )
66+
67+
2568class ChristianModel (nn .Module ):
2669 """Simple CNN model for image classification.
2770
@@ -57,7 +100,9 @@ def __init__(self, image_shape, num_classes):
57100 self .cnn1 = CNNBlock (C , 50 )
58101 self .cnn2 = CNNBlock (50 , 100 )
59102
60- self .fc1 = nn .Linear (100 * 4 * 4 , num_classes )
103+ fc_input_shape = find_fc_input_shape (image_shape , self .cnn1 , self .cnn2 )
104+
105+ self .fc1 = nn .Linear (fc_input_shape , num_classes )
61106
62107 def forward (self , x ):
63108 x = self .cnn1 (x )
@@ -70,9 +115,10 @@ def forward(self, x):
70115
71116
72117if __name__ == "__main__" :
73- model = ChristianModel (3 , 7 )
118+ x = torch .randn (3 , 3 , 28 , 28 )
119+
120+ model = ChristianModel (x .shape [1 :], 7 )
74121
75- x = torch .randn (3 , 3 , 16 , 16 )
76122 y = model (x )
77123
78124 print (y )
0 commit comments