@@ -57,7 +57,11 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst
5757        const  int  i10 = i0 % ne10;
5858
5959        float  result = src0_row ? (float ) src0_row[i0] : 0 .0f ;
60-         result = (..., (result = bin_op (result, (float )src1s[i_src1 + i10])));
60+         if  constexpr  (sizeof ...(src1_ptrs) > 0 ) {
61+             result = (..., (result = bin_op (result, (float )src1s[i_src1 + i10])));
62+         } else  {
63+             result = bin_op (result, (float )src1[i_src1 + i10]);
64+         }
6165
6266        dst_row[i0] = (dst_t ) result;
6367    }
@@ -96,7 +100,11 @@ static __global__ void k_bin_bcast_unravel(const src0_t *   src0, const src1_t *
96100    const  int  i10 = i0 % ne10;
97101
98102    float  result = src0_row ? (float ) src0_row[i0] : 0 .0f ;
99-     result = (..., (result = bin_op (result, (float )src1s[i_src1 + i10])));
103+     if  constexpr  (sizeof ...(src1_ptrs) > 0 ) {
104+         result = (..., (result = bin_op (result, (float )src1s[i_src1 + i10])));
105+     } else  {
106+         result = bin_op (result, (float )src1[i_src1 + i10]);
107+     }
100108
101109    dst_row[i0] = (dst_t ) result;
102110}
@@ -231,23 +239,43 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
231239
232240        if  (block_nums.z  > 65535 ) {
233241            int  block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1 ) / block_size;
234-             k_bin_bcast_unravel<bin_op, src0_t , src1_t , dst_t >
235-                 <<<block_num, block_size, 0 , stream>>> (src0_dd, src1_dd, dst_dd,
236-                     ne0, ne1, ne2, ne3,
237-                     ne10, ne11, ne12, ne13,
238-                     /*  s0, */   s1, s2, s3,
239-                     /*  s00,*/   s01, s02, s03,
240-                     /*  s10,*/   s11, s12,s13,
241-                     (const  src1_t  *) dst->src [I + 1 ]->data ...);
242+             if  constexpr  (sizeof ...(I) > 0 ) {
243+                 k_bin_bcast_unravel<bin_op, src0_t , src1_t , dst_t >
244+                     <<<block_num, block_size, 0 , stream>>> (src0_dd, src1_dd, dst_dd,
245+                         ne0, ne1, ne2, ne3,
246+                         ne10, ne11, ne12, ne13,
247+                         /*  s0, */   s1, s2, s3,
248+                         /*  s00,*/   s01, s02, s03,
249+                         /*  s10,*/   s11, s12,s13,
250+                         (const  src1_t  *) dst->src [I + 1 ]->data ...);
251+             } else  {
252+                 k_bin_bcast_unravel<bin_op, src0_t , src1_t , dst_t >
253+                     <<<block_num, block_size, 0 , stream>>> (src0_dd, src1_dd, dst_dd,
254+                         ne0, ne1, ne2, ne3,
255+                         ne10, ne11, ne12, ne13,
256+                         /*  s0, */   s1, s2, s3,
257+                         /*  s00,*/   s01, s02, s03,
258+                         /*  s10,*/   s11, s12,s13);
259+             }
242260        } else  {
243-             k_bin_bcast<bin_op, src0_t , src1_t , dst_t >
244-                 <<<block_nums, block_dims, 0 , stream>>> (src0_dd, src1_dd, dst_dd,
245-                     ne0, ne1, ne2, ne3,
246-                     ne10, ne11, ne12, ne13,
247-                     /*  s0, */   s1, s2, s3,
248-                     /*  s00,*/   s01, s02, s03,
249-                     /*  s10,*/   s11, s12,s13,
250-                     (const  src1_t  *) dst->src [I + 1 ]->data ...);
261+             if  constexpr  (sizeof ...(I) > 0 ) {
262+                 k_bin_bcast<bin_op, src0_t , src1_t , dst_t >
263+                     <<<block_nums, block_dims, 0 , stream>>> (src0_dd, src1_dd, dst_dd,
264+                         ne0, ne1, ne2, ne3,
265+                         ne10, ne11, ne12, ne13,
266+                         /*  s0, */   s1, s2, s3,
267+                         /*  s00,*/   s01, s02, s03,
268+                         /*  s10,*/   s11, s12,s13,
269+                         (const  src1_t  *) dst->src [I + 1 ]->data ...);
270+             } else  {
271+                 k_bin_bcast<bin_op, src0_t , src1_t , dst_t >
272+                     <<<block_nums, block_dims, 0 , stream>>> (src0_dd, src1_dd, dst_dd,
273+                         ne0, ne1, ne2, ne3,
274+                         ne10, ne11, ne12, ne13,
275+                         /*  s0, */   s1, s2, s3,
276+                         /*  s00,*/   s01, s02, s03,
277+                         /*  s10,*/   s11, s12,s13);
278+             }
251279        }
252280    }
253281}
@@ -327,7 +355,7 @@ static void ggml_cuda_op_bin_bcast(
327355}
328356
329357void  ggml_cuda_op_repeat (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
330-     ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_repeat>>(dst, dst->src [0 ], dst, nullptr , dst->src [0 ]->data , dst->data , ctx.stream ());
358+     ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_repeat,  0 >>(dst, dst->src [0 ], dst, nullptr , dst->src [0 ]->data , dst->data , ctx.stream ());
331359}
332360
333361void  ggml_cuda_op_add (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
0 commit comments