Skip to content

Commit c0ee89f

Browse files
committed
feat: introduce multi-backend support and update configurations
- Added support for `WGPU`, `CUDA`, and `Metal` backends in examples (`resnet_tiny`, `resnet-finetune`, and `swin_tiny`). - Refactored backend initialization logic to dynamically handle multiple features. - Updated `Cargo.toml` files to add feature flags (`wgpu`, `cuda`, `metal`) for enhanced flexibility. - Modified `.run` configurations to include appropriate backend-specific settings. - Upgraded `aho-corasick` dependency from `1.1.3` to `1.1.4` in `Cargo.lock`.
1 parent 5b46e9a commit c0ee89f

File tree

12 files changed

+76
-54
lines changed

12 files changed

+76
-54
lines changed

.run/Fix.run.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
<component name="ProjectRunConfigurationManager">
22
<configuration default="false" name="Fix" type="CargoCommandRunConfiguration" factoryName="Cargo Command" focusToolWindowBeforeRun="true">
33
<option name="buildProfileId" value="test" />
4-
<option name="command" value="clippy --fix --allow-dirty --allow-staged" />
4+
<option name="command" value="clippy --fix --allow-dirty --allow-staged --features wgpu" />
55
<option name="workingDirectory" value="file://$PROJECT_DIR$" />
66
<envs>
77
<env name="CINIC10_PATH" value="$CINIC10_PATH$" />

.run/Test.run.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
<component name="ProjectRunConfigurationManager">
22
<configuration default="false" name="Test" type="CargoCommandRunConfiguration" factoryName="Cargo Command" focusToolWindowBeforeRun="true">
33
<option name="buildProfileId" value="test" />
4-
<option name="command" value="test" />
4+
<option name="command" value="test --features wgpu" />
55
<option name="workingDirectory" value="file://$PROJECT_DIR$" />
66
<envs>
77
<env name="CINIC10_PATH" value="/media/Data/CINIC-10" />

Cargo.lock

