diff --git a/common/arg.cpp b/common/arg.cpp index cdf8970254446..c25a0faf9d489 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2030,6 +2030,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.multiple_choice_tasks = value; } ).set_examples({LLAMA_EXAMPLE_PERPLEXITY})); + add_opt(common_arg( + {"--ignore-context-overflow"}, + string_format("ignores context window overflow when computing scores (default: %s)", params.ctx_overflow ? "enabled" : "disabled"), + [](common_params & params) { + params.ctx_overflow = true; + } + ).set_examples({LLAMA_EXAMPLE_PERPLEXITY})); add_opt(common_arg( {"--kl-divergence"}, "computes KL-divergence to logits provided via --kl-divergence-base", diff --git a/common/common.h b/common/common.h index 725b5123d24f9..c742c39e6fd77 100644 --- a/common/common.h +++ b/common/common.h @@ -303,6 +303,8 @@ struct common_params { bool multiple_choice = false; // compute TruthfulQA score over random tasks from datafile supplied in prompt size_t multiple_choice_tasks = 0; // number of tasks to use when computing the TruthfulQA score. If 0, all tasks will be computed + bool ctx_overflow = false; // ignore context window overflow during HellaSwag, WinoGrande or Multiple Choice evaluation + bool kl_divergence = false; // compute KL divergence bool usage = false; // print usage diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 175f2804b5da0..54149d874a2ad 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -786,8 +786,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) { hs_task_count = params.hellaswag_tasks; } - // The random seed should not impact the final result if the computation is done over enough tasks, so kept hardcoded for now - std::mt19937 rng(1); + std::mt19937 rng(std::random_device{}()); // Dataholder for hellaswag tasks struct hs_data_t { @@ -921,6 +920,11 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) { if (i0 == i1) { LOG_ERR("%s : task %zu does not fit in the context window\n", __func__, i0); + if (params.ctx_overflow) { + LOG_ERR("%s : ignoring offending task: %s\n", __func__, hs_data[i0].context.c_str()); + continue; + } + return; } @@ -1111,7 +1115,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params) if (params.winogrande_tasks > 0 && params.winogrande_tasks < data.size()) { LOG_INF("%s : selecting %zu random tasks\n", __func__, params.winogrande_tasks); - std::mt19937 rng(1); + std::mt19937 rng(std::random_device{}()); std::vector aux(data.size()); for (int i = 0; i < int(data.size()); ++i) { aux[i] = i; @@ -1214,6 +1218,11 @@ static void winogrande_score(llama_context * ctx, const common_params & params) if (i0 == i1) { LOG_ERR("%s : task %zu does not fit in the context window\n", __func__, i0); + if (params.ctx_overflow) { + LOG_ERR("%s : offending task: %s\n", __func__, data[i0].first.c_str()); + continue; + } + return; } @@ -1437,7 +1446,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par } else { LOG_INF("%s: selecting %zu random tasks from %u tasks available\n", __func__, params.multiple_choice_tasks, n_task); - std::mt19937 rng(1); + std::mt19937 rng(std::random_device{}()); std::vector aux(n_task); for (uint32_t i = 0; i < n_task; ++i) aux[i] = i; float scale = 1.f/(1.f + (float)std::mt19937::max()); @@ -1586,6 +1595,11 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par if (i0 == i1) { LOG_ERR("%s : task %zu does not fit in the context window\n", __func__, i0); + if (params.ctx_overflow) { + LOG_ERR("%s : offending task: %s\n", __func__, tasks[i0].question.c_str()); + continue; + } + return; }