File tree Expand file tree Collapse file tree 3 files changed +10
-9
lines changed
Expand file tree Collapse file tree 3 files changed +10
-9
lines changed Original file line number Diff line number Diff line change @@ -64,6 +64,11 @@ struct RuntimeState {
6464
6565 // Random generator for sampling step.
6666 std::shared_ptr<std::default_random_engine> rand_gen;
67+
68+ // Whether decode has been run ever after prefill.
69+ // This is only used by the compiled model executor to determine whether
70+ // KVCache preparation for prefill or decode should be done.
71+ bool ran_decode = false ;
6772};
6873
6974// A resource interface to hold the llm context.
Original file line number Diff line number Diff line change @@ -502,11 +502,11 @@ absl::Status LlmLiteRtCompiledModelExecutorBase::RollBackProcessedTokens() {
502502
503503absl::Status LlmLiteRtCompiledModelExecutorBase::PrepareFirstPrefillAfterDecode (
504504 int token_index_to_reduce) {
505- if (!ran_decode_ ) {
505+ if (!llm_context_-> runtime_state (). ran_decode ) {
506506 return absl::OkStatus ();
507507 }
508508
509- ran_decode_ = false ;
509+ llm_context_-> runtime_state (). ran_decode = false ;
510510
511511 int output_heads = 1 ;
512512 if (llm_context_->runtime_config ().output_heads .has_value ()) {
@@ -939,11 +939,11 @@ int LlmLiteRtCompiledModelExecutorBase::BindTensorsAndRunDecodeStatic(
939939}
940940
941941absl::Status LlmLiteRtCompiledModelExecutorBase::PrepareFirstDecode () {
942- if (ran_decode_ ) {
942+ if (llm_context_-> runtime_state (). ran_decode ) {
943943 return absl::OkStatus ();
944944 }
945945 // Mark that we have run decode at least once.
946- ran_decode_ = true ;
946+ llm_context_-> runtime_state (). ran_decode = true ;
947947
948948 int output_heads = 1 ;
949949 if (llm_context_->runtime_config ().output_heads .has_value ()) {
@@ -1050,7 +1050,7 @@ LlmLiteRtCompiledModelExecutorBase::DecodeLogits(
10501050 auto output_logits,
10511051 decode_output_buffers_[signatures_.output_logits ].Duplicate ());
10521052
1053- bool last_run_is_decode = ran_decode_ ;
1053+ bool last_run_is_decode = llm_context_-> runtime_state (). ran_decode ;
10541054 RETURN_IF_ERROR (PrepareFirstDecode ());
10551055 ASSIGN_OR_RETURN (auto step_and_token, GetTokenToDecode (inputs));
10561056 RETURN_IF_ERROR (DecodeInternal (step_and_token.token , output_logits));
Original file line number Diff line number Diff line change @@ -284,10 +284,6 @@ class LlmLiteRtCompiledModelExecutorBase : public LlmExecutor {
284284 // 3. The processed tokens.(e.g. KVCache)
285285 std::unique_ptr<LlmContext> llm_context_;
286286
287- // Whether decode has been run ever after prefill.
288- // TODO: b/409401231 - Make sure this state is session dependent.
289- bool ran_decode_ = false ;
290-
291287 // Sampler for sampling logits.
292288 // For now, only CPU sampler is supported.
293289 std::unique_ptr<Sampler> sampler_;
You can’t perform that action at this time.
0 commit comments