@@ -94,7 +94,9 @@ static void concat_f32_cuda(const float * x, const float * y, float * dst, int n
9494}
9595
9696// non-contiguous kernel (slow)
97- static __global__ void concat_f32_non_cont (
97+ template <int dim>
98+ static __global__ void __launch_bounds__ (CUDA_CONCAT_BLOCK_SIZE)
99+ concat_f32_non_cont(
98100 const char * src0,
99101 const char * src1,
100102 char * dst,
@@ -121,22 +123,26 @@ static __global__ void concat_f32_non_cont(
121123 uint64_t nb0,
122124 uint64_t nb1,
123125 uint64_t nb2,
124- uint64_t nb3,
125- int32_t dim) {
126+ uint64_t nb3){
126127 const int64_t i3 = blockIdx .z ;
127128 const int64_t i2 = blockIdx .y ;
128129 const int64_t i1 = blockIdx .x ;
129130
130- int64_t o[4 ] = {0 , 0 , 0 , 0 };
131- o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
132-
133131 const float * x;
134132
135- for (int i0 = threadIdx .x ; i0 < ne0; i0 += blockDim .x ) {
133+ for (int64_t i0 = threadIdx .x ; i0 < ne0; i0 += blockDim .x ) {
136134 if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
137135 x = (const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00);
138136 } else {
139- x = (const float *)(src1 + (i3 - o[3 ])*nb13 + (i2 - o[2 ])*nb12 + (i1 - o[1 ])*nb11 + (i0 - o[0 ])*nb10);
137+ if /* constexpr*/ (dim == 0 ) {
138+ x = (const float *) (src1 + i3 * nb13 + i2 * nb12 + i1 * nb11 + (i0 - ne00) * nb10);
139+ } else if (dim == 1 ) {
140+ x = (const float *) (src1 + i3 * nb13 + i2 * nb12 + (i1 - ne01) * nb11 + i0 * nb10);
141+ } else if (dim == 2 ) {
142+ x = (const float *) (src1 + i3 * nb13 + (i2 - ne02) * nb12 + i1 * nb11 + i0 * nb10);
143+ } else if (dim == 3 ) {
144+ x = (const float *) (src1 + (i3 - ne03) * nb13 + i2 * nb12 + i1 * nb11 + i0 * nb10);
145+ }
140146 }
141147
142148 float * y = (float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
@@ -182,15 +188,37 @@ void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
182188 }
183189 } else {
184190 dim3 grid_dim (dst->ne [1 ], dst->ne [2 ], dst->ne [3 ]);
185- concat_f32_non_cont<<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0 , stream>>> (
186- (const char *)src0->data ,
187- (const char *)src1->data ,
188- ( char *)dst->data ,
189- src0->ne [0 ], src0->ne [1 ], src0->ne [2 ], src0->ne [3 ],
190- src0->nb [0 ], src0->nb [1 ], src0->nb [2 ], src0->nb [3 ],
191- src1->ne [0 ], src1->ne [1 ], src1->ne [2 ], src1->ne [3 ],
192- src1->nb [0 ], src1->nb [1 ], src1->nb [2 ], src1->nb [3 ],
193- dst->ne [0 ], dst->ne [1 ], dst->ne [2 ], dst->ne [3 ],
194- dst->nb [0 ], dst->nb [1 ], dst->nb [2 ], dst->nb [3 ], dim);
191+ switch (dim) {
192+ case 0 :
193+ concat_f32_non_cont<0 ><<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0 , stream>>> (
194+ (const char *) src0->data , (const char *) src1->data , (char *) dst->data , src0->ne [0 ], src0->ne [1 ],
195+ src0->ne [2 ], src0->ne [3 ], src0->nb [0 ], src0->nb [1 ], src0->nb [2 ], src0->nb [3 ], src1->ne [0 ],
196+ src1->ne [1 ], src1->ne [2 ], src1->ne [3 ], src1->nb [0 ], src1->nb [1 ], src1->nb [2 ], src1->nb [3 ],
197+ dst->ne [0 ], dst->ne [1 ], dst->ne [2 ], dst->ne [3 ], dst->nb [0 ], dst->nb [1 ], dst->nb [2 ], dst->nb [3 ]);
198+ break ;
199+ case 1 :
200+ concat_f32_non_cont<1 ><<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0 , stream>>> (
201+ (const char *) src0->data , (const char *) src1->data , (char *) dst->data , src0->ne [0 ], src0->ne [1 ],
202+ src0->ne [2 ], src0->ne [3 ], src0->nb [0 ], src0->nb [1 ], src0->nb [2 ], src0->nb [3 ], src1->ne [0 ],
203+ src1->ne [1 ], src1->ne [2 ], src1->ne [3 ], src1->nb [0 ], src1->nb [1 ], src1->nb [2 ], src1->nb [3 ],
204+ dst->ne [0 ], dst->ne [1 ], dst->ne [2 ], dst->ne [3 ], dst->nb [0 ], dst->nb [1 ], dst->nb [2 ], dst->nb [3 ]);
205+ break ;
206+ case 2 :
207+ concat_f32_non_cont<2 ><<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0 , stream>>> (
208+ (const char *) src0->data , (const char *) src1->data , (char *) dst->data , src0->ne [0 ], src0->ne [1 ],
209+ src0->ne [2 ], src0->ne [3 ], src0->nb [0 ], src0->nb [1 ], src0->nb [2 ], src0->nb [3 ], src1->ne [0 ],
210+ src1->ne [1 ], src1->ne [2 ], src1->ne [3 ], src1->nb [0 ], src1->nb [1 ], src1->nb [2 ], src1->nb [3 ],
211+ dst->ne [0 ], dst->ne [1 ], dst->ne [2 ], dst->ne [3 ], dst->nb [0 ], dst->nb [1 ], dst->nb [2 ], dst->nb [3 ]);
212+ break ;
213+ case 3 :
214+ concat_f32_non_cont<3 ><<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0 , stream>>> (
215+ (const char *) src0->data , (const char *) src1->data , (char *) dst->data , src0->ne [0 ], src0->ne [1 ],
216+ src0->ne [2 ], src0->ne [3 ], src0->nb [0 ], src0->nb [1 ], src0->nb [2 ], src0->nb [3 ], src1->ne [0 ],
217+ src1->ne [1 ], src1->ne [2 ], src1->ne [3 ], src1->nb [0 ], src1->nb [1 ], src1->nb [2 ], src1->nb [3 ],
218+ dst->ne [0 ], dst->ne [1 ], dst->ne [2 ], dst->ne [3 ], dst->nb [0 ], dst->nb [1 ], dst->nb [2 ], dst->nb [3 ]);
219+ break ;
220+ default :
221+ break ;
222+ }
195223 }
196224}
0 commit comments