Skip to content

Commit 081de36

Browse files
committed
update EP context model helper
1 parent 3ced4cf commit 081de36

File tree

4 files changed

+157
-640
lines changed

4 files changed

+157
-640
lines changed
Lines changed: 78 additions & 225 deletions
Original file line numberDiff line numberDiff line change
@@ -1,258 +1,111 @@
1-
#include <cassert>
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
24
#include <iostream>
35
#include <fstream>
6+
#include <filesystem>
7+
8+
#include "tensorrt_execution_provider_utils.h"
49
#include "onnx_ctx_model_helper.h"
5-
#include "tensorrt_execution_provider.h"
6-
#include "path_string.h"
710

8-
namespace onnxruntime {
9-
10-
bool GraphHasCtxNode(const OrtGraphViewer* graph_viewer) {
11-
const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION);
12-
const OrtGraphApi* graph_api = api->GetGraphApi(ORT_API_VERSION);
13-
int maxNodeIndex = 0;
14-
graph_api->OrtGraph_MaxNodeIndex(graph_viewer, &maxNodeIndex);
15-
for (int i = 0; i < maxNodeIndex; ++i) {
16-
const OrtNode* node = nullptr;
17-
graph_api->OrtGraph_GetOrtNode(graph_viewer, i, &node);
18-
if (node == nullptr) {
19-
continue;
20-
}
21-
const char* opType = nullptr;
22-
graph_api->OrtNode_GetOpType(node, &opType);
23-
if (strcmp(opType, EPCONTEXT_OP.c_str()) == 0) {
24-
return true;
25-
}
26-
}
27-
return false;
28-
}
11+
extern TensorrtLogger& GetTensorrtLogger(bool verbose_log);
2912

