@@ -442,7 +442,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
442442 return {tokens, std::exp (nll / count), logit_history, prob_history};
443443}
444444
445- static results_perplexity perplexity (llama_context * ctx, const gpt_params & params) {
445+ static results_perplexity perplexity (llama_context * ctx, const gpt_params & params, const int32_t n_ctx ) {
446446 if (params.ppl_stride > 0 ) {
447447 return perplexity_v2 (ctx, params);
448448 }
@@ -453,7 +453,6 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
453453 // BOS tokens will be added for each chunk before eval
454454
455455 const bool add_bos = llama_should_add_bos_token (llama_get_model (ctx));
456- const int n_ctx = llama_n_ctx (ctx);
457456
458457 std::ofstream logits_stream;
459458 if (!params.logits_file .empty ()) {
@@ -499,13 +498,19 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
499498 double nll2 = 0.0 ;
500499
501500 const int num_batches = (n_ctx + n_batch - 1 ) / n_batch;
501+ const int n_seq = std::max (1 , n_batch / n_ctx);
502+
503+ GGML_ASSERT (n_batch < n_ctx || n_batch % n_ctx == 0 );
504+ GGML_ASSERT (params.n_ctx == n_seq * n_ctx);
505+
506+ llama_batch batch = llama_batch_init (std::min (n_batch, n_ctx*n_seq), 0 , 1 );
502507
503508 std::vector<float > logits;
504509 if (num_batches > 1 ) {
505510 logits.reserve ((size_t )n_ctx * n_vocab);
506511 }
507512
508- fprintf (stderr, " %s: calculating perplexity over %d chunks, batch_size=%d\n " , __func__, n_chunk, n_batch);
513+ fprintf (stderr, " %s: calculating perplexity over %d chunks, n_ctx=%d, batch_size=%d, n_seq=%d \n " , __func__, n_chunk, n_ctx, n_batch, n_seq );
509514
510515 std::vector<std::thread> workers (std::thread::hardware_concurrency () - 1 );
511516
@@ -518,10 +523,26 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
518523 log_probs.resize (n_ctx * nv);
519524 }
520525
521- for (int i = 0 ; i < n_chunk; ++i) {
526+ // We get the logits for all the tokens in the context window (params.n_ctx)
527+ // from llama_eval above. Now, based on https://huggingface.co/docs/transformers/perplexity,
528+ // calculate the perplexity over the last half of the window (so the model always has
529+ // some context to predict the token).
530+ //
531+ // We rely on the fact that attention in the forward pass only looks at previous
532+ // tokens here, so the logits returned for each token are an accurate representation
533+ // of what the model would have predicted at that point.
534+ //
535+ // Example, we have a context window of 512, we will compute perplexity for each of the
536+ // last 256 tokens. Then, we split the input up into context window size chunks to
537+ // process the entire prompt.
538+ const int first = n_ctx/2 ;
539+
540+ for (int i = 0 ; i < n_chunk; i += n_seq) {
522541 const int start = i * n_ctx;
523542 const int end = start + n_ctx;
524543
544+ const int n_seq_batch = std::min (n_seq, n_chunk - i);
545+
525546 const auto t_start = std::chrono::high_resolution_clock::now ();
526547
527548 // clear the KV cache
@@ -531,22 +552,37 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
531552 const int batch_start = start + j * n_batch;
532553 const int batch_size = std::min (end - batch_start, n_batch);
533554
534- // save original token and restore it after eval
535- const auto token_org = tokens[batch_start];
555+ batch.n_tokens = 0 ;
556+ for (int seq = 0 ; seq < n_seq_batch; seq++) {
557+ int seq_start = batch_start + seq*n_ctx;
536558
537- // add BOS token for the first batch of each chunk
538- if (add_bos && j == 0 ) {
539- tokens[batch_start] = llama_token_bos (llama_get_model (ctx));
559+ // save original token and restore it after eval
560+ const auto token_org = tokens[seq_start];
561+
562+ // add BOS token for the first batch of each chunk
563+ if (add_bos && j == 0 ) {
564+ tokens[seq_start] = llama_token_bos (llama_get_model (ctx));
565+ }
566+
567+ for (int k = 0 ; k < batch_size; ++k) {
568+ const int idx = seq*n_ctx + k;
569+ batch.token [idx] = tokens[seq_start + k];
570+ batch.pos [idx] = j*n_batch + k;
571+ batch.n_seq_id [idx] = 1 ;
572+ batch.seq_id [idx][0 ] = seq;
573+ batch.logits [idx] = batch.pos [idx] >= first ? 1 : 0 ;
574+ }
575+ batch.n_tokens += batch_size;
576+
577+ // restore the original token in case it was set to BOS
578+ tokens[seq_start] = token_org;
540579 }
541580
542- if (llama_decode (ctx, llama_batch_get_one (tokens. data () + batch_start, batch_size, j * n_batch, 0 ) )) {
581+ if (llama_decode (ctx, batch )) {
543582 fprintf (stderr, " %s : failed to eval\n " , __func__);
544583 return {tokens, -1 , logit_history, prob_history};
545584 }
546585
547- // restore the original token in case it was set to BOS
548- tokens[batch_start] = token_org;
549-
550586 if (num_batches > 1 ) {
551587 const auto * batch_logits = llama_get_logits (ctx);
552588 logits.insert (logits.end (), batch_logits, batch_logits + batch_size * n_vocab);
@@ -558,45 +594,39 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
558594 if (i == 0 ) {
559595 const float t_total = std::chrono::duration<float >(t_end - t_start).count ();
560596 fprintf (stderr, " %s: %.2f seconds per pass - ETA " , __func__, t_total);
561- int total_seconds = (int )(t_total * n_chunk);
597+ int total_seconds = (int )(t_total* n_chunk/n_seq );
562598 if (total_seconds >= 60 *60 ) {
563599 fprintf (stderr, " %d hours " , total_seconds / (60 *60 ));
564600 total_seconds = total_seconds % (60 *60 );
565601 }
566602 fprintf (stderr, " %.2f minutes\n " , total_seconds / 60.0 );
567603 }
568604
569- // We get the logits for all the tokens in the context window (params.n_ctx)
570- // from llama_eval above. Now, based on https://huggingface.co/docs/transformers/perplexity,
571- // calculate the perplexity over the last half of the window (so the model always has
572- // some context to predict the token).
573- //
574- // We rely on the fact that attention in the forward pass only looks at previous
575- // tokens here, so the logits returned for each token are an accurate representation
576- // of what the model would have predicted at that point.
577- //
578- // Example, we have a context window of 512, we will compute perplexity for each of the
579- // last 256 tokens. Then, we split the input up into context window size chunks to
580- // process the entire prompt.
581- const int first = n_ctx/2 ;
582- const float * all_logits = num_batches > 1 ? logits.data () : llama_get_logits (ctx);
583- if (!params.logits_file .empty ()) {
584- process_logits (logits_stream, n_vocab, all_logits + first*n_vocab, tokens.data () + start + first, n_ctx - 1 - first,
585- workers, log_probs, nll, nll2);
586- } else {
587- process_logits (n_vocab, all_logits + first*n_vocab, tokens.data () + start + first, n_ctx - 1 - first,
588- workers, nll, nll2, logit_history.data () + start + first, prob_history.data () + start + first);
589- }
590- count += n_ctx - first - 1 ;
591-
592- // perplexity is e^(average negative log-likelihood)
593- if (params.ppl_output_type == 0 ) {
594- printf (" [%d]%.4lf," , i + 1 , std::exp (nll / count));
595- } else {
596- double av = nll/count;
597- double av2 = nll2/count - av*av;
598- if (av2 > 0 ) av2 = sqrt (av2/(count-1 ));
599- printf (" %8d %.4lf %4lf %4lf\n " , i*n_ctx, std::exp (nll / count), av, av2);
605+ for (int seq = 0 ; seq < n_seq_batch; seq++) {
606+ const float * all_logits = num_batches > 1 ? logits.data () : llama_get_logits_ith (ctx, seq*n_ctx);
607+ llama_token * tokens_data = tokens.data () + start + seq*n_ctx + first;
608+ if (!params.logits_file .empty ()) {
609+ process_logits (logits_stream, n_vocab, all_logits + first*n_vocab,
610+ tokens_data, n_ctx - 1 - first,
611+ workers, log_probs, nll, nll2);
612+ } else {
613+ process_logits (n_vocab, all_logits + first*n_vocab,
614+ tokens_data, n_ctx - 1 - first,
615+ workers, nll, nll2,
616+ logit_history.data () + start + seq*n_ctx + first,
617+ prob_history.data () + start + seq*n_ctx + first);
618+ }
619+ count += n_ctx - first - 1 ;
620+
621+ // perplexity is e^(average negative log-likelihood)
622+ if (params.ppl_output_type == 0 ) {
623+ printf (" [%d]%.4lf," , i + seq + 1 , std::exp (nll / count));
624+ } else {
625+ double av = nll/count;
626+ double av2 = nll2/count - av*av;
627+ if (av2 > 0 ) av2 = sqrt (av2/(count-1 ));
628+ printf (" %8d %.4lf %4lf %4lf\n " , i*n_ctx, std::exp (nll / count), av, av2);
629+ }
600630 }
601631 fflush (stdout);
602632
@@ -615,6 +645,8 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
615645 printf (" Unexpected negative standard deviation of log(prob)\n " );
616646 }
617647
648+ llama_batch_free (batch);
649+
618650 return {tokens, ppl, logit_history, prob_history};
619651}
620652
@@ -1782,13 +1814,24 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
17821814int main (int argc, char ** argv) {
17831815 gpt_params params;
17841816
1785- params.n_batch = 512 ;
17861817 if (!gpt_params_parse (argc, argv, params)) {
17871818 return 1 ;
17881819 }
17891820
17901821 params.logits_all = true ;
1791- params.n_batch = std::min (params.n_batch , params.n_ctx );
1822+
1823+ const int32_t n_ctx = params.n_ctx ;
1824+
1825+ const bool ppl = !params.hellaswag && !params.winogrande && !params.multiple_choice && !params.kl_divergence ;
1826+ if (ppl) {
1827+ int n_seq = std::max (1 , params.n_batch / n_ctx);
1828+ int32_t n_kv = n_seq * n_ctx;
1829+ params.n_parallel = n_seq;
1830+ params.n_ctx = n_kv;
1831+ params.n_batch = std::min (params.n_batch , n_kv);
1832+ } else {
1833+ params.n_batch = std::min (params.n_batch , params.n_ctx );
1834+ }
17921835
17931836 if (params.ppl_stride > 0 ) {
17941837 fprintf (stderr, " Will perform strided perplexity calculation -> adjusting context size from %d to %d\n " ,
@@ -1847,7 +1890,7 @@ int main(int argc, char ** argv) {
18471890 } else if (params.kl_divergence ) {
18481891 kl_divergence (ctx, params);
18491892 } else {
1850- results = perplexity (ctx, params);
1893+ results = perplexity (ctx, params, n_ctx );
18511894 }
18521895
18531896 llama_print_timings (ctx);
0 commit comments