Skip to content

Commit 049d544

Browse files
committed
revert 11380
1 parent c8fad12 commit 049d544

File tree

6 files changed

+43
-59
lines changed

6 files changed

+43
-59
lines changed

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7921,7 +7921,7 @@ static void ggml_compute_forward_out_prod_f32(
79217921

79227922
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
79237923
float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
7924-
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
7924+
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
79257925

79267926
ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1);
79277927
}
@@ -7930,7 +7930,7 @@ static void ggml_compute_forward_out_prod_f32(
79307930

79317931
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
79327932
float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
7933-
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
7933+
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
79347934

79357935
ggml_vec_mad_f32(ne0, d, s0, *s1);
79367936
}

ggml/src/ggml-cpu/ggml-cpu.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -416,8 +416,7 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
416416
case GGML_OP_IM2COL_BACK:
417417
return src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32;
418418
case GGML_OP_OUT_PROD:
419-
return (src0->type == GGML_TYPE_F32 || (ggml_is_quantized(src0->type) && src0->ne[2] == src1->ne[2] && src0->ne[3] == src1->ne[3])) &&
420-
src1->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
419+
return (src0->type == GGML_TYPE_F32 || ggml_is_quantized(src0->type)) && src1->type == GGML_TYPE_F32;
421420
default:
422421
return true;
423422
}

ggml/src/ggml-cuda/binbcast.cu

Lines changed: 24 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -93,31 +93,26 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * s
9393

9494
template <typename T>
9595
static __global__ void k_repeat_back(
96-
const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
97-
const size_t s00, const size_t s01, const size_t s02, const size_t s03,
98-
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3) {
96+
const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02,
97+
const int64_t ne0, const int64_t ne1, const int64_t ne2) {
9998

100-
const int64_t tid0 = int64_t(blockIdx.x)*blockDim.x + threadIdx.x;
101-
const int64_t tid1 = int64_t(blockIdx.y)*blockDim.y + threadIdx.y;
102-
const int64_t tid23 = int64_t(blockIdx.z)*blockDim.z + threadIdx.z;
103-
const int64_t tid2 = tid23 % ne2;
104-
const int64_t tid3 = tid23 / ne2;
99+
const int64_t tid0 = (int64_t) blockIdx.x*blockDim.x + threadIdx.x;
100+
const int64_t tid1 = (int64_t) blockIdx.y*blockDim.y + threadIdx.y;
101+
const int64_t tid2 = (int64_t) blockIdx.z*blockDim.z + threadIdx.z;
105102

106103
if (tid0 >= ne0) {
107104
return;
108105
}
109106

110107
T sum = 0;
111-
for (int64_t i3 = tid3; i3 < ne03; i3 += ne3) {
112-
for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) {
113-
for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) {
114-
for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) {
115-
sum += src[i3*s03 + i2*s02 + i1*s01 + i0*s00];
116-
}
108+
for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) {
109+
for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) {
110+
for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) {
111+
sum += src[i2*ne01*ne00 + i1*ne00 + i0];
117112
}
118113
}
119114
}
120-
dst[tid3*ne2*ne1*ne0 + tid2*ne1*ne0 + tid1*ne0 + tid0] = sum;
115+
dst[tid2*ne1*ne0 + tid1*ne0 + tid0] = sum;
121116
}
122117

123118
template<float (*bin_op)(const float, const float)>
@@ -279,14 +274,12 @@ struct bin_bcast_cuda {
279274

280275
template <typename T>
281276
static void repeat_back_cuda(
282-
const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
283-
const size_t s00, const size_t s01, const size_t s02, const size_t s03,
284-
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
277+
const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02,
278+
const int64_t ne0, const int64_t ne1, const int64_t ne2, cudaStream_t stream) {
285279

286280
const dim3 block_dims(WARP_SIZE, 1, 1);
287-
const dim3 block_nums((ne0 + WARP_SIZE - 1) / WARP_SIZE, ne1, ne2*ne3);
288-
k_repeat_back<T><<<block_nums, block_dims, 0, stream>>>
289-
(src, dst, ne00, ne01, ne02, ne03, s00, s01, s02, s03, ne0, ne1, ne2, ne3);
281+
const dim3 block_nums((ne0 + WARP_SIZE - 1) / WARP_SIZE, ne1, ne2);
282+
k_repeat_back<T><<<block_nums, block_dims, 0, stream>>>(src, dst, ne00, ne01, ne02, ne0, ne1, ne2);
290283
}
291284

292285
template<class op>
@@ -333,26 +326,27 @@ void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst
333326
const ggml_tensor * src0 = dst->src[0];
334327

335328
GGML_ASSERT(src0->type == dst->type);
329+
GGML_ASSERT(ggml_is_contiguous(src0));
336330
GGML_ASSERT(ggml_is_contiguous(dst));
337331
GGML_ASSERT(ggml_can_repeat(dst, src0));
338332

339333
cudaStream_t stream = ctx.stream();
340334

341-
GGML_TENSOR_UNARY_OP_LOCALS;
342-
343-
GGML_ASSERT(ne2*ne3 <= (1 << 15));
335+
const int64_t ne00 = src0->ne[0];
336+
const int64_t ne01 = src0->ne[1];
337+
const int64_t ne02 = src0->ne[2];
338+
GGML_ASSERT(src0->ne[3] == 1);
344339

345-
const size_t ts = ggml_type_size(src0->type);
346-
const size_t s00 = nb00 / ts;
347-
const size_t s01 = nb01 / ts;
348-
const size_t s02 = nb02 / ts;
349-
const size_t s03 = nb03 / ts;
340+
const int64_t ne0 = dst->ne[0];
341+
const int64_t ne1 = dst->ne[1];
342+
const int64_t ne2 = dst->ne[2];
343+
GGML_ASSERT(dst->ne[3] == 1);
350344

351345
switch (dst->type) {
352346
case GGML_TYPE_F32: {
353347
const float * src0_d = (const float *) src0->data;
354348
float * dst_d = (float *) dst->data;
355-
repeat_back_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s00, s01, s02, s03, ne0, ne1, ne2, ne3, stream);
349+
repeat_back_cuda<float>(src0_d, dst_d, ne00, ne01, ne02, ne0, ne1, ne2, stream);
356350
} break;
357351
default: {
358352
GGML_ASSERT(false);

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3007,7 +3007,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
30073007
return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
30083008
} break;
30093009
case GGML_OP_REPEAT_BACK:
3010-
return op->type == GGML_TYPE_F32 && (op->src[0]->ne[2]*op->src[0]->ne[3]) <= (1 << 15);
3010+
return op->type == GGML_TYPE_F32 && op->src[0]->ne[3] == 1;
30113011
case GGML_OP_CONCAT:
30123012
{
30133013
ggml_type src0_type = op->src[0]->type;

ggml/src/ggml-cuda/out-prod.cu

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,6 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
3434

3535
CUBLAS_CHECK(cublasSetStream(handle, stream));
3636

37-
const int64_t lda = nb01 / sizeof(float);
38-
const int64_t ldc = nb1 / sizeof(float);
39-
4037
const bool src1_T = ggml_is_transposed(src1);
4138
const cublasOperation_t src1_cublas_op = src1_T ? CUBLAS_OP_N : CUBLAS_OP_T;
4239
const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float);
@@ -60,9 +57,9 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
6057
CUBLAS_CHECK(
6158
cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op,
6259
ne0, ne1, ne01,
63-
&alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, lda,
60+
&alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, ne00,
6461
src1_d + i3 *s13 + i2 *s12, ldb,
65-
&beta, dst_d + i3 *s3 + i2 *s2, ldc));
62+
&beta, dst_d + i3 *s3 + i2 *s2, ne0));
6663
}
6764
}
6865
}

