@@ -259,45 +259,47 @@ static __global__ void quantize_mmq_q8_1_id(
259259}
260260
261261void quantize_row_q8_1_cuda (
262- const float * x, void * vy, const int64_t kx0 , const int64_t kx1 , const int64_t channels ,
263- const int64_t kx0_padded , const ggml_type type_x , cudaStream_t stream) {
262+ const float * x, void * vy, const ggml_type type_src0, const int64_t ne00 , const int64_t s01 , const int64_t s02, const int64_t s03 ,
263+ const int64_t ne0 , const int64_t ne1, const int64_t ne2, const int64_t ne3 , cudaStream_t stream) {
264264
265- GGML_ASSERT (kx0_padded % QK8_1 == 0 );
265+ GGML_ASSERT (ne0 % QK8_1 == 0 );
266266
267- const int64_t block_num_x = (kx0_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1 ) / CUDA_QUANTIZE_BLOCK_SIZE;
268- const dim3 num_blocks (block_num_x, kx1*channels, 1 );
267+ const int64_t block_num_x = (ne0 + CUDA_QUANTIZE_BLOCK_SIZE - 1 ) / CUDA_QUANTIZE_BLOCK_SIZE;
268+ const dim3 num_blocks (block_num_x, ne1, ne2*ne3 );
269269 const dim3 block_size (CUDA_QUANTIZE_BLOCK_SIZE, 1 , 1 );
270- quantize_q8_1<<<num_blocks, block_size, 0 , stream>>> (x, vy, kx0, kx0_padded);
271-
272- GGML_UNUSED (type_x);
270+ quantize_q8_1<<<num_blocks, block_size, 0 , stream>>> (x, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
271+ GGML_UNUSED (type_src0);
273272}
274273
275274void quantize_mmq_q8_1_cuda (
276- const float * x, void * vy, const int64_t kx0 , const int64_t kx1 , const int64_t channels ,
277- const int64_t kx0_padded , const ggml_type type_x , cudaStream_t stream) {
275+ const float * x, void * vy, const ggml_type type_src0, const int64_t ne00 , const int64_t s01 , const int64_t s02, const int64_t s03 ,
276+ const int64_t ne0 , const int64_t ne1, const int64_t ne2, const int64_t ne3 , cudaStream_t stream) {
278277
279- GGML_ASSERT (kx0_padded % (4 *QK8_1) == 0 );
278+ GGML_ASSERT (ne0 % (4 *QK8_1) == 0 );
280279
281- const int64_t block_num_x = (kx0_padded + 4 *CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1 ) / (4 *CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
282- const dim3 num_blocks (block_num_x, kx1, channels );
280+ const int64_t block_num_x = (ne0 + 4 *CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1 ) / (4 *CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
281+ const dim3 num_blocks (block_num_x, ne1, ne2*ne3 );
283282 const dim3 block_size (CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1 , 1 );
284- switch (mmq_get_q8_1_ds_layout (type_x )) {
283+ switch (mmq_get_q8_1_ds_layout (type_src0 )) {
285284 case MMQ_Q8_1_DS_LAYOUT_D4:
286285 quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D4>
287- <<<num_blocks, block_size, 0 , stream>>> (x, vy, kx0, kx1, kx0_padded );
286+ <<<num_blocks, block_size, 0 , stream>>> (x, vy, ne00, ne1, ne0 );
288287 break ;
289288 case MMQ_Q8_1_DS_LAYOUT_DS4:
290289 quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_DS4>
291- <<<num_blocks, block_size, 0 , stream>>> (x, vy, kx0, kx1, kx0_padded );
290+ <<<num_blocks, block_size, 0 , stream>>> (x, vy, ne00, ne1, ne0 );
292291 break ;
293292 case MMQ_Q8_1_DS_LAYOUT_D2S6:
294293 quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D2S6>
295- <<<num_blocks, block_size, 0 , stream>>> (x, vy, kx0, kx1, kx0_padded );
294+ <<<num_blocks, block_size, 0 , stream>>> (x, vy, ne00, ne1, ne0 );
296295 break ;
297296 default :
298297 GGML_ABORT (" fatal error" );
299298 break ;
300299 }
300+ GGML_UNUSED (s01);
301+ GGML_UNUSED (s02);
302+ GGML_UNUSED (s03);
301303}
302304
303305void quantize_mmq_q8_1_id_cuda (
0 commit comments