diff --git a/examples/models/llama/runner/runner.cpp b/examples/models/llama/runner/runner.cpp index e0a317aaff3..0ecc611ef6c 100644 --- a/examples/models/llama/runner/runner.cpp +++ b/examples/models/llama/runner/runner.cpp @@ -31,6 +31,7 @@ static constexpr auto kEnableDynamicShape = "enable_dynamic_shape"; static constexpr auto kBosId = "get_bos_id"; static constexpr auto kEosIds = "get_eos_ids"; static constexpr auto kMaxSeqLen = "get_max_seq_len"; +static constexpr auto kMaxContextLen = "get_max_context_len"; static constexpr auto kVocabSize = "get_vocab_size"; static constexpr auto kUseKVCache = "use_kv_cache"; static constexpr auto kUseSDPAWithKVCache = "use_sdpa_with_kv_cache"; @@ -49,6 +50,7 @@ Runner::Runner( metadata_({ {kEnableDynamicShape, false}, {kMaxSeqLen, 128}, + {kMaxContextLen, 128}, {kUseKVCache, true}, {kUseSDPAWithKVCache, false}, }) { @@ -201,9 +203,9 @@ Error Runner::generate( shouldStop_ = false; // Set the sequence length to the max seq length if not provided - seq_len = (seq_len > 0 && seq_len <= metadata_.at(kMaxSeqLen)) + seq_len = (seq_len > 0 && seq_len <= metadata_.at(kMaxContextLen)) ? seq_len - : metadata_.at(kMaxSeqLen); + : metadata_.at(kMaxContextLen); ::tokenizers::Result> encode_res = tokenizer_->encode( prompt,