Skip to content

Commit 3ad7736

Browse files
committed
Update CompileImpl
1 parent 549b29d commit 3ad7736

File tree

1 file changed

+37
-43
lines changed

1 file changed

+37
-43
lines changed

plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc

Lines changed: 37 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
//#include "tensorrt_execution_provider_utils.h"
1414
#include "tensorrt_execution_provider.h"
1515
#include "cuda_allocator.h"
16-
//#include "onnx_ctx_model_helper.h"
16+
#include "onnx_ctx_model_helper.h"
1717
#include "onnx/onnx_pb.h"
1818
#include "cuda/unary_elementwise_ops_impl.h"
1919

@@ -1480,8 +1480,7 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this
14801480
}
14811481

14821482

1483-
OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* graph,
1484-
OrtEpGraphSupportInfo* graph_support_info) {
1483+
OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* graph, OrtEpGraphSupportInfo* graph_support_info) {
14851484
TensorrtExecutionProvider* ep = static_cast<TensorrtExecutionProvider*>(this_ptr);
14861485
const OrtApi& ort_api = ep->ort_api;
14871486

@@ -1780,71 +1779,66 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this
17801779
return nullptr;
17811780
}
17821781

1783-
OrtStatus* ORT_API_CALL TensorrtExecutionProvider::CompileImpl(_In_ OrtEp* this_ptr, _In_ const OrtGraph** graphs,
1784-
_In_ const OrtNode** fused_nodes, _In_ size_t count,
1782+
OrtStatus* ORT_API_CALL TensorrtExecutionProvider::CompileImpl(_In_ OrtEp* this_ptr,
1783+
_In_ const OrtGraph** graphs,
1784+
_In_ const OrtNode** fused_nodes,
1785+
_In_ size_t count,
17851786
_Out_writes_all_(count) OrtNodeComputeInfo** node_compute_infos,
17861787
_Out_writes_(count) OrtNode** ep_context_nodes) {
17871788

17881789
TensorrtExecutionProvider* ep = static_cast<TensorrtExecutionProvider*>(this_ptr);
1790+
const OrtApi& ort_api = ep->ort_api;
17891791

17901792
gsl::span<OrtNodeComputeInfo*> result(node_compute_infos, count);
17911793

1792-
for (size_t graph_idx = 0; graph_idx < count; graph_idx++) {
1793-
auto fused_node = fused_nodes[graph_idx];
1794-
1795-
// Gets node's inputs and outputs as pointer array
1796-
OrtArrayOfConstObjects* inputs_array = nullptr;
1797-
OrtArrayOfConstObjects* outputs_array = nullptr;
1798-
DeferOrtRelease<OrtArrayOfConstObjects> release_inputs(&inputs_array, ep->ort_api.ReleaseArrayOfConstObjects);
1799-
DeferOrtRelease<OrtArrayOfConstObjects> release_outputs(&outputs_array, ep->ort_api.ReleaseArrayOfConstObjects);
1800-
1801-
RETURN_IF_ERROR(ep->ort_api.Node_GetInputs(fused_node, &inputs_array));
1802-
RETURN_IF_ERROR(ep->ort_api.Node_GetOutputs(fused_node, &outputs_array));
1803-
1804-
// Gets node's inputs and outputs as OrtValueInfo in gsl::span
1805-
gsl::span<const OrtValueInfo* const> node_inputs{};
1806-
gsl::span<const OrtValueInfo* const> node_outputs{};
1807-
1808-
GetSpanFromArrayOfConstObjects<OrtValueInfo>(inputs_array, node_inputs);
1809-
GetSpanFromArrayOfConstObjects<OrtValueInfo>(outputs_array, node_outputs);
1794+
for (size_t fused_node_idx = 0; fused_node_idx < count; fused_node_idx++) {
1795+
auto fused_node = fused_nodes[fused_node_idx];
18101796

18111797
// Gets number of node's inputs and outputs
18121798
size_t num_node_inputs = 0;
1813-
size_t num_node_outputs = 0;
1814-
RETURN_IF_ERROR(ep->ort_api.ArrayOfConstObjects_GetSize(inputs_array, &num_node_inputs));
1815-
RETURN_IF_ERROR(ep->ort_api.ArrayOfConstObjects_GetSize(outputs_array, &num_node_outputs));
1799+
RETURN_IF_ERROR(ort_api.Node_GetNumInputs(fused_node, &num_node_inputs));
1800+
1801+
std::vector<const OrtValueInfo*> node_inputs(num_node_inputs);
1802+
RETURN_IF_ERROR(ort_api.Node_GetInputs(fused_node, node_inputs.data(), node_inputs.size()));
18161803

18171804
// Builds map from input name to its index in input list
18181805
std::unordered_map<std::string, size_t> input_map;
18191806
input_map.reserve(num_node_inputs);
18201807
for (size_t i = 0; i < num_node_inputs; i++) {
1821-
// TODO: Add ValueInfo_GetName() c api
1822-
//std::string& name = node_inputs[i]->GetName();
1823-
//input_map[name] = i;
1808+
const OrtValueInfo* value_info = node_inputs[i];
1809+
const char* name = nullptr;
1810+
RETURN_IF_ERROR(ort_api.GetValueInfoName(value_info, &name));
1811+
1812+
input_map.emplace(name, i);
18241813
}
18251814

1815+
// Gets number of node's outputs
1816+
size_t num_node_outputs = 0;
1817+
RETURN_IF_ERROR(ort_api.Node_GetNumInputs(fused_node, &num_node_outputs));
1818+
1819+
std::vector<const OrtValueInfo*> node_outputs(num_node_outputs);
1820+
RETURN_IF_ERROR(ort_api.Node_GetOutputs(fused_node, node_outputs.data(), node_outputs.size()));
1821+
18261822
// Builds map from output name to its index in output list
18271823
std::unordered_map<std::string, size_t> output_map;
1828-
input_map.reserve(num_node_outputs);
1824+
output_map.reserve(num_node_outputs);
18291825
for (size_t i = 0; i < num_node_outputs; i++) {
1830-
// TODO: Add ValueInfo_GetName() c api
1831-
//std::string& name = node_outputs[i]->GetName();
1832-
//output_map[name] = i;
1833-
}
1826+
const OrtValueInfo* value_info = node_outputs[i];
1827+
const char* name = nullptr;
1828+
RETURN_IF_ERROR(ort_api.GetValueInfoName(value_info, &name));
18341829

1830+
output_map.emplace(name, i);
1831+
}
1832+
18351833
OrtStatus* status;
1836-
//if (GraphHasCtxNode(graph_body_viewer)) {
1837-
if (false) {
1838-
status = ep->CreateNodeComputeInfoFromPrecompiledEngine(this_ptr, graphs[graph_idx], fused_node,
1834+
if (GraphHasCtxNode(graphs[fused_node_idx], ort_api)) {
1835+
RETURN_IF_ERROR(ep->CreateNodeComputeInfoFromPrecompiledEngine(this_ptr, graphs[fused_node_idx], fused_node,
18391836
input_map,
1840-
output_map, &result[graph_idx]);
1837+
output_map, &result[fused_node_idx]));
18411838
} else {
1842-
status = ep->CreateNodeComputeInfoFromGraph(this_ptr, graphs[graph_idx], fused_node, input_map,
1843-
output_map, &result[graph_idx]);
1839+
RETURN_IF_ERROR(ep->CreateNodeComputeInfoFromGraph(this_ptr, graphs[fused_node_idx], fused_node, input_map,
1840+
output_map, &result[fused_node_idx]));
18441841
}
1845-
//if (status != Status::OK()) {
1846-
// return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, status.ErrorMessage());
1847-
//}
18481842
}
18491843

18501844
return nullptr;

0 commit comments

Comments
 (0)