@@ -6,12 +6,12 @@ class SolveigModel(nn.Module):
66 """
77 A Convolutional Neural Network model for classification.
88
9- Args:
9+ Args
1010 ----
11- in_channels : int
12- Number of input channels (e.g., 3 for RGB images, 1 for grayscale ).
11+ image_shape : tuple( int, int, int)
12+ Shape of the input image (C, H, W ).
1313 num_classes : int
14- The number of output classes (e.g., 2 for binary classification) .
14+ Number of classes in the dataset .
1515
1616 Attributes:
1717 -----------
@@ -25,12 +25,14 @@ class SolveigModel(nn.Module):
2525 Fully connected layer that outputs the final classification scores.
2626 """
2727
28- def __init__ (self , in_channels , num_classes ):
28+ def __init__ (self , image_shape , num_classes ):
2929 super ().__init__ ()
3030
31+ C , * _ = image_shape
32+
3133 # Define the first convolutional block (conv + relu + maxpool)
3234 self .conv_block1 = nn .Sequential (
33- nn .Conv2d (in_channels = in_channels , out_channels = 25 , kernel_size = 3 , padding = 1 ),
35+ nn .Conv2d (in_channels = C , out_channels = 25 , kernel_size = 3 , padding = 1 ),
3436 nn .ReLU (),
3537 nn .MaxPool2d (kernel_size = 2 , stride = 2 )
3638 )
@@ -62,9 +64,11 @@ def forward(self, x):
6264
6365
6466if __name__ == "__main__" :
65- model = SolveigModel (3 , 3 )
6667
67- x = torch .randn (1 , 3 , 16 , 16 )
68+ x = torch .randn (1 ,3 , 16 , 16 )
69+
70+ model = SolveigModel (x .shape [1 :], 3 )
71+
6872 y = model (x )
6973
7074 print (y )
0 commit comments