Skip to content

Commit 1ef1341

Browse files
*Major T/s improvement* Use the Metal qmatmul MM kernels (#2615)
* Add GGUF BF16 support (#17) * Add GGUF bf16 type support * Add non avx impl for vec_dot_bf16 * Fix from_u32 * Fix loading * Fix dequant of bf16 * Update kernels for metal bf16 (#19) * Update kernels for metal bf16 * Fix typo * Check if have bfloat * Sync ggml metal kernels (#33) * Metal qmatmul mat-mat product (#39) * Test passes * All tests pass * Now all the tests really pass * Try out always using mm * Mirror llama.cpp metric * Mirror llama.cpp metric * Update test * Update test * fixed merge error --------- Co-authored-by: keighbee <[email protected]>
1 parent 6c95317 commit 1ef1341

File tree

12 files changed

+5612
-2525
lines changed

12 files changed

+5612
-2525
lines changed

candle-core/benches/benchmarks/mod.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@ impl BenchDevice for Device {
2222
Device::Cpu => Ok(()),
2323
Device::Cuda(device) => {
2424
#[cfg(feature = "cuda")]
25-
return Ok(device
26-
.synchronize()
27-
.map_err(|e| candle_core::Error::Cuda(Box::new(e)))?);
25+
{
26+
use cuda::WrapErr;
27+
return Ok(device.synchronize().w()?);
28+
}
2829
#[cfg(not(feature = "cuda"))]
2930
panic!("Cuda device without cuda feature enabled: {:?}", device)
3031
}

candle-core/src/cpu/avx.rs

Lines changed: 81 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
use super::{Cpu, CpuF16};
1+
use super::{Cpu, CpuBF16, CpuF16};
22
#[cfg(target_arch = "x86")]
33
use core::arch::x86::*;
44
#[cfg(target_arch = "x86_64")]
55
use core::arch::x86_64::*;
66

7-
use half::f16;
7+
use half::{bf16, f16};
88

99
pub struct CurrentCpu {}
1010

@@ -146,3 +146,82 @@ impl CpuF16<ARR> for CurrentCpuF16 {
146146
*y = _mm_cvtss_f32(_mm_hadd_ps(t1, t1));
147147
}
148148
}
149+
150+
pub struct CurrentCpuBF16 {}
151+
impl CpuBF16<ARR> for CurrentCpuBF16 {
152+
type Unit = __m256;
153+
type Array = [__m256; ARR];
154+
155+
const STEP: usize = STEP;
156+
const EPR: usize = EPR;
157+
158+
fn n() -> usize {
159+
ARR
160+
}
161+
162+
unsafe fn zero() -> Self::Unit {
163+
_mm256_setzero_ps()
164+
}
165+
166+
unsafe fn zero_array() -> Self::Array {
167+
[Self::zero(); ARR]
168+
}
169+
170+
unsafe fn from_f32(v: f32) -> Self::Unit {
171+
_mm256_set1_ps(v)
172+
}
173+
174+
#[cfg(target_feature = "f16c")]
175+
unsafe fn load(mem_addr: *const bf16) -> Self::Unit {
176+
_mm256_cvtph_ps(_mm_loadu_si128(mem_addr as *const __m128i))
177+
}
178+
179+
#[cfg(not(target_feature = "f16c"))]
180+
unsafe fn load(mem_addr: *const bf16) -> Self::Unit {
181+
let mut tmp = [0.0f32; 8];
182+
for i in 0..8 {
183+
tmp[i] = (*mem_addr.add(i)).to_f32();
184+
}
185+
_mm256_loadu_ps(tmp.as_ptr())
186+
}
187+
188+
unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit {
189+
_mm256_add_ps(a, b)
190+
}
191+
192+
unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit {
193+
_mm256_add_ps(_mm256_mul_ps(b, c), a)
194+
}
195+
196+
#[cfg(target_feature = "f16c")]
197+
unsafe fn vec_store(mem_addr: *mut bf16, a: Self::Unit) {
198+
_mm_storeu_si128(mem_addr as *mut __m128i, _mm256_cvtps_ph(a, 0))
199+
}
200+
201+
#[cfg(not(target_feature = "f16c"))]
202+
unsafe fn vec_store(mem_addr: *mut bf16, a: Self::Unit) {
203+
let mut tmp = [0.0f32; 8];
204+
_mm256_storeu_ps(tmp.as_mut_ptr(), a);
205+
for i in 0..8 {
206+
*mem_addr.add(i) = bf16::from_f32(tmp[i]);
207+
}
208+
}
209+
210+
unsafe fn vec_reduce(mut x: Self::Array, y: *mut f32) {
211+
let mut offset = ARR >> 1;
212+
for i in 0..offset {
213+
x[i] = _mm256_add_ps(x[i], x[offset + i]);
214+
}
215+
offset >>= 1;
216+
for i in 0..offset {
217+
x[i] = _mm256_add_ps(x[i], x[offset + i]);
218+
}
219+
offset >>= 1;
220+
for i in 0..offset {
221+
x[i] = _mm256_add_ps(x[i], x[offset + i]);
222+
}
223+
let t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), _mm256_extractf128_ps(x[0], 1));
224+
let t1 = _mm_hadd_ps(t0, t0);
225+
*y = _mm_cvtss_f32(_mm_hadd_ps(t1, t1));
226+
}
227+
}

