Skip to content

Commit 90f115c

Browse files
authored
Add cubecl re-export, root Tensor, doc updates and Noam scheduler fix (tracel-ai#3742)
1 parent 9358a07 commit 90f115c

File tree

7 files changed

+88
-104
lines changed

7 files changed

+88
-104
lines changed

Cargo.lock

Lines changed: 2 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

burn-book/src/quantization.md

Lines changed: 51 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ Quantization support in Burn is currently in active development.
3030

3131
It supports the following modes on some backends:
3232

33-
- Static per-tensor quantization to signed 8-bit integer (`i8`)
33+
- Per-tensor and per-block (linear) quantization to 8-bit, 4-bit and 2-bit representations
3434

3535
No integer operations are currently supported, which means tensors are dequantized to perform the
3636
operations in floating point precision.
@@ -45,48 +45,22 @@ tensors and can collect their statistics, such as the min and max value when usi
4545

4646
```rust , ignore
4747
# use burn::module::Quantizer;
48-
# use burn::tensor::quantization::{Calibration, QuantizationScheme, QuantizationType};
48+
# use burn::tensor::quantization::{Calibration, QuantLevel, QuantParam, QuantScheme, QuantValue};
4949
#
5050
// Quantization config
51+
let scheme = QuantScheme::default()
52+
.with_level(QuantLevel::Block(32))
53+
.with_value(QuantValue::Q4F)
54+
.with_param(QuantParam::F16);
5155
let mut quantizer = Quantizer {
5256
calibration: Calibration::MinMax,
53-
scheme: QuantizationScheme::PerTensor(QuantizationMode::Symmetric, QuantizationType::QInt8),
57+
scheme,
5458
};
5559

5660
// Quantize the weights
5761
let model = model.quantize_weights(&mut quantizer);
5862
```
5963

60-
> Given that all operations are currently performed in floating point precision, it might be wise to
61-
> dequantize the module parameters before inference. This allows us to save disk space by storing
62-
> the model in reduced precision while preserving the inference speed.
63-
>
64-
> This can easily be implemented with a `ModuleMapper`.
65-
>
66-
> ```rust, ignore
67-
> # use burn::module::{ModuleMapper, ParamId};
68-
> # use burn::tensor::{backend::Backend, Tensor};
69-
> #
70-
> /// Module mapper used to dequantize the model params being loaded.
71-
> pub struct Dequantize {}
72-
>
73-
> impl<B: Backend> ModuleMapper<B> for Dequantize {
74-
> fn map_float<const D: usize>(
75-
> &mut self,
76-
> _id: ParamId,
77-
> tensor: Tensor<B, D>,
78-
> ) -> Tensor<B, D> {
79-
> tensor.dequantize()
80-
> }
81-
> }
82-
>
83-
> // Load saved quantized model in floating point precision
84-
> model = model
85-
> .load_file(file_path, recorder, &device)
86-
> .expect("Should be able to load the quantized model weights")
87-
> .map(&mut Dequantize {});
88-
> ```
89-
9064
### Calibration
9165

9266
Calibration is the step during quantization where the range of all floating-point tensors is
@@ -101,29 +75,55 @@ To compute the quantization parameters, Burn supports the following `Calibration
10175

10276
### Quantization Scheme
10377

104-
A quantization scheme defines the quantized type, quantization granularity and range mapping
105-
technique.
78+
A quantization scheme defines how an input is quantized, including the representation of quantized
79+
values, storage format, granularity, and how the values are scaled.
10680

107-
Burn currently supports the following `QuantizationType` variants.
81+
```rust
82+
let scheme = QuantScheme::default()
83+
.with_mode(QuantMode::Symmetric) // Quantization mode
84+
.with_level(QuantLevel::Block(32)) // Granularity (per-tensor or per-block)
85+
.with_value(QuantValue::Q8S) // Data type of quantized values, independent of how they're stored
86+
.with_store(QuantStore::Native) // Storage format for quantized values
87+
.with_param(QuantParam::F16); // Precision for quantization parameters
88+
```
10889

109-
| Type | Description |
110-
| :------ | :--------------------------------- |
111-
| `QInt8` | 8-bit signed integer quantization. |
90+
#### Quantization Mode
11291

113-
Quantization parameters are defined based on the range of values to represent and can typically be
114-
calculated for the layer's entire weight tensor with per-tensor quantization or separately for each
115-
channel with per-channel quantization (commonly used with CNNs).
92+
| Mode | Description |
93+
| :---------- | :------------------------------------------- |
94+
| `Symmetric` | Values are scaled symmetrically around zero. |
11695

117-
Burn currently supports the following `QuantizationScheme` variants.
96+
#### Quantization Level
11897

119-
| Variant | Description |
120-
| :----------------------------- | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
121-
| `PerTensor(mode, type)` | Applies a single set of quantization parameters to the entire tensor. The `mode` defines how values are transformed, and `type` represents the target quantization type. |
98+
| Level | Description |
99+
| :------------------ | :--------------------------------------------------------------------------- |
100+
| `Tensor` | A single quantization parameter set for the entire tensor. |
101+
| `Block(block_size)` | Tensor divided into 1D linear blocks, each with its own quantization params. |
122102

123-
#### Quantization Mode
103+
#### Quantization Value
104+
105+
| Value | Bits | Description |
106+
| :---- | :--: | :---------------------------- |
107+
| `Q8F` | 8 | 8-bit full-range quantization |
108+
| `Q4F` | 4 | 4-bit full-range quantization |
109+
| `Q2F` | 2 | 2-bit full-range quantization |
110+
| `Q8S` | 8 | 8-bit symmetric quantization |
111+
| `Q4S` | 4 | 4-bit symmetric quantization |
112+
| `Q2S` | 2 | 2-bit symmetric quantization |
113+
114+
#### Quantization Store
115+
116+
| Store | Description |
117+
| :------- | :------------------------------------------------------ |
118+
| `Native` | Each quantized value stored directly in memory. |
119+
| `U32` | Multiple quantized values packed into a 32-bit integer. |
120+
121+
Native storage is not supported for sub-byte quantization values.
124122

125-
| Mode | Description |
126-
| ----------- | -------------------------------------------------------------------- |
127-
| `Symmetric` | Maps values using a scale factor for a range centered around zero. |
123+
#### Quantization Parameters Precision
128124

129-
---
125+
| Param | Description |
126+
| :----- | :----------------------------- |
127+
| `F32` | Full floating-point precision. |
128+
| `F16` | Half-precision floating point. |
129+
| `BF16` | Brain float 16-bit precision. |

crates/burn-core/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ pub mod record;
3939

4040
/// Module for the tensor.
4141
pub mod tensor;
42+
// Tensor at root: `burn::Tensor`
43+
pub use tensor::Tensor;
4244

4345
/// Module for visual operations
4446
#[cfg(feature = "vision")]

crates/burn-core/src/lr_scheduler/noam.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ use crate::{LearningRate, config::Config};
88
/// Configuration to create a [noam](NoamLrScheduler) learning rate scheduler.
99
#[derive(Config, Debug)]
1010
pub struct NoamLrSchedulerConfig {
11-
/// The initial learning rate.
12-
init_lr: LearningRate,
11+
/// The overall scale factor for the learning rate decay.
12+
factor: f64,
1313
/// The number of steps before the exponential decay stats.
1414
#[config(default = 4000)]
1515
warmup_steps: usize,
@@ -23,7 +23,7 @@ pub struct NoamLrSchedulerConfig {
2323
pub struct NoamLrScheduler {
2424
warmup_steps: f64,
2525
embedding_size: f64,
26-
init_lr: LearningRate,
26+
factor: f64,
2727
step: f64,
2828
}
2929

@@ -49,7 +49,7 @@ impl NoamLrSchedulerConfig {
4949
Ok(NoamLrScheduler {
5050
warmup_steps: self.warmup_steps as f64,
5151
embedding_size: self.model_size as f64,
52-
init_lr: self.init_lr,
52+
factor: self.factor,
5353
step: 0.0,
5454
})
5555
}
@@ -64,7 +64,7 @@ impl LrScheduler for NoamLrScheduler {
6464
let arg1 = self.step.powf(-0.5);
6565
let arg2 = self.step * self.warmup_steps.powf(-1.5);
6666

67-
self.init_lr * self.embedding_size.powf(-0.5) * f64::min(arg1, arg2)
67+
self.factor * self.embedding_size.powf(-0.5) * f64::min(arg1, arg2)
6868
}
6969

7070
fn to_record<B: Backend>(&self) -> Self::Record<B> {

crates/burn-tensor/src/tensor/element/base.rs

Lines changed: 17 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ pub trait Element:
1414
ToElement
1515
+ ElementRandom
1616
+ ElementConversion
17-
+ ElementPrecision
1817
+ ElementComparison
1918
+ ElementLimits
2019
+ bytemuck::CheckedBitPattern
@@ -78,42 +77,20 @@ pub trait ElementLimits {
7877
const MAX: Self;
7978
}
8079

81-
/// Element precision trait for tensor.
82-
#[derive(Clone, PartialEq, Eq, Copy, Debug)]
83-
pub enum Precision {
84-
/// Double precision, e.g. f64.
85-
Double,
86-
87-
/// Full precision, e.g. f32.
88-
Full,
89-
90-
/// Half precision, e.g. f16.
91-
Half,
92-
93-
/// Other precision.
94-
Other,
95-
}
96-
97-
/// Element precision trait for tensor.
98-
pub trait ElementPrecision {
99-
/// Returns the precision of the element.
100-
fn precision() -> Precision;
101-
}
102-
10380
/// Macro to implement the element trait for a type.
10481
#[macro_export]
10582
macro_rules! make_element {
10683
(
107-
ty $type:ident $precision:expr,
84+
ty $type:ident,
10885
convert $convert:expr,
10986
random $random:expr,
11087
cmp $cmp:expr,
11188
dtype $dtype:expr
11289
) => {
113-
make_element!(ty $type $precision, convert $convert, random $random, cmp $cmp, dtype $dtype, min $type::MIN, max $type::MAX);
90+
make_element!(ty $type, convert $convert, random $random, cmp $cmp, dtype $dtype, min $type::MIN, max $type::MAX);
11491
};
11592
(
116-
ty $type:ident $precision:expr,
93+
ty $type:ident,
11794
convert $convert:expr,
11895
random $random:expr,
11996
cmp $cmp:expr,
@@ -140,12 +117,6 @@ macro_rules! make_element {
140117
}
141118
}
142119

143-
impl ElementPrecision for $type {
144-
fn precision() -> Precision {
145-
$precision
146-
}
147-
}
148-
149120
impl ElementRandom for $type {
150121
fn random<R: RngCore>(distribution: Distribution, rng: &mut R) -> Self {
151122
#[allow(clippy::redundant_closure_call)]
@@ -170,87 +141,87 @@ macro_rules! make_element {
170141
}
171142

172143
make_element!(
173-
ty f64 Precision::Double,
144+
ty f64,
174145
convert ToElement::to_f64,
175146
random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
176147
cmp |a: &f64, b: &f64| a.total_cmp(b),
177148
dtype DType::F64
178149
);
179150

180151
make_element!(
181-
ty f32 Precision::Full,
152+
ty f32,
182153
convert ToElement::to_f32,
183154
random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
184155
cmp |a: &f32, b: &f32| a.total_cmp(b),
185156
dtype DType::F32
186157
);
187158

188159
make_element!(
189-
ty i64 Precision::Double,
160+
ty i64,
190161
convert ToElement::to_i64,
191162
random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
192163
cmp |a: &i64, b: &i64| Ord::cmp(a, b),
193164
dtype DType::I64
194165
);
195166

196167
make_element!(
197-
ty u64 Precision::Double,
168+
ty u64,
198169
convert ToElement::to_u64,
199170
random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
200171
cmp |a: &u64, b: &u64| Ord::cmp(a, b),
201172
dtype DType::U64
202173
);
203174

204175
make_element!(
205-
ty i32 Precision::Full,
176+
ty i32,
206177
convert ToElement::to_i32,
207178
random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
208179
cmp |a: &i32, b: &i32| Ord::cmp(a, b),
209180
dtype DType::I32
210181
);
211182

212183
make_element!(
213-
ty u32 Precision::Full,
184+
ty u32,
214185
convert ToElement::to_u32,
215186
random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
216187
cmp |a: &u32, b: &u32| Ord::cmp(a, b),
217188
dtype DType::U32
218189
);
219190

220191
make_element!(
221-
ty i16 Precision::Half,
192+
ty i16,
222193
convert ToElement::to_i16,
223194
random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
224195
cmp |a: &i16, b: &i16| Ord::cmp(a, b),
225196
dtype DType::I16
226197
);
227198

228199
make_element!(
229-
ty u16 Precision::Half,
200+
ty u16,
230201
convert ToElement::to_u16,
231202
random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
232203
cmp |a: &u16, b: &u16| Ord::cmp(a, b),
233204
dtype DType::U16
234205
);
235206

236207
make_element!(
237-
ty i8 Precision::Other,
208+
ty i8,
238209
convert ToElement::to_i8,
239210
random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
240211
cmp |a: &i8, b: &i8| Ord::cmp(a, b),
241212
dtype DType::I8
242213
);
243214

244215
make_element!(
245-
ty u8 Precision::Other,
216+
ty u8,
246217
convert ToElement::to_u8,
247218
random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
248219
cmp |a: &u8, b: &u8| Ord::cmp(a, b),
249220
dtype DType::U8
250221
);
251222

252223
make_element!(
253-
ty f16 Precision::Half,
224+
ty f16,
254225
convert ToElement::to_f16,
255226
random |distribution: Distribution, rng: &mut R| {
256227
let sample: f32 = distribution.sampler(rng).sample();
@@ -260,7 +231,7 @@ make_element!(
260231
dtype DType::F16
261232
);
262233
make_element!(
263-
ty bf16 Precision::Half,
234+
ty bf16,
264235
convert ToElement::to_bf16,
265236
random |distribution: Distribution, rng: &mut R| {
266237
let sample: f32 = distribution.sampler(rng).sample();
@@ -272,7 +243,7 @@ make_element!(
272243

273244
#[cfg(feature = "cubecl")]
274245
make_element!(
275-
ty flex32 Precision::Half,
246+
ty flex32,
276247
convert |elem: &dyn ToElement| flex32::from_f32(elem.to_f32()),
277248
random |distribution: Distribution, rng: &mut R| {
278249
let sample: f32 = distribution.sampler(rng).sample();
@@ -285,7 +256,7 @@ make_element!(
285256
);
286257

287258
make_element!(
288-
ty bool Precision::Other,
259+
ty bool,
289260
convert ToElement::to_bool,
290261
random |distribution: Distribution, rng: &mut R| {
291262
let sample: u8 = distribution.sampler(rng).sample();

0 commit comments

Comments
 (0)