Skip to content

Commit 7747000

Browse files
ikawrakowIwan Kawrakow
andauthored
DeepSeek TG optimizations for TG (#928)
* Fuse concat and copy into K cache * Avoid ggml_cont() when n_token = 1 Combined effect: about +2% in TG performance with full GPU offload Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 9dfbc69 commit 7747000

File tree

4 files changed

+79
-5
lines changed

4 files changed

+79
-5
lines changed

ggml/src/ggml-cuda.cu

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3224,7 +3224,14 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
32243224
ggml_cuda_op_group_norm(ctx, dst);
32253225
break;
32263226
case GGML_OP_CONCAT:
3227-
ggml_cuda_op_concat(ctx, dst);
3227+
if (fusion && i + 2 < cgraph->n_nodes &&
3228+
cgraph->nodes[i+1]->op == GGML_OP_VIEW &&
3229+
cgraph->nodes[i+2]->op == GGML_OP_CPY &&
3230+
ggml_cuda_concat_cpy(ctx, dst, cgraph->nodes[i+2])) {
3231+
i += 2;
3232+
} else {
3233+
ggml_cuda_op_concat(ctx, dst);
3234+
}
32283235
break;
32293236
case GGML_OP_UPSCALE:
32303237
ggml_cuda_op_upscale(ctx, dst);

ggml/src/ggml-cuda/cpy.cu

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -686,3 +686,65 @@ bool ggml_cuda_cpy_2(ggml_backend_cuda_context & ctx, const ggml_tensor * src1,
686686
#endif
687687
return true;
688688
}
689+
690+
template <typename src_t, typename dst_t>
691+
static __global__ void concat_cpy(const char * csrc1, const char * csrc2, char * cdst, int ne1, int ne,
692+
char ** dest_ptrs, int copy_index) {
693+
694+
auto dst = (dst_t *)(dest_ptrs ? dest_ptrs[copy_index] : cdst);
695+
auto src1 = (const src_t *)csrc1;
696+
auto src2 = (const src_t *)csrc2;
697+
698+
for (int i = threadIdx.x; i < ne; i += blockDim.x) {
699+
if constexpr (std::is_same_v<dst_t, nv_bfloat16>) {
700+
dst[i] = __float2bfloat16(i < ne1 ? src1[i] : src2[i - ne1]);
701+
} else {
702+
dst[i] = (dst_t)(i < ne1 ? src1[i] : src2[i - ne1]);
703+
}
704+
}
705+
}
706+
707+
template <typename src_t, typename dst_t>
708+
static void ggml_concat_cpy_cuda(const char * src1, const char * src2, char * dst, int ne1, int ne, cudaStream_t stream,
709+
char ** dest_ptrs, int& copy_index) {
710+
711+
int block_dim = std::min(ne, 768);
712+
concat_cpy<src_t, dst_t><<<1, block_dim, 0, stream>>>(src1, src2, dst, ne1, ne, dest_ptrs, copy_index);
713+
++copy_index;
714+
}
715+
716+
bool ggml_cuda_concat_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * concat, const ggml_tensor * dst,
717+
[[maybe_unused]] bool disable_indirection) {
718+
719+
if (dst->type != GGML_TYPE_F16 && dst->type != GGML_TYPE_BF16) return false;
720+
//if (ggml_nrows(dst) > 1) return false;
721+
if (dst->src[0] != concat) return false;
722+
if (ggml_nrows(concat->src[0]) != 1 || ggml_nrows(concat->src[1]) != 1) return false;
723+
if (concat->src[0]->type != GGML_TYPE_F32 || concat->src[1]->type != GGML_TYPE_F32) return false;
724+
if (dst->ne[0] != concat->src[0]->ne[0] + concat->src[1]->ne[0]) return false;
725+
726+
char ** dest_ptrs = nullptr;
727+
int graph_cpynode_index = -1;
728+
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
729+
if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection) {
730+
dest_ptrs = ctx.cuda_graph->dest_ptrs_d;
731+
graph_cpynode_index = ctx.cuda_graph->graph_cpynode_index;
732+
}
733+
#endif
734+
735+
if (dst->type == GGML_TYPE_F16) {
736+
ggml_concat_cpy_cuda<float, half>((const char *)concat->src[0]->data, (const char *)concat->src[1]->data,
737+
(char *)dst->data, concat->src[0]->ne[0], dst->ne[0], ctx.stream(), dest_ptrs, graph_cpynode_index);
738+
} else {
739+
ggml_concat_cpy_cuda<float, nv_bfloat16>((const char *)concat->src[0]->data, (const char *)concat->src[1]->data,
740+
(char *)dst->data, concat->src[0]->ne[0], dst->ne[0], ctx.stream(), dest_ptrs, graph_cpynode_index);
741+
}
742+
743+
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
744+
if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection) {
745+
ctx.cuda_graph->graph_cpynode_index = graph_cpynode_index;
746+
}
747+
#endif
748+
return true;
749+
750+
}

ggml/src/ggml-cuda/cpy.cuh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,6 @@ void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_des
1212

1313
bool ggml_cuda_cpy_2(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1,
1414
ggml_tensor * dst1, ggml_tensor * dst2, bool disable_indirection = false);
15+
16+
bool ggml_cuda_concat_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * concat, const ggml_tensor * dst,
17+
bool disable_indirection = false);

src/llama-build-context.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6268,11 +6268,13 @@ ggml_cgraph * llm_build_context::build_deepseek2() {
62686268
kqv = ggml_mul_mat(ctx0, wv_b, kqv_compressed);
62696269
cb(kqv, "kqv", il);
62706270

6271-
kqv = ggml_cont(ctx0, ggml_permute(ctx0, kqv, 0, 2, 1, 3));
6272-
cb(kqv, "kqv_perm", il);
6273-
6274-
cur = ggml_view_2d(ctx0, kqv, n_embd_head_v*n_head, n_tokens, ggml_row_size(kqv->type, n_embd_head_v*n_head), 0);
6271+
if (n_tokens > 1) {
6272+
kqv = ggml_cont(ctx0, ggml_permute(ctx0, kqv, 0, 2, 1, 3));
6273+
cb(kqv, "kqv_perm", il);
6274+
}
6275+
cur = ggml_reshape_2d(ctx0, kqv, n_embd_head_v*n_head, n_tokens);
62756276
cb(cur, "kqv_2d", il);
6277+
62766278
}
62776279

62786280
ggml_build_forward_expand(gf, cur);

0 commit comments

Comments
 (0)