ggml/src/ggml.c

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5352,7 +5352,7 @@ static void ggml_compute_backward(
53525352
} break;
53535353
case GGML_OP_MUL: {
53545354
if (src0_needs_grads) {
5355-
ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, src1));
5355+
ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, src1, grad));
53565356
}
53575357
if (src1_needs_grads) {
53585358
struct ggml_tensor * tmp = ggml_mul(ctx, src0, grad);
@@ -5444,25 +5444,21 @@ static void ggml_compute_backward(
54445444
// src1.shape [n,p,qq,rr]
54455445

54465446
if (src0_needs_grads) {
5447-
GGML_ASSERT(grad->ne[2] == src1->ne[2]);
5448-
GGML_ASSERT(grad->ne[3] == src1->ne[3]);
5449-
struct ggml_tensor * tmp =
5447+
struct ggml_tensor * s1_tg =
54505448
ggml_out_prod(ctx, // [n,m,qq,rr]
54515449
src1, // [n,p,qq,rr]
54525450
grad); // [m,p,qq,rr]
5453-
if (!ggml_are_same_shape(tmp, src0)) {
5454-
GGML_ASSERT(tmp->ne[0] == src0->ne[0]);
5455-
GGML_ASSERT(tmp->ne[1] == src0->ne[1]);
5456-
GGML_ASSERT(tmp->ne[3] == 1);
5457-
5458-
const int64_t nr2 = tmp->ne[2] / src0->ne[2];
5459-
const size_t nb2 = tmp->nb[2] * nr2;
5460-
const size_t nb3 = tmp->nb[2];
5461-
5462-
tmp = ggml_view_4d(ctx, tmp, src0->ne[0], src0->ne[1], src0->ne[2], nr2, tmp->nb[1], nb2, nb3, 0);
5463-
tmp = ggml_repeat_back(ctx, tmp, src0);
5451+
const int64_t qq = s1_tg->ne[2];
5452+
const int64_t rr = s1_tg->ne[3];
5453+
const int64_t q1 = src0->ne[2];
5454+
const int64_t r1 = src0->ne[3];
5455+
const bool ne2_broadcasted = qq > q1;
5456+
const bool ne3_broadcasted = rr > r1;
5457+
if (ne2_broadcasted || ne3_broadcasted) {
5458+
// sum broadcast repetitions of s1_tg into shape of src0
5459+
s1_tg = ggml_repeat_back(ctx, s1_tg, src0);
54645460
}
5465-
ggml_add_or_set(ctx, cgraph, isrc0, tmp);
5461+
ggml_add_or_set(ctx, cgraph, isrc0, s1_tg /*= [n,m,q1,r1]*/);
54665462
}
54675463
if (src1_needs_grads) {
54685464
ggml_add_or_set(ctx, cgraph, isrc1,
@@ -5531,9 +5527,7 @@ static void ggml_compute_backward(
55315527
if (src0_needs_grads) {
55325528
GGML_ASSERT(!cgraph->grads[isrc0] || ggml_is_contiguous(cgraph->grads[isrc0]));
55335529
GGML_ASSERT(ggml_is_contiguous(grad));
5534-
GGML_ASSERT(ggml_nelements(tensor) == ggml_nelements(src0));
5535-
ggml_add_or_set(ctx, cgraph, isrc0,
5536-
ggml_are_same_shape(tensor, src0) ? grad : ggml_reshape(ctx, grad, src0));
5530+
ggml_add_or_set(ctx, cgraph, isrc0, grad);
55375531
}
55385532
} break;
55395533
case GGML_OP_RESHAPE: {

0 commit comments

Comments
 (0)