Skip to content

Commit 0a34c76

Browse files
authored
Overlap prompt processing KV cache update for WindowedKeyValueCache in DecoderOnlyPipelineState (microsoft#1526)
1 parent dc448b5 commit 0a34c76

File tree

9 files changed

+507
-242
lines changed

9 files changed

+507
-242
lines changed

CMakeLists.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,7 @@ else()
9595
endif()
9696

9797
target_include_directories(onnxruntime-genai PRIVATE ${ORT_HEADER_DIR})
98-
target_include_directories(onnxruntime-genai PRIVATE ${onnxruntime_extensions_SOURCE_DIR}/include)
99-
target_include_directories(onnxruntime-genai PRIVATE ${onnxruntime_extensions_SOURCE_DIR}/shared/api/)
98+
target_include_directories(onnxruntime-genai PRIVATE ${onnxruntime_extensions_SOURCE_DIR}/shared/api)
10099
target_link_libraries(onnxruntime-genai PRIVATE onnxruntime_extensions)
101100
target_link_directories(onnxruntime-genai PRIVATE ${ORT_LIB_DIR})
102101
target_link_libraries(onnxruntime-genai PRIVATE Threads::Threads)

src/models/decoder_only_pipeline.cpp

Lines changed: 103 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -88,16 +88,20 @@ static NameToLayerIdxMap GeneratePastKeyNameToLayerIdxMap(const Config& config)
8888
return m;
8989
}
9090

