@@ -3010,16 +3010,25 @@ static void ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
30103010 0 , src0_2->ne [1 ], src1->ne [1 ], ne10_padded, stream);
30113011 CUDA_CHECK (cudaGetLastError ());
30123012 } else {
3013- quantize_mmq_q8_1_cuda ((const float *)src1->data , src1_quantized.get (), src1->ne [0 ], src1->ne [1 ], 1 , ne10_padded, src0_1->type , stream);
3014- CUDA_CHECK (cudaGetLastError ());
30153013
3016- ggml_cuda_op_mul_mat_q (ctx, src0_1, src1, dst, (const char *)src0_1->data , nullptr , src1_quantized.get (), dst_up.get (),
3017- 0 , src0_1->ne [1 ], src1->ne [1 ], ne10_padded, stream);
3018- CUDA_CHECK (cudaGetLastError ());
3014+ if (ggml_cuda_should_use_mmq (src0_1->type , ggml_cuda_info ().devices [ctx.device ].cc , src1->ne [1 ])) {
3015+ quantize_mmq_q8_1_cuda ((const float *)src1->data , src1_quantized.get (), src1->ne [0 ], src1->ne [1 ], 1 ,
3016+ ne10_padded, src0_1->type , stream);
3017+ CUDA_CHECK (cudaGetLastError ());
30193018
3020- ggml_cuda_op_mul_mat_q (ctx, src0_2, src1, dst, (const char *)src0_2->data , nullptr , src1_quantized.get (), (float *)dst->data ,
3021- 0 , src0_1->ne [1 ], src1->ne [1 ], ne10_padded, stream);
3022- CUDA_CHECK (cudaGetLastError ());
3019+ ggml_cuda_op_mul_mat_q (ctx, src0_1, src1, dst, (const char *)src0_1->data , nullptr , src1_quantized.get (), dst_up.get (),
3020+ 0 , src0_1->ne [1 ], src1->ne [1 ], ne10_padded, stream);
3021+ CUDA_CHECK (cudaGetLastError ());
3022+
3023+ ggml_cuda_op_mul_mat_q (ctx, src0_2, src1, dst, (const char *)src0_2->data , nullptr , src1_quantized.get (), (float *)dst->data ,
3024+ 0 , src0_1->ne [1 ], src1->ne [1 ], ne10_padded, stream);
3025+ CUDA_CHECK (cudaGetLastError ());
3026+ } else {
3027+ auto local_dst = *dst;
3028+ local_dst.data = dst_up.get ();
3029+ ggml_cuda_mul_mat (ctx, src0_1, src1, &local_dst, nullptr , 0 );
3030+ ggml_cuda_mul_mat (ctx, src0_2, src1, dst, nullptr , 0 );
3031+ }
30233032 }
30243033
30253034 ggml_fused_mul_unary (ctx, (ggml_unary_op)dst->op_params [0 ], ggml_nelements (dst),
0 commit comments