Skip to content

Commit ab8cd70

Browse files
committed
Add default logger for TRT logger
1 parent 2472a15 commit ab8cd70

File tree

5 files changed

+32
-16
lines changed

5 files changed

+32
-16
lines changed

plugin_execution_providers/tensorrt/onnx_ctx_model_helper.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
#include "onnx_ctx_model_helper.h"
1010
#include "onnx/onnx_pb.h"
1111

12-
extern TensorrtLogger& GetTensorrtLogger(bool verbose_log);
12+
extern TensorrtLogger& GetTensorrtLogger(bool verbose_log, const OrtLogger& ort_default_logger,
13+
const OrtApi* ort_api);
1314

1415
/*
1516
* Check whether the graph has the EP context node.

plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,11 @@ void OutputAllocator::notifyShape(char const* /*tensorName*/, nvinfer1::Dims con
8484
}
8585
}
8686

87-
TensorrtLogger& GetTensorrtLogger(bool verbose_log) {
87+
TensorrtLogger& GetTensorrtLogger(bool verbose_log,
88+
const OrtLogger& ort_default_logger,
89+
const OrtApi* ort_api) {
8890
const auto log_level = verbose_log ? nvinfer1::ILogger::Severity::kVERBOSE : nvinfer1::ILogger::Severity::kWARNING;
89-
static TensorrtLogger trt_logger(log_level);
91+
static TensorrtLogger trt_logger(ort_default_logger, ort_api, log_level);
9092
if (log_level != trt_logger.get_level()) {
9193
trt_logger.set_level(verbose_log ? nvinfer1::ILogger::Severity::kVERBOSE : nvinfer1::ILogger::Severity::kWARNING);
9294
}
@@ -1041,7 +1043,7 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this
10411043
model_proto.SerializeToOstream(&dump);
10421044
}
10431045

1044-
TensorrtLogger& trt_logger = GetTensorrtLogger(detailed_build_log_);
1046+
TensorrtLogger& trt_logger = GetTensorrtLogger(detailed_build_log_, logger_, &ort_api);
10451047
auto trt_builder = GetBuilder(trt_logger);
10461048
auto network_flags = 0;
10471049
#if NV_TENSORRT_MAJOR > 8
@@ -2021,7 +2023,7 @@ OrtStatus* TensorrtExecutionProvider::RefitEngine(
20212023
}
20222024

20232025
// weight-stripped engine refit logic
2024-
TensorrtLogger& trt_logger = GetTensorrtLogger(detailed_build_log);
2026+
TensorrtLogger& trt_logger = GetTensorrtLogger(detailed_build_log, logger_, &ort_api);
20252027
auto refitter = std::unique_ptr<nvinfer1::IRefitter>(nvinfer1::createInferRefitter(*trt_engine, trt_logger));
20262028
auto parser_refitter =
20272029
std::unique_ptr<nvonnxparser::IParserRefitter>(nvonnxparser::createParserRefitter(*refitter, trt_logger));
@@ -2378,7 +2380,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(TensorrtExecutionProviderFa
23782380

23792381
{
23802382
auto lock = GetApiLock();
2381-
runtime_ = std::unique_ptr<nvinfer1::IRuntime>(nvinfer1::createInferRuntime(GetTensorrtLogger(detailed_build_log_)));
2383+
runtime_ = std::unique_ptr<nvinfer1::IRuntime>(nvinfer1::createInferRuntime(GetTensorrtLogger(detailed_build_log_, logger_, &ort_api)));
23822384
}
23832385

23842386
// EP Context setting

plugin_execution_providers/tensorrt/tensorrt_execution_provider.h

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,14 @@ using DestroyFunc = void (*)(void*, void*);
6666

6767
class TensorrtLogger : public nvinfer1::ILogger {
6868
nvinfer1::ILogger::Severity verbosity_;
69+
const OrtLogger& ort_default_logger_;
70+
const OrtApi* ort_api_ = nullptr;
6971

7072
public:
71-
TensorrtLogger(Severity verbosity = Severity::kWARNING)
72-
: verbosity_(verbosity) {}
73+
TensorrtLogger(const OrtLogger& ort_default_logger,
74+
const OrtApi* ort_api,
75+
Severity verbosity = Severity::kWARNING)
76+
: ort_default_logger_{ort_default_logger}, ort_api_{ort_api}, verbosity_(verbosity) {}
7377
void log(Severity severity, const char* msg) noexcept override {
7478
if (severity <= verbosity_) {
7579
time_t rawtime = std::time(0);
@@ -87,11 +91,19 @@ class TensorrtLogger : public nvinfer1::ILogger {
8791
: severity == Severity::kWARNING ? "WARNING"
8892
: severity == Severity::kINFO ? " INFO"
8993
: "UNKNOWN");
94+
OrtLoggingLevel ort_severity;
9095
if (severity <= Severity::kERROR) {
91-
// LOGS_DEFAULT(ERROR) << "[" << buf << " " << sevstr << "] " << msg;
92-
} else {
93-
// LOGS_DEFAULT(WARNING) << "[" << buf << " " << sevstr << "] " << msg;
96+
ort_severity = ORT_LOGGING_LEVEL_ERROR;
9497
}
98+
else {
99+
ort_severity = ORT_LOGGING_LEVEL_WARNING;
100+
}
101+
102+
std::string message = "[" + std::string(buf) + " " + std::string(sevstr) + "] " + std::string(msg);
103+
104+
Ort::ThrowOnError(ort_api_->Logger_LogMessage(&ort_default_logger_,
105+
ort_severity,
106+
message.c_str(), ORT_FILE, __LINE__, __FUNCTION__));
95107
}
96108
}
97109
void set_level(Severity verbosity) {

plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
#include <unordered_map>
1212
#include <vector>
1313

14-
TensorrtExecutionProviderFactory::TensorrtExecutionProviderFactory(const char* ep_name, ApiPtrs apis)
15-
: ApiPtrs(apis), ep_name_{ep_name} {
14+
TensorrtExecutionProviderFactory::TensorrtExecutionProviderFactory(const char* ep_name, const OrtLogger& default_logger, ApiPtrs apis)
15+
: ApiPtrs(apis), default_logger_{default_logger}, ep_name_{ep_name} {
1616
ort_version_supported = ORT_API_VERSION; // set to the ORT version we were compiled with.
1717
GetName = GetNameImpl;
1818
GetVendor = GetVendorImpl;
@@ -280,14 +280,14 @@ extern "C" {
280280
// Public symbols
281281
//
282282
EXPORT_SYMBOL OrtStatus* CreateEpFactories(const char* registration_name, const OrtApiBase* ort_api_base,
283-
const OrtLogger*,
283+
const OrtLogger* default_logger,
284284
OrtEpFactory** factories, size_t max_factories, size_t* num_factories) {
285285
const OrtApi* ort_api = ort_api_base->GetApi(ORT_API_VERSION);
286286
const OrtEpApi* ort_ep_api = ort_api->GetEpApi();
287287
const OrtModelEditorApi* model_editor_api = ort_api->GetModelEditorApi();
288288

289289
// Factory could use registration_name or define its own EP name.
290-
std::unique_ptr<OrtEpFactory> factory = std::make_unique<TensorrtExecutionProviderFactory>(registration_name, ApiPtrs{*ort_api, *ort_ep_api, *model_editor_api});
290+
std::unique_ptr<OrtEpFactory> factory = std::make_unique<TensorrtExecutionProviderFactory>(registration_name, *default_logger, ApiPtrs{*ort_api, *ort_ep_api, *model_editor_api});
291291

292292
if (max_factories < 1) {
293293
return ort_api->CreateStatus(ORT_INVALID_ARGUMENT,

plugin_execution_providers/tensorrt/tensorrt_provider_factory.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ using MemoryInfoUniquePtr = std::unique_ptr<OrtMemoryInfo, std::function<void(Or
1111
///
1212
struct TensorrtExecutionProviderFactory : public OrtEpFactory, public ApiPtrs {
1313
public:
14-
TensorrtExecutionProviderFactory(const char* ep_name, ApiPtrs apis);
14+
TensorrtExecutionProviderFactory(const char* ep_name, const OrtLogger& default_logger, ApiPtrs apis);
1515

1616
OrtStatus* CreateMemoryInfoForDevices(int num_devices);
1717

@@ -62,4 +62,5 @@ struct TensorrtExecutionProviderFactory : public OrtEpFactory, public ApiPtrs {
6262
const std::string ep_name_; // EP name
6363
const std::string vendor_{"Nvidia"}; // EP vendor name
6464
const std::string ep_version_{"0.1.0"}; // EP version
65+
const OrtLogger& default_logger_;
6566
};

0 commit comments

Comments
 (0)