-
Notifications
You must be signed in to change notification settings - Fork 13.4k
llama : remove all_pos_0, all_pos_1, all_seq_id from llama_batch #9745
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
b226c5b
1c48616
9970316
9276950
59fd6b6
7740c96
6a9769a
0639ff1
b4c9911
734f9e2
7264596
6395174
4be7ecf
9dd7e77
5d99ae4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -412,13 +412,22 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params | |
| const int batch_start = start + j * n_batch; | ||
| const int batch_size = std::min(end - batch_start, n_batch); | ||
|
|
||
| llama_batch batch = llama_batch_init(batch_size, 0, 1); | ||
|
||
| for (int i = 0; i < batch_size; i++) { | ||
| batch. token[i] = tokens[batch_start + i]; | ||
| batch. pos[i] = j*n_batch + i; | ||
| batch.logits[i] = true; | ||
| batch.seq_id[i][0] = 0; | ||
| } | ||
|
|
||
| //LOG_DBG(" Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch); | ||
| // TODO: use llama_batch.logits instead of relying on logits_all == true | ||
| if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) { | ||
| if (llama_decode(ctx, batch)) { | ||
| //LOG_ERR("%s : failed to eval\n", __func__); | ||
| return {tokens, -1, logit_history, prob_history}; | ||
| } | ||
|
|
||
| llama_batch_free(batch); | ||
|
|
||
| // save original token and restore it after eval | ||
| const auto token_org = tokens[batch_start]; | ||
|
|
||
|
|
@@ -704,7 +713,6 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector< | |
| batch.n_seq_id + i, | ||
| batch.seq_id + i, | ||
| batch.logits + i, | ||
| 0, 0, 0, // unused | ||
| }; | ||
|
|
||
| const int ret = llama_decode(ctx, batch_view); | ||
|
|
@@ -1803,12 +1811,21 @@ static void kl_divergence(llama_context * ctx, const common_params & params) { | |
| tokens[batch_start] = llama_token_bos(llama_get_model(ctx)); | ||
| } | ||
|
|
||
| // TODO: use llama_batch.logits instead of relying on logits_all == true | ||
| if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) { | ||
| llama_batch batch = llama_batch_init(batch_size, 0, 1); | ||
| for (int i = 0; i < batch_size; i++) { | ||
| batch. token[i] = tokens[batch_start + i]; | ||
| batch. pos[i] = j*n_batch + i; | ||
| batch.logits[i] = true; | ||
| batch.seq_id[i][0] = 0; | ||
| } | ||
|
|
||
| if (llama_decode(ctx, batch)) { | ||
| LOG_ERR("%s : failed to eval\n", __func__); | ||
| return; | ||
| } | ||
|
|
||
| llama_batch_free(batch); | ||
|
|
||
| // restore the original token in case it was set to BOS | ||
| tokens[batch_start] = token_org; | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -49,7 +49,7 @@ int main(int argc, char ** argv) { | |
| auto tokens = common_tokenize(ctx, params.prompt, true); | ||
|
|
||
| // evaluate prompt | ||
| llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size(), n_past, 0)); | ||
| llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size())); | ||
| n_past += tokens.size(); | ||
|
|
||
| // save state (rng, logits, embedding and kv_cache) to file | ||
|
|
@@ -77,7 +77,7 @@ int main(int argc, char ** argv) { | |
| printf("%s", next_token_str.c_str()); | ||
| result0 += next_token_str; | ||
|
|
||
| if (llama_decode(ctx, llama_batch_get_one(&next_token, 1, n_past, 0))) { | ||
| if (llama_decode(ctx, llama_batch_get_one(&next_token, 1))) { | ||
| fprintf(stderr, "\n%s : failed to evaluate\n", __func__); | ||
| llama_free(ctx); | ||
| llama_free_model(model); | ||
|
|
@@ -133,7 +133,7 @@ int main(int argc, char ** argv) { | |
| printf("%s", next_token_str.c_str()); | ||
| result1 += next_token_str; | ||
|
|
||
| if (llama_decode(ctx2, llama_batch_get_one(&next_token, 1, n_past, 0))) { | ||
| if (llama_decode(ctx2, llama_batch_get_one(&next_token, 1))) { | ||
| fprintf(stderr, "\n%s : failed to evaluate\n", __func__); | ||
| llama_free(ctx2); | ||
| llama_free_model(model); | ||
|
|
@@ -221,7 +221,7 @@ int main(int argc, char ** argv) { | |
| printf("%s", next_token_str.c_str()); | ||
| result2 += next_token_str; | ||
|
|
||
| if (llama_decode(ctx3, llama_batch_get_one(&next_token, 1, n_past, 1))) { | ||
| if (llama_decode(ctx3, llama_batch_get_one(&next_token, 1))) { | ||
|
||
| fprintf(stderr, "\n%s : failed to evaluate\n", __func__); | ||
| llama_free(ctx3); | ||
| llama_free_model(model); | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small explanation for what's happening: We suppose to shift all tokens from
n_keep + n_discard + 1, so the end of must ben_past + 1(or we can simply set it to-1, which means[p0, inf))Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, I don't think
n_past + 1is needed here. There shouldn't be a token withpos == n_pastin the KV cache.But yes, using either
n_pastor-1would achieve the same thing. Think usingn_pastis more illustrative.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok thanks, I figured out that I counted the token from 1, not from 0. I fixed that in 5d99ae4