candle-core/src/cpu/kernels.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,13 @@ impl VecOps for half::bf16 {
121121
fn max(self, other: Self) -> Self {
122122
Self::max(self, other)
123123
}
124+
125+
#[inline(always)]
126+
unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
127+
let mut res_f32 = 0f32;
128+
super::vec_dot_bf16(lhs, rhs, &mut res_f32, len);
129+
*res = half::bf16::from_f32(res_f32);
130+
}
124131
}
125132
impl VecOps for u8 {
126133
#[inline(always)]

candle-core/src/cpu/mod.rs

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,33 @@ trait CpuF16<const ARR: usize> {
3838
unsafe fn from_f32(v: f32) -> Self::Unit;
3939
unsafe fn vec_store(mem_addr: *mut f16, a: Self::Unit);
4040
}
41-
use half::f16;
41+
42+
#[allow(unused)]
43+
trait CpuBF16<const ARR: usize> {
44+
type Unit;
45+
type Array;
46+
const STEP: usize;
47+
const EPR: usize;
48+
49+
fn n() -> usize;
50+
unsafe fn zero() -> Self::Unit;
51+
unsafe fn zero_array() -> Self::Array;
52+
unsafe fn load(mem_addr: *const bf16) -> Self::Unit;
53+
unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit;
54+
unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit;
55+
unsafe fn vec_reduce(x: Self::Array, y: *mut f32);
56+
unsafe fn from_f32(v: f32) -> Self::Unit;
57+
unsafe fn vec_store(mem_addr: *mut bf16, a: Self::Unit);
58+
}
59+
60+
use half::{bf16, f16};
4261

4362
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
4463
#[cfg(target_feature = "avx")]
4564
pub mod avx;
4665
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
4766
#[cfg(target_feature = "avx")]
48-
pub use avx::{CurrentCpu, CurrentCpuF16};
67+
pub use avx::{CurrentCpu, CurrentCpuBF16, CurrentCpuF16};
4968

5069
#[cfg(target_arch = "wasm32")]
5170
#[cfg(target_feature = "simd128")]
@@ -172,6 +191,34 @@ pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f
172191
*c = sumf;
173192
}
174193

194+
#[cfg(target_feature = "avx")]
195+
#[inline(always)]
196+
pub(crate) unsafe fn vec_dot_bf16(a_row: *const bf16, b_row: *const bf16, c: *mut f32, k: usize) {
197+
let mut sumf = 0.0f32;
198+
let np = k & !(CurrentCpuBF16::STEP - 1);
199+
200+
let mut sum = CurrentCpuBF16::zero_array();
201+
let mut ax = CurrentCpuBF16::zero_array();
202+
let mut ay = CurrentCpuBF16::zero_array();
203+
204+
for i in (0..np).step_by(CurrentCpuBF16::STEP) {
205+
for j in 0..CurrentCpuBF16::n() {
206+
ax[j] = CurrentCpuBF16::load(a_row.add(i + j * CurrentCpuBF16::EPR));
207+
ay[j] = CurrentCpuBF16::load(b_row.add(i + j * CurrentCpuBF16::EPR));
208+
209+
sum[j] = CurrentCpuBF16::vec_fma(sum[j], ax[j], ay[j]);
210+
}
211+
}
212+
213+
CurrentCpuBF16::vec_reduce(sum, &mut sumf);
214+
215+
// leftovers
216+
for i in np..k {
217+
sumf += (*a_row.add(i)).to_f32() * (*b_row.add(i)).to_f32();
218+
}
219+
*c = sumf;
220+
}
221+
175222
#[cfg(not(target_feature = "avx"))]
176223
#[inline(always)]
177224
pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f32, k: usize) {
@@ -182,3 +229,14 @@ pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f
182229
}
183230
*c = sum;
184231
}
232+
233+
#[cfg(not(target_feature = "avx"))]
234+
#[inline(always)]
235+
pub(crate) unsafe fn vec_dot_bf16(a_row: *const bf16, b_row: *const bf16, c: *mut f32, k: usize) {
236+
// leftovers
237+
let mut sum = 0.0;
238+
for i in 0..k {
239+
sum += (*a_row.add(i)).to_f32() * (*b_row.add(i)).to_f32();
240+
}
241+
*c = sum;
242+
}

