Skip to content

Commit 5f7fc49

Browse files
authored
Remove position_id and fix context phase KV shapes for in-place cache buffer support (microsoft#1505)
- Remove position IDs - Fix context phase KV shapes for in-place cache buffer support @BLSharda @baijumeswani @kunal-vaishnavi
1 parent ba56bf1 commit 5f7fc49

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

src/models/model.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -307,15 +307,13 @@ void ConfigureMultiProfile(const Config& config, OrtSessionOptions& session_opti
307307
const auto add_input_shapes = [](std::ostringstream& shapes, int seq_len, bool append = false) {
308308
if (append) shapes << ",";
309309
shapes << Config::Defaults::InputIdsName << ":1x" << seq_len << ","
310-
<< Config::Defaults::AttentionMaskName << ":1x" << seq_len << ","
311-
<< Config::Defaults::PositionIdsName << ":1x" << seq_len;
310+
<< Config::Defaults::AttentionMaskName << ":1x" << seq_len;
312311
};
313312

314313
// Helper function to add generation phase input shapes
315314
const auto add_generation_input_shapes = [](std::ostringstream& shapes, int context_len) {
316315
shapes << "," << Config::Defaults::AttentionMaskName << ":1x" << context_len << ","
317-
<< Config::Defaults::InputIdsName << ":1x1,"
318-
<< Config::Defaults::PositionIdsName << ":1x1";
316+
<< Config::Defaults::InputIdsName << ":1x1";
319317
};
320318

321319
// Helper function to add empty KV cache shapes for all layers
@@ -369,7 +367,7 @@ void ConfigureMultiProfile(const Config& config, OrtSessionOptions& session_opti
369367

370368
// MAX SHAPES (prefill with maximum context and generation after maximum context)
371369
add_input_shapes(max_shapes, max_context_len);
372-
add_empty_key_value_cache_shapes(max_shapes, past_key_pattern, past_value_pattern, num_layers, num_kv_heads, head_dim);
370+
add_key_value_cache_shapes(max_shapes, past_key_pattern, past_value_pattern, max_context_len - 1, num_layers, num_kv_heads, head_dim);
373371
add_generation_input_shapes(max_shapes, max_context_len);
374372
add_key_value_cache_shapes(max_shapes, past_key_pattern, past_value_pattern, max_context_len - 1, num_layers, num_kv_heads, head_dim);
375373

0 commit comments

Comments
 (0)