@@ -2030,15 +2030,15 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
20302030            const  int  cc            = ggml_cuda_info ().devices [id].cc ;
20312031            const  int  warp_size     = ggml_cuda_info ().devices [id].warp_size ;
20322032            use_mul_mat_q           = use_mul_mat_q             && ggml_cuda_should_use_mmq (src0->type , cc, src1->ne [1 ]);
2033-             use_mul_mat_f           = use_mul_mat_f             && ggml_cuda_should_use_mmf (src0->type , cc, warp_size, src0->ne , src1->ne );
2033+             use_mul_mat_f           = use_mul_mat_f             && ggml_cuda_should_use_mmf (src0->type , cc, warp_size, src0->ne , src1->ne [ 1 ] );
20342034            use_mul_mat_vec_f       = use_mul_mat_vec_f         && ggml_cuda_should_use_mmvf (src0->type , cc, src0->ne , src1->ne [1 ]);
20352035            any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16   || !fast_fp16_hardware_available (cc);
20362036        }
20372037    } else  {
20382038        const  int  cc            = ggml_cuda_info ().devices [ctx.device ].cc ;
20392039        const  int  warp_size     = ggml_cuda_info ().devices [ctx.device ].warp_size ;
20402040        use_mul_mat_q           = use_mul_mat_q             && ggml_cuda_should_use_mmq (src0->type , cc, src1->ne [1 ]);
2041-         use_mul_mat_f           = use_mul_mat_f             && ggml_cuda_should_use_mmf (src0->type , cc, warp_size, src0->ne , src1->ne );
2041+         use_mul_mat_f           = use_mul_mat_f             && ggml_cuda_should_use_mmf (src0->type , cc, warp_size, src0->ne , src1->ne [ 1 ] );
20422042        use_mul_mat_vec_f       = use_mul_mat_vec_f         && ggml_cuda_should_use_mmvf (src0->type , cc, src0->ne , src1->ne [1 ]);
20432043        any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16   || !fast_fp16_hardware_available (cc);
20442044    }
@@ -2110,7 +2110,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
21102110            return ;
21112111        }
21122112
2113-         if  ( ! ggml_is_quantized (src0-> type  ) &&  ggml_cuda_should_use_mmf (src0->type , cc, WARP_SIZE, src0->ne , src1->ne , ids)) {
2113+         if  (ggml_cuda_should_use_mmf (src0->type , cc, WARP_SIZE, src0->ne , src1->ne [ 2 ] , ids)) {
21142114            ggml_cuda_mul_mat_f (ctx, src0, src1, ids, dst);
21152115            return ;
21162116        }
0 commit comments