Skip to content

Commit 340a7ce

Browse files
committed
Upate test
1 parent 86e6056 commit 340a7ce

File tree

2 files changed

+188
-31
lines changed

2 files changed

+188
-31
lines changed

plugin_execution_providers/test/tensorrt/tensorrt_basic_test.cc

Lines changed: 129 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -8,39 +8,49 @@
88

99
namespace test {
1010
namespace trt_ep {
11+
1112
// char type for filesystem paths
1213
using PathChar = ORTCHAR_T;
1314
// string type for filesystem paths
1415
using PathString = std::basic_string<PathChar>;
1516

16-
class TensorrtExecutionProviderCacheTest : public testing::TestWithParam<std::string> {};
17+
template <typename T>
18+
void VerifyOutptus(const std::vector<Ort::Value>& fetches,
19+
const std::vector<int64_t>& expected_dims,
20+
const std::vector<T>& expected_values) {
21+
ASSERT_EQ(1, fetches.size());
22+
const Ort::Value& actual_output = fetches[0];
23+
Ort::TensorTypeAndShapeInfo type_shape_info = actual_output.GetTensorTypeAndShapeInfo();
24+
ONNXTensorElementDataType element_type = type_shape_info.GetElementType();
25+
auto shape = type_shape_info.GetShape();
1726

18-
OrtStatus* CreateOrtSession(PathString model_name,
19-
std::string lib_registration_name,
20-
PathString lib_path) {
21-
const OrtApi* ort_api = OrtGetApiBase()->GetApi(ORT_API_VERSION);
22-
Ort::Env env;
27+
ASSERT_EQ(element_type, ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT);
28+
ASSERT_EQ(shape, expected_dims);
2329

24-
// Register plugin TRT EP library with ONNX Runtime.
25-
env.RegisterExecutionProviderLibrary(
26-
lib_registration_name.c_str(), // Registration name can be anything the application chooses.
27-
lib_path // Path to the plugin TRT EP library.
28-
);
30+
size_t element_cnt = type_shape_info.GetElementCount();
31+
const T* actual_values = actual_output.GetTensorData<T>();
2932

30-
// Unregister the library using the application-specified registration name.
31-
// Must only unregister a library after all sessions that use the library have been released.
32-
auto unregister_plugin_eps_at_scope_exit =
33-
gsl::finally([&]() { env.UnregisterExecutionProviderLibrary(lib_registration_name.c_str()); });
33+
ASSERT_EQ(element_cnt, expected_values.size());
34+
35+
for (size_t i = 0; i != element_cnt; ++i) {
36+
ASSERT_EQ(actual_values[i], expected_values[i]);
37+
}
38+
}
39+
40+
static OrtStatus* CreateOrtSession(Ort::Env& env,
41+
PathString model_name,
42+
std::string ep_name,
43+
OrtSession** session) {
44+
const OrtApi* ort_api = OrtGetApiBase()->GetApi(ORT_API_VERSION);
3445

3546
{
3647
std::vector<Ort::ConstEpDevice> ep_devices = env.GetEpDevices();
37-
// EP name should match the name assigned by the EP factory when creating the EP (i.e., in the implementation of
38-
// OrtEP::CreateEp())
39-
std::string ep_name = lib_registration_name;
4048

4149
// Find the Ort::EpDevice for "TensorRTEp".
4250
std::vector<Ort::ConstEpDevice> selected_ep_devices = {};
4351
for (Ort::ConstEpDevice ep_device : ep_devices) {
52+
// EP name should match the name assigned by the EP factory when creating the EP (i.e., in the implementation of
53+
// OrtEP::CreateEp())
4454
if (std::string(ep_device.EpName()) == ep_name) {
4555
selected_ep_devices.push_back(ep_device);
4656
break;
@@ -57,30 +67,118 @@ OrtStatus* CreateOrtSession(PathString model_name,
5767
Ort::SessionOptions session_options;
5868
session_options.AppendExecutionProvider_V2(env, selected_ep_devices, ep_options);
5969

60-
Ort::Session session(env, model_name.c_str(), session_options);
70+
Ort::Session ort_session(env, model_name.c_str(), session_options);
71+
*session = ort_session.release();
72+
}
6173

62-
// Get default ORT allocator
63-
Ort::AllocatorWithDefaultOptions allocator;
74+
return nullptr;
75+
}
6476

65-
// Get input name
66-
Ort::AllocatedStringPtr input_name_ptr =
67-
session.GetInputNameAllocated(0, allocator); // Keep the smart pointer alive to avoid dangling pointer
68-
const char* input_name = input_name_ptr.get();
77+
static OrtStatus* RunInference(Ort::Session& session,
78+
std::vector<Ort::Value>& outputs) {
79+
// Get default ORT allocator
80+
Ort::AllocatorWithDefaultOptions allocator;
6981

70-
}
82+
RETURN_IF_NOT(session.GetInputCount() == 3);
83+
84+
// Get input names
85+
Ort::AllocatedStringPtr input_name_ptr =
86+
session.GetInputNameAllocated(0, allocator); // Keep the smart pointer alive to avoid dangling pointer
87+
const char* input_name = input_name_ptr.get();
88+
89+
Ort::AllocatedStringPtr input_name2_ptr = session.GetInputNameAllocated(1, allocator);
90+
const char* input_name2 = input_name2_ptr.get();
91+
92+
Ort::AllocatedStringPtr input_name3_ptr = session.GetInputNameAllocated(2, allocator);
93+
const char* input_name3 = input_name3_ptr.get();
94+
95+
// Input data.
96+
std::vector<float> input_values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
97+
98+
// Input shape: (1, 3, 2)
99+
std::vector<int64_t> input_shape{1, 3, 2};
100+
101+
// Create tensor
102+
Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
103+
104+
// Create input data as an OrtValue.
105+
// Make input2 data and input3 data same as input1 data.
106+
Ort::Value input_tensor = Ort::Value::CreateTensor<float>(memory_info, input_values.data(), input_values.size(),
107+
input_shape.data(), input_shape.size());
108+
Ort::Value input2_tensor = Ort::Value::CreateTensor<float>(memory_info, input_values.data(), input_values.size(),
109+
input_shape.data(), input_shape.size());
110+
Ort::Value input3_tensor = Ort::Value::CreateTensor<float>(memory_info, input_values.data(), input_values.size(),
111+
input_shape.data(), input_shape.size());
112+
113+
std::vector<Ort::Value> input_tensors;
114+
input_tensors.reserve(3);
115+
input_tensors.push_back(std::move(input_tensor));
116+
input_tensors.push_back(std::move(input2_tensor));
117+
input_tensors.push_back(std::move(input3_tensor));
118+
119+
// Get output name
120+
Ort::AllocatedStringPtr output_name_ptr =
121+
session.GetOutputNameAllocated(0, allocator); // Keep the smart pointer alive to avoid dangling pointer
122+
const char* output_name = output_name_ptr.get();
71123

124+
// Run session
125+
std::vector<const char*> input_names{input_name, input_name2, input_name3};
126+
std::vector<const char*> output_names{output_name};
72127

128+
auto output_tensors = session.Run(Ort::RunOptions{nullptr}, input_names.data(), input_tensors.data(),
129+
input_tensors.size(), output_names.data(), 1);
130+
outputs = std::move(output_tensors);
131+
132+
return nullptr;
73133
}
74134

75-
TEST(TensorrtExecutionProviderTest, SessionCreationWithMultiThreadsAndInferenceWithMultiThreads) {
76-
std::vector<std::thread> threads;
77-
std::string model_name = "basic_model_for_test.onnx";
78-
std::string graph_name = "basic_model";
135+
136+
137+
TEST(TensorrtExecutionProviderTest, CreateSessionAndRunInference) {
138+
Ort::Env env;
79139
std::string lib_registration_name = "TensorRTEp";
140+
std::string& ep_name = lib_registration_name;
80141
PathString lib_path = ORT_TSTR("TensorRTEp.dll");
142+
143+
// Register plugin TRT EP library with ONNX Runtime.
144+
env.RegisterExecutionProviderLibrary(
145+
lib_registration_name.c_str(), // Registration name can be anything the application chooses.
146+
lib_path // Path to the plugin TRT EP library.
147+
);
148+
149+
// Unregister the library using the application-specified registration name.
150+
// Must only unregister a library after all sessions that use the library have been released.
151+
auto unregister_plugin_eps_at_scope_exit =
152+
gsl::finally([&]() { env.UnregisterExecutionProviderLibrary(lib_registration_name.c_str()); });
153+
154+
155+
std::string model_name = "basic_model_for_test.onnx";
156+
std::string graph_name = "basic_model";
81157
std::vector<int64_t> dims = {1, 3, 2};
82158
CreateBaseModel(model_name, graph_name, dims);
83-
CreateOrtSession(ToPathString(model_name), lib_registration_name, lib_path);
159+
160+
OrtSession* session = nullptr;
161+
ASSERT_EQ(CreateOrtSession(env, ToPathString(model_name), ep_name, &session), nullptr);
162+
ASSERT_NE(session, nullptr);
163+
Ort::Session ort_session{session};
164+
165+
std::vector<Ort::Value> output_tensors;
166+
ASSERT_EQ(RunInference(ort_session, output_tensors), nullptr);
167+
168+
// Extract output data
169+
float* output_data = output_tensors.front().GetTensorMutableData<float>();
170+
171+
std::cout << "Output:" << std::endl;
172+
for (int i = 0; i < 6; i++) {
173+
std::cout << output_data[i] << " ";
174+
}
175+
std::cout << std::endl;
176+
177+
std::vector<float> expected_values = {3.0f, 6.0f, 9.0f, 12.0f, 15.0f, 18.0f};
178+
std::vector<int64_t> expected_shape{1, 3, 2};
179+
VerifyOutptus(output_tensors, expected_shape, expected_values);
180+
181+
84182
}
85183

86184
} // namespace trt_ep

plugin_execution_providers/test/tensorrt/test_trt_ep_utils.h

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,68 @@
33

44
namespace test {
55
namespace trt_ep {
6+
67
std::string ToUTF8String(std::wstring_view s);
78
std::wstring ToWideString(std::string_view s);
89

10+
#define ENFORCE(condition, ...) \
11+
do { \
12+
if (!(condition)) { \
13+
throw std::runtime_error(std::string(__VA_ARGS__)); \
14+
} \
15+
} while (false)
16+
17+
#define THROW(...) throw std::runtime_error(std::string(__VA_ARGS__));
18+
19+
#define RETURN_IF_ORTSTATUS_ERROR(fn) RETURN_IF_ERROR(fn)
20+
21+
#define RETURN_IF_ERROR(fn) \
22+
do { \
23+
OrtStatus* _status = (fn); \
24+
if (_status != nullptr) { \
25+
return _status; \
26+
} \
27+
} while (0)
28+
29+
#define RETURN_IF_ORT_STATUS_ERROR(fn) \
30+
do { \
31+
auto _status = (fn); \
32+
if (!_status.IsOK()) { \
33+
return _status; \
34+
} \
35+
} while (0)
36+
37+
#define RETURN_IF(cond, ...) \
38+
do { \
39+
if ((cond)) { \
40+
return Ort::GetApi().CreateStatus(ORT_EP_FAIL, std::string(__VA_ARGS__).c_str()); \
41+
} \
42+
} while (0)
43+
44+
#define RETURN_IF_NOT(condition, ...) RETURN_IF(!(condition), __VA_ARGS__)
45+
46+
#define MAKE_STATUS(error_code, msg) Ort::GetApi().CreateStatus(error_code, (msg));
47+
48+
#define THROW_IF_ERROR(expr) \
49+
do { \
50+
auto _status = (expr); \
51+
if (_status != nullptr) { \
52+
std::ostringstream oss; \
53+
oss << Ort::GetApi().GetErrorMessage(_status); \
54+
Ort::GetApi().ReleaseStatus(_status); \
55+
throw std::runtime_error(oss.str()); \
56+
} \
57+
} while (0)
58+
59+
#define RETURN_FALSE_AND_PRINT_IF_ERROR(fn) \
60+
do { \
61+
OrtStatus* status = (fn); \
62+
if (status != nullptr) { \
63+
std::cerr << Ort::GetApi().GetErrorMessage(status) << std::endl; \
64+
return false; \
65+
} \
66+
} while (0)
67+
968
void CreateBaseModel(const std::string& model_path,
1069
const std::string& graph_name,
1170
const std::vector<int64_t>& dims,

0 commit comments

Comments
 (0)