@@ -686,3 +686,65 @@ bool ggml_cuda_cpy_2(ggml_backend_cuda_context & ctx, const ggml_tensor * src1,
686686#endif
687687 return true ;
688688}
689+
690+ template <typename src_t , typename dst_t >
691+ static __global__ void concat_cpy (const char * csrc1, const char * csrc2, char * cdst, int ne1, int ne,
692+ char ** dest_ptrs, int copy_index) {
693+
694+ auto dst = (dst_t *)(dest_ptrs ? dest_ptrs[copy_index] : cdst);
695+ auto src1 = (const src_t *)csrc1;
696+ auto src2 = (const src_t *)csrc2;
697+
698+ for (int i = threadIdx .x ; i < ne; i += blockDim .x ) {
699+ if constexpr (std::is_same_v<dst_t , nv_bfloat16>) {
700+ dst[i] = __float2bfloat16 (i < ne1 ? src1[i] : src2[i - ne1]);
701+ } else {
702+ dst[i] = (dst_t )(i < ne1 ? src1[i] : src2[i - ne1]);
703+ }
704+ }
705+ }
706+
707+ template <typename src_t , typename dst_t >
708+ static void ggml_concat_cpy_cuda (const char * src1, const char * src2, char * dst, int ne1, int ne, cudaStream_t stream,
709+ char ** dest_ptrs, int & copy_index) {
710+
711+ int block_dim = std::min (ne, 768 );
712+ concat_cpy<src_t , dst_t ><<<1 , block_dim, 0 , stream>>> (src1, src2, dst, ne1, ne, dest_ptrs, copy_index);
713+ ++copy_index;
714+ }
715+
716+ bool ggml_cuda_concat_cpy (ggml_backend_cuda_context & ctx, const ggml_tensor * concat, const ggml_tensor * dst,
717+ [[maybe_unused]] bool disable_indirection) {
718+
719+ if (dst->type != GGML_TYPE_F16 && dst->type != GGML_TYPE_BF16) return false ;
720+ // if (ggml_nrows(dst) > 1) return false;
721+ if (dst->src [0 ] != concat) return false ;
722+ if (ggml_nrows (concat->src [0 ]) != 1 || ggml_nrows (concat->src [1 ]) != 1 ) return false ;
723+ if (concat->src [0 ]->type != GGML_TYPE_F32 || concat->src [1 ]->type != GGML_TYPE_F32) return false ;
724+ if (dst->ne [0 ] != concat->src [0 ]->ne [0 ] + concat->src [1 ]->ne [0 ]) return false ;
725+
726+ char ** dest_ptrs = nullptr ;
727+ int graph_cpynode_index = -1 ;
728+ #if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
729+ if (ctx.cuda_graph ->use_cpy_indirection && !disable_indirection) {
730+ dest_ptrs = ctx.cuda_graph ->dest_ptrs_d ;
731+ graph_cpynode_index = ctx.cuda_graph ->graph_cpynode_index ;
732+ }
733+ #endif
734+
735+ if (dst->type == GGML_TYPE_F16) {
736+ ggml_concat_cpy_cuda<float , half>((const char *)concat->src [0 ]->data , (const char *)concat->src [1 ]->data ,
737+ (char *)dst->data , concat->src [0 ]->ne [0 ], dst->ne [0 ], ctx.stream (), dest_ptrs, graph_cpynode_index);
738+ } else {
739+ ggml_concat_cpy_cuda<float , nv_bfloat16>((const char *)concat->src [0 ]->data , (const char *)concat->src [1 ]->data ,
740+ (char *)dst->data , concat->src [0 ]->ne [0 ], dst->ne [0 ], ctx.stream (), dest_ptrs, graph_cpynode_index);
741+ }
742+
743+ #if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
744+ if (ctx.cuda_graph ->use_cpy_indirection && !disable_indirection) {
745+ ctx.cuda_graph ->graph_cpynode_index = graph_cpynode_index;
746+ }
747+ #endif
748+ return true ;
749+
750+ }
0 commit comments