Skip to content

Commit 2c4e42e

Browse files
committed
Review: fix formatting, remove useless type conversion, fix naming for bools
1 parent c02cd2f commit 2c4e42e

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1861,7 +1861,7 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
18611861

18621862
const auto convert_func = traits::get_nc_converter(src1->type);
18631863
GGML_ASSERT(convert_func != nullptr);
1864-
convert_func((const void*)((const char*)src1->data), src1_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
1864+
convert_func(src1->data, src1_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
18651865
src1_ptr = src1_alloc.get();
18661866
s11 = ne10;
18671867
s12 = ne11*s11;
@@ -1922,7 +1922,7 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
19221922
ne01, ne11, ne10,
19231923
alpha, src0_ptr, cu_data_type_a, nb01/nb00, nb02/nb00, // strideA
19241924
src1_ptr, cu_data_type_b, s11, s12, // strideB
1925-
beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC
1925+
beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC
19261926
ne12*ne13,
19271927
cu_compute_type,
19281928
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
@@ -1954,7 +1954,7 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
19541954
ne01, ne11, ne10,
19551955
alpha, (const void **) (ptrs_src.get() + 0*ne23), cu_data_type_a, nb01/nb00,
19561956
(const void **) (ptrs_src.get() + 1*ne23), cu_data_type_b, s11,
1957-
beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne0,
1957+
beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne0,
19581958
ne23,
19591959
cu_compute_type,
19601960
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
@@ -2033,10 +2033,11 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
20332033
//printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
20342034
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
20352035

2036+
//TODO update for generic tensor parallelism
20362037
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
2037-
bool can_use_batched_cublas_f16 = src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16);
2038-
bool can_use_batched_cublas_bf16 = src0->type == GGML_TYPE_BF16 && bf16_mma_hardware_available(cc);
2039-
bool can_use_batched_cublas_f32 = src0->type == GGML_TYPE_F32;
2038+
bool use_batched_cublas_f16 = src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16);
2039+
bool use_batched_cublas_bf16 = src0->type == GGML_TYPE_BF16 && bf16_mma_hardware_available(cc);
2040+
bool use_batched_cublas_f32 = src0->type == GGML_TYPE_F32;
20402041

20412042
if (!split && use_mul_mat_vec) {
20422043
// the custom F16 vector kernel can be used over batched cuBLAS GEMM
@@ -2046,7 +2047,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
20462047
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst);
20472048
} else if (!split && use_mul_mat_q) {
20482049
ggml_cuda_mul_mat_q(ctx, src0, src1, nullptr, dst);
2049-
} else if (!split && (can_use_batched_cublas_f16 || can_use_batched_cublas_bf16 || can_use_batched_cublas_f32)
2050+
} else if (!split && (use_batched_cublas_f16 || use_batched_cublas_bf16 || use_batched_cublas_f32)
20502051
&& !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
20512052
// general KQ + KQV multi-batch without FlashAttention
20522053
ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);

0 commit comments

Comments
 (0)