Skip to content

Commit ac00550

Browse files
committed
feat: implement ResNetStem struct and forward pass, and bump bimm-contracts to 0.4.2
- Added the `ResNetStem` struct with associated forward method for streamlined ResNet stem implementation. - Updated `bimm-contracts` to version 0.4.2 in Cargo files and lock file.
1 parent e6a3e0b commit ac00550

File tree

4 files changed

+48
-8
lines changed

4 files changed

+48
-8
lines changed

Cargo.lock

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ burn = "^0.18.0"
3232
burn-import = "^0.18.0"
3333
dirs = "^6.0.0"
3434

35-
bimm-contracts = "^0.4.1"
35+
bimm-contracts = "^0.4.2"
3636

3737
# Burn coupled-dependencies
3838
globwalk = "^0.9.1"

crates/bimm/src/models/resnet/stems.rs

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,12 @@
6565
6666
use crate::compat::activation_wrapper::ActivationConfig;
6767
use crate::compat::normalization_wrapper::NormalizationConfig;
68-
use crate::layers::blocks::cna::CNA2dConfig;
68+
use crate::layers::blocks::cna::{CNA2d, CNA2dConfig};
69+
use burn::module::Module;
6970
use burn::nn::PaddingConfig2d;
7071
use burn::nn::conv::Conv2dConfig;
71-
use burn::nn::pool::MaxPool2dConfig;
72+
use burn::nn::pool::{MaxPool2d, MaxPool2dConfig};
73+
use burn::prelude::{Backend, Tensor};
7274

7375
/// stem contract configuration.
7476
#[derive(Debug, Clone, Default)]
@@ -148,3 +150,37 @@ pub struct ResNetStemStructureConfig {
148150
/// The pooling layer.
149151
pub pool: Option<MaxPool2dConfig>,
150152
}
153+
154+
/// stem impl.
155+
#[derive(Module, Debug)]
156+
pub struct ResNetStem<B: Backend> {
157+
/// The first convolution.
158+
pub cna1: CNA2d<B>,
159+
/// The second convolution.
160+
pub cna2: Option<CNA2d<B>>,
161+
/// The third convolution.
162+
pub cna3: Option<CNA2d<B>>,
163+
/// The pooling.
164+
pub pool: Option<MaxPool2d>,
165+
}
166+
167+
impl<B: Backend> ResNetStem<B> {
168+
/// forward pass.
169+
pub fn forward(
170+
&self,
171+
input: Tensor<B, 4>,
172+
) -> Tensor<B, 4> {
173+
let mut x = input;
174+
x = self.cna1.forward(x);
175+
if let Some(cna2) = &self.cna2 {
176+
x = cna2.forward(x);
177+
}
178+
if let Some(cna3) = &self.cna3 {
179+
x = cna3.forward(x);
180+
}
181+
if let Some(pool) = &self.pool {
182+
x = pool.forward(x);
183+
}
184+
x
185+
}
186+
}

examples/resnet-finetune/src/main.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,12 @@ pub struct Args {
115115
pub weight_decay: f32,
116116
}
117117

118+
/// Log config.
119+
///
120+
/// Only exists for logging.
118121
#[derive(Config, Debug)]
119-
struct LogConfig {
122+
#[allow(clippy::too_many_arguments)]
123+
pub struct LogConfig {
120124
seed: u64,
121125
train_percentage: u8,
122126
batch_size: usize,

0 commit comments

Comments
 (0)