From 41e6b40010527b092fc79827644f78abd3d0e0fa Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Thu, 25 Sep 2025 13:11:34 +0200 Subject: [PATCH 1/9] Add direct copy for floats in quantization --- candle-core/src/quantized/k_quants.rs | 88 ++++++++++++++++++--------- candle-core/tests/quantized_tests.rs | 49 +++++++++------ 2 files changed, 92 insertions(+), 45 deletions(-) diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs index 4c41de9edb..3c6ca6fad6 100644 --- a/candle-core/src/quantized/k_quants.rs +++ b/candle-core/src/quantized/k_quants.rs @@ -5,7 +5,7 @@ use super::utils::{ use super::GgmlDType; use crate::Result; use byteorder::{ByteOrder, LittleEndian}; -use half::{bf16, f16}; +use half::{bf16, f16, slice::HalfFloatSliceExt}; use rayon::prelude::*; // Default to QK_K 256 rather than 64. @@ -22,6 +22,7 @@ pub const QK8_1: usize = 32; pub trait GgmlType: Sized + Clone + Send + Sync { const DTYPE: GgmlDType; const BLCK_SIZE: usize; + const DIRECT_COPY: bool = false; type VecDotType: GgmlType; // This is only safe for types that include immediate values such as float/int/... @@ -31,6 +32,12 @@ pub trait GgmlType: Sized + Clone + Send + Sync { fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()>; fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()>; + fn direct_copy(_xs: &[f32], _ys: &mut [Self]) -> Result<()> { + Err(crate::Error::Msg( + "direct_copy not implemented for this type".into(), + )) + } + /// Dot product used as a building block for quantized mat-mul. /// n is the number of elements to be considered. fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result; @@ -658,8 +665,24 @@ impl GgmlType for BlockQ8_1 { Self::vec_dot_unopt(n, xs, ys) } - fn vec_dot_unopt(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result { - unimplemented!("no support for vec-dot on Q8_1") + fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + let qk = QK8_1; + if n % QK8_1 != 0 { + crate::bail!("vec_dot_q8_1_q8_1: {n} is not divisible by {qk}") + } + + // Generic implementation. + let mut sumf = 0f32; + for (xs, ys) in xs.iter().zip(ys.iter()) { + let sum_i = xs + .qs + .iter() + .zip(ys.qs.iter()) + .map(|(&x, &y)| x as i32 * y as i32) + .sum::(); + sumf += sum_i as f32 * f16::to_f32(xs.d) * f16::to_f32(ys.d) + } + Ok(sumf) } fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { @@ -1838,7 +1861,7 @@ impl GgmlType for BlockQ8K { } } -// https://github.com/ggerganov/llama.cpp/blob/b5ffb2849d23afe73647f68eec7b68187af09be6/ggml.c#L10605 +// https://github.com/ggml-org/llama.cpp/blob/aa3ee0eb0b80efca126cedf9bcb4fb5864b46ce3/ggml/src/ggml-cpu/ggml-cpu.c#L1205 pub fn matmul( mkn: (usize, usize, usize), lhs: &[f32], @@ -1849,18 +1872,24 @@ pub fn matmul( if m * k != lhs.len() { crate::bail!("unexpected lhs length {} {mkn:?}", lhs.len()); } - let k_in_lhs_blocks = k.div_ceil(T::BLCK_SIZE); let k_in_rhs_blocks = k.div_ceil(T::VecDotType::BLCK_SIZE); - // TODO: Do not make this copy if the DotType is f32. + // TODO: Pre-allocate this. let mut lhs_b = vec![T::VecDotType::zeros(); m * k_in_lhs_blocks]; - for row_idx in 0..m { - let lhs_b = &mut lhs_b[row_idx * k_in_lhs_blocks..(row_idx + 1) * k_in_lhs_blocks]; - let lhs = &lhs[row_idx * k..(row_idx + 1) * k]; - T::VecDotType::from_float(lhs, lhs_b)? - } - let lhs_b = lhs_b.as_slice(); + + // f32, f16, and bf16 support direct copy + let lhs_b = if T::DIRECT_COPY { + T::VecDotType::direct_copy(lhs, &mut lhs_b[..])?; + lhs_b.as_slice() + } else { + for row_idx in 0..m { + let lhs_b_mut = &mut lhs_b[row_idx * k_in_lhs_blocks..(row_idx + 1) * k_in_lhs_blocks]; + let lhs = &lhs[row_idx * k..(row_idx + 1) * k]; + T::VecDotType::from_float(lhs, lhs_b_mut)? + } + lhs_b.as_slice() + }; for row_idx in 0..m { let lhs_row = &lhs_b[row_idx * k_in_lhs_blocks..(row_idx + 1) * k_in_lhs_blocks]; @@ -1885,6 +1914,7 @@ pub fn matmul( impl GgmlType for f32 { const DTYPE: GgmlDType = GgmlDType::F32; const BLCK_SIZE: usize = 1; + const DIRECT_COPY: bool = true; type VecDotType = f32; fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { @@ -1918,11 +1948,16 @@ impl GgmlType for f32 { ys.copy_from_slice(xs); Ok(()) } + + fn direct_copy(xs: &[f32], ys: &mut [Self]) -> Result<()> { + Self::from_float(xs, ys) + } } impl GgmlType for f16 { const DTYPE: GgmlDType = GgmlDType::F16; const BLCK_SIZE: usize = 1; + const DIRECT_COPY: bool = true; type VecDotType = f16; fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { @@ -1945,10 +1980,7 @@ impl GgmlType for f16 { if xs.len() != ys.len() { crate::bail!("size mismatch {} {}", xs.len(), ys.len()); } - // TODO: vectorize - for (x, y) in xs.iter().zip(ys.iter_mut()) { - *y = f16::from_f32(*x) - } + ys.convert_from_f32_slice(xs); Ok(()) } @@ -1956,17 +1988,19 @@ impl GgmlType for f16 { if xs.len() != ys.len() { crate::bail!("size mismatch {} {}", xs.len(), ys.len()); } - // TODO: vectorize - for (x, y) in xs.iter().zip(ys.iter_mut()) { - *y = x.to_f32() - } + xs.convert_to_f32_slice(ys); Ok(()) } + + fn direct_copy(xs: &[f32], ys: &mut [Self]) -> Result<()> { + Self::from_float(xs, ys) + } } impl GgmlType for bf16 { const DTYPE: GgmlDType = GgmlDType::BF16; const BLCK_SIZE: usize = 1; + const DIRECT_COPY: bool = true; type VecDotType = bf16; fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { @@ -1989,10 +2023,7 @@ impl GgmlType for bf16 { if xs.len() != ys.len() { crate::bail!("size mismatch {} {}", xs.len(), ys.len()); } - // TODO: vectorize - for (x, y) in xs.iter().zip(ys.iter_mut()) { - *y = bf16::from_f32(*x) - } + ys.convert_from_f32_slice(xs); Ok(()) } @@ -2000,10 +2031,11 @@ impl GgmlType for bf16 { if xs.len() != ys.len() { crate::bail!("size mismatch {} {}", xs.len(), ys.len()); } - // TODO: vectorize - for (x, y) in xs.iter().zip(ys.iter_mut()) { - *y = x.to_f32() - } + xs.convert_to_f32_slice(ys); Ok(()) } + + fn direct_copy(xs: &[f32], ys: &mut [Self]) -> Result<()> { + Self::from_float(xs, ys) + } } diff --git a/candle-core/tests/quantized_tests.rs b/candle-core/tests/quantized_tests.rs index 46a92b2961..0e25742d00 100644 --- a/candle-core/tests/quantized_tests.rs +++ b/candle-core/tests/quantized_tests.rs @@ -816,7 +816,9 @@ fn vec_dot_reference(a: &[f32], b: &[f32]) -> f32 { /// Returns the error achieved by the GGML matmul unit test. fn ggml_reference_matmul_error(dtype: GgmlDType) -> Result { let err = match dtype { + GgmlDType::F32 => 0.000000, GgmlDType::F16 => 0.000010, + GgmlDType::BF16 => 0.000200, GgmlDType::Q2K => 0.004086, GgmlDType::Q3K => 0.016148, GgmlDType::Q4K => 0.002425, @@ -827,6 +829,7 @@ fn ggml_reference_matmul_error(dtype: GgmlDType) -> Result { GgmlDType::Q5_0 => 0.001353, GgmlDType::Q5_1 => 0.00149, GgmlDType::Q8_0 => 0.000092, + GgmlDType::Q8_1 => 0.000092, // Not from the ggml repo. GgmlDType::Q8K => 0.00065, @@ -862,7 +865,6 @@ fn ggml_matmul_error_test_(a: &[f32], b: &[f32], err_m: f32) -> Res let result = T::vec_dot(length, &a_quant, &b_quant)?; let result_unopt = T::vec_dot_unopt(length, &a_quant, &b_quant)?; - let reference_result = vec_dot_reference(a, b); if (result - result_unopt).abs() / length as f32 > 1e-6 { bail!( @@ -870,6 +872,17 @@ fn ggml_matmul_error_test_(a: &[f32], b: &[f32], err_m: f32) -> Res ) } + let mut dst = vec![0.0f32; 1]; + crate::k_quants::matmul((1, length, 1), a, &b_quant, &mut dst)?; + let result_matmul = dst[0]; + if (result_matmul - result_unopt).abs() / length as f32 > 1e-6 { + bail!( + "calling matmul vs calling vec-dot directly returned different values, matmul {result_matmul}, unopt {result_unopt}" + ) + } + + let reference_result = vec_dot_reference(a, b); + let error = (result - reference_result).abs() / length as f32; let ggml_error = ggml_reference_matmul_error(T::DTYPE)? * err_m; @@ -893,11 +906,15 @@ fn ggml_matmul_error_test_(a: &[f32], b: &[f32], err_m: f32) -> Res #[test] fn quantized_mm() -> Result<()> { + ggml_matmul_error_test::()?; + ggml_matmul_error_test::()?; + ggml_matmul_error_test::()?; ggml_matmul_error_test::()?; ggml_matmul_error_test::()?; ggml_matmul_error_test::()?; ggml_matmul_error_test::()?; ggml_matmul_error_test::()?; + ggml_matmul_error_test::()?; Ok(()) } @@ -973,15 +990,13 @@ quantized_matmul!( quantized_matmul_q8_0_metal, GgmlDType::Q8_0 ); -// Not implemented in Ggml -// quantized_matmul!( -// quantized_matmul_q8_1_bis, -// quantized_matmul_q8_1_cpu, -// quantized_matmul_q8_1_cuda, -// quantized_matmul_q8_1_metal, -// GgmlDType::Q8_1 -// ); -// TODO This is bugged (also bugged in GGML +quantized_matmul!( + quantized_matmul_q8_1_bis, + quantized_matmul_q8_1_cpu, + quantized_matmul_q8_1_cuda, + quantized_matmul_q8_1_metal, + GgmlDType::Q8_1 +); quantized_matmul!( quantized_matmul_q2k_bis, quantized_matmul_q2k_cpu, @@ -1018,13 +1033,13 @@ quantized_matmul!( GgmlDType::Q6K ); // Not implemented on metal -// quantized_matmul!( -// quantized_matmul_q8k_bis, -// quantized_matmul_q8k_cpu, -// quantized_matmul_q8k_cuda, -// quantized_matmul_q8k_metal, -// GgmlDType::Q8K -// ); +quantized_matmul!( + quantized_matmul_q8k_bis, + quantized_matmul_q8k_cpu, + quantized_matmul_q8k_cuda, + quantized_matmul_q8k_metal, + GgmlDType::Q8K +); #[test] fn quantized_matmul_q2k() -> Result<()> { From b1a294803ced73d2539a828281c8ec426aba9f6e Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Thu, 25 Sep 2025 13:23:58 +0200 Subject: [PATCH 2/9] Calling q matmul directly gives slightly different results, but within ggml error leniency --- candle-core/tests/quantized_tests.rs | 40 +++++++++++++++------------- 1 file changed, 22 insertions(+), 18 deletions(-) diff --git a/candle-core/tests/quantized_tests.rs b/candle-core/tests/quantized_tests.rs index 0e25742d00..fbf5e3dbe6 100644 --- a/candle-core/tests/quantized_tests.rs +++ b/candle-core/tests/quantized_tests.rs @@ -833,7 +833,6 @@ fn ggml_reference_matmul_error(dtype: GgmlDType) -> Result { // Not from the ggml repo. GgmlDType::Q8K => 0.00065, - _ => bail!("No GGML results for quantization type {dtype:?}",), }; Ok(err) } @@ -875,32 +874,37 @@ fn ggml_matmul_error_test_(a: &[f32], b: &[f32], err_m: f32) -> Res let mut dst = vec![0.0f32; 1]; crate::k_quants::matmul((1, length, 1), a, &b_quant, &mut dst)?; let result_matmul = dst[0]; + /* if (result_matmul - result_unopt).abs() / length as f32 > 1e-6 { bail!( "calling matmul vs calling vec-dot directly returned different values, matmul {result_matmul}, unopt {result_unopt}" ) } + */ let reference_result = vec_dot_reference(a, b); - let error = (result - reference_result).abs() / length as f32; - - let ggml_error = ggml_reference_matmul_error(T::DTYPE)? * err_m; - - if !error.is_finite() || error > GGML_MAX_DOT_PRODUCT_ERROR { - bail!("Dot product error {error} exceeds max error {GGML_MAX_DOT_PRODUCT_ERROR}",); - } + let verify_result = |result: f32| { + let error = (result - reference_result).abs() / length as f32; + let ggml_error = ggml_reference_matmul_error(T::DTYPE)? * err_m; + if !error.is_finite() || error > GGML_MAX_DOT_PRODUCT_ERROR { + bail!("Dot product error {error} exceeds max error {GGML_MAX_DOT_PRODUCT_ERROR}",); + } + // We diverge slightly due to different rounding behavior / f16 to f32 conversions in GGML + // => we use a slightly higher error threshold + const ERROR_LENIENCY: f32 = 0.00001; + if error - ERROR_LENIENCY > ggml_error { + bail!( + "Dot product error {} exceeds ggml reference error {}", + error, + ggml_error + ); + } + Ok(()) + }; - // We diverge slightly due to different rounding behavior / f16 to f32 conversions in GGML - // => we use a slightly higher error threshold - const ERROR_LENIENCY: f32 = 0.00001; - if error - ERROR_LENIENCY > ggml_error { - bail!( - "Dot product error {} exceeds ggml reference error {}", - error, - ggml_error - ); - } + verify_result(result)?; + verify_result(result_matmul)?; Ok(()) } From 2b56d0ad9bca84b90eb5ec0c0c51102a4eda002f Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Thu, 2 Oct 2025 07:41:59 +0200 Subject: [PATCH 3/9] fix quantized_mm test. Flip a and b input to matmul --- candle-core/src/quantized/k_quants.rs | 5 ++--- candle-core/tests/quantized_tests.rs | 11 +++++------ 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs index 96d4549376..9e3a211deb 100644 --- a/candle-core/src/quantized/k_quants.rs +++ b/candle-core/src/quantized/k_quants.rs @@ -1863,14 +1863,13 @@ impl GgmlType for BlockQ8K { // https://github.com/ggml-org/llama.cpp/blob/aa3ee0eb0b80efca126cedf9bcb4fb5864b46ce3/ggml/src/ggml-cpu/ggml-cpu.c#L1205 pub fn matmul( - mkn: (usize, usize, usize), + (m, k, n): (usize, usize, usize), lhs: &[f32], rhs_t: &[T], dst: &mut [f32], ) -> Result<()> { - let (m, k, n) = mkn; if m * k != lhs.len() { - crate::bail!("unexpected lhs length {} {mkn:?}", lhs.len()); + crate::bail!("unexpected lhs length {} ({m},{k},{n})", lhs.len()); } let k_in_lhs_blocks = k.div_ceil(T::BLCK_SIZE); let k_in_rhs_blocks = k.div_ceil(T::VecDotType::BLCK_SIZE); diff --git a/candle-core/tests/quantized_tests.rs b/candle-core/tests/quantized_tests.rs index 320dea1a45..8e84c52673 100644 --- a/candle-core/tests/quantized_tests.rs +++ b/candle-core/tests/quantized_tests.rs @@ -867,20 +867,19 @@ fn ggml_matmul_error_test_(a: &[f32], b: &[f32], err_m: f32) -> Res if (result - result_unopt).abs() / length as f32 > 1e-6 { bail!( - "the opt and unopt vec-dot returned different values, opt {result}, unopt {result_unopt}" + "the opt and unopt vec-dot returned different values, opt: {result} vs unopt: {result_unopt}" ) } let mut dst = vec![0.0f32; 1]; - crate::k_quants::matmul((1, length, 1), a, &b_quant, &mut dst)?; + crate::k_quants::matmul((1, length, 1), b, &a_quant, &mut dst)?; let result_matmul = dst[0]; - /* - if (result_matmul - result_unopt).abs() / length as f32 > 1e-6 { + + if (result_matmul - result).abs() / length as f32 > 1e-6 { bail!( - "calling matmul vs calling vec-dot directly returned different values, matmul {result_matmul}, unopt {result_unopt}" + "calling matmul vs calling vec-dot directly returned different values, matmul: {result_matmul} vs vec-dot: {result}" ) } - */ let reference_result = vec_dot_reference(a, b); From b1881729bf0ac2f069be023245e6abcfd07b3cfe Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Thu, 2 Oct 2025 07:54:40 +0200 Subject: [PATCH 4/9] Add compile time verification of block sizes being equal to vec dot type block sizes --- candle-core/src/quantized/k_quants.rs | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs index 9e3a211deb..bf9c6d2a4c 100644 --- a/candle-core/src/quantized/k_quants.rs +++ b/candle-core/src/quantized/k_quants.rs @@ -1868,9 +1868,12 @@ pub fn matmul( rhs_t: &[T], dst: &mut [f32], ) -> Result<()> { + debug_assert_eq!(T::BLCK_SIZE, T::VecDotType::BLCK_SIZE); + if m * k != lhs.len() { crate::bail!("unexpected lhs length {} ({m},{k},{n})", lhs.len()); } + let k_in_lhs_blocks = k.div_ceil(T::BLCK_SIZE); let k_in_rhs_blocks = k.div_ceil(T::VecDotType::BLCK_SIZE); @@ -2038,3 +2041,23 @@ impl GgmlType for bf16 { Self::from_float(xs, ys) } } + +macro_rules! verify_block_size { + ( $block_type:ident ) => { + const _: () = + assert!($block_type::BLCK_SIZE == <$block_type as GgmlType>::VecDotType::BLCK_SIZE); + }; +} + +macro_rules! verify_block_sizes { + ( $( $block_type:ident ),* ) => { + $( + verify_block_size!($block_type); + )* + }; +} + +verify_block_sizes!( + BlockQ4_0, BlockQ4_1, BlockQ5_0, BlockQ5_1, BlockQ8_0, BlockQ8_1, BlockQ2K, BlockQ3K, BlockQ4K, + BlockQ5K, BlockQ6K, BlockQ8K, f32, f16, bf16 +); From 851cab44947188aff66082014ec79de1efd14576 Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Thu, 2 Oct 2025 08:07:33 +0200 Subject: [PATCH 5/9] Since we have verified that the block sizes are equal we can simplify qmatmul --- candle-core/src/quantized/k_quants.rs | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs index bf9c6d2a4c..dcf23863ae 100644 --- a/candle-core/src/quantized/k_quants.rs +++ b/candle-core/src/quantized/k_quants.rs @@ -1874,11 +1874,10 @@ pub fn matmul( crate::bail!("unexpected lhs length {} ({m},{k},{n})", lhs.len()); } - let k_in_lhs_blocks = k.div_ceil(T::BLCK_SIZE); - let k_in_rhs_blocks = k.div_ceil(T::VecDotType::BLCK_SIZE); + let k_in_blocks = k.div_ceil(T::BLCK_SIZE); // TODO: Pre-allocate this. - let mut lhs_b = vec![T::VecDotType::zeros(); m * k_in_lhs_blocks]; + let mut lhs_b = vec![T::VecDotType::zeros(); m * k_in_blocks]; // f32, f16, and bf16 support direct copy let lhs_b = if T::DIRECT_COPY { @@ -1886,7 +1885,7 @@ pub fn matmul( lhs_b.as_slice() } else { for row_idx in 0..m { - let lhs_b_mut = &mut lhs_b[row_idx * k_in_lhs_blocks..(row_idx + 1) * k_in_lhs_blocks]; + let lhs_b_mut = &mut lhs_b[row_idx * k_in_blocks..(row_idx + 1) * k_in_blocks]; let lhs = &lhs[row_idx * k..(row_idx + 1) * k]; T::VecDotType::from_float(lhs, lhs_b_mut)? } @@ -1894,7 +1893,7 @@ pub fn matmul( }; for row_idx in 0..m { - let lhs_row = &lhs_b[row_idx * k_in_lhs_blocks..(row_idx + 1) * k_in_lhs_blocks]; + let lhs_row = &lhs_b[row_idx * k_in_blocks..(row_idx + 1) * k_in_blocks]; let dst_row = &mut dst[row_idx * n..(row_idx + 1) * n]; let result: Result> = dst_row @@ -1903,7 +1902,7 @@ pub fn matmul( .with_min_len(128) .with_max_len(512) .map(|(col_idx, dst)| { - let rhs_col = &rhs_t[col_idx * k_in_rhs_blocks..(col_idx + 1) * k_in_rhs_blocks]; + let rhs_col = &rhs_t[col_idx * k_in_blocks..(col_idx + 1) * k_in_blocks]; T::vec_dot(k, rhs_col, lhs_row).map(|value| *dst = value) }) .collect(); From d01f2d33c5febf69e4aa9dab7ab2d8a961ab4c67 Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Thu, 2 Oct 2025 08:10:55 +0200 Subject: [PATCH 6/9] clippy --- candle-core/src/quantized/k_quants.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs index dcf23863ae..0df050b017 100644 --- a/candle-core/src/quantized/k_quants.rs +++ b/candle-core/src/quantized/k_quants.rs @@ -667,7 +667,7 @@ impl GgmlType for BlockQ8_1 { fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { let qk = QK8_1; - if n % QK8_1 != 0 { + if !n.is_multiple_of(QK8_1) { crate::bail!("vec_dot_q8_1_q8_1: {n} is not divisible by {qk}") } From 0c718d8c049593dc9d40739f0ea0d07bb50f11b7 Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Thu, 2 Oct 2025 10:49:03 +0200 Subject: [PATCH 7/9] Improved direct copy. Add comment to debug assert --- candle-core/src/quantized/k_quants.rs | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs index 0df050b017..bca98b15bf 100644 --- a/candle-core/src/quantized/k_quants.rs +++ b/candle-core/src/quantized/k_quants.rs @@ -1868,29 +1868,29 @@ pub fn matmul( rhs_t: &[T], dst: &mut [f32], ) -> Result<()> { - debug_assert_eq!(T::BLCK_SIZE, T::VecDotType::BLCK_SIZE); + debug_assert_eq!( + T::BLCK_SIZE, + T::VecDotType::BLCK_SIZE, + "Mismatched block sizes" + ); if m * k != lhs.len() { crate::bail!("unexpected lhs length {} ({m},{k},{n})", lhs.len()); } - let k_in_blocks = k.div_ceil(T::BLCK_SIZE); // TODO: Pre-allocate this. let mut lhs_b = vec![T::VecDotType::zeros(); m * k_in_blocks]; - // f32, f16, and bf16 support direct copy - let lhs_b = if T::DIRECT_COPY { - T::VecDotType::direct_copy(lhs, &mut lhs_b[..])?; - lhs_b.as_slice() + if T::DIRECT_COPY { + T::VecDotType::direct_copy(lhs, &mut lhs_b)?; } else { for row_idx in 0..m { let lhs_b_mut = &mut lhs_b[row_idx * k_in_blocks..(row_idx + 1) * k_in_blocks]; let lhs = &lhs[row_idx * k..(row_idx + 1) * k]; T::VecDotType::from_float(lhs, lhs_b_mut)? } - lhs_b.as_slice() - }; + } for row_idx in 0..m { let lhs_row = &lhs_b[row_idx * k_in_blocks..(row_idx + 1) * k_in_blocks]; From fbb6e17431e7fa9768e214bab84fdd173a9c3c07 Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Thu, 2 Oct 2025 11:19:14 +0200 Subject: [PATCH 8/9] Add more info to quantized matmul test failures --- candle-core/tests/quantized_tests.rs | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/candle-core/tests/quantized_tests.rs b/candle-core/tests/quantized_tests.rs index 8e84c52673..703eb090c9 100644 --- a/candle-core/tests/quantized_tests.rs +++ b/candle-core/tests/quantized_tests.rs @@ -883,27 +883,26 @@ fn ggml_matmul_error_test_(a: &[f32], b: &[f32], err_m: f32) -> Res let reference_result = vec_dot_reference(a, b); - let verify_result = |result: f32| { + let verify_result = |result: f32, source: &str| { let error = (result - reference_result).abs() / length as f32; let ggml_error = ggml_reference_matmul_error(T::DTYPE)? * err_m; if !error.is_finite() || error > GGML_MAX_DOT_PRODUCT_ERROR { - bail!("Dot product error {error} exceeds max error {GGML_MAX_DOT_PRODUCT_ERROR}",); + bail!("Dot product with dtype {:?} error {error} exceeds max error {GGML_MAX_DOT_PRODUCT_ERROR}. Source: {source}", T::DTYPE); } // We diverge slightly due to different rounding behavior / f16 to f32 conversions in GGML // => we use a slightly higher error threshold const ERROR_LENIENCY: f32 = 0.00001; if error - ERROR_LENIENCY > ggml_error { bail!( - "Dot product error {} exceeds ggml reference error {}", - error, - ggml_error + "Dot product with dtype {:?} error {error} exceeds ggml reference error {ggml_error}. Source: {source}", + T::DTYPE, ); } Ok(()) }; - verify_result(result)?; - verify_result(result_matmul)?; + verify_result(result, "vec-dot")?; + verify_result(result_matmul, "matmul")?; Ok(()) } From 96e1b089ad9698eaa831322e6f754416d81c43e1 Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Fri, 3 Oct 2025 21:05:23 +0200 Subject: [PATCH 9/9] Add neon CpuF16 and CpuBF16 implementations --- candle-core/src/cpu/mod.rs | 10 +-- candle-core/src/cpu/neon.rs | 122 +++++++++++++++++++++++++++++++++++- 2 files changed, 126 insertions(+), 6 deletions(-) diff --git a/candle-core/src/cpu/mod.rs b/candle-core/src/cpu/mod.rs index c4864b7a81..3acf6744c4 100644 --- a/candle-core/src/cpu/mod.rs +++ b/candle-core/src/cpu/mod.rs @@ -78,7 +78,7 @@ pub use simd128::CurrentCpu; pub mod neon; #[cfg(any(target_arch = "arm", target_arch = "aarch64"))] #[cfg(target_feature = "neon")] -pub use neon::CurrentCpu; +pub use neon::{CurrentCpu, CurrentCpuBF16, CurrentCpuF16}; #[cfg(any( target_feature = "neon", @@ -163,7 +163,7 @@ pub(crate) unsafe fn vec_sum(row: *const f32, b: *mut f32, k: usize) { } } -#[cfg(target_feature = "avx2")] +#[cfg(any(target_feature = "neon", target_feature = "avx2"))] #[inline(always)] pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f32, k: usize) { let mut sumf = 0.0f32; @@ -191,7 +191,7 @@ pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f *c = sumf; } -#[cfg(target_feature = "avx2")] +#[cfg(any(target_feature = "neon", target_feature = "avx2"))] #[inline(always)] pub(crate) unsafe fn vec_dot_bf16(a_row: *const bf16, b_row: *const bf16, c: *mut f32, k: usize) { let mut sumf = 0.0f32; @@ -219,7 +219,7 @@ pub(crate) unsafe fn vec_dot_bf16(a_row: *const bf16, b_row: *const bf16, c: *mu *c = sumf; } -#[cfg(not(target_feature = "avx2"))] +#[cfg(not(any(target_feature = "avx2", target_feature = "neon")))] #[inline(always)] pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f32, k: usize) { // leftovers @@ -230,7 +230,7 @@ pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f *c = sum; } -#[cfg(not(target_feature = "avx2"))] +#[cfg(not(any(target_feature = "avx2", target_feature = "neon")))] #[inline(always)] pub(crate) unsafe fn vec_dot_bf16(a_row: *const bf16, b_row: *const bf16, c: *mut f32, k: usize) { // leftovers diff --git a/candle-core/src/cpu/neon.rs b/candle-core/src/cpu/neon.rs index 66b8b45e15..0923f81824 100644 --- a/candle-core/src/cpu/neon.rs +++ b/candle-core/src/cpu/neon.rs @@ -1,4 +1,6 @@ -use super::Cpu; +use half::{bf16, f16}; + +use super::{Cpu, CpuBF16, CpuF16}; #[cfg(target_arch = "arm")] use core::arch::arm::*; @@ -72,3 +74,121 @@ impl Cpu for CurrentCpu { *y = Self::reduce_one(x[0]); } } + +pub struct CurrentCpuF16 {} +impl CpuF16 for CurrentCpuF16 { + type Unit = float32x4_t; + type Array = [float32x4_t; ARR]; + + const STEP: usize = STEP; + const EPR: usize = EPR; + + fn n() -> usize { + ARR + } + + unsafe fn zero() -> Self::Unit { + vdupq_n_f32(0.0) + } + + unsafe fn from_f32(x: f32) -> Self::Unit { + vdupq_n_f32(x) + } + + unsafe fn zero_array() -> Self::Array { + [Self::zero(); ARR] + } + + unsafe fn load(mem_addr: *const f16) -> Self::Unit { + let mut tmp = [0.0f32; 8]; + for i in 0..8 { + tmp[i] = (*mem_addr.add(i)).to_f32(); + } + vld1q_f32(tmp.as_ptr()) + } + + unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit { + vaddq_f32(a, b) + } + + unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit { + vfmaq_f32(a, b, c) + } + + unsafe fn vec_store(mem_addr: *mut f16, a: Self::Unit) { + let mut tmp = [0.0f32; 8]; + vst1q_f32(tmp.as_mut_ptr(), a); + for i in 0..8 { + *mem_addr.add(i) = f16::from_f32(tmp[i]); + } + } + + unsafe fn vec_reduce(mut x: Self::Array, y: *mut f32) { + for i in 0..ARR / 2 { + x[2 * i] = vaddq_f32(x[2 * i], x[2 * i + 1]); + } + for i in 0..ARR / 4 { + x[4 * i] = vaddq_f32(x[4 * i], x[4 * i + 2]); + } + *y = CurrentCpu::reduce_one(x[0]); + } +} + +pub struct CurrentCpuBF16 {} +impl CpuBF16 for CurrentCpuBF16 { + type Unit = float32x4_t; + type Array = [float32x4_t; ARR]; + + const STEP: usize = STEP; + const EPR: usize = EPR; + + fn n() -> usize { + ARR + } + + unsafe fn zero() -> Self::Unit { + vdupq_n_f32(0.0) + } + + unsafe fn from_f32(x: f32) -> Self::Unit { + vdupq_n_f32(x) + } + + unsafe fn zero_array() -> Self::Array { + [Self::zero(); ARR] + } + + unsafe fn load(mem_addr: *const bf16) -> Self::Unit { + let mut tmp = [0.0f32; 8]; + for i in 0..8 { + tmp[i] = (*mem_addr.add(i)).to_f32(); + } + vld1q_f32(tmp.as_ptr()) + } + + unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit { + vaddq_f32(a, b) + } + + unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit { + vfmaq_f32(a, b, c) + } + + unsafe fn vec_store(mem_addr: *mut bf16, a: Self::Unit) { + let mut tmp = [0.0f32; 8]; + vst1q_f32(tmp.as_mut_ptr(), a); + for i in 0..8 { + *mem_addr.add(i) = bf16::from_f32(tmp[i]); + } + } + + unsafe fn vec_reduce(mut x: Self::Array, y: *mut f32) { + for i in 0..ARR / 2 { + x[2 * i] = vaddq_f32(x[2 * i], x[2 * i + 1]); + } + for i in 0..ARR / 4 { + x[4 * i] = vaddq_f32(x[4 * i], x[4 * i + 2]); + } + *y = CurrentCpu::reduce_one(x[0]); + } +}