Skip to content

Commit 0385779

Browse files
[QNN EP] Add Support for Reciprocal Op in QNN EP (microsoft#25035)
- Implemented ReciprocalOpBuilder to support ONNX Reciprocal Op in QNN EP. - Decomposed Reciprocal into a Div Op. - Added unit tests to run Reciprocal Op on HTP ### Description Adds support for the ONNX Reciprocal operator in QNN EP via Div decomposition. ### Motivation and Context Enables execution of models using Reciprocal on QNN backend, improving Op support.
1 parent 13c0631 commit 0385779

File tree

7 files changed

+196
-0
lines changed

7 files changed

+196
-0
lines changed

onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,37 @@ bool EinsumNodeGroupSelector::Check(const GraphViewer& graph_viewer,
433433
return true;
434434
}
435435

436+
bool ReciprocalNodeGroupSelector::Check(const GraphViewer& graph_viewer,
437+
const Node& node, const Node* redundant_clip_node,
438+
const std::vector<const Node*>& dq_nodes,
439+
const std::vector<const Node*>& q_nodes) const {
440+
if (!CheckQDQNodes(graph_viewer, node, redundant_clip_node, dq_nodes, q_nodes, /*num_dq_inputs=*/-1,
441+
/*is_empty_q_nodes_allowed=*/true)) {
442+
return false;
443+
}
444+
size_t num_dq_inputs = dq_nodes.size();
445+
for (size_t i = 0; i < num_dq_inputs; ++i) {
446+
int32_t dt_input = dq_nodes[i]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
447+
if (!allow_int8_ && dt_input == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) {
448+
return false;
449+
}
450+
if (!allow_16bit_ && Is16BitIntType(dt_input)) {
451+
return false;
452+
}
453+
if (!allow_4bit_ && Is4BitIntType(dt_input)) {
454+
return false;
455+
}
456+
}
457+
if (!q_nodes.empty()) {
458+
int32_t dt_input0 = dq_nodes[0]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
459+
int32_t dt_output = q_nodes[0]->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
460+
if (dt_input0 != dt_output) {
461+
return false;
462+
}
463+
}
464+
return true;
465+
}
466+
436467
bool MatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, const Node* redundant_clip_node,
437468
const std::vector<const Node*>& dq_nodes,
438469
const std::vector<const Node*>& q_nodes) const {

onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,21 @@ class EinsumNodeGroupSelector : public NodeGroupSelector {
198198
bool allow_4bit_;
199199
};
200200

201+
class ReciprocalNodeGroupSelector : public NodeGroupSelector {
202+
public:
203+
explicit ReciprocalNodeGroupSelector(bool allow_int8 = true, bool allow_16bit = true, bool allow_4bit = true)
204+
: allow_int8_(allow_int8), allow_16bit_(allow_16bit), allow_4bit_(allow_4bit) {}
205+
206+
private:
207+
bool Check(const GraphViewer& graph_viewer,
208+
const Node& node, const Node* redundant_clip_node,
209+
const std::vector<const Node*>& dq_nodes,
210+
const std::vector<const Node*>& q_nodes) const override;
211+
bool allow_int8_;
212+
bool allow_16bit_;
213+
bool allow_4bit_;
214+
};
215+
201216
// 2 DQ nodes for input -> node -> optional Q if QLinearMatMul, MatMulIntegerToFloat if not
202217
// The lack of a trailing Q isn't really a QDQ node group, so we default support for that to off.
203218
class MatMulNodeGroupSelector : public NodeGroupSelector {

onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,11 @@ static const OpVersionsAndSelector::OpVersionsMap GetConvTransposeOpVersionsMap(
116116
static const OpVersionsAndSelector::OpVersionsMap GetEinsumOpVersionsMap() {
117117
return {{"Einsum", {}}};
118118
}
119+
120+
static const OpVersionsAndSelector::OpVersionsMap GetReciprocalOpVersionsMap() {
121+
return {{"Reciprocal", {}}};
122+
}
123+
119124
static const OpVersionsAndSelector::OpVersionsMap GetMatMulOpVersionsMap() {
120125
return {{"MatMul", {}}};
121126
}
@@ -215,6 +220,13 @@ void RegisterEinsumSelector(Selectors& qdq_selectors) {
215220
std::move(selector));
216221
}
217222

223+
void RegisterReciprocalSelector(Selectors& qdq_selectors) {
224+
/* register selector for Reciprocal op */
225+
std::unique_ptr<NodeGroupSelector> selector = std::make_unique<ReciprocalNodeGroupSelector>();
226+
qdq_selectors.RegisterSelector(GetReciprocalOpVersionsMap(),
227+
std::move(selector));
228+
}
229+
218230
void RegisterMatMulSelector(Selectors& qdq_selectors) {
219231
/* register selector for matmul op */
220232
std::unique_ptr<NodeGroupSelector> selector = std::make_unique<MatMulNodeGroupSelector>();
@@ -288,6 +300,7 @@ void SelectorManager::CreateSelectors() {
288300
RegisterConvSelector(qdq_selectors_);
289301
RegisterConvTransposeSelector(qdq_selectors_);
290302
RegisterEinsumSelector(qdq_selectors_);
303+
RegisterReciprocalSelector(qdq_selectors_);
291304
RegisterMatMulSelector(qdq_selectors_);
292305
RegisterGemmSelector(qdq_selectors_);
293306
RegisterInstanceAndLayerNormalizationSelector(qdq_selectors_);

onnxruntime/core/providers/qnn/builder/op_builder_factory.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,10 @@ OpBuilderRegistrations::OpBuilderRegistrations() {
166166
CreateTransposeOpBuilder("Transpose", *this);
167167
}
168168

169+
{
170+
CreateReciprocalOpBuilder("Reciprocal", *this);
171+
}
172+
169173
{
170174
CreatePadOpBuilder("Pad", *this);
171175
}

onnxruntime/core/providers/qnn/builder/op_builder_factory.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ void CreateLRNOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_r
9393

9494
void CreateTransposeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
9595

96+
void CreateReciprocalOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
97+
9698
void CreatePadOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
9799

98100
void CreateExpandOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "core/providers/qnn/builder/opbuilder/base_op_builder.h"
5+
#include "core/providers/qnn/builder/qnn_model_wrapper.h"
6+
#include "core/providers/qnn/builder/op_builder_factory.h"
7+
#include "core/providers/qnn/builder/qnn_utils.h"
8+
9+
namespace onnxruntime {
10+
namespace qnn {
11+
12+
class ReciprocalOpBuilder : public BaseOpBuilder {
13+
public:
14+
ReciprocalOpBuilder() : BaseOpBuilder("ReciprocalOpBuilder") {}
15+
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ReciprocalOpBuilder);
16+
17+
Status IsOpSupported(QnnModelWrapper& qnn_model_wrapper,
18+
const NodeUnit& node_unit,
19+
const logging::Logger& logger) const override final ORT_MUST_USE_RESULT;
20+
21+
protected:
22+
Status ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper,
23+
const NodeUnit& node_unit,
24+
std::vector<std::string>&& input_names,
25+
const logging::Logger& logger,
26+
bool do_op_validation) const override ORT_MUST_USE_RESULT;
27+
};
28+
29+
Status ReciprocalOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper,
30+
const NodeUnit& node_unit,
31+
const logging::Logger& logger) const {
32+
ORT_UNUSED_PARAMETER(logger);
33+
34+
const auto& inputs = node_unit.Inputs();
35+
ORT_RETURN_IF_NOT(inputs.size() == 1, "Reciprocal operator must have exactly 1 input.");
36+
37+
const auto& outputs = node_unit.Outputs();
38+
ORT_RETURN_IF_NOT(outputs.size() == 1, "Reciprocal operator must have exactly 1 output.");
39+
40+
// Check input type is float for CPU.
41+
ORT_RETURN_IF_ERROR(DataTypeCheckForCpuBackend(qnn_model_wrapper, inputs[0].node_arg.Type()));
42+
43+
return Status::OK();
44+
}
45+
46+
Status ReciprocalOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper,
47+
const NodeUnit& node_unit,
48+
std::vector<std::string>&& input_names,
49+
const logging::Logger& logger,
50+
bool do_op_validation) const {
51+
ORT_UNUSED_PARAMETER(logger);
52+
53+
// Create a constant tensor for the divisor (1.0)
54+
std::string divisor_name = node_unit.Name() + "_divisor";
55+
std::vector<uint32_t> divisor_shape{1};
56+
std::vector<uint8_t> divisor_data;
57+
58+
TensorInfo input_info{};
59+
ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(node_unit.Inputs()[0], input_info));
60+
61+
QnnQuantParamsWrapper divisor_quant_param = input_info.quant_param.Copy();
62+
Qnn_DataType_t divisor_qnn_data_type = input_info.qnn_data_type;
63+
64+
if (input_info.quant_param.IsQuantized()) {
65+
// Create a quantized divisor tensor
66+
double divisor_value = 1.0;
67+
int quantized_divisor_value;
68+
ORT_RETURN_IF_ERROR(utils::Quantize(divisor_value, divisor_quant_param.Get().scaleOffsetEncoding.scale,
69+
divisor_quant_param.Get().scaleOffsetEncoding.offset,
70+
divisor_qnn_data_type, quantized_divisor_value));
71+
size_t element_size = qnn::utils::GetElementSizeByType(divisor_qnn_data_type);
72+
divisor_data.resize(element_size);
73+
std::memcpy(divisor_data.data(), &quantized_divisor_value, element_size);
74+
} else {
75+
// Create a float divisor tensor
76+
divisor_data.resize(sizeof(float));
77+
float one = 1.0f;
78+
std::memcpy(divisor_data.data(), &one, sizeof(float));
79+
}
80+
81+
QnnTensorWrapper divisor_tensorwrapper(divisor_name, QNN_TENSOR_TYPE_STATIC, divisor_qnn_data_type,
82+
std::move(divisor_quant_param), std::move(divisor_shape), std::move(divisor_data));
83+
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(divisor_tensorwrapper)), "Failed to add divisor tensor.");
84+
85+
// Create the Div node
86+
const auto& outputs = node_unit.Outputs();
87+
const std::string& output_name = outputs[0].node_arg.Name();
88+
bool is_graph_output = qnn_model_wrapper.IsGraphOutput(output_name);
89+
Qnn_TensorType_t tensor_type = is_graph_output ? QNN_TENSOR_TYPE_APP_READ : QNN_TENSOR_TYPE_NATIVE;
90+
TensorInfo output_info{};
91+
ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(outputs[0], output_info));
92+
QnnTensorWrapper output_tensorwrapper(output_name, tensor_type, output_info.qnn_data_type,
93+
output_info.quant_param.Copy(), std::move(output_info.shape));
94+
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensorwrapper)), "Failed to add output tensor.");
95+
96+
ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(
97+
utils::GetNodeName(node_unit),
98+
QNN_OP_PACKAGE_NAME_QTI_AISW,
99+
QNN_OP_ELEMENT_WISE_DIVIDE,
100+
{divisor_name, input_names[0]},
101+
{output_name},
102+
{},
103+
do_op_validation),
104+
"Failed to create Div node.");
105+
106+
return Status::OK();
107+
}
108+
109+
void CreateReciprocalOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
110+
op_registrations.AddOpBuilder(op_type, std::make_unique<ReciprocalOpBuilder>());
111+
}
112+
113+
} // namespace qnn
114+
} // namespace onnxruntime