3013
/*
31-
* Return the directory where the ep context model locates
14+
* Check whether the graph has the EP context node.
15+
* The node can contain the precompiled engine info for TRT EP to directly load the engine.
16+
*
17+
* Note: Please see more details about "EPContext" contrib op in contrib_defs.cc
3218
*/
33-
std::filesystem::path GetPathOrParentPathOfCtxModel(const std::string& ep_context_file_path) {
34-
if (ep_context_file_path.empty()) {
35-
return std::filesystem::path();
36-
}
37-
std::filesystem::path ctx_path(ep_context_file_path);
38-
if (std::filesystem::is_directory(ep_context_file_path)) {
39-
return ctx_path;
40-
} else {
41-
return ctx_path.parent_path();
42-
}
43-
}
19+
bool EPContextNodeHelper::GraphHasCtxNode(const OrtGraph* graph, const OrtApi& ort_api) {
20+
size_t num_nodes = 0;
21+
RETURN_IF_ERROR(ort_api.Graph_GetNumNodes(graph, &num_nodes));
4422

45-
std::string GetCtxModelPath(const std::string& ep_context_file_path,
46-
const std::string& original_model_path) {
47-
std::string ctx_model_path;
23+
std::vector<const OrtNode*> nodes(num_nodes);
4824

49-
if (!ep_context_file_path.empty() && !std::filesystem::is_directory(ep_context_file_path)) {
50-
ctx_model_path = ep_context_file_path;
51-
} else {
52-
std::filesystem::path model_path = original_model_path;
53-
std::filesystem::path model_name_stem = model_path.stem(); // model_name.onnx -> model_name
54-
std::string ctx_model_name = model_name_stem.string() + "_ctx.onnx";
25+
for (size_t i = 0; i < num_nodes; ++i) {
26+
auto node = nodes[i];
5527

56-
if (std::filesystem::is_directory(ep_context_file_path)) {
57-
std::filesystem::path model_directory = ep_context_file_path;
58-
ctx_model_path = model_directory.append(ctx_model_name).string();
59-
} else {
60-
ctx_model_path = ctx_model_name;
28+
const char* op_type = nullptr;
29+
RETURN_IF_ERROR(ort_api.Node_GetOperatorType(node, &op_type));
30+
if (node != nullptr && op_type == "EPContext") {
31+
return true;
6132
}
6233
}
63-
return ctx_model_path;
64-
}
65-
66-
bool IsAbsolutePath(const std::string& path_string) {
67-
#ifdef _WIN32
68-
onnxruntime::PathString ort_path_string = onnxruntime::ToPathString(path_string);
69-
auto path = std::filesystem::path(ort_path_string.c_str());
70-
return path.is_absolute();
71-
#else
72-
if (!path_string.empty() && path_string[0] == '/') {
73-
return true;
74-
}
7534
return false;
76-
#endif
77-
}
78-
79-
// Like "../file_path"
80-
bool IsRelativePathToParentPath(const std::string& path_string) {
81-
#ifdef _WIN32
82-
onnxruntime::PathString ort_path_string = onnxruntime::ToPathString(path_string);
83-
auto path = std::filesystem::path(ort_path_string.c_str());
84-
auto relative_path = path.lexically_normal().make_preferred().wstring();
85-
if (relative_path.find(L"..", 0) != std::string::npos) {
86-
return true;
87-
}
88-
return false;
89-
#else
90-
if (!path_string.empty() && path_string.find("..", 0) != std::string::npos) {
91-
return true;
92-
}
93-
return false;
94-
#endif
9535
}
9636

9737
/*
98-
* Get the weight-refitted engine cache path from a weight-stripped engine cache path
99-
*
100-
* Weight-stipped engine:
101-
* An engine with weights stripped and its size is smaller than a regualr engine.
102-
* The cache name of weight-stripped engine is TensorrtExecutionProvider_TRTKernel_XXXXX.stripped.engine
103-
*
104-
* Weight-refitted engine:
105-
* An engine that its weights have been refitted and it's simply a regular engine.
106-
* The cache name of weight-refitted engine is TensorrtExecutionProvider_TRTKernel_XXXXX.engine
38+
* Create EPContext OrtNode from a fused_node
10739
*/
108-
std::string GetWeightRefittedEnginePath(std::string stripped_engine_cache) {
109-
std::filesystem::path stripped_engine_cache_path(stripped_engine_cache);
110-
std::string refitted_engine_cache_path = stripped_engine_cache_path.stem().stem().string() + ".engine";
111-
return refitted_engine_cache_path;
112-
}
113-
114-
bool IsWeightStrippedEngineCache(std::filesystem::path& engine_cache_path) {
115-
// The weight-stripped engine cache has the naming of xxx.stripped.engine
116-
return engine_cache_path.stem().extension().string() == ".stripped";
117-
}
118-
119-
OrtStatusPtr TensorRTCacheModelHandler::GetEpContextFromGraph(const OrtGraphViewer* graph_viewer) {
120-
if (!ValidateEPCtxNode(graph_viewer)) {
121-
return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "It's not a valid EP Context node");
122-
}
123-
const OrtNode* node = nullptr;
124-
graph_api_->OrtGraph_GetOrtNode(graph_viewer, 0, &node);
125-
126-
int64_t embed_mode = -1;
127-
graph_api_->OrtNode_GetAttributeInt(node, EMBED_MODE.c_str(), &embed_mode);
128-
if (embed_mode) {
129-
// Get engine from byte stream.
130-
const char* context_binary_cstr = nullptr;
131-
size_t size;
132-
graph_api_->OrtNode_GetAttributeStrWithSize(node, EP_CACHE_CONTEXT.c_str(), &context_binary_cstr, &size);
133-
std::string context_binary(context_binary_cstr, size);
134-
*(trt_engine_) = std::unique_ptr<nvinfer1::ICudaEngine>(trt_runtime_->deserializeCudaEngine(const_cast<char*>(context_binary.c_str()),
135-
static_cast<size_t>(context_binary.length())));
136-
// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Read engine as binary data from \"ep_cache_context\" attribute of ep context node and deserialized it";
137-
if (!(*trt_engine_)) {
138-
return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "TensorRT EP could not deserialize engine from binary data");
40+
OrtStatus* EPContextNodeHelper::CreateEPContextNode(const std::string& engine_cache_path,
41+
char* engine_data,
42+
size_t size,
43+
const int64_t embed_mode,
44+
const std::string& compute_capability,
45+
const std::string& onnx_model_path,
46+
OrtNode** ep_context_node) {
47+
48+
// Helper to collect input or output names from an array of OrtValueInfo instances.
49+
auto collect_input_output_names = [&](gsl::span<const OrtValueInfo* const> value_infos,
50+
std::vector<const char*>& result) -> OrtStatus* {
51+
size_t num_values = value_infos.size();
52+
std::vector<const char*> value_names(num_values);
53+
54+
for (size_t i = 0; i < num_values; ++i) {
55+
const OrtValueInfo* value_info = value_infos[i];
56+
RETURN_IF_ERROR(ort_api.GetValueInfoName(value_info, &value_names[i]));
13957
}
140-
} else {
141-
// Get engine from cache file.
142-
const char* cache_path_cstr = nullptr;
143-
graph_api_->OrtNode_GetAttributeStr(node, EP_CACHE_CONTEXT.c_str(), &cache_path_cstr);
144-
std::string cache_path(cache_path_cstr);
14558

146-
// For security purpose, in the case of running context model, TRT EP won't allow
147-
// engine cache path to be the relative path like "../file_path" or the absolute path.
148-
// It only allows the engine cache to be in the same directory or sub directory of the context model.
149-
if (IsAbsolutePath(cache_path)) {
150-
return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, std::string("For security purpose, the ep_cache_context attribute should be set with a relative path, but it is an absolute path: " + cache_path).c_str());
151-
}
152-
if (IsRelativePathToParentPath(cache_path)) {
153-
return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "The file path in ep_cache_context attribute has '..'. For security purpose, it's not allowed to point outside the directory.");
154-
}
59+
result = std::move(value_names);
60+
return nullptr;
61+
};
15562

156-
// The engine cache and context model (current model) should be in the same directory
157-
std::filesystem::path ctx_model_dir(GetPathOrParentPathOfCtxModel(ep_context_model_path_));
158-
auto engine_cache_path = ctx_model_dir.append(cache_path);
159-
// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] GetEpContextFromGraph engine_cache_path: " + engine_cache_path.string();
63+
const char* fused_node_name = nullptr;
16064

161-
// If it's a weight-stripped engine cache, it needs to be refitted even though the refit flag is not enabled
162-
if (!weight_stripped_engine_refit_) {
163-
weight_stripped_engine_refit_ = IsWeightStrippedEngineCache(engine_cache_path);
164-
}
65+
RETURN_IF_ERROR(ort_api.Node_GetName(fused_node_, &fused_node_name));
16566

166-
// If the serialized refitted engine is present, use it directly without refitting the engine again
167-
if (weight_stripped_engine_refit_) {
168-
const std::filesystem::path refitted_engine_cache_path = GetWeightRefittedEnginePath(engine_cache_path.string());
169-
if (std::filesystem::exists(refitted_engine_cache_path)) {
170-
// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " + refitted_engine_cache_path.string() + " exists.";
171-
engine_cache_path = refitted_engine_cache_path.string();
172-
weight_stripped_engine_refit_ = false;
173-
}
174-
}
67+
size_t num_fused_node_inputs = 0;
68+
size_t num_fused_node_outputs = 0;
69+
RETURN_IF_ERROR(ort_api.Node_GetNumInputs(fused_node_, &num_fused_node_inputs));
70+
RETURN_IF_ERROR(ort_api.Node_GetNumOutputs(fused_node_, &num_fused_node_outputs));
17571

176-
if (!std::filesystem::exists(engine_cache_path)) {
177-
return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL,
178-
std::string("TensorRT EP can't find engine cache: " + engine_cache_path.string() +
179-
". Please make sure engine cache is in the same directory or sub-directory of context model.").c_str());
180-
}
72+
std::vector<const OrtValueInfo*> fused_node_inputs(num_fused_node_inputs);
73+
std::vector<const OrtValueInfo*> fused_node_outputs(num_fused_node_outputs);
74+
RETURN_IF_ERROR(ort_api.Node_GetInputs(fused_node_, fused_node_inputs.data(), fused_node_inputs.size()));
75+
RETURN_IF_ERROR(ort_api.Node_GetOutputs(fused_node_, fused_node_outputs.data(), fused_node_outputs.size()));
18176

182-
std::ifstream engine_file(engine_cache_path.string(), std::ios::binary | std::ios::in);
183-
engine_file.seekg(0, std::ios::end);
184-
size_t engine_size = engine_file.tellg();
185-
engine_file.seekg(0, std::ios::beg);
186-
std::unique_ptr<char[]> engine_buf{new char[engine_size]};
187-
engine_file.read((char*)engine_buf.get(), engine_size);
188-
*(trt_engine_) = std::unique_ptr<nvinfer1::ICudaEngine>(trt_runtime_->deserializeCudaEngine(engine_buf.get(), engine_size));
189-
if (!(*trt_engine_)) {
190-
return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL,
191-
std::string("TensorRT EP could not deserialize engine from cache: " + engine_cache_path.string()).c_str());
192-
}
193-
// LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path.string();
77+
std::vector<const char*> input_names;
78+
std::vector<const char*> output_names;
19479

195-
if (weight_stripped_engine_refit_) {
196-
const char* onnx_model_filename_cstr = nullptr;
197-
graph_api_->OrtNode_GetAttributeStr(node, ONNX_MODEL_FILENAME.c_str(), &onnx_model_filename_cstr);
198-
const std::string onnx_model_filename(onnx_model_filename_cstr);
199-
std::string weight_stripped_engine_cache = engine_cache_path.string();
200-
auto status = TensorrtExecutionProvider::RefitEngine(onnx_model_filename,
201-
onnx_model_folder_path_,
202-
weight_stripped_engine_cache,
203-
true /* path check for security */,
204-
(*trt_engine_).get(),
205-
true /* serialize refitted engine to disk */,
206-
detailed_build_log_);
207-
if (status != nullptr) {
208-
return api_->CreateStatus(OrtErrorCode::ORT_EP_FAIL, api_->GetErrorMessage(status));
209-
}
210-
}
211-
}
212-
return nullptr;
213-
}
80+
RETURN_IF_ERROR(collect_input_output_names(fused_node_inputs, /*out*/ input_names));
81+
RETURN_IF_ERROR(collect_input_output_names(fused_node_outputs, /*out*/ output_names));
21482

