@@ -46,24 +46,57 @@ fn pad(p: usize, q: usize) -> usize {
4646fn quantize_q8_1 (
4747 src : & CudaView < f32 > ,
4848 dst : & mut CudaSlice < u8 > ,
49- elem_count : usize ,
49+ k : usize ,
5050 ky : usize ,
5151 dev : & CudaDevice ,
5252) -> Result < ( ) > {
53- let kx = elem_count;
54- let kx_padded = pad ( kx, MATRIX_ROW_PADDING ) ;
53+ let kx_padded = pad ( k, MATRIX_ROW_PADDING ) ;
5554 let num_blocks = ceil_div ( kx_padded, CUDA_QUANTIZE_BLOCK_SIZE ) ;
55+
56+ let total_rows = ky;
57+ // Get Q8_1 metadata.
58+ let q8_1_block_size = GgmlDType :: Q8_1 . block_size ( ) ;
59+ let q8_1_type_size = GgmlDType :: Q8_1 . type_size ( ) ;
60+
61+ // Calculate the size of the output buffer in bytes.
62+ let num_blocks_per_row = kx_padded / q8_1_block_size;
63+ let dst_row_size_bytes = num_blocks_per_row * q8_1_type_size;
64+
65+ const CHUNK_SIZE : usize = 65535 ; // gridDim.y limit
5666 let func = dev. get_or_load_func ( "quantize_q8_1" , & candle_kernels:: QUANTIZED ) ?;
57- let cfg = cudarc:: driver:: LaunchConfig {
58- grid_dim : ( num_blocks as u32 , ky as u32 , 1 ) ,
59- block_dim : ( CUDA_QUANTIZE_BLOCK_SIZE as u32 , 1 , 1 ) ,
60- shared_mem_bytes : 0 ,
61- } ;
62- let mut builder = func. builder ( ) ;
63- builder. arg ( src) ;
64- builder. arg ( dst) ;
65- barg ! ( builder, kx as i32 , kx_padded as i32 ) ;
66- unsafe { builder. launch ( cfg) } . w ( ) ?;
67+
68+ let mut rows_processed = 0 ;
69+ while rows_processed < total_rows {
70+ // --- calculate the number of rows for this chunk ---
71+ let remaining_rows = total_rows - rows_processed;
72+ // This is our gridDim.y, now <= 65535
73+ let rows_in_chunk = std:: cmp:: min ( CHUNK_SIZE , remaining_rows) ;
74+
75+ // --- slice the source (f32) tensor by elements ---
76+ let src_start_elem = rows_processed * k;
77+ let src_num_elems = rows_in_chunk * k;
78+ let src_chunk = src. slice ( src_start_elem..( src_start_elem + src_num_elems) ) ;
79+
80+ // --- slice the destination (u8) tensor by bytes ---
81+ let dst_start_byte = rows_processed * dst_row_size_bytes;
82+ let dst_num_bytes = rows_in_chunk * dst_row_size_bytes;
83+ let dst_chunk = dst. slice ( dst_start_byte..( dst_start_byte + dst_num_bytes) ) ;
84+
85+ let cfg = cudarc:: driver:: LaunchConfig {
86+ grid_dim : ( num_blocks as u32 , rows_in_chunk as u32 , 1 ) ,
87+ block_dim : ( CUDA_QUANTIZE_BLOCK_SIZE as u32 , 1 , 1 ) ,
88+ shared_mem_bytes : 0 ,
89+ } ;
90+
91+ let mut builder = func. builder ( ) ;
92+ builder. arg ( & src_chunk) ;
93+ builder. arg ( & dst_chunk) ;
94+ barg ! ( builder, k as i32 , kx_padded as i32 ) ;
95+ unsafe { builder. launch ( cfg) } . w ( ) ?;
96+
97+ rows_processed += rows_in_chunk;
98+ }
99+
67100 Ok ( ( ) )
68101}
69102
@@ -477,6 +510,87 @@ impl QCudaStorage {
477510 Ok ( ( ) )
478511 }
479512
513+ pub fn quantize_imatrix (
514+ & mut self ,
515+ src : & CudaStorage ,
516+ imatrix_weights : & [ f32 ] ,
517+ n_per_row : usize ,
518+ ) -> Result < ( ) > {
519+ // Run the quantization on cpu.
520+ let src = match & src. slice {
521+ crate :: cuda_backend:: CudaStorageSlice :: F32 ( data) => self . device . memcpy_dtov ( data) ?,
522+ _ => crate :: bail!( "only f32 can be quantized" ) ,
523+ } ;
524+ let src_len = src. len ( ) ;
525+ let src = crate :: Storage :: Cpu ( crate :: CpuStorage :: F32 ( src) ) ;
526+ let mut qcpu_storage = crate :: Device :: Cpu . qzeros ( src_len, self . dtype ) ?;
527+ qcpu_storage. quantize_imatrix ( & src, imatrix_weights, n_per_row) ?;
528+ let data = qcpu_storage. data ( ) ?;
529+ let padded_len =
530+ data. len ( ) + MATRIX_ROW_PADDING * self . dtype . type_size ( ) / self . dtype . block_size ( ) ;
531+ let mut inner = unsafe { self . device . alloc :: < u8 > ( padded_len) ? } ;
532+ self . device
533+ . memcpy_htod ( data. as_ref ( ) , & mut inner. slice_mut ( ..data. len ( ) ) ) ?;
534+ self . data = PaddedCudaSlice {
535+ inner,
536+ len : data. len ( ) ,
537+ } ;
538+ Ok ( ( ) )
539+ }
540+
541+ pub fn quantize_imatrix_onto (
542+ & mut self ,
543+ src : & crate :: CpuStorage ,
544+ imatrix_weights : & [ f32 ] ,
545+ n_per_row : usize ,
546+ ) -> Result < ( ) > {
547+ // Run the quantization on cpu.
548+ let src_len = src. as_slice :: < f32 > ( ) ?. len ( ) ;
549+ let mut qcpu_storage = crate :: Device :: Cpu . qzeros ( src_len, self . dtype ) ?;
550+
551+ if let QStorage :: Cpu ( storage) = & mut qcpu_storage {
552+ storage. from_float_imatrix ( src. as_slice :: < f32 > ( ) ?, imatrix_weights, n_per_row) ;
553+ } else {
554+ unreachable ! ( )
555+ }
556+
557+ let data = qcpu_storage. data ( ) ?;
558+ let padded_len =
559+ data. len ( ) + MATRIX_ROW_PADDING * self . dtype . type_size ( ) / self . dtype . block_size ( ) ;
560+ let mut inner = unsafe { self . device . alloc :: < u8 > ( padded_len) ? } ;
561+ self . device
562+ . memcpy_htod ( data. as_ref ( ) , & mut inner. slice_mut ( ..data. len ( ) ) ) ?;
563+ self . data = PaddedCudaSlice {
564+ inner,
565+ len : data. len ( ) ,
566+ } ;
567+ Ok ( ( ) )
568+ }
569+
570+ pub fn quantize_onto ( & mut self , src : & crate :: CpuStorage ) -> Result < ( ) > {
571+ // Run the quantization on cpu.
572+ let src_len = src. as_slice :: < f32 > ( ) ?. len ( ) ;
573+ let mut qcpu_storage = crate :: Device :: Cpu . qzeros ( src_len, self . dtype ) ?;
574+
575+ if let QStorage :: Cpu ( storage) = & mut qcpu_storage {
576+ storage. from_float ( src. as_slice :: < f32 > ( ) ?) ;
577+ } else {
578+ unreachable ! ( )
579+ }
580+
581+ let data = qcpu_storage. data ( ) ?;
582+ let padded_len =
583+ data. len ( ) + MATRIX_ROW_PADDING * self . dtype . type_size ( ) / self . dtype . block_size ( ) ;
584+ let mut inner = unsafe { self . device . alloc :: < u8 > ( padded_len) ? } ;
585+ self . device
586+ . memcpy_htod ( data. as_ref ( ) , & mut inner. slice_mut ( ..data. len ( ) ) ) ?;
587+ self . data = PaddedCudaSlice {
588+ inner,
589+ len : data. len ( ) ,
590+ } ;
591+ Ok ( ( ) )
592+ }
593+
480594 pub fn storage_size_in_bytes ( & self ) -> usize {
481595 self . data . len
482596 }
@@ -503,6 +617,13 @@ impl QCudaStorage {
503617 self . dequantize_matmul ( self_shape, storage, layout)
504618 }
505619 }
620+
621+ pub fn data ( & self ) -> Result < Vec < u8 > > {
622+ let mut out = vec ! [ 0u8 ; self . data. len] ;
623+ self . device
624+ . memcpy_dtoh ( & self . data . inner . slice ( ..self . data . len ) , & mut out) ?;
625+ Ok ( out)
626+ }
506627}
507628
508629impl QCudaStorage {
@@ -629,7 +750,7 @@ mod test {
629750 let mut y_q8_1 = unsafe { dev. alloc :: < u8 > ( y_size_in_bytes) ? } ;
630751 let vs: Vec < f32 > = ( 0 ..el) . map ( |v| v as f32 ) . collect ( ) ;
631752 let y = dev. memcpy_stod ( & vs) ?;
632- quantize_q8_1 ( & y. slice ( .. ) , & mut y_q8_1, el, 1 , & dev) ?;
753+ quantize_q8_1 ( & y. as_view ( ) , & mut y_q8_1, el, 1 , & dev) ?;
633754 Ok ( ( ) )
634755 }
635756
@@ -643,30 +764,30 @@ mod test {
643764 xs. quantize ( & CudaStorage :: wrap_cuda_slice ( y. clone ( ) , dev. clone ( ) ) ) ?;
644765 let cuda_storage = mul_mat_vec_via_q8_1 (
645766 & xs. data ,
646- & y. slice ( .. ) ,
767+ & y. as_view ( ) ,
647768 /* dtype */ GgmlDType :: Q4_0 ,
648769 /* ncols */ ncols,
649770 /* nrows */ 1 ,
650771 /* b_size */ 1 ,
651772 & dev,
652773 ) ?;
653774 let vs = cuda_storage. as_cuda_slice :: < f32 > ( ) ?;
654- let vs = dev. memcpy_dtov ( & vs. slice ( .. ) ) ?;
775+ let vs = dev. memcpy_dtov ( & vs. as_view ( ) ) ?;
655776 assert_eq ! ( vs. len( ) , 1 ) ;
656777 // for n = 255, n.(n+1).(2n+1) / 6 = 5559680
657778 // Q8 means 1/256 precision.
658779 assert_eq ! ( vs[ 0 ] , 5561664.5 ) ;
659780
660781 let cuda_storage = dequantize_mul_mat_vec (
661782 & xs. data ,
662- & y. slice ( .. ) ,
783+ & y. as_view ( ) ,
663784 /* dtype */ GgmlDType :: Q4_0 ,
664785 /* ncols */ ncols,
665786 /* nrows */ 1 ,
666787 & dev,
667788 ) ?;
668789 let vs = cuda_storage. as_cuda_slice :: < f32 > ( ) ?;
669- let vs = dev. memcpy_dtov ( & vs. slice ( .. ) ) ?;
790+ let vs = dev. memcpy_dtov ( & vs. as_view ( ) ) ?;
670791 assert_eq ! ( vs. len( ) , 1 ) ;
671792 assert_eq ! ( vs[ 0 ] , 5561851.0 ) ;
672793 Ok ( ( ) )
@@ -682,7 +803,7 @@ mod test {
682803 xs. quantize ( & CudaStorage :: wrap_cuda_slice ( y. clone ( ) , dev. clone ( ) ) ) ?;
683804 let cuda_storage = mul_mat_via_q8_1 (
684805 & xs. data ,
685- & y. slice ( .. ) ,
806+ & y. as_view ( ) ,
686807 /* dtype */ GgmlDType :: Q4_0 ,
687808 /* x_rows */ 4 ,
688809 /* x_cols */ ncols,
@@ -691,7 +812,7 @@ mod test {
691812 & dev,
692813 ) ?;
693814 let vs = cuda_storage. as_cuda_slice :: < f32 > ( ) ?;
694- let vs = dev. memcpy_dtov ( & vs. slice ( .. ) ) ?;
815+ let vs = dev. memcpy_dtov ( & vs. as_view ( ) ) ?;
695816
696817 /*
697818 x = torch.tensor([float(v) for v in range(1024)]).reshape(4, 256)
@@ -723,7 +844,7 @@ mod test {
723844 xs. quantize ( & CudaStorage :: wrap_cuda_slice ( y. clone ( ) , dev. clone ( ) ) ) ?;
724845 let cuda_storage = mul_mat_via_q8_1 (
725846 & xs. data ,
726- & y. slice ( .. ) ,
847+ & y. as_view ( ) ,
727848 /* dtype */ GgmlDType :: Q4_0 ,
728849 /* x_rows */ x_rows,
729850 /* x_cols */ ncols,
@@ -732,7 +853,7 @@ mod test {
732853 & dev,
733854 ) ?;
734855 let vs = cuda_storage. as_cuda_slice :: < f32 > ( ) ?;
735- let _vs = dev. memcpy_dtov ( & vs. slice ( .. ) ) ?;
856+ let _vs = dev. memcpy_dtov ( & vs. as_view ( ) ) ?;
736857 Ok ( ( ) )
737858 }
738859}
0 commit comments