Skip to content

Commit 8c363d5

Browse files
authored
Fix TRT-RTX EP regression (microsoft#1754)
Fix regression caused by PR microsoft#1711 TRT-RTX EP was broken due to this regression * Use right EP name for NvTensorRtRtx * Make sure user_compute_stream option is set in WinML path
1 parent d77033c commit 8c363d5

File tree

2 files changed

+9
-11
lines changed

2 files changed

+9
-11
lines changed

src/config.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ std::string_view NormalizeProviderName(std::string_view name) {
2727
} else if (lower_name == "vitisai") {
2828
return "VitisAI";
2929
} else if (lower_name == "nvtensorrtrtx") {
30-
return "NvTensorRTRTX";
30+
return "NvTensorRtRtx";
3131
}
3232
return name; // Return name unchanged
3333
}

src/models/model.cpp

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -591,6 +591,14 @@ DeviceInterface* SetProviderSessionOptions(OrtSessionOptions& session_options,
591591
session_options.AddConfigEntry("ep.nvtensorrtrtxexecutionprovider.enable_cuda_graph", "1");
592592
}
593593
p_device = GetDeviceInterface(DeviceType::NvTensorRtRtx);
594+
595+
if (is_primary_session_options && p_device) {
596+
void* stream_ptr = p_device->GetCudaStream();
597+
std::stringstream stream_value;
598+
stream_value << reinterpret_cast<uintptr_t>(stream_ptr);
599+
std::string stream_value_str = stream_value.str();
600+
session_options.AddConfigEntry("ep.nvtensorrtrtxexecutionprovider.user_compute_stream", stream_value_str.c_str());
601+
}
594602
}
595603

596604
#if USE_WINML
@@ -716,16 +724,6 @@ DeviceInterface* SetProviderSessionOptions(OrtSessionOptions& session_options,
716724
}
717725
#else
718726
std::vector<const char*> keys, values;
719-
std::string stream_value_str;
720-
if (provider_options.name == "NvTensorRtRtx" && is_primary_session_options && p_device) {
721-
void* stream_ptr = p_device->GetCudaStream();
722-
std::stringstream stream_value;
723-
stream_value << reinterpret_cast<uintptr_t>(stream_ptr);
724-
stream_value_str = stream_value.str();
725-
726-
keys.emplace_back("user_compute_stream");
727-
values.emplace_back(stream_value_str.c_str());
728-
}
729727

730728
for (auto& option : provider_options.options) {
731729
keys.emplace_back(option.first.c_str());

0 commit comments

Comments
 (0)