Skip to content

Commit c262dc8

Browse files
JohannesGaesslerggerganov
authored andcommitted
CPU/CUDA: fix (GQA) mul mat back, add CUDA support (llama/11380)
1 parent 30767b4 commit c262dc8

File tree

6 files changed

+59
-43
lines changed

6 files changed

+59
-43
lines changed

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7883,7 +7883,7 @@ static void ggml_compute_forward_out_prod_f32(
78837883

78847884
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
78857885
float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
7886-
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
7886+
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
78877887

78887888
ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1);
78897889
}
@@ -7892,7 +7892,7 @@ static void ggml_compute_forward_out_prod_f32(
78927892

78937893
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
78947894
float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
7895-
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
7895+
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
78967896

78977897
ggml_vec_mad_f32(ne0, d, s0, *s1);
78987898
}

ggml/src/ggml-cpu/ggml-cpu.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,8 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
416416
case GGML_OP_IM2COL_BACK:
417417
return src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32;
418418
case GGML_OP_OUT_PROD:
419-
return (src0->type == GGML_TYPE_F32 || ggml_is_quantized(src0->type)) && src1->type == GGML_TYPE_F32;
419+
return (src0->type == GGML_TYPE_F32 || (ggml_is_quantized(src0->type) && src0->ne[2] == src1->ne[2] && src0->ne[3] == src1->ne[3])) &&
420+
src1->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
420421
default:
421422
return true;
422423
}

ggml/src/ggml-cuda/binbcast.cu

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -93,26 +93,31 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * s
9393

9494
template <typename T>
9595
static __global__ void k_repeat_back(
96-
const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02,
97-
const int64_t ne0, const int64_t ne1, const int64_t ne2) {
96+
const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
97+
const size_t s00, const size_t s01, const size_t s02, const size_t s03,
98+
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3) {
9899

99-
const int64_t tid0 = (int64_t) blockIdx.x*blockDim.x + threadIdx.x;
100-
const int64_t tid1 = (int64_t) blockIdx.y*blockDim.y + threadIdx.y;
101-
const int64_t tid2 = (int64_t) blockIdx.z*blockDim.z + threadIdx.z;
100+
const int64_t tid0 = int64_t(blockIdx.x)*blockDim.x + threadIdx.x;
101+
const int64_t tid1 = int64_t(blockIdx.y)*blockDim.y + threadIdx.y;
102+
const int64_t tid23 = int64_t(blockIdx.z)*blockDim.z + threadIdx.z;
103+
const int64_t tid2 = tid23 % ne2;
104+
const int64_t tid3 = tid23 / ne2;
102105

103106
if (tid0 >= ne0) {
104107
return;
105108
}
106109

107110
T sum = 0;
108-
for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) {
109-
for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) {
110-
for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) {
111-
sum += src[i2*ne01*ne00 + i1*ne00 + i0];
111+
for (int64_t i3 = tid3; i3 < ne03; i3 += ne3) {
112+
for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) {
113+
for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) {
114+
for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) {
115+
sum += src[i3*s03 + i2*s02 + i1*s01 + i0*s00];
116+
}
112117
}
113118
}
114119
}
115-
dst[tid2*ne1*ne0 + tid1*ne0 + tid0] = sum;
120+
dst[tid3*ne2*ne1*ne0 + tid2*ne1*ne0 + tid1*ne0 + tid0] = sum;
116121
}
117122

118123
template<float (*bin_op)(const float, const float)>
@@ -274,12 +279,14 @@ struct bin_bcast_cuda {
274279

275280
template <typename T>
276281
static void repeat_back_cuda(
277-
const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02,
278-
const int64_t ne0, const int64_t ne1, const int64_t ne2, cudaStream_t stream) {
282+
const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
283+
const size_t s00, const size_t s01, const size_t s02, const size_t s03,
284+
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
279285

280286
const dim3 block_dims(WARP_SIZE, 1, 1);
281-
const dim3 block_nums((ne0 + WARP_SIZE - 1) / WARP_SIZE, ne1, ne2);
282-
k_repeat_back<T><<<block_nums, block_dims, 0, stream>>>(src, dst, ne00, ne01, ne02, ne0, ne1, ne2);
287+
const dim3 block_nums((ne0 + WARP_SIZE - 1) / WARP_SIZE, ne1, ne2*ne3);
288+
k_repeat_back<T><<<block_nums, block_dims, 0, stream>>>
289+
(src, dst, ne00, ne01, ne02, ne03, s00, s01, s02, s03, ne0, ne1, ne2, ne3);
283290
}
284291

285292
template<class op>
@@ -326,27 +333,26 @@ void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst
326333
const ggml_tensor * src0 = dst->src[0];
327334

328335
GGML_ASSERT(src0->type == dst->type);
329-
GGML_ASSERT(ggml_is_contiguous(src0));
330336
GGML_ASSERT(ggml_is_contiguous(dst));
331337
GGML_ASSERT(ggml_can_repeat(dst, src0));
332338

333339
cudaStream_t stream = ctx.stream();
334340

335-
const int64_t ne00 = src0->ne[0];
336-
const int64_t ne01 = src0->ne[1];
337-
const int64_t ne02 = src0->ne[2];
338-
GGML_ASSERT(src0->ne[3] == 1);
341+
GGML_TENSOR_UNARY_OP_LOCALS;
342+
343+
GGML_ASSERT(ne2*ne3 <= (1 << 15));
339344

340-
const int64_t ne0 = dst->ne[0];
341-
const int64_t ne1 = dst->ne[1];
342-
const int64_t ne2 = dst->ne[2];
343-
GGML_ASSERT(dst->ne[3] == 1);
345+
const size_t ts = ggml_type_size(src0->type);
346+
const size_t s00 = nb00 / ts;
347+
const size_t s01 = nb01 / ts;
348+
const size_t s02 = nb02 / ts;
349+
const size_t s03 = nb03 / ts;
344350

345351
switch (dst->type) {
346352
case GGML_TYPE_F32: {
347353
const float * src0_d = (const float *) src0->data;
348354
float * dst_d = (float *) dst->data;
349-
repeat_back_cuda<float>(src0_d, dst_d, ne00, ne01, ne02, ne0, ne1, ne2, stream);
355+
repeat_back_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s00, s01, s02, s03, ne0, ne1, ne2, ne3, stream);
350356
} break;
351357
default: {
352358
GGML_ASSERT(false);

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3002,7 +3002,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
30023002
return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
30033003
} break;
30043004
case GGML_OP_REPEAT_BACK:
3005-
return op->type == GGML_TYPE_F32 && op->src[0]->ne[3] == 1;
3005+
return op->type == GGML_TYPE_F32 && (op->src[0]->ne[2]*op->src[0]->ne[3]) <= (1 << 15);
30063006
case GGML_OP_CONCAT:
30073007
{
30083008
ggml_type src0_type = op->src[0]->type;

ggml/src/ggml-cuda/out-prod.cu

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
3434

3535
CUBLAS_CHECK(cublasSetStream(handle, stream));
3636

37+
const int64_t lda = nb01 / sizeof(float);
38+
const int64_t ldc = nb1 / sizeof(float);
39+
3740
const bool src1_T = ggml_is_transposed(src1);
3841
const cublasOperation_t src1_cublas_op = src1_T ? CUBLAS_OP_N : CUBLAS_OP_T;
3942
const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float);
@@ -57,9 +60,9 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
5760
CUBLAS_CHECK(
5861
cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op,
5962
ne0, ne1, ne01,
60-
&alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, ne00,
63+
&alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, lda,
6164
src1_d + i3 *s13 + i2 *s12, ldb,
62-
&beta, dst_d + i3 *s3 + i2 *s2, ne0));
65+
&beta, dst_d + i3 *s3 + i2 *s2, ldc));
6366
}
6467
}
6568
}

