Skip to content

Commit af5a69e

Browse files
fp8 support (#2989)
* fp8 support * use float8 crate with cudarc 16.x, fix errors * fix Tensor::ones for fp8 * fp8: fix failing tests * more fp8 * add fp8 where bf16 is in tests * skip fp8 testing on metal * fixed onnx eval match statements that didn't have full coverage * Unused import backend::BackendDevice * kernels: fix cuda_arch guards for fp8 ops --------- Co-authored-by: keighbee <kb@huggingface.co>
1 parent 96415a4 commit af5a69e

File tree

35 files changed

+739
-44
lines changed

35 files changed

+739
-44
lines changed

Cargo.toml

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@ members = [
1111
"tensor-tools",
1212
]
1313
exclude = [
14-
"candle-book",
15-
"candle-flash-attn",
16-
"candle-kernels",
17-
"candle-metal-kernels",
18-
"candle-onnx",
14+
"candle-book",
15+
"candle-flash-attn",
16+
"candle-kernels",
17+
"candle-metal-kernels",
18+
"candle-onnx",
1919
]
2020
resolver = "2"
2121

@@ -42,14 +42,35 @@ candle-nn = { path = "./candle-nn", version = "0.9.1" }
4242
candle-onnx = { path = "./candle-onnx", version = "0.9.1" }
4343
candle-transformers = { path = "./candle-transformers", version = "0.9.1" }
4444
clap = { version = "4.2.4", features = ["derive"] }
45-
criterion = { version = "0.5.1", default-features=false }
46-
cudarc = { version = "0.16.3", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
45+
criterion = { version = "0.5.1", default-features = false }
46+
cudarc = { version = "0.16.3", features = [
47+
"std",
48+
"cublas",
49+
"cublaslt",
50+
"curand",
51+
"driver",
52+
"nvrtc",
53+
"f16",
54+
"cuda-version-from-build-system",
55+
"dynamic-linking",
56+
], default-features = false }
4757
fancy-regex = "0.13.0"
4858
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
4959
hf-hub = "0.4.1"
50-
half = { version = "2.5.0", features = ["num-traits", "use-intrinsics", "rand_distr"] }
60+
half = { version = "2.5.0", features = [
61+
"num-traits",
62+
"use-intrinsics",
63+
"rand_distr",
64+
] }
65+
float8 = { git = "https://github.com/zackangelo/float8", branch = "cudarc_0_16", features = [
66+
"num-traits",
67+
"rand_distr",
68+
] }
5169
hound = "3.5.1"
52-
image = { version = "0.25.2", default-features = false, features = ["jpeg", "png"] }
70+
image = { version = "0.25.2", default-features = false, features = [
71+
"jpeg",
72+
"png",
73+
] }
5374
imageproc = { version = "0.24.0", default-features = false }
5475
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
5576
libc = { version = "0.2.147" }
@@ -75,7 +96,7 @@ ug-cuda = "0.4.0"
7596
ug-metal = "0.4.0"
7697
yoke = { version = "0.7.2", features = ["derive"] }
7798
zip = { version = "1.1.1", default-features = false }
78-
metal = { version = "0.27.0", features = ["mps"]}
99+
metal = { version = "0.27.0", features = ["mps"] }
79100

80101
[profile.release-with-debug]
81102
inherits = "release"

candle-core/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ metal = { workspace = true, optional = true }
1818
cudarc = { workspace = true, optional = true }
1919
gemm = { workspace = true }
2020
half = { workspace = true }
21+
float8 = { workspace = true }
2122
intel-mkl-src = { workspace = true, optional = true }
2223
libc = { workspace = true, optional = true }
2324
memmap2 = { workspace = true }
@@ -43,7 +44,7 @@ criterion = { workspace = true }
4344

4445
[features]
4546
default = []
46-
cuda = ["cudarc", "dep:candle-kernels", "dep:ug-cuda"]
47+
cuda = ["cudarc", "dep:candle-kernels", "dep:ug-cuda", "float8/cuda"]
4748
cudnn = ["cuda", "cudarc/cudnn"]
4849
mkl = ["dep:libc", "dep:intel-mkl-src"]
4950
accelerate = ["dep:libc", "dep:accelerate-src"]

candle-core/benches/benchmarks/affine.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ fn criterion_benchmark(c: &mut Criterion) {
3737
run_affine_benchmark(c, &device, DType::F32, "affine_f32");
3838
run_affine_benchmark(c, &device, DType::F16, "affine_f16");
3939
run_affine_benchmark(c, &device, DType::BF16, "affine_bf16");
40+
run_affine_benchmark(c, &device, DType::F8E4M3, "affine_fp8");
4041
}
4142
}
4243

candle-core/src/convert.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
//! Implement conversion traits for tensors
22
use crate::{DType, Device, Error, Tensor, WithDType};
3+
use float8::F8E4M3;
34
use half::{bf16, f16, slice::HalfFloatSliceExt};
45
use std::convert::TryFrom;
56

@@ -139,6 +140,11 @@ impl Tensor {
139140
let vs = vs.to_vec1::<u8>()?;
140141
f.write_all(&vs)?;
141142
}
143+
DType::F8E4M3 => {
144+
for v in vs.to_vec1::<F8E4M3>()? {
145+
f.write_u8(v.to_bits())?
146+
}
147+
}
142148
}
143149
Ok(())
144150
}

