@@ -2186,6 +2186,7 @@ struct mmid_row_mapping {
21862186 int32_t i2;
21872187};
21882188
2189+ template <typename data_t = float >
21892190static __global__ void k_copy_src_to_contiguous (const char * __restrict__ src_original, char * __restrict__ src_contiguous,
21902191 const mmid_row_mapping * __restrict__ row_mapping,
21912192 int64_t ne10, int64_t ne11, size_t nb11, size_t nb12) {
@@ -2194,8 +2195,8 @@ static __global__ void k_copy_src_to_contiguous(const char * __restrict__ src_or
21942195 const int32_t i11 = row_mapping[i].i1 % ne11;
21952196 const int32_t i12 = row_mapping[i].i2 ;
21962197
2197- float * src_row_contiguous = (float *)(src_contiguous + i*nb11);
2198- const float * src_row_original = (const float *)(src_original + i11*nb11 + i12*nb12);
2198+ data_t * src_row_contiguous = (data_t *)(src_contiguous + i*nb11);
2199+ const data_t * src_row_original = (const data_t *)(src_original + i11*nb11 + i12*nb12);
21992200
22002201 for (int j = threadIdx .x ; j < ne10; j += blockDim .x ) {
22012202 src_row_contiguous[j] = src_row_original[j];
@@ -2673,6 +2674,17 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
26732674 }
26742675 }
26752676 } else {
2677+ // printf("ne10 = %ld, ne11 = %ld, ne12 = %ld, nb10 = %zu nb11 = %zu nb12 = %zu\n", src1->ne[0], src1->ne[1], src1->ne[2], src1->nb[0], src1->nb[1], src1->nb[2]);
2678+ ggml_cuda_pool_alloc<char > src1_quantized (ctx.pool ());
2679+ bool use_quantized_src1 = false ;
2680+ int64_t src1_padded_num_cols = 0 , src1_padded_row_size = 0 , src1_quantized_size = 0 ;
2681+ if (ggml_is_quantized (src0_1->type ) && src0_1->type == src0_2->type && src1->ne [1 ] == 1 && src1->ne [3 ] == 1 ) {
2682+ src1_padded_num_cols = GGML_PAD (src1->ne [0 ], MATRIX_ROW_PADDING);
2683+ src1_padded_row_size = src1_padded_num_cols/ggml_blck_size (GGML_TYPE_Q8_1)*ggml_type_size (GGML_TYPE_Q8_1);
2684+ src1_quantized_size = src1_padded_row_size*src1->ne [2 ] + get_mmq_x_max_host (ggml_cuda_info ().devices [ctx.device ].cc )*sizeof (block_q8_1_mmq);
2685+ src1_quantized.alloc (src1_quantized_size);
2686+ use_quantized_src1 = true ;
2687+ }
26762688 ggml_cuda_pool_alloc<char > src1_contiguous (ctx.pool (), sizeof (float )*ggml_nelements (src1));
26772689 ggml_cuda_pool_alloc<char > dst_up_contiguous (ctx.pool (), sizeof (float )*ggml_nelements (dst));
26782690 ggml_cuda_pool_alloc<char > dst_gate_contiguous (ctx.pool (), sizeof (float )*ggml_nelements (dst));
@@ -2704,7 +2716,13 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
27042716 if (num_src1_rows == 0 ) continue ;
27052717 size_t mapping_offset = cum_moe_counts[i02];
27062718
2707- {
2719+ if (use_quantized_src1) {
2720+ quantize_mmq_q8_1_id_cuda ((const float *)src1->data , src1_quantized.get (), (const char *)(dev_row_mapping.get () + mapping_offset),
2721+ src1->ne [0 ], num_src1_rows, src1_padded_num_cols, src0_1->type , stream);
2722+ CUDA_CHECK (cudaGetLastError ());
2723+ src1_row.data = src1_quantized.get ();
2724+ }
2725+ else {
27082726 dim3 block_dims (std::min ((unsigned int )ne10, 768u ));
27092727 dim3 grid_dims (num_src1_rows);
27102728 k_copy_src_to_contiguous<<<grid_dims, block_dims, 0 , stream>>> (
@@ -2719,21 +2737,31 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
27192737 GGML_ASSERT (nb1 == sizeof (float )*ne0);
27202738
27212739 src1_row.ne [1 ] = num_src1_rows;
2722- src1_row.nb [1 ] = nb11;
2723- src1_row.nb [2 ] = num_src1_rows*nb11 ;
2724- src1_row.nb [3 ] = num_src1_rows*nb11 ;
2740+ src1_row.nb [1 ] = use_quantized_src1 ? src1_padded_row_size : nb11;
2741+ src1_row.nb [2 ] = num_src1_rows*src1_row. nb [ 1 ] ;
2742+ src1_row.nb [3 ] = num_src1_rows*src1_row. nb [ 1 ] ;
27252743
27262744 dst_row.ne [1 ] = num_src1_rows;
27272745 dst_row.nb [1 ] = nb1;
27282746 dst_row.nb [2 ] = num_src1_rows*nb1;
27292747 dst_row.nb [3 ] = num_src1_rows*nb1;
27302748
27312749 dst_row.data = dst_up_contiguous.get ();
2732- ggml_cuda_mul_mat (ctx, &src0_1_row, &src1_row, &dst_row);
2750+ if (use_quantized_src1) {
2751+ ggml_cuda_op_mul_mat_q (ctx, &src0_1_row, &src1_row, &dst_row, (const char *)src0_1_row.data , nullptr , src1_quantized.get (), (float *)dst_row.data ,
2752+ 0 , src0_1_row.ne [1 ], num_src1_rows, src1_padded_num_cols, stream);
2753+ } else {
2754+ ggml_cuda_mul_mat (ctx, &src0_1_row, &src1_row, &dst_row);
2755+ }
27332756 CUDA_CHECK (cudaGetLastError ());
27342757
27352758 dst_row.data = dst_gate_contiguous.get ();
2736- ggml_cuda_mul_mat (ctx, &src0_2_row, &src1_row, &dst_row);
2759+ if (use_quantized_src1) {
2760+ ggml_cuda_op_mul_mat_q (ctx, &src0_2_row, &src1_row, &dst_row, (const char *)src0_2_row.data , nullptr , src1_quantized.get (), (float *)dst_row.data ,
2761+ 0 , src0_2_row.ne [1 ], num_src1_rows, src1_padded_num_cols, stream);
2762+ } else {
2763+ ggml_cuda_mul_mat (ctx, &src0_2_row, &src1_row, &dst_row);
2764+ }
27372765 CUDA_CHECK (cudaGetLastError ());
27382766
27392767 ggml_fused_mul_unary (ctx, (ggml_unary_op)dst->op_params [0 ], ggml_nelements (&dst_row),
0 commit comments