ggml/src/ggml.c

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5343,7 +5343,7 @@ static void ggml_compute_backward(
53435343
} break;
53445344
case GGML_OP_MUL: {
53455345
if (src0_needs_grads) {
5346-
ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, src1, grad));
5346+
ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, src1));
53475347
}
53485348
if (src1_needs_grads) {
53495349
struct ggml_tensor * tmp = ggml_mul(ctx, src0, grad);
@@ -5435,21 +5435,25 @@ static void ggml_compute_backward(
54355435
// src1.shape [n,p,qq,rr]
54365436

54375437
if (src0_needs_grads) {
5438-
struct ggml_tensor * s1_tg =
5438+
GGML_ASSERT(grad->ne[2] == src1->ne[2]);
5439+
GGML_ASSERT(grad->ne[3] == src1->ne[3]);
5440+
struct ggml_tensor * tmp =
54395441
ggml_out_prod(ctx, // [n,m,qq,rr]
54405442
src1, // [n,p,qq,rr]
54415443
grad); // [m,p,qq,rr]
5442-
const int64_t qq = s1_tg->ne[2];
5443-
const int64_t rr = s1_tg->ne[3];
5444-
const int64_t q1 = src0->ne[2];
5445-
const int64_t r1 = src0->ne[3];
5446-
const bool ne2_broadcasted = qq > q1;
5447-
const bool ne3_broadcasted = rr > r1;
5448-
if (ne2_broadcasted || ne3_broadcasted) {
5449-
// sum broadcast repetitions of s1_tg into shape of src0
5450-
s1_tg = ggml_repeat_back(ctx, s1_tg, src0);
5444+
if (!ggml_are_same_shape(tmp, src0)) {
5445+
GGML_ASSERT(tmp->ne[0] == src0->ne[0]);
5446+
GGML_ASSERT(tmp->ne[1] == src0->ne[1]);
5447+
GGML_ASSERT(tmp->ne[3] == 1);
5448+
5449+
const int64_t nr2 = tmp->ne[2] / src0->ne[2];
5450+
const size_t nb2 = tmp->nb[2] * nr2;
5451+
const size_t nb3 = tmp->nb[2];
5452+
5453+
tmp = ggml_view_4d(ctx, tmp, src0->ne[0], src0->ne[1], src0->ne[2], nr2, tmp->nb[1], nb2, nb3, 0);
5454+
tmp = ggml_repeat_back(ctx, tmp, src0);
54515455
}
5452-
ggml_add_or_set(ctx, cgraph, isrc0, s1_tg /*= [n,m,q1,r1]*/);
5456+
ggml_add_or_set(ctx, cgraph, isrc0, tmp);
54535457
}
54545458
if (src1_needs_grads) {
54555459
ggml_add_or_set(ctx, cgraph, isrc1,
@@ -5518,7 +5522,9 @@ static void ggml_compute_backward(
55185522
if (src0_needs_grads) {
55195523
GGML_ASSERT(!cgraph->grads[isrc0] || ggml_is_contiguous(cgraph->grads[isrc0]));
55205524
GGML_ASSERT(ggml_is_contiguous(grad));
5521-
ggml_add_or_set(ctx, cgraph, isrc0, grad);
5525+
GGML_ASSERT(ggml_nelements(tensor) == ggml_nelements(src0));
5526+
ggml_add_or_set(ctx, cgraph, isrc0,
5527+
ggml_are_same_shape(tensor, src0) ? grad : ggml_reshape(ctx, grad, src0));
55225528
}
55235529
} break;
55245530
case GGML_OP_RESHAPE: {

0 commit comments

Comments
 (0)