22#include < cstdint>
33#include < utility>
44
5- static __device__ __forceinline__ float op_repeat (const float a, const float b) {
6- return b;
7- GGML_UNUSED (a);
8- }
9-
10- static __device__ __forceinline__ float op_add (const float a, const float b) {
11- return a + b;
12- }
13-
14- static __device__ __forceinline__ float op_sub (const float a, const float b) {
15- return a - b;
16- }
17-
18- static __device__ __forceinline__ float op_mul (const float a, const float b) {
19- return a * b;
20- }
21-
22- static __device__ __forceinline__ float op_div (const float a, const float b) {
23- return a / b;
24- }
25-
26- template <float (*bin_op)(const float , const float ), typename src0_t , typename src1_t , typename dst_t , typename ... S1Ptrs>
5+ template <float (*bin_op)(const float , const float ), typename src0_t , typename src1_t , typename dst_t , typename ... src1_ptrs>
276static __global__ void k_bin_bcast (const src0_t * src0, const src1_t * src1, dst_t * dst,
287 const int ne0, const int ne1, const int ne2, const int ne3,
298 const int ne10, const int ne11, const int ne12, const int ne13,
309 /* int s0, */ const int s1, const int s2, const int s3,
3110 /* int s00,*/ const int s01, const int s02, const int s03,
3211 /* int s10,*/ const int s11, const int s12, const int s13,
33- S1Ptrs ... src1s) {
12+ src1_ptrs ... src1s) {
3413 const int i0s = blockDim .x *blockIdx .x + threadIdx .x ;
3514 const int i1 = (blockDim .y *blockIdx .y + threadIdx .y );
3615 const int i2 = (blockDim .z *blockIdx .z + threadIdx .z ) / ne3;
@@ -55,26 +34,20 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst
5534 const int i10 = i0 % ne10;
5635
5736 float result = src0_row ? (float ) src0_row[i0] : 0 .0f ;
58-
59- auto add_one = [&](const src1_t * p) {
60- const src1_t * row = p + i_src1;
61- result = bin_op (result, (float ) row[i10]);
62- return 0 ;
63- };
64- (void ) std::initializer_list<int >{ (add_one (src1s), 0 )... };
37+ result = (..., (result = bin_op (result, (float )src1s[i_src1 + i10])));
6538
6639 dst_row[i0] = (dst_t ) result;
6740 }
6841}
6942
70- template <float (*bin_op)(const float , const float ), typename src0_t , typename src1_t , typename dst_t , typename ... S1Ptrs >
43+ template <float (*bin_op)(const float , const float ), typename src0_t , typename src1_t , typename dst_t , typename ... src1_ptrs >
7144static __global__ void k_bin_bcast_unravel (const src0_t * src0, const src1_t * src1, dst_t * dst,
7245 const int ne0, const int ne1, const int ne2,const int ne3,
7346 const int ne10, const int ne11, const int ne12, const int ne13,
7447 /* int s0, */ const int s1, const int s2, const int s3,
7548 /* int s00,*/ const int s01, const int s02, const int s03,
7649 /* int s10,*/ const int s11, const int s12, const int s13,
77- S1Ptrs ... src1s) {
50+ src1_ptrs ... src1s) {
7851 const int i = blockDim .x *blockIdx .x + threadIdx .x ;
7952
8053 const int i3 = i/(ne2*ne1*ne0);
@@ -100,13 +73,7 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t *
10073 const int i10 = i0 % ne10;
10174
10275 float result = src0_row ? (float ) src0_row[i0] : 0 .0f ;
103-
104- auto add_one = [&](const src1_t * p) {
105- const src1_t * row = p + i_src1;
106- result = bin_op (result, (float ) row[i10]);
107- return 0 ;
108- };
109- (void ) std::initializer_list<int >{ (add_one (src1s), 0 )... };
76+ result = (..., (result = bin_op (result, (float )src1s[i_src1 + i10])));
11077
11178 dst_row[i0] = (dst_t ) result;
11279}
@@ -291,7 +258,8 @@ static __global__ void k_repeat_back(
291258 dst[tid3*ne2*ne1*ne0 + tid2*ne1*ne0 + tid1*ne0 + tid0] = sum;
292259}
293260
294- template <float (*bin_op)(const float , const float ), int n_fuse = 1 > struct bin_bcast_cuda {
261+ template <float (*bin_op)(const float , const float ), int n_fuse = 1 >
262+ struct bin_bcast_cuda {
295263 template <typename src0_t , typename src1_t , typename dst_t >
296264 void operator ()(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst,
297265 const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd,
@@ -355,26 +323,27 @@ void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
355323 ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_div>>(dst->src [0 ], dst->src [1 ], dst, dst->src [0 ]->data , dst->src [1 ]->data , dst->data , ctx.stream ());
356324}
357325
358- template <int n_fuse> static void ggml_cuda_op_fused_add_impl (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
326+ template <float (*op)(const float , const float ), int n_fuse>
327+ static void ggml_cuda_op_fused_binbcast_impl (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
359328 cudaStream_t stream = ctx.stream ();
360329
361330 const ggml_tensor * src0 = dst->src [0 ];
362331 const ggml_tensor * src1 = dst->src [1 ];
363332
364333 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
365- launch_bin_bcast_pack<op_add , float , float , float >(src0, src1, dst,
334+ launch_bin_bcast_pack<op , float , float , float >(src0, src1, dst,
366335 (const float *) src0->data , (const float *) src1->data , (float *) dst->data ,
367336 stream, std::make_index_sequence<n_fuse>{});
368337 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
369- launch_bin_bcast_pack<op_add , half, half, half>(src0, src1, dst,
338+ launch_bin_bcast_pack<op , half, half, half>(src0, src1, dst,
370339 (const half *) src0->data , (const half *) src1->data , (half *) dst->data ,
371340 stream, std::make_index_sequence<n_fuse>{});
372341 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
373- launch_bin_bcast_pack<op_add , half, float , half>(src0, src1, dst,
342+ launch_bin_bcast_pack<op , half, float , half>(src0, src1, dst,
374343 (const half *) src0->data , (const float *) src1->data , (half *) dst->data ,
375344 stream, std::make_index_sequence<n_fuse>{});
376345 } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
377- launch_bin_bcast_pack<op_add , half, float , float >(src0, src1, dst,
346+ launch_bin_bcast_pack<op , half, float , float >(src0, src1, dst,
378347 (const half *) src0->data , (const float *) src1->data , (float *) dst->data ,
379348 stream, std::make_index_sequence<n_fuse>{});
380349 } else {
@@ -385,30 +354,32 @@ template <int n_fuse> static void ggml_cuda_op_fused_add_impl(ggml_backend_cuda_
385354 }
386355}
387356
388- void ggml_cuda_op_fused_add (ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse) {
357+
358+ template <float (*op)(const float , const float )>
359+ void ggml_cuda_op_fused_binbcast (ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse) {
389360 GGML_ASSERT (2 <= n_fuse && n_fuse <= 8 );
390361
391362 switch (n_fuse) {
392363 case 2 :
393- ggml_cuda_op_fused_add_impl< 2 >(ctx, dst);
364+ ggml_cuda_op_fused_binbcast_impl<op, 2 >(ctx, dst);
394365 break ;
395366 case 3 :
396- ggml_cuda_op_fused_add_impl< 3 >(ctx, dst);
367+ ggml_cuda_op_fused_binbcast_impl<op, 3 >(ctx, dst);
397368 break ;
398369 case 4 :
399- ggml_cuda_op_fused_add_impl< 4 >(ctx, dst);
370+ ggml_cuda_op_fused_binbcast_impl<op, 4 >(ctx, dst);
400371 break ;
401372 case 5 :
402- ggml_cuda_op_fused_add_impl< 5 >(ctx, dst);
373+ ggml_cuda_op_fused_binbcast_impl<op, 5 >(ctx, dst);
403374 break ;
404375 case 6 :
405- ggml_cuda_op_fused_add_impl< 6 >(ctx, dst);
376+ ggml_cuda_op_fused_binbcast_impl<op, 6 >(ctx, dst);
406377 break ;
407378 case 7 :
408- ggml_cuda_op_fused_add_impl< 7 >(ctx, dst);
379+ ggml_cuda_op_fused_binbcast_impl<op, 7 >(ctx, dst);
409380 break ;
410381 case 8 :
411- ggml_cuda_op_fused_add_impl< 8 >(ctx, dst);
382+ ggml_cuda_op_fused_binbcast_impl<op, 8 >(ctx, dst);
412383 break ;
413384 default :
414385 GGML_ASSERT (false && " Unsupported n_fuse value" );
@@ -445,3 +416,5 @@ void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst
445416 } break ;
446417 }
447418}
419+
420+ template void ggml_cuda_op_fused_binbcast<op_add>(ggml_backend_cuda_context &, ggml_tensor *, int );
0 commit comments