3939//! if aa_layer is not None:
4040//! if issubclass(aa_layer, nn.AvgPool2d):
4141//! self.maxpool = aa_layer(2)
42- //! else:
43- //! self.maxpool = nn.Sequential(*[
44- //! nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
45- //! aa_layer(channels=inplanes, stride=2)
46- //! ])
42+ //! else:
43+ //! self.maxpool = nn.Sequential(*[
44+ //! nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
45+ //! aa_layer(channels=inplanes, stride=2)
46+ //! ])
4747//! else:
4848//! self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
4949//! ```
6262//! ]
6363//! ```
6464//!
65- use crate :: compat:: activation_wrapper:: { Activation , ActivationConfig } ;
66- use crate :: layers:: blocks:: conv_norm:: { ConvNorm2d , ConvNorm2dConfig , ConvNorm2dMeta } ;
67- use crate :: models:: resnet:: util:: CONV_INTO_RELU_INITIALIZER ;
68- use burn:: config:: Config ;
69- use burn:: nn:: Initializer ;
70- use burn:: prelude:: { Backend , Module } ;
71- use burn:: tensor:: Tensor ;
72-
73- /// [`Stem`] Meta API.
74- pub trait StemMeta {
75- /// The number of input channels.
76- fn in_planes ( & self ) -> usize ;
7765
78- /// The number of output channels.
79- fn out_planes ( & self ) -> usize ;
80-
81- /// The stride of the first convolution.
82- fn stride ( & self ) -> [ usize ; 2 ] ;
83- }
66+ use crate :: compat :: activation_wrapper :: ActivationConfig ;
67+ use crate :: compat :: normalization_wrapper :: NormalizationConfig ;
68+ use crate :: layers :: blocks :: cna :: CNA2dConfig ;
69+ use burn :: nn :: PaddingConfig2d ;
70+ use burn :: nn :: conv :: Conv2dConfig ;
71+ use burn :: nn :: pool :: MaxPool2dConfig ;
8472
85- /// [`Stem`] configuration.
73+ /// stem contract configuration.
8674#[ derive( Debug , Clone , Default ) ]
87- pub enum StemAbstractConfig {
75+ pub enum ResNetStemContractConfig {
8876 /// Default; single 7x7 convolution with stride 2.
8977 #[ default]
9078 Default ,
@@ -108,109 +96,55 @@ pub enum StemAbstractConfig {
10896 } ,
10997}
11098
111- /// [`StemStage`] configuration.
112- #[ derive( Config , Debug ) ]
113- pub struct StemStageConfig {
114- /// Convolution + Normalization layer.
115- pub conv_norm : ConvNorm2dConfig ,
116-
117- /// Activation function.
118- #[ config( default = "ActivationConfig::Relu" ) ]
119- pub activation : ActivationConfig ,
120-
121- /// Initializer for the convolutional layers.
122- #[ config( default = "CONV_INTO_RELU_INITIALIZER.clone()" ) ]
123- pub initializer : Initializer ,
124- }
125-
126- impl StemMeta for StemStageConfig {
127- fn in_planes ( & self ) -> usize {
128- self . conv_norm . in_channels ( )
129- }
130-
131- fn out_planes ( & self ) -> usize {
132- self . conv_norm . out_channels ( )
133- }
134-
135- fn stride ( & self ) -> [ usize ; 2 ] {
136- * self . conv_norm . stride ( )
137- }
138- }
139-
140- impl StemStageConfig {
141- /// Initialize the [`StemStage`].
142- pub fn init < B : Backend > (
143- self ,
144- device : & B :: Device ,
145- ) -> StemStage < B > {
146- StemStage {
147- conv_norm : self
148- . conv_norm
149- . with_initializer ( self . initializer )
150- . init ( device) ,
151-
152- activation : self . activation . init ( device) ,
99+ impl ResNetStemContractConfig {
100+ /// Convert to a [`ResNetStemStructureConfig`].
101+ pub fn to_structure (
102+ & self ,
103+ in_channels : usize ,
104+ normalization : NormalizationConfig ,
105+ activation : ActivationConfig ,
106+ ) -> ResNetStemStructureConfig {
107+ match self {
108+ ResNetStemContractConfig :: Default => ( ) ,
109+ _ => unimplemented ! ( "{:?}" , self ) ,
153110 }
154- }
155- }
156-
157- /// `ResNet` [`StemStage`].
158- #[ derive( Module , Debug ) ]
159- pub struct StemStage < B : Backend > {
160- /// Convolution + Normalization layer.
161- pub conv_norm : ConvNorm2d < B > ,
162111
163- /// Activation function.
164- pub activation : Activation < B > ,
165- }
166-
167- impl < B : Backend > StemStage < B > {
168- /// Forward pass.
169- pub fn forward (
170- & self ,
171- input : Tensor < B , 4 > ,
172- ) -> Tensor < B , 4 > {
173- let x = self . conv_norm . forward ( input) ;
174- self . activation . forward ( x)
112+ let cna1 = CNA2dConfig {
113+ conv : Conv2dConfig :: new ( [ in_channels, 64 ] , [ 7 , 7 ] )
114+ . with_stride ( [ 2 , 2 ] )
115+ . with_padding ( PaddingConfig2d :: Explicit ( 3 , 3 ) )
116+ . with_bias ( false ) ,
117+ norm : normalization. clone ( ) ,
118+ act : activation. clone ( ) ,
119+ } ;
120+
121+ let pool = Some (
122+ MaxPool2dConfig :: new ( [ 3 , 3 ] )
123+ . with_strides ( [ 2 , 2 ] )
124+ . with_padding ( PaddingConfig2d :: Explicit ( 1 , 1 ) ) ,
125+ ) ;
126+
127+ ResNetStemStructureConfig {
128+ cna1,
129+ cna2 : None ,
130+ cna3 : None ,
131+ pool,
132+ }
175133 }
176134}
177135
178- /// [`Stem]` configuration.
179- #[ derive( Config , Debug ) ]
180- pub struct StemConfig {
181- /// Stem stages.
182- pub stages : Vec < StemStageConfig > ,
183- }
136+ /// stem contract configuration.
137+ #[ derive( Debug , Clone ) ]
138+ pub struct ResNetStemStructureConfig {
139+ /// The first convolution.
140+ pub cna1 : CNA2dConfig ,
184141
185- impl StemConfig {
186- /// Initialize a [`Stem`].
187- pub fn init < B : Backend > (
188- self ,
189- device : & B :: Device ,
190- ) -> Stem < B > {
191- // TODO: check that the stages have valid input/output sizes.
192- Stem {
193- stages : self
194- . stages
195- . into_iter ( )
196- . map ( |stage| stage. init ( device) )
197- . collect ( ) ,
198- }
199- }
200- }
142+ /// The second convolution.
143+ pub cna2 : Option < CNA2dConfig > ,
201144
202- /// `ResNet` Input [`Stem`] Module.
203- #[ derive( Module , Debug ) ]
204- pub struct Stem < B : Backend > {
205- stages : Vec < StemStage < B > > ,
206- }
145+ /// The third convolution.
146+ pub cna3 : Option < CNA2dConfig > ,
207147
208- impl < B : Backend > Stem < B > {
209- /// Forward pass.
210- pub fn forward (
211- & self ,
212- input : Tensor < B , 4 > ,
213- ) -> Tensor < B , 4 > {
214- self . stages . iter ( ) . fold ( input, |x, stage| stage. forward ( x) )
215- }
148+ /// The pooling layer.
149+ pub pool : Option < MaxPool2dConfig > ,
216150}
0 commit comments