|
13 | 13 | //#include "tensorrt_execution_provider_utils.h" |
14 | 14 | #include "tensorrt_execution_provider.h" |
15 | 15 | #include "cuda_allocator.h" |
16 | | -//#include "onnx_ctx_model_helper.h" |
| 16 | +#include "onnx_ctx_model_helper.h" |
17 | 17 | #include "onnx/onnx_pb.h" |
18 | 18 | #include "cuda/unary_elementwise_ops_impl.h" |
19 | 19 |
|
@@ -1480,8 +1480,7 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this |
1480 | 1480 | } |
1481 | 1481 |
|
1482 | 1482 |
|
1483 | | -OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* graph, |
1484 | | - OrtEpGraphSupportInfo* graph_support_info) { |
| 1483 | +OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* graph, OrtEpGraphSupportInfo* graph_support_info) { |
1485 | 1484 | TensorrtExecutionProvider* ep = static_cast<TensorrtExecutionProvider*>(this_ptr); |
1486 | 1485 | const OrtApi& ort_api = ep->ort_api; |
1487 | 1486 |
|
@@ -1780,71 +1779,66 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this |
1780 | 1779 | return nullptr; |
1781 | 1780 | } |
1782 | 1781 |
|
1783 | | -OrtStatus* ORT_API_CALL TensorrtExecutionProvider::CompileImpl(_In_ OrtEp* this_ptr, _In_ const OrtGraph** graphs, |
1784 | | - _In_ const OrtNode** fused_nodes, _In_ size_t count, |
| 1782 | +OrtStatus* ORT_API_CALL TensorrtExecutionProvider::CompileImpl(_In_ OrtEp* this_ptr, |
| 1783 | + _In_ const OrtGraph** graphs, |
| 1784 | + _In_ const OrtNode** fused_nodes, |
| 1785 | + _In_ size_t count, |
1785 | 1786 | _Out_writes_all_(count) OrtNodeComputeInfo** node_compute_infos, |
1786 | 1787 | _Out_writes_(count) OrtNode** ep_context_nodes) { |
1787 | 1788 |
|
1788 | 1789 | TensorrtExecutionProvider* ep = static_cast<TensorrtExecutionProvider*>(this_ptr); |
| 1790 | + const OrtApi& ort_api = ep->ort_api; |
1789 | 1791 |
|
1790 | 1792 | gsl::span<OrtNodeComputeInfo*> result(node_compute_infos, count); |
1791 | 1793 |
|
1792 | | - for (size_t graph_idx = 0; graph_idx < count; graph_idx++) { |
1793 | | - auto fused_node = fused_nodes[graph_idx]; |
1794 | | - |
1795 | | - // Gets node's inputs and outputs as pointer array |
1796 | | - OrtArrayOfConstObjects* inputs_array = nullptr; |
1797 | | - OrtArrayOfConstObjects* outputs_array = nullptr; |
1798 | | - DeferOrtRelease<OrtArrayOfConstObjects> release_inputs(&inputs_array, ep->ort_api.ReleaseArrayOfConstObjects); |
1799 | | - DeferOrtRelease<OrtArrayOfConstObjects> release_outputs(&outputs_array, ep->ort_api.ReleaseArrayOfConstObjects); |
1800 | | - |
1801 | | - RETURN_IF_ERROR(ep->ort_api.Node_GetInputs(fused_node, &inputs_array)); |
1802 | | - RETURN_IF_ERROR(ep->ort_api.Node_GetOutputs(fused_node, &outputs_array)); |
1803 | | - |
1804 | | - // Gets node's inputs and outputs as OrtValueInfo in gsl::span |
1805 | | - gsl::span<const OrtValueInfo* const> node_inputs{}; |
1806 | | - gsl::span<const OrtValueInfo* const> node_outputs{}; |
1807 | | - |
1808 | | - GetSpanFromArrayOfConstObjects<OrtValueInfo>(inputs_array, node_inputs); |
1809 | | - GetSpanFromArrayOfConstObjects<OrtValueInfo>(outputs_array, node_outputs); |
| 1794 | + for (size_t fused_node_idx = 0; fused_node_idx < count; fused_node_idx++) { |
| 1795 | + auto fused_node = fused_nodes[fused_node_idx]; |
1810 | 1796 |
|
1811 | 1797 | // Gets number of node's inputs and outputs |
1812 | 1798 | size_t num_node_inputs = 0; |
1813 | | - size_t num_node_outputs = 0; |
1814 | | - RETURN_IF_ERROR(ep->ort_api.ArrayOfConstObjects_GetSize(inputs_array, &num_node_inputs)); |
1815 | | - RETURN_IF_ERROR(ep->ort_api.ArrayOfConstObjects_GetSize(outputs_array, &num_node_outputs)); |
| 1799 | + RETURN_IF_ERROR(ort_api.Node_GetNumInputs(fused_node, &num_node_inputs)); |
| 1800 | + |
| 1801 | + std::vector<const OrtValueInfo*> node_inputs(num_node_inputs); |
| 1802 | + RETURN_IF_ERROR(ort_api.Node_GetInputs(fused_node, node_inputs.data(), node_inputs.size())); |
1816 | 1803 |
|
1817 | 1804 | // Builds map from input name to its index in input list |
1818 | 1805 | std::unordered_map<std::string, size_t> input_map; |
1819 | 1806 | input_map.reserve(num_node_inputs); |
1820 | 1807 | for (size_t i = 0; i < num_node_inputs; i++) { |
1821 | | - // TODO: Add ValueInfo_GetName() c api |
1822 | | - //std::string& name = node_inputs[i]->GetName(); |
1823 | | - //input_map[name] = i; |
| 1808 | + const OrtValueInfo* value_info = node_inputs[i]; |
| 1809 | + const char* name = nullptr; |
| 1810 | + RETURN_IF_ERROR(ort_api.GetValueInfoName(value_info, &name)); |
| 1811 | + |
| 1812 | + input_map.emplace(name, i); |
1824 | 1813 | } |
1825 | 1814 |
|
| 1815 | + // Gets number of node's outputs |
| 1816 | + size_t num_node_outputs = 0; |
| 1817 | + RETURN_IF_ERROR(ort_api.Node_GetNumInputs(fused_node, &num_node_outputs)); |
| 1818 | + |
| 1819 | + std::vector<const OrtValueInfo*> node_outputs(num_node_outputs); |
| 1820 | + RETURN_IF_ERROR(ort_api.Node_GetOutputs(fused_node, node_outputs.data(), node_outputs.size())); |
| 1821 | + |
1826 | 1822 | // Builds map from output name to its index in output list |
1827 | 1823 | std::unordered_map<std::string, size_t> output_map; |
1828 | | - input_map.reserve(num_node_outputs); |
| 1824 | + output_map.reserve(num_node_outputs); |
1829 | 1825 | for (size_t i = 0; i < num_node_outputs; i++) { |
1830 | | - // TODO: Add ValueInfo_GetName() c api |
1831 | | - //std::string& name = node_outputs[i]->GetName(); |
1832 | | - //output_map[name] = i; |
1833 | | - } |
| 1826 | + const OrtValueInfo* value_info = node_outputs[i]; |
| 1827 | + const char* name = nullptr; |
| 1828 | + RETURN_IF_ERROR(ort_api.GetValueInfoName(value_info, &name)); |
1834 | 1829 |
|
| 1830 | + output_map.emplace(name, i); |
| 1831 | + } |
| 1832 | + |
1835 | 1833 | OrtStatus* status; |
1836 | | - //if (GraphHasCtxNode(graph_body_viewer)) { |
1837 | | - if (false) { |
1838 | | - status = ep->CreateNodeComputeInfoFromPrecompiledEngine(this_ptr, graphs[graph_idx], fused_node, |
| 1834 | + if (GraphHasCtxNode(graphs[fused_node_idx], ort_api)) { |
| 1835 | + RETURN_IF_ERROR(ep->CreateNodeComputeInfoFromPrecompiledEngine(this_ptr, graphs[fused_node_idx], fused_node, |
1839 | 1836 | input_map, |
1840 | | - output_map, &result[graph_idx]); |
| 1837 | + output_map, &result[fused_node_idx])); |
1841 | 1838 | } else { |
1842 | | - status = ep->CreateNodeComputeInfoFromGraph(this_ptr, graphs[graph_idx], fused_node, input_map, |
1843 | | - output_map, &result[graph_idx]); |
| 1839 | + RETURN_IF_ERROR(ep->CreateNodeComputeInfoFromGraph(this_ptr, graphs[fused_node_idx], fused_node, input_map, |
| 1840 | + output_map, &result[fused_node_idx])); |
1844 | 1841 | } |
1845 | | - //if (status != Status::OK()) { |
1846 | | - // return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, status.ErrorMessage()); |
1847 | | - //} |
1848 | 1842 | } |
1849 | 1843 |
|
1850 | 1844 | return nullptr; |
|
0 commit comments