@@ -1156,7 +1156,7 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this
11561156 weight_stripped_engine_refit_ = true ;
11571157 }
11581158
1159- std::unique_ptr<nvinfer1::IHostMemory> serialized_engine = nullptr ;
1159+ std::unique_ptr<nvinfer1::IHostMemory> serialized_engine;
11601160
11611161 if (!has_dynamic_shape) {
11621162 std::string timing_cache_path = " " ;
@@ -1258,7 +1258,7 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this
12581258 }
12591259
12601260 serialized_engine =
1261- std::make_unique <nvinfer1::IHostMemory>(trt_builder->buildSerializedNetwork (*trt_network, *trt_config));
1261+ std::unique_ptr <nvinfer1::IHostMemory>(trt_builder->buildSerializedNetwork (*trt_network, *trt_config));
12621262
12631263 if (serialized_engine == nullptr ) {
12641264 std::string err_msg = " TensorRT EP failed to create engine from network for fused node: " + fused_node_name;
@@ -1390,32 +1390,9 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this
13901390 input_shape_ranges_[fused_node_name] = input_implicit_shape_ranges;
13911391 profiles_.emplace (fused_node_name, std::move (trt_profiles));
13921392
1393- /*
1394- // For dynamic shape input model, firstly TRT EP creates a model proto which includes inputs, outputs and empty
1395- // engine. TRT EP will serialize the model at inference time due to engine can be updated and the updated engine
1396- // should be included in the model. However, if the embed_mode is 0 (only includes engine path), TRT EP will serialize
1397- // it here.
1398- if (dump_ep_context_model_ && has_dynamic_shape) {
1399- // "ep_cache_context" node attribute should be a relative path to context model directory
1400- if (ep_cache_context_attr_.empty()) {
1401- auto cache_file_name = std::filesystem::path(engine_cache_path).filename();
1402- ep_cache_context_attr_ = std::filesystem::path(engine_cache_relative_path_to_context_model_dir)
1403- .append(cache_file_name.string())
1404- .string();
1405- }
1406- std::string compute_capability_hw_compat = compute_capability_;
1407- if (engine_cache_enable_ && engine_hw_compatible_) {
1408- compute_capability_hw_compat = "80+";
1409- }
1410- model_proto_.reset(CreateCtxModel(graph_body_viewer, ep_cache_context_attr_, nullptr, 0, ep_context_embed_mode_,
1411- compute_capability_hw_compat, model_path_, GetLogger()));
1412- if (ep_context_embed_mode_ == 0) {
1413- DumpCtxModel(model_proto_.get(), ctx_model_path_);
1414- }
1415- }
1416- */
14171393
1418- std::unique_ptr<EPContextNodeHelper> ep_ctx_node_helper = std::make_unique<EPContextNodeHelper>(graph, fused_node);
1394+ // Create EP Context nodes
1395+ std::unique_ptr<EPContextNodeHelper> ep_ctx_node_helper = std::make_unique<EPContextNodeHelper>(*ep, graph, fused_node);
14191396 if (dump_ep_context_model_) {
14201397 std::string compute_capability_hw_compat = compute_capability_;
14211398 if (engine_cache_enable_ && engine_hw_compatible_) {
@@ -1490,6 +1467,8 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this
14901467 engine_hw_compatible_,
14911468 sync_stream_after_enqueue_};
14921469
1470+ ep->compute_states_ [fused_node_name] = std::move (compute_state);
1471+
14931472 // Update the OrtNodeComputeInfo associated with the graph.
14941473 auto ep_node_compute_info = std::make_unique<TRTEpNodeComputeInfo>(*ep);
14951474 *node_compute_info = ep_node_compute_info.release ();
@@ -1554,10 +1533,10 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this
15541533 auto supported_control_flow_op = [&](const OrtNode* node) {
15551534 OrtStatus* status = nullptr ;
15561535 size_t num_subgraphs = 0 ;
1557- RETURN_FALSE_AND_PRINT_IF_ERROR (ort_api.Node_GetNumSubgraphs (node, &num_subgraphs), ort_api );
1536+ RETURN_FALSE_AND_PRINT_IF_ERROR (ort_api.Node_GetNumSubgraphs (node, &num_subgraphs));
15581537
15591538 std::vector<const OrtGraph*> node_subgraphs (num_subgraphs);
1560- RETURN_FALSE_AND_PRINT_IF_ERROR (ort_api.Node_GetSubgraphs (node, node_subgraphs.data (), node_subgraphs.size (), nullptr ), ort_api );
1539+ RETURN_FALSE_AND_PRINT_IF_ERROR (ort_api.Node_GetSubgraphs (node, node_subgraphs.data (), node_subgraphs.size (), nullptr ));
15611540
15621541
15631542 // Iterate the node's subgraphs
@@ -1566,7 +1545,7 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this
15661545
15671546 // Get number of subgraph's nodes
15681547 size_t num_subgraph_nodes = 0 ;
1569- RETURN_FALSE_AND_PRINT_IF_ERROR (ort_api.Graph_GetNumNodes (subgraph, &num_subgraph_nodes), ort_api );
1548+ RETURN_FALSE_AND_PRINT_IF_ERROR (ort_api.Graph_GetNumNodes (subgraph, &num_subgraph_nodes));
15701549
15711550 // TRT EP should consider the empty subgraph is fully supported by TRT.
15721551 if (num_subgraph_nodes == 0 ) {
@@ -1926,13 +1905,11 @@ OrtStatus* TensorrtExecutionProvider::RefitEngine(
19261905// / </summary>
19271906TensorrtExecutionProvider::TensorrtExecutionProvider (TensorrtExecutionProviderFactory& factory,
19281907 const std::string& name,
1929- const OrtHardwareDevice& device,
19301908 const OrtSessionOptions& session_options,
19311909 const OrtLogger& logger)
19321910 : ApiPtrs{static_cast <const ApiPtrs&>(factory)},
19331911 factory_ (factory),
19341912 name_{name},
1935- hardware_device_{device},
19361913 session_options_{session_options},
19371914 logger_{logger} {
19381915
@@ -2176,7 +2153,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(TensorrtExecutionProviderFa
21762153 * Please refer to ParserProfileShapes() for more details)
21772154 *
21782155 */
2179- bool status = true ;
2156+ // bool status = true;
21802157 // if (status) {
21812158 // status = ParseProfileShapes(profile_min_shapes, profile_min_shapes_);
21822159 // if (!status) {
@@ -2266,14 +2243,14 @@ OrtStatus* TRTEpNodeComputeInfo::CreateStateImpl(OrtNodeComputeInfo* this_ptr, O
22662243 TensorrtExecutionProvider& ep = node_compute_info->ep ;
22672244
22682245 std::string fused_node_name = ep.ep_api .NodeComputeContext_NodeName (compute_context);
2269- auto state_it = ep.GetComputeStates () .find (fused_node_name);
2270- if (state_it == ep.GetComputeStates () .end ()) {
2246+ auto state_it = ep.compute_states_ .find (fused_node_name);
2247+ if (state_it == ep.compute_states_ .end ()) {
22712248 std::string message = " Unable to TensorRT EP's compute state for fused node with name " + fused_node_name;
22722249 return ep.ort_api .CreateStatus (ORT_EP_FAIL, message.c_str ());
22732250 }
22742251
2275- TensorrtComputeState& compute_state = *state_it->second ;
2276- *compute_state = &compute_state ;
2252+ TensorrtComputeState& trt_ep_compute_state = *state_it->second ;
2253+ *compute_state = &trt_ep_compute_state ;
22772254 return nullptr ;
22782255}
22792256
@@ -2335,7 +2312,7 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void*
23352312 bool context_update = false ;
23362313 std::unordered_set<std::string> input_names;
23372314
2338- std::unordered_map<std::string, DDSOutputAllocatorMap> dds_output_allocator_maps = ep.GetDDSOutputAllocators ();
2315+ std::unordered_map<std::string, DDSOutputAllocatorMap>& dds_output_allocator_maps = ep.GetDDSOutputAllocators ();
23392316 auto & dds_output_allocator_map = dds_output_allocator_maps[fused_node_name];
23402317
23412318 // Get default OrtMemoryInfo from factory
@@ -2911,7 +2888,7 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void*
29112888
29122889void TRTEpNodeComputeInfo::ReleaseStateImpl (OrtNodeComputeInfo* this_ptr, void * compute_state) {
29132890 (void )this_ptr;
2914- TensorrtComputeState& compute_state = *reinterpret_cast <TensorrtComputeState*>(compute_state);
2915- (void )compute_state ;
2891+ TensorrtComputeState& trt_ep_compute_state = *reinterpret_cast <TensorrtComputeState*>(compute_state);
2892+ (void )trt_ep_compute_state ;
29162893 // Do nothing for here.
29172894}
0 commit comments