@@ -1858,7 +1858,7 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
18581858
18591859 const auto convert_func = traits::get_nc_converter (src1->type );
18601860 GGML_ASSERT (convert_func != nullptr );
1861- convert_func (( const void *)(( const char *) src1->data ) , src1_alloc.get (), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
1861+ convert_func (src1->data , src1_alloc.get (), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
18621862 src1_ptr = src1_alloc.get ();
18631863 s11 = ne10;
18641864 s12 = ne11*s11;
@@ -1919,7 +1919,7 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
19191919 ne01, ne11, ne10,
19201920 alpha, src0_ptr, cu_data_type_a, nb01/nb00, nb02/nb00, // strideA
19211921 src1_ptr, cu_data_type_b, s11, s12, // strideB
1922- beta, dst_t , cu_data_type, ne0, ne1*ne0, // strideC
1922+ beta, dst_t , cu_data_type, ne0, ne1*ne0, // strideC
19231923 ne12*ne13,
19241924 cu_compute_type,
19251925 CUBLAS_GEMM_DEFAULT_TENSOR_OP));
@@ -1951,7 +1951,7 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
19511951 ne01, ne11, ne10,
19521952 alpha, (const void **) (ptrs_src.get () + 0 *ne23), cu_data_type_a, nb01/nb00,
19531953 (const void **) (ptrs_src.get () + 1 *ne23), cu_data_type_b, s11,
1954- beta, ( void **) (ptrs_dst.get () + 0 *ne23), cu_data_type, ne0,
1954+ beta, ( void **) (ptrs_dst.get () + 0 *ne23), cu_data_type, ne0,
19551955 ne23,
19561956 cu_compute_type,
19571957 CUBLAS_GEMM_DEFAULT_TENSOR_OP));
@@ -2030,10 +2030,11 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
20302030 // printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
20312031 // printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
20322032
2033+ // TODO update for generic tensor parallelism
20332034 const int cc = ggml_cuda_info ().devices [ggml_cuda_get_device ()].cc ;
2034- bool can_use_batched_cublas_f16 = src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16);
2035- bool can_use_batched_cublas_bf16 = src0->type == GGML_TYPE_BF16 && bf16_mma_hardware_available (cc);
2036- bool can_use_batched_cublas_f32 = src0->type == GGML_TYPE_F32;
2035+ bool use_batched_cublas_f16 = src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16);
2036+ bool use_batched_cublas_bf16 = src0->type == GGML_TYPE_BF16 && bf16_mma_hardware_available (cc);
2037+ bool use_batched_cublas_f32 = src0->type == GGML_TYPE_F32;
20372038
20382039 if (!split && use_mul_mat_vec) {
20392040 // the custom F16 vector kernel can be used over batched cuBLAS GEMM
@@ -2043,7 +2044,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
20432044 ggml_cuda_mul_mat_vec_q (ctx, src0, src1, nullptr , dst);
20442045 } else if (!split && use_mul_mat_q) {
20452046 ggml_cuda_mul_mat_q (ctx, src0, src1, nullptr , dst);
2046- } else if (!split && (can_use_batched_cublas_f16 || can_use_batched_cublas_bf16 || can_use_batched_cublas_f32 )
2047+ } else if (!split && (use_batched_cublas_f16 || use_batched_cublas_bf16 || use_batched_cublas_f32 )
20472048 && !ggml_is_transposed (src0) && !ggml_is_transposed (src1) && src1->ne [2 ]*src1->ne [3 ] > 1 ) {
20482049 // general KQ + KQV multi-batch without FlashAttention
20492050 ggml_cuda_mul_mat_batched_cublas (ctx, src0, src1, dst);
0 commit comments