Skip to content

Commit bfa9e11

Browse files
authored
Minor changes in TRT-RTX min and opt profile (#1659)
- Keep minimum batch size and seq length as 1 instead of 0 - Change opt profile for attention mask and past KV shape to have larger seq length instead of 1
1 parent f9a57f5 commit bfa9e11

File tree

1 file changed

+17
-12
lines changed

1 file changed

+17
-12
lines changed

src/models/model.cpp

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -433,20 +433,25 @@ void ConfigureNvTensorRtRTxProfile(const Config& config, OrtSessionOptions& sess
433433
session_options.AddConfigEntry("ep.nvtensorrtrtxexecutionprovider.nv_profile_opt_shapes", opt_shapes.str().c_str());
434434
session_options.AddConfigEntry("ep.nvtensorrtrtxexecutionprovider.nv_profile_max_shapes", max_shapes.str().c_str());
435435
} else {
436-
// Single profile mode: simple shapes with batch_dim=[0,1,1] and seq_dim=[0,1,max_context_len]
436+
// Single profile mode: simple shapes with batch_dim=[1,1,batch_size] and seq_dim=[1,1024,max_context_len]
437437
std::ostringstream min_shapes, opt_shapes, max_shapes;
438438

439-
// MIN SHAPES: batch_dim=0, seq_dim=0
440-
min_shapes << Config::Defaults::InputIdsName << ":0x0,"
441-
<< Config::Defaults::AttentionMaskName << ":0x0";
442-
add_key_value_cache_shapes(min_shapes, 0, past_key_pattern, past_value_pattern, 0, num_layers, num_kv_heads, head_dim);
443-
444-
// OPT SHAPES: batch_dim=1, seq_dim=1
445-
opt_shapes << Config::Defaults::InputIdsName << ":1x1,"
446-
<< Config::Defaults::AttentionMaskName << ":1x1";
447-
add_key_value_cache_shapes(opt_shapes, 1, past_key_pattern, past_value_pattern, 1, num_layers, num_kv_heads, head_dim);
448-
449-
// MAX SHAPES: batch_dim=1, seq_dim=max_context_len
439+
// MIN SHAPES: batch_dim=1, seq_dim=1
440+
constexpr int min_context_len = 1;
441+
constexpr int min_batch_size = 1;
442+
min_shapes << Config::Defaults::InputIdsName << ":" << min_batch_size << "x" << min_context_len << ","
443+
<< Config::Defaults::AttentionMaskName << ":" << min_batch_size << "x" << min_context_len;
444+
add_key_value_cache_shapes(min_shapes, min_batch_size, past_key_pattern, past_value_pattern, 0, num_layers, num_kv_heads, head_dim);
445+
446+
// OPT SHAPES: batch_dim=1, seq_dim=1024
447+
const int opt_context_len = std::min(max_context_len / 2, 1024); // Use a reasonable opt context length
448+
constexpr int opt_batch_size = 1; // Use a opt batch size of 1
449+
// keeping seq length to 1 as optimizing for the gen phase
450+
opt_shapes << Config::Defaults::InputIdsName << ":" << opt_batch_size << "x" << 1 << ","
451+
<< Config::Defaults::AttentionMaskName << ":" << opt_batch_size << "x" << opt_context_len;
452+
add_key_value_cache_shapes(opt_shapes, opt_batch_size, past_key_pattern, past_value_pattern, opt_context_len, num_layers, num_kv_heads, head_dim);
453+
454+
// MAX SHAPES: seq_dim=max_context_len
450455
max_shapes << Config::Defaults::InputIdsName << ":" << batch_size << "x" << max_context_len << ","
451456
<< Config::Defaults::AttentionMaskName << ":" << batch_size << "x" << max_context_len;
452457
add_key_value_cache_shapes(max_shapes, batch_size, past_key_pattern, past_value_pattern, max_context_len, num_layers, num_kv_heads, head_dim);

0 commit comments

Comments
 (0)