44#include " onnxruntime_cxx_api.h"
55#undef ORT_API_MANUAL_INIT
66
7+ #include " tensorrt_provider_factory.h"
78#include " utils/provider_options.h"
89#include " tensorrt_execution_provider_info.h"
910#include " nv_includes.h"
@@ -150,11 +151,6 @@ class OutputAllocator : public nvinfer1::IOutputAllocator {
150151 std::vector<int64_t > output_shapes;
151152};
152153
153- using ShapeRangesMap = std::unordered_map<std::string, std::unordered_map<size_t , std::vector<std::vector<int64_t >>>>;
154-
155- template <typename T>
156- using IAllocatorUniquePtr = std::unique_ptr<T, std::function<void (T*)>>;
157-
158154struct TensorrtComputeState {
159155 std::string fused_node_name;
160156 nvinfer1::IBuilder* builder;
@@ -166,6 +162,8 @@ struct TensorrtComputeState {
166162 std::vector<std::unordered_map<std::string, size_t >> output_info;
167163 std::unordered_map<std::string, std::unordered_map<size_t , std::vector<std::vector<int64_t >>>> input_shape_ranges;
168164 std::mutex* tensorrt_mu_ptr = nullptr ;
165+ std::string compute_capability;
166+ size_t max_workspace_size = 1 << 30 ; // 1GB;
169167 bool fp16_enable = false ;
170168 bool int8_enable = false ;
171169 bool int8_calibration_cache_available = false ;
@@ -178,7 +176,7 @@ struct TensorrtComputeState {
178176 std::vector<nvinfer1::IOptimizationProfile*> profiles;
179177 bool context_memory_sharing_enable = false ;
180178 size_t * max_context_mem_size_ptr = nullptr ;
181- IAllocatorUniquePtr <void >* context_memory = nullptr ;
179+ AllocatorUniquePtr <void >* context_memory = nullptr ;
182180 std::unordered_map<std::string, float > dynamic_range_map;
183181 bool engine_decryption_enable = false ;
184182 int (*engine_decryption)(const char *, char *, size_t *) = nullptr ;
@@ -193,10 +191,17 @@ struct TensorrtComputeState {
193191 int auxiliary_streams = -1 ;
194192 bool filter_tactic_sources = false ;
195193 nvinfer1::TacticSources tactic_sources;
196- bool cuda_graph_enable = 0 ;
194+ bool cuda_graph_enable = false ;
195+ bool weight_stripped_engine_enable = false ;
196+ bool weight_stripped_engine_refit = false ;
197+ char * model_path;
198+ std::string onnx_model_folder_path;
199+ const void * onnx_model_bytestream;
200+ size_t onnx_model_bytestream_size;
197201 std::string cache_prefix;
198202 std::string cache_suffix;
199203 bool engine_hw_compatible = false ;
204+ bool sync_stream_after_enqueue = true ;
200205};
201206
202207// Minimum information to construct kernel function state for direct engine load code path
@@ -211,6 +216,7 @@ struct TensorrtComputeStateForEPContext {
211216 std::mutex* tensorrt_mu_ptr = nullptr ;
212217};
213218
219+ using ShapeRangesMap = std::unordered_map<std::string, std::unordered_map<size_t , std::vector<std::vector<int64_t >>>>;
214220using DDSOutputAllocatorMap = std::unordered_map<std::string, std::unique_ptr<OutputAllocator>>;
215221std::string GetWeightRefittedEnginePath (std::string engine_cache_path);
216222
@@ -220,54 +226,51 @@ static const std::string k_ep_ctx_onnx_model_filename = "onnx_model_filename";
220226
221227// / <summary>
222228// /
223- // / Plugin TensorRT EP OrtNodeComputeInfo that represents the computation function for a compiled OrtGraph.
224- // /
225- // / </summary>
226- struct TRTEpNodeComputeInfo : OrtNodeComputeInfo {
227- explicit TRTEpNodeComputeInfo (TensorrtExecutionProvider& ep);
228-
229- static OrtStatus* ORT_API_CALL CreateStateImpl (OrtNodeComputeInfo* this_ptr, OrtNodeComputeContext* compute_context,
230- void ** compute_state);
231- static OrtStatus* ORT_API_CALL ComputeImpl (OrtNodeComputeInfo* this_ptr, void * compute_state,
232- OrtKernelContext* kernel_context);
233- static void ORT_API_CALL ReleaseStateImpl (OrtNodeComputeInfo* this_ptr, void * compute_state);
234-
235- TensorrtExecutionProvider& ep;
236- };
237-
238- // / <summary>
239- // /
240- // / Plugin TensorRT EP that implements OrtEp
229+ // / Plugin TensorRT EP
241230// /
242231// / </summary>
243- struct TensorrtExecutionProvider : OrtEp, ApiPtrs {
244- TensorrtExecutionProvider (ApiPtrs apis, const std::string& name, const OrtHardwareDevice& device,
245- const OrtSessionOptions& session_options, const OrtLogger& logger);
232+ struct TensorrtExecutionProvider : public OrtEp , public ApiPtrs {
233+ TensorrtExecutionProvider (TensorrtExecutionProviderFactory& factory, const std::string& name,
234+ const OrtHardwareDevice& device, const OrtSessionOptions& session_options,
235+ const OrtLogger& logger);
246236 ~TensorrtExecutionProvider ();
247237
238+ TensorrtExecutionProviderFactory& factory_;
248239 std::string name_;
249240 const OrtHardwareDevice& hardware_device_;
250241 const OrtSessionOptions& session_options_;
251242 const OrtLogger& logger_;
252243
253- SubGraphCollection_t GetSupportedList (SubGraphCollection_t supported_nodes_list, int iterations, const int max_iterations,
254- const OrtGraph* graph, bool * early_termination) const ;
244+ SubGraphCollection_t GetSupportedList (SubGraphCollection_t supported_nodes_list, int iterations,
245+ const int max_iterations, const OrtGraph* graph, bool * early_termination) const ;
255246
256247 OrtStatus* CreateNodeComputeInfoFromPrecompiledEngine (OrtEp* this_ptr, const OrtGraph* graph,
257248 const OrtNode* fused_node,
258249 std::unordered_map<std::string, size_t >& input_map,
259250 std::unordered_map<std::string, size_t >& output_map,
260- OrtNodeComputeInfo* node_compute_info);
251+ OrtNodeComputeInfo** node_compute_info);
261252
262253 OrtStatus* CreateNodeComputeInfoFromGraph (OrtEp* this_ptr, const OrtGraph* graph, const OrtNode* fused_node,
263254 std::unordered_map<std::string, size_t >& input_map,
264255 std::unordered_map<std::string, size_t >& output_map,
265- OrtNodeComputeInfo* node_compute_info);
256+ OrtNodeComputeInfo** node_compute_info);
257+
258+ OrtStatus* RefitEngine (std::string onnx_model_filename, std::string& onnx_model_folder_path,
259+ std::string& weight_stripped_engine_cath_path, bool path_check,
260+ const void * onnx_model_bytestream, size_t onnx_model_bytestream_size,
261+ nvinfer1::ICudaEngine* trt_engine, bool serialize_refitted_engine,
262+ bool detailed_build_log);
266263
267264 std::unordered_map<std::string, std::unique_ptr<TensorrtComputeState>>& GetComputeStates () { return compute_states_; }
268265
269- std::unordered_map<std::string, std::unique_ptr<TensorrtComputeState>>& GetComputeStatesForEPContext () { return compute_states_; }
266+ std::unordered_map<std::string, std::unique_ptr<TensorrtComputeState>>& GetComputeStatesForEPContext () {
267+ return compute_states_;
268+ }
269+
270+ void GetAllocator (OrtAllocator** alloc) const { *alloc = alloc_; }
270271
272+ void SetAllocator (OrtAllocator* alloc) { alloc_ = alloc; }
273+
271274 std::unordered_map<std::string, DDSOutputAllocatorMap>& GetDDSOutputAllocators () {
272275 return dds_output_allocator_maps_;
273276 }
@@ -312,6 +315,19 @@ struct TensorrtExecutionProvider : OrtEp, ApiPtrs {
312315 std::unordered_map<std::string, std::string> cache_suffix_;
313316
314317 private:
318+ static const char * ORT_API_CALL GetNameImpl (const OrtEp* this_ptr) noexcept ;
319+ static OrtStatus* ORT_API_CALL GetCapabilityImpl (OrtEp* this_ptr, const OrtGraph* graph,
320+ OrtEpGraphSupportInfo* graph_support_info);
321+ static OrtStatus* ORT_API_CALL CompileImpl (_In_ OrtEp* this_ptr, _In_ const OrtGraph** graphs,
322+ _In_ const OrtNode** fused_nodes, _In_ size_t count,
323+ _Out_writes_all_ (count) OrtNodeComputeInfo** node_compute_infos,
324+ _Out_writes_(count) OrtNode** ep_context_nodes);
325+ static void ORT_API_CALL ReleaseNodeComputeInfosImpl (OrtEp* this_ptr, OrtNodeComputeInfo** node_compute_infos,
326+ size_t num_node_compute_infos);
327+
328+ OrtStatus* CreateEpContextNodes (gsl::span<const OrtNode*> fused_nodes,
329+ /* out*/ gsl::span<OrtNode*> ep_context_nodes);
330+
315331 mutable TensorrtExecutionProviderInfo info_;
316332 bool external_stream_ = false ;
317333 cudaStream_t stream_ = nullptr ;
@@ -331,6 +347,8 @@ struct TensorrtExecutionProvider : OrtEp, ApiPtrs {
331347 bool weight_stripped_engine_enable_ = false ;
332348 bool weight_stripped_engine_refit_ = false ;
333349 std::string onnx_model_folder_path_;
350+ const void * onnx_model_bytestream_;
351+ size_t onnx_model_bytestream_size_;
334352 bool build_heuristics_enable_ = false ;
335353 bool sparsity_enable_ = false ;
336354 int builder_optimization_level_ = 3 ;
@@ -344,7 +362,7 @@ struct TensorrtExecutionProvider : OrtEp, ApiPtrs {
344362 bool context_memory_sharing_enable_ = false ;
345363 bool layer_norm_fp32_fallback_ = false ;
346364 size_t max_ctx_mem_size_ = 0 ;
347- IAllocatorUniquePtr <void > context_memory_ = nullptr ;
365+ AllocatorUniquePtr <void > context_memory_ = nullptr ;
348366 mutable char model_path_[4096 ] = {}; // Reserved for max path length
349367 bool engine_decryption_enable_ = false ;
350368 int (*engine_decryption_)(const char *, char *, size_t *) = nullptr;
@@ -419,3 +437,20 @@ struct TensorrtExecutionProvider : OrtEp, ApiPtrs {
419437
420438 nvinfer1::IBuilder* GetBuilder (TensorrtLogger& trt_logger) const ;
421439};
440+
441+ // / <summary>
442+ // /
443+ // / Plugin TensorRT EP OrtNodeComputeInfo that represents the computation function for a compiled OrtGraph.
444+ // /
445+ // / </summary>
446+ struct TRTEpNodeComputeInfo : OrtNodeComputeInfo {
447+ explicit TRTEpNodeComputeInfo (TensorrtExecutionProvider& ep);
448+
449+ static OrtStatus* ORT_API_CALL CreateStateImpl (OrtNodeComputeInfo* this_ptr, OrtNodeComputeContext* compute_context,
450+ void ** compute_state);
451+ static OrtStatus* ORT_API_CALL ComputeImpl (OrtNodeComputeInfo* this_ptr, void * compute_state,
452+ OrtKernelContext* kernel_context);
453+ static void ORT_API_CALL ReleaseStateImpl (OrtNodeComputeInfo* this_ptr, void * compute_state);
454+
455+ TensorrtExecutionProvider& ep;
456+ };
0 commit comments