-
Notifications
You must be signed in to change notification settings - Fork 13.4k
llama : remove all_pos_0, all_pos_1, all_seq_id from llama_batch #9745
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
llama : remove all_pos_0, all_pos_1, all_seq_id from llama_batch #9745
Conversation
I don't see a clear motivation for removing this. I believe that single sequence usage is by far the most common way llama.cpp is used, and removing this function will require most applications to add a lot of boilerplate. We should aim to make the llama.cpp API as simple as possible to use. |
My main motivation for this PR is that instead of having an API call solely for keeping backward-compatibility, we could keep it as an utility, not a core API. Second motivation is that Keeping these backward-compat struct member makes the code inside
|
I think in this use case, simple specify So if we really want to simplify the usage for end user, we could allow user to only set Even more simple, |
There is a lot we could do to simplify the
|
Let me clarify a bit more, what I mean was that in all examples, we always set:
So I assume that 99% of the case, if user want to work with single-sequence (the most basic usage), then
The problem with such change is that even without touching It seems OK for me to keep In any cases, I still strongly prefer to remove |
Sounds goods to me. Other than causing an ABI break, removing |
697a3f9
to
1c48616
Compare
// - pos : the positions of the respective token in the sequence | ||
// (if set to NULL, the token position will be tracked automatically by llama_decode) | ||
// - seq_id : the sequence to which the respective token belongs | ||
// (if set to NULL, the sequence ID will be assumed to be 0) | ||
// - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output | ||
// (if set to NULL, only the logits for last token will be returned) | ||
// |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@slaren @ggerganov I updated the behavior of llama_batch
to adapt to the removal of all_pos_0, all_pos_1, all_seq_id
, please let me know what you think about this implementation. Thank you!
result2 += next_token_str; | ||
|
||
if (llama_decode(ctx3, llama_batch_get_one(&next_token, 1, n_past, 1))) { | ||
if (llama_decode(ctx3, llama_batch_get_one(&next_token, 1))) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will generate a batch for seq_id == 0
and it needs to be seq_id == 1
make -j && ./llama-save-load-state -m ${some_model}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for spotting that! Fixed in 6395174
examples/perplexity/perplexity.cpp
Outdated
const int batch_start = start + j * n_batch; | ||
const int batch_size = std::min(end - batch_start, n_batch); | ||
|
||
llama_batch batch = llama_batch_init(batch_size, 0, 1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move the llama_batch
outside the loop and reuse it. Maybe utilize the common_batch_
API to make it little less cumbersome.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
src/llama.cpp
Outdated
batch.n_seq_id = n_seq_id.data(); | ||
} | ||
if (!batch.seq_id) { | ||
seq_id.resize(batch.n_tokens); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make this also NULL
terminated for consistency (see llama_batch_init
):
seq_id.resize(batch.n_tokens); | |
seq_id.resize(batch.n_tokens + 1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in 7264596
examples/infill/infill.cpp
Outdated
|
||
llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1); | ||
llama_kv_cache_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard); | ||
llama_kv_cache_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past + 1, -n_discard); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small explanation for what's happening: We suppose to shift all tokens from n_keep + n_discard + 1
, so the end of must be n_past + 1
(or we can simply set it to -1
, which means [p0, inf)
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, I don't think n_past + 1
is needed here. There shouldn't be a token with pos == n_past
in the KV cache.
But yes, using either n_past
or -1
would achieve the same thing. Think using n_past
is more illustrative.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok thanks, I figured out that I counted the token from 1, not from 0. I fixed that in 5d99ae4
…l-org#9745) * refactor llama_batch_get_one * adapt all examples * fix simple.cpp * fix llama_bench * fix * fix context shifting * free batch before return * use common_batch_add, reuse llama_batch in loop * null terminated seq_id list * fix save-load-state example * fix perplexity * correct token pos in llama_batch_allocr
…l-org#9745) * refactor llama_batch_get_one * adapt all examples * fix simple.cpp * fix llama_bench * fix * fix context shifting * free batch before return * use common_batch_add, reuse llama_batch in loop * null terminated seq_id list * fix save-load-state example * fix perplexity * correct token pos in llama_batch_allocr
…l-org#9745) * refactor llama_batch_get_one * adapt all examples * fix simple.cpp * fix llama_bench * fix * fix context shifting * free batch before return * use common_batch_add, reuse llama_batch in loop * null terminated seq_id list * fix save-load-state example * fix perplexity * correct token pos in llama_batch_allocr
Motivation
While working on the ability to add both embeddings and tokens to the same batch, I noticed that the old API for
llama_batch
, namelyall_pos_0
,all_post_1
andall_seq_id
has been there for quite a long time.Migration guide
The recommended way is to use
llama_batch_init
andllama_batch_free
:If the binary is linked against
common
, you can use some helper functions:common_batch_add
to add a new token into the batchcommon_batch_clear
to remove all tokens from the batchIf your use case is using single sequence, then you can adapt to the new call signature of
llama_batch_get_one
(although, this is not recommended):The position of tokens will be tracked automatically by
llama_decode
. For example, if the first time, you callllama_decode
on a batch of 10 tokens, then the next timellama_decode
will start decoding from position 11.