@@ -433,20 +433,25 @@ void ConfigureNvTensorRtRTxProfile(const Config& config, OrtSessionOptions& sess
433433 session_options.AddConfigEntry (" ep.nvtensorrtrtxexecutionprovider.nv_profile_opt_shapes" , opt_shapes.str ().c_str ());
434434 session_options.AddConfigEntry (" ep.nvtensorrtrtxexecutionprovider.nv_profile_max_shapes" , max_shapes.str ().c_str ());
435435 } else {
436- // Single profile mode: simple shapes with batch_dim=[0, 1,1] and seq_dim=[0,1 ,max_context_len]
436+ // Single profile mode: simple shapes with batch_dim=[1,1,batch_size ] and seq_dim=[1,1024 ,max_context_len]
437437 std::ostringstream min_shapes, opt_shapes, max_shapes;
438438
439- // MIN SHAPES: batch_dim=0, seq_dim=0
440- min_shapes << Config::Defaults::InputIdsName << " :0x0,"
441- << Config::Defaults::AttentionMaskName << " :0x0" ;
442- add_key_value_cache_shapes (min_shapes, 0 , past_key_pattern, past_value_pattern, 0 , num_layers, num_kv_heads, head_dim);
443-
444- // OPT SHAPES: batch_dim=1, seq_dim=1
445- opt_shapes << Config::Defaults::InputIdsName << " :1x1,"
446- << Config::Defaults::AttentionMaskName << " :1x1" ;
447- add_key_value_cache_shapes (opt_shapes, 1 , past_key_pattern, past_value_pattern, 1 , num_layers, num_kv_heads, head_dim);
448-
449- // MAX SHAPES: batch_dim=1, seq_dim=max_context_len
439+ // MIN SHAPES: batch_dim=1, seq_dim=1
440+ constexpr int min_context_len = 1 ;
441+ constexpr int min_batch_size = 1 ;
442+ min_shapes << Config::Defaults::InputIdsName << " :" << min_batch_size << " x" << min_context_len << " ,"
443+ << Config::Defaults::AttentionMaskName << " :" << min_batch_size << " x" << min_context_len;
444+ add_key_value_cache_shapes (min_shapes, min_batch_size, past_key_pattern, past_value_pattern, 0 , num_layers, num_kv_heads, head_dim);
445+
446+ // OPT SHAPES: batch_dim=1, seq_dim=1024
447+ const int opt_context_len = std::min (max_context_len / 2 , 1024 ); // Use a reasonable opt context length
448+ constexpr int opt_batch_size = 1 ; // Use a opt batch size of 1
449+ // keeping seq length to 1 as optimizing for the gen phase
450+ opt_shapes << Config::Defaults::InputIdsName << " :" << opt_batch_size << " x" << 1 << " ,"
451+ << Config::Defaults::AttentionMaskName << " :" << opt_batch_size << " x" << opt_context_len;
452+ add_key_value_cache_shapes (opt_shapes, opt_batch_size, past_key_pattern, past_value_pattern, opt_context_len, num_layers, num_kv_heads, head_dim);
453+
454+ // MAX SHAPES: seq_dim=max_context_len
450455 max_shapes << Config::Defaults::InputIdsName << " :" << batch_size << " x" << max_context_len << " ,"
451456 << Config::Defaults::AttentionMaskName << " :" << batch_size << " x" << max_context_len;
452457 add_key_value_cache_shapes (max_shapes, batch_size, past_key_pattern, past_value_pattern, max_context_len, num_layers, num_kv_heads, head_dim);
0 commit comments