Skip to content

Commit 95ea453

Browse files
EricLBuehlerivarflakstadanonenityharicot
authored
Add more misc. changes from candle fork (#3196)
* Merge with fork Co-authored-by Guoqing Bao <[email protected]> * Update sdpa * Fix flash attn bf16 case * Metal fixes * Add metal methods * Add new_private_buffer * Fix metal tests * Format * Apply review comments * Update CI (#3194) * Update CI * I have no clue what was going on with this maturin file, but I don't like it * update cuda container options * Add compute cap to cuda wf * Fix rust toolchain call * update cuda ci runner and bindgen_cuda * Add initial support for imatrix quantization (#3193) * add clear kv cache to quantized qwen3 weights (#3189) * Fix metal bug * Apply review comments * Fix merge * Add lld installation and test steps for Linux (#3213) --------- Co-authored-by: ivarflakstad <[email protected]> Co-authored-by: anonenity <[email protected]> Co-authored-by: Nicolas PASCAL <[email protected]>
1 parent 01bea21 commit 95ea453

File tree

15 files changed

+2568
-967
lines changed

15 files changed

+2568
-967
lines changed

candle-core/src/error.rs

Lines changed: 70 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
//! Candle-specific Error and Result
2+
use std::{convert::Infallible, fmt::Display};
3+
24
use crate::{DType, DeviceLocation, Layout, MetalError, Shape};
35

46
#[derive(Debug, Clone)]
@@ -209,6 +211,13 @@ pub enum Error {
209211
#[error("{0}")]
210212
Wrapped(Box<dyn std::fmt::Display + Send + Sync>),
211213

214+
/// Arbitrary errors wrapping with context.
215+
#[error("{wrapped:?}\n{context:?}")]
216+
WrappedContext {
217+
wrapped: Box<dyn std::error::Error + Send + Sync>,
218+
context: String,
219+
},
220+
212221
#[error("{context}\n{inner}")]
213222
Context {
214223
inner: Box<Self>,
@@ -299,40 +308,87 @@ pub fn zip<T, U>(r1: Result<T>, r2: Result<U>) -> Result<(T, U)> {
299308
}
300309
}
301310

302-
// Taken from anyhow.
303-
pub trait Context<T> {
311+
pub(crate) mod private {
312+
pub trait Sealed {}
313+
314+
impl<T, E> Sealed for std::result::Result<T, E> where E: std::error::Error {}
315+
impl<T> Sealed for Option<T> {}
316+
}
317+
318+
/// Attach more context to an error.
319+
///
320+
/// Inspired by [`anyhow::Context`].
321+
pub trait Context<T, E>: private::Sealed {
304322
/// Wrap the error value with additional context.
305-
fn context<C>(self, context: C) -> Result<T>
323+
fn context<C>(self, context: C) -> std::result::Result<T, Error>
306324
where
307-
C: std::fmt::Display + Send + Sync + 'static;
325+
C: Display + Send + Sync + 'static;
308326

309327
/// Wrap the error value with additional context that is evaluated lazily
310328
/// only once an error does occur.
311-
fn with_context<C, F>(self, f: F) -> Result<T>
329+
fn with_context<C, F>(self, f: F) -> std::result::Result<T, Error>
312330
where
313-
C: std::fmt::Display + Send + Sync + 'static,
331+
C: Display + Send + Sync + 'static,
314332
F: FnOnce() -> C;
315333
}
316334

317-
impl<T> Context<T> for Option<T> {
318-
fn context<C>(self, context: C) -> Result<T>
335+
impl<T, E> Context<T, E> for std::result::Result<T, E>
336+
where
337+
E: std::error::Error + Send + Sync + 'static,
338+
{
339+
fn context<C>(self, context: C) -> std::result::Result<T, Error>
319340
where
320-
C: std::fmt::Display + Send + Sync + 'static,
341+
C: Display + Send + Sync + 'static,
321342
{
343+
// Not using map_err to save 2 useless frames off the captured backtrace
344+
// in ext_context.
322345
match self {
323-
Some(v) => Ok(v),
324-
None => Err(Error::UnwrapNone.context(context).bt()),
346+
Ok(ok) => Ok(ok),
347+
Err(error) => Err(Error::WrappedContext {
348+
wrapped: Box::new(error),
349+
context: context.to_string(),
350+
}
351+
.bt()),
352+
}
353+
}
354+
355+
fn with_context<C, F>(self, context: F) -> std::result::Result<T, Error>
356+
where
357+
C: Display + Send + Sync + 'static,
358+
F: FnOnce() -> C,
359+
{
360+
match self {
361+
Ok(ok) => Ok(ok),
362+
Err(error) => Err(Error::WrappedContext {
363+
wrapped: Box::new(error),
364+
context: context().to_string(),
365+
}
366+
.bt()),
367+
}
368+
}
369+
}
370+
371+
impl<T> Context<T, Infallible> for Option<T> {
372+
fn context<C>(self, context: C) -> std::result::Result<T, Error>
373+
where
374+
C: Display + Send + Sync + 'static,
375+
{
376+
// Not using ok_or_else to save 2 useless frames off the captured
377+
// backtrace.
378+
match self {
379+
Some(ok) => Ok(ok),
380+
None => Err(Error::msg(context).bt()),
325381
}
326382
}
327383

328-
fn with_context<C, F>(self, f: F) -> Result<T>
384+
fn with_context<C, F>(self, context: F) -> std::result::Result<T, Error>
329385
where
330-
C: std::fmt::Display + Send + Sync + 'static,
386+
C: Display + Send + Sync + 'static,
331387
F: FnOnce() -> C,
332388
{
333389
match self {
334390
Some(v) => Ok(v),
335-
None => Err(Error::UnwrapNone.context(f()).bt()),
391+
None => Err(Error::UnwrapNone.context(context()).bt()),
336392
}
337393
}
338394
}

candle-core/src/metal_backend/device.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,12 @@ pub const RESOURCE_OPTIONS: MTLResourceOptions =
6767
//| MTLResourceOptions::HazardTrackingModeUntracked.bits(),
6868
//);
6969

70+
// Resource options used for `new_private_buffer`. This uses `private` where supported.
71+
#[cfg(target_os = "ios")]
72+
pub const PRIVATE_RESOURCE_OPTIONS: MTLResourceOptions = MTLResourceOptions::StorageModeShared;
73+
#[cfg(not(target_os = "ios"))]
74+
pub const PRIVATE_RESOURCE_OPTIONS: MTLResourceOptions = MTLResourceOptions::StorageModePrivate;
75+
7076
impl std::fmt::Debug for MetalDevice {
7177
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
7278
write!(f, "MetalDevice({:?})", self.id)
@@ -169,6 +175,23 @@ impl MetalDevice {
169175
self.allocate_buffer(size)
170176
}
171177

178+
/// Creates a new private buffer (not necessarily zeroed).
179+
///
180+
/// This is intentionally not in the Metal buffer pool to allow the efficient implementation of persistent buffers.
181+
pub fn new_private_buffer(
182+
&self,
183+
element_count: usize,
184+
dtype: DType,
185+
_name: &str,
186+
) -> Result<Arc<Buffer>> {
187+
let size = element_count * dtype.size_in_bytes();
188+
let buffer = self
189+
.device
190+
.new_buffer(size, PRIVATE_RESOURCE_OPTIONS)
191+
.map_err(MetalError::from)?;
192+
Ok(Arc::new(buffer))
193+
}
194+
172195
/// Creates a new buffer from data.
173196
///
174197
/// Does not require synchronization, as [newBufferWithBytes](https://developer.apple.com/documentation/metal/mtldevice/1433429-newbufferwithbytes)

candle-core/src/quantized/cuda.rs

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,125 @@ fn mul_mat_via_q8_1(
406406
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
407407
}
408408

409+
fn indexed_moe_forward_fused_q8_1_input(
410+
weight: &CudaView<u8>,
411+
w_shape: &crate::Shape, //[num_experts, n, k]
412+
w_dtype: GgmlDType,
413+
input: &CudaSlice<f32>,
414+
in_shape: &crate::Shape, //[batch, topk or 1, k]
415+
ids: &CudaView<u32>,
416+
idx_shape: &crate::Shape, //[batch, topk]
417+
dev: &CudaDevice,
418+
) -> Result<(CudaStorage, crate::Shape)> {
419+
let (_, n, k) = w_shape.dims3()?;
420+
let batch = in_shape.dims()[0];
421+
let input_dim1 = in_shape.dims()[1];
422+
423+
let topk = idx_shape.dims()[1];
424+
assert!(batch == idx_shape.dims()[0], "batch dim not match!");
425+
426+
// Quantize input into q8_1.
427+
let total_rows = batch * input_dim1;
428+
let k_padded = pad(k, MATRIX_ROW_PADDING);
429+
// Get Q8_1 metadata.
430+
let q8_1_block_size = GgmlDType::Q8_1.block_size();
431+
let q8_1_type_size = GgmlDType::Q8_1.type_size();
432+
433+
// Calculate the size of the output buffer in bytes.
434+
let num_blocks_per_row = k_padded / q8_1_block_size;
435+
let dst_row_size_bytes = num_blocks_per_row * q8_1_type_size;
436+
let y_size_in_bytes = total_rows * dst_row_size_bytes;
437+
let mut input_quant = unsafe { dev.alloc::<u8>(y_size_in_bytes)? };
438+
439+
let input_view = input.slice(0..);
440+
quantize_q8_1(&input_view, &mut input_quant, k, total_rows, dev)?;
441+
442+
// output buffer
443+
let outsize = batch * topk * n;
444+
let out = unsafe { dev.alloc::<f32>(outsize)? };
445+
446+
let kernel_name = match w_dtype {
447+
GgmlDType::Q2K => "indexed_moe_forward_q2k_q8_1",
448+
GgmlDType::Q3K => "indexed_moe_forward_q3k_q8_1",
449+
GgmlDType::Q4K => "indexed_moe_forward_q4k_q8_1",
450+
GgmlDType::Q5K => "indexed_moe_forward_q5k_q8_1",
451+
GgmlDType::Q6K => "indexed_moe_forward_q6k_q8_1",
452+
GgmlDType::Q8_0 => "indexed_moe_forward_q8_0_q8_1",
453+
_ => crate::bail!("unsupported dtype for indexed_moe_forward {w_dtype:?}"),
454+
};
455+
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
456+
let (nblocks, nwarps) = (n as u32, 4);
457+
let cfg = cudarc::driver::LaunchConfig {
458+
grid_dim: (nblocks, batch as u32, topk as u32),
459+
block_dim: (WARP_SIZE as u32, nwarps, 1),
460+
shared_mem_bytes: 0,
461+
};
462+
463+
let mut builder = func.builder();
464+
builder.arg(weight);
465+
builder.arg(&input_quant);
466+
builder.arg(ids);
467+
builder.arg(&out);
468+
469+
barg!(
470+
builder,
471+
n as i32,
472+
k as i32,
473+
batch as i32,
474+
topk as i32,
475+
k_padded as i32,
476+
input_dim1 as i32
477+
);
478+
unsafe { builder.launch(cfg) }.w()?;
479+
480+
let mut out_shape = in_shape.dims().to_vec();
481+
out_shape.pop();
482+
out_shape.push(n);
483+
out_shape[1] = topk;
484+
Ok((
485+
CudaStorage::wrap_cuda_slice(out, dev.clone()),
486+
out_shape.into(),
487+
))
488+
}
489+
409490
impl QCudaStorage {
491+
pub fn indexed_moe_forward(
492+
&self,
493+
self_shape: &crate::Shape, //[num_experts, n, k]
494+
input: &CudaStorage, //[batch, topk or 1, k]
495+
input_l: &crate::Layout,
496+
ids: &CudaStorage, //[batch, topk]
497+
ids_l: &crate::Layout,
498+
) -> Result<(CudaStorage, crate::Shape)> {
499+
if matches!(
500+
self.dtype(),
501+
GgmlDType::Q8_0
502+
| GgmlDType::Q2K
503+
| GgmlDType::Q3K
504+
| GgmlDType::Q4K
505+
| GgmlDType::Q5K
506+
| GgmlDType::Q6K
507+
) {
508+
let input_storage = input.as_cuda_slice::<f32>()?;
509+
let ids_storage = ids.as_cuda_slice::<u32>()?;
510+
indexed_moe_forward_fused_q8_1_input(
511+
&self.data.inner.slice(0..),
512+
self_shape, //[num_experts, n, k]
513+
self.dtype(),
514+
&input_storage,
515+
input_l.shape(), //[batch, topk or 1, k]
516+
&ids_storage.slice(0..),
517+
ids_l.shape(), //[batch, topk]
518+
&self.device,
519+
)
520+
} else {
521+
crate::bail!(
522+
"The given quantized dtype {:?} is not supported for indexed_moe_forward!",
523+
self.dtype()
524+
);
525+
}
526+
}
527+
410528
pub fn zeros(device: &CudaDevice, el_count: usize, dtype: GgmlDType) -> Result<Self> {
411529
let size_in_bytes = ceil_div(el_count, dtype.block_size()) * dtype.type_size();
412530
let padded_size_in_bytes =

candle-core/src/quantized/dummy_cuda.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,17 @@ impl QCudaStorage {
7070
pub fn data(&self) -> Result<Vec<u8>> {
7171
Err(Error::NotCompiledWithCudaSupport)
7272
}
73+
74+
pub fn indexed_moe_forward(
75+
&self,
76+
_: &crate::Shape,
77+
_: &CudaStorage,
78+
_: &crate::Layout,
79+
_: &CudaStorage,
80+
_: &crate::Layout,
81+
) -> Result<(CudaStorage, crate::Shape)> {
82+
Err(Error::NotCompiledWithCudaSupport)
83+
}
7384
}
7485

7586
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(

candle-core/src/quantized/dummy_metal.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,17 @@ impl QMetalStorage {
6666
pub fn data(&self) -> Result<Vec<u8>> {
6767
Err(Error::NotCompiledWithMetalSupport)
6868
}
69+
70+
pub fn indexed_moe_forward(
71+
&self,
72+
_: &crate::Shape,
73+
_: &MetalStorage,
74+
_: &crate::Layout,
75+
_: &MetalStorage,
76+
_: &crate::Layout,
77+
) -> Result<(MetalStorage, crate::Shape)> {
78+
Err(Error::NotCompiledWithMetalSupport)
79+
}
6980
}
7081

7182
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(

candle-core/src/quantized/mod.rs

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,34 @@ impl QTensor {
642642
pub fn data(&self) -> Result<Cow<'_, [u8]>> {
643643
self.storage.data()
644644
}
645+
646+
pub fn indexed_moe_forward(&self, x: &Tensor, ids: &Tensor) -> Result<Tensor> {
647+
match &self.storage {
648+
QStorage::Cuda(s) => match (&*x.storage(), &*ids.storage()) {
649+
(Storage::Cuda(x_storage), Storage::Cuda(ids_storage)) => {
650+
let (storage, out_shape) = s.indexed_moe_forward(
651+
self.shape(),
652+
x_storage,
653+
x.layout(),
654+
ids_storage,
655+
ids.layout(),
656+
)?;
657+
Ok(crate::tensor::from_storage(
658+
Storage::Cuda(storage),
659+
out_shape,
660+
crate::op::BackpropOp::none(),
661+
false,
662+
))
663+
}
664+
_ => {
665+
panic!("Non-cuda indexed_moe_forward is not implemented!");
666+
}
667+
},
668+
_ => {
669+
panic!("indexed_moe_forward is not implemented in this platform!");
670+
}
671+
}
672+
}
645673
}
646674

647675
#[derive(Clone, Debug)]
@@ -713,6 +741,15 @@ impl QMatMul {
713741
};
714742
xs.to_dtype(DType::F16)?.matmul(&w)?.to_dtype(in_dtype)
715743
}
744+
745+
pub fn indexed_moe_forward(&self, x: &Tensor, ids: &Tensor) -> Result<Tensor> {
746+
match self {
747+
Self::QTensor(t) => t.indexed_moe_forward(x, ids),
748+
_ => {
749+
panic!("Not implemented!")
750+
}
751+
}
752+
}
716753
}
717754

718755
impl crate::CustomOp1 for QTensor {

0 commit comments

Comments
 (0)