Skip to content

Commit 8ebfc22

Browse files
authored
Add cublas_handle api, update safetensors (#3192)
* Add cublas_handle api, update safetensors * Add more quantized apis * Make .vscode a .gitignore
1 parent 60252cc commit 8ebfc22

File tree

6 files changed

+85
-15
lines changed

6 files changed

+85
-15
lines changed

.vscode/settings.json

Lines changed: 0 additions & 11 deletions
This file was deleted.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ parquet = { version = "51.0.0" }
8686
rand = "0.9.0"
8787
rand_distr = "0.5.1"
8888
rayon = "1.7.0"
89-
safetensors = "0.4.1"
89+
safetensors = "0.6.0"
9090
serde = { version = "1.0.171", features = ["derive"] }
9191
serde_plain = "1.0.2"
9292
serde_json = "1.0.99"

candle-core/src/cuda_backend/device.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,10 @@ impl CudaDevice {
145145
self.stream.clone()
146146
}
147147

148+
pub fn cublas_handle(&self) -> Arc<cudarc::cublas::CudaBlas> {
149+
self.blas.clone()
150+
}
151+
148152
/// When turned on, all cuda tensors **created after calling this function** will
149153
/// not track uses via cuda events.
150154
///

candle-core/src/quantized/mod.rs

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,22 @@ use half::{bf16, f16};
3232

3333
pub use k_quants::GgmlType;
3434

35+
fn as_t_slice<T>(data: Cow<'_, [u8]>) -> &[T] {
36+
let size = std::mem::size_of::<T>();
37+
assert_eq!(
38+
data.len() % size,
39+
0,
40+
"Data length must be a multiple of T's size"
41+
);
42+
let ptr = data.as_ptr();
43+
assert_eq!(
44+
(ptr as usize) % std::mem::align_of::<T>(),
45+
0,
46+
"Data pointer must be aligned to T's alignment"
47+
);
48+
unsafe { std::slice::from_raw_parts(ptr as *const T, data.len() / size) }
49+
}
50+
3551
pub struct QTensor {
3652
storage: QStorage,
3753
shape: Shape,
@@ -63,6 +79,46 @@ pub enum QStorage {
6379
}
6480

6581
impl QStorage {
82+
pub fn from_data(data: Cow<'_, [u8]>, device: &Device, dtype: GgmlDType) -> Result<Self> {
83+
match device {
84+
Device::Cpu => Ok(Self::Cpu(dtype.from_data(data))),
85+
Device::Metal(d) => match dtype {
86+
GgmlDType::F32 => metal::load_quantized(d, as_t_slice::<f32>(data)),
87+
GgmlDType::F16 => metal::load_quantized(d, as_t_slice::<f16>(data)),
88+
GgmlDType::Q4_0 => metal::load_quantized(d, as_t_slice::<BlockQ4_0>(data)),
89+
GgmlDType::Q4_1 => metal::load_quantized(d, as_t_slice::<BlockQ4_1>(data)),
90+
GgmlDType::Q5_0 => metal::load_quantized(d, as_t_slice::<BlockQ5_0>(data)),
91+
GgmlDType::Q5_1 => metal::load_quantized(d, as_t_slice::<BlockQ5_1>(data)),
92+
GgmlDType::Q8_0 => metal::load_quantized(d, as_t_slice::<BlockQ8_0>(data)),
93+
GgmlDType::Q8_1 => metal::load_quantized(d, as_t_slice::<BlockQ8_1>(data)),
94+
GgmlDType::Q2K => metal::load_quantized(d, as_t_slice::<BlockQ2K>(data)),
95+
GgmlDType::Q3K => metal::load_quantized(d, as_t_slice::<BlockQ3K>(data)),
96+
GgmlDType::Q4K => metal::load_quantized(d, as_t_slice::<BlockQ4K>(data)),
97+
GgmlDType::Q5K => metal::load_quantized(d, as_t_slice::<BlockQ5K>(data)),
98+
GgmlDType::Q6K => metal::load_quantized(d, as_t_slice::<BlockQ6K>(data)),
99+
GgmlDType::Q8K => metal::load_quantized(d, as_t_slice::<BlockQ8K>(data)),
100+
GgmlDType::BF16 => metal::load_quantized(d, as_t_slice::<bf16>(data)),
101+
},
102+
Device::Cuda(d) => match dtype {
103+
GgmlDType::F32 => cuda::load_quantized(d, as_t_slice::<f32>(data)),
104+
GgmlDType::F16 => cuda::load_quantized(d, as_t_slice::<f16>(data)),
105+
GgmlDType::Q4_0 => cuda::load_quantized(d, as_t_slice::<BlockQ4_0>(data)),
106+
GgmlDType::Q4_1 => cuda::load_quantized(d, as_t_slice::<BlockQ4_1>(data)),
107+
GgmlDType::Q5_0 => cuda::load_quantized(d, as_t_slice::<BlockQ5_0>(data)),
108+
GgmlDType::Q5_1 => cuda::load_quantized(d, as_t_slice::<BlockQ5_1>(data)),
109+
GgmlDType::Q8_0 => cuda::load_quantized(d, as_t_slice::<BlockQ8_0>(data)),
110+
GgmlDType::Q8_1 => cuda::load_quantized(d, as_t_slice::<BlockQ8_1>(data)),
111+
GgmlDType::Q2K => cuda::load_quantized(d, as_t_slice::<BlockQ2K>(data)),
112+
GgmlDType::Q3K => cuda::load_quantized(d, as_t_slice::<BlockQ3K>(data)),
113+
GgmlDType::Q4K => cuda::load_quantized(d, as_t_slice::<BlockQ4K>(data)),
114+
GgmlDType::Q5K => cuda::load_quantized(d, as_t_slice::<BlockQ5K>(data)),
115+
GgmlDType::Q6K => cuda::load_quantized(d, as_t_slice::<BlockQ6K>(data)),
116+
GgmlDType::Q8K => cuda::load_quantized(d, as_t_slice::<BlockQ8K>(data)),
117+
GgmlDType::BF16 => cuda::load_quantized(d, as_t_slice::<bf16>(data)),
118+
},
119+
}
120+
}
121+
66122
fn block_size(&self) -> usize {
67123
match self {
68124
QStorage::Cpu(storage) => storage.block_size(),
@@ -214,6 +270,27 @@ impl GgmlDType {
214270
Self::BF16 => Box::new(vec![bf16::zeros(); elem_count]),
215271
}
216272
}
273+
274+
pub fn from_data(&self, data: Cow<'_, [u8]>) -> Box<dyn QuantizedType> {
275+
match self {
276+
Self::F32 => Box::new(as_t_slice::<f32>(data).to_vec()),
277+
Self::F16 => Box::new(as_t_slice::<f16>(data).to_vec()),
278+
Self::Q4_0 => Box::new(as_t_slice::<BlockQ4_0>(data).to_vec()),
279+
Self::Q4_1 => Box::new(as_t_slice::<BlockQ4_1>(data).to_vec()),
280+
Self::Q5_0 => Box::new(as_t_slice::<BlockQ5_0>(data).to_vec()),
281+
Self::Q5_1 => Box::new(as_t_slice::<BlockQ5_1>(data).to_vec()),
282+
Self::Q8_0 => Box::new(as_t_slice::<BlockQ8_0>(data).to_vec()),
283+
Self::Q8_1 => Box::new(as_t_slice::<BlockQ8_1>(data).to_vec()),
284+
Self::Q2K => Box::new(as_t_slice::<BlockQ2K>(data).to_vec()),
285+
Self::Q3K => Box::new(as_t_slice::<BlockQ3K>(data).to_vec()),
286+
Self::Q4K => Box::new(as_t_slice::<BlockQ4K>(data).to_vec()),
287+
Self::Q5K => Box::new(as_t_slice::<BlockQ5K>(data).to_vec()),
288+
Self::Q6K => Box::new(as_t_slice::<BlockQ6K>(data).to_vec()),
289+
Self::Q8K => Box::new(as_t_slice::<BlockQ8K>(data).to_vec()),
290+
Self::BF16 => Box::new(as_t_slice::<bf16>(data).to_vec()),
291+
}
292+
}
293+
217294
/// The type size for blocks in bytes.
218295
pub fn type_size(&self) -> usize {
219296
use k_quants::*;

candle-core/src/safetensors.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ impl st::View for &Tensor {
9494
impl Tensor {
9595
pub fn save_safetensors<P: AsRef<Path>>(&self, name: &str, filename: P) -> Result<()> {
9696
let data = [(name, self.clone())];
97-
Ok(st::serialize_to_file(data, &None, filename.as_ref())?)
97+
Ok(st::serialize_to_file(data, None, filename.as_ref())?)
9898
}
9999
}
100100

@@ -268,7 +268,7 @@ pub fn save<K: AsRef<str> + Ord + std::fmt::Display, P: AsRef<Path>>(
268268
tensors: &HashMap<K, Tensor>,
269269
filename: P,
270270
) -> Result<()> {
271-
Ok(st::serialize_to_file(tensors, &None, filename.as_ref())?)
271+
Ok(st::serialize_to_file(tensors, None, filename.as_ref())?)
272272
}
273273

274274
#[derive(yoke::Yokeable)]

candle-nn/src/var_map.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ impl VarMap {
3232
pub fn save<P: AsRef<std::path::Path>>(&self, path: P) -> Result<()> {
3333
let tensor_data = self.data.lock().unwrap();
3434
let data = tensor_data.iter().map(|(k, v)| (k, v.as_tensor()));
35-
safetensors::tensor::serialize_to_file(data, &None, path.as_ref())?;
35+
safetensors::tensor::serialize_to_file(data, None, path.as_ref())?;
3636
Ok(())
3737
}
3838

0 commit comments

Comments
 (0)