1010#include < cstdio>
1111#include < cstring>
1212#include < ctime>
13+ #include < cinttypes>
1314#include < fstream>
1415#include < mutex>
1516#include < random>
@@ -103,7 +104,7 @@ static std::vector<float> softmax(const std::vector<float>& logits) {
103104 return probs;
104105}
105106
106- static results_log_softmax log_softmax (int n_vocab, const float * logits, int tok) {
107+ static results_log_softmax log_softmax (int64_t n_vocab, const float * logits, int tok) {
107108 float max_logit = logits[0 ];
108109 for (int i = 1 ; i < n_vocab; ++i) {
109110 max_logit = std::max (max_logit, logits[i]);
@@ -122,7 +123,7 @@ static inline int nearest_int(float fval) {
122123 return (i & 0x007fffff ) - 0x00400000 ;
123124}
124125
125- static double log_softmax (int n_vocab, const float * logits, uint16_t * log_prob, int tok) {
126+ static double log_softmax (int64_t n_vocab, const float * logits, uint16_t * log_prob, int tok) {
126127 float max_logit = logits[0 ];
127128 float min_logit = logits[0 ];
128129 for (int i = 1 ; i < n_vocab; ++i) {
@@ -153,7 +154,7 @@ static double log_softmax(int n_vocab, const float * logits, uint16_t * log_prob
153154}
154155
155156static void process_logits (
156- int n_vocab, const float * logits, const int * tokens, int n_token, std::vector<std::thread> & workers,
157+ int64_t n_vocab, const float * logits, const int * tokens, int n_token, std::vector<std::thread> & workers,
157158 double & nll, double & nll2, float * logit_history, float * prob_history
158159) {
159160 std::mutex mutex;
@@ -187,7 +188,7 @@ static void process_logits(
187188 }
188189}
189190
190- static void process_logits (std::ostream& out, int n_vocab, const float * logits, const int * tokens, int n_token,
191+ static void process_logits (std::ostream& out, int64_t n_vocab, const float * logits, const int * tokens, int n_token,
191192 std::vector<std::thread> & workers, std::vector<uint16_t > & log_probs, double & nll, double & nll2) {
192193 std::mutex mutex;
193194 const int nv = 2 *((n_vocab + 1 )/2 ) + 4 ;
@@ -234,7 +235,7 @@ struct kl_divergence_result {
234235 size_t count = 0.0 ;
235236};
236237
237- static std::pair<double , float > log_softmax (int n_vocab, const float * logits, const uint16_t * base_log_prob, int tok, kl_divergence_result & kld) {
238+ static std::pair<double , float > log_softmax (int64_t n_vocab, const float * logits, const uint16_t * base_log_prob, int tok, kl_divergence_result & kld) {
238239 float max_logit = logits[0 ];
239240 int imax = 0 ;
240241 for (int i = 1 ; i < n_vocab; ++i) {
@@ -281,7 +282,9 @@ static std::pair<double, float> log_softmax(int n_vocab, const float * logits, c
281282 kld.sum_kld += sum;
282283 kld.sum_kld2 += sum*sum;
283284 ++kld.count ;
284- if (imax == imax_base) ++kld.n_same_top ;
285+ if (imax == imax_base) {
286+ ++kld.n_same_top ;
287+ }
285288
286289 const float p_base = expf (-nll_base);
287290 const float p = expf (-nll);
@@ -295,7 +298,7 @@ static std::pair<double, float> log_softmax(int n_vocab, const float * logits, c
295298 return std::make_pair (sum, p_diff);
296299}
297300
298- static void process_logits (int n_vocab, const float * logits, const int * tokens, int n_token,
301+ static void process_logits (int64_t n_vocab, const float * logits, const int * tokens, int n_token,
299302 std::vector<std::thread> & workers, const std::vector<uint16_t > & base_log_probs, kl_divergence_result & kld,
300303 float * kld_values, float * p_diff_values) {
301304 std::mutex mutex;
@@ -383,9 +386,10 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
383386 const int n_chunk_max = (tokens.size () - calc_chunk + params.ppl_stride - 1 ) / params.ppl_stride ;
384387
385388 const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min (params.n_chunks , n_chunk_max);
386- const int n_vocab = llama_n_vocab (llama_get_model (ctx));
387389 const int n_batch = params.n_batch ;
388390
391+ const int64_t n_vocab = llama_n_vocab (llama_get_model (ctx));
392+
389393 int count = 0 ;
390394 double nll = 0.0 ;
391395
@@ -521,9 +525,10 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
521525 const int n_chunk_max = tokens.size () / n_ctx;
522526
523527 const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min (params.n_chunks , n_chunk_max);
524- const int n_vocab = llama_n_vocab (llama_get_model (ctx));
525528 const int n_batch = params.n_batch ;
526529
530+ const int64_t n_vocab = llama_n_vocab (llama_get_model (ctx));
531+
527532 int count = 0 ;
528533 double nll = 0.0 ;
529534 double nll2 = 0.0 ;
@@ -723,7 +728,7 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<
723728
724729#define K_TOKEN_CHUNK 4
725730
726- static void compute_logprobs (const float * batch_logits, int n_vocab, std::vector<std::thread>& workers,
731+ static void compute_logprobs (const float * batch_logits, int64_t n_vocab, std::vector<std::thread>& workers,
727732 const std::vector<std::pair<size_t , llama_token>>& eval_pairs, std::vector<float >& eval_results) {
728733 if (eval_results.size () != eval_pairs.size ()) {
729734 eval_results.resize (eval_pairs.size ());
@@ -877,10 +882,11 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
877882
878883 double acc = 0 .0f ;
879884
880- const int n_vocab = llama_n_vocab (llama_get_model (ctx));
881885 const int n_ctx = llama_n_ctx (ctx);
882886 const int n_batch = params.n_batch ;
883887
888+ const int64_t n_vocab = llama_n_vocab (llama_get_model (ctx));
889+
884890 const int max_tasks_per_batch = 32 ;
885891 const int max_seq = std::min (4 *max_tasks_per_batch, (int ) llama_n_seq_max (ctx));
886892
@@ -1158,10 +1164,11 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
11581164
11591165 LOG_INF (" %s : calculating winogrande score over selected tasks.\n " , __func__);
11601166
1161- const int n_vocab = llama_n_vocab (llama_get_model (ctx));
11621167 const int n_ctx = llama_n_ctx (ctx);
11631168 const int n_batch = params.n_batch ;
11641169
1170+ const int64_t n_vocab = llama_n_vocab (llama_get_model (ctx));
1171+
11651172 const int max_tasks_per_batch = 128 ;
11661173 const int max_seq = std::min (2 *max_tasks_per_batch, (int ) llama_n_seq_max (ctx));
11671174
@@ -1509,10 +1516,11 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
15091516
15101517 LOG (" \n task\t acc_norm\n " );
15111518
1512- const int n_vocab = llama_n_vocab (llama_get_model (ctx));
15131519 const int n_ctx = llama_n_ctx (ctx);
15141520 const int n_batch = params.n_batch ;
15151521
1522+ const int64_t n_vocab = llama_n_vocab (llama_get_model (ctx));
1523+
15161524 const int max_tasks_per_batch = 32 ;
15171525 const int max_seq = std::min (4 *max_tasks_per_batch, (int ) llama_n_seq_max (ctx));
15181526
@@ -1709,15 +1717,16 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
17091717 __func__, params.logits_file .c_str (), n_ctx, params.n_ctx );
17101718 }
17111719
1712- int n_vocab, n_chunk;
1720+ int64_t n_vocab;
1721+ int64_t n_chunk;
17131722 in.read ((char *)&n_vocab, sizeof (n_vocab));
17141723 in.read ((char *)&n_chunk, sizeof (n_chunk));
17151724 if (in.fail ()) {
17161725 LOG_ERR (" %s: failed reading n_vocab, n_chunk from %s\n " , __func__, params.logits_file .c_str ());
17171726 return ;
17181727 }
17191728 if (n_vocab != llama_n_vocab (llama_get_model (ctx))) {
1720- LOG_ERR (" %s: inconsistent vocabulary (%d vs %d)\n " , __func__, n_vocab, llama_n_vocab (llama_get_model (ctx)));
1729+ LOG_ERR (" %s: inconsistent vocabulary (%" PRId64 " vs %d)\n " , __func__, n_vocab, llama_n_vocab (llama_get_model (ctx)));
17211730 }
17221731
17231732 std::vector<llama_token> tokens (n_ctx * n_chunk);
0 commit comments