215-
bool TensorRTCacheModelHandler::ValidateEPCtxNode(const OrtGraphViewer* graph_viewer) {
216-
int node_count = 0;
217-
graph_api_->OrtGraph_NumberOfNodes(graph_viewer, &node_count);
218-
assert(node_count == 1);
219-
const OrtNode* node = nullptr;
220-
graph_api_->OrtGraph_GetOrtNode(graph_viewer, 0, &node);
221-
const char* opType = nullptr;
222-
graph_api_->OrtNode_GetOpType(node, &opType);
223-
assert(strcmp(opType, EPCONTEXT_OP.c_str()) == 0);
83+
// Create node attributes. The CreateNode() function copies the attributes, so we have to release them.
84+
std::array<OrtOpAttr*, 4> attributes = {};
85+
DeferOrtRelease<OrtOpAttr> defer_release_attrs(attributes.data(), attributes.size(), ort_api.ReleaseOpAttr);
22486

225-
size_t key_count = 0;
226-
graph_api_->OrtNode_GetAttributeKeyCount(node, COMPUTE_CAPABILITY.c_str(), &key_count);
227-
// Show the warning if compute capability is not matched
228-
if (key_count > 0) {
229-
const char* model_compute_capability = nullptr;
230-
graph_api_->OrtNode_GetAttributeStr(node, COMPUTE_CAPABILITY.c_str(), &model_compute_capability);
231-
// Verify if engine was compiled with ampere+ hardware compatibility enabled
232-
if (strcmp(model_compute_capability, "80+") == 0) {
233-
// if (std::stoi(compute_capability_) < 80) {
234-
// LOGS_DEFAULT(WARNING) << "[TensorRT EP] However, this GPU doesn't match. The compute capability of the GPU: " << compute_capability_;
235-
// }
236-
} else if (strcmp(model_compute_capability, compute_capability_.c_str()) != 0) {
237-
// LOGS_DEFAULT(WARNING) << "[TensorRT EP] Engine was compiled for a different compatibility level and might not work or perform suboptimal";
238-
// LOGS_DEFAULT(WARNING) << "[TensorRT EP] The compute capability of the engine: " << model_compute_capability;
239-
// LOGS_DEFAULT(WARNING) << "[TensorRT EP] The compute capability of the GPU: " << compute_capability_;
87+
RETURN_IF_ERROR(ort_api.CreateOpAttr("embed_mode", &embed_mode, 1, ORT_OP_ATTR_INT, &attributes[0]));
88+
89+
std::string engine_data_str = "";
90+
if (embed_mode) {
91+
if (size > 0) {
92+
engine_data_str.assign(engine_data, size);
24093
}
94+
RETURN_IF_ERROR(
95+
ort_api.CreateOpAttr("ep_cache_context", engine_data_str.c_str(), 1, ORT_OP_ATTR_STRING, &attributes[1]));
96+
} else {
97+
RETURN_IF_ERROR(ort_api.CreateOpAttr("ep_cache_context", engine_cache_path.c_str(), 1, ORT_OP_ATTR_STRING, &attributes[1]));
24198
}
24299

243-
// "embed_mode" attr and "ep_cache_context" attr should be present
244-
graph_api_->OrtNode_GetAttributeKeyCount(node, EMBED_MODE.c_str(), &key_count);
245-
assert(key_count > 0);
246-
graph_api_->OrtNode_GetAttributeKeyCount(node, EP_CACHE_CONTEXT.c_str(), &key_count);
247-
assert(key_count > 0);
100+
101+
ort_api.CreateOpAttr("hardware_architecture", compute_capability.c_str(), 1, ORT_OP_ATTR_STRING, &attributes[2]);
102+
ort_api.CreateOpAttr("onnx_model_filename", std::filesystem::path(onnx_model_path).filename().string().c_str(), 1,
103+
ORT_OP_ATTR_STRING, &attributes[3]);
248104

249-
int64_t embed_mode = -1;
250-
graph_api_->OrtNode_GetAttributeInt(node, EMBED_MODE.c_str(), &embed_mode);
251-
if (embed_mode == 1) {
252-
// engine binary data
253-
// LOGS_DEFAULT(WARNING) << EPCONTEXT_WARNING;
254-
}
255105

256-
return true;
257-
}
106+
RETURN_IF_ERROR(model_editor_api.CreateNode("EPContext", "com.microsoft", fused_node_name, input_names.data(),
107+
input_names.size(), output_names.data(), output_names.size(),
108+
attributes.data(), attributes.size(), ep_context_node));
109+
110+
return nullptr;
258111
}

0 commit comments

Comments
 (0)