@@ -94,7 +94,9 @@ static void concat_f32_cuda(const float * x, const float * y, float * dst, int n
9494}
9595
9696//  non-contiguous kernel (slow)
97- static  __global__  void  concat_f32_non_cont (
97+ template  <int  dim>
98+ static  __global__  void  __launch_bounds__ (CUDA_CONCAT_BLOCK_SIZE)
99+     concat_f32_non_cont(
98100        const  char  * src0,
99101        const  char  * src1,
100102              char  * dst,
@@ -121,22 +123,28 @@ static __global__ void concat_f32_non_cont(
121123          uint64_t    nb0,
122124          uint64_t    nb1,
123125          uint64_t    nb2,
124-           uint64_t    nb3,
125-           int32_t    dim) {
126+           uint64_t    nb3){
127+     static_assert (dim >= 0  && dim <= 3 );
128+ 
126129    const  int64_t  i3 = blockIdx .z ;
127130    const  int64_t  i2 = blockIdx .y ;
128131    const  int64_t  i1 = blockIdx .x ;
129132
130-     int64_t  o[4 ] = {0 , 0 , 0 , 0 };
131-     o[dim] = dim == 0  ? ne00 : (dim == 1  ? ne01 : (dim == 2  ? ne02 : ne03));
132- 
133133    const  float  * x;
134134
135-     for  (int  i0 = threadIdx .x ; i0 < ne0; i0 += blockDim .x ) {
135+     for  (int64_t  i0 = threadIdx .x ; i0 < ne0; i0 += blockDim .x ) {
136136        if  (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
137137            x = (const  float  *)(src0 + (i3       )*nb03 + (i2       )*nb02 + (i1       )*nb01 + (i0       )*nb00);
138138        } else  {
139-             x = (const  float  *)(src1 + (i3 - o[3 ])*nb13 + (i2 - o[2 ])*nb12 + (i1 - o[1 ])*nb11 + (i0 - o[0 ])*nb10);
139+             if  constexpr  (dim == 0 ) {
140+                 x = (const  float  *) (src1 + i3 * nb13 + i2 * nb12 + i1 * nb11 + (i0 - ne00) * nb10);
141+             } else  if  constexpr  (dim == 1 ) {
142+                 x = (const  float  *) (src1 + i3 * nb13 + i2 * nb12 + (i1 - ne01) * nb11 + i0 * nb10);
143+             } else  if  constexpr  (dim == 2 ) {
144+                 x = (const  float  *) (src1 + i3 * nb13 + (i2 - ne02) * nb12 + i1 * nb11 + i0 * nb10);
145+             } else  if  constexpr  (dim == 3 ) {
146+                 x = (const  float  *) (src1 + (i3 - ne03) * nb13 + i2 * nb12 + i1 * nb11 + i0 * nb10);
147+             }
140148        }
141149
142150        float  * y = (float  *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
@@ -182,15 +190,32 @@ void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
182190        }
183191    } else  {
184192        dim3  grid_dim (dst->ne [1 ], dst->ne [2 ], dst->ne [3 ]);
185-         concat_f32_non_cont<<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0 , stream>>> (
186-                 (const  char  *)src0->data ,
187-                 (const  char  *)src1->data ,
188-                 (      char  *)dst->data ,
193+         auto  launch_kernel = [&](auto  dim) {
194+             concat_f32_non_cont<dim><<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0 , stream>>> (
195+                 (const  char  *) src0->data , (const  char  *) src1->data , (char  *) dst->data ,
189196                src0->ne [0 ], src0->ne [1 ], src0->ne [2 ], src0->ne [3 ],
190197                src0->nb [0 ], src0->nb [1 ], src0->nb [2 ], src0->nb [3 ],
191198                src1->ne [0 ], src1->ne [1 ], src1->ne [2 ], src1->ne [3 ],
192199                src1->nb [0 ], src1->nb [1 ], src1->nb [2 ], src1->nb [3 ],
193-                 dst->ne [0 ],  dst->ne [1 ],  dst->ne [2 ],  dst->ne [3 ],
194-                 dst->nb [0 ],  dst->nb [1 ],  dst->nb [2 ],  dst->nb [3 ], dim);
200+                 dst->ne [0 ], dst->ne [1 ], dst->ne [2 ], dst->ne [3 ],
201+                 dst->nb [0 ], dst->nb [1 ], dst->nb [2 ], dst->nb [3 ]);
202+         };
203+         switch  (dim) {
204+             case  0 :
205+                 launch_kernel (std::integral_constant<int , 0 >{});
206+                 break ;
207+             case  1 :
208+                 launch_kernel (std::integral_constant<int , 1 >{});
209+                 break ;
210+             case  2 :
211+                 launch_kernel (std::integral_constant<int , 2 >{});
212+                 break ;
213+             case  3 :
214+                 launch_kernel (std::integral_constant<int , 3 >{});
215+                 break ;
216+             default :
217+                 GGML_ABORT (" Invalid dim: %d" 
218+                 break ;
219+         }
195220    }
196221}
0 commit comments