candle-core/src/quantized/cuda.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,7 @@ impl QCudaStorage {
431431
match self.dtype {
432432
GgmlDType::F32 => deq::<f32>(&buffer, block_len, &mut out)?,
433433
GgmlDType::F16 => deq::<half::f16>(&buffer, block_len, &mut out)?,
434+
GgmlDType::BF16 => deq::<half::bf16>(&buffer, block_len, &mut out)?,
434435
GgmlDType::Q4_0 => deq::<crate::quantized::BlockQ4_0>(&buffer, block_len, &mut out)?,
435436
GgmlDType::Q4_1 => deq::<crate::quantized::BlockQ4_1>(&buffer, block_len, &mut out)?,
436437
GgmlDType::Q5_0 => deq::<crate::quantized::BlockQ5_0>(&buffer, block_len, &mut out)?,

candle-core/src/quantized/ggml_file.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ pub fn qtensor_from_ggml(
153153
match ggml_dtype {
154154
GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims, device),
155155
GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims, device),
156+
GgmlDType::BF16 => from_raw_data::<half::bf16>(raw_data, size_in_bytes, dims, device),
156157
GgmlDType::Q4_0 => {
157158
from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims, device)
158159
}

candle-core/src/quantized/k_quants.rs

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use super::utils::{
55
use super::GgmlDType;
66
use crate::Result;
77
use byteorder::{ByteOrder, LittleEndian};
8-
use half::f16;
8+
use half::{bf16, f16};
99
use rayon::prelude::*;
1010

1111
// Default to QK_K 256 rather than 64.
@@ -1963,3 +1963,47 @@ impl GgmlType for f16 {
19631963
Ok(())
19641964
}
19651965
}
1966+
1967+
impl GgmlType for bf16 {
1968+
const DTYPE: GgmlDType = GgmlDType::BF16;
1969+
const BLCK_SIZE: usize = 1;
1970+
type VecDotType = bf16;
1971+
1972+
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
1973+
Self::vec_dot_unopt(n, xs, ys)
1974+
}
1975+
1976+
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
1977+
if xs.len() < n {
1978+
crate::bail!("size mismatch {} < {n}", xs.len())
1979+
}
1980+
if ys.len() < n {
1981+
crate::bail!("size mismatch {} < {n}", ys.len())
1982+
}
1983+
let mut res = 0f32;
1984+
unsafe { crate::cpu::vec_dot_bf16(xs.as_ptr(), ys.as_ptr(), &mut res, n) };
1985+
Ok(res)
1986+
}
1987+
1988+
fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
1989+
if xs.len() != ys.len() {
1990+
crate::bail!("size mismatch {} {}", xs.len(), ys.len());
1991+
}
1992+
// TODO: vectorize
1993+
for (x, y) in xs.iter().zip(ys.iter_mut()) {
1994+
*y = bf16::from_f32(*x)
1995+
}
1996+
Ok(())
1997+
}
1998+
1999+
fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
2000+
if xs.len() != ys.len() {
2001+
crate::bail!("size mismatch {} {}", xs.len(), ys.len());
2002+
}
2003+
// TODO: vectorize
2004+
for (x, y) in xs.iter().zip(ys.iter_mut()) {
2005+
*y = x.to_f32()
2006+
}
2007+
Ok(())
2008+
}
2009+
}

0 commit comments

Comments
 (0)