@@ -88,16 +88,20 @@ static NameToLayerIdxMap GeneratePastKeyNameToLayerIdxMap(const Config& config)
88
88
return m;
89
89
}
90
90
91
- static std::vector<size_t > DetectLayerIndicesFromPastKeyNameInputs (
91
+ static std::vector<size_t > GetLayerIndicesSetFromPastKeyNameInputs (
92
92
const NameToLayerIdxMap& past_key_name_to_layer_idx, std::span<const std::string> inputs) {
93
- std::vector<size_t > detected_layer_indices {};
93
+ std::vector<size_t > layer_indices {};
94
94
for (const auto & input_name : inputs) {
95
95
const auto it = past_key_name_to_layer_idx.find (input_name);
96
96
if (it != past_key_name_to_layer_idx.end ()) {
97
- detected_layer_indices .push_back (it->second );
97
+ layer_indices .push_back (it->second );
98
98
}
99
99
}
100
- return detected_layer_indices;
100
+ // sort and remove duplicates
101
+ std::sort (layer_indices.begin (), layer_indices.end ());
102
+ layer_indices.erase (std::unique (layer_indices.begin (), layer_indices.end ()),
103
+ layer_indices.end ());
104
+ return layer_indices;
101
105
}
102
106
103
107
DecoderOnlyPipelineState::DecoderOnlyPipelineState (const DecoderOnlyPipelineModel& model,
@@ -107,8 +111,7 @@ DecoderOnlyPipelineState::DecoderOnlyPipelineState(const DecoderOnlyPipelineMode
107
111
model_{model},
108
112
input_ids_{CreateInputIDs (*this )},
109
113
key_value_cache_{CreateKeyValueCache (*this )},
110
- do_key_value_cache_partial_token_generation_update_{
111
- key_value_cache_ && key_value_cache_->IsPartialTokenGenerationUpdateSupported ()},
114
+ do_key_value_cache_partial_update_{key_value_cache_ && key_value_cache_->IsPartialUpdateSupported ()},
112
115
position_inputs_{CreatePositionInputs (*this , sequence_lengths, model_.config_ ->model .decoder .inputs .attention_mask )} {
113
116
input_ids_->Add ();
114
117
position_inputs_->Add ();
@@ -118,41 +121,68 @@ DecoderOnlyPipelineState::DecoderOnlyPipelineState(const DecoderOnlyPipelineMode
118
121
}
119
122
extra_inputs_.Add ();
120
123
121
- const auto past_key_name_to_layer_idx = [&]() -> std::optional<NameToLayerIdxMap> {
122
- if (do_key_value_cache_partial_token_generation_update_) {
123
- return GeneratePastKeyNameToLayerIdxMap (*model_.config_ );
124
- }
125
- return std::nullopt ;
126
- }();
124
+ const auto & config_pipeline = model_.config_ ->model .decoder .pipeline ;
127
125
128
- for (const auto & pipeline_model : model_. config_ -> model . decoder . pipeline ) {
126
+ for (size_t i = 0 ; i < config_pipeline. size (); ++i ) {
129
127
auto pipeline_model_state = std::make_unique<IntermediatePipelineState>(model_, params, pipeline_states_.size ());
128
+ pipeline_states_.emplace_back (std::move (pipeline_model_state));
129
+ }
130
130
131
- auto overlapped_kv_cache_update_record = [&]() -> std::optional<OverlappedKeyValueCacheUpdateRecord> {
132
- if (do_key_value_cache_partial_token_generation_update_) {
133
- const bool token_gen_only = !pipeline_model.run_on_prompt && pipeline_model.run_on_token_gen ;
134
- if (token_gen_only) {
135
- auto layer_indices = DetectLayerIndicesFromPastKeyNameInputs (*past_key_name_to_layer_idx,
136
- pipeline_model.inputs );
137
- if (!layer_indices.empty ()) {
138
- // token generation model with KV cache tensors - we should overlap KV cache update
139
- auto record = OverlappedKeyValueCacheUpdateRecord{};
140
- record.layer_indices = std::move (layer_indices);
141
- return record;
142
- }
131
+ if (do_key_value_cache_partial_update_) {
132
+ const auto past_key_name_to_layer_idx = GeneratePastKeyNameToLayerIdxMap (*model_.config_ );
133
+
134
+ std::map<std::vector<size_t >, size_t > layer_indices_to_update_record_idx{};
135
+ std::unordered_set<size_t > layer_indices_encountered{};
136
+
137
+ for (size_t i = 0 ; i < config_pipeline.size (); ++i) {
138
+ const auto & pipeline_model = config_pipeline[i];
139
+
140
+ const auto layer_indices = GetLayerIndicesSetFromPastKeyNameInputs (past_key_name_to_layer_idx,
141
+ pipeline_model.inputs );
142
+
143
+ if (layer_indices.empty ()) {
144
+ continue ;
145
+ }
146
+
147
+ size_t record_idx{};
148
+
149
+ if (auto layer_indices_to_update_record_it = layer_indices_to_update_record_idx.find (layer_indices);
150
+ layer_indices_to_update_record_it != layer_indices_to_update_record_idx.end ()) {
151
+ // we have seen this exact set of layer indices before. reuse the existing record.
152
+ record_idx = layer_indices_to_update_record_it->second ;
153
+ } else {
154
+ // verify that the new set of layer indices is valid.
155
+ // i.e., it is disjoint with the set of all layer indices we've seen so far.
156
+ const bool layer_indices_valid =
157
+ std::all_of (layer_indices.begin (), layer_indices.end (),
158
+ [&layer_indices_encountered](size_t layer_idx) {
159
+ return layer_indices_encountered.find (layer_idx) == layer_indices_encountered.end ();
160
+ });
161
+
162
+ if (!layer_indices_valid) {
163
+ throw std::runtime_error (
164
+ " Invalid layer indices. Layer index sets for partial key value cache update must be either an exact "
165
+ " match with another set or disjoint with all other sets." );
143
166
}
167
+
168
+ // add a new record
169
+ auto record = PartialKeyValueCacheUpdateRecord{};
170
+ record.layer_indices = layer_indices;
171
+
172
+ partial_kv_cache_update_records_.emplace_back (std::move (record));
173
+ record_idx = partial_kv_cache_update_records_.size () - 1 ;
174
+
175
+ // add layer_indices to what we've seen so far
176
+ layer_indices_encountered.insert (layer_indices.begin (), layer_indices.end ());
177
+ layer_indices_to_update_record_idx.emplace (layer_indices, record_idx);
144
178
}
145
- return std::nullopt ;
146
- }();
147
179
148
- pipeline_states_.emplace_back (std::move (pipeline_model_state));
149
- pipeline_overlapped_kv_cache_update_records_.emplace_back (std::move (overlapped_kv_cache_update_record));
150
- }
180
+ pipeline_state_id_to_partial_kv_cache_update_record_idx_.emplace (i, record_idx);
181
+ }
151
182
152
- if (std::any_of (pipeline_overlapped_kv_cache_update_records_.begin (),
153
- pipeline_overlapped_kv_cache_update_records_.end (),
154
- [](const auto & record) { return record.has_value (); })) {
155
- key_value_cache_update_worker_thread_.emplace ();
183
+ if (!partial_kv_cache_update_records_.empty ()) {
184
+ key_value_cache_update_worker_thread_.emplace ();
185
+ }
156
186
}
157
187
}
158
188
@@ -175,6 +205,23 @@ void DecoderOnlyPipelineState::RunPipeline(int total_length, DeviceSpan<int32_t>
175
205
(const_cast <DecoderOnlyPipelineModel*>(&model_))->sessions_ [model_.config_ ->model .decoder .pipeline [pipeline_state->id_ ].reset_session_idx ].reset ();
176
206
}
177
207
208
+ auto * const partial_kv_cache_update_record = [&]() -> PartialKeyValueCacheUpdateRecord* {
209
+ auto it = pipeline_state_id_to_partial_kv_cache_update_record_idx_.find (pipeline_state->id_ );
210
+ if (it != pipeline_state_id_to_partial_kv_cache_update_record_idx_.end ()) {
211
+ return &partial_kv_cache_update_records_[it->second ];
212
+ }
213
+ return nullptr ;
214
+ }();
215
+
216
+ // If there is any outstanding partial KV cache update, wait for it to finish.
217
+ // It is important to synchronize at this point, before setting input/output tensors for this pipeline state run,
218
+ // because a KV cache update may replace the KV cache input/output tensors.
219
+ if (partial_kv_cache_update_record) {
220
+ if (partial_kv_cache_update_record->outstanding_update .valid ()) {
221
+ partial_kv_cache_update_record->outstanding_update .get ();
222
+ }
223
+ }
224
+
178
225
// Clear the intermediate pipeline state outputs from the previous runs.
179
226
// These outputs will be replaced by the outputs from the current run.
180
227
for (const auto & output_name : pipeline_state->output_names_ ) {
@@ -251,26 +298,18 @@ void DecoderOnlyPipelineState::RunPipeline(int total_length, DeviceSpan<int32_t>
251
298
}
252
299
}
253
300
254
- auto & overlapped_kv_update_record = pipeline_overlapped_kv_cache_update_records_[pipeline_state->id_ ];
255
- if (overlapped_kv_update_record.has_value ()) {
256
- // wait for any outstanding KV cache update to finish
257
- if (overlapped_kv_update_record->outstanding_update .valid ()) {
258
- overlapped_kv_update_record->outstanding_update .get ();
259
- }
260
- }
261
-
262
301
// Run the intermediate pipeline state
263
302
pipeline_state->Run (total_length, next_tokens, next_indices);
264
303
265
- if (overlapped_kv_update_record.has_value ()) {
304
+ // If there is any partial KV cache update to start, enqueue it.
305
+ if (partial_kv_cache_update_record) {
266
306
assert (key_value_cache_update_worker_thread_.has_value ());
267
- // enqueue the next KV cache update
268
307
auto update_fn = [&key_value_cache = *key_value_cache_.get (),
269
- layer_indices = overlapped_kv_update_record ->layer_indices ,
308
+ layer_indices = partial_kv_cache_update_record ->layer_indices ,
270
309
next_indices, total_length]() {
271
- key_value_cache.PartialTokenGenerationUpdate (next_indices, total_length, layer_indices);
310
+ key_value_cache.PartialUpdate (next_indices, total_length, layer_indices);
272
311
};
273
- overlapped_kv_update_record ->outstanding_update = key_value_cache_update_worker_thread_->Enqueue (update_fn);
312
+ partial_kv_cache_update_record ->outstanding_update = key_value_cache_update_worker_thread_->Enqueue (update_fn);
274
313
}
275
314
276
315
// Transfer ownership of all the non-managed outputs from the current pipeline state to the ortvalue store.
@@ -307,7 +346,7 @@ DeviceSpan<float> DecoderOnlyPipelineState::Run(int total_length, DeviceSpan<int
307
346
if (model_.config_ ->model .decoder .sliding_window .has_value () && i < num_chunks - 1 ) {
308
347
// Sliding the window over the input_ids, key_cache, and value_cache, position_ids, and attention_mask
309
348
input_ids_->Update (next_tokens);
310
- if (key_value_cache_) key_value_cache_-> Update (next_indices, total_length);
349
+ UpdateKeyValueCache (next_indices, total_length);
311
350
position_inputs_->Update (next_tokens, total_length, static_cast <int >(input_ids_->GetShape ()[1 ]));
312
351
}
313
352
}
@@ -330,27 +369,30 @@ DeviceSpan<float> DecoderOnlyPipelineState::Run(int total_length, DeviceSpan<int
330
369
return logits_.Get ();
331
370
}
332
371
333
- void DecoderOnlyPipelineState::UpdateInputsOutputs (DeviceSpan<int32_t >& next_tokens,
334
- DeviceSpan<int32_t > beam_indices, int total_length) {
335
- input_ids_->Update (next_tokens);
336
- size_t new_length = input_ids_->GetShape ()[1 ];
337
- position_inputs_->Update (next_tokens, total_length, static_cast <int >(new_length));
338
-
372
+ void DecoderOnlyPipelineState::UpdateKeyValueCache (DeviceSpan<int32_t > beam_indices, int total_length) {
339
373
if (key_value_cache_) {
340
- const bool outstanding_key_value_cache_partial_token_generation_update =
341
- do_key_value_cache_partial_token_generation_update_ &&
342
- std::any_of (pipeline_overlapped_kv_cache_update_records_ .rbegin (),
343
- pipeline_overlapped_kv_cache_update_records_ .rend (),
344
- [](const std::optional<OverlappedKeyValueCacheUpdateRecord> & record) {
345
- return record.has_value () && record-> outstanding_update .valid ();
374
+ const bool outstanding_key_value_cache_partial_update =
375
+ do_key_value_cache_partial_update_ &&
376
+ std::any_of (partial_kv_cache_update_records_ .rbegin (),
377
+ partial_kv_cache_update_records_ .rend (),
378
+ [](const PartialKeyValueCacheUpdateRecord & record) {
379
+ return record.outstanding_update .valid ();
346
380
});
347
381
348
- if (outstanding_key_value_cache_partial_token_generation_update ) {
382
+ if (outstanding_key_value_cache_partial_update ) {
349
383
// If there is any outstanding partial KV cache update, don't update the KV cache here.
350
384
} else {
351
385
key_value_cache_->Update (beam_indices, total_length);
352
386
}
353
387
}
388
+ }
389
+
390
+ void DecoderOnlyPipelineState::UpdateInputsOutputs (DeviceSpan<int32_t >& next_tokens,
391
+ DeviceSpan<int32_t > beam_indices, int total_length) {
392
+ input_ids_->Update (next_tokens);
393
+ size_t new_length = input_ids_->GetShape ()[1 ];
394
+ position_inputs_->Update (next_tokens, total_length, static_cast <int >(new_length));
395
+ UpdateKeyValueCache (beam_indices, total_length);
354
396
355
397
logits_.Update (next_tokens, new_length);
356
398
}
0 commit comments