Skip to content

Commit 24e0b07

Browse files
authored
Cast Nodes Fusion (microsoft#24842)
### Description <!-- Describe your changes. --> We might have a case where multiple Cast nodes in the chain cast back to the original type. This fusion will remove extra nodes. E.g. `A ('float32') -> Cast (to='float16') -> Cast (to='int4') -> Cast (to='float32') -> Cast (to='float16') -> B ` will reduce to ` A ('float32') -> Cast (to='float16') -> B ` All the Cast nodes throughout the path need to have one input and one output to be considered for the fusion. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Gemma3 ONNX models used to have double casting, and many new models created by the model builder might have as well. Extra Casts might reduce accuracy and increase inference time.
1 parent 340b188 commit 24e0b07

File tree

11 files changed

+193
-10
lines changed

11 files changed

+193
-10
lines changed

include/onnxruntime/core/optimizer/graph_transformer_utils.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ namespace optimizer_utils {
3636
TODO: This is visible for testing at the moment, but we should rather make it private. */
3737
InlinedVector<std::unique_ptr<RewriteRule>> GenerateRewriteRules(
3838
TransformerLevel level,
39-
const InlinedHashSet<std::string>& rules_to_disable = {});
39+
const InlinedHashSet<std::string>& rules_to_disable = {},
40+
const bool enable_cast_chain_elimination = false);
4041

4142
/** Given a TransformerLevel, this method generates a name for the rule-based graph transformer of that level. */
4243
std::string GenerateRuleBasedTransformerName(TransformerLevel level);
@@ -45,7 +46,8 @@ std::string GenerateRuleBasedTransformerName(TransformerLevel level);
4546
std::unique_ptr<RuleBasedGraphTransformer> GenerateRuleBasedGraphTransformer(
4647
TransformerLevel level,
4748
const InlinedHashSet<std::string>& rules_to_disable,
48-
const InlinedHashSet<std::string_view>& compatible_execution_providers);
49+
const InlinedHashSet<std::string_view>& compatible_execution_providers,
50+
const bool enable_cast_chain_elimination = false);
4951

5052
/** Generates all predefined (both rule-based and non-rule-based) transformers for this level.
5153
Any transformers or rewrite rules named in rules_and_transformers_to_disable will be excluded. */

include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ static const char* const kOrtSessionOptionsEnableQuantQDQCleanup = "session.enab
6767
// GeluApproximation has side effects which may change the inference results. It is disabled by default due to this.
6868
static const char* const kOrtSessionOptionsEnableGeluApproximation = "optimization.enable_gelu_approximation";
6969

70+
// Enable or disable Cast chain elimination in graph optimization. "0": disable; "1": enable. The default is "0".
71+
// CastElimination with chain elimination has side effects which may change the inference results. It is disabled by default due to this.
72+
static const char* const kOrtSessionOptionsEnableCastChainElimination = "optimization.enable_cast_chain_elimination";
73+
7074
// This setting controls whether to enable AheadOfTime function inlining.
7175
// AOT function inlining examines the graph and attempts to inline as many locally defined functions in the model
7276
// as possible with the help of enabled execution providers.

onnxruntime/core/graph/graph_utils.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,11 @@ bool IsGraphInput(const Graph& graph, const NodeArg* input) {
610610
return std::find(graph_inputs.begin(), graph_inputs.end(), input) != graph_inputs.end();
611611
}
612612

