Skip to content

Commit e6a3e0b

Browse files
committed
refactor: streamline ResNet stem configuration and clean up unused code
- Simplified `StemStage` and `StemConfig` structures, removing redundant abstractions. - Refactored to introduce `ResNetStemContractConfig` and `ResNetStemStructureConfig` for cleaner separation of concerns. - Reduced unused imports for improved maintainability.
1 parent 50410b8 commit e6a3e0b

File tree

1 file changed

+56
-122
lines changed

1 file changed

+56
-122
lines changed

crates/bimm/src/models/resnet/stems.rs

Lines changed: 56 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,11 @@
3939
//! if aa_layer is not None:
4040
//! if issubclass(aa_layer, nn.AvgPool2d):
4141
//! self.maxpool = aa_layer(2)
42-
//! else:
43-
//! self.maxpool = nn.Sequential(*[
44-
//! nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
45-
//! aa_layer(channels=inplanes, stride=2)
46-
//! ])
42+
//! else:
43+
//! self.maxpool = nn.Sequential(*[
44+
//! nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
45+
//! aa_layer(channels=inplanes, stride=2)
46+
//! ])
4747
//! else:
4848
//! self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
4949
//! ```
@@ -62,29 +62,17 @@
6262
//! ]
6363
//! ```
6464
//!
65-
use crate::compat::activation_wrapper::{Activation, ActivationConfig};
66-
use crate::layers::blocks::conv_norm::{ConvNorm2d, ConvNorm2dConfig, ConvNorm2dMeta};
67-
use crate::models::resnet::util::CONV_INTO_RELU_INITIALIZER;
68-
use burn::config::Config;
69-
use burn::nn::Initializer;
70-
use burn::prelude::{Backend, Module};
71-
use burn::tensor::Tensor;
72-
73-
/// [`Stem`] Meta API.
74-
pub trait StemMeta {
75-
/// The number of input channels.
76-
fn in_planes(&self) -> usize;
7765
78-
/// The number of output channels.
79-
fn out_planes(&self) -> usize;
80-
81-
/// The stride of the first convolution.
82-
fn stride(&self) -> [usize; 2];
83-
}
66+
use crate::compat::activation_wrapper::ActivationConfig;
67+
use crate::compat::normalization_wrapper::NormalizationConfig;
68+
use crate::layers::blocks::cna::CNA2dConfig;
69+
use burn::nn::PaddingConfig2d;
70+
use burn::nn::conv::Conv2dConfig;
71+
use burn::nn::pool::MaxPool2dConfig;
8472