candle-core/src/cpu_backend/mod.rs

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
use crate::backend::{BackendDevice, BackendStorage};
33
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
44
use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType};
5+
use float8::F8E4M3;
56
use half::{bf16, f16};
67
use rayon::prelude::*;
78

@@ -25,6 +26,7 @@ pub enum CpuStorage {
2526
F16(Vec<f16>),
2627
F32(Vec<f32>),
2728
F64(Vec<f64>),
29+
F8E4M3(Vec<F8E4M3>),
2830
}
2931

3032
#[derive(Debug, Clone)]
@@ -36,6 +38,7 @@ pub enum CpuStorageRef<'a> {
3638
F16(&'a [f16]),
3739
F32(&'a [f32]),
3840
F64(&'a [f64]),
41+
F8E4M3(&'a [F8E4M3]),
3942
}
4043

4144
#[derive(Debug, Clone)]
@@ -1691,6 +1694,17 @@ impl CpuStorage {
16911694
.concat();
16921695
Self::F64(storages)
16931696
}
1697+
Self::F8E4M3(_) => {
1698+
let storages = storages
1699+
.iter()
1700+
.map(|s| match s {
1701+
Self::F8E4M3(s) => Ok(s.as_slice()),
1702+
_ => crate::bail!("dtype mismatch"),
1703+
})
1704+
.collect::<Result<Vec<_>>>()?
1705+
.concat();
1706+
Self::F8E4M3(storages)
1707+
}
16941708
};
16951709
Ok(s)
16961710
}
@@ -1708,6 +1722,7 @@ impl BackendStorage for CpuStorage {
17081722
Self::F16(_) => DType::F16,
17091723
Self::F32(_) => DType::F32,
17101724
Self::F64(_) => DType::F64,
1725+
Self::F8E4M3(_) => DType::F8E4M3,
17111726
}
17121727
}
17131728

