@@ -94,6 +94,7 @@ struct Args
9494 int batch_size = 4096 ;
9595 bool detect_thoughts = false ;
9696 int penalty_window = 256 ;
97+ int max_new_tokens = -1 ;
9798};
9899
99100#define MULTI_LINE_END_MARKER_W L" \\ ."
@@ -165,6 +166,7 @@ void usage(const std::string &prog)
165166 << " --multi enabled multiple lines of input [*]\n "
166167 << " when enabled, `" << MULTI_LINE_END_MARKER << " ` marks the end of your input.\n "
167168 << " --format FMT conversion format (model specific, FMT = chat | completion | qa) (default: chat)\n "
169+ << " --max_new_tokens N max number of new tokens in a round of generation (default: -1, i.e. unlimited)\n "
168170 << " Performance options:\n "
169171 << " -n, --threads N number of threads for inference (default: number of cores)\n "
170172 << " -ngl, --n_gpu_layers N number of the main model layers to offload to a backend device (GPU) (default: GPU not used)\n "
@@ -485,6 +487,7 @@ static size_t parse_args(Args &args, const std::vector<std::string> &argv)
485487 handle_para0 (" --batch_size" , batch_size, std::stoi)
486488 handle_para0 (" --tts_export" , tts_export, std::string)
487489 handle_para0 (" --re_quantize" , re_quantize, std::string)
490+ handle_para0 (" --max_new_tokens" , max_new_tokens, std::stoi)
488491 else
489492 break ;
490493
@@ -898,6 +901,8 @@ void chat(Args &args, chatllm::Pipeline &pipeline, TextStreamer &streamer)
898901 pipeline.tokenizer ->set_chat_format (args.format );
899902 }
900903
904+ pipeline.gen_max_tokens = args.max_new_tokens ;
905+
901906 if (args.tokenize )
902907 {
903908 auto ids = pipeline.tokenizer ->encode (args.prompt );
0 commit comments