85-
/// [`Stem`] configuration.
73+
/// stem contract configuration.
8674
#[derive(Debug, Clone, Default)]
87-
pub enum StemAbstractConfig {
75+
pub enum ResNetStemContractConfig {
8876
/// Default; single 7x7 convolution with stride 2.
8977
#[default]
9078
Default,
@@ -108,109 +96,55 @@ pub enum StemAbstractConfig {
10896
},
10997
}
11098

111-
/// [`StemStage`] configuration.
112-
#[derive(Config, Debug)]
113-
pub struct StemStageConfig {
114-
/// Convolution + Normalization layer.
115-
pub conv_norm: ConvNorm2dConfig,
116-
117-
/// Activation function.
118-
#[config(default = "ActivationConfig::Relu")]
119-
pub activation: ActivationConfig,
120-
121-
/// Initializer for the convolutional layers.
122-
#[config(default = "CONV_INTO_RELU_INITIALIZER.clone()")]
123-
pub initializer: Initializer,
124-
}
125-
126-
impl StemMeta for StemStageConfig {
127-
fn in_planes(&self) -> usize {
128-
self.conv_norm.in_channels()
129-
}
130-
131-
fn out_planes(&self) -> usize {
132-
self.conv_norm.out_channels()
133-
}
134-
135-
fn stride(&self) -> [usize; 2] {
136-
*self.conv_norm.stride()
137-
}
138-
}
139-
140-
impl StemStageConfig {
141-
/// Initialize the [`StemStage`].
142-
pub fn init<B: Backend>(
143-
self,
144-
device: &B::Device,
145-
) -> StemStage<B> {
146-
StemStage {
147-
conv_norm: self
148-
.conv_norm
149-
.with_initializer(self.initializer)
150-
.init(device),
151-
152-
activation: self.activation.init(device),
99+
impl ResNetStemContractConfig {
100+
/// Convert to a [`ResNetStemStructureConfig`].
101+
pub fn to_structure(
102+
&self,
103+
in_channels: usize,
104+
normalization: NormalizationConfig,
105+
activation: ActivationConfig,
106+
) -> ResNetStemStructureConfig {
107+
match self {
108+
ResNetStemContractConfig::Default => (),
109+
_ => unimplemented!("{:?}", self),
153110
}
154-
}
155-
}
156-
157-
/// `ResNet` [`StemStage`].
158-
#[derive(Module, Debug)]
159-
pub struct StemStage<B: Backend> {
160-
/// Convolution + Normalization layer.
161-
pub conv_norm: ConvNorm2d<B>,
162111

163-
/// Activation function.
164-
pub activation: Activation<B>,
165-
}
166-
167-
impl<B: Backend> StemStage<B> {
168-
/// Forward pass.
169-
pub fn forward(
170-
&self,
171-
input: Tensor<B, 4>,
172-
) -> Tensor<B, 4> {
173-
let x = self.conv_norm.forward(input);
174-
self.activation.forward(x)
112+
let cna1 = CNA2dConfig {
113+
conv: Conv2dConfig::new([in_channels, 64], [7, 7])
114+
.with_stride([2, 2])
115+
.with_padding(PaddingConfig2d::Explicit(3, 3))
116+
.with_bias(false),
117+
norm: normalization.clone(),
118+
act: activation.clone(),
119+
};
120+
121+
let pool = Some(
122+
MaxPool2dConfig::new([3, 3])
123+
.with_strides([2, 2])
124+
.with_padding(PaddingConfig2d::Explicit(1, 1)),
125+
);
126+
127+
ResNetStemStructureConfig {
128+
cna1,
129+
cna2: None,
130+
cna3: None,
131+
pool,
132+
}
175133
}
176134
}
177135

178-
/// [`Stem]` configuration.
179-
#[derive(Config, Debug)]
180-
pub struct StemConfig {
181-
/// Stem stages.
182-
pub stages: Vec<StemStageConfig>,
183-
}
136+
/// stem contract configuration.
137+
#[derive(Debug, Clone)]
138+
pub struct ResNetStemStructureConfig {
139+
/// The first convolution.
140+
pub cna1: CNA2dConfig,
184141

185-
impl StemConfig {
186-
/// Initialize a [`Stem`].
187-
pub fn init<B: Backend>(
188-
self,
189-
device: &B::Device,
190-
) -> Stem<B> {
191-
// TODO: check that the stages have valid input/output sizes.
192-
Stem {
193-
stages: self
194-
.stages
195-
.into_iter()
196-
.map(|stage| stage.init(device))
197-
.collect(),
198-
}
199-
}
200-
}
142+
/// The second convolution.
143+
pub cna2: Option<CNA2dConfig>,
201144

202-
/// `ResNet` Input [`Stem`] Module.
203-
#[derive(Module, Debug)]
204-
pub struct Stem<B: Backend> {
205-
stages: Vec<StemStage<B>>,
206-
}
145+
/// The third convolution.
146+
pub cna3: Option<CNA2dConfig>,
207147

208-
impl<B: Backend> Stem<B> {
209-
/// Forward pass.
210-
pub fn forward(
211-
&self,
212-
input: Tensor<B, 4>,
213-
) -> Tensor<B, 4> {
214-
self.stages.iter().fold(input, |x, stage| stage.forward(x))
215-
}
148+
/// The pooling layer.
149+
pub pool: Option<MaxPool2dConfig>,
216150
}

0 commit comments

Comments
 (0)