Skip to content

Commit 486fa63

Browse files
committed
Add initial test infra
1 parent 5ab50ac commit 486fa63

File tree

6 files changed

+504
-0
lines changed

6 files changed

+504
-0
lines changed
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
# usage:
2+
# cd build/
3+
# cmake -S ../ -B ./ -DCMAKE_BUILD_TYPE=Debug -DORT_HOME=/path/to/ort_package/onnxruntime-win-x64-gpu-1.23.0 -DCMAKE_CUDA_ARCHITECTURES=80 -DCMAKE_CUDA_COMPILER=/usr/local/cuda/bin/nvcc -DTENSORRT_HOME=/path/to/tensorrt/TensorRT-10.3.0.26 -DCMAKE_POSITION_INDEPENDENT_CODE=ON (see the result of "nvidia-smi --query-gpu=compute_cap --format=csv,noheader,nounits")
4+
# cmake --build ./ --config Debug
5+
cmake_minimum_required(VERSION 3.26)
6+
project(tensorrt_ep_test VERSION 1.0)
7+
set(CMAKE_CXX_STANDARD 17)
8+
9+
# CMake config to force dynamic debug CRT or dynamic release CRT globally for all dependencies.
10+
# This is to address the issue of:
11+
# libprotobufd.lib(common.obj) : error LNK2038: mismatch detected for 'RuntimeLibrary': value 'MTd_StaticDebug' doesn't match value 'MDd_DynamicDebug' in unary_elementwise_ops_impl.obj
12+
if (WIN32)
13+
if(CMAKE_BUILD_TYPE STREQUAL "Debug")
14+
set(CMAKE_MSVC_RUNTIME_LIBRARY "MultiThreadedDebugDLL" CACHE STRING "" FORCE) # /MDd
15+
set(BUILD_SHARED_LIBS OFF) # Build protobuf as static .lib, but using dynamic runtime
16+
endif()
17+
18+
if(CMAKE_BUILD_TYPE STREQUAL "Release")
19+
set(CMAKE_MSVC_RUNTIME_LIBRARY "MultiThreadedDLL" CACHE STRING "" FORCE)
20+
set(BUILD_SHARED_LIBS OFF) # Build protobuf as static .lib, but using dynamic runtime
21+
endif()
22+
endif()
23+
24+
add_definitions(-DONNX_NAMESPACE=onnx)
25+
add_definitions(-DONNX_ML)
26+
add_definitions(-DNOMINMAX)
27+
file(GLOB tensorrt_ep_test_src "./*.cc" "./*.h")
28+
add_executable(tensorrt_ep_test ${tensorrt_ep_test_src})
29+
30+
if (NOT ORT_HOME)
31+
message(FATAL_ERROR "Please specify ORT_HOME, e.g. -DORT_HOME=/path/to/ort/")
32+
endif()
33+
34+
# Use release mode if not specified
35+
if (NOT CMAKE_BUILD_TYPE)
36+
set(CMAKE_BUILD_TYPE "Release")
37+
endif()
38+
39+
# Add dependencies
40+
include(FetchContent)
41+
42+
# Add protobuf
43+
FetchContent_Declare(
44+
protobuf
45+
GIT_REPOSITORY https://github.com/protocolbuffers/protobuf.git
46+
GIT_TAG v21.12 # Use a specific tag or commit
47+
)
48+
49+
if (WIN32)
50+
# Sometimes, protobuf ignores CMAKE_MSVC_RUNTIME_LIBRARY. To ensure it works:
51+
set(protobuf_MSVC_STATIC_RUNTIME OFF CACHE BOOL "" FORCE)
52+
endif()
53+
54+
set(protobuf_BUILD_TESTS OFF CACHE BOOL "" FORCE)
55+
56+
FetchContent_MakeAvailable(protobuf)
57+
58+
# Add ONNX
59+
FetchContent_Declare(
60+
onnx
61+
GIT_REPOSITORY https://github.com/onnx/onnx.git
62+
GIT_TAG v1.18.0 # Use a specific tag or commit
63+
)
64+
65+
FetchContent_MakeAvailable(onnx)
66+
67+
# Add GSL
68+
FetchContent_Declare(
69+
gsl
70+
GIT_REPOSITORY https://github.com/microsoft/GSL.git
71+
GIT_TAG v4.0.0 # Use a specific tag or commit
72+
)
73+
74+
FetchContent_MakeAvailable(gsl)
75+
76+
# Add GoogleTest
77+
FetchContent_Declare(
78+
googletest
79+
URL https://github.com/google/googletest/archive/refs/heads/main.zip
80+
)
81+
# For Windows: prevents overriding parent project's runtime library settings
82+
set(gtest_force_shared_crt ON CACHE BOOL "" FORCE)
83+
FetchContent_MakeAvailable(googletest)
84+
85+
set(DEPS_PATH "${CMAKE_BINARY_DIR}/_deps")
86+
87+
if (WIN32) # Windows
88+
set(ORT_LIB "${ORT_HOME}/lib/onnxruntime.lib")
89+
90+
if(CMAKE_BUILD_TYPE STREQUAL "Debug")
91+
set(DEPS_LIBS ${DEPS_LIBS}
92+
"${DEPS_PATH}/protobuf-build/${CMAKE_BUILD_TYPE}/libprotobufd.lib"
93+
"${DEPS_PATH}/protobuf-build/${CMAKE_BUILD_TYPE}/libprotocd.lib")
94+
else()
95+
set(DEPS_LIBS ${DEPS_LIBS}
96+
"${DEPS_PATH}/protobuf-build/${CMAKE_BUILD_TYPE}/libprotobuf.lib"
97+
"${DEPS_PATH}/protobuf-build/${CMAKE_BUILD_TYPE}/libprotoc.lib")
98+
endif()
99+
100+
set(DEPS_LIBS "${DEPS_PATH}/onnx-build/${CMAKE_BUILD_TYPE}/onnx.lib"
101+
"${DEPS_PATH}/onnx-build/${CMAKE_BUILD_TYPE}/onnx_proto.lib")
102+
103+
else() # Linux
104+
set(ORT_LIB "${ORT_HOME}/lib/libonnxruntime.so")
105+
106+
set(DEPS_LIBS "${DEPS_PATH}/onnx-build/libonnx.a"
107+
"${DEPS_PATH}/onnx-build/libonnx_proto.a")
108+
109+
if(CMAKE_BUILD_TYPE STREQUAL "Debug")
110+
set(DEPS_LIBS ${DEPS_LIBS}
111+
"${DEPS_PATH}/protobuf-build/libprotobufd.a"
112+
"${DEPS_PATH}/protobuf-build/libprotocd.a")
113+
else()
114+
set(DEPS_LIBS ${DEPS_LIBS}
115+
"${DEPS_PATH}/protobuf-build/libprotobuf.a"
116+
"${DEPS_PATH}/protobuf-build/libprotoc.a")
117+
endif()
118+
endif()
119+
120+
MESSAGE(STATUS "Looking for following dependencies ...")
121+
MESSAGE(STATUS "ORT lib : ${ORT_LIB}")
122+
MESSAGE(STATUS "Deps libs: ${DEPS_LIBS}")
123+
124+
125+
target_include_directories(tensorrt_ep_test PUBLIC "${ORT_HOME}/include"
126+
"${DEPS_PATH}/gsl-src/include" # GSL is header-only
127+
"${DEPS_PATH}/onnx-src"
128+
"${DEPS_PATH}/onnx-build"
129+
"${DEPS_PATH}/protobuf-src/src"
130+
)
131+
132+
target_link_libraries(tensorrt_ep_test PUBLIC #${DEPS_LIBS}
133+
GTest::gtest GTest::gtest_main
134+
protobuf::libprotobuf onnx
135+
${ORT_LIB}
136+
)
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#ifdef _WIN32
5+
#include <Windows.h>
6+
#include <assert.h>
7+
#endif
8+
9+
#include <stdexcept>
10+
11+
#ifdef ORT_NO_EXCEPTIONS
12+
#if defined(__ANDROID__)
13+
#include <android/log.h>
14+
#else
15+
#include <iostream>
16+
#endif
17+
#endif
18+
19+
#include <string>
20+
21+
#define THROW(...) throw std::runtime_error(std::string(__VA_ARGS__));
22+
23+
#ifdef _WIN32
24+
std::string ToUTF8String(std::wstring_view s) {
25+
if (s.size() >= static_cast<size_t>(std::numeric_limits<int>::max()))
26+
THROW("length overflow");
27+
28+
const int src_len = static_cast<int>(s.size() + 1);
29+
const int len = WideCharToMultiByte(CP_UTF8, 0, s.data(), src_len, nullptr, 0, nullptr, nullptr);
30+
assert(len > 0);
31+
std::string ret(static_cast<size_t>(len) - 1, '\0');
32+
#pragma warning(disable : 4189)
33+
const int r = WideCharToMultiByte(CP_UTF8, 0, s.data(), src_len, (char*)ret.data(), len, nullptr, nullptr);
34+
assert(len == r);
35+
#pragma warning(default : 4189)
36+
return ret;
37+
}
38+
39+
std::wstring ToWideString(std::string_view s) {
40+
if (s.size() >= static_cast<size_t>(std::numeric_limits<int>::max()))
41+
THROW("length overflow");
42+
43+
const int src_len = static_cast<int>(s.size() + 1);
44+
const int len = MultiByteToWideChar(CP_UTF8, 0, s.data(), src_len, nullptr, 0);
45+
assert(len > 0);
46+
std::wstring ret(static_cast<size_t>(len) - 1, '\0');
47+
#pragma warning(disable : 4189)
48+
const int r = MultiByteToWideChar(CP_UTF8, 0, s.data(), src_len, (wchar_t*)ret.data(), len);
49+
assert(len == r);
50+
#pragma warning(default : 4189)
51+
return ret;
52+
}
53+
#endif // #ifdef _WIN32
54+
55+
#ifdef NO_EXCEPTIONS
56+
void PrintFinalMessage(const char* msg) {
57+
#if defined(__ANDROID__)
58+
__android_log_print(ANDROID_LOG_ERROR, "onnxruntime", "%s", msg);
59+
#else
60+
// TODO, consider changing the output of the error message from std::cerr to logging when the
61+
// exceptions are disabled, since using std::cerr might increase binary size, and std::cerr output
62+
// might not be easily accessible on some systems such as mobile
63+
// TODO, see if we need to change the output of the error message from std::cerr to NSLog for iOS
64+
std::cerr << msg << std::endl;
65+
#endif
66+
}
67+
#endif // #ifdef NO_EXCEPTIONS
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include <string>
7+
#include <type_traits>
8+
9+
// for std::tolower or std::towlower
10+
#ifdef _WIN32
11+
#include <cwctype>
12+
#else
13+
#include <cctype>
14+
#endif
15+
16+
#include "onnxruntime_c_api.h"
17+
18+
// char type for filesystem paths
19+
using PathChar = ORTCHAR_T;
20+
// string type for filesystem paths
21+
using PathString = std::basic_string<PathChar>;
22+
23+
inline std::string ToUTF8String(const std::string& s) { return s; }
24+
#ifdef _WIN32
25+
/**
26+
* Convert a wide character string to a UTF-8 string
27+
*/
28+
std::string ToUTF8String(std::wstring_view s);
29+
inline std::string ToUTF8String(const wchar_t* s) {
30+
return ToUTF8String(std::wstring_view{s});
31+
}
32+
inline std::string ToUTF8String(const std::wstring& s) {
33+
return ToUTF8String(std::wstring_view{s});
34+
}
35+
std::wstring ToWideString(std::string_view s);
36+
inline std::wstring ToWideString(const char* s) {
37+
return ToWideString(std::string_view{s});
38+
}
39+
inline std::wstring ToWideString(const std::string& s) {
40+
return ToWideString(std::string_view{s});
41+
}
42+
inline std::wstring ToWideString(const std::wstring& s) { return s; }
43+
inline std::wstring ToWideString(std::wstring_view s) { return std::wstring{s}; }
44+
#else
45+
inline std::string ToWideString(const std::string& s) { return s; }
46+
inline std::string ToWideString(const char* s) { return s; }
47+
inline std::string ToWideString(std::string_view s) { return std::string{s}; }
48+
#endif
49+
50+
inline PathString ToPathString(const PathString& s) {
51+
return s;
52+
}
53+
54+
#ifdef _WIN32
55+
56+
static_assert(std::is_same<PathString, std::wstring>::value, "PathString is not std::wstring!");
57+
58+
inline PathString ToPathString(std::string_view s) {
59+
return ToWideString(s);
60+
}
61+
inline PathString ToPathString(const char* s) {
62+
return ToWideString(s);
63+
}
64+
inline PathString ToPathString(const std::string& s) {
65+
return ToWideString(s);
66+
}
67+
68+
inline PathChar ToLowerPathChar(PathChar c) {
69+
return std::towlower(c);
70+
}
71+
72+
inline std::string PathToUTF8String(const PathString& s) {
73+
return ToUTF8String(s);
74+
}
75+
76+
#else
77+
78+
static_assert(std::is_same<PathString, std::string>::value, "PathString is not std::string!");
79+
80+
inline PathString ToPathString(const char* s) {
81+
return s;
82+
}
83+
84+
inline PathString ToPathString(std::string_view s) {
85+
return PathString{s};
86+
}
87+
88+
inline PathChar ToLowerPathChar(PathChar c) {
89+
return std::tolower(c);
90+
}
91+
92+
inline std::string PathToUTF8String(const PathString& s) {
93+
return s;
94+
}
95+
96+
#endif
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
#include <gtest/gtest.h>
2+
#include <string>
3+
#include <gsl/gsl>
4+
5+
#include "onnxruntime_cxx_api.h"
6+
#include "test_trt_ep_utils.h"
7+
#include "path_string.h"
8+
9+
namespace test {
10+
namespace trt_ep {
11+
// char type for filesystem paths
12+
using PathChar = ORTCHAR_T;
13+
// string type for filesystem paths
14+
using PathString = std::basic_string<PathChar>;
15+
16+
class TensorrtExecutionProviderCacheTest : public testing::TestWithParam<std::string> {};
17+
18+
OrtStatus* CreateOrtSession(PathString model_name,
19+
std::string lib_registration_name,
20+
PathString lib_path) {
21+
const OrtApi* ort_api = OrtGetApiBase()->GetApi(ORT_API_VERSION);
22+
Ort::Env env;
23+
24+
// Register plugin TRT EP library with ONNX Runtime.
25+
env.RegisterExecutionProviderLibrary(
26+
lib_registration_name.c_str(), // Registration name can be anything the application chooses.
27+
lib_path // Path to the plugin TRT EP library.
28+
);
29+
30+
// Unregister the library using the application-specified registration name.
31+
// Must only unregister a library after all sessions that use the library have been released.
32+
auto unregister_plugin_eps_at_scope_exit =
33+
gsl::finally([&]() { env.UnregisterExecutionProviderLibrary(lib_registration_name.c_str()); });
34+
35+
{
36+
std::vector<Ort::ConstEpDevice> ep_devices = env.GetEpDevices();
37+
// EP name should match the name assigned by the EP factory when creating the EP (i.e., in the implementation of
38+
// OrtEP::CreateEp())
39+
std::string ep_name = lib_registration_name;
40+
41+
// Find the Ort::EpDevice for "TensorRTEp".
42+
std::vector<Ort::ConstEpDevice> selected_ep_devices = {};
43+
for (Ort::ConstEpDevice ep_device : ep_devices) {
44+
if (std::string(ep_device.EpName()) == ep_name) {
45+
selected_ep_devices.push_back(ep_device);
46+
break;
47+
}
48+
}
49+
50+
if (selected_ep_devices[0] == nullptr) {
51+
// Did not find EP. Report application error ...
52+
std::string message = "Did not find EP: " + ep_name;
53+
return ort_api->CreateStatus(ORT_EP_FAIL, message.c_str());
54+
}
55+
56+
std::unordered_map<std::string, std::string> ep_options; // Optional EP options.
57+
Ort::SessionOptions session_options;
58+
session_options.AppendExecutionProvider_V2(env, selected_ep_devices, ep_options);
59+
60+
Ort::Session session(env, model_name.c_str(), session_options);
61+
62+
// Get default ORT allocator
63+
Ort::AllocatorWithDefaultOptions allocator;
64+
65+
// Get input name
66+
Ort::AllocatedStringPtr input_name_ptr =
67+
session.GetInputNameAllocated(0, allocator); // Keep the smart pointer alive to avoid dangling pointer
68+
const char* input_name = input_name_ptr.get();
69+
70+
}
71+
72+
73+
}
74+
75+
TEST(TensorrtExecutionProviderTest, SessionCreationWithMultiThreadsAndInferenceWithMultiThreads) {
76+
std::vector<std::thread> threads;
77+
std::string model_name = "basic_model_for_test.onnx";
78+
std::string graph_name = "basic_model";
79+
std::string lib_registration_name = "TensorRTEp";
80+
PathString lib_path = ORT_TSTR("TensorRTEp.dll");
81+
std::vector<int64_t> dims = {1, 3, 2};
82+
CreateBaseModel(model_name, graph_name, dims);
83+
CreateOrtSession(ToPathString(model_name), lib_registration_name, lib_path);
84+
}
85+
86+
} // namespace trt_ep
87+
} // namespace test

0 commit comments

Comments
 (0)