Skip to content

Commit 3269f73

Browse files
committed
Clean up CompileImpl
1 parent ed65a9f commit 3269f73

File tree

2 files changed

+38
-46
lines changed

2 files changed

+38
-46
lines changed

plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc

Lines changed: 27 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1642,52 +1642,46 @@ static OrtStatus* ORT_API_CALL CompileImpl(OrtEp* this_ptr, const OrtGraph** gra
16421642
gsl::span<const OrtValueInfo* const> node_inputs{};
16431643
gsl::span<const OrtValueInfo* const> node_outputs{};
16441644

1645-
RETURN_IF_ERROR(GetSpanFromConstPointerArray<OrtValueInfo>(inputs_array, node_inputs));
1646-
RETURN_IF_ERROR(GetSpanFromConstPointerArray<OrtValueInfo>(outputs_array, node_outputs));
1645+
GetSpanFromArrayOfConstObjects<OrtValueInfo>(inputs_array, node_inputs);
1646+
GetSpanFromArrayOfConstObjects<OrtValueInfo>(outputs_array, node_outputs);
16471647

16481648
// Gets number of node's inputs and outputs
16491649
size_t num_node_inputs = 0;
16501650
size_t num_node_outputs = 0;
1651-
RETURN_IF_ERROR(ep->ort_api.ConstPointerArray_GetSize(inputs_array, &num_node_inputs));
1652-
RETURN_IF_ERROR(ep->ort_api.ConstPointerArray_GetSize(outputs_array, &num_node_outputs));
1651+
RETURN_IF_ERROR(ep->ort_api.ArrayOfConstObjects_GetSize(inputs_array, &num_node_inputs));
1652+
RETURN_IF_ERROR(ep->ort_api.ArrayOfConstObjects_GetSize(outputs_array, &num_node_outputs));
16531653

16541654
// Builds map from input name to its index in input list
16551655
std::unordered_map<std::string, size_t> input_map;
16561656
input_map.reserve(num_node_inputs);
1657-
for (size_t i = 0, i < num_node_inputs; i++) {
1658-
std::string& name = node_inputs[i]->GetName();
1659-
input_map[name] = i;
1657+
for (size_t i = 0; i < num_node_inputs; i++) {
1658+
// TODO: Add ValueInfo_GetName() c api
1659+
//std::string& name = node_inputs[i]->GetName();
1660+
//input_map[name] = i;
16601661
}
16611662

16621663
// Builds map from output name to its index in output list
1663-
std::unordered_map<std::string, size_t> out_map;
1664+
std::unordered_map<std::string, size_t> output_map;
16641665
input_map.reserve(num_node_outputs);
1665-
for (size_t i = 0, i < num_node_outputs; i++) {
1666-
std::string& name = node_outputs[i]->GetName();
1667-
out_map[name] = i;
1668-
}
1669-
1670-
Status status;
1671-
if (GraphHasCtxNode(graph_body_viewer)) {
1672-
status = ep->CreateNodeComputeInfoFromPrecompiledEngine(graph_body_viewer,
1673-
fused_node,
1674-
input_map,
1675-
output_map,
1676-
node_compute_funcs);
1666+
for (size_t i = 0; i < num_node_outputs; i++) {
1667+
// TODO: Add ValueInfo_GetName() c api
1668+
//std::string& name = node_outputs[i]->GetName();
1669+
//output_map[name] = i;
1670+
}
1671+
1672+
OrtStatus* status;
1673+
//if (GraphHasCtxNode(graph_body_viewer)) {
1674+
if (false) {
1675+
status = ep->CreateNodeComputeInfoFromPrecompiledEngine(this_ptr, graphs[graph_idx], fused_node,
1676+
input_map,
1677+
output_map, node_compute_infos[graph_idx]);
16771678
} else {
1678-
status = ep->CreateNodeComputeInfoFromGraph(graph_body_viewer, fused_node, input_map, output_map, node_compute_funcs);
1679+
status = ep->CreateNodeComputeInfoFromGraph(this_ptr, graphs[graph_idx], fused_node, input_map,
1680+
output_map, node_compute_infos[graph_idx]);
16791681
}
1680-
if (status != Status::OK()) {
1681-
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, status.ErrorMessage());
1682-
}
1683-
1684-
/*
1685-
OrtArrayOfConstObjects* nodes_array = nullptr;
1686-
DeferOrtRelease<OrtArrayOfConstObjects> release_nodes(&nodes_array, ep->ort_api.ReleaseArrayOfConstObjects);
1687-
size_t num_nodes = 0;
1688-
RETURN_IF_ERROR(ep->ort_api.Graph_GetNodes(graphs[graph_idx], &nodes_array));
1689-
RETURN_IF_ERROR(ep->ort_api.ArrayOfConstObjects_GetSize(nodes_array, &num_nodes));
1690-
*/
1682+
//if (status != Status::OK()) {
1683+
// return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, status.ErrorMessage());
1684+
//}
16911685
}
16921686

16931687
return nullptr;
@@ -1724,6 +1718,7 @@ struct TensorrtExecutionProvider : TensorrtExecutionProvider(ApiPtrs apis, const
17241718
// The implementation of the SessionOptionsAppendExecutionProvider C API function automatically adds EP options to
17251719
// the session option configurations with the key prefix "ep.<lowercase_ep_name>.".
17261720
const std::string key_prefix = OrtSessionOptions::GetProviderOptionPrefix(name_.c_str());
1721+
const ConfigOptions& config_options = session_options.GetConfigOptions();
17271722
const std::unordered_map<std::string, std::string>& config_options_map = config_options.GetConfigOptionsMap();
17281723

17291724
// Get provider options as key-value pair strings

plugin_execution_providers/tensorrt/tensorrt_execution_provider.h

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,17 @@ struct TensorrtExecutionProvider : OrtEp, ApiPtrs {
246246
SubGraphCollection_t GetSupportedList(SubGraphCollection_t supported_nodes_list, int iterations, const int max_iterations,
247247
const OrtGraph* graph, bool* early_termination) const;
248248

249+
OrtStatus* CreateNodeComputeInfoFromPrecompiledEngine(OrtEp* this_ptr, const OrtGraph* graphs,
250+
const OrtNode* fused_nodes,
251+
std::unordered_map<std::string, size_t>& input_map,
252+
std::unordered_map<std::string, size_t>& output_map,
253+
OrtNodeComputeInfo* node_compute_infos);
254+
255+
OrtStatus* CreateNodeComputeInfoFromGraph(OrtEp* this_ptr, const OrtGraph* graphs, const OrtNode* fused_nodes,
256+
std::unordered_map<std::string, size_t>& input_map,
257+
std::unordered_map<std::string, size_t>& output_map,
258+
OrtNodeComputeInfo* node_compute_infos);
259+
249260
/*
250261
bool IsGraphCaptured(int graph_annotation_id) const { return false; }
251262
@@ -386,20 +397,6 @@ struct TensorrtExecutionProvider : OrtEp, ApiPtrs {
386397
// to allocate enough memory in Arena before graph capturing.
387398
const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations.
388399

389-
OrtStatus* CreateNodeComputeInfoFromPrecompiledEngine(OrtEp* this_ptr,
390-
const OrtGraph** graphs,
391-
const OrtNode** fused_nodes,
392-
std::unordered_map<std::string, size_t>& input_map,
393-
std::unordered_map<std::string, size_t>& output_map,
394-
OrtNodeComputeInfo** node_compute_infos);
395-
396-
OrtStatus* CreateNodeComputeInfoFromGraph(OrtEp* this_ptr,
397-
const OrtGraph** graphs,
398-
const OrtNode** fused_nodes,
399-
std::unordered_map<std::string, size_t>& input_map,
400-
std::unordered_map<std::string, size_t>& output_map,
401-
OrtNodeComputeInfo** node_compute_infos);
402-
403400
bool IsGraphCaptureAllowed() const { return false; };
404401

405402
nvinfer1::IBuilder* GetBuilder(TensorrtLogger& trt_logger) const;

0 commit comments

Comments
 (0)