|
1 | | -#include <cassert> |
| 1 | +// Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | +// Licensed under the MIT License. |
| 3 | + |
2 | 4 | #include <iostream> |
3 | 5 | #include <fstream> |
| 6 | +#include <filesystem> |
| 7 | + |
| 8 | +#include "tensorrt_execution_provider_utils.h" |
4 | 9 | #include "onnx_ctx_model_helper.h" |
5 | | -#include "tensorrt_execution_provider.h" |
6 | | -#include "path_string.h" |
7 | 10 |
|
8 | | -namespace onnxruntime { |
9 | | - |
10 | | -bool GraphHasCtxNode(const OrtGraphViewer* graph_viewer) { |
11 | | - const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); |
12 | | - const OrtGraphApi* graph_api = api->GetGraphApi(ORT_API_VERSION); |
13 | | - int maxNodeIndex = 0; |
14 | | - graph_api->OrtGraph_MaxNodeIndex(graph_viewer, &maxNodeIndex); |
15 | | - for (int i = 0; i < maxNodeIndex; ++i) { |
16 | | - const OrtNode* node = nullptr; |
17 | | - graph_api->OrtGraph_GetOrtNode(graph_viewer, i, &node); |
18 | | - if (node == nullptr) { |
19 | | - continue; |
20 | | - } |
21 | | - const char* opType = nullptr; |
22 | | - graph_api->OrtNode_GetOpType(node, &opType); |
23 | | - if (strcmp(opType, EPCONTEXT_OP.c_str()) == 0) { |
24 | | - return true; |
25 | | - } |
26 | | - } |
27 | | - return false; |
28 | | -} |
| 11 | +extern TensorrtLogger& GetTensorrtLogger(bool verbose_log); |
29 | 12 |
|
30 | 13 | /* |
31 | | - * Return the directory where the ep context model locates |
| 14 | + * Check whether the graph has the EP context node. |
| 15 | + * The node can contain the precompiled engine info for TRT EP to directly load the engine. |
| 16 | + * |
| 17 | + * Note: Please see more details about "EPContext" contrib op in contrib_defs.cc |
32 | 18 | */ |
33 | | -std::filesystem::path GetPathOrParentPathOfCtxModel(const std::string& ep_context_file_path) { |
34 | | - if (ep_context_file_path.empty()) { |
35 | | - return std::filesystem::path(); |
36 | | - } |
37 | | - std::filesystem::path ctx_path(ep_context_file_path); |
38 | | - if (std::filesystem::is_directory(ep_context_file_path)) { |
39 | | - return ctx_path; |
40 | | - } else { |
41 | | - return ctx_path.parent_path(); |
42 | | - } |
43 | | -} |
| 19 | +bool EPContextNodeHelper::GraphHasCtxNode(const OrtGraph* graph, const OrtApi& ort_api) { |
| 20 | + size_t num_nodes = 0; |
| 21 | + RETURN_IF_ERROR(ort_api.Graph_GetNumNodes(graph, &num_nodes)); |
44 | 22 |
|
45 | | -std::string GetCtxModelPath(const std::string& ep_context_file_path, |
46 | | - const std::string& original_model_path) { |
47 | | - std::string ctx_model_path; |
| 23 | + std::vector<const OrtNode*> nodes(num_nodes); |
48 | 24 |
|
49 | | - if (!ep_context_file_path.empty() && !std::filesystem::is_directory(ep_context_file_path)) { |
50 | | - ctx_model_path = ep_context_file_path; |
51 | | - } else { |
52 | | - std::filesystem::path model_path = original_model_path; |
53 | | - std::filesystem::path model_name_stem = model_path.stem(); // model_name.onnx -> model_name |
54 | | - std::string ctx_model_name = model_name_stem.string() + "_ctx.onnx"; |
| 25 | + for (size_t i = 0; i < num_nodes; ++i) { |
| 26 | + auto node = nodes[i]; |
55 | 27 |
|
56 | | - if (std::filesystem::is_directory(ep_context_file_path)) { |
57 | | - std::filesystem::path model_directory = ep_context_file_path; |
58 | | - ctx_model_path = model_directory.append(ctx_model_name).string(); |
59 | | - } else { |
60 | | - ctx_model_path = ctx_model_name; |
| 28 | + const char* op_type = nullptr; |
| 29 | + RETURN_IF_ERROR(ort_api.Node_GetOperatorType(node, &op_type)); |
| 30 | + if (node != nullptr && op_type == "EPContext") { |
| 31 | + return true; |
61 | 32 | } |
62 | 33 | } |
63 | | - return ctx_model_path; |
64 | | -} |
65 | | - |
66 | | -bool IsAbsolutePath(const std::string& path_string) { |
67 | | -#ifdef _WIN32 |
68 | | - onnxruntime::PathString ort_path_string = onnxruntime::ToPathString(path_string); |
69 | | - auto path = std::filesystem::path(ort_path_string.c_str()); |
70 | | - return path.is_absolute(); |
71 | | -#else |
72 | | - if (!path_string.empty() && path_string[0] == '/') { |
73 | | - return true; |
74 | | - } |
75 | 34 | return false; |
76 | | -#endif |
77 | | -} |
78 | | - |
79 | | -// Like "../file_path" |
80 | | -bool IsRelativePathToParentPath(const std::string& path_string) { |
81 | | -#ifdef _WIN32 |
82 | | - onnxruntime::PathString ort_path_string = onnxruntime::ToPathString(path_string); |
83 | | - auto path = std::filesystem::path(ort_path_string.c_str()); |
84 | | - auto relative_path = path.lexically_normal().make_preferred().wstring(); |
85 | | - if (relative_path.find(L"..", 0) != std::string::npos) { |
86 | | - return true; |
87 | | - } |
88 | | - return false; |
89 | | -#else |
90 | | - if (!path_string.empty() && path_string.find("..", 0) != std::string::npos) { |
91 | | - return true; |
92 | | - } |
93 | | - return false; |
94 | | -#endif |
95 | 35 | } |
96 | 36 |
|
97 | 37 | /* |
98 | | - * Get the weight-refitted engine cache path from a weight-stripped engine cache path |
99 | | - * |
100 | | - * Weight-stipped engine: |
101 | | - * An engine with weights stripped and its size is smaller than a regualr engine. |
102 | | - * The cache name of weight-stripped engine is TensorrtExecutionProvider_TRTKernel_XXXXX.stripped.engine |
103 | | - * |
104 | | - * Weight-refitted engine: |
105 | | - * An engine that its weights have been refitted and it's simply a regular engine. |
106 | | - * The cache name of weight-refitted engine is TensorrtExecutionProvider_TRTKernel_XXXXX.engine |
| 38 | + * Create EPContext OrtNode from a fused_node |
107 | 39 | */ |
108 | | -std::string GetWeightRefittedEnginePath(std::string stripped_engine_cache) { |
109 | | - std::filesystem::path stripped_engine_cache_path(stripped_engine_cache); |
110 | | - std::string refitted_engine_cache_path = stripped_engine_cache_path.stem().stem().string() + ".engine"; |
111 | | - return refitted_engine_cache_path; |
112 | | -} |
113 | | - |
114 | | -bool IsWeightStrippedEngineCache(std::filesystem::path& engine_cache_path) { |
115 | | - // The weight-stripped engine cache has the naming of xxx.stripped.engine |
116 | | - return engine_cache_path.stem().extension().string() == ".stripped"; |
117 | | -} |
118 | | - |
119 | | -OrtStatusPtr TensorRTCacheModelHandler::GetEpContextFromGraph(const OrtGraphViewer* graph_viewer) { |
120 | | - if (!ValidateEPCtxNode(graph_viewer)) { |
121 | | - return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "It's not a valid EP Context node"); |
122 | | - } |
123 | | - const OrtNode* node = nullptr; |
124 | | - graph_api_->OrtGraph_GetOrtNode(graph_viewer, 0, &node); |
125 | | - |
126 | | - int64_t embed_mode = -1; |
127 | | - graph_api_->OrtNode_GetAttributeInt(node, EMBED_MODE.c_str(), &embed_mode); |
128 | | - if (embed_mode) { |
129 | | - // Get engine from byte stream. |
130 | | - const char* context_binary_cstr = nullptr; |
131 | | - size_t size; |
132 | | - graph_api_->OrtNode_GetAttributeStrWithSize(node, EP_CACHE_CONTEXT.c_str(), &context_binary_cstr, &size); |
133 | | - std::string context_binary(context_binary_cstr, size); |
134 | | - *(trt_engine_) = std::unique_ptr<nvinfer1::ICudaEngine>(trt_runtime_->deserializeCudaEngine(const_cast<char*>(context_binary.c_str()), |
135 | | - static_cast<size_t>(context_binary.length()))); |
136 | | -// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Read engine as binary data from \"ep_cache_context\" attribute of ep context node and deserialized it"; |
137 | | - if (!(*trt_engine_)) { |
138 | | - return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP could not deserialize engine from binary data"); |
| 40 | +OrtStatus* EPContextNodeHelper::CreateEPContextNode(const std::string& engine_cache_path, |
| 41 | + char* engine_data, |
| 42 | + size_t size, |
| 43 | + const int64_t embed_mode, |
| 44 | + const std::string& compute_capability, |
| 45 | + const std::string& onnx_model_path, |
| 46 | + OrtNode** ep_context_node) { |
| 47 | + |
| 48 | + // Helper to collect input or output names from an array of OrtValueInfo instances. |
| 49 | + auto collect_input_output_names = [&](gsl::span<const OrtValueInfo* const> value_infos, |
| 50 | + std::vector<const char*>& result) -> OrtStatus* { |
| 51 | + size_t num_values = value_infos.size(); |
| 52 | + std::vector<const char*> value_names(num_values); |
| 53 | + |
| 54 | + for (size_t i = 0; i < num_values; ++i) { |
| 55 | + const OrtValueInfo* value_info = value_infos[i]; |
| 56 | + RETURN_IF_ERROR(ort_api.GetValueInfoName(value_info, &value_names[i])); |
139 | 57 | } |
140 | | - } else { |
141 | | - // Get engine from cache file. |
142 | | - const char* cache_path_cstr = nullptr; |
143 | | - graph_api_->OrtNode_GetAttributeStr(node, EP_CACHE_CONTEXT.c_str(), &cache_path_cstr); |
144 | | - std::string cache_path(cache_path_cstr); |
145 | 58 |
|
146 | | - // For security purpose, in the case of running context model, TRT EP won't allow |
147 | | - // engine cache path to be the relative path like "../file_path" or the absolute path. |
148 | | - // It only allows the engine cache to be in the same directory or sub directory of the context model. |
149 | | - if (IsAbsolutePath(cache_path)) { |
150 | | - return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, std::string("For security purpose, the ep_cache_context attribute should be set with a relative path, but it is an absolute path: " + cache_path).c_str()); |
151 | | - } |
152 | | - if (IsRelativePathToParentPath(cache_path)) { |
153 | | - return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "The file path in ep_cache_context attribute has '..'. For security purpose, it's not allowed to point outside the directory."); |
154 | | - } |
| 59 | + result = std::move(value_names); |
| 60 | + return nullptr; |
| 61 | + }; |
155 | 62 |
|
156 | | - // The engine cache and context model (current model) should be in the same directory |
157 | | - std::filesystem::path ctx_model_dir(GetPathOrParentPathOfCtxModel(ep_context_model_path_)); |
158 | | - auto engine_cache_path = ctx_model_dir.append(cache_path); |
159 | | -// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] GetEpContextFromGraph engine_cache_path: " + engine_cache_path.string(); |
| 63 | + const char* fused_node_name = nullptr; |
160 | 64 |
|
161 | | - // If it's a weight-stripped engine cache, it needs to be refitted even though the refit flag is not enabled |
162 | | - if (!weight_stripped_engine_refit_) { |
163 | | - weight_stripped_engine_refit_ = IsWeightStrippedEngineCache(engine_cache_path); |
164 | | - } |
| 65 | + RETURN_IF_ERROR(ort_api.Node_GetName(fused_node_, &fused_node_name)); |
165 | 66 |
|
166 | | - // If the serialized refitted engine is present, use it directly without refitting the engine again |
167 | | - if (weight_stripped_engine_refit_) { |
168 | | - const std::filesystem::path refitted_engine_cache_path = GetWeightRefittedEnginePath(engine_cache_path.string()); |
169 | | - if (std::filesystem::exists(refitted_engine_cache_path)) { |
170 | | -// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " + refitted_engine_cache_path.string() + " exists."; |
171 | | - engine_cache_path = refitted_engine_cache_path.string(); |
172 | | - weight_stripped_engine_refit_ = false; |
173 | | - } |
174 | | - } |
| 67 | + size_t num_fused_node_inputs = 0; |
| 68 | + size_t num_fused_node_outputs = 0; |
| 69 | + RETURN_IF_ERROR(ort_api.Node_GetNumInputs(fused_node_, &num_fused_node_inputs)); |
| 70 | + RETURN_IF_ERROR(ort_api.Node_GetNumOutputs(fused_node_, &num_fused_node_outputs)); |
175 | 71 |
|
176 | | - if (!std::filesystem::exists(engine_cache_path)) { |
177 | | - return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, |
178 | | - std::string("TensorRT EP can't find engine cache: " + engine_cache_path.string() + |
179 | | - ". Please make sure engine cache is in the same directory or sub-directory of context model.").c_str()); |
180 | | - } |
| 72 | + std::vector<const OrtValueInfo*> fused_node_inputs(num_fused_node_inputs); |
| 73 | + std::vector<const OrtValueInfo*> fused_node_outputs(num_fused_node_outputs); |
| 74 | + RETURN_IF_ERROR(ort_api.Node_GetInputs(fused_node_, fused_node_inputs.data(), fused_node_inputs.size())); |
| 75 | + RETURN_IF_ERROR(ort_api.Node_GetOutputs(fused_node_, fused_node_outputs.data(), fused_node_outputs.size())); |
181 | 76 |
|
182 | | - std::ifstream engine_file(engine_cache_path.string(), std::ios::binary | std::ios::in); |
183 | | - engine_file.seekg(0, std::ios::end); |
184 | | - size_t engine_size = engine_file.tellg(); |
185 | | - engine_file.seekg(0, std::ios::beg); |
186 | | - std::unique_ptr<char[]> engine_buf{new char[engine_size]}; |
187 | | - engine_file.read((char*)engine_buf.get(), engine_size); |
188 | | - *(trt_engine_) = std::unique_ptr<nvinfer1::ICudaEngine>(trt_runtime_->deserializeCudaEngine(engine_buf.get(), engine_size)); |
189 | | - if (!(*trt_engine_)) { |
190 | | - return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, |
191 | | - std::string("TensorRT EP could not deserialize engine from cache: " + engine_cache_path.string()).c_str()); |
192 | | - } |
193 | | -// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path.string(); |
| 77 | + std::vector<const char*> input_names; |
| 78 | + std::vector<const char*> output_names; |
194 | 79 |
|
195 | | - if (weight_stripped_engine_refit_) { |
196 | | - const char* onnx_model_filename_cstr = nullptr; |
197 | | - graph_api_->OrtNode_GetAttributeStr(node, ONNX_MODEL_FILENAME.c_str(), &onnx_model_filename_cstr); |
198 | | - const std::string onnx_model_filename(onnx_model_filename_cstr); |
199 | | - std::string weight_stripped_engine_cache = engine_cache_path.string(); |
200 | | - auto status = TensorrtExecutionProvider::RefitEngine(onnx_model_filename, |
201 | | - onnx_model_folder_path_, |
202 | | - weight_stripped_engine_cache, |
203 | | - true /* path check for security */, |
204 | | - (*trt_engine_).get(), |
205 | | - true /* serialize refitted engine to disk */, |
206 | | - detailed_build_log_); |
207 | | - if (status != nullptr) { |
208 | | - return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, api_->GetErrorMessage(status)); |
209 | | - } |
210 | | - } |
211 | | - } |
212 | | - return nullptr; |
213 | | -} |
| 80 | + RETURN_IF_ERROR(collect_input_output_names(fused_node_inputs, /*out*/ input_names)); |
| 81 | + RETURN_IF_ERROR(collect_input_output_names(fused_node_outputs, /*out*/ output_names)); |
214 | 82 |
|
215 | | -bool TensorRTCacheModelHandler::ValidateEPCtxNode(const OrtGraphViewer* graph_viewer) { |
216 | | - int node_count = 0; |
217 | | - graph_api_->OrtGraph_NumberOfNodes(graph_viewer, &node_count); |
218 | | - assert(node_count == 1); |
219 | | - const OrtNode* node = nullptr; |
220 | | - graph_api_->OrtGraph_GetOrtNode(graph_viewer, 0, &node); |
221 | | - const char* opType = nullptr; |
222 | | - graph_api_->OrtNode_GetOpType(node, &opType); |
223 | | - assert(strcmp(opType, EPCONTEXT_OP.c_str()) == 0); |
| 83 | + // Create node attributes. The CreateNode() function copies the attributes, so we have to release them. |
| 84 | + std::array<OrtOpAttr*, 4> attributes = {}; |
| 85 | + DeferOrtRelease<OrtOpAttr> defer_release_attrs(attributes.data(), attributes.size(), ort_api.ReleaseOpAttr); |
224 | 86 |
|
225 | | - size_t key_count = 0; |
226 | | - graph_api_->OrtNode_GetAttributeKeyCount(node, COMPUTE_CAPABILITY.c_str(), &key_count); |
227 | | - // Show the warning if compute capability is not matched |
228 | | - if (key_count > 0) { |
229 | | - const char* model_compute_capability = nullptr; |
230 | | - graph_api_->OrtNode_GetAttributeStr(node, COMPUTE_CAPABILITY.c_str(), &model_compute_capability); |
231 | | - // Verify if engine was compiled with ampere+ hardware compatibility enabled |
232 | | - if (strcmp(model_compute_capability, "80+") == 0) { |
233 | | -// if (std::stoi(compute_capability_) < 80) { |
234 | | -// LOGS_DEFAULT(WARNING) << "[TensorRT EP] However, this GPU doesn't match. The compute capability of the GPU: " << compute_capability_; |
235 | | -// } |
236 | | - } else if (strcmp(model_compute_capability, compute_capability_.c_str()) != 0) { |
237 | | -// LOGS_DEFAULT(WARNING) << "[TensorRT EP] Engine was compiled for a different compatibility level and might not work or perform suboptimal"; |
238 | | -// LOGS_DEFAULT(WARNING) << "[TensorRT EP] The compute capability of the engine: " << model_compute_capability; |
239 | | -// LOGS_DEFAULT(WARNING) << "[TensorRT EP] The compute capability of the GPU: " << compute_capability_; |
| 87 | + RETURN_IF_ERROR(ort_api.CreateOpAttr("embed_mode", &embed_mode, 1, ORT_OP_ATTR_INT, &attributes[0])); |
| 88 | + |
| 89 | + std::string engine_data_str = ""; |
| 90 | + if (embed_mode) { |
| 91 | + if (size > 0) { |
| 92 | + engine_data_str.assign(engine_data, size); |
240 | 93 | } |
| 94 | + RETURN_IF_ERROR( |
| 95 | + ort_api.CreateOpAttr("ep_cache_context", engine_data_str.c_str(), 1, ORT_OP_ATTR_STRING, &attributes[1])); |
| 96 | + } else { |
| 97 | + RETURN_IF_ERROR(ort_api.CreateOpAttr("ep_cache_context", engine_cache_path.c_str(), 1, ORT_OP_ATTR_STRING, &attributes[1])); |
241 | 98 | } |
242 | 99 |
|
243 | | - // "embed_mode" attr and "ep_cache_context" attr should be present |
244 | | - graph_api_->OrtNode_GetAttributeKeyCount(node, EMBED_MODE.c_str(), &key_count); |
245 | | - assert(key_count > 0); |
246 | | - graph_api_->OrtNode_GetAttributeKeyCount(node, EP_CACHE_CONTEXT.c_str(), &key_count); |
247 | | - assert(key_count > 0); |
| 100 | + |
| 101 | + ort_api.CreateOpAttr("hardware_architecture", compute_capability.c_str(), 1, ORT_OP_ATTR_STRING, &attributes[2]); |
| 102 | + ort_api.CreateOpAttr("onnx_model_filename", std::filesystem::path(onnx_model_path).filename().string().c_str(), 1, |
| 103 | + ORT_OP_ATTR_STRING, &attributes[3]); |
248 | 104 |
|
249 | | - int64_t embed_mode = -1; |
250 | | - graph_api_->OrtNode_GetAttributeInt(node, EMBED_MODE.c_str(), &embed_mode); |
251 | | - if (embed_mode == 1) { |
252 | | - // engine binary data |
253 | | -// LOGS_DEFAULT(WARNING) << EPCONTEXT_WARNING; |
254 | | - } |
255 | 105 |
|
256 | | - return true; |
257 | | -} |
| 106 | + RETURN_IF_ERROR(model_editor_api.CreateNode("EPContext", "com.microsoft", fused_node_name, input_names.data(), |
| 107 | + input_names.size(), output_names.data(), output_names.size(), |
| 108 | + attributes.data(), attributes.size(), ep_context_node)); |
| 109 | + |
| 110 | + return nullptr; |
258 | 111 | } |
0 commit comments