@@ -32,6 +32,22 @@ use half::{bf16, f16};
3232
3333pub use k_quants:: GgmlType ;
3434
35+ fn as_t_slice < T > ( data : Cow < ' _ , [ u8 ] > ) -> & [ T ] {
36+ let size = std:: mem:: size_of :: < T > ( ) ;
37+ assert_eq ! (
38+ data. len( ) % size,
39+ 0 ,
40+ "Data length must be a multiple of T's size"
41+ ) ;
42+ let ptr = data. as_ptr ( ) ;
43+ assert_eq ! (
44+ ( ptr as usize ) % std:: mem:: align_of:: <T >( ) ,
45+ 0 ,
46+ "Data pointer must be aligned to T's alignment"
47+ ) ;
48+ unsafe { std:: slice:: from_raw_parts ( ptr as * const T , data. len ( ) / size) }
49+ }
50+
3551pub struct QTensor {
3652 storage : QStorage ,
3753 shape : Shape ,
@@ -63,6 +79,46 @@ pub enum QStorage {
6379}
6480
6581impl QStorage {
82+ pub fn from_data ( data : Cow < ' _ , [ u8 ] > , device : & Device , dtype : GgmlDType ) -> Result < Self > {
83+ match device {
84+ Device :: Cpu => Ok ( Self :: Cpu ( dtype. from_data ( data) ) ) ,
85+ Device :: Metal ( d) => match dtype {
86+ GgmlDType :: F32 => metal:: load_quantized ( d, as_t_slice :: < f32 > ( data) ) ,
87+ GgmlDType :: F16 => metal:: load_quantized ( d, as_t_slice :: < f16 > ( data) ) ,
88+ GgmlDType :: Q4_0 => metal:: load_quantized ( d, as_t_slice :: < BlockQ4_0 > ( data) ) ,
89+ GgmlDType :: Q4_1 => metal:: load_quantized ( d, as_t_slice :: < BlockQ4_1 > ( data) ) ,
90+ GgmlDType :: Q5_0 => metal:: load_quantized ( d, as_t_slice :: < BlockQ5_0 > ( data) ) ,
91+ GgmlDType :: Q5_1 => metal:: load_quantized ( d, as_t_slice :: < BlockQ5_1 > ( data) ) ,
92+ GgmlDType :: Q8_0 => metal:: load_quantized ( d, as_t_slice :: < BlockQ8_0 > ( data) ) ,
93+ GgmlDType :: Q8_1 => metal:: load_quantized ( d, as_t_slice :: < BlockQ8_1 > ( data) ) ,
94+ GgmlDType :: Q2K => metal:: load_quantized ( d, as_t_slice :: < BlockQ2K > ( data) ) ,
95+ GgmlDType :: Q3K => metal:: load_quantized ( d, as_t_slice :: < BlockQ3K > ( data) ) ,
96+ GgmlDType :: Q4K => metal:: load_quantized ( d, as_t_slice :: < BlockQ4K > ( data) ) ,
97+ GgmlDType :: Q5K => metal:: load_quantized ( d, as_t_slice :: < BlockQ5K > ( data) ) ,
98+ GgmlDType :: Q6K => metal:: load_quantized ( d, as_t_slice :: < BlockQ6K > ( data) ) ,
99+ GgmlDType :: Q8K => metal:: load_quantized ( d, as_t_slice :: < BlockQ8K > ( data) ) ,
100+ GgmlDType :: BF16 => metal:: load_quantized ( d, as_t_slice :: < bf16 > ( data) ) ,
101+ } ,
102+ Device :: Cuda ( d) => match dtype {
103+ GgmlDType :: F32 => cuda:: load_quantized ( d, as_t_slice :: < f32 > ( data) ) ,
104+ GgmlDType :: F16 => cuda:: load_quantized ( d, as_t_slice :: < f16 > ( data) ) ,
105+ GgmlDType :: Q4_0 => cuda:: load_quantized ( d, as_t_slice :: < BlockQ4_0 > ( data) ) ,
106+ GgmlDType :: Q4_1 => cuda:: load_quantized ( d, as_t_slice :: < BlockQ4_1 > ( data) ) ,
107+ GgmlDType :: Q5_0 => cuda:: load_quantized ( d, as_t_slice :: < BlockQ5_0 > ( data) ) ,
108+ GgmlDType :: Q5_1 => cuda:: load_quantized ( d, as_t_slice :: < BlockQ5_1 > ( data) ) ,
109+ GgmlDType :: Q8_0 => cuda:: load_quantized ( d, as_t_slice :: < BlockQ8_0 > ( data) ) ,
110+ GgmlDType :: Q8_1 => cuda:: load_quantized ( d, as_t_slice :: < BlockQ8_1 > ( data) ) ,
111+ GgmlDType :: Q2K => cuda:: load_quantized ( d, as_t_slice :: < BlockQ2K > ( data) ) ,
112+ GgmlDType :: Q3K => cuda:: load_quantized ( d, as_t_slice :: < BlockQ3K > ( data) ) ,
113+ GgmlDType :: Q4K => cuda:: load_quantized ( d, as_t_slice :: < BlockQ4K > ( data) ) ,
114+ GgmlDType :: Q5K => cuda:: load_quantized ( d, as_t_slice :: < BlockQ5K > ( data) ) ,
115+ GgmlDType :: Q6K => cuda:: load_quantized ( d, as_t_slice :: < BlockQ6K > ( data) ) ,
116+ GgmlDType :: Q8K => cuda:: load_quantized ( d, as_t_slice :: < BlockQ8K > ( data) ) ,
117+ GgmlDType :: BF16 => cuda:: load_quantized ( d, as_t_slice :: < bf16 > ( data) ) ,
118+ } ,
119+ }
120+ }
121+
66122 fn block_size ( & self ) -> usize {
67123 match self {
68124 QStorage :: Cpu ( storage) => storage. block_size ( ) ,
@@ -214,6 +270,27 @@ impl GgmlDType {
214270 Self :: BF16 => Box :: new ( vec ! [ bf16:: zeros( ) ; elem_count] ) ,
215271 }
216272 }
273+
274+ pub fn from_data ( & self , data : Cow < ' _ , [ u8 ] > ) -> Box < dyn QuantizedType > {
275+ match self {
276+ Self :: F32 => Box :: new ( as_t_slice :: < f32 > ( data) . to_vec ( ) ) ,
277+ Self :: F16 => Box :: new ( as_t_slice :: < f16 > ( data) . to_vec ( ) ) ,
278+ Self :: Q4_0 => Box :: new ( as_t_slice :: < BlockQ4_0 > ( data) . to_vec ( ) ) ,
279+ Self :: Q4_1 => Box :: new ( as_t_slice :: < BlockQ4_1 > ( data) . to_vec ( ) ) ,
280+ Self :: Q5_0 => Box :: new ( as_t_slice :: < BlockQ5_0 > ( data) . to_vec ( ) ) ,
281+ Self :: Q5_1 => Box :: new ( as_t_slice :: < BlockQ5_1 > ( data) . to_vec ( ) ) ,
282+ Self :: Q8_0 => Box :: new ( as_t_slice :: < BlockQ8_0 > ( data) . to_vec ( ) ) ,
283+ Self :: Q8_1 => Box :: new ( as_t_slice :: < BlockQ8_1 > ( data) . to_vec ( ) ) ,
284+ Self :: Q2K => Box :: new ( as_t_slice :: < BlockQ2K > ( data) . to_vec ( ) ) ,
285+ Self :: Q3K => Box :: new ( as_t_slice :: < BlockQ3K > ( data) . to_vec ( ) ) ,
286+ Self :: Q4K => Box :: new ( as_t_slice :: < BlockQ4K > ( data) . to_vec ( ) ) ,
287+ Self :: Q5K => Box :: new ( as_t_slice :: < BlockQ5K > ( data) . to_vec ( ) ) ,
288+ Self :: Q6K => Box :: new ( as_t_slice :: < BlockQ6K > ( data) . to_vec ( ) ) ,
289+ Self :: Q8K => Box :: new ( as_t_slice :: < BlockQ8K > ( data) . to_vec ( ) ) ,
290+ Self :: BF16 => Box :: new ( as_t_slice :: < bf16 > ( data) . to_vec ( ) ) ,
291+ }
292+ }
293+
217294 /// The type size for blocks in bytes.
218295 pub fn type_size ( & self ) -> usize {
219296 use k_quants:: * ;
0 commit comments