Lines changed: 9 additions & 8 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ members = [
1010

1111
[workspace.package]
1212
edition = "2024"
13-
version = "0.19.0"
13+
version = "0.19.1"
1414
repository = "https://github.com/crutcher/bimm"
1515
license = "MIT"
1616
rust-version = "1.88.0"

crates/bimm-firehose-image/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ description = "bimm-firehose image processing support"
1212
workspace = true
1313

1414
[dependencies]
15-
bimm-firehose = { version = "0.19.0", path = "../bimm-firehose" }
15+
bimm-firehose = { version = "0.19.1", path = "../bimm-firehose" }
1616

1717
burn = { workspace = true, features = ["dataset", "vision", "ndarray"] }
1818
serde = { workspace = true, features = ["derive"] }

crates/bimm/src/models/resnet/residual_block.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,7 @@ mod tests {
411411
let output = block.forward(input);
412412

413413
assert_shape_contract!(
414-
["batch", "out_channels", "out_height", "out_width"],
414+
["batch", "out_planes", "out_height", "out_width"],
415415
&output,
416416
&[
417417
("batch", batch_size),

examples/resnet-finetune/Cargo.toml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@ workspace = true
1111

1212
[dependencies]
1313
anyhow = { workspace = true }
14-
bimm = { version = "0.19.0", path = "../../crates/bimm" }
14+
bimm = { version = "0.19.1", path = "../../crates/bimm" }
1515

16-
burn = { workspace = true, features = ["network", "train", "autodiff", "vision", "cuda"] }
16+
burn = { workspace = true, features = ["network", "train", "autodiff", "vision"] }
1717

1818
clap = { workspace = true, features = ["derive"] }
1919

@@ -24,7 +24,9 @@ rand = { workspace = true }
2424
serde = { workspace = true, features = ["std", "derive"] }
2525
tar = { workspace = true }
2626

27-
2827
[features]
2928
default = []
29+
wgpu = ["burn/wgpu"]
30+
cuda = ["burn/cuda"]
31+
metal = ["burn/metal"]
3032

examples/resnet-finetune/src/main.rs

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use crate::data::{ClassificationBatch, ClassificationBatcher};
99
use crate::dataset::{CLASSES, PlanetLoader, download};
1010
use bimm::cache::disk::DiskCacheConfig;
1111
use bimm::models::resnet::{PREFAB_RESNET_MAP, ResNet, ResNetContractConfig};
12-
use burn::backend::{Autodiff, Cuda};
12+
use burn::backend::Autodiff;
1313
use burn::config::Config;
1414
use burn::data::dataloader::DataLoaderBuilder;
1515
use burn::data::dataset::transform::ShuffledDataset;
@@ -150,21 +150,26 @@ fn main() -> anyhow::Result<()> {
150150

151151
let _source_tree = download();
152152

153-
let device = Default::default();
154-
train::<Autodiff<Cuda>>(&args, &device)
153+
#[cfg(feature = "wgpu")]
154+
return train::<Autodiff<burn::backend::Wgpu>>(&args);
155+
156+
#[cfg(feature = "cuda")]
157+
return train::<Autodiff<burn::backend::Cuda>>(&args);
158+
159+
#[cfg(feature = "metal")]
160+
return train::<Autodiff<burn::backend::Metal>>(&args);
155161
}
156162

157163
#[must_use]
158-
pub fn train<B: AutodiffBackend>(
159-
args: &Args,
160-
device: &B::Device,
161-
) -> anyhow::Result<()> {
164+
pub fn train<B: AutodiffBackend>(args: &Args) -> anyhow::Result<()> {
165+
let device: B::Device = Default::default();
166+
162167
// Remove existing artifacts before to get an accurate learner summary
163168
let artifact_dir: &str = args.artifact_dir.as_ref();
164169
std::fs::remove_dir_all(artifact_dir);
165170
std::fs::create_dir_all(artifact_dir).expect("Failed to create artifacts directory");
166171

167-
B::seed(device, args.seed);
172+
B::seed(&device, args.seed);
168173

169174
let disk_cache = DiskCacheConfig::default();
170175

@@ -180,7 +185,7 @@ pub fn train<B: AutodiffBackend>(
180185
let model: ResNet<B> = resnet_config
181186
.clone()
182187
.to_structure()
183-
.init(device)
188+
.init(&device)
184189
.load_pytorch_weights(weights)
185190
.expect("Failed to load pretrained weights")
186191
.with_classes(CLASSES.len())

examples/resnet_tiny/Cargo.toml

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@ repository.workspace = true
1010
workspace = true
1111

1212
[dependencies]
13-
bimm = { version = "0.19.0", path = "../../crates/bimm" }
14-
bimm-firehose = { version = "0.19.0", path = "../../crates/bimm-firehose" }
15-
bimm-firehose-image = { version = "0.19.0", path = "../../crates/bimm-firehose-image" }
13+
bimm = { version = "0.19.1", path = "../../crates/bimm" }
14+
bimm-firehose = { version = "0.19.1", path = "../../crates/bimm-firehose" }
15+
bimm-firehose-image = { version = "0.19.1", path = "../../crates/bimm-firehose-image" }
1616

17-
burn = { workspace = true, features = ["train", "autodiff", "cuda", "fusion", "vision", "dataset"] }
17+
burn = { workspace = true, features = ["train", "autodiff", "fusion", "vision", "dataset"] }
1818
serde = { workspace = true, features = ["derive"] }
1919

2020
clap = { workspace = true, features = ["derive"] }
@@ -23,4 +23,7 @@ rand = { workspace = true }
2323
anyhow = { workspace = true }
2424

2525
[features]
26-
nightly = []
26+
default = []
27+
wgpu = ["burn/wgpu"]
28+
cuda = ["burn/cuda"]
29+
metal = ["burn/metal"]

examples/resnet_tiny/src/main.rs

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ use bimm_firehose_image::augmentation::orientation::flip::HorizontalFlipStage;
1919
use bimm_firehose_image::burn_support::{ImageToTensorData, stack_tensor_data_column};
2020
use bimm_firehose_image::loader::{ImageLoader, ResizeSpec};
2121
use bimm_firehose_image::{ColorType, ImageShape};
22-
use burn::backend::{Autodiff, Cuda};
22+
use burn::backend::Autodiff;
2323
use burn::data::dataloader::{DataLoaderBuilder, Dataset};
2424
use burn::data::dataset::transform::ShuffledDataset;
2525
use burn::lr_scheduler::cosine::CosineAnnealingLrSchedulerConfig;
@@ -126,10 +126,15 @@ pub struct Args {
126126

127127
fn main() -> anyhow::Result<()> {
128128
let args = Args::parse();
129-
type B = Autodiff<Cuda>;
130129

131-
let device = Default::default();
132-
backend_main::<B>(&args, &device)
130+
#[cfg(feature = "wgpu")]
131+
return backend_main::<Autodiff<burn::backend::Wgpu>>(&args);
132+
133+
#[cfg(feature = "cuda")]
134+
return backend_main::<Autodiff<burn::backend::Cuda>>(&args);
135+
136+
#[cfg(feature = "metal")]
137+
return backend_main::<Autodiff<burn::backend::Metal>>(&args);
133138
}
134139

135140
/// Create the artifact directory for saving training artifacts.
@@ -140,17 +145,16 @@ fn create_artifact_dir(artifact_dir: &str) {
140145
}
141146

142147
/// Train the model with the given configuration and devices.
143-
pub fn backend_main<B: AutodiffBackend>(
144-
args: &Args,
145-
device: &B::Device,
146-
) -> anyhow::Result<()> {
148+
pub fn backend_main<B: AutodiffBackend>(args: &Args) -> anyhow::Result<()> {
149+
let device: B::Device = Default::default();
150+
147151
let image_shape = ImageShape {
148152
height: 32,
149153
width: 32,
150154
};
151155
let num_classes = 10;
152156

153-
B::seed(device, args.seed);
157+
B::seed(&device, args.seed);
154158

155159
let prefab = PREFAB_RESNET_MAP.expect_lookup_prefab(&args.resnet_prefab);
156160

@@ -159,7 +163,7 @@ pub fn backend_main<B: AutodiffBackend>(
159163
.with_activation(ActivationConfig::Gelu)
160164
.to_structure();
161165

162-
let resnet: ResNet<B> = resnet_config.init(device);
166+
let resnet: ResNet<B> = resnet_config.init(&device);
163167

164168
let resnet: ResNet<B> = match &args.resnet_pretrained {
165169
Some(pretrained) => {

0 commit comments

Comments
 (0)