Skip to content

Commit e80e648

Browse files
authored
Move nn components to burn-nn (tracel-ai#3740)
1 parent 47eac76 commit e80e648

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

103 files changed

+1081
-731
lines changed

.github/workflows/publish.yml

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,16 +223,45 @@ jobs:
223223
secrets:
224224
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}
225225

226+
publish-burn-collective:
227+
uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v3
228+
with:
229+
crate: burn-collective
230+
needs:
231+
- publish-burn-common
232+
- publish-burn-tensor
233+
- publish-burn-communication
234+
# dev dependencies
235+
- publish-burn-wgpu
236+
- publish-burn-ndarray
237+
- publish-burn-cuda
238+
secrets:
239+
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}
240+
241+
publish-burn-communication:
242+
uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v3
243+
with:
244+
crate: burn-communication
245+
needs:
246+
- publish-burn-common
247+
- publish-burn-tensor
248+
secrets:
249+
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}
250+
226251
publish-burn-core:
227252
uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v3
228253
needs:
229254
- publish-burn-dataset
230255
- publish-burn-common
231256
- publish-burn-derive
232257
- publish-burn-tensor
258+
- publish-burn-vision
259+
- publish-burn-collective
260+
# dev dependencies
233261
- publish-burn-autodiff
234262
- publish-burn-wgpu
235263
- publish-burn-tch
264+
- publish-burn-cuda
236265
- publish-burn-ndarray
237266
- publish-burn-candle
238267
- publish-burn-remote
@@ -241,6 +270,22 @@ jobs:
241270
secrets:
242271
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}
243272

273+
publish-burn-nn:
274+
uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v3
275+
needs:
276+
- publish-burn-core
277+
# dev dependencies
278+
- publish-burn-autodiff
279+
- publish-burn-wgpu
280+
- publish-burn-tch
281+
- publish-burn-ndarray
282+
- publish-burn-candle
283+
- publish-burn-remote
284+
with:
285+
crate: burn-nn
286+
secrets:
287+
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}
288+
244289
publish-burn-train:
245290
uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v3
246291
needs:

Cargo.lock

Lines changed: 19 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/burn-core/Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ version.workspace = true
1515
workspace = true
1616

