@@ -1861,7 +1861,7 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
18611861
18621862 const auto convert_func = traits::get_nc_converter (src1->type );
18631863 GGML_ASSERT (convert_func != nullptr );
1864- convert_func (( const void *)(( const char *) src1->data ) , src1_alloc.get (), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
1864+ convert_func (src1->data , src1_alloc.get (), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
18651865 src1_ptr = src1_alloc.get ();
18661866 s11 = ne10;
18671867 s12 = ne11*s11;
@@ -1922,7 +1922,7 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
19221922 ne01, ne11, ne10,
19231923 alpha, src0_ptr, cu_data_type_a, nb01/nb00, nb02/nb00, // strideA
19241924 src1_ptr, cu_data_type_b, s11, s12, // strideB
1925- beta, dst_t , cu_data_type, ne0, ne1*ne0, // strideC
1925+ beta, dst_t , cu_data_type, ne0, ne1*ne0, // strideC
19261926 ne12*ne13,
19271927 cu_compute_type,
19281928 CUBLAS_GEMM_DEFAULT_TENSOR_OP));
@@ -1954,7 +1954,7 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
19541954 ne01, ne11, ne10,
19551955 alpha, (const void **) (ptrs_src.get () + 0 *ne23), cu_data_type_a, nb01/nb00,
19561956 (const void **) (ptrs_src.get () + 1 *ne23), cu_data_type_b, s11,
1957- beta, ( void **) (ptrs_dst.get () + 0 *ne23), cu_data_type, ne0,
1957+ beta, ( void **) (ptrs_dst.get () + 0 *ne23), cu_data_type, ne0,
19581958 ne23,
19591959 cu_compute_type,
19601960 CUBLAS_GEMM_DEFAULT_TENSOR_OP));
@@ -2033,10 +2033,11 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
20332033 // printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
20342034 // printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
20352035
2036+ // TODO update for generic tensor parallelism
20362037 const int cc = ggml_cuda_info ().devices [ggml_cuda_get_device ()].cc ;
2037- bool can_use_batched_cublas_f16 = src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16);
2038- bool can_use_batched_cublas_bf16 = src0->type == GGML_TYPE_BF16 && bf16_mma_hardware_available (cc);
2039- bool can_use_batched_cublas_f32 = src0->type == GGML_TYPE_F32;
2038+ bool use_batched_cublas_f16 = src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16);
2039+ bool use_batched_cublas_bf16 = src0->type == GGML_TYPE_BF16 && bf16_mma_hardware_available (cc);
2040+ bool use_batched_cublas_f32 = src0->type == GGML_TYPE_F32;
20402041
20412042 if (!split && use_mul_mat_vec) {
20422043 // the custom F16 vector kernel can be used over batched cuBLAS GEMM
@@ -2046,7 +2047,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
20462047 ggml_cuda_mul_mat_vec_q (ctx, src0, src1, nullptr , dst);
20472048 } else if (!split && use_mul_mat_q) {
20482049 ggml_cuda_mul_mat_q (ctx, src0, src1, nullptr , dst);
2049- } else if (!split && (can_use_batched_cublas_f16 || can_use_batched_cublas_bf16 || can_use_batched_cublas_f32 )
2050+ } else if (!split && (use_batched_cublas_f16 || use_batched_cublas_bf16 || use_batched_cublas_f32 )
20502051 && !ggml_is_transposed (src0) && !ggml_is_transposed (src1) && src1->ne [2 ]*src1->ne [3 ] > 1 ) {
20512052 // general KQ + KQV multi-batch without FlashAttention
20522053 ggml_cuda_mul_mat_batched_cublas (ctx, src0, src1, dst);
0 commit comments