Skip to content

Commit c0f2999

Browse files
compiladeNexesenex
authored andcommitted
graph : reduce splits for recurrent and hybrid models (ggml-org#14825)
* graph : avoid creating redundant s_copy views * graph : comment the s_copy views
1 parent f605be8 commit c0f2999

File tree

2 files changed

+34
-21
lines changed

2 files changed

+34
-21
lines changed

src/llama-graph.cpp

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1757,56 +1757,62 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
17571757

17581758
ggml_tensor * llm_graph_context::build_rs(
17591759
ggml_tensor * s,
1760-
ggml_tensor * state_copy,
1760+
ggml_tensor * state_copy_main,
1761+
ggml_tensor * state_copy_extra,
17611762
int32_t state_size,
17621763
int32_t n_seqs,
1763-
uint32_t n_kv,
1764-
uint32_t kv_head,
1765-
uint32_t kv_size,
1764+
uint32_t n_rs,
1765+
uint32_t rs_head,
1766+
uint32_t rs_size,
17661767
int32_t rs_zero,
17671768
const llm_graph_get_rows_fn & get_state_rows) const {
17681769

1769-
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_size);
1770+
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, rs_size);
17701771

17711772
// Clear a single state which will then be copied to the other cleared states.
17721773
// Note that this is a no-op when the view is zero-sized.
17731774
ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
17741775
ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
17751776

17761777
// copy states
1777-
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
1778-
// {state_size, kv_size} -> {state_size, n_seqs}
1779-
ggml_tensor * output_states = get_state_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
1778+
// NOTE: assuming the copy destinations are ALL contained between rs_head and rs_head + n_rs
1779+
// {state_size, rs_size} -> {state_size, n_seqs}
1780+
ggml_tensor * output_states = get_state_rows(ctx0, states, state_copy_main);
17801781
ggml_build_forward_expand(gf, output_states);
17811782

1782-
// copy extra states which won't be changed further (between n_seqs and n_kv)
1783-
ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0]));
1783+
// copy extra states which won't be changed further (between n_seqs and n_rs)
1784+
ggml_tensor * states_extra = ggml_get_rows(ctx0, states, state_copy_extra);
17841785
ggml_build_forward_expand(gf,
17851786
ggml_cpy(ctx0,
17861787
states_extra,
1787-
ggml_view_1d(ctx0, s, state_size*(n_kv - n_seqs), (kv_head + n_seqs)*state_size*ggml_element_size(s))));
1788+
ggml_view_1d(ctx0, s, state_size*(n_rs - n_seqs), (rs_head + n_seqs)*state_size*ggml_element_size(s))));
17881789

17891790
return output_states;
17901791
}
17911792

17921793
static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl(
17931794
ggml_context * ctx0,
1795+
const llama_ubatch & ubatch,
17941796
const llama_memory_recurrent_context * mctx_cur) {
17951797

17961798
auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
17971799

1798-
const auto n_rs = mctx_cur->get_n_rs();
1800+
const int64_t n_rs = mctx_cur->get_n_rs();
1801+
const int64_t n_seqs = ubatch.n_seqs;
17991802

18001803
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
18011804
ggml_set_input(inp->s_copy);
18021805

1806+
inp->s_copy_main = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0);
1807+
inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]);
1808+
18031809
return inp;
18041810
}
18051811

18061812
llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
18071813
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
18081814

1809-
auto inp = build_rs_inp_impl(ctx0, mctx_cur);
1815+
auto inp = build_rs_inp_impl(ctx0, ubatch, mctx_cur);
18101816

18111817
return (llm_graph_input_rs *) res->add_input(std::move(inp));
18121818
}
@@ -1819,7 +1825,9 @@ ggml_tensor * llm_graph_context::build_rs(
18191825
const llm_graph_get_rows_fn & get_state_rows) const {
18201826
const auto * kv_state = inp->mctx;
18211827

1822-
return build_rs(s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows);
1828+
return build_rs(s, inp->s_copy_main, inp->s_copy_extra, state_size, n_seqs,
1829+
kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(),
1830+
get_state_rows);
18231831
}
18241832

18251833
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
@@ -1866,7 +1874,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
18661874
llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
18671875
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
18681876

1869-
auto inp_rs = build_rs_inp_impl(ctx0, mctx_cur->get_recr());
1877+
auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr());
18701878
auto inp_attn = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
18711879

18721880
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move(inp_attn), std::move(inp_rs), mctx_cur);

src/llama-graph.h

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,12 @@ class llm_graph_input_rs : public llm_graph_input_i {
214214

215215
void set_input(const llama_ubatch * ubatch) override;
216216

217-
ggml_tensor * s_copy; // I32 [kv_size]
217+
ggml_tensor * s_copy; // I32 [n_rs]
218+
219+
// views of s_copy, computed once per graph
220+
// and shared across layers which use build_rs
221+
ggml_tensor * s_copy_main; // I32 [n_seqs]
222+
ggml_tensor * s_copy_extra; // I32 [n_rs - n_seqs]
218223

219224
const llama_memory_recurrent_context * mctx;
220225
};
@@ -730,20 +735,20 @@ struct llm_graph_context {
730735
// recurrent
731736
//
732737

733-
// TODO: avoid notion of "kv"
734738
// TODO: move this implementation to llama_memory_recurrent.
735739
// this is analogous to llama_kv_cache_unified::cpy_k / cpy_v
736740
// when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the
737741
// implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
738742
// `llama_memory_recurrent`
739743
ggml_tensor * build_rs(
740744
ggml_tensor * s,
741-
ggml_tensor * state_copy,
745+
ggml_tensor * state_copy_main,
746+
ggml_tensor * state_copy_extra,
742747
int32_t state_size,
743748
int32_t n_seqs,
744-
uint32_t n_kv,
745-
uint32_t kv_head,
746-
uint32_t kv_size,
749+
uint32_t n_rs,
750+
uint32_t rs_head,
751+
uint32_t rs_size,
747752
int32_t rs_zero,
748753
const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
749754

0 commit comments

Comments
 (0)