613+
bool IsGraphOutput(const Graph& graph, const NodeArg* output) {
614+
const auto& graph_outputs = graph.GetOutputs();
615+
return std::find(graph_outputs.begin(), graph_outputs.end(), output) != graph_outputs.end();
616+
}
617+
613618
bool IsInitializer(const Graph& graph, const std::string& name, bool check_outer_scope) {
614619
bool is_initializer = false;
615620
const ONNX_NAMESPACE::TensorProto* initializer = nullptr;

onnxruntime/core/graph/graph_utils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,9 @@ bool IsOutputUsed(const Node& node, int index);
132132
/** Returns true if the graph has the given input.*/
133133
bool IsGraphInput(const Graph& graph, const NodeArg* input);
134134

135+
/** Returns true if the graph has the given output.*/
136+
bool IsGraphOutput(const Graph& graph, const NodeArg* output);
137+
135138
/** returns true if 'name' is an initializer in 'graph', or an ancestor graph if check_outer_scope is true.
136139
@param check_outer_scope If true and 'graph' is a subgraph, check ancestor graph/s for 'name' if not found in 'graph'.
137140
*/
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "core/common/logging/logging.h"
5+
#include "core/optimizer/rewrite_rule.h"
6+
#include "core/optimizer/cast_chain_elimination.h"
7+
#include "core/optimizer/utils.h"
8+
#include "core/graph/graph.h"
9+
#include "core/graph/graph_utils.h"
10+
11+
namespace onnxruntime {
12+
13+
Status CastChainElimination::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const {
14+
auto nextNodeIt = node.OutputNodesBegin();
15+
Node* next = graph.GetNode(nextNodeIt->Index());
16+
17+
// We can remove the current node.
18+
graph_utils::RemoveNodeOutputEdges(graph, node);
19+
20+
NodeArg* last_node_output_def = node.MutableOutputDefs()[0];
21+
const std::string& last_node_output_tensor_name = last_node_output_def->Name();
22+
23+
// Find the matching def slot, so we can wire the final node to the input of the removeable node.
24+
int slot = -1;
25+
26+
auto& inputs = next->MutableInputDefs();
27+
for (int i = 0, n = static_cast<int>(inputs.size()); i < n; ++i) {
28+
if (inputs[i]->Name() == last_node_output_tensor_name) {
29+
slot = i;
30+
break;
31+
}
32+
}
33+
34+
next->MutableInputDefs()[slot] = node.MutableInputDefs()[0];
35+
36+
graph_utils::MoveAllNodeInputEdges(graph, node, *next);
37+
38+
graph.RemoveNode(node.Index());
39+
40+
rule_effect = RewriteRuleEffect::kRemovedCurrentNode;
41+
42+
return Status::OK();
43+
}
44+
45+
bool CastChainElimination::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const {
46+
if (!graph_utils::CanRemoveNode(graph, node, logger)) {
47+
return false;
48+
}
49+
50+
// Skip nodes that don't have 1 output edge.
51+
if (node.GetOutputEdgesCount() != 1) {
52+
return false;
53+
}
54+
55+
const auto nextNodeIt = node.OutputNodesBegin();
56+
57+
const Node* next = graph.GetNode(nextNodeIt->Index());
58+
59+
// Skip if the next node is not of type Cast.
60+
if (next->OpType() != "Cast") {
61+
return false;
62+
}
63+
64+
return true;
65+
}
66+
} // namespace onnxruntime
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include "core/optimizer/rewrite_rule.h"
7+
8+
namespace onnxruntime {
9+
10+
/**
11+
@Class CastElimination
12+
The transform that will try to find the longest chain of the type Cast where the 'to' attribute has the same data type as the input of the first Cast node in the chain.
13+
E.g.
14+
A ('float32') -> Cast (to='float16') -> Cast (to='int4') -> Cast (to='float32') -> Cast (to='float16') -> B
15+
will reduce to
16+
A ('float32') -> Cast (to='float16') -> B
17+
18+
All the Cast nodes throughout the path need to have one input and one output to be considered for the fusion.
19+
*/
20+
class CastChainElimination : public RewriteRule {
21+
public:
22+
CastChainElimination() noexcept : RewriteRule("CastChainElimination") {}
23+
24+
std::vector<std::string> TargetOpTypes() const noexcept override {
25+
return {"Cast"};
26+
}
27+
28+
private:
29+
bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override;
30+
31+
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;
32+
};
33+
34+
} // namespace onnxruntime

onnxruntime/core/optimizer/cast_elimination.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,4 @@ bool CastElimination::SatisfyCondition(const Graph& graph, const Node& node, con
3131
return optimizer_utils::IsAttributeWithExpectedValue(node, "to", static_cast<int64_t>(input_type->tensor_type().elem_type()));
3232
}
3333

34-
} // namespace onnxruntime
34+
} // namespace onnxruntime

onnxruntime/core/optimizer/cast_elimination.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,4 @@ class CastElimination : public RewriteRule {
2828
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;
2929
};
3030

31-
} // namespace onnxruntime
31+
} // namespace onnxruntime

