Skip to content

Commit 4da9f90

Browse files
committed
update ep factory
1 parent 3269f73 commit 4da9f90

File tree

2 files changed

+15
-15
lines changed

2 files changed

+15
-15
lines changed

plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
#define ORT_API_MANUAL_INIT
22
#include "onnxruntime_cxx_api.h"
33
#undef ORT_API_MANUAL_INIT
4-
#include <tensorrt_provider_factory.h>
4+
#include "tensorrt_provider_factory.h"
5+
#include "tensorrt_execution_provider.h"
56

67
#include <gsl/gsl>
78
#include <cassert>
@@ -11,7 +12,7 @@
1112
#include <unordered_map>
1213
#include <vector>
1314

14-
struct TensorrtExecutionProvider;
15+
//struct TensorrtExecutionProvider;
1516

1617
static const char* ORT_API_CALL GetNameImpl(const OrtEpFactory* this_ptr) {
1718
const auto* factory = static_cast<const TensorrtExecutionProviderFactory*>(this_ptr);
@@ -30,21 +31,22 @@ static OrtStatus* ORT_API_CALL GetSupportedDevicesImpl(OrtEpFactory* this_ptr,
3031
size_t max_ep_devices,
3132
size_t* p_num_ep_devices) {
3233
size_t& num_ep_devices = *p_num_ep_devices;
33-
auto* factory = static_cast<ExampleEpFactory*>(this_ptr);
34+
auto* factory = static_cast<TensorrtExecutionProviderFactory*>(this_ptr);
3435

3536
for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) {
3637
// C API
3738
const OrtHardwareDevice& device = *devices[i];
38-
if (factory->ort_api.HardwareDevice_Type(&device) == OrtHardwareDeviceType::OrtHardwareDeviceType_CPU) {
39-
// these can be returned as nullptr if you have nothing to add.
39+
if (factory->ort_api.HardwareDevice_Type(&device) == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) {
40+
// These can be returned as nullptr if you have nothing to add.
4041
OrtKeyValuePairs* ep_metadata = nullptr;
4142
OrtKeyValuePairs* ep_options = nullptr;
4243
factory->ort_api.CreateKeyValuePairs(&ep_metadata);
4344
factory->ort_api.CreateKeyValuePairs(&ep_options);
4445

45-
// random example using made up values
46-
factory->ort_api.AddKeyValuePair(ep_metadata, "version", "0.1");
47-
factory->ort_api.AddKeyValuePair(ep_options, "run_really_fast", "true");
46+
// The ep options can be provided here as default values.
47+
// Users can also call SessionOptionsAppendExecutionProvider_V2 C API with provided ep options to override.
48+
factory->ort_api.AddKeyValuePair(ep_metadata, "version", "0.1"); // random example using made up values
49+
factory->ort_api.AddKeyValuePair(ep_options, "trt_builder_optimization_level", "3");
4850

4951
// OrtEpDevice copies ep_metadata and ep_options.
5052
auto* status = factory->ort_api.GetEpApi()->CreateEpDevice(factory, &device, ep_metadata, ep_options,
@@ -61,11 +63,11 @@ static OrtStatus* ORT_API_CALL GetSupportedDevicesImpl(OrtEpFactory* this_ptr,
6163
// C++ API equivalent. Throws on error.
6264
//{
6365
// Ort::ConstHardwareDevice device(devices[i]);
64-
// if (device.Type() == OrtHardwareDeviceType::OrtHardwareDeviceType_CPU) {
66+
// if (device.Type() == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) {
6567
// Ort::KeyValuePairs ep_metadata;
6668
// Ort::KeyValuePairs ep_options;
6769
// ep_metadata.Add("version", "0.1");
68-
// ep_options.Add("run_really_fast", "true");
70+
// ep_options.Add("trt_builder_optimization_level", "3");
6971
// Ort::EpDevice ep_device{*this_ptr, device, ep_metadata.GetConst(), ep_options.GetConst()};
7072
// ep_devices[num_ep_devices++] = ep_device.release();
7173
// }
@@ -102,14 +104,14 @@ static OrtStatus* ORT_API_CALL CreateEpImpl(OrtEpFactory* this_ptr,
102104
// const OrtHardwareDevice* device = devices[0];
103105
// const OrtKeyValuePairs* ep_metadata = ep_metadata[0];
104106

105-
auto dummy_ep = std::make_unique<TensorrtExecutionProvider>(*factory, factory->ep_name_, *session_options, *logger);
107+
auto dummy_ep = std::make_unique<onnxruntime::TensorrtExecutionProvider>(*factory, factory->ep_name_, *session_options, *logger);
106108

107109
*ep = dummy_ep.release();
108110
return nullptr;
109111
}
110112

111113
static void ORT_API_CALL ReleaseEpImpl(OrtEpFactory* /*this_ptr*/, OrtEp* ep) {
112-
ExampleEp* dummy_ep = static_cast<TensorrtExecutionProvider*>(ep);
114+
onnxruntime::TensorrtExecutionProvider* dummy_ep = static_cast<onnxruntime::TensorrtExecutionProvider*>(ep);
113115
delete dummy_ep;
114116
}
115117

plugin_execution_providers/tensorrt/tensorrt_provider_factory.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,9 @@ struct ApiPtrs {
3939
const OrtEpApi& ep_api;
4040
};
4141

42-
/// <summary>
4342
///
4443
/// Plugin TensorRT EP factory that can create an OrtEp and return information about the supported hardware devices.
4544
///
46-
/// </summary>
4745
struct TensorrtExecutionProviderFactory : OrtEpFactory, ApiPtrs {
4846
TensorrtExecutionProviderFactory(const char* ep_name, ApiPtrs apis) : ApiPtrs(apis), ep_name_{ep_name} {
4947
ort_version_supported = ORT_API_VERSION; // set to the ORT version we were compiled with.
@@ -53,6 +51,6 @@ struct TensorrtExecutionProviderFactory : OrtEpFactory, ApiPtrs {
5351
CreateEp = CreateEpImpl;
5452
ReleaseEp = ReleaseEpImpl;
5553
}
56-
const std::string ep_name_; // EP name
54+
const std::string ep_name_; // EP name
5755
const std::string vendor_{"Nvidia"}; // EP vendor name
5856
};

0 commit comments

Comments
 (0)