88
99namespace test {
1010namespace trt_ep {
11+
1112// char type for filesystem paths
1213using PathChar = ORTCHAR_T;
1314// string type for filesystem paths
1415using PathString = std::basic_string<PathChar>;
1516
16- class TensorrtExecutionProviderCacheTest : public testing ::TestWithParam<std::string> {};
17+ template <typename T>
18+ void VerifyOutptus (const std::vector<Ort::Value>& fetches,
19+ const std::vector<int64_t >& expected_dims,
20+ const std::vector<T>& expected_values) {
21+ ASSERT_EQ (1 , fetches.size ());
22+ const Ort::Value& actual_output = fetches[0 ];
23+ Ort::TensorTypeAndShapeInfo type_shape_info = actual_output.GetTensorTypeAndShapeInfo ();
24+ ONNXTensorElementDataType element_type = type_shape_info.GetElementType ();
25+ auto shape = type_shape_info.GetShape ();
1726
18- OrtStatus* CreateOrtSession (PathString model_name,
19- std::string lib_registration_name,
20- PathString lib_path) {
21- const OrtApi* ort_api = OrtGetApiBase ()->GetApi (ORT_API_VERSION);
22- Ort::Env env;
27+ ASSERT_EQ (element_type, ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT);
28+ ASSERT_EQ (shape, expected_dims);
2329
24- // Register plugin TRT EP library with ONNX Runtime.
25- env.RegisterExecutionProviderLibrary (
26- lib_registration_name.c_str (), // Registration name can be anything the application chooses.
27- lib_path // Path to the plugin TRT EP library.
28- );
30+ size_t element_cnt = type_shape_info.GetElementCount ();
31+ const T* actual_values = actual_output.GetTensorData <T>();
2932
30- // Unregister the library using the application-specified registration name.
31- // Must only unregister a library after all sessions that use the library have been released.
32- auto unregister_plugin_eps_at_scope_exit =
33- gsl::finally ([&]() { env.UnregisterExecutionProviderLibrary (lib_registration_name.c_str ()); });
33+ ASSERT_EQ (element_cnt, expected_values.size ());
34+
35+ for (size_t i = 0 ; i != element_cnt; ++i) {
36+ ASSERT_EQ (actual_values[i], expected_values[i]);
37+ }
38+ }
39+
40+ static OrtStatus* CreateOrtSession (Ort::Env& env,
41+ PathString model_name,
42+ std::string ep_name,
43+ OrtSession** session) {
44+ const OrtApi* ort_api = OrtGetApiBase ()->GetApi (ORT_API_VERSION);
3445
3546 {
3647 std::vector<Ort::ConstEpDevice> ep_devices = env.GetEpDevices ();
37- // EP name should match the name assigned by the EP factory when creating the EP (i.e., in the implementation of
38- // OrtEP::CreateEp())
39- std::string ep_name = lib_registration_name;
4048
4149 // Find the Ort::EpDevice for "TensorRTEp".
4250 std::vector<Ort::ConstEpDevice> selected_ep_devices = {};
4351 for (Ort::ConstEpDevice ep_device : ep_devices) {
52+ // EP name should match the name assigned by the EP factory when creating the EP (i.e., in the implementation of
53+ // OrtEP::CreateEp())
4454 if (std::string (ep_device.EpName ()) == ep_name) {
4555 selected_ep_devices.push_back (ep_device);
4656 break ;
@@ -57,30 +67,118 @@ OrtStatus* CreateOrtSession(PathString model_name,
5767 Ort::SessionOptions session_options;
5868 session_options.AppendExecutionProvider_V2 (env, selected_ep_devices, ep_options);
5969
60- Ort::Session session (env, model_name.c_str (), session_options);
70+ Ort::Session ort_session (env, model_name.c_str (), session_options);
71+ *session = ort_session.release ();
72+ }
6173
62- // Get default ORT allocator
63- Ort::AllocatorWithDefaultOptions allocator;
74+ return nullptr ;
75+ }
6476
65- // Get input name
66- Ort::AllocatedStringPtr input_name_ptr =
67- session. GetInputNameAllocated ( 0 , allocator); // Keep the smart pointer alive to avoid dangling pointer
68- const char * input_name = input_name_ptr. get () ;
77+ static OrtStatus* RunInference (Ort::Session& session,
78+ std::vector< Ort::Value>& outputs) {
79+ // Get default ORT allocator
80+ Ort::AllocatorWithDefaultOptions allocator ;
6981
70- }
82+ RETURN_IF_NOT (session.GetInputCount () == 3 );
83+
84+ // Get input names
85+ Ort::AllocatedStringPtr input_name_ptr =
86+ session.GetInputNameAllocated (0 , allocator); // Keep the smart pointer alive to avoid dangling pointer
87+ const char * input_name = input_name_ptr.get ();
88+
89+ Ort::AllocatedStringPtr input_name2_ptr = session.GetInputNameAllocated (1 , allocator);
90+ const char * input_name2 = input_name2_ptr.get ();
91+
92+ Ort::AllocatedStringPtr input_name3_ptr = session.GetInputNameAllocated (2 , allocator);
93+ const char * input_name3 = input_name3_ptr.get ();
94+
95+ // Input data.
96+ std::vector<float > input_values = {1 .0f , 2 .0f , 3 .0f , 4 .0f , 5 .0f , 6 .0f };
97+
98+ // Input shape: (1, 3, 2)
99+ std::vector<int64_t > input_shape{1 , 3 , 2 };
100+
101+ // Create tensor
102+ Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu (OrtArenaAllocator, OrtMemTypeDefault);
103+
104+ // Create input data as an OrtValue.
105+ // Make input2 data and input3 data same as input1 data.
106+ Ort::Value input_tensor = Ort::Value::CreateTensor<float >(memory_info, input_values.data (), input_values.size (),
107+ input_shape.data (), input_shape.size ());
108+ Ort::Value input2_tensor = Ort::Value::CreateTensor<float >(memory_info, input_values.data (), input_values.size (),
109+ input_shape.data (), input_shape.size ());
110+ Ort::Value input3_tensor = Ort::Value::CreateTensor<float >(memory_info, input_values.data (), input_values.size (),
111+ input_shape.data (), input_shape.size ());
112+
113+ std::vector<Ort::Value> input_tensors;
114+ input_tensors.reserve (3 );
115+ input_tensors.push_back (std::move (input_tensor));
116+ input_tensors.push_back (std::move (input2_tensor));
117+ input_tensors.push_back (std::move (input3_tensor));
118+
119+ // Get output name
120+ Ort::AllocatedStringPtr output_name_ptr =
121+ session.GetOutputNameAllocated (0 , allocator); // Keep the smart pointer alive to avoid dangling pointer
122+ const char * output_name = output_name_ptr.get ();
71123
124+ // Run session
125+ std::vector<const char *> input_names{input_name, input_name2, input_name3};
126+ std::vector<const char *> output_names{output_name};
72127
128+ auto output_tensors = session.Run (Ort::RunOptions{nullptr }, input_names.data (), input_tensors.data (),
129+ input_tensors.size (), output_names.data (), 1 );
130+ outputs = std::move (output_tensors);
131+
132+ return nullptr ;
73133}
74134
75- TEST (TensorrtExecutionProviderTest, SessionCreationWithMultiThreadsAndInferenceWithMultiThreads) {
76- std::vector<std::thread> threads;
77- std::string model_name = " basic_model_for_test.onnx " ;
78- std::string graph_name = " basic_model " ;
135+
136+
137+ TEST (TensorrtExecutionProviderTest, CreateSessionAndRunInference) {
138+ Ort::Env env ;
79139 std::string lib_registration_name = " TensorRTEp" ;
140+ std::string& ep_name = lib_registration_name;
80141 PathString lib_path = ORT_TSTR (" TensorRTEp.dll" );
142+
143+ // Register plugin TRT EP library with ONNX Runtime.
144+ env.RegisterExecutionProviderLibrary (
145+ lib_registration_name.c_str (), // Registration name can be anything the application chooses.
146+ lib_path // Path to the plugin TRT EP library.
147+ );
148+
149+ // Unregister the library using the application-specified registration name.
150+ // Must only unregister a library after all sessions that use the library have been released.
151+ auto unregister_plugin_eps_at_scope_exit =
152+ gsl::finally ([&]() { env.UnregisterExecutionProviderLibrary (lib_registration_name.c_str ()); });
153+
154+
155+ std::string model_name = " basic_model_for_test.onnx" ;
156+ std::string graph_name = " basic_model" ;
81157 std::vector<int64_t > dims = {1 , 3 , 2 };
82158 CreateBaseModel (model_name, graph_name, dims);
83- CreateOrtSession (ToPathString (model_name), lib_registration_name, lib_path);
159+
160+ OrtSession* session = nullptr ;
161+ ASSERT_EQ (CreateOrtSession (env, ToPathString (model_name), ep_name, &session), nullptr );
162+ ASSERT_NE (session, nullptr );
163+ Ort::Session ort_session{session};
164+
165+ std::vector<Ort::Value> output_tensors;
166+ ASSERT_EQ (RunInference (ort_session, output_tensors), nullptr );
167+
168+ // Extract output data
169+ float * output_data = output_tensors.front ().GetTensorMutableData <float >();
170+
171+ std::cout << " Output:" << std::endl;
172+ for (int i = 0 ; i < 6 ; i++) {
173+ std::cout << output_data[i] << " " ;
174+ }
175+ std::cout << std::endl;
176+
177+ std::vector<float > expected_values = {3 .0f , 6 .0f , 9 .0f , 12 .0f , 15 .0f , 18 .0f };
178+ std::vector<int64_t > expected_shape{1 , 3 , 2 };
179+ VerifyOutptus (output_tensors, expected_shape, expected_values);
180+
181+
84182}
85183
86184} // namespace trt_ep
0 commit comments