Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 37 additions & 23 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1082,7 +1082,9 @@ static void ggml_cuda_op_mul_mat_cublas(

const int compute_capability = ggml_cuda_info().devices[id].cc;

if (compute_capability >= GGML_CUDA_CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT;

if (compute_capability >= GGML_CUDA_CC_VOLTA && use_fp16) {
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool(id));
if (src0->type != GGML_TYPE_F16) {
Expand All @@ -1103,28 +1105,38 @@ static void ggml_cuda_op_mul_mat_cublas(
to_fp16_cuda(src1_ddf_i, src1_as_f16.get(), ne, stream);
}
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16.get();
ggml_cuda_pool_alloc<half> dst_f16(ctx.pool(id), row_diff*src1_ncols);

const half alpha_f16 = 1.0f;
const half beta_f16 = 0.0f;
CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));

cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
if (ggml_cuda_info().devices[ctx.device].cc == GGML_CUDA_CC_CDNA) {
cu_compute_type = CUBLAS_COMPUTE_32F;
}
if (compute_capability == GGML_CUDA_CC_CDNA) {
const float alpha = 1.0f;
const float beta = 0.0f;
CUBLAS_CHECK(
cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
row_diff, src1_ncols, ne10,
&alpha, src0_ptr, CUDA_R_16F, ne00,
src1_ptr, CUDA_R_16F, ne10,
&beta, dst_dd_i, CUDA_R_32F, ldc,
CUBLAS_COMPUTE_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
} else {
ggml_cuda_pool_alloc<half> dst_f16(ctx.pool(id), row_diff*src1_ncols);

CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
CUBLAS_CHECK(
cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
row_diff, src1_ncols, ne10,
&alpha_f16, src0_ptr, CUDA_R_16F, ne00,
src1_ptr, CUDA_R_16F, ne10,
&beta_f16, dst_f16.get(), CUDA_R_16F, ldc,
cu_compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
const half alpha_f16 = 1.0f;
const half beta_f16 = 0.0f;

const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
to_fp32_cuda(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
CUBLAS_CHECK(
cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
row_diff, src1_ncols, ne10,
&alpha_f16, src0_ptr, CUDA_R_16F, ne00,
src1_ptr, CUDA_R_16F, ne10,
&beta_f16, dst_dd_i, CUDA_R_16F, ldc,
CUBLAS_COMPUTE_16F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));

const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
to_fp32_cuda(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
}
} else {
ggml_cuda_pool_alloc<float> src0_ddq_as_f32(ctx.pool(id));
ggml_cuda_pool_alloc<float> src1_ddq_as_f32(ctx.pool(id));
Expand Down Expand Up @@ -1613,10 +1625,6 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
cudaDataType_t cu_data_type = CUDA_R_16F;

if (ggml_cuda_info().devices[ctx.device].cc == GGML_CUDA_CC_CDNA) {
cu_compute_type = CUBLAS_COMPUTE_32F;
}

// dst strides
size_t nbd2 = dst->nb[2];
size_t nbd3 = dst->nb[3];
Expand Down Expand Up @@ -1645,6 +1653,12 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
beta = &beta_f32;
}

if (ggml_cuda_info().devices[ctx.device].cc == GGML_CUDA_CC_CDNA) {
cu_compute_type = CUBLAS_COMPUTE_32F;
alpha = &alpha_f32;
beta = &beta_f32;
}

GGML_ASSERT(ne12 % ne02 == 0);
GGML_ASSERT(ne13 % ne03 == 0);

Expand Down
3 changes: 2 additions & 1 deletion ggml/src/ggml-cuda/mmvq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ static void mul_mat_vec_q_cuda(
int64_t nwarps = 1;
int64_t rows_per_cuda_block = 1;

if (ggml_cuda_info().devices[id].cc < GGML_CUDA_CC_CDNA || ggml_cuda_info().devices[id].cc == GGML_CUDA_CC_RDNA1) { // NVIDIA and AMD older than RDNA2 but not CDNA
if (ggml_cuda_info().devices[id].cc < GGML_CUDA_CC_RDNA2) { // NVIDIA and AMD older than RDNA2
switch(ncols_y) {
case 1:
nwarps = 4;
Expand All @@ -166,6 +166,7 @@ static void mul_mat_vec_q_cuda(
break;
}
}

const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1) / rows_per_cuda_block;
const dim3 block_nums(nblocks, 1, 1);
const dim3 block_dims(WARP_SIZE, nwarps, 1);
Expand Down
Loading