@@ -188,37 +188,33 @@ void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
188188 }
189189 } else {
190190 dim3 grid_dim (dst->ne [1 ], dst->ne [2 ], dst->ne [3 ]);
191+ auto launch_kernel = [&](auto dim) {
192+ concat_f32_non_cont<dim><<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0 , stream>>> (
193+ (const char *) src0->data , (const char *) src1->data , (char *) dst->data ,
194+ src0->ne [0 ], src0->ne [1 ], src0->ne [2 ], src0->ne [3 ],
195+ src0->nb [0 ], src0->nb [1 ], src0->nb [2 ], src0->nb [3 ],
196+ src1->ne [0 ], src1->ne [1 ], src1->ne [2 ], src1->ne [3 ],
197+ src1->nb [0 ], src1->nb [1 ], src1->nb [2 ], src1->nb [3 ],
198+ dst->ne [0 ], dst->ne [1 ], dst->ne [2 ], dst->ne [3 ],
199+ dst->nb [0 ], dst->nb [1 ], dst->nb [2 ], dst->nb [3 ]);
200+ };
191201 switch (dim) {
192202 case 0 :
193- concat_f32_non_cont<0 ><<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0 , stream>>> (
194- (const char *) src0->data , (const char *) src1->data , (char *) dst->data , src0->ne [0 ], src0->ne [1 ],
195- src0->ne [2 ], src0->ne [3 ], src0->nb [0 ], src0->nb [1 ], src0->nb [2 ], src0->nb [3 ], src1->ne [0 ],
196- src1->ne [1 ], src1->ne [2 ], src1->ne [3 ], src1->nb [0 ], src1->nb [1 ], src1->nb [2 ], src1->nb [3 ],
197- dst->ne [0 ], dst->ne [1 ], dst->ne [2 ], dst->ne [3 ], dst->nb [0 ], dst->nb [1 ], dst->nb [2 ], dst->nb [3 ]);
203+ launch_kernel (std::integral_constant<int , 0 >{});
198204 break ;
199205 case 1 :
200- concat_f32_non_cont<1 ><<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0 , stream>>> (
201- (const char *) src0->data , (const char *) src1->data , (char *) dst->data , src0->ne [0 ], src0->ne [1 ],
202- src0->ne [2 ], src0->ne [3 ], src0->nb [0 ], src0->nb [1 ], src0->nb [2 ], src0->nb [3 ], src1->ne [0 ],
203- src1->ne [1 ], src1->ne [2 ], src1->ne [3 ], src1->nb [0 ], src1->nb [1 ], src1->nb [2 ], src1->nb [3 ],
204- dst->ne [0 ], dst->ne [1 ], dst->ne [2 ], dst->ne [3 ], dst->nb [0 ], dst->nb [1 ], dst->nb [2 ], dst->nb [3 ]);
206+ launch_kernel (std::integral_constant<int , 1 >{});
205207 break ;
206208 case 2 :
207- concat_f32_non_cont<2 ><<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0 , stream>>> (
208- (const char *) src0->data , (const char *) src1->data , (char *) dst->data , src0->ne [0 ], src0->ne [1 ],
209- src0->ne [2 ], src0->ne [3 ], src0->nb [0 ], src0->nb [1 ], src0->nb [2 ], src0->nb [3 ], src1->ne [0 ],
210- src1->ne [1 ], src1->ne [2 ], src1->ne [3 ], src1->nb [0 ], src1->nb [1 ], src1->nb [2 ], src1->nb [3 ],
211- dst->ne [0 ], dst->ne [1 ], dst->ne [2 ], dst->ne [3 ], dst->nb [0 ], dst->nb [1 ], dst->nb [2 ], dst->nb [3 ]);
209+ launch_kernel (std::integral_constant<int , 2 >{});
212210 break ;
213211 case 3 :
214- concat_f32_non_cont<3 ><<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0 , stream>>> (
215- (const char *) src0->data , (const char *) src1->data , (char *) dst->data , src0->ne [0 ], src0->ne [1 ],
216- src0->ne [2 ], src0->ne [3 ], src0->nb [0 ], src0->nb [1 ], src0->nb [2 ], src0->nb [3 ], src1->ne [0 ],
217- src1->ne [1 ], src1->ne [2 ], src1->ne [3 ], src1->nb [0 ], src1->nb [1 ], src1->nb [2 ], src1->nb [3 ],
218- dst->ne [0 ], dst->ne [1 ], dst->ne [2 ], dst->ne [3 ], dst->nb [0 ], dst->nb [1 ], dst->nb [2 ], dst->nb [3 ]);
212+ launch_kernel (std::integral_constant<int , 3 >{});
219213 break ;
220214 default :
215+ GGML_ABORT (" Invalid dim: %d" , dim);
221216 break ;
222217 }
223218 }
219+ }
224220}
0 commit comments