Skip to content

Commit 17c8d15

Browse files
authored
feat: update to latest Burn (#25)
1 parent d52deb4 commit 17c8d15

File tree

9 files changed

+85
-98
lines changed

9 files changed

+85
-98
lines changed

.github/workflows/validate.yml

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,11 @@ jobs:
5252
uses: Swatinem/rust-cache@v2
5353
with:
5454
key: ${{ runner.os }}-${{ matrix.cache }}-${{ matrix.backend }}-${{ hashFiles('**/Cargo.toml') }}
55-
- name: (Linux) Install llvmpipe, lavapipe
56-
if: runner.os == 'Linux'
55+
- name: (Linux) Install Vulkan deps
56+
if: runner.os == 'Linux' && matrix.backend == 'wgpu'
5757
run: |-
5858
sudo apt-get update -y -qq
59-
sudo add-apt-repository ppa:kisak/kisak-mesa -y
60-
sudo apt-get update
61-
sudo apt install -y libegl1-mesa libgl1-mesa-dri libxcb-xfixes0-dev mesa-vulkan-drivers
59+
sudo apt-get install -y libvulkan1 mesa-vulkan-drivers vulkan-tools libxcb-xfixes0-dev
6260
- name: (Windows) Install warp
6361
if: runner.os == 'Windows'
6462
shell: bash
@@ -88,9 +86,6 @@ jobs:
8886
8987
echo "VK_DRIVER_FILES=$PWD/mesa/lvp_icd.x86_64.json" >> "$GITHUB_ENV"
9088
echo "GALLIUM_DRIVER=llvmpipe" >> "$GITHUB_ENV"
91-
- name: (Windows) Install dxc
92-
if: runner.os == 'Windows'
93-
uses: napokue/setup-dxc@v1.1.0
9489
- name: Run tests
9590
shell: bash
9691
run: |-

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,6 @@ Cargo.lock
1919
# IDEs
2020
.idea
2121
.fleet
22+
23+
# Others
24+
inspiration/

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ torch = ["burn/tch"]
1616
wgpu = ["burn/wgpu"]
1717

1818
[dependencies]
19-
burn = { version = "0.13.0", default-features = false }
19+
burn = { version = "0.20.1", default-features = false }
2020
num-traits = { version = "0.2.18", default-features = false }
2121
serde = { version = "1.0.197", default-features = false, features = [
2222
"derive",

src/lib.rs

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,25 +9,13 @@ pub mod pipelines;
99
pub mod transformers;
1010
pub mod utils;
1111

12-
#[cfg(all(test, feature = "ndarray"))]
13-
use burn::backend::ndarray;
14-
15-
#[cfg(all(test, feature = "torch"))]
16-
use burn::backend::libtorch;
17-
18-
#[cfg(all(test, feature = "wgpu"))]
19-
use burn::backend::wgpu;
20-
2112
extern crate alloc;
2213

2314
#[cfg(all(test, feature = "ndarray"))]
24-
pub type TestBackend = ndarray::NdArray<f32>;
15+
pub type TestBackend = burn::backend::NdArray<f32>;
2516

2617
#[cfg(all(test, feature = "torch"))]
27-
pub type TestBackend = libtorch::LibTorch<f32>;
18+
pub type TestBackend = burn::backend::LibTorch<f32>;
2819

29-
#[cfg(all(test, feature = "wgpu", not(target_os = "macos")))]
30-
pub type TestBackend = wgpu::Wgpu<wgpu::Vulkan, f32, i32>;
31-
32-
#[cfg(all(test, feature = "wgpu", target_os = "macos"))]
33-
pub type TestBackend = wgpu::Wgpu<wgpu::Metal, f32, i32>;
20+
#[cfg(all(test, feature = "wgpu"))]
21+
pub type TestBackend = burn::backend::Wgpu;

src/models/attention.rs

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ use burn::tensor::Tensor;
1616
#[allow(unused_imports)]
1717
use num_traits::Float;
1818

19-
#[derive(Config)]
19+
#[derive(Config, Debug)]
2020
pub struct GeGluConfig {
2121
/// The size of the input features.
2222
d_input: usize,
@@ -44,7 +44,7 @@ impl<B: Backend> GeGlu<B> {
4444
}
4545
}
4646

47-
#[derive(Config)]
47+
#[derive(Config, Debug)]
4848
pub struct FeedForwardConfig {
4949
/// The size of the input features.
5050
pub d_input: usize,
@@ -90,7 +90,7 @@ impl<B: Backend> FeedForward<B> {
9090
}
9191
}
9292

93-
#[derive(Config)]
93+
#[derive(Config, Debug)]
9494
pub struct CrossAttentionConfig {
9595
/// The number of channels in the query.
9696
d_query: usize,
@@ -232,7 +232,7 @@ impl<B: Backend> CrossAttention<B> {
232232
}
233233
}
234234

235-
#[derive(Config)]
235+
#[derive(Config, Debug)]
236236
pub struct BasicTransformerBlockConfig {
237237
d_model: usize,
238238
d_context: Option<usize>,
@@ -488,13 +488,13 @@ mod tests {
488488
use super::*;
489489
use crate::TestBackend;
490490
use burn::module::{Param, ParamId};
491-
use burn::tensor::{Data, Shape};
491+
use burn::tensor::{Shape, TensorData, Tolerance};
492492

493493
#[test]
494494
fn test_geglu_tensor_shape_3() {
495495
let device = Default::default();
496496
let weight = Tensor::from_data(
497-
Data::from([
497+
TensorData::from([
498498
[
499499
0.1221, 2.0378, -0.1171, 1.3004, -0.9630, -0.3108, -1.3376, -1.0593,
500500
],
@@ -505,7 +505,7 @@ mod tests {
505505
&device,
506506
);
507507
let bias = Tensor::from_data(
508-
Data::from([
508+
TensorData::from([
509509
0.2867778149426027,
510510
0.6646517317105776,
511511
0.023946332404821136,
@@ -526,7 +526,7 @@ mod tests {
526526
};
527527

528528
let tensor: Tensor<TestBackend, 3> = Tensor::from_data(
529-
Data::from([
529+
TensorData::from([
530530
[[1., 2.], [3., 4.], [5., 6.]],
531531
[[7., 8.], [9., 10.], [11., 12.]],
532532
]),
@@ -535,8 +535,8 @@ mod tests {
535535

536536
let output = geglu.forward(tensor);
537537
assert_eq!(output.shape(), Shape::from([2, 3, 4]));
538-
output.to_data().assert_approx_eq(
539-
&Data::from([
538+
output.into_data().assert_approx_eq::<f32>(
539+
&TensorData::from([
540540
[
541541
[4.2632e0, -1.7927e-1, -2.3216e-1, -3.7916e-2],
542542
[1.3460e1, -2.9266e-1, -2.1707e-4, -4.5595e-2],
@@ -548,22 +548,22 @@ mod tests {
548548
[1.0119e2, -2.1943e-5, -0.0000e0, -0.0000e0],
549549
],
550550
]),
551-
2,
551+
Tolerance::rel_abs(1e-2, 1e-2),
552552
);
553553
}
554554

555555
#[test]
556556
fn test_geglu_tensor_shape_2() {
557557
let device = Default::default();
558558
let weight = Tensor::from_data(
559-
Data::from([
559+
TensorData::from([
560560
[0.6054, 1.9322, 0.1445, 1.3004, -0.6853, -0.8947],
561561
[-0.3678, 0.4081, -1.9001, -1.5843, -0.9399, 0.1018],
562562
]),
563563
&device,
564564
);
565565
let bias = Tensor::from_data(
566-
Data::from([
566+
TensorData::from([
567567
0.3237631905393836,
568568
0.22052049807936902,
569569
-0.3196353346822061,
@@ -582,17 +582,17 @@ mod tests {
582582
};
583583

584584
let tensor: Tensor<TestBackend, 2> =
585-
Tensor::from_data(Data::from([[1., 2.], [3., 4.], [5., 6.]]), &device);
585+
Tensor::from_data(TensorData::from([[1., 2.], [3., 4.], [5., 6.]]), &device);
586586

587587
let output = geglu.forward(tensor);
588588
assert_eq!(output.shape(), Shape::from([3, 3]));
589-
output.to_data().assert_approx_eq(
590-
&Data::from([
589+
output.into_data().assert_approx_eq::<f32>(
590+
&TensorData::from([
591591
[-2.4192e-5, -3.3057e-2, 2.8535e-1],
592592
[-0.0000e0, -2.0983e-7, 5.2465e-1],
593593
[-0.0000e0, -0.0000e0, 1.2599e-2],
594594
]),
595-
1,
595+
Tolerance::rel_abs(1e-1, 1e-1),
596596
);
597597
}
598598

@@ -601,7 +601,7 @@ mod tests {
601601
let device = Default::default();
602602
// create tensor of size [2, 4, 2]
603603
let query: Tensor<TestBackend, 3> = Tensor::from_data(
604-
Data::from([
604+
TensorData::from([
605605
[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]],
606606
[[9.0, 10.0], [11.0, 12.0], [13.0, 14.0], [15.0, 16.0]],
607607
[[17.0, 18.0], [19.0, 20.0], [21.0, 22.0], [23.0, 24.0]],
@@ -610,7 +610,7 @@ mod tests {
610610
&device,
611611
);
612612
let key: Tensor<TestBackend, 3> = Tensor::from_data(
613-
Data::from([
613+
TensorData::from([
614614
[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]],
615615
[[9.0, 10.0], [11.0, 12.0], [13.0, 14.0], [15.0, 16.0]],
616616
[[17.0, 18.0], [19.0, 20.0], [21.0, 22.0], [23.0, 24.0]],
@@ -619,7 +619,7 @@ mod tests {
619619
&device,
620620
);
621621
let value: Tensor<TestBackend, 3> = Tensor::from_data(
622-
Data::from([
622+
TensorData::from([
623623
[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]],
624624
[[9.0, 10.0], [11.0, 12.0], [13.0, 14.0], [15.0, 16.0]],
625625
[[17.0, 18.0], [19.0, 20.0], [21.0, 22.0], [23.0, 24.0]],
@@ -637,8 +637,8 @@ mod tests {
637637
let output = cross_attention.sliced_attention(query, key, value, 2);
638638

639639
assert_eq!(output.shape(), Shape::from([2, 4, 4]));
640-
output.into_data().assert_approx_eq(
641-
&Data::from([
640+
output.into_data().assert_approx_eq::<f32>(
641+
&TensorData::from([
642642
[
643643
[5.9201, 6.9201, 14.9951, 15.9951],
644644
[6.7557, 7.7557, 14.9986, 15.9986],
@@ -652,7 +652,7 @@ mod tests {
652652
[23.0000, 24.0000, 31.0000, 32.0000],
653653
],
654654
]),
655-
3,
655+
Tolerance::rel_abs(1e-3, 1e-3),
656656
)
657657
}
658658
}

src/models/embeddings.rs

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -85,26 +85,27 @@ impl<B: Backend> Timesteps<B> {
8585
mod tests {
8686
use super::*;
8787
use crate::TestBackend;
88-
use burn::tensor::{Data, Shape};
88+
use burn::tensor::{Shape, TensorData, Tolerance};
8989

9090
#[test]
9191
#[cfg(not(feature = "torch"))]
9292
fn test_timesteps_even_channels() {
9393
let device = Default::default();
9494
let timesteps = Timesteps::<TestBackend>::new(4, true, 0.);
95-
let xs: Tensor<TestBackend, 1> = Tensor::from_data(Data::from([1., 2., 3., 4.]), &device);
95+
let xs: Tensor<TestBackend, 1> =
96+
Tensor::from_data(TensorData::from([1., 2., 3., 4.]), &device);
9697

97-
let emb = timesteps.forward(xs);
98+
let emb: Tensor<TestBackend, 2> = timesteps.forward(xs);
9899

99100
assert_eq!(emb.shape(), Shape::from([4, 4]));
100-
emb.to_data().assert_approx_eq(
101-
&Data::from([
101+
emb.into_data().assert_approx_eq::<f32>(
102+
&TensorData::from([
102103
[0.5403, 1.0000, 0.8415, 0.0100],
103104
[-0.4161, 0.9998, 0.9093, 0.0200],
104105
[-0.9900, 0.9996, 0.1411, 0.0300],
105106
[-0.6536, 0.9992, -0.7568, 0.0400],
106107
]),
107-
3,
108+
Tolerance::rel_abs(1e-3, 1e-3),
108109
);
109110
}
110111

@@ -114,21 +115,21 @@ mod tests {
114115
let device = Default::default();
115116
let timesteps = Timesteps::<TestBackend>::new(5, true, 0.);
116117
let xs: Tensor<TestBackend, 1> =
117-
Tensor::from_data(Data::from([1., 2., 3., 4., 5.]), &device);
118+
Tensor::from_data(TensorData::from([1., 2., 3., 4., 5.]), &device);
118119

119-
let emb = timesteps.forward(xs);
120+
let emb: Tensor<TestBackend, 2> = timesteps.forward(xs);
120121

121122
assert_eq!(emb.shape(), Shape::from([6, 4]));
122-
emb.to_data().assert_approx_eq(
123-
&Data::from([
123+
emb.into_data().assert_approx_eq::<f32>(
124+
&TensorData::from([
124125
[0.5403, 1.0000, 0.8415, 0.0100],
125126
[-0.4161, 0.9998, 0.9093, 0.0200],
126127
[-0.9900, 0.9996, 0.1411, 0.0300],
127128
[-0.6536, 0.9992, -0.7568, 0.0400],
128129
[0.2837, 0.9988, -0.9589, 0.0500],
129130
[0.0000, 0.0000, 0.0000, 0.0000],
130131
]),
131-
3,
132+
Tolerance::rel_abs(1e-3, 1e-3),
132133
);
133134
}
134135
}

0 commit comments

Comments
 (0)