1717
[features]
18-
dataset = ["burn-dataset"]
1918
default = [
2019
"std",
2120
"burn-common/default",
@@ -26,12 +25,13 @@ doc = [
2625
"std",
2726
"dataset",
2827
"audio",
29-
"vision",
3028
# Doc features
3129
"burn-common/doc",
3230
"burn-dataset/doc",
3331
"burn-tensor/doc",
3432
]
33+
dataset = ["burn-dataset"]
34+
3535
network = ["burn-common/network"]
3636
sqlite = ["burn-dataset?/sqlite"]
3737
sqlite-bundled = ["burn-dataset?/sqlite-bundled"]

crates/burn-core/README.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
# Burn Core
22

3-
This crate should be used with [burn](https://github.com/tracel-ai/burn).
3+
This crate should be used with [burn](https://github.com/tracel-ai/burn). It contains the core
4+
traits and components for building and training deep learning models with Burn.
45

56
[![Current Crates.io Version](https://img.shields.io/crates/v/burn-core.svg)](https://crates.io/crates/burn-core)
67
[![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)](https://github.com/tracel-ai/burn-core/blob/master/README.md)
78

89
## Feature Flags
910

10-
This crate can be used without the standard library (`#![no_std]`) with `alloc` by disabling
11-
the default `std` feature.
11+
This crate can be used without the standard library (`#![no_std]`) with `alloc` by disabling the
12+
default `std` feature.
1213

1314
- `std` - enables the standard library. Enabled by default.
1415
- `experimental-named-tensor` - enables experimental named tensor.

crates/burn-core/src/lib.rs

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,6 @@ pub mod grad_clipping;
3131
/// Module for the neural network module.
3232
pub mod module;
3333

34-
/// Neural network module.
35-
pub mod nn;
36-
3734
/// Module for the recorder.
3835
pub mod record;
3936

@@ -47,7 +44,6 @@ pub use tensor::Tensor;
4744
pub mod vision;
4845

4946
extern crate alloc;
50-
extern crate core;
5147

5248
/// Backend for test cases
5349
#[cfg(all(
@@ -84,6 +80,47 @@ mod tests {
8480
burn_fusion::memory_checks!();
8581
}
8682

83+
#[cfg(test)]
84+
mod test_utils {
85+
use crate as burn;
86+
use crate::module::Module;
87+
use crate::module::Param;
88+
use burn_tensor::Tensor;
89+
use burn_tensor::backend::Backend;
90+
use burn_tensor::module::linear;
91+
92+
/// Simple linear module.
93+
#[derive(Module, Debug)]
94+
pub struct SimpleLinear<B: Backend> {
95+
pub weight: Param<Tensor<B, 2>>,
96+
pub bias: Option<Param<Tensor<B, 1>>>,
97+
}
98+
99+
impl<B: Backend> SimpleLinear<B> {
100+
pub fn new(in_features: usize, out_features: usize, device: &B::Device) -> Self {
101+
let weight = Tensor::random(
102+
[out_features, in_features],
103+
burn_tensor::Distribution::Default,
104+
device,
105+
);
106+
let bias = Tensor::random([out_features], burn_tensor::Distribution::Default, device);
107+
108+
Self {
109+
weight: Param::from_tensor(weight),
110+
bias: Some(Param::from_tensor(bias)),
111+
}
112+
}
113+
114+
pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
115+
linear(
116+
input,
117+
self.weight.val(),
118+
self.bias.as_ref().map(|b| b.val()),
119+
)
120+
}
121+
}
122+
}
123+
87124
/// Type alias for the learning rate.
88125
///
89126
/// LearningRate also implements [learning rate scheduler](crate::lr_scheduler::LrScheduler) so it
@@ -97,7 +134,6 @@ pub mod prelude {
97134
pub use crate::{
98135
config::Config,
99136
module::Module,
100-
nn,
101137
tensor::{
102138
Bool, Device, ElementConversion, Float, Int, RangesArg, Shape, Tensor, TensorData,
103139
backend::Backend, cast::ToElement, s,

crates/burn-core/src/module/base.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ macro_rules! module {
6363
/// parameter B. This will be used by the [derive](burn_derive::Module) attribute to generate the code
6464
/// necessary to optimize and train the module on any backend.
6565
///
66-
/// ```no_run
66+
/// ```rust, ignore
6767
/// // Not necessary when using the burn crate directly.
6868
/// use burn_core as burn;
6969
///

crates/burn-core/src/nn/initializer.rs renamed to crates/burn-core/src/module/initializer.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ fn qr_decomposition<B: Backend>(
287287
mod tests {
288288
use super::*;
289289

290-
use crate::tensor::{ElementConversion, TensorData};
290+
use burn_tensor::{ElementConversion, TensorData};
291291
use num_traits::Pow;
292292

293293
pub type TB = burn_ndarray::NdArray<f32>;

crates/burn-core/src/module/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
mod base;
22
mod display;
3+
mod initializer;
34
mod param;
45
mod quantize;
56
#[cfg(feature = "std")]
67
mod reinit;
78

89
pub use base::*;
910
pub use display::*;
11+
pub use initializer::*;
1012
pub use param::*;
1113
pub use quantize::*;
1214

crates/burn-core/src/module/param/base.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ fn new_mapper<T, F: Fn(T) -> T + Send + Sync + 'static>(func: F) -> Mapper<T> {
3434
/// # Laziness
3535
///
3636
/// The initialization of parameters can be lazy when created using
37-
/// [uninitialized](Self::uninitialized), which can be done using an [initializer](crate::nn::Initializer).
37+
/// [uninitialized](Self::uninitialized), which can be done using an [initializer](crate::module::Initializer).
3838
///
3939
/// This reduces the amount of allocations done when loading a model for inference without having
4040
/// to create a custom initialization function only for inference.

crates/burn-core/src/module/param/constant.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
use alloc::{format, string::ToString};
22
use core::{fmt::Display, marker::PhantomData};
33

4+
use crate as burn;
45
use crate::{
5-
self as burn,
66
module::{
77
AutodiffModule, Content, Devices, Module, ModuleDisplay, ModuleDisplayDefault,
88
ModuleMapper, ModuleVisitor,
99
},
10-
record::Record,
10+
record::{PrecisionSettings, Record},
1111
};
12-
use burn::record::PrecisionSettings;
1312
use burn_tensor::{
1413
BasicAutodiffOps, BasicOps, Tensor,
1514
backend::{AutodiffBackend, Backend},

0 commit comments

Comments
 (0)