diff --git a/candle-core/src/cpu/mod.rs b/candle-core/src/cpu/mod.rs index c4864b7a81..3acf6744c4 100644 --- a/candle-core/src/cpu/mod.rs +++ b/candle-core/src/cpu/mod.rs @@ -78,7 +78,7 @@ pub use simd128::CurrentCpu; pub mod neon; #[cfg(any(target_arch = "arm", target_arch = "aarch64"))] #[cfg(target_feature = "neon")] -pub use neon::CurrentCpu; +pub use neon::{CurrentCpu, CurrentCpuBF16, CurrentCpuF16}; #[cfg(any( target_feature = "neon", @@ -163,7 +163,7 @@ pub(crate) unsafe fn vec_sum(row: *const f32, b: *mut f32, k: usize) { } } -#[cfg(target_feature = "avx2")] +#[cfg(any(target_feature = "neon", target_feature = "avx2"))] #[inline(always)] pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f32, k: usize) { let mut sumf = 0.0f32; @@ -191,7 +191,7 @@ pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f *c = sumf; } -#[cfg(target_feature = "avx2")] +#[cfg(any(target_feature = "neon", target_feature = "avx2"))] #[inline(always)] pub(crate) unsafe fn vec_dot_bf16(a_row: *const bf16, b_row: *const bf16, c: *mut f32, k: usize) { let mut sumf = 0.0f32; @@ -219,7 +219,7 @@ pub(crate) unsafe fn vec_dot_bf16(a_row: *const bf16, b_row: *const bf16, c: *mu *c = sumf; } -#[cfg(not(target_feature = "avx2"))] +#[cfg(not(any(target_feature = "avx2", target_feature = "neon")))] #[inline(always)] pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f32, k: usize) { // leftovers @@ -230,7 +230,7 @@ pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f *c = sum; } -#[cfg(not(target_feature = "avx2"))] +#[cfg(not(any(target_feature = "avx2", target_feature = "neon")))] #[inline(always)] pub(crate) unsafe fn vec_dot_bf16(a_row: *const bf16, b_row: *const bf16, c: *mut f32, k: usize) { // leftovers diff --git a/candle-core/src/cpu/neon.rs b/candle-core/src/cpu/neon.rs index 66b8b45e15..0923f81824 100644 --- a/candle-core/src/cpu/neon.rs +++ b/candle-core/src/cpu/neon.rs @@ -1,4 +1,6 @@ -use super::Cpu; +use half::{bf16, f16}; + +use super::{Cpu, CpuBF16, CpuF16}; #[cfg(target_arch = "arm")] use core::arch::arm::*; @@ -72,3 +74,121 @@ impl Cpu for CurrentCpu { *y = Self::reduce_one(x[0]); } } + +pub struct CurrentCpuF16 {} +impl CpuF16 for CurrentCpuF16 { + type Unit = float32x4_t; + type Array = [float32x4_t; ARR]; + + const STEP: usize = STEP; + const EPR: usize = EPR; + + fn n() -> usize { + ARR + } + + unsafe fn zero() -> Self::Unit { + vdupq_n_f32(0.0) + } + + unsafe fn from_f32(x: f32) -> Self::Unit { + vdupq_n_f32(x) + } + + unsafe fn zero_array() -> Self::Array { + [Self::zero(); ARR] + } + + unsafe fn load(mem_addr: *const f16) -> Self::Unit { + let mut tmp = [0.0f32; 8]; + for i in 0..8 { + tmp[i] = (*mem_addr.add(i)).to_f32(); + } + vld1q_f32(tmp.as_ptr()) + } + + unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit { + vaddq_f32(a, b) + } + + unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit { + vfmaq_f32(a, b, c) + } + + unsafe fn vec_store(mem_addr: *mut f16, a: Self::Unit) { + let mut tmp = [0.0f32; 8]; + vst1q_f32(tmp.as_mut_ptr(), a); + for i in 0..8 { + *mem_addr.add(i) = f16::from_f32(tmp[i]); + } + } + + unsafe fn vec_reduce(mut x: Self::Array, y: *mut f32) { + for i in 0..ARR / 2 { + x[2 * i] = vaddq_f32(x[2 * i], x[2 * i + 1]); + } + for i in 0..ARR / 4 { + x[4 * i] = vaddq_f32(x[4 * i], x[4 * i + 2]); + } + *y = CurrentCpu::reduce_one(x[0]); + } +} + +pub struct CurrentCpuBF16 {} +impl CpuBF16 for CurrentCpuBF16 { + type Unit = float32x4_t; + type Array = [float32x4_t; ARR]; + + const STEP: usize = STEP; + const EPR: usize = EPR; + + fn n() -> usize { + ARR + } + + unsafe fn zero() -> Self::Unit { + vdupq_n_f32(0.0) + } + + unsafe fn from_f32(x: f32) -> Self::Unit { + vdupq_n_f32(x) + } + + unsafe fn zero_array() -> Self::Array { + [Self::zero(); ARR] + } + + unsafe fn load(mem_addr: *const bf16) -> Self::Unit { + let mut tmp = [0.0f32; 8]; + for i in 0..8 { + tmp[i] = (*mem_addr.add(i)).to_f32(); + } + vld1q_f32(tmp.as_ptr()) + } + + unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit { + vaddq_f32(a, b) + } + + unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit { + vfmaq_f32(a, b, c) + } + + unsafe fn vec_store(mem_addr: *mut bf16, a: Self::Unit) { + let mut tmp = [0.0f32; 8]; + vst1q_f32(tmp.as_mut_ptr(), a); + for i in 0..8 { + *mem_addr.add(i) = bf16::from_f32(tmp[i]); + } + } + + unsafe fn vec_reduce(mut x: Self::Array, y: *mut f32) { + for i in 0..ARR / 2 { + x[2 * i] = vaddq_f32(x[2 * i], x[2 * i + 1]); + } + for i in 0..ARR / 4 { + x[4 * i] = vaddq_f32(x[4 * i], x[4 * i + 2]); + } + *y = CurrentCpu::reduce_one(x[0]); + } +}