Skip to content

Commit 549eacb

Browse files
authored
Add initial support for imatrix quantization (#3193)
1 parent ab56dfe commit 549eacb

File tree

13 files changed

+1426
-45
lines changed

13 files changed

+1426
-45
lines changed

candle-core/src/cuda_backend/device.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,18 @@ impl CudaDevice {
9494
self.stream.memcpy_dtod(src, dst).w()
9595
}
9696

97+
pub fn memcpy_dtoh<
98+
T: cudarc::driver::DeviceRepr,
99+
Src: cudarc::driver::DevicePtr<T>,
100+
Dst: cudarc::driver::HostSlice<T>,
101+
>(
102+
&self,
103+
src: &Src,
104+
dst: &mut Dst,
105+
) -> Result<()> {
106+
self.stream.memcpy_dtoh(src, dst).w()
107+
}
108+
97109
pub fn memcpy_stod<
98110
T: cudarc::driver::DeviceRepr,
99111
Src: cudarc::driver::HostSlice<T> + ?Sized,

candle-core/src/metal_backend/device.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ impl MetalDevice {
125125
}
126126

127127
pub fn command_encoder(&self) -> Result<ComputeCommandEncoder> {
128-
let mut commands = self.commands.write().map_err(MetalError::from)?;
128+
let commands = self.commands.write().map_err(MetalError::from)?;
129129
let (flush, command_encoder) = commands.command_encoder().map_err(MetalError::from)?;
130130
if flush {
131131
self.drop_unused_buffers()?
@@ -134,7 +134,7 @@ impl MetalDevice {
134134
}
135135

136136
pub fn blit_command_encoder(&self) -> Result<BlitCommandEncoder> {
137-
let mut commands = self.commands.write().map_err(MetalError::from)?;
137+
let commands = self.commands.write().map_err(MetalError::from)?;
138138
let (flush, command_encoder) = commands.blit_command_encoder().map_err(MetalError::from)?;
139139
if flush {
140140
self.drop_unused_buffers()?
@@ -143,7 +143,7 @@ impl MetalDevice {
143143
}
144144

145145
pub fn wait_until_completed(&self) -> Result<()> {
146-
let mut commands = self.commands.write().map_err(MetalError::from)?;
146+
let commands = self.commands.write().map_err(MetalError::from)?;
147147
commands.wait_until_completed().map_err(MetalError::from)?;
148148
Ok(())
149149
}

candle-core/src/quantized/cuda.rs

Lines changed: 143 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -46,24 +46,57 @@ fn pad(p: usize, q: usize) -> usize {
4646
fn quantize_q8_1(
4747
src: &CudaView<f32>,
4848
dst: &mut CudaSlice<u8>,
49-
elem_count: usize,
49+
k: usize,
5050
ky: usize,
5151
dev: &CudaDevice,
5252
) -> Result<()> {
53-
let kx = elem_count;
54-
let kx_padded = pad(kx, MATRIX_ROW_PADDING);
53+
let kx_padded = pad(k, MATRIX_ROW_PADDING);
5554
let num_blocks = ceil_div(kx_padded, CUDA_QUANTIZE_BLOCK_SIZE);
55+
56+
let total_rows = ky;
57+
// Get Q8_1 metadata.
58+
let q8_1_block_size = GgmlDType::Q8_1.block_size();
59+
let q8_1_type_size = GgmlDType::Q8_1.type_size();
60+
61+
// Calculate the size of the output buffer in bytes.
62+
let num_blocks_per_row = kx_padded / q8_1_block_size;
63+
let dst_row_size_bytes = num_blocks_per_row * q8_1_type_size;
64+
65+
const CHUNK_SIZE: usize = 65535; // gridDim.y limit
5666
let func = dev.get_or_load_func("quantize_q8_1", &candle_kernels::QUANTIZED)?;
57-
let cfg = cudarc::driver::LaunchConfig {
58-
grid_dim: (num_blocks as u32, ky as u32, 1),
59-
block_dim: (CUDA_QUANTIZE_BLOCK_SIZE as u32, 1, 1),
60-
shared_mem_bytes: 0,
61-
};
62-
let mut builder = func.builder();
63-
builder.arg(src);
64-
builder.arg(dst);
65-
barg!(builder, kx as i32, kx_padded as i32);
66-
unsafe { builder.launch(cfg) }.w()?;
67+
68+
let mut rows_processed = 0;
69+
while rows_processed < total_rows {
70+
// --- calculate the number of rows for this chunk ---
71+
let remaining_rows = total_rows - rows_processed;
72+
// This is our gridDim.y, now <= 65535
73+
let rows_in_chunk = std::cmp::min(CHUNK_SIZE, remaining_rows);
74+
75+
// --- slice the source (f32) tensor by elements ---
76+
let src_start_elem = rows_processed * k;
77+
let src_num_elems = rows_in_chunk * k;
78+
let src_chunk = src.slice(src_start_elem..(src_start_elem + src_num_elems));
79+
80+
// --- slice the destination (u8) tensor by bytes ---
81+
let dst_start_byte = rows_processed * dst_row_size_bytes;
82+
let dst_num_bytes = rows_in_chunk * dst_row_size_bytes;
83+
let dst_chunk = dst.slice(dst_start_byte..(dst_start_byte + dst_num_bytes));
84+
85+
let cfg = cudarc::driver::LaunchConfig {
86+
grid_dim: (num_blocks as u32, rows_in_chunk as u32, 1),
87+
block_dim: (CUDA_QUANTIZE_BLOCK_SIZE as u32, 1, 1),
88+
shared_mem_bytes: 0,
89+
};
90+
91+
let mut builder = func.builder();
92+
builder.arg(&src_chunk);
93+
builder.arg(&dst_chunk);
94+
barg!(builder, k as i32, kx_padded as i32);
95+
unsafe { builder.launch(cfg) }.w()?;
96+
97+
rows_processed += rows_in_chunk;
98+
}
99+
67100
Ok(())
68101
}
69102

@@ -477,6 +510,87 @@ impl QCudaStorage {
477510
Ok(())
478511
}
479512

513+
pub fn quantize_imatrix(
514+
&mut self,
515+
src: &CudaStorage,
516+
imatrix_weights: &[f32],
517+
n_per_row: usize,
518+
) -> Result<()> {
519+
// Run the quantization on cpu.
520+
let src = match &src.slice {
521+
crate::cuda_backend::CudaStorageSlice::F32(data) => self.device.memcpy_dtov(data)?,
522+
_ => crate::bail!("only f32 can be quantized"),
523+
};
524+
let src_len = src.len();
525+
let src = crate::Storage::Cpu(crate::CpuStorage::F32(src));
526+
let mut qcpu_storage = crate::Device::Cpu.qzeros(src_len, self.dtype)?;
527+
qcpu_storage.quantize_imatrix(&src, imatrix_weights, n_per_row)?;
528+
let data = qcpu_storage.data()?;
529+
let padded_len =
530+
data.len() + MATRIX_ROW_PADDING * self.dtype.type_size() / self.dtype.block_size();
531+
let mut inner = unsafe { self.device.alloc::<u8>(padded_len)? };
532+
self.device
533+
.memcpy_htod(data.as_ref(), &mut inner.slice_mut(..data.len()))?;
534+
self.data = PaddedCudaSlice {
535+
inner,
536+
len: data.len(),
537+
};
538+
Ok(())
539+
}
540+
541+
pub fn quantize_imatrix_onto(
542+
&mut self,
543+
src: &crate::CpuStorage,
544+
imatrix_weights: &[f32],
545+
n_per_row: usize,
546+
) -> Result<()> {
547+
// Run the quantization on cpu.
548+
let src_len = src.as_slice::<f32>()?.len();
549+
let mut qcpu_storage = crate::Device::Cpu.qzeros(src_len, self.dtype)?;
550+
551+
if let QStorage::Cpu(storage) = &mut qcpu_storage {
552+
storage.from_float_imatrix(src.as_slice::<f32>()?, imatrix_weights, n_per_row);
553+
} else {
554+
unreachable!()
555+
}
556+
557+
let data = qcpu_storage.data()?;
558+
let padded_len =
559+
data.len() + MATRIX_ROW_PADDING * self.dtype.type_size() / self.dtype.block_size();
560+
let mut inner = unsafe { self.device.alloc::<u8>(padded_len)? };
561+
self.device
562+
.memcpy_htod(data.as_ref(), &mut inner.slice_mut(..data.len()))?;
563+
self.data = PaddedCudaSlice {
564+
inner,
565+
len: data.len(),
566+
};
567+
Ok(())
568+
}
569+
570+
pub fn quantize_onto(&mut self, src: &crate::CpuStorage) -> Result<()> {
571+
// Run the quantization on cpu.
572+
let src_len = src.as_slice::<f32>()?.len();
573+
let mut qcpu_storage = crate::Device::Cpu.qzeros(src_len, self.dtype)?;
574+
575+
if let QStorage::Cpu(storage) = &mut qcpu_storage {
576+
storage.from_float(src.as_slice::<f32>()?);
577+
} else {
578+
unreachable!()
579+
}
580+
581+
let data = qcpu_storage.data()?;
582+
let padded_len =
583+
data.len() + MATRIX_ROW_PADDING * self.dtype.type_size() / self.dtype.block_size();
584+
let mut inner = unsafe { self.device.alloc::<u8>(padded_len)? };
585+
self.device
586+
.memcpy_htod(data.as_ref(), &mut inner.slice_mut(..data.len()))?;
587+
self.data = PaddedCudaSlice {
588+
inner,
589+
len: data.len(),
590+
};
591+
Ok(())
592+
}
593+
480594
pub fn storage_size_in_bytes(&self) -> usize {
481595
self.data.len
482596
}
@@ -503,6 +617,13 @@ impl QCudaStorage {
503617
self.dequantize_matmul(self_shape, storage, layout)
504618
}
505619
}
620+
621+
pub fn data(&self) -> Result<Vec<u8>> {
622+
let mut out = vec![0u8; self.data.len];
623+
self.device
624+
.memcpy_dtoh(&self.data.inner.slice(..self.data.len), &mut out)?;
625+
Ok(out)
626+
}
506627
}
507628

508629
impl QCudaStorage {
@@ -629,7 +750,7 @@ mod test {
629750
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes)? };
630751
let vs: Vec<f32> = (0..el).map(|v| v as f32).collect();
631752
let y = dev.memcpy_stod(&vs)?;
632-
quantize_q8_1(&y.slice(..), &mut y_q8_1, el, 1, &dev)?;
753+
quantize_q8_1(&y.as_view(), &mut y_q8_1, el, 1, &dev)?;
633754
Ok(())
634755
}
635756

@@ -643,30 +764,30 @@ mod test {
643764
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
644765
let cuda_storage = mul_mat_vec_via_q8_1(
645766
&xs.data,
646-
&y.slice(..),
767+
&y.as_view(),
647768
/* dtype */ GgmlDType::Q4_0,
648769
/* ncols */ ncols,
649770
/* nrows */ 1,
650771
/* b_size */ 1,
651772
&dev,
652773
)?;
653774
let vs = cuda_storage.as_cuda_slice::<f32>()?;
654-
let vs = dev.memcpy_dtov(&vs.slice(..))?;
775+
let vs = dev.memcpy_dtov(&vs.as_view())?;
655776
assert_eq!(vs.len(), 1);
656777
// for n = 255, n.(n+1).(2n+1) / 6 = 5559680
657778
// Q8 means 1/256 precision.
658779
assert_eq!(vs[0], 5561664.5);
659780

660781
let cuda_storage = dequantize_mul_mat_vec(
661782
&xs.data,
662-
&y.slice(..),
783+
&y.as_view(),
663784
/* dtype */ GgmlDType::Q4_0,
664785
/* ncols */ ncols,
665786
/* nrows */ 1,
666787
&dev,
667788
)?;
668789
let vs = cuda_storage.as_cuda_slice::<f32>()?;
669-
let vs = dev.memcpy_dtov(&vs.slice(..))?;
790+
let vs = dev.memcpy_dtov(&vs.as_view())?;
670791
assert_eq!(vs.len(), 1);
671792
assert_eq!(vs[0], 5561851.0);
672793
Ok(())
@@ -682,7 +803,7 @@ mod test {
682803
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
683804
let cuda_storage = mul_mat_via_q8_1(
684805
&xs.data,
685-
&y.slice(..),
806+
&y.as_view(),
686807
/* dtype */ GgmlDType::Q4_0,
687808
/* x_rows */ 4,
688809
/* x_cols */ ncols,
@@ -691,7 +812,7 @@ mod test {
691812
&dev,
692813
)?;
693814
let vs = cuda_storage.as_cuda_slice::<f32>()?;
694-
let vs = dev.memcpy_dtov(&vs.slice(..))?;
815+
let vs = dev.memcpy_dtov(&vs.as_view())?;
695816

696817
/*
697818
x = torch.tensor([float(v) for v in range(1024)]).reshape(4, 256)
@@ -723,7 +844,7 @@ mod test {
723844
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
724845
let cuda_storage = mul_mat_via_q8_1(
725846
&xs.data,
726-
&y.slice(..),
847+
&y.as_view(),
727848
/* dtype */ GgmlDType::Q4_0,
728849
/* x_rows */ x_rows,
729850
/* x_cols */ ncols,
@@ -732,7 +853,7 @@ mod test {
732853
&dev,
733854
)?;
734855
let vs = cuda_storage.as_cuda_slice::<f32>()?;
735-
let _vs = dev.memcpy_dtov(&vs.slice(..))?;
856+
let _vs = dev.memcpy_dtov(&vs.as_view())?;
736857
Ok(())
737858
}
738859
}

candle-core/src/quantized/dummy_cuda.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,28 @@ impl QCudaStorage {
3232
Err(Error::NotCompiledWithCudaSupport)
3333
}
3434

35+
pub fn quantize_imatrix(
36+
&mut self,
37+
_src: &CudaStorage,
38+
_imatrix_weights: &[f32],
39+
_n_per_row: usize,
40+
) -> Result<()> {
41+
Err(Error::NotCompiledWithCudaSupport)
42+
}
43+
44+
pub fn quantize_imatrix_onto(
45+
&mut self,
46+
_src: &crate::CpuStorage,
47+
_imatrix_weights: &[f32],
48+
_n_per_row: usize,
49+
) -> Result<()> {
50+
Err(Error::NotCompiledWithCudaSupport)
51+
}
52+
53+
pub fn quantize_onto(&mut self, _src: &crate::CpuStorage) -> Result<()> {
54+
Err(Error::NotCompiledWithCudaSupport)
55+
}
56+
3557
pub fn storage_size_in_bytes(&self) -> usize {
3658
0
3759
}
@@ -44,6 +66,10 @@ impl QCudaStorage {
4466
) -> Result<(CudaStorage, crate::Shape)> {
4567
Err(Error::NotCompiledWithCudaSupport)
4668
}
69+
70+
pub fn data(&self) -> Result<Vec<u8>> {
71+
Err(Error::NotCompiledWithCudaSupport)
72+
}
4773
}
4874

4975
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(

candle-core/src/quantized/dummy_metal.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,28 @@ impl QMetalStorage {
2828
Err(Error::NotCompiledWithMetalSupport)
2929
}
3030

31+
pub fn quantize_imatrix(
32+
&mut self,
33+
_src: &MetalStorage,
34+
_imatrix_weights: &[f32],
35+
_n_per_row: usize,
36+
) -> Result<()> {
37+
Err(Error::NotCompiledWithMetalSupport)
38+
}
39+
40+
pub fn quantize_imatrix_onto(
41+
&mut self,
42+
_src: &crate::CpuStorage,
43+
_imatrix_weights: &[f32],
44+
_n_per_row: usize,
45+
) -> Result<()> {
46+
Err(Error::NotCompiledWithMetalSupport)
47+
}
48+
49+
pub fn quantize_onto(&mut self, _src: &crate::CpuStorage) -> Result<()> {
50+
Err(Error::NotCompiledWithCudaSupport)
51+
}
52+
3153
pub fn storage_size_in_bytes(&self) -> usize {
3254
0
3355
}
@@ -40,6 +62,10 @@ impl QMetalStorage {
4062
) -> Result<(MetalStorage, crate::Shape)> {
4163
Err(Error::NotCompiledWithMetalSupport)
4264
}
65+
66+
pub fn data(&self) -> Result<Vec<u8>> {
67+
Err(Error::NotCompiledWithMetalSupport)
68+
}
4369
}
4470

4571
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(

candle-core/src/quantized/ggml_file.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ fn from_raw_data<T: super::GgmlType + Send + Sync + 'static>(
134134
super::QTensor::new(data, dims)
135135
}
136136

137-
/// Creates a Tensor from a raw GGML tensor.
137+
/// Creates a [Tensor] from a raw GGML tensor.
138138
pub fn qtensor_from_ggml(
139139
ggml_dtype: GgmlDType,
140140
raw_data: &[u8],

candle-core/src/quantized/gguf_file.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
//! Support for the [GGUF file format](https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md).
22
//!
3+
//! Spec: https://github.com/ggml-org/ggml/blob/master/docs/gguf.md
34
45
use super::{GgmlDType, QTensor};
56
use crate::{Context, Device, Result};

0 commit comments

Comments
 (0)