Skip to content

Commit d1d43c5

Browse files
committed
Reapply 11356
1 parent 7301fd6 commit d1d43c5

File tree

2 files changed

+38
-24
lines changed

2 files changed

+38
-24
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 36 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1231,8 +1231,9 @@ static void ggml_cuda_op_mul_mat_cublas(
12311231
return;
12321232
}
12331233

1234-
if (compute_capability >= GGML_CUDA_CC_VOLTA && (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
1234+
const bool use_fp16 = (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT;
12351235

1236+
if (compute_capability >= GGML_CUDA_CC_VOLTA && use_fp16) {
12361237
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
12371238
ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool(id));
12381239
if (src0->type != GGML_TYPE_F16) {
@@ -1253,28 +1254,38 @@ static void ggml_cuda_op_mul_mat_cublas(
12531254
to_fp16_cuda(src1_ddf_i, src1_as_f16.get(), src1_ncols, ne10, stream);
12541255
}
12551256
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16.get();
1256-
ggml_cuda_pool_alloc<half> dst_f16(ctx.pool(id), row_diff*src1_ncols);
12571257

1258-
const half alpha_f16 = 1.0f;
1259-
const half beta_f16 = 0.0f;
1258+
CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
12601259

1261-
cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
1262-
if (ggml_cuda_info().devices[ctx.device].cc == GGML_CUDA_CC_CDNA) {
1263-
cu_compute_type = CUBLAS_COMPUTE_32F;
1264-
}
1260+
if (compute_capability == GGML_CUDA_CC_CDNA) {
1261+
const float alpha = 1.0f;
1262+
const float beta = 0.0f;
1263+
CUBLAS_CHECK(
1264+
cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
1265+
row_diff, src1_ncols, ne10,
1266+
&alpha, src0_ptr, CUDA_R_16F, ne00,
1267+
src1_ptr, CUDA_R_16F, ne10,
1268+
&beta, dst_dd_i, CUDA_R_32F, ldc,
1269+
CUBLAS_COMPUTE_32F,
1270+
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1271+
} else {
1272+
ggml_cuda_pool_alloc<half> dst_f16(ctx.pool(id), row_diff*src1_ncols);
12651273

1266-
CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
1267-
CUBLAS_CHECK(
1268-
cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
1269-
row_diff, src1_ncols, ne10,
1270-
&alpha_f16, src0_ptr, CUDA_R_16F, ne00,
1271-
src1_ptr, CUDA_R_16F, ne10,
1272-
&beta_f16, dst_f16.get(), CUDA_R_16F, ldc,
1273-
cu_compute_type,
1274-
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1274+
const half alpha_f16 = 1.0f;
1275+
const half beta_f16 = 0.0f;
12751276

1276-
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
1277-
to_fp32_cuda(dst_f16.get(), dst_dd_i, row_diff, src1_ncols, stream);
1277+
CUBLAS_CHECK(
1278+
cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
1279+
row_diff, src1_ncols, ne10,
1280+
&alpha_f16, src0_ptr, CUDA_R_16F, ne00,
1281+
src1_ptr, CUDA_R_16F, ne10,
1282+
&beta_f16, dst_dd_i, CUDA_R_16F, ldc,
1283+
CUBLAS_COMPUTE_16F,
1284+
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1285+
1286+
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
1287+
to_fp32_cuda(dst_f16.get(), dst_dd_i, row_diff, src1_ncols, stream);
1288+
}
12781289
} else {
12791290
ggml_cuda_pool_alloc<float> src0_ddq_as_f32(ctx.pool(id));
12801291
ggml_cuda_pool_alloc<float> src1_ddq_as_f32(ctx.pool(id));
@@ -1816,10 +1827,6 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
18161827
cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
18171828
cudaDataType_t cu_data_type = CUDA_R_16F;
18181829

1819-
if (ggml_cuda_info().devices[ctx.device].cc == GGML_CUDA_CC_CDNA) {
1820-
cu_compute_type = CUBLAS_COMPUTE_32F;
1821-
}
1822-
18231830
// dst strides
18241831
size_t nbd2 = dst->nb[2];
18251832
size_t nbd3 = dst->nb[3];
@@ -1848,6 +1855,12 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
18481855
beta = &beta_f32;
18491856
}
18501857

1858+
if (ggml_cuda_info().devices[ctx.device].cc == GGML_CUDA_CC_CDNA) {
1859+
cu_compute_type = CUBLAS_COMPUTE_32F;
1860+
alpha = &alpha_f32;
1861+
beta = &beta_f32;
1862+
}
1863+
18511864
GGML_ASSERT(ne12 % ne02 == 0);
18521865
GGML_ASSERT(ne13 % ne03 == 0);
18531866

ggml/src/ggml-cuda/mmvq.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ static void mul_mat_vec_q_cuda(
147147
int64_t nwarps = 1;
148148
int64_t rows_per_cuda_block = 1;
149149

150-
if (ggml_cuda_info().devices[id].cc < GGML_CUDA_CC_CDNA || ggml_cuda_info().devices[id].cc == GGML_CUDA_CC_RDNA1) { // NVIDIA and AMD older than RDNA2 but not CDNA
150+
if (ggml_cuda_info().devices[id].cc < GGML_CUDA_CC_RDNA2) { // NVIDIA and AMD older than RDNA2
151151
switch(ncols_y) {
152152
case 1:
153153
nwarps = 4;
@@ -171,6 +171,7 @@ static void mul_mat_vec_q_cuda(
171171
break;
172172
}
173173
}
174+
174175
const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1) / rows_per_cuda_block;
175176
const dim3 block_nums(nblocks, 1, 1);
176177
const dim3 block_dims(WARP_SIZE, nwarps, 1);

0 commit comments

Comments
 (0)