Skip to content

Commit 7adf3b7

Browse files
authored
NvTensorRTRTx: Enable CUDA graph via config and fix attention_mask shape handling (#1594)
- Add option to enable CUDA graph for NvTensorRTRTx EP through provider config. - Fix handling of attention_mask shapes when `enable_cuda_graph` is false for NvTensorRTRTx: - When `past_present_share_buffer` (in place kv cache) is enabled, NvTensorRTRTx expects attention_mask shape as `[b, max_seq_len]` with masking applied. Previously, these shapes were only sent when both `past_present_share_buffer` and graph capture were enabled. This PR ensures the correct shape is passed to TRT for in-place KV cache, aligning with expected behavior. @baijumeswani for review
1 parent 070a034 commit 7adf3b7

File tree

5 files changed

+31
-11
lines changed

5 files changed

+31
-11
lines changed

src/config.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -828,7 +828,7 @@ void SetProviderOption(Config& config, std::string_view provider_name, std::stri
828828
JSON::Parse(element, json.str());
829829
}
830830

831-
bool IsGraphCaptureEnabled(Config::SessionOptions& session_options) {
831+
bool IsGraphCaptureEnabled(const Config::SessionOptions& session_options) {
832832
for (const auto& provider : session_options.providers) {
833833
const auto provider_options = std::find_if(session_options.provider_options.begin(),
834834
session_options.provider_options.end(),
@@ -846,7 +846,12 @@ bool IsGraphCaptureEnabled(Config::SessionOptions& session_options) {
846846
} else if (provider_options->name == "DML") {
847847
return true;
848848
} else if (provider_options->name == "NvTensorRtRtx") {
849-
return true;
849+
for (const auto& value : provider_options->options) {
850+
if (value.first == "enable_cuda_graph" && value.second == "1") {
851+
return true;
852+
}
853+
}
854+
return false;
850855
}
851856
}
852857
}

src/config.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ void SetSearchBool(Config::Search& search, std::string_view name, bool value);
276276
void ClearProviders(Config& config);
277277
void SetProviderOption(Config& config, std::string_view provider_name, std::string_view option_name, std::string_view option_value);
278278
void OverlayConfig(Config& config, std::string_view json);
279-
bool IsGraphCaptureEnabled(Config::SessionOptions& session_options);
279+
bool IsGraphCaptureEnabled(const Config::SessionOptions& session_options);
280280
bool IsMultiProfileEnabled(const Config::SessionOptions& session_options);
281281

282282
} // namespace Generators

src/models/model.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,9 @@ DeviceInterface* SetProviderSessionOptions(OrtSessionOptions& session_options,
541541
if (IsMultiProfileEnabled(config.model.decoder.session_options)) {
542542
ConfigureMultiProfile(config, session_options);
543543
}
544+
if (IsGraphCaptureEnabled(config.model.decoder.session_options)) {
545+
session_options.AddConfigEntry("ep.nvtensorrtrtxexecutionprovider.nv_cuda_graph_enable", "1");
546+
}
544547
p_device = GetDeviceInterface(DeviceType::NvTensorRtRtx);
545548
}
546549

src/models/position_inputs.cpp

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ void DefaultPositionInputs::UpdatePositionIDs(int total_length, int new_kv_lengt
141141
}
142142

143143
void DefaultPositionInputs::CreateNextAttentionMaskTensor(int total_length) {
144-
if (state_.params_->use_graph_capture)
144+
if (ShouldUseStaticMaskHandling())
145145
return;
146146
attention_mask_shape_[1] = total_length;
147147
attention_mask_next_->CreateTensor(attention_mask_shape_);
@@ -154,26 +154,26 @@ void DefaultPositionInputs::UpdateAttentionMask(int total_length, int new_kv_len
154154
CreateNextAttentionMaskTensor(total_length);
155155

156156
// Update the attention mask on the device. If it fails, copy to CPU, update there, and copy back to device.
157-
if (!model_.p_device_inputs_->UpdateAttentionMask(state_.params_->use_graph_capture ? nullptr : attention_mask_next_->GetMutableRawData(),
157+
if (!model_.p_device_inputs_->UpdateAttentionMask(ShouldUseStaticMaskHandling() ? nullptr : attention_mask_next_->GetMutableRawData(),
158158
attention_mask_->GetMutableRawData(),
159159
static_cast<int>(attention_mask_shape_[0]),
160160
new_kv_length,
161161
total_length,
162162
state_.params_->search.max_length,
163-
state_.params_->use_graph_capture,
163+
ShouldUseStaticMaskHandling(),
164164
type_)) {
165165
// auto* attention_mask_next_span = state_.params_->use_graph_capture ? &attention_mask_next_->GetByteSpan() : nullptr;
166166
DeviceSpan<uint8_t> attention_mask_next_span;
167-
if (!state_.params_->use_graph_capture)
167+
if (!ShouldUseStaticMaskHandling())
168168
attention_mask_next_span = attention_mask_next_->GetByteSpan();
169169
auto attention_mask_span = attention_mask_->GetByteSpan();
170-
GetDeviceInterface(DeviceType::CPU)->UpdateAttentionMask(state_.params_->use_graph_capture ? nullptr : attention_mask_next_span.CopyDeviceToCpu().data(), attention_mask_span.CopyDeviceToCpu().data(), static_cast<int>(attention_mask_shape_[0]), new_kv_length, total_length, state_.params_->search.max_length, state_.params_->use_graph_capture, type_);
171-
if (!state_.params_->use_graph_capture)
170+
GetDeviceInterface(DeviceType::CPU)->UpdateAttentionMask(ShouldUseStaticMaskHandling() ? nullptr : attention_mask_next_span.CopyDeviceToCpu().data(), attention_mask_span.CopyDeviceToCpu().data(), static_cast<int>(attention_mask_shape_[0]), new_kv_length, total_length, state_.params_->search.max_length, ShouldUseStaticMaskHandling(), type_);
171+
if (!ShouldUseStaticMaskHandling())
172172
attention_mask_next_span.CopyCpuToDevice();
173173
attention_mask_span.CopyCpuToDevice();
174174
}
175175

176-
if (!state_.params_->use_graph_capture) {
176+
if (!ShouldUseStaticMaskHandling()) {
177177
attention_mask_->ort_tensor_ = std::move(attention_mask_next_->ort_tensor_);
178178
state_.inputs_[mask_input_index_] = attention_mask_->GetOrtTensor();
179179
}
@@ -256,7 +256,7 @@ void DefaultPositionInputs::CreateAndInitializeAttentionMask(DeviceSpan<int32_t>
256256
}
257257
}
258258

259-
if (state_.params_->use_graph_capture) {
259+
if (ShouldUseStaticMaskHandling()) {
260260
InitializeStaticMask<T>(*attention_mask);
261261
} else {
262262
attention_mask = model_.ExpandInputs(attention_mask, state_.params_->search.num_beams);
@@ -291,6 +291,12 @@ void DefaultPositionInputs::RewindMask(size_t index) {
291291
}
292292
}
293293

294+
bool DefaultPositionInputs::ShouldUseStaticMaskHandling() const {
295+
return state_.params_->use_graph_capture ||
296+
(state_.params_->search.past_present_share_buffer &&
297+
model_.p_device_->GetType() == DeviceType::NvTensorRtRtx);
298+
}
299+
294300
// TODO: SlidingWindow does not support graph capture
295301
WindowedPositionInputs::WindowedPositionInputs(State& state)
296302
: state_{state} {

src/models/position_inputs.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,12 @@ struct DefaultPositionInputs : PositionInputs {
3838

3939
void RewindMask(size_t index);
4040

41+
// This returns true when either:
42+
// 1. Graph capture is enabled, OR
43+
// 2. Past-present buffer sharing is enabled AND the device is NvTensorRtRtx
44+
// Both scenarios require static mask allocation and special shape handling for optimization
45+
bool ShouldUseStaticMaskHandling() const;
46+
4147
const Model& model_;
4248
State& state_;
4349
std::string attention_mask_name_;

0 commit comments

Comments
 (0)