@@ -93,31 +93,26 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * s
9393
9494template <typename T>
9595static __global__ void k_repeat_back (
96- const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
97- const size_t s00, const size_t s01, const size_t s02, const size_t s03,
98- const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3) {
96+ const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02,
97+ const int64_t ne0, const int64_t ne1, const int64_t ne2) {
9998
100- const int64_t tid0 = int64_t (blockIdx .x )*blockDim .x + threadIdx .x ;
101- const int64_t tid1 = int64_t (blockIdx .y )*blockDim .y + threadIdx .y ;
102- const int64_t tid23 = int64_t (blockIdx .z )*blockDim .z + threadIdx .z ;
103- const int64_t tid2 = tid23 % ne2;
104- const int64_t tid3 = tid23 / ne2;
99+ const int64_t tid0 = (int64_t ) blockIdx .x *blockDim .x + threadIdx .x ;
100+ const int64_t tid1 = (int64_t ) blockIdx .y *blockDim .y + threadIdx .y ;
101+ const int64_t tid2 = (int64_t ) blockIdx .z *blockDim .z + threadIdx .z ;
105102
106103 if (tid0 >= ne0) {
107104 return ;
108105 }
109106
110107 T sum = 0 ;
111- for (int64_t i3 = tid3; i3 < ne03; i3 += ne3) {
112- for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) {
113- for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) {
114- for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) {
115- sum += src[i3*s03 + i2*s02 + i1*s01 + i0*s00];
116- }
108+ for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) {
109+ for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) {
110+ for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) {
111+ sum += src[i2*ne01*ne00 + i1*ne00 + i0];
117112 }
118113 }
119114 }
120- dst[tid3*ne2*ne1*ne0 + tid2*ne1*ne0 + tid1*ne0 + tid0] = sum;
115+ dst[tid2*ne1*ne0 + tid1*ne0 + tid0] = sum;
121116}
122117
123118template <float (*bin_op)(const float , const float )>
@@ -279,14 +274,12 @@ struct bin_bcast_cuda {
279274
280275template <typename T>
281276static void repeat_back_cuda (
282- const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
283- const size_t s00, const size_t s01, const size_t s02, const size_t s03,
284- const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
277+ const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02,
278+ const int64_t ne0, const int64_t ne1, const int64_t ne2, cudaStream_t stream) {
285279
286280 const dim3 block_dims (WARP_SIZE, 1 , 1 );
287- const dim3 block_nums ((ne0 + WARP_SIZE - 1 ) / WARP_SIZE, ne1, ne2*ne3);
288- k_repeat_back<T><<<block_nums, block_dims, 0 , stream>>>
289- (src, dst, ne00, ne01, ne02, ne03, s00, s01, s02, s03, ne0, ne1, ne2, ne3);
281+ const dim3 block_nums ((ne0 + WARP_SIZE - 1 ) / WARP_SIZE, ne1, ne2);
282+ k_repeat_back<T><<<block_nums, block_dims, 0 , stream>>> (src, dst, ne00, ne01, ne02, ne0, ne1, ne2);
290283}
291284
292285template <class op >
@@ -333,26 +326,27 @@ void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst
333326 const ggml_tensor * src0 = dst->src [0 ];
334327
335328 GGML_ASSERT (src0->type == dst->type );
329+ GGML_ASSERT (ggml_is_contiguous (src0));
336330 GGML_ASSERT (ggml_is_contiguous (dst));
337331 GGML_ASSERT (ggml_can_repeat (dst, src0));
338332
339333 cudaStream_t stream = ctx.stream ();
340334
341- GGML_TENSOR_UNARY_OP_LOCALS;
342-
343- GGML_ASSERT (ne2*ne3 <= (1 << 15 ));
335+ const int64_t ne00 = src0->ne [0 ];
336+ const int64_t ne01 = src0->ne [1 ];
337+ const int64_t ne02 = src0->ne [2 ];
338+ GGML_ASSERT (src0->ne [3 ] == 1 );
344339
345- const size_t ts = ggml_type_size (src0->type );
346- const size_t s00 = nb00 / ts;
347- const size_t s01 = nb01 / ts;
348- const size_t s02 = nb02 / ts;
349- const size_t s03 = nb03 / ts;
340+ const int64_t ne0 = dst->ne [0 ];
341+ const int64_t ne1 = dst->ne [1 ];
342+ const int64_t ne2 = dst->ne [2 ];
343+ GGML_ASSERT (dst->ne [3 ] == 1 );
350344
351345 switch (dst->type ) {
352346 case GGML_TYPE_F32: {
353347 const float * src0_d = (const float *) src0->data ;
354348 float * dst_d = (float *) dst->data ;
355- repeat_back_cuda (src0_d, dst_d, ne00, ne01, ne02, ne03, s00, s01, s02, s03, ne0, ne1, ne2, ne3 , stream);
349+ repeat_back_cuda< float > (src0_d, dst_d, ne00, ne01, ne02, ne0, ne1, ne2, stream);
356350 } break ;
357351 default : {
358352 GGML_ASSERT (false );
0 commit comments