Skip to content

Commit b8904ac

Browse files
authored
NvTensorRtRtx dependency on CUDA device name removed (microsoft#1485)
1 parent fe9657c commit b8904ac

File tree

4 files changed

+35
-18
lines changed

4 files changed

+35
-18
lines changed

src/cuda/interface.cpp

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@
99
#include "kernels.h"
1010
#include <cstdarg>
1111

12+
#if defined(_WIN32) || defined(_WIN64)
13+
#define strcasecmp _stricmp
14+
#endif
15+
1216
namespace Generators {
1317

1418
GenaiInterface* gp_genai{};
@@ -68,16 +72,14 @@ struct GpuMemory final : DeviceBuffer {
6872
bool owned_; // If we own the memory, we delete it on destruction
6973
};
7074

71-
struct CudaInterfaceImpl final : DeviceInterface {
72-
CudaInterfaceImpl() {
75+
struct CudaInterfaceImplBase : DeviceInterface {
76+
CudaInterfaceImplBase() {
7377
g_stream.Create();
7478
}
7579

76-
~CudaInterfaceImpl() {
80+
~CudaInterfaceImplBase() {
7781
}
7882

79-
DeviceType GetType() const override { return DeviceType::CUDA; }
80-
8183
void InitOrt(const OrtApi& api, Ort::Allocator& allocator) override {
8284
Ort::api = &api;
8385
assert(!ort_allocator_);
@@ -164,6 +166,14 @@ struct CudaInterfaceImpl final : DeviceInterface {
164166
}
165167
};
166168

169+
struct CudaInterfaceImpl final : CudaInterfaceImplBase {
170+
DeviceType GetType() const override { return DeviceType::CUDA; }
171+
};
172+
173+
struct NvTensorRtRtxInterfaceImpl final : CudaInterfaceImplBase {
174+
DeviceType GetType() const override { return DeviceType::NvTensorRtRtx; }
175+
};
176+
167177
std::unique_ptr<DeviceInterface> g_cuda_device;
168178

169179
DeviceInterface& GetCudaDeviceInterface() { return *g_cuda_device; }
@@ -205,9 +215,13 @@ void operator delete(void* p, size_t /*size*/) noexcept { Generators::gp_genai->
205215
#endif
206216

207217
extern "C" {
208-
Generators::DeviceInterface* GetInterface(GenaiInterface* p_genai) {
218+
Generators::DeviceInterface* GetInterface(GenaiInterface* p_genai, const char* deviceType) {
209219
Generators::gp_genai = p_genai;
210-
Generators::g_cuda_device = std::make_unique<Generators::CudaInterfaceImpl>();
220+
if (strcasecmp(deviceType, "NvTensorRtRtx") == 0) {
221+
Generators::g_cuda_device = std::make_unique<Generators::NvTensorRtRtxInterfaceImpl>();
222+
} else {
223+
Generators::g_cuda_device = std::make_unique<Generators::CudaInterfaceImpl>();
224+
}
211225
return Generators::g_cuda_device.get();
212226
}
213227
}

src/generators.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,8 @@ struct LibraryHandle {
178178
};
179179
#endif
180180

181-
DeviceInterface* GetCudaInterface() {
181+
DeviceInterface* GetCudaInterface(DeviceType type) {
182+
assert(type == DeviceType::NvTensorRtRtx || type == DeviceType::CUDA);
182183
try {
183184
#if defined(_WIN32)
184185
static LibraryHandle library{"onnxruntime-genai-cuda.dll"};
@@ -190,8 +191,10 @@ DeviceInterface* GetCudaInterface() {
190191
if (!library)
191192
throw std::runtime_error("Shared library load failure (see first error)");
192193

193-
Generators::DeviceInterface* GetInterface(GenaiInterface * p_genai);
194-
static DeviceInterface* cuda_interface = reinterpret_cast<decltype(&GetInterface)>(library.GetSymbol("GetInterface"))(&g_genai);
194+
Generators::DeviceInterface* GetInterface(GenaiInterface * p_genai, const char* deviceType);
195+
static DeviceInterface* cuda_interface =
196+
reinterpret_cast<decltype(&GetInterface)>(
197+
library.GetSymbol("GetInterface"))(&g_genai, to_string(type).c_str());
195198

196199
return cuda_interface;
197200
} catch (const std::exception& e) {
@@ -213,6 +216,8 @@ std::string to_string(DeviceType device_type) {
213216
return "QnnWithSharedMemory";
214217
case DeviceType::OpenVINO:
215218
return "OpenVINO";
219+
case DeviceType::NvTensorRtRtx:
220+
return "NvTensorRtRtx";
216221
default:
217222
throw std::runtime_error("Unknown device type");
218223
}
@@ -224,7 +229,8 @@ DeviceInterface* GetDeviceInterface(DeviceType type) {
224229
case DeviceType::CPU:
225230
return GetCpuInterface();
226231
case DeviceType::CUDA:
227-
return GetCudaInterface();
232+
case DeviceType::NvTensorRtRtx:
233+
return GetCudaInterface(type);
228234
#if USE_DML
229235
case DeviceType::DML:
230236
return GetDmlInterface();

src/models/model.cpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -483,14 +483,10 @@ DeviceInterface* SetProviderSessionOptions(OrtSessionOptions& session_options,
483483
session_options.AddConfigEntry("session.inter_op.allow_spinning", "0");
484484
session_options.AddConfigEntry("session.intra_op.allow_spinning", "0");
485485
} else if (provider_options.name == "NvTensorRtRtx") {
486-
// After setting the NvTensorRtRtx provider in Onnxruntime, GenAI will then treat it as the cuda device.
487-
session_options.AddConfigEntry("ep.nvtensorrtrtxexecutionprovider.nv_cuda_graph_enable", "1");
488-
489486
if (IsMultiProfileEnabled(config.model.decoder.session_options)) {
490487
ConfigureMultiProfile(config, session_options);
491488
}
492-
493-
p_device = GetDeviceInterface(DeviceType::CUDA);
489+
p_device = GetDeviceInterface(DeviceType::NvTensorRtRtx);
494490
}
495491

496492
std::vector<const char*> keys, values;
@@ -536,7 +532,7 @@ void EnsureDeviceOrtInit(DeviceInterface& device, const Config& config) {
536532
// This ensures memory allocated on-device for model inputs/outputs is valid for the lifetime of GenAI.
537533

538534
// Names for the device types used by 'SetProviderSessionOptions'
539-
static const char* device_type_names[] = {"CPU (Not used, see above)", "cuda", "DML", "WebGPU", "QNN", "OpenVINO (Not used, see above)"};
535+
static const char* device_type_names[] = {"CPU (Not used, see above)", "cuda", "DML", "WebGPU", "QNN", "OpenVINO (Not used, see above)", "NvTensorRtRtx"};
540536
static_assert(std::size(device_type_names) == static_cast<size_t>(DeviceType::MAX));
541537

542538
// Create an OrtSessionOptions and set the options to use the DeviceType we're using here
@@ -555,7 +551,7 @@ void EnsureDeviceOrtInit(DeviceInterface& device, const Config& config) {
555551
allocator.session_ = OrtSession::Create(GetOrtEnv(), g_trivial_model, sizeof(g_trivial_model), session_options.get());
556552

557553
// Names for the device memory types used by 'OrtMemoryInfo::Create'
558-
static const char* device_memory_type_names[] = {"CPU (Not used, see above)", "Cuda", "DML", "WebGPU_Buffer", "QnnHtpShared", "OpenVINO (Not used, see above)"};
554+
static const char* device_memory_type_names[] = {"CPU (Not used, see above)", "Cuda", "DML", "WebGPU_Buffer", "QnnHtpShared", "OpenVINO (Not used, see above)", "Cuda"};
559555
static_assert(std::size(device_memory_type_names) == static_cast<size_t>(DeviceType::MAX));
560556

561557
// Get the allocator from the OrtSession for the DeviceType (it's called 'AllocatorCreate' but it's really 'AllocatorGet')

src/smartptrs.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ enum struct DeviceType {
9191
WEBGPU,
9292
QNN,
9393
OpenVINO,
94+
NvTensorRtRtx,
9495
MAX
9596
};
9697

0 commit comments

Comments
 (0)