@@ -307,15 +307,13 @@ void ConfigureMultiProfile(const Config& config, OrtSessionOptions& session_opti
307
307
const auto add_input_shapes = [](std::ostringstream& shapes, int seq_len, bool append = false ) {
308
308
if (append) shapes << " ," ;
309
309
shapes << Config::Defaults::InputIdsName << " :1x" << seq_len << " ,"
310
- << Config::Defaults::AttentionMaskName << " :1x" << seq_len << " ,"
311
- << Config::Defaults::PositionIdsName << " :1x" << seq_len;
310
+ << Config::Defaults::AttentionMaskName << " :1x" << seq_len;
312
311
};
313
312
314
313
// Helper function to add generation phase input shapes
315
314
const auto add_generation_input_shapes = [](std::ostringstream& shapes, int context_len) {
316
315
shapes << " ," << Config::Defaults::AttentionMaskName << " :1x" << context_len << " ,"
317
- << Config::Defaults::InputIdsName << " :1x1,"
318
- << Config::Defaults::PositionIdsName << " :1x1" ;
316
+ << Config::Defaults::InputIdsName << " :1x1" ;
319
317
};
320
318
321
319
// Helper function to add empty KV cache shapes for all layers
@@ -369,7 +367,7 @@ void ConfigureMultiProfile(const Config& config, OrtSessionOptions& session_opti
369
367
370
368
// MAX SHAPES (prefill with maximum context and generation after maximum context)
371
369
add_input_shapes (max_shapes, max_context_len);
372
- add_empty_key_value_cache_shapes (max_shapes, past_key_pattern, past_value_pattern, num_layers, num_kv_heads, head_dim);
370
+ add_key_value_cache_shapes (max_shapes, past_key_pattern, past_value_pattern, max_context_len - 1 , num_layers, num_kv_heads, head_dim);
373
371
add_generation_input_shapes (max_shapes, max_context_len);
374
372
add_key_value_cache_shapes (max_shapes, past_key_pattern, past_value_pattern, max_context_len - 1 , num_layers, num_kv_heads, head_dim);
375
373
0 commit comments