91-
static std::vector<size_t> DetectLayerIndicesFromPastKeyNameInputs(
91+
static std::vector<size_t> GetLayerIndicesSetFromPastKeyNameInputs(
9292
const NameToLayerIdxMap& past_key_name_to_layer_idx, std::span<const std::string> inputs) {
93-
std::vector<size_t> detected_layer_indices{};
93+
std::vector<size_t> layer_indices{};
9494
for (const auto& input_name : inputs) {
9595
const auto it = past_key_name_to_layer_idx.find(input_name);
9696
if (it != past_key_name_to_layer_idx.end()) {
97-
detected_layer_indices.push_back(it->second);
97+
layer_indices.push_back(it->second);
9898
}
9999
}
100-
return detected_layer_indices;
100+
// sort and remove duplicates
101+
std::sort(layer_indices.begin(), layer_indices.end());
102+
layer_indices.erase(std::unique(layer_indices.begin(), layer_indices.end()),
103+
layer_indices.end());
104+
return layer_indices;
101105
}
102106

103107
DecoderOnlyPipelineState::DecoderOnlyPipelineState(const DecoderOnlyPipelineModel& model,
@@ -107,8 +111,7 @@ DecoderOnlyPipelineState::DecoderOnlyPipelineState(const DecoderOnlyPipelineMode
107111
model_{model},
108112
input_ids_{CreateInputIDs(*this)},
109113
key_value_cache_{CreateKeyValueCache(*this)},
110-
do_key_value_cache_partial_token_generation_update_{
111-
key_value_cache_ && key_value_cache_->IsPartialTokenGenerationUpdateSupported()},
114+
do_key_value_cache_partial_update_{key_value_cache_ && key_value_cache_->IsPartialUpdateSupported()},
112115
position_inputs_{CreatePositionInputs(*this, sequence_lengths, model_.config_->model.decoder.inputs.attention_mask)} {
113116
input_ids_->Add();
114117
position_inputs_->Add();
@@ -118,41 +121,68 @@ DecoderOnlyPipelineState::DecoderOnlyPipelineState(const DecoderOnlyPipelineMode
118121
}
119122
extra_inputs_.Add();
120123

121-
const auto past_key_name_to_layer_idx = [&]() -> std::optional<NameToLayerIdxMap> {
122-
if (do_key_value_cache_partial_token_generation_update_) {
123-
return GeneratePastKeyNameToLayerIdxMap(*model_.config_);
124-
}
125-
return std::nullopt;
126-
}();
124+
const auto& config_pipeline = model_.config_->model.decoder.pipeline;
127125

128-
for (const auto& pipeline_model : model_.config_->model.decoder.pipeline) {
126+
for (size_t i = 0; i < config_pipeline.size(); ++i) {
129127
auto pipeline_model_state = std::make_unique<IntermediatePipelineState>(model_, params, pipeline_states_.size());
128+
pipeline_states_.emplace_back(std::move(pipeline_model_state));
129+
}
130130

131-
auto overlapped_kv_cache_update_record = [&]() -> std::optional<OverlappedKeyValueCacheUpdateRecord> {
132-
if (do_key_value_cache_partial_token_generation_update_) {
133-
const bool token_gen_only = !pipeline_model.run_on_prompt && pipeline_model.run_on_token_gen;
134-
if (token_gen_only) {
135-
auto layer_indices = DetectLayerIndicesFromPastKeyNameInputs(*past_key_name_to_layer_idx,
136-
pipeline_model.inputs);
137-
if (!layer_indices.empty()) {
138-
// token generation model with KV cache tensors - we should overlap KV cache update
139-
auto record = OverlappedKeyValueCacheUpdateRecord{};
140-
record.layer_indices = std::move(layer_indices);
141-
return record;
142-
}
131+
if (do_key_value_cache_partial_update_) {
132+
const auto past_key_name_to_layer_idx = GeneratePastKeyNameToLayerIdxMap(*model_.config_);
133+
134+
std::map<std::vector<size_t>, size_t> layer_indices_to_update_record_idx{};
135+
std::unordered_set<size_t> layer_indices_encountered{};
136+
137+
for (size_t i = 0; i < config_pipeline.size(); ++i) {
138+
const auto& pipeline_model = config_pipeline[i];
139+
140+
const auto layer_indices = GetLayerIndicesSetFromPastKeyNameInputs(past_key_name_to_layer_idx,
141+
pipeline_model.inputs);
142+
143+
if (layer_indices.empty()) {
144+
continue;
145+
}
146+
147+
size_t record_idx{};
148+
149+
if (auto layer_indices_to_update_record_it = layer_indices_to_update_record_idx.find(layer_indices);
150+
layer_indices_to_update_record_it != layer_indices_to_update_record_idx.end()) {
151+
// we have seen this exact set of layer indices before. reuse the existing record.
152+
record_idx = layer_indices_to_update_record_it->second;
153+
} else {
154+
// verify that the new set of layer indices is valid.
155+
// i.e., it is disjoint with the set of all layer indices we've seen so far.
156+
const bool layer_indices_valid =
157+
std::all_of(layer_indices.begin(), layer_indices.end(),
158+
[&layer_indices_encountered](size_t layer_idx) {
159+
return layer_indices_encountered.find(layer_idx) == layer_indices_encountered.end();
160+
});
161+
162+
if (!layer_indices_valid) {
163+
throw std::runtime_error(
164+
"Invalid layer indices. Layer index sets for partial key value cache update must be either an exact "
165+
"match with another set or disjoint with all other sets.");
143166
}
167+
168+
// add a new record
169+
auto record = PartialKeyValueCacheUpdateRecord{};
170+
record.layer_indices = layer_indices;
171+
172+
partial_kv_cache_update_records_.emplace_back(std::move(record));
173+
record_idx = partial_kv_cache_update_records_.size() - 1;
174+
175+
// add layer_indices to what we've seen so far
176+
layer_indices_encountered.insert(layer_indices.begin(), layer_indices.end());
177+
layer_indices_to_update_record_idx.emplace(layer_indices, record_idx);
144178
}
145-
return std::nullopt;
146-
}();
147179

148-
pipeline_states_.emplace_back(std::move(pipeline_model_state));
149-
pipeline_overlapped_kv_cache_update_records_.emplace_back(std::move(overlapped_kv_cache_update_record));
150-
}
180+
pipeline_state_id_to_partial_kv_cache_update_record_idx_.emplace(i, record_idx);
181+
}
151182

152-
if (std::any_of(pipeline_overlapped_kv_cache_update_records_.begin(),
153-
pipeline_overlapped_kv_cache_update_records_.end(),
154-
[](const auto& record) { return record.has_value(); })) {
155-
key_value_cache_update_worker_thread_.emplace();
183+
if (!partial_kv_cache_update_records_.empty()) {
184+
key_value_cache_update_worker_thread_.emplace();
185+
}
156186
}
157187
}
158188

@@ -175,6 +205,23 @@ void DecoderOnlyPipelineState::RunPipeline(int total_length, DeviceSpan<int32_t>
175205
(const_cast<DecoderOnlyPipelineModel*>(&model_))->sessions_[model_.config_->model.decoder.pipeline[pipeline_state->id_].reset_session_idx].reset();
176206
}
177207

208+
auto* const partial_kv_cache_update_record = [&]() -> PartialKeyValueCacheUpdateRecord* {
209+
auto it = pipeline_state_id_to_partial_kv_cache_update_record_idx_.find(pipeline_state->id_);
210+
if (it != pipeline_state_id_to_partial_kv_cache_update_record_idx_.end()) {
211+
return &partial_kv_cache_update_records_[it->second];
212+
}
213+
return nullptr;
214+
}();
215+
216+
// If there is any outstanding partial KV cache update, wait for it to finish.
217+
// It is important to synchronize at this point, before setting input/output tensors for this pipeline state run,
218+
// because a KV cache update may replace the KV cache input/output tensors.
219+
if (partial_kv_cache_update_record) {
220+
if (partial_kv_cache_update_record->outstanding_update.valid()) {
221+
partial_kv_cache_update_record->outstanding_update.get();
222+
}
223+
}
224+
178225
// Clear the intermediate pipeline state outputs from the previous runs.
179226
// These outputs will be replaced by the outputs from the current run.
180227
for (const auto& output_name : pipeline_state->output_names_) {
@@ -251,26 +298,18 @@ void DecoderOnlyPipelineState::RunPipeline(int total_length, DeviceSpan<int32_t>
251298
}
252299
}
253300

254-
auto& overlapped_kv_update_record = pipeline_overlapped_kv_cache_update_records_[pipeline_state->id_];
255-
if (overlapped_kv_update_record.has_value()) {
256-
// wait for any outstanding KV cache update to finish
257-
if (overlapped_kv_update_record->outstanding_update.valid()) {
258-
overlapped_kv_update_record->outstanding_update.get();
259-
}
260-
}
261-
262301
// Run the intermediate pipeline state
263302
pipeline_state->Run(total_length, next_tokens, next_indices);
264303

265-
if (overlapped_kv_update_record.has_value()) {
304+
// If there is any partial KV cache update to start, enqueue it.
305+
if (partial_kv_cache_update_record) {
266306
assert(key_value_cache_update_worker_thread_.has_value());
267-
// enqueue the next KV cache update
268307
auto update_fn = [&key_value_cache = *key_value_cache_.get(),
269-
layer_indices = overlapped_kv_update_record->layer_indices,
308+
layer_indices = partial_kv_cache_update_record->layer_indices,
270309
next_indices, total_length]() {
271-
key_value_cache.PartialTokenGenerationUpdate(next_indices, total_length, layer_indices);
310+
key_value_cache.PartialUpdate(next_indices, total_length, layer_indices);
272311
};
273-
overlapped_kv_update_record->outstanding_update = key_value_cache_update_worker_thread_->Enqueue(update_fn);
312+
partial_kv_cache_update_record->outstanding_update = key_value_cache_update_worker_thread_->Enqueue(update_fn);
274313
}
275314

276315
// Transfer ownership of all the non-managed outputs from the current pipeline state to the ortvalue store.
@@ -307,7 +346,7 @@ DeviceSpan<float> DecoderOnlyPipelineState::Run(int total_length, DeviceSpan<int
307346
if (model_.config_->model.decoder.sliding_window.has_value() && i < num_chunks - 1) {
308347
// Sliding the window over the input_ids, key_cache, and value_cache, position_ids, and attention_mask
309348
input_ids_->Update(next_tokens);
310-
if (key_value_cache_) key_value_cache_->Update(next_indices, total_length);
349+
UpdateKeyValueCache(next_indices, total_length);
311350
position_inputs_->Update(next_tokens, total_length, static_cast<int>(input_ids_->GetShape()[1]));
312351
}
313352
}
@@ -330,27 +369,30 @@ DeviceSpan<float> DecoderOnlyPipelineState::Run(int total_length, DeviceSpan<int
330369
return logits_.Get();
331370
}
332371

333-
void DecoderOnlyPipelineState::UpdateInputsOutputs(DeviceSpan<int32_t>& next_tokens,
334-
DeviceSpan<int32_t> beam_indices, int total_length) {
335-
input_ids_->Update(next_tokens);
336-
size_t new_length = input_ids_->GetShape()[1];
337-
position_inputs_->Update(next_tokens, total_length, static_cast<int>(new_length));
338-
372+
void DecoderOnlyPipelineState::UpdateKeyValueCache(DeviceSpan<int32_t> beam_indices, int total_length) {
339373
if (key_value_cache_) {
340-
const bool outstanding_key_value_cache_partial_token_generation_update =
341-
do_key_value_cache_partial_token_generation_update_ &&
342-
std::any_of(pipeline_overlapped_kv_cache_update_records_.rbegin(),
343-
pipeline_overlapped_kv_cache_update_records_.rend(),
344-
[](const std::optional<OverlappedKeyValueCacheUpdateRecord>& record) {
345-
return record.has_value() && record->outstanding_update.valid();
374+
const bool outstanding_key_value_cache_partial_update =
375+
do_key_value_cache_partial_update_ &&
376+
std::any_of(partial_kv_cache_update_records_.rbegin(),
377+
partial_kv_cache_update_records_.rend(),
378+
[](const PartialKeyValueCacheUpdateRecord& record) {
379+
return record.outstanding_update.valid();
346380
});
347381

348-
if (outstanding_key_value_cache_partial_token_generation_update) {
382+
if (outstanding_key_value_cache_partial_update) {
349383
// If there is any outstanding partial KV cache update, don't update the KV cache here.
350384
} else {
351385
key_value_cache_->Update(beam_indices, total_length);
352386
}
353387
}
388+
}
389+
390+
void DecoderOnlyPipelineState::UpdateInputsOutputs(DeviceSpan<int32_t>& next_tokens,
391+
DeviceSpan<int32_t> beam_indices, int total_length) {
392+
input_ids_->Update(next_tokens);
393+
size_t new_length = input_ids_->GetShape()[1];
394+
position_inputs_->Update(next_tokens, total_length, static_cast<int>(new_length));
395+
UpdateKeyValueCache(beam_indices, total_length);
354396

355397
logits_.Update(next_tokens, new_length);
356398
}

src/models/decoder_only_pipeline.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,18 +68,21 @@ struct DecoderOnlyPipelineState : State {
6868
DeviceSpan<int32_t> next_indices);
6969

7070
private:
71+
void UpdateKeyValueCache(DeviceSpan<int32_t> beam_indices, int total_length);
72+
7173
void UpdateInputsOutputs(DeviceSpan<int32_t>& next_tokens, DeviceSpan<int32_t> next_indices,
7274
int total_length);
7375

7476
const DecoderOnlyPipelineModel& model_;
7577
std::vector<std::unique_ptr<IntermediatePipelineState>> pipeline_states_;
7678

77-
struct OverlappedKeyValueCacheUpdateRecord {
79+
struct PartialKeyValueCacheUpdateRecord {
7880
std::vector<size_t> layer_indices{}; // indicates which layers of the KV cache are to be updated
7981
std::future<void> outstanding_update{}; // future for an outstanding update task
8082
};
8183

82-
std::vector<std::optional<OverlappedKeyValueCacheUpdateRecord>> pipeline_overlapped_kv_cache_update_records_;
84+
std::map<size_t, size_t> pipeline_state_id_to_partial_kv_cache_update_record_idx_;
85+
std::vector<PartialKeyValueCacheUpdateRecord> partial_kv_cache_update_records_;
8386

8487
// Stores all the outputs from the previous pipeline state(s)
8588
std::unordered_map<std::string, std::unique_ptr<OrtValue>> ortvalue_store_;
@@ -88,7 +91,7 @@ struct DecoderOnlyPipelineState : State {
8891
Logits logits_{*this};
8992

9093
std::unique_ptr<KeyValueCache> key_value_cache_;
91-
const bool do_key_value_cache_partial_token_generation_update_;
94+
const bool do_key_value_cache_partial_update_;
9295
std::optional<WorkerThread> key_value_cache_update_worker_thread_{};
9396

9497
std::unique_ptr<PositionInputs> position_inputs_;

src/models/kv_cache.h

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
14
#pragma once
25

36
#include "model.h"
@@ -15,15 +18,15 @@ struct KeyValueCache {
1518

1619
virtual void RewindTo(size_t index) = 0;
1720

18-
// Note: PartialTokenGenerationUpdate() is mainly for supporting DecoderOnlyPipelineState usage where we update
21+
// Note: PartialUpdate() is mainly for supporting DecoderOnlyPipelineState usage where we update
1922
// part of the KV cache after running part of the pipeline.
2023
// An alternative may be to have a dedicated KV cache per IntermediatePipelineState.
2124

22-
virtual bool IsPartialTokenGenerationUpdateSupported() const { return false; }
25+
virtual bool IsPartialUpdateSupported() const { return false; }
2326

24-
virtual void PartialTokenGenerationUpdate(DeviceSpan<int32_t> beam_indices, int total_length,
25-
std::span<const size_t> layer_indices_to_update) {
26-
throw std::runtime_error("PartialTokenGenerationUpdate is not supported.");
27+
virtual void PartialUpdate(DeviceSpan<int32_t> beam_indices, int total_length,
28+
std::span<const size_t> layer_indices_to_update) {
29+
throw std::runtime_error("PartialUpdate is not supported.");
2730
}
2831
};
2932

0 commit comments

Comments
 (0)