Skip to content

Commit 545df93

Browse files
committed
server : replace slot.n_prompt_tokens() with slot.task->n_tokens()
1 parent 6c46646 commit 545df93

File tree

1 file changed

+27
-27
lines changed

1 file changed

+27
-27
lines changed

tools/server/server.cpp

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,10 @@ struct server_task {
292292

293293
server_task(server_task_type type) : type(type) {}
294294

295+
int32_t n_tokens() const {
296+
return tokens.size();
297+
}
298+
295299
static slot_params params_from_json_cmpl(
296300
const llama_context * ctx,
297301
const common_params & params_base,
@@ -1644,10 +1648,6 @@ struct server_slot {
16441648
int32_t n_prompt_tokens_cache = 0;
16451649
int32_t n_prompt_tokens_processed = 0;
16461650

1647-
int32_t n_prompt_tokens() const {
1648-
return task->tokens.size();
1649-
}
1650-
16511651
size_t last_nl_pos = 0;
16521652

16531653
std::string generated_text;
@@ -2864,8 +2864,8 @@ struct server_context {
28642864
slot.stop = STOP_TYPE_LIMIT;
28652865
slot.has_next_token = false;
28662866

2867-
SLT_DBG(slot, "stopped due to running out of context capacity, slot.prompt.n_tokens() = %d, n_prompt_tokens = %d, n_decoded = %d, n_ctx = %d\n",
2868-
slot.prompt.n_tokens(), slot.n_prompt_tokens(), slot.n_decoded, slot.n_ctx);
2867+
SLT_DBG(slot, "stopped due to running out of context capacity, prompt.n_tokens() = %d, task.n_tokens = %d, n_decoded = %d, n_ctx = %d\n",
2868+
slot.prompt.n_tokens(), slot.task->n_tokens(), slot.n_decoded, slot.n_ctx);
28692869
}
28702870

28712871
// check the limits
@@ -2992,7 +2992,7 @@ struct server_context {
29922992
}
29932993

29942994
void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
2995-
send_error(slot.task->id, error, type, slot.n_prompt_tokens(), slot.n_ctx);
2995+
send_error(slot.task->id, error, type, slot.task->n_tokens(), slot.n_ctx);
29962996
}
29972997

29982998
void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER, const int32_t n_prompt_tokens = 0, const int32_t n_ctx = 0) {
@@ -3029,7 +3029,7 @@ struct server_context {
30293029

30303030
if (is_progress) {
30313031
res->is_progress = true;
3032-
res->progress.total = slot.n_prompt_tokens();
3032+
res->progress.total = slot.task->n_tokens();
30333033
res->progress.cache = slot.n_prompt_tokens_cache;
30343034
res->progress.processed = slot.prompt.tokens.size();
30353035
res->progress.time_ms = (ggml_time_us() - slot.t_start_process_prompt / 1000);
@@ -3041,7 +3041,7 @@ struct server_context {
30413041
}
30423042

30433043
res->n_decoded = slot.n_decoded;
3044-
res->n_prompt_tokens = slot.n_prompt_tokens();
3044+
res->n_prompt_tokens = slot.task->n_tokens();
30453045
res->post_sampling_probs = slot.task->params.post_sampling_probs;
30463046

30473047
res->verbose = slot.task->params.verbose;
@@ -3077,7 +3077,7 @@ struct server_context {
30773077

30783078
res->truncated = slot.truncated;
30793079
res->n_decoded = slot.n_decoded;
3080-
res->n_prompt_tokens = slot.n_prompt_tokens();
3080+
res->n_prompt_tokens = slot.task->n_tokens();
30813081
res->n_tokens_cached = slot.prompt.n_tokens();
30823082
res->has_new_line = slot.has_new_line;
30833083
res->stopping_word = slot.stopping_word;
@@ -3117,7 +3117,7 @@ struct server_context {
31173117
auto res = std::make_unique<server_task_result_embd>();
31183118
res->id = slot.task->id;
31193119
res->index = slot.task->index;
3120-
res->n_tokens = slot.n_prompt_tokens();
3120+
res->n_tokens = slot.task->n_tokens();
31213121
res->oaicompat = slot.task->params.oaicompat;
31223122

31233123
const int n_embd = llama_model_n_embd(model);
@@ -3162,7 +3162,7 @@ struct server_context {
31623162
auto res = std::make_unique<server_task_result_rerank>();
31633163
res->id = slot.task->id;
31643164
res->index = slot.task->index;
3165-
res->n_tokens = slot.n_prompt_tokens();
3165+
res->n_tokens = slot.task->n_tokens();
31663166

31673167
for (int i = 0; i < batch.n_tokens; ++i) {
31683168
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
@@ -3561,7 +3561,7 @@ struct server_context {
35613561
}
35623562

35633563
// Shift context
3564-
int n_keep = slot.task->params.n_keep < 0 ? slot.n_prompt_tokens() : slot.task->params.n_keep;
3564+
int n_keep = slot.task->params.n_keep < 0 ? slot.task->n_tokens() : slot.task->params.n_keep;
35653565

35663566
if (add_bos_token) {
35673567
n_keep += 1;
@@ -3646,8 +3646,8 @@ struct server_context {
36463646

36473647
slot.state = SLOT_STATE_PROCESSING_PROMPT;
36483648

3649-
SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n",
3650-
slot.n_ctx, slot.task->params.n_keep, slot.n_prompt_tokens());
3649+
SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, task.n_tokens = %d\n",
3650+
slot.n_ctx, slot.task->params.n_keep, slot.task->n_tokens());
36513651

36523652
// print prompt tokens (for debugging)
36533653
/*if (1) {
@@ -3684,19 +3684,19 @@ struct server_context {
36843684
}
36853685

36863686
if (!slot.can_split()) {
3687-
if (slot.n_prompt_tokens() > n_ubatch) {
3687+
if (slot.task->n_tokens() > n_ubatch) {
36883688
send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER);
36893689
slot.release();
36903690
continue;
36913691
}
36923692

3693-
if (slot.n_prompt_tokens() > slot.n_ctx) {
3693+
if (slot.task->n_tokens() > slot.n_ctx) {
36943694
send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_EXCEED_CONTEXT_SIZE);
36953695
slot.release();
36963696
continue;
36973697
}
36983698
} else {
3699-
if (slot.n_prompt_tokens() >= slot.n_ctx) {
3699+
if (slot.task->n_tokens() >= slot.n_ctx) {
37003700
send_error(slot, "the request exceeds the available context size, try increasing it", ERROR_TYPE_EXCEED_CONTEXT_SIZE);
37013701
slot.release();
37023702
continue;
@@ -3874,8 +3874,8 @@ struct server_context {
38743874
}
38753875

38763876
// [TAG_PROMPT_LOGITS]
3877-
if (n_past == slot.n_prompt_tokens() && n_past > 0) {
3878-
SLT_WRN(slot, "need to evaluate at least 1 token for each active slot (n_past = %d, n_prompt_tokens = %d)\n", n_past, slot.n_prompt_tokens());
3877+
if (n_past == slot.task->n_tokens() && n_past > 0) {
3878+
SLT_WRN(slot, "need to evaluate at least 1 token for each active slot (n_past = %d, task.n_tokens() = %d)\n", n_past, slot.task->n_tokens());
38793879
n_past--;
38803880
SLT_WRN(slot, "n_past was set to %d\n", n_past);
38813881
}
@@ -3888,7 +3888,7 @@ struct server_context {
38883888

38893889
if (!slot.can_split()) {
38903890
// cannot fit the prompt in the current batch - will try next iter
3891-
if (batch.n_tokens + slot.n_prompt_tokens() > n_batch) {
3891+
if (batch.n_tokens + slot.task->n_tokens() > n_batch) {
38923892
continue;
38933893
}
38943894
}
@@ -3910,7 +3910,7 @@ struct server_context {
39103910
slot.prompt.tokens.keep_first(slot.prompt.n_tokens());
39113911

39123912
// check if we should process the image
3913-
if (slot.prompt.n_tokens() < slot.n_prompt_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) {
3913+
if (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) {
39143914
// process the image
39153915
int32_t new_n_past;
39163916
int32_t res = input_tokens.process_chunk(ctx, mctx, slot.prompt.n_tokens(), slot.id, new_n_past);
@@ -3962,7 +3962,7 @@ struct server_context {
39623962
);
39633963

39643964
// add prompt tokens for processing in the current batch
3965-
while (slot.prompt.n_tokens() < slot.n_prompt_tokens() && batch.n_tokens < n_batch) {
3965+
while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.n_tokens < n_batch) {
39663966
// get next token to process
39673967
llama_token cur_tok = input_tokens[slot.prompt.n_tokens()];
39683968
if (cur_tok == LLAMA_TOKEN_NULL) {
@@ -3984,25 +3984,25 @@ struct server_context {
39843984
slot.n_prompt_tokens_processed++;
39853985

39863986
// process the last few tokens of the prompt separately in order to allow for a checkpoint to be created.
3987-
if (do_checkpoint && slot.n_prompt_tokens() - slot.prompt.n_tokens() == 64) {
3987+
if (do_checkpoint && slot.task->n_tokens() - slot.prompt.n_tokens() == 64) {
39883988
break;
39893989
}
39903990
}
39913991

39923992
// SLT_INF(slot, "new slot.prompt.tokens: %s\n", slot.slot.prompt.tokens.str().c_str());
39933993

3994-
SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.n_prompt_tokens());
3994+
SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens());
39953995

39963996
// entire prompt has been processed
3997-
if (slot.prompt.n_tokens() == slot.n_prompt_tokens()) {
3997+
if (slot.prompt.n_tokens() == slot.task->n_tokens()) {
39983998
slot.state = SLOT_STATE_DONE_PROMPT;
39993999

40004000
GGML_ASSERT(batch.n_tokens > 0);
40014001

40024002
common_sampler_reset(slot.smpl);
40034003

40044004
// Process all prompt tokens through sampler system
4005-
for (int i = 0; i < slot.n_prompt_tokens(); ++i) {
4005+
for (int i = 0; i < slot.task->n_tokens(); ++i) {
40064006
llama_token id = input_tokens[i];
40074007
if (id != LLAMA_TOKEN_NULL) {
40084008
common_sampler_accept(slot.smpl, id, false);

0 commit comments

Comments
 (0)