Skip to content

Commit 6413908

Browse files
authored
Skip hash computation for EPContext models (microsoft#25106)
### Description Add early exit in ComputeModelGraphHash when EPContext nodes are present, returning "0" to indicate pre-compiled model state. Conditionally skip ComputeModelWeightHash when graph hash is "0" to avoid unnecessary computation for pre-compiled models. This optimization reduces overhead for models containing EPContext nodes, which represent execution provider pre-compiled subgraphs. ### Motivation and Context Currently, the hash generated by ComputeModelGraphHash function when the graph contains EPContext nodes does not correctly represent the graph because we do not hash the contents of the context pointed to by the EPContext node. Thus, it makes more sense to skip hashing for cases involving EPContext nodes.
1 parent 0ef8213 commit 6413908

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

onnxruntime/core/session/inference_session.cc

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1774,11 +1774,17 @@ static bool ModelHasFP16Inputs(const Graph& graph) {
17741774

17751775
#ifdef _WIN32
17761776
[[maybe_unused]] static std::string ComputeModelGraphHash(const Graph& graph) {
1777-
std::size_t final_hash = 0;
1778-
const std::size_t node_hash_count = TelemetrySampleCount;
1777+
// Skip hashing if the graph contains an EPContext node.
1778+
const auto& nodes = graph.Nodes();
1779+
for (const auto& node : nodes) {
1780+
if (node.OpType() == "EPContext") {
1781+
return "0";
1782+
}
1783+
}
17791784

17801785
// Graph Hash
1781-
const auto& nodes = graph.Nodes();
1786+
std::size_t final_hash = 0;
1787+
const std::size_t node_hash_count = TelemetrySampleCount;
17821788
const std::size_t total_nodes = graph.NumberOfNodes();
17831789
const std::size_t node_step = (total_nodes > node_hash_count) ? (total_nodes / node_hash_count) : 1;
17841790

@@ -2077,7 +2083,7 @@ common::Status InferenceSession::Initialize() {
20772083
#endif
20782084
#ifdef _WIN32
20792085
model_graph_hash = ComputeModelGraphHash(graph);
2080-
model_weight_hash = ComputeModelWeightHash(initializers);
2086+
model_weight_hash = (model_graph_hash == "0") ? "0" : ComputeModelWeightHash(initializers);
20812087
SetGraphHash(model_graph_hash);
20822088
SetWeightHash(model_weight_hash);
20832089
#endif

0 commit comments

Comments
 (0)