diff --git a/common/arg.cpp b/common/arg.cpp index 93f0108b2b9..5910864d6c2 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1774,6 +1774,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.sampling.samplers = common_sampler_types_from_names(sampler_names, true); } ).set_sparam()); + add_opt(common_arg( + {"--t-max-predict-ms"}, "MILLISECONDS", + string_format("time limit in ms for prediction phase; triggers if generation exceeds this time and a new-line was generated (default: %ld)", params.t_max_predict_ms), + [](common_params & params, const std::string & value) { + params.t_max_predict_ms = std::stoll(value); + } + ).set_examples({LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"-s", "--seed"}, "SEED", string_format("RNG seed (default: %d, use random seed for %d)", params.sampling.seed, LLAMA_DEFAULT_SEED), diff --git a/common/common.h b/common/common.h index 87ea0606954..6f5491fd228 100644 --- a/common/common.h +++ b/common/common.h @@ -347,6 +347,7 @@ struct common_params { int32_t control_vector_layer_end = -1; // layer range for control vector bool offline = false; + int64_t t_max_predict_ms= 0; // max time in ms to predict after first new line (0 = unlimited) int32_t ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used. int32_t ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line // (which is more convenient to use for plotting) diff --git a/tools/main/main.cpp b/tools/main/main.cpp index 865ea4a2f72..80581815ac9 100644 --- a/tools/main/main.cpp +++ b/tools/main/main.cpp @@ -562,6 +562,10 @@ int main(int argc, char ** argv) { embd_inp.push_back(decoder_start_token_id); } + // Add for --t-max-predict-ms + bool seen_new_line = false; + int64_t t_start_generation = 0; + while ((n_remain != 0 && !is_antiprompt) || params.interactive) { // predict if (!embd.empty()) { @@ -739,6 +743,17 @@ int main(int argc, char ** argv) { // Console/Stream Output LOG("%s", token_str.c_str()); + if (token_str.find('\n') != std::string::npos) { + if (!seen_new_line) { + seen_new_line = true; + t_start_generation = ggml_time_us(); + } else if (params.t_max_predict_ms > 0 && (ggml_time_us() - t_start_generation > 1000.0f * params.t_max_predict_ms)) { + LOG_DBG("stopped by time limit, t_max_predict_ms = %d ms\n", (int) params.t_max_predict_ms); + n_remain = 0; + break; + } + } + // Record Displayed Tokens To Log // Note: Generated tokens are created one by one hence this check if (embd.size() > 1) {