onnxruntime/test/providers/qnn/simple_op_htp_test.cc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -991,6 +991,23 @@ TEST_F(QnnHTPBackendTests, BinaryOp_And4D) {
991991
ExpectedEPNodeAssignment::All);
992992
}
993993

994+
// Test Reciprocal on HTP
995+
TEST_F(QnnHTPBackendTests, Reciprocal_Basic_FLOAT) {
996+
RunOpTest<float>("Reciprocal",
997+
{TestInputDef<float>({2, 2}, false, {1.0f, 2.0f, 0.5f, 4.0f})},
998+
{}, // No attributes
999+
13,
1000+
ExpectedEPNodeAssignment::All);
1001+
}
1002+
1003+
TEST_F(QnnHTPBackendTests, Reciprocal_QU8) {
1004+
RunQDQOpTest<uint8_t>("Reciprocal",
1005+
{TestInputDef<float>({2, 2}, false, GetFloatDataInRange(1.0f, 5.0f, 4))},
1006+
{}, // No attributes
1007+
13,
1008+
ExpectedEPNodeAssignment::All);
1009+
}
1010+
9941011
// Test ScatterND op on HTP
9951012
TEST_F(QnnHTPBackendTests, ScatterND_int64_int64) {
9961013
std::vector<int64_t> data = {0, 1, 2, 3};

0 commit comments

Comments
 (0)