Skip to content

Commit 060cc9d

Browse files
ai-edge-botcopybara-github
authored andcommitted
Internal change
LiteRT-LM-PiperOrigin-RevId: 861951627
1 parent 41b4175 commit 060cc9d

File tree

3 files changed

+10
-9
lines changed

3 files changed

+10
-9
lines changed

runtime/executor/llm_executor_io_types.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,11 @@ struct RuntimeState {
6464

6565
// Random generator for sampling step.
6666
std::shared_ptr<std::default_random_engine> rand_gen;
67+
68+
// Whether decode has been run ever after prefill.
69+
// This is only used by the compiled model executor to determine whether
70+
// KVCache preparation for prefill or decode should be done.
71+
bool ran_decode = false;
6772
};
6873

6974
// A resource interface to hold the llm context.

runtime/executor/llm_litert_compiled_model_executor.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -502,11 +502,11 @@ absl::Status LlmLiteRtCompiledModelExecutorBase::RollBackProcessedTokens() {
502502

503503
absl::Status LlmLiteRtCompiledModelExecutorBase::PrepareFirstPrefillAfterDecode(
504504
int token_index_to_reduce) {
505-
if (!ran_decode_) {
505+
if (!llm_context_->runtime_state().ran_decode) {
506506
return absl::OkStatus();
507507
}
508508

509-
ran_decode_ = false;
509+
llm_context_->runtime_state().ran_decode = false;
510510

511511
int output_heads = 1;
512512
if (llm_context_->runtime_config().output_heads.has_value()) {
@@ -939,11 +939,11 @@ int LlmLiteRtCompiledModelExecutorBase::BindTensorsAndRunDecodeStatic(
939939
}
940940

941941
absl::Status LlmLiteRtCompiledModelExecutorBase::PrepareFirstDecode() {
942-
if (ran_decode_) {
942+
if (llm_context_->runtime_state().ran_decode) {
943943
return absl::OkStatus();
944944
}
945945
// Mark that we have run decode at least once.
946-
ran_decode_ = true;
946+
llm_context_->runtime_state().ran_decode = true;
947947

948948
int output_heads = 1;
949949
if (llm_context_->runtime_config().output_heads.has_value()) {
@@ -1050,7 +1050,7 @@ LlmLiteRtCompiledModelExecutorBase::DecodeLogits(
10501050
auto output_logits,
10511051
decode_output_buffers_[signatures_.output_logits].Duplicate());
10521052

1053-
bool last_run_is_decode = ran_decode_;
1053+
bool last_run_is_decode = llm_context_->runtime_state().ran_decode;
10541054
RETURN_IF_ERROR(PrepareFirstDecode());
10551055
ASSIGN_OR_RETURN(auto step_and_token, GetTokenToDecode(inputs));
10561056
RETURN_IF_ERROR(DecodeInternal(step_and_token.token, output_logits));

runtime/executor/llm_litert_compiled_model_executor.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -284,10 +284,6 @@ class LlmLiteRtCompiledModelExecutorBase : public LlmExecutor {
284284
// 3. The processed tokens.(e.g. KVCache)
285285
std::unique_ptr<LlmContext> llm_context_;
286286

287-
// Whether decode has been run ever after prefill.
288-
// TODO: b/409401231 - Make sure this state is session dependent.
289-
bool ran_decode_ = false;
290-
291287
// Sampler for sampling logits.
292288
// For now, only CPU sampler is supported.
293289
std::unique_ptr<Sampler> sampler_;

0 commit comments

Comments
 (0)