Skip to content

Commit aa26a58

Browse files
committed
added logprobs api and logprobs viewer
1 parent 6731dd6 commit aa26a58

File tree

5 files changed

+229
-29
lines changed

5 files changed

+229
-29
lines changed

expose.cpp

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -294,11 +294,29 @@ extern "C"
294294
return output;
295295
}
296296

297+
static std::vector<TopPicksData> last_logprob_toppicks;
298+
static std::vector<logprob_item> last_logprob_items;
297299
last_logprobs_outputs last_logprobs()
298300
{
299301
last_logprobs_outputs output;
300-
std::vector<TopPicksData> toppicks = gpttype_get_top_picks_data(); //copy top picks
301-
output.count = 0;
302+
last_logprob_items.clear();
303+
last_logprob_toppicks.clear();
304+
last_logprob_toppicks = gpttype_get_top_picks_data(); //copy top picks
305+
for(int i=0;i<last_logprob_toppicks.size();++i)
306+
{
307+
logprob_item itm;
308+
itm.option_count = last_logprob_toppicks[i].tokenid.size();
309+
itm.selected_token = last_logprob_toppicks[i].selected_token.c_str();
310+
itm.selected_logprob = last_logprob_toppicks[i].selected_logprob;
311+
itm.logprobs = last_logprob_toppicks[i].logprobs.data();
312+
for(int j=0;j<itm.option_count && j<logprobs_max;++j)
313+
{
314+
itm.tokens[j] = last_logprob_toppicks[i].tokens[j].c_str();
315+
}
316+
last_logprob_items.push_back(itm);
317+
}
318+
output.count = last_logprob_items.size();
319+
output.logprob_items = last_logprob_items.data();
302320
return output;
303321
}
304322

expose.h

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
const int tensor_split_max = 16;
55
const int images_max = 4;
6+
const int logprobs_max = 5;
67

78
// match kobold's sampler list and order
89
enum samplers
@@ -111,19 +112,26 @@ struct generation_outputs
111112
{
112113
int status = -1;
113114
int stopreason = stop_reason::INVALID;
115+
int prompt_tokens = 0;
116+
int completion_tokens = 0;
114117
const char * text; //response will now be stored in c++ allocated memory
115118
};
116119
struct token_count_outputs
117120
{
118121
int count = 0;
119122
int * ids; //we'll just use shared memory for this one, bit of a hack
120123
};
124+
125+
struct logprob_item {
126+
int option_count;
127+
const char * selected_token;
128+
float selected_logprob;
129+
const char * tokens[logprobs_max];
130+
float * logprobs = nullptr;
131+
};
121132
struct last_logprobs_outputs {
122133
int count = 0;
123-
char ** selected_token;
124-
float * selected_logprob;
125-
char * tokens[5];
126-
float * logprobs[5];
134+
logprob_item * logprob_items = nullptr;
127135
};
128136
struct sd_load_model_inputs
129137
{

gpttype_adapter.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -597,13 +597,13 @@ llama_token sample_token(llama_token_data_array * candidates, std::mt19937 & rng
597597
int idx = dist(rng);
598598

599599
newpick.selected_token = FileFormatTokenizeID(candidates->data[idx].id, file_format, true);
600-
newpick.selected_logprob = candidates->data[idx].logit;
600+
newpick.selected_logprob = logf(candidates->data[idx].p);
601601
newpick.selected_probability = candidates->data[idx].p;
602602
newpick.selected_tokenid = candidates->data[idx].id;
603-
for (size_t i = 0; (i < candidates->size && i<5); ++i)
603+
for (size_t i = 0; (i < candidates->size && i<logprobs_max); ++i)
604604
{
605605
newpick.tokens.push_back(FileFormatTokenizeID(candidates->data[i].id, file_format, true));
606-
newpick.logprobs.push_back(candidates->data[i].logit);
606+
newpick.logprobs.push_back(logf(candidates->data[i].p));
607607
newpick.p.push_back(candidates->data[i].p);
608608
newpick.tokenid.push_back(candidates->data[i].id);
609609
}
@@ -2467,6 +2467,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
24672467
printf("\nWarning: KCPP text generation not initialized!\n");
24682468
output.text = nullptr;
24692469
output.status = 0;
2470+
output.prompt_tokens = output.completion_tokens = 0;
24702471
output.stopreason = stop_reason::INVALID;
24712472
generation_finished = true;
24722473
return output;
@@ -3142,6 +3143,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
31423143
fprintf(stderr, "\nFailed to predict at %d! Check your context buffer sizes!\n",n_past);
31433144
output.text = nullptr;
31443145
output.status = 0;
3146+
output.prompt_tokens = output.completion_tokens = 0;
31453147
output.stopreason = stop_reason::INVALID;
31463148
generation_finished = true;
31473149
return output;
@@ -3471,6 +3473,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
34713473
fprintf(stderr, "\nFailed to eval llava image at %d!\n",n_past);
34723474
output.text = nullptr;
34733475
output.status = 0;
3476+
output.prompt_tokens = output.completion_tokens = 0;
34743477
output.stopreason = stop_reason::INVALID;
34753478
generation_finished = true;
34763479
return output;
@@ -3482,6 +3485,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
34823485
fprintf(stderr, "\nLLAVA image tokens mismatch at %d! (%d vs %d tokens)\n",n_past,llavatokenscounted,llavatokensevaled);
34833486
output.text = nullptr;
34843487
output.status = 0;
3488+
output.prompt_tokens = output.completion_tokens = 0;
34853489
output.stopreason = stop_reason::INVALID;
34863490
generation_finished = true;
34873491
return output;
@@ -3534,6 +3538,9 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
35343538
printf("\nCtxLimit:%d/%d, Amt:%d/%d, Init:%.2fs, Process:%.2fs (%.1fms/T = %.2fT/s), Generate:%.2fs (%.1fms/T = %.2fT/s), Total:%.2fs (%.2fT/s)",(int)current_context_tokens.size(),(int)nctx, realnpredict, kcpp_data->n_predict, time0, time1, pt1, ts1, time2, pt2, ts2, (time1 + time2), tokens_per_second);
35353539
fflush(stdout);
35363540
output.status = 1;
3541+
int finaltokcount = (int)current_context_tokens.size()-realnpredict;
3542+
output.prompt_tokens = (finaltokcount<0?0:finaltokcount);
3543+
output.completion_tokens = realnpredict;
35373544
output.stopreason = last_stop_reason;
35383545
last_eval_time = pt2;
35393546
last_process_time = pt1;

0 commit comments

Comments
 (0)