-
Notifications
You must be signed in to change notification settings - Fork 13.3k
graph : reuse SSM graphs #16490
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
graph : reuse SSM graphs #16490
Conversation
Very cool! I'll test shortly with Granite 4. The only thought I've had about why this might be difficult is around implementing the SSD version of SSM_SCAN. In |
|
||
bool res = true; | ||
|
||
res &= s_copy->ne[0] == mctx->get_n_rs(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mctx->get_head()
(the start of the slot) and mctx->get_rs_z()
(the first zeroed state) are used in view offsets, and so would need to match too, otherwise the graph can't really be re-used.
The case where they wouldn't match (but n_rs
matches) is when ubatches of the same size with different sequences are used.
E.g. seq_ids 0, 1, with 1 token and then seq_ids 2, 3 with 1 token, in consecutive ubatches, repeatedly.
This probably happens when using -ub 1
in the llama-parallel
example, I think (because it uses a single seq_id
per ubatch at a time, but ends up using different seq_ids while using the same size of ubatches).
(Note that I didn't actually test the changes yet, so I don't know if this is a real problem)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a check earlier for whether the sequences are the same:
Lines 443 to 457 in 638e2c2
// when we split the batch using "equal_seqs" we have to verify that the participating sequences are the same | |
// the reason is because the set of attention streams would be different for different sequences | |
if (can_reuse_ubatch && ubatch.equal_seqs()) { | |
if (!ubatch.data) { | |
// if the old ubatch does not own it's data, then we cannot guarantee that it is still alive, and | |
// therefore we cannot perform the sequence id check. normally should never happen | |
can_reuse_ubatch = false; | |
} else { | |
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { | |
can_reuse_ubatch &= ubatch.seq_id_unq[s] == other.ubatch.seq_id_unq[s]; | |
} | |
} | |
} | |
This check applies to all graphs and if not satisfied, we don't attempt to reuse the graph. I think this should cover this case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a check earlier for whether the sequences are the same
Right, this should cover the case where different sequences are used.
However, I don't think it covers the case when a sequence is cleared (which will make mctx->get_rs_z()
differ).
I'm noticing different perplexity with and without graph-reuse with a Q8_0
mamba-130m
on CPU.
(this is on the first 10 chunks of calibration_datav3
)
params | LLAMA_GRAPH_REUSE_DISABLE | PPL |
---|---|---|
-b 512 |
0 | 7.7852 |
-b 2048 |
0 | 7.8628 |
-b 512 |
1 | 7.7852 |
-b 2048 |
1 | 7.7852 |
I'm not sure it's caused by what exactly, but I'm suspecting it's either related to rs_z
or head
(since this doesn't seem to happen with non-recurrent models (I tested with a Q8_0
TinyLlama)).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ggerganov
Checking for head
and rs_z
mismatch does seem to help with the case in my previous comment, making the graph-reuse case have the same PPL as when it's not used.
Patch with changes
diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp
index 7f0c974f1..aad42d62d 100644
--- a/src/llama-graph.cpp
+++ b/src/llama-graph.cpp
@@ -258,6 +258,9 @@ bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) {
bool res = true;
+ res &= this->head == mctx->get_head();
+ res &= this->rs_z == mctx->get_rs_z();
+
res &= s_copy->ne[0] == mctx->get_n_rs();
res &= s_copy_main->ne[0] == params.ubatch.n_seqs;
@@ -482,6 +485,9 @@ bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) {
res &= inp_attn->self_kq_mask->ne[0] == mctx->get_attn()->get_n_kv();
res &= inp_attn->self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
+ res &= inp_rs->head == mctx->get_recr()->get_head();
+ res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z();
+
res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs;
@@ -1827,6 +1833,9 @@ static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl(
inp->s_copy_main = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0);
inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]);
+ inp->head = mctx_cur->get_head();
+ inp->rs_z = mctx_cur->get_rs_z();
+
return inp;
}
diff --git a/src/llama-graph.h b/src/llama-graph.h
index 394e88432..a596461bb 100644
--- a/src/llama-graph.h
+++ b/src/llama-graph.h
@@ -234,6 +234,10 @@ public:
ggml_tensor * s_copy_extra; // I32 [n_rs - n_seqs]
const llama_memory_recurrent_context * mctx;
+
+ // used in view offsets, need to match for valid graph reuse
+ uint32_t head;
+ int32_t rs_z;
};
class llm_graph_input_cross_embd : public llm_graph_input_i {
It might not be ideal to expose another way to get head
and rs_z
. But the constructor of llm_graph_input_rs
would need access to llama-memory-recurrent.h
to use mctx->get_head()
and mctx->get_rs_z()
.
Strangely enough, hybrid models like Falcon-H1 don't manifest the same problem as mamba-130m
; I can't reproduce the original problem with that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might not be ideal to expose another way to get head and rs_z. But the constructor of llm_graph_input_rs would need access to llama-memory-recurrent.h to use mctx->get_head() and mctx->get_rs_z().
Can you clarify what you mean here? The proposed solution seems OK to me.
On a related topic, would it be possible to avoid these offsets through the use of ggml_set_rows()
in a similar way as we avoided the KV cache offset for the regular attention?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@compilade I improved the state management of the recurrent state with 6589d3b. The recurrent memory context now keeps immutable values such as head
, rs_z
, etc. These can be used in the can_reuse()
logic without duplicating this state in the inputs.
@gabe-l-hart What is SSD? |
Sorry, commenting from my phone at the airport! SSD is the State Space Duality part of the |
Results looking good for MetalReuse on, fa on./bin/llama-batched-bench -m $(find-ollama-gguf.sh granite4:micro-h) -c 2048 -b 2048 -ub 512 -npp 128,256 -ntg 128 -npl 1,2,4 -ngl 99 -fa on
Reuse on, fa off./bin/llama-batched-bench -m $(find-ollama-gguf.sh granite4:micro-h) -c 2048 -b 2048 -ub 512 -npp 128,256 -ntg 128 -npl 1,2,4 -ngl 99 -fa off
Reuse off, fa onLLAMA_GRAPH_REUSE_DISABLE=1 ./bin/llama-batched-bench -m $(find-ollama-gguf.sh granite4:micro-h) -c 2048 -b 2048 -ub 512 -npp 128,256 -ntg 128 -npl 1,2,4 -ngl 99 -fa on
Reuse off, fa offLLAMA_GRAPH_REUSE_DISABLE=1 ./bin/llama-batched-bench -m $(find-ollama-gguf.sh granite4:micro-h) -c 2048 -b 2048 -ub 512 -npp 128,256 -ntg 128 -npl 1,2,4 -ngl 99 -fa off
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't followed the graph reuse implementation well enough to review the code changes here well, but the performance changes are working well for me on metal and I've validated that results match perfectly with and without graph reuse for single-concurrency.
@gabe-l-hart as a side note, I've added |
@pwilkin I thought I saw that trying to keep up with the comments! It's high on my todo list after this conference to get into your PR (partly selfishly because I want to reuse these parts) |
@gabe-l-hart Parallel performance of SSMs should be fixed with #16494 |
Thank you for digging into these performance improvements! |
I'm hitting errors on lldb ./bin/llama-cli -- -m $(find-ollama-gguf.sh granite4:micro-h) -no-cnv -p "tell me a story about a developer and their dog?" -ngl 99 --temp 0
I'll investigate further, but wanted to post in case it's about to be merged |
It looks like it broken in |
Should be ok now. I mistakenly thought that the old |
Confirmed, it's working again for me! I'll test a little further with parallel sequences, but I think it's probably ready |
Hitting assertions with lldb ./bin/llama-parallel -- -m $(find-ollama-gguf.sh granite4:micro-h) -ngl 99 -fa on -ns 10 -np 10
|
Just confirmed that I don't hit these on |
18212b0
to
2744d61
Compare
Running cleanly with those reverts |
In case it's helpful, I was seeing it consistently on the second call to debug logs
|
So the change in 00f115f does not work for some reason. We want to eventually extract the state of the recurrent memory into the memory context as we do with the KV cache implementations. But I think there is something being mutated when it should not be. For now, let's revert this and figure it out later. To clarify, the design is that when building the graph we should only reference data that is stored in the memory context (i.e. in |
Got it, that makes sense. |
2744d61
to
16d57ca
Compare
This reverts commit 00f115f.
16d57ca
to
7641e6f
Compare
Not sure if there is a reason not to enable graph reuse for recurrent graphs (mamba, hybrids, SSM, etc.). Did a few tests and seems to work, resulting in some modest perf improvements. cc @gabe-l-hart @compilade
Without graph reuse
make -j && LLAMA_GRAPH_REUSE_DISABLE=1 ./bin/llama-bench -m ../models/mamba-130m/ggml-model-f16.gguf -m ../models/granite-4-h-tiny/ggml-model-q8_0.gguf -m ../models/ai21-jamba-mini-1.7/ggml-model-q8_0.gguf -m ../models/liquidai-lfm2-2.6b/ggml-model-q4_k.gguf -fa 1 -t 1 -n 32
With graph reuse
make -j && ./bin/llama-bench -m ../models/mamba-130m/ggml-model-f16.gguf -m ../models/granite-4-h-tiny/ggml-model-q8_0.gguf -m ../models/ai21-jamba-mini-1.7/ggml-model-q8_0.gguf -m ../models/liquidai-lfm2-2.6b/ggml-model-q4_k.gguf -fa 1 -t 1 -n 32