Skip to content

Commit dcfa538

Browse files
authored
[Executorch][llama] Change runner to decouple prompt length from sequence
length Differential Revision: D69073908 Pull Request resolved: #9594
1 parent 10ef2c0 commit dcfa538

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

examples/models/llama/runner/runner.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ static constexpr auto kEnableDynamicShape = "enable_dynamic_shape";
3131
static constexpr auto kBosId = "get_bos_id";
3232
static constexpr auto kEosIds = "get_eos_ids";
3333
static constexpr auto kMaxSeqLen = "get_max_seq_len";
34+
static constexpr auto kMaxContextLen = "get_max_context_len";
3435
static constexpr auto kVocabSize = "get_vocab_size";
3536
static constexpr auto kUseKVCache = "use_kv_cache";
3637
static constexpr auto kUseSDPAWithKVCache = "use_sdpa_with_kv_cache";
@@ -49,6 +50,7 @@ Runner::Runner(
4950
metadata_({
5051
{kEnableDynamicShape, false},
5152
{kMaxSeqLen, 128},
53+
{kMaxContextLen, 128},
5254
{kUseKVCache, true},
5355
{kUseSDPAWithKVCache, false},
5456
}) {
@@ -201,9 +203,9 @@ Error Runner::generate(
201203
shouldStop_ = false;
202204

203205
// Set the sequence length to the max seq length if not provided
204-
seq_len = (seq_len > 0 && seq_len <= metadata_.at(kMaxSeqLen))
206+
seq_len = (seq_len > 0 && seq_len <= metadata_.at(kMaxContextLen))
205207
? seq_len
206-
: metadata_.at(kMaxSeqLen);
208+
: metadata_.at(kMaxContextLen);
207209

208210
::tokenizers::Result<std::vector<uint64_t>> encode_res = tokenizer_->encode(
209211
prompt,

0 commit comments

Comments
 (0)