Skip to content

Commit 1865637

Browse files
qti-hungjuiwCopilot
authored andcommitted
[QNN-EP] Add CastLoneQFusion to transform Cast and QNode into Convert (microsoft#25667)
### Description - Introduced `CastLoneQFusion` in QNNEP to fuse `Cast` followed by `QuantizeLinear` into a single `Convert` operation. - Added corresponding test cases for **UINT8-to-FLOAT** `Cast` combined with `QuantizeLinear`, covering various **QuantType** scenarios. ### Motivation and Context - To optimize the model by reducing unnecessary QDQ nodes, this fusion transformation has been implemented. --------- Co-authored-by: Copilot <[email protected]>
1 parent 8755a0f commit 1865637

File tree

4 files changed

+221
-17
lines changed

4 files changed

+221
-17
lines changed
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "core/providers/qnn/builder/qnn_node_group/cast_lone_q_fusion.h"
5+
#include "core/providers/qnn/builder/qnn_model_wrapper.h"
6+
#include "core/providers/qnn/builder/qnn_node_group/utils.h"
7+
8+
namespace onnxruntime {
9+
namespace qnn {
10+
11+
constexpr char kOpCast[] = "Cast";
12+
constexpr char kOpConvert[] = "Convert";
13+
14+
Status CreateOrValidateOnQnn(
15+
QnnModelWrapper* qnn_model_wrapper,
16+
gsl::span<const NodeUnit* const> node_units,
17+
[[maybe_unused]] const logging::Logger& logger,
18+
bool validate) {
19+
const NodeUnit* cast = node_units[0];
20+
const NodeUnit* quantize_linear = node_units[1];
21+
22+
// ProcessInputs
23+
const auto& input_name = cast->Inputs()[0].node_arg.Name();
24+
if (!qnn_model_wrapper->IsQnnTensorWrapperExist(input_name)) {
25+
TensorInfo cast_node_input_info = {};
26+
ORT_RETURN_IF_ERROR(qnn_model_wrapper->GetTensorInfo(cast->Inputs()[0], cast_node_input_info));
27+
QnnTensorWrapper input_tensor_wrapper;
28+
ORT_RETURN_IF_ERROR(qnn_model_wrapper->MakeTensorWrapper(cast_node_input_info, input_name, input_tensor_wrapper));
29+
ORT_RETURN_IF_NOT(qnn_model_wrapper->AddTensorWrapper(std::move(input_tensor_wrapper)),
30+
"Failed to add input tensor for QNN Convert node.");
31+
}
32+
// ProcessAttributesAndOutputs
33+
const auto& output_name = quantize_linear->Outputs()[0].node_arg.Name();
34+
TensorInfo q_node_output_info = {};
35+
ORT_RETURN_IF_ERROR(qnn_model_wrapper->GetTensorInfo(quantize_linear->Outputs()[0], q_node_output_info));
36+
QnnTensorWrapper output_tensor_wrapper;
37+
ORT_RETURN_IF_ERROR(qnn_model_wrapper->MakeTensorWrapper(q_node_output_info, output_name, output_tensor_wrapper));
38+
ORT_RETURN_IF_NOT(qnn_model_wrapper->AddTensorWrapper(std::move(output_tensor_wrapper)),
39+
"Failed to add output tensor for QNN Convert node.");
40+
ORT_RETURN_IF_NOT(qnn_model_wrapper->CreateQnnNode(cast->Name() + "_ort_qnn_ep_convert",
41+
QNN_OP_PACKAGE_NAME_QTI_AISW,
42+
QNN_OP_CONVERT,
43+
{input_name},
44+
{output_name},
45+
{},
46+
validate),
47+
"Failed to add fused " + std::string(kOpConvert) + " node.");
48+
49+
return Status::OK();
50+
}
51+
52+
std::unique_ptr<IQnnNodeGroup> CastLoneQFusion::TryFusion(
53+
QnnModelWrapper& qnn_model_wrapper,
54+
const NodeUnit& cast_node_unit,
55+
const std::unordered_map<const Node*, const NodeUnit*>& node_to_node_unit,
56+
const std::unordered_map<const NodeUnit*, const IQnnNodeGroup*>& node_unit_to_qnn_node_group,
57+
[[maybe_unused]] const logging::Logger& logger) {
58+
if (cast_node_unit.OpType() != kOpCast || cast_node_unit.UnitType() != NodeUnit::Type::SingleNode) {
59+
return nullptr;
60+
}
61+
62+
// Transform the pattern Non-DQ Node -> Cast -> Q into Non-DQ Node -> Convert
63+
const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer();
64+
const std::array<std::string_view, 1> child_op_types{QUANTIZE_LINEAR};
65+
const NodeUnit* quantize_linear = GetOnlyChildOfType(
66+
graph_viewer, cast_node_unit, child_op_types,
67+
node_to_node_unit, node_unit_to_qnn_node_group);
68+
const std::array<std::string_view, 1> parent_op_types{DEQUANTIZE_LINEAR};
69+
const NodeUnit* dequantize_linear = GetParentOfType(
70+
graph_viewer, cast_node_unit, parent_op_types,
71+
node_to_node_unit, node_unit_to_qnn_node_group);
72+
73+
if (quantize_linear == nullptr || dequantize_linear != nullptr) {
74+
return nullptr;
75+
}
76+
77+
// Skip Constant cast
78+
if (qnn_model_wrapper.IsConstantInput(cast_node_unit.Inputs()[0].node_arg.Name())) {
79+
return nullptr;
80+
}
81+
std::array<const NodeUnit*, 2> node_unit_array{&cast_node_unit, quantize_linear};
82+
auto node_units = gsl::make_span<const NodeUnit*>(node_unit_array.data(), 2);
83+
84+
if (CreateOrValidateOnQnn(&qnn_model_wrapper, node_units, logger, /*validate=*/true) != Status::OK()) {
85+
return nullptr;
86+
}
87+
return std::make_unique<CastLoneQFusion>(node_units);
88+
}
89+
90+
gsl::span<const NodeUnit* const> CastLoneQFusion::GetNodeUnits() const {
91+
return gsl::span<const NodeUnit* const>{node_units_.data(), node_units_.size()};
92+
}
93+
94+
Status CastLoneQFusion::IsSupported(
95+
QnnModelWrapper& qnn_model_wrapper, [[maybe_unused]] const logging::Logger& logger) const {
96+
return CreateOrValidateOnQnn(&qnn_model_wrapper, GetNodeUnits(), logger, true);
97+
}
98+
99+
Status CastLoneQFusion::AddToModelBuilder(
100+
QnnModelWrapper& qnn_model_wrapper, [[maybe_unused]] const logging::Logger& logger) const {
101+
return CreateOrValidateOnQnn(&qnn_model_wrapper, GetNodeUnits(), logger, false);
102+
}
103+
104+
} // namespace qnn
105+
} // namespace onnxruntime
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "core/providers/qnn/builder/qnn_node_group/qnn_node_group.h"
5+
6+
namespace onnxruntime {
7+
namespace qnn {
8+
/// <summary>
9+
/// Represents a fusion of pattern: Quantize(Cast(x)) => Convert(x)
10+
/// when x is not the output of Dequantize
11+
/// </summary>
12+
class CastLoneQFusion : public IQnnNodeGroup {
13+
public:
14+
explicit CastLoneQFusion(gsl::span<const NodeUnit* const> node_units) {
15+
ORT_ENFORCE(node_units.size() == 2, "Pattern expect exactly 2 NodeUnits.");
16+
node_units_[0] = node_units[0];
17+
node_units_[1] = node_units[1];
18+
}
19+
20+
ORT_DISALLOW_COPY_AND_ASSIGNMENT(CastLoneQFusion);
21+
22+
Status IsSupported(QnnModelWrapper& qnn_model_wrapper, const logging::Logger& logger) const override;
23+
Status AddToModelBuilder(QnnModelWrapper& qnn_model_wrapper, const logging::Logger& logger) const override;
24+
gsl::span<const NodeUnit* const> GetNodeUnits() const override;
25+
const NodeUnit* GetTargetNodeUnit() const override { return node_units_[0]; }
26+
std::string_view Type() const override { return "CastLoneQFusion"; }
27+
28+
/// <summary>
29+
/// Traverses graph to check if the given starting NodeUnit is part of a valid Cast -> Quantize sequence.
30+
/// If so, returns a IQnnNodeGroup that contains the Cast and Quantize NodeUnits.
31+
/// </summary>
32+
static std::unique_ptr<IQnnNodeGroup> TryFusion(
33+
QnnModelWrapper& qnn_model_wrapper,
34+
const NodeUnit& mul_node_unit,
35+
const std::unordered_map<const Node*, const NodeUnit*>& node_to_node_unit,
36+
const std::unordered_map<const NodeUnit*, const IQnnNodeGroup*>& node_unit_to_qnn_node_group,
37+
const logging::Logger& logger);
38+
39+
private:
40+
std::array<const NodeUnit*, 2> node_units_;
41+
};
42+
43+
} // namespace qnn
44+
} // namespace onnxruntime

onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "core/providers/qnn/builder/qnn_node_group/qnn_node_group.h"
1717
#include "core/providers/qnn/builder/qnn_node_group/reshape_gemm_fusion.h"
1818
#include "core/providers/qnn/builder/qnn_node_group/scale_softmax_fusion.h"
19+
#include "core/providers/qnn/builder/qnn_node_group/cast_lone_q_fusion.h"
1920
#include "core/providers/qnn/builder/qnn_node_group/channel_shuffle_fusion.h"
2021
#include "core/providers/qnn/builder/qnn_node_group/udo_fusion.h"
2122
#include "core/providers/qnn/builder/qnn_node_group/lpbqgemm_fusion.h"
@@ -80,6 +81,7 @@ static std::unordered_map<std::string, std::vector<FusionFunc>> fusions = {
8081
{"MatMul", {LowPowerBlockQuantizedMatMulFusion::TryFusion}},
8182
{"Gemm", {LowPowerBlockQuantizedGemmFusion::TryFusion, ReshapeGemmFusion::TryFusion}},
8283
{"Mul", {ScaleSoftmaxFusion::TryFusion}},
84+
{"Cast", {CastLoneQFusion::TryFusion}},
8385
{"Transpose", {ChannelShuffleFusion::TryFusion}}};
8486

8587
void registerUDO(const std::string& node_type, const std::string& op_package) {

onnxruntime/test/providers/qnn/qnn_basic_test.cc

Lines changed: 70 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1042,10 +1042,11 @@ TEST_F(QnnHTPBackendTests, QnnContextPriorityHigh) {
10421042
// cast_input -> Cast -> Q -> DQ ----
10431043
// |
10441044
// input2 -> Q -> DQ -> Add -> Q -> DQ -> output
1045-
static GetTestModelFn BuildCastAddTestCase() {
1046-
return [](ModelTestBuilder& builder) {
1045+
template <typename InputType, typename QuantType>
1046+
static GetTestQDQModelFn<QuantType> BuildCastAddQDQTestCase() {
1047+
return [](ModelTestBuilder& builder, std::vector<QuantParams<QuantType>>& output_qparams) {
10471048
// Creat Cast node int32 -> float32
1048-
NodeArg* cast_input = MakeTestInput(builder, TestInputDef<int32_t>({2, 3}, false, {0, 1, 0, 1, 0, 1}));
1049+
NodeArg* cast_input = MakeTestInput(builder, TestInputDef<InputType>({2, 3}, false, {0, 1, 0, 1, 0, 1}));
10491050

10501051
auto* cast_output = builder.MakeIntermediate();
10511052
Node& cast_node = builder.AddNode("Cast", {cast_input}, {cast_output});
@@ -1054,18 +1055,36 @@ static GetTestModelFn BuildCastAddTestCase() {
10541055
// Create Add node
10551056
std::vector<float> data = {0.0f, 0.0f, 1.0f, 0.0f, 1.0f, 0.0f};
10561057
gsl::span<float> data_range = gsl::make_span(data);
1057-
QuantParams<uint8_t> q_parameter = GetDataQuantParams<uint8_t>(data_range);
1058-
auto* add_input1_qdq = AddQDQNodePair<uint8_t>(builder, cast_output, q_parameter.scale, q_parameter.zero_point);
1058+
QuantParams<QuantType> q_parameter = GetDataQuantParams<QuantType>(data_range);
1059+
auto* add_input1_qdq = AddQDQNodePair<QuantType>(builder, cast_output, q_parameter.scale, q_parameter.zero_point);
10591060

10601061
NodeArg* add_input2 = MakeTestInput(builder, TestInputDef<float>({2, 3}, false, data));
1061-
auto* add_input2_qdq = AddQDQNodePair<uint8_t>(builder, add_input2, q_parameter.scale, q_parameter.zero_point);
1062+
auto* add_input2_qdq = AddQDQNodePair<QuantType>(builder, add_input2, q_parameter.scale, q_parameter.zero_point);
10621063

10631064
auto* add_output = builder.MakeIntermediate();
10641065

10651066
builder.AddNode("Add", {add_input1_qdq, add_input2_qdq}, {add_output});
10661067

10671068
// add_output -> Q -> DQ -> output
1068-
AddQDQNodePairWithOutputAsGraphOutput<uint8_t>(builder, add_output, q_parameter.scale, q_parameter.zero_point);
1069+
AddQDQNodePairWithOutputAsGraphOutput<QuantType>(builder, add_output, output_qparams[0].scale, output_qparams[0].zero_point);
1070+
};
1071+
}
1072+
1073+
template <typename InputType>
1074+
static GetTestModelFn BuildCastAddTestCase() {
1075+
return [](ModelTestBuilder& builder) {
1076+
// Creat Cast node int32 -> float32
1077+
NodeArg* cast_input = MakeTestInput(builder, TestInputDef<InputType>({2, 3}, false, {0, 1, 0, 1, 0, 1}));
1078+
1079+
auto* cast_output = builder.MakeIntermediate();
1080+
Node& cast_node = builder.AddNode("Cast", {cast_input}, {cast_output});
1081+
cast_node.AddAttribute("to", static_cast<int64_t>(ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT));
1082+
1083+
// Create Add node
1084+
NodeArg* add_input2 = MakeTestInput(builder, TestInputDef<float>({2, 3}, false, {0.0f, 0.0f, 1.0f, 0.0f, 1.0f, 0.0f}));
1085+
auto* add_output = builder.MakeOutput();
1086+
1087+
builder.AddNode("Add", {cast_output, add_input2}, {add_output});
10691088
};
10701089
}
10711090

@@ -1091,19 +1110,53 @@ TEST_F(QnnHTPBackendTests, ProfilingTest) {
10911110
0.008f);
10921111
}
10931112

1094-
TEST_F(QnnHTPBackendTests, CastAddHTPAccuracyTest) {
1113+
TEST_F(QnnHTPBackendTests, CastAddQDQU8) {
10951114
ProviderOptions provider_options;
1096-
#if defined(_WIN32)
1097-
provider_options["backend_path"] = "QnnHtp.dll";
1098-
#else
1099-
provider_options["backend_path"] = "libQnnHtp.so";
1100-
#endif
1115+
provider_options["backend_type"] = "htp";
11011116
provider_options["offload_graph_io_quantization"] = "0";
11021117

1103-
RunQnnModelTest(BuildCastAddTestCase(),
1104-
provider_options,
1105-
13, // opset
1106-
ExpectedEPNodeAssignment::All);
1118+
TestQDQModelAccuracy<uint8_t>(BuildCastAddTestCase<uint8_t>(),
1119+
BuildCastAddQDQTestCase<uint8_t, uint8_t>(),
1120+
provider_options,
1121+
21,
1122+
ExpectedEPNodeAssignment::All);
1123+
}
1124+
1125+
TEST_F(QnnHTPBackendTests, CastAddQDQU16) {
1126+
ProviderOptions provider_options;
1127+
provider_options["backend_type"] = "htp";
1128+
provider_options["offload_graph_io_quantization"] = "0";
1129+
1130+
TestQDQModelAccuracy<uint16_t>(BuildCastAddTestCase<uint8_t>(),
1131+
BuildCastAddQDQTestCase<uint8_t, uint16_t>(),
1132+
provider_options,
1133+
21,
1134+
ExpectedEPNodeAssignment::All);
1135+
}
1136+
1137+
TEST_F(QnnHTPBackendTests, CastAddQDQS8) {
1138+
ProviderOptions provider_options;
1139+
provider_options["backend_type"] = "htp";
1140+
provider_options["offload_graph_io_quantization"] = "0";
1141+
1142+
TestQDQModelAccuracy<int8_t>(BuildCastAddTestCase<uint8_t>(),
1143+
BuildCastAddQDQTestCase<uint8_t, int8_t>(),
1144+
provider_options,
1145+
21,
1146+
ExpectedEPNodeAssignment::All);
1147+
}
1148+
1149+
TEST_F(QnnHTPBackendTests, CastAddQDQS16) {
1150+
ProviderOptions provider_options;
1151+
provider_options["backend_type"] = "htp";
1152+
provider_options["offload_graph_io_quantization"] = "0";
1153+
1154+
TestQDQModelAccuracy<int16_t>(BuildCastAddTestCase<uint8_t>(),
1155+
BuildCastAddQDQTestCase<uint8_t, int16_t>(),
1156+
provider_options,
1157+
21,
1158+
// QNN has not yet supported S16 Quantize/Dequantize
1159+
ExpectedEPNodeAssignment::Some);
11071160
}
11081161

11091162
// Test float32 model with FP16 precision

0 commit comments

Comments
 (0)