@@ -1742,6 +1757,10 @@ impl BackendStorage for CpuStorage {
17421757
let data = unary_map(storage, layout, bf16::from_f64);
17431758
Ok(Self::BF16(data))
17441759
}
1760+
(Self::F8E4M3(storage), DType::BF16) => {
1761+
let data = unary_map(storage, layout, |v| bf16::from_f32(v.to_f32()));
1762+
Ok(Self::BF16(data))
1763+
}
17451764
(Self::U8(storage), DType::F16) => {
17461765
let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
17471766
Ok(Self::F16(data))
@@ -1770,6 +1789,10 @@ impl BackendStorage for CpuStorage {
17701789
let data = unary_map(storage, layout, f16::from_f64);
17711790
Ok(Self::F16(data))
17721791
}
1792+
(Self::F8E4M3(storage), DType::F16) => {
1793+
let data = unary_map(storage, layout, |v| f16::from_f32(v.to_f32()));
1794+
Ok(Self::F16(data))
1795+
}
17731796
(Self::U8(storage), DType::F32) => {
17741797
let data = unary_map(storage, layout, |v| v as f32);
17751798
Ok(Self::F32(data))
@@ -1798,6 +1821,10 @@ impl BackendStorage for CpuStorage {
17981821
let data = unary_map(storage, layout, |v| v as f32);
17991822
Ok(Self::F32(data))
18001823
}
1824+
(Self::F8E4M3(storage), DType::F32) => {
1825+
let data = unary_map(storage, layout, |v| v.to_f32());
1826+
Ok(Self::F32(data))
1827+
}
18011828
(Self::U8(storage), DType::U8) => {
18021829
let data = unary_map(storage, layout, |v| v);
18031830
Ok(Self::U8(data))
@@ -1826,6 +1853,10 @@ impl BackendStorage for CpuStorage {
18261853
let data = unary_map(storage, layout, |v| v as u8);
18271854
Ok(Self::U8(data))
18281855
}
1856+
(Self::F8E4M3(storage), DType::U8) => {
1857+
let data = unary_map(storage, layout, |v| v.to_f32() as u8);
1858+
Ok(Self::U8(data))
1859+
}
18291860
(Self::U8(storage), DType::U32) => {
18301861
let data = unary_map(storage, layout, |v| v as u32);
18311862
Ok(Self::U32(data))
@@ -1854,6 +1885,10 @@ impl BackendStorage for CpuStorage {
18541885
let data = unary_map(storage, layout, |v| v as u32);
18551886
Ok(Self::U32(data))
18561887
}
1888+
(Self::F8E4M3(storage), DType::U32) => {
1889+
let data = unary_map(storage, layout, |v| v.to_f32() as u32);
1890+
Ok(Self::U32(data))
1891+
}
18571892
(Self::U8(storage), DType::I64) => {
18581893
let data = unary_map(storage, layout, |v| v as i64);
18591894
Ok(Self::I64(data))
@@ -1882,6 +1917,10 @@ impl BackendStorage for CpuStorage {
18821917
let data = unary_map(storage, layout, |v| v as i64);
18831918
Ok(Self::I64(data))
18841919
}
1920+
(Self::F8E4M3(storage), DType::I64) => {
1921+
let data = unary_map(storage, layout, |v| v.to_f32() as i64);
1922+
Ok(Self::I64(data))
1923+
}
18851924
(Self::U8(storage), DType::F64) => {
18861925
let data = unary_map(storage, layout, |v| v as f64);
18871926
Ok(Self::F64(data))
@@ -1910,6 +1949,42 @@ impl BackendStorage for CpuStorage {
19101949
let data = unary_map(storage, layout, |v| v);
19111950
Ok(Self::F64(data))
19121951
}
1952+
(Self::F8E4M3(storage), DType::F64) => {
1953+
let data = unary_map(storage, layout, |v| v.to_f64());
1954+
Ok(Self::F64(data))
1955+
}
1956+
(Self::U8(storage), DType::F8E4M3) => {
1957+
let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32));
1958+
Ok(Self::F8E4M3(data))
1959+
}
1960+
(Self::U32(storage), DType::F8E4M3) => {
1961+
let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32));
1962+
Ok(Self::F8E4M3(data))
1963+
}
1964+
(Self::I64(storage), DType::F8E4M3) => {
1965+
let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32));
1966+
Ok(Self::F8E4M3(data))
1967+
}
1968+
(Self::BF16(storage), DType::F8E4M3) => {
1969+
let data = unary_map(storage, layout, |v| F8E4M3::from(v.to_f32()));
1970+
Ok(Self::F8E4M3(data))
1971+
}
1972+
(Self::F16(storage), DType::F8E4M3) => {
1973+
let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v.to_f32()));
1974+
Ok(Self::F8E4M3(data))
1975+
}
1976+
(Self::F32(storage), DType::F8E4M3) => {
1977+
let data = unary_map(storage, layout, F8E4M3::from_f32);
1978+
Ok(Self::F8E4M3(data))
1979+
}
1980+
(Self::F64(storage), DType::F8E4M3) => {
1981+
let data = unary_map(storage, layout, F8E4M3::from_f64);
1982+
Ok(Self::F8E4M3(data))
1983+
}
1984+
(Self::F8E4M3(storage), DType::F8E4M3) => {
1985+
let data = unary_map(storage, layout, |v| v);
1986+
Ok(Self::F8E4M3(data))
1987+
}
19131988
}
19141989
}
19151990