onnxruntime/core/optimizer/graph_transformer_utils.cc

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "core/optimizer/bias_gelu_fusion.h"
2525
#include "core/optimizer/bias_softmax_fusion.h"
2626
#include "core/optimizer/cast_elimination.h"
27+
#include "core/optimizer/cast_chain_elimination.h"
2728
#include "core/optimizer/common_subexpression_elimination.h"
2829
#include "core/optimizer/constant_folding.h"
2930
#include "core/optimizer/constant_sharing.h"
@@ -115,8 +116,10 @@ std::string GenerateRuleBasedTransformerName(TransformerLevel level) {
115116

116117
InlinedVector<std::unique_ptr<RewriteRule>> GenerateRewriteRules(
117118
TransformerLevel level,
118-
const InlinedHashSet<std::string>& rules_to_disable) {
119+
const InlinedHashSet<std::string>& rules_to_disable,
120+
const bool enable_cast_chain_elimination) {
119121
InlinedVector<std::unique_ptr<RewriteRule>> rules;
122+
120123
switch (level) {
121124
case TransformerLevel::Level1:
122125
rules.push_back(std::make_unique<EliminateIdentity>());
@@ -125,6 +128,9 @@ InlinedVector<std::unique_ptr<RewriteRule>> GenerateRewriteRules(
125128
rules.push_back(std::make_unique<EliminateDropout>());
126129
rules.push_back(std::make_unique<ExpandElimination>());
127130
rules.push_back(std::make_unique<CastElimination>());
131+
if (enable_cast_chain_elimination) {
132+
rules.push_back(std::make_unique<CastChainElimination>());
133+
}
128134
rules.push_back(std::make_unique<PreShapeNodeElimination>());
129135
rules.push_back(std::make_unique<NoopElimination>());
130136
rules.push_back(std::make_unique<DivMulFusion>());
@@ -175,8 +181,9 @@ InlinedVector<std::unique_ptr<RewriteRule>> GenerateRewriteRules(
175181
std::unique_ptr<RuleBasedGraphTransformer> GenerateRuleBasedGraphTransformer(
176182
TransformerLevel level,
177183
const InlinedHashSet<std::string>& rules_to_disable,
178-
const InlinedHashSet<std::string_view>& compatible_execution_providers) {
179-
auto rewrite_rules_to_register = GenerateRewriteRules(level, rules_to_disable);
184+
const InlinedHashSet<std::string_view>& compatible_execution_providers,
185+
const bool enable_cast_chain_elimination) {
186+
auto rewrite_rules_to_register = GenerateRewriteRules(level, rules_to_disable, enable_cast_chain_elimination);
180187
if (rewrite_rules_to_register.empty()) {
181188
return nullptr;
182189
}
@@ -202,6 +209,8 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
202209
InlinedVector<std::unique_ptr<GraphTransformer>> transformers;
203210
const bool disable_quant_qdq =
204211
session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsDisableQuantQDQ, "0") == "1";
212+
const bool enable_cast_chain_elimination =
213+
session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsEnableCastChainElimination, "0") == "1";
205214
#ifndef DISABLE_CONTRIB_OPS
206215
const InlinedHashSet<std::string_view> cpu_ep = {onnxruntime::kCpuExecutionProvider};
207216
const InlinedHashSet<std::string_view> cpu_acl_eps = {onnxruntime::kCpuExecutionProvider,
@@ -215,7 +224,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
215224
// RewriteRule optimizations are the simplest (they generally remove unnecessary nodes and are cheap to run)
216225
// so run them first so there is potentially less for the more intensive optimizations like ConstantFolding,
217226
// CommonSubexpressionElimination and TransposeOptimizer to do.
218-
auto rule_transformer = GenerateRuleBasedGraphTransformer(level, rules_and_transformers_to_disable, {});
227+
auto rule_transformer = GenerateRuleBasedGraphTransformer(level, rules_and_transformers_to_disable, {}, enable_cast_chain_elimination);
219228
if (rule_transformer != nullptr) {
220229
transformers.emplace_back(std::move(rule_transformer));
221230
}
@@ -269,7 +278,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
269278
} break;
270279

271280
case TransformerLevel::Level2: {
272-
auto rule_transformer = GenerateRuleBasedGraphTransformer(level, rules_and_transformers_to_disable, {});
281+
auto rule_transformer = GenerateRuleBasedGraphTransformer(level, rules_and_transformers_to_disable, {}, enable_cast_chain_elimination);
273282
if (rule_transformer != nullptr) {
274283
transformers.emplace_back(std::move(rule_transformer));
275284
}

onnxruntime/test/optimizer/graph_transform_test.cc

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "core/optimizer/bias_gelu_fusion.h"
2626
#include "core/optimizer/bias_softmax_fusion.h"
2727
#include "core/optimizer/cast_elimination.h"
28+
#include "core/optimizer/cast_chain_elimination.h"
2829
#include "core/optimizer/common_subexpression_elimination.h"
2930
#include "core/optimizer/concat_slice_elimination.h"
3031
#include "core/optimizer/constant_folding.h"
@@ -4362,7 +4363,7 @@ TEST_F(GraphTransformationTests, ExpandElimination) {
43624363
ASSERT_TRUE(op_to_count["Expand"] == 3);
43634364
}
43644365

4365-
TEST_F(GraphTransformationTests, CastElimination) {
4366+
TEST_F(GraphTransformationTests, CastEliminationSimple) {
43664367
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "cast_elimination.onnx";
43674368
std::shared_ptr<Model> model;
43684369
ASSERT_TRUE(Model::Load(model_uri, model, nullptr, *logger_).IsOK());
@@ -4380,6 +4381,25 @@ TEST_F(GraphTransformationTests, CastElimination) {
43804381
ASSERT_TRUE(op_to_count["Cast"] == 4);
43814382
}
43824383

4384+
TEST_F(GraphTransformationTests, CastChainEliminationRepeatedPattern) {
4385+
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "cast_elimination_complex.onnx";
4386+
4387+
std::shared_ptr<Model> model;
4388+
ASSERT_TRUE(Model::Load(model_uri, model, nullptr, *logger_).IsOK());
4389+
Graph& graph = model->MainGraph();
4390+
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
4391+
ASSERT_TRUE(op_to_count["Cast"] == 7);
4392+
4393+
auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformer1");
4394+
ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique<CastChainElimination>()));
4395+
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
4396+
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1));
4397+
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
4398+
4399+
op_to_count = CountOpsInGraph(graph);
4400+
ASSERT_TRUE(op_to_count["Cast"] == 3);
4401+
}
4402+
43834403
TEST_F(GraphTransformationTests, PreShapeNodeElimination) {
43844404
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "pre_shape_node_elimination.onnx";
43854405
std::shared_ptr<Model> model;

0 commit comments

Comments
 (0)