1212
1313
1414class CNN (nn .Sequential ):
15+ """Convolutional Neural Network (CNN).
16+
17+ The CNN network is a sequence of convolutional layers, optional batch normalization, activation functions, and
18+ optional max pooling. The final output can be flattened or pooled depending on the configuration.
19+ """
20+
1521 def __init__ (
1622 self ,
1723 in_channels : int ,
@@ -24,13 +30,25 @@ def __init__(
2430 batchnorm : bool | list [bool ] = False ,
2531 max_pool : bool | list [bool ] = False ,
2632 ) -> None :
27- """Convolutional Neural Network model.
33+ """Initialize the CNN.
34+
35+ Args:
36+ in_channels: Number of input channels.
37+ activation: Activation function to use.
38+ out_channels: List of output channels for each convolutional layer.
39+ kernel_size: List of kernel sizes for each convolutional layer or a single kernel size for all layers.
40+ stride: List of strides for each convolutional layer or a single stride for all layers.
41+ flatten: Whether to flatten the output tensor.
42+ avg_pool: If specified, applies an adaptive average pooling to the given output size after the convolutions.
43+ batchnorm: Whether to apply batch normalization after each convolutional layer.
44+ max_pool: Whether to apply max pooling after each convolutional layer.
2845
2946 .. note::
3047 Do not save config to allow for the model to be jit compiled.
3148 """
3249 super ().__init__ ()
3350
51+ # If parameters are not lists, convert them to lists
3452 if isinstance (batchnorm , bool ):
3553 batchnorm = [batchnorm ] * len (out_channels )
3654 if isinstance (max_pool , bool ):
@@ -40,12 +58,11 @@ def __init__(
4058 if isinstance (stride , int ):
4159 stride = [stride ] * len (out_channels )
4260
43- # get activation function
61+ # Resolve activation function
4462 activation_function = resolve_nn_activation (activation )
4563
46- # build model layers
64+ # Create layers sequentially
4765 layers = []
48-
4966 for idx in range (len (out_channels )):
5067 in_channels = in_channels if idx == 0 else out_channels [idx - 1 ]
5168 layers .append (
@@ -62,16 +79,17 @@ def __init__(
6279 if max_pool [idx ]:
6380 layers .append (nn .MaxPool2d (kernel_size = 3 , stride = 2 , padding = 1 ))
6481
65- # register the layers
82+ # Register the layers
6683 for idx , layer in enumerate (layers ):
6784 self .add_module (f"{ idx } " , layer )
6885
86+ # Add avgpool if specified
6987 if avg_pool is not None :
7088 self .avgpool = nn .AdaptiveAvgPool2d (avg_pool )
7189 else :
7290 self .avgpool = None
7391
74- # save flatten config for forward function
92+ # Save flatten flag for forward function
7593 self .flatten = flatten
7694
7795 def forward (self , x : torch .Tensor ) -> torch .Tensor :
@@ -84,9 +102,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
84102 x = x .flatten (start_dim = 1 )
85103 return x
86104
87- def init_weights (self , scales : float | tuple [float ]) -> None :
88- """Initialize the weights of the CNN."""
89- # initialize the weights
105+ def init_weights (self ) -> None :
106+ """Initialize the weights of the CNN with Xavier initialization."""
90107 for idx , module in enumerate (self ):
91108 if isinstance (module , nn .Conv2d ):
92109 nn .init .xavier_uniform_ (module .weight )
0 commit comments