@@ -2023,6 +2098,10 @@ impl BackendStorage for CpuStorage {
20232098
let data = unary_map(storage, layout, |v| v.powf(e));
20242099
Ok(Self::F64(data))
20252100
}
2101+
Self::F8E4M3(storage) => {
2102+
let data = unary_map(storage, layout, |v| v.powf(F8E4M3::from_f64(e)));
2103+
Ok(Self::F8E4M3(data))
2104+
}
20262105
Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()),
20272106
Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()),
20282107
Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()),
@@ -2048,6 +2127,10 @@ impl BackendStorage for CpuStorage {
20482127
let data = unary_map(storage, layout, |v| elu(v, alpha));
20492128
Ok(Self::F64(data))
20502129
}
2130+
Self::F8E4M3(storage) => {
2131+
let data = unary_map(storage, layout, |v| elu(v, F8E4M3::from_f64(alpha)));
2132+
Ok(Self::F8E4M3(data))
2133+
}
20512134
Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()),
20522135
Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()),
20532136
Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()),
@@ -2092,6 +2175,15 @@ impl BackendStorage for CpuStorage {
20922175
Ok(Self::F64(data))
20932176
}
20942177
}
2178+
Self::F8E4M3(storage) => {
2179+
if B::F8E4M3_VEC {
2180+
let data = unary_map_vec(storage, layout, B::f8e4m3, B::f8e4m3_vec);
2181+
Ok(Self::F8E4M3(data))
2182+
} else {
2183+
let data = unary_map(storage, layout, B::f8e4m3);
2184+
Ok(Self::F8E4M3(data))
2185+
}
2186+
}
20952187
Self::U8(storage) => {
20962188
let data = unary_map(storage, layout, B::u8);
20972189
Ok(Self::U8(data))
@@ -2564,6 +2656,7 @@ impl BackendStorage for CpuStorage {
25642656
(Self::U8(storage), Scalar::U8(v)) => set(storage, l, v),
25652657
(Self::U32(storage), Scalar::U32(v)) => set(storage, l, v),
25662658
(Self::I64(storage), Scalar::I64(v)) => set(storage, l, v),
2659+
(Self::F8E4M3(storage), Scalar::F8E4M3(v)) => set(storage, l, v),
25672660
(st, s) => crate::bail!(
25682661
"const_set dtype mismatch, expected {:?} but got {:?}",
25692662
st.dtype(),
@@ -2632,6 +2725,16 @@ impl BackendDevice for CpuDevice {
26322725
}
26332726
Ok(CpuStorage::F16(data))
26342727
}
2728+
DType::F8E4M3 => {
2729+
let mut data = Vec::with_capacity(elem_count);
2730+
let uniform =
2731+
rand::distr::Uniform::new(F8E4M3::from_f64(min), F8E4M3::from_f64(max))
2732+
.map_err(Error::wrap)?;
2733+
for _i in 0..elem_count {
2734+
data.push(rng.sample::<F8E4M3, _>(uniform))
2735+
}
2736+
Ok(CpuStorage::F8E4M3(data))
2737+
}
26352738
DType::F32 => {
26362739
let mut data = Vec::with_capacity(elem_count);
26372740
let uniform =
@@ -2679,6 +2782,15 @@ impl BackendDevice for CpuDevice {
26792782
}
26802783
Ok(CpuStorage::F16(data))
26812784
}
2785+
DType::F8E4M3 => {
2786+
let mut data = Vec::with_capacity(elem_count);
2787+
let normal = rand_distr::Normal::new(F8E4M3::from_f64(mean), F8E4M3::from_f64(std))
2788+
.map_err(Error::wrap)?;
2789+
for _i in 0..elem_count {
2790+
data.push(normal.sample(&mut rng))
2791+
}
2792+
Ok(CpuStorage::F8E4M3(data))
2793+
}
26822794
DType::F32 => {
26832795
let mut data = Vec::with_capacity(elem_count);
26842796
let normal =
@@ -2742,6 +2854,11 @@ impl BackendDevice for CpuDevice {
27422854
v.set_len(elem_count);
27432855
CpuStorage::F64(v)
27442856
}
2857+
DType::F8E4M3 => {
2858+
let mut v = Vec::with_capacity(elem_count);
2859+
v.set_len(elem_count);
2860+
CpuStorage::F8E4M3(v)
2861+
}
27452862
};
27462863
Ok(storage)
27472864
}
@@ -2754,6 +2871,7 @@ impl BackendDevice for CpuDevice {
27542871
DType::I64 => CpuStorage::I64(vec![0i64; elem_count]),
27552872
DType::BF16 => CpuStorage::BF16(vec![bf16::ZERO; elem_count]),
27562873
DType::F16 => CpuStorage::F16(vec![f16::ZERO; elem_count]),
2874+
DType::F8E4M3 => CpuStorage::F8E4M3(vec![F8E4M3::ZERO; elem_count]),
27572875
DType::F32 => CpuStorage::F32(vec![0f32; elem_count]),
27582876
DType::F64 => CpuStorage::F64(vec![0f64; elem_count]),
27592877
};

candle-core/src/cpu_backend/utils.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ pub trait Map1 {
1515
C::F16(vs) => Ok(C::F16(self.f(vs, layout)?)),
1616
C::F32(vs) => Ok(C::F32(self.f(vs, layout)?)),
1717
C::F64(vs) => Ok(C::F64(self.f(vs, layout)?)),
18+
C::F8E4M3(vs) => Ok(C::F8E4M3(self.f(vs, layout)?)),
1819
}
1920
}
2021
}
@@ -31,6 +32,7 @@ pub trait Map1Any {
3132
C::F16(vs) => Ok(self.f(vs, layout, C::F16)?),
3233
C::F32(vs) => Ok(self.f(vs, layout, C::F32)?),
3334
C::F64(vs) => Ok(self.f(vs, layout, C::F64)?),
35+
C::F8E4M3(vs) => Ok(self.f(vs, layout, C::F8E4M3)?),
3436
}
3537
}
3638
}
@@ -48,6 +50,7 @@ pub trait Map2 {
4850
(C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)),
4951
(C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2)?)),
5052
(C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2)?)),
53+
(C::F8E4M3(v1), C::F8E4M3(v2)) => Ok(C::F8E4M3(self.f(v1, l1, v2, l2)?)),
5154
_ => Err(Error::DTypeMismatchBinaryOp {
5255
lhs: v1.dtype(),
5356
rhs: v2.dtype(),
@@ -95,6 +98,7 @@ pub trait Map2U8 {
9598
(C::F16(v1), C::F16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
9699
(C::F32(v1), C::F32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
97100
(C::F64(v1), C::F64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
101+
(C::F8E4M3(v1), C::F8E4M3(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
98102
_ => Err(Error::DTypeMismatchBinaryOp {
99103
lhs: v1.dtype(),
100104
rhs: v2.dtype(),

0 commit comments

Comments
 (0)