1+ // Copyright (c) Microsoft Corporation. All rights reserved.
2+ // Licensed under the MIT License.
3+
4+ #include " core/providers/qnn/builder/opbuilder/base_op_builder.h"
5+ #include " core/providers/qnn/builder/qnn_model_wrapper.h"
6+ #include " core/providers/qnn/builder/op_builder_factory.h"
7+ #include " core/providers/qnn/builder/qnn_utils.h"
8+
9+ namespace onnxruntime {
10+ namespace qnn {
11+
12+ class ReciprocalOpBuilder : public BaseOpBuilder {
13+ public:
14+ ReciprocalOpBuilder () : BaseOpBuilder(" ReciprocalOpBuilder" ) {}
15+ ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE (ReciprocalOpBuilder);
16+
17+ Status IsOpSupported (QnnModelWrapper& qnn_model_wrapper,
18+ const NodeUnit& node_unit,
19+ const logging::Logger& logger) const override final ORT_MUST_USE_RESULT;
20+
21+ protected:
22+ Status ProcessAttributesAndOutputs (QnnModelWrapper& qnn_model_wrapper,
23+ const NodeUnit& node_unit,
24+ std::vector<std::string>&& input_names,
25+ const logging::Logger& logger,
26+ bool do_op_validation) const override ORT_MUST_USE_RESULT;
27+ };
28+
29+ Status ReciprocalOpBuilder::IsOpSupported (QnnModelWrapper& qnn_model_wrapper,
30+ const NodeUnit& node_unit,
31+ const logging::Logger& logger) const {
32+ ORT_UNUSED_PARAMETER (logger);
33+
34+ const auto & inputs = node_unit.Inputs ();
35+ ORT_RETURN_IF_NOT (inputs.size () == 1 , " Reciprocal operator must have exactly 1 input." );
36+
37+ const auto & outputs = node_unit.Outputs ();
38+ ORT_RETURN_IF_NOT (outputs.size () == 1 , " Reciprocal operator must have exactly 1 output." );
39+
40+ // Check input type is float for CPU.
41+ ORT_RETURN_IF_ERROR (DataTypeCheckForCpuBackend (qnn_model_wrapper, inputs[0 ].node_arg .Type ()));
42+
43+ return Status::OK ();
44+ }
45+
46+ Status ReciprocalOpBuilder::ProcessAttributesAndOutputs (QnnModelWrapper& qnn_model_wrapper,
47+ const NodeUnit& node_unit,
48+ std::vector<std::string>&& input_names,
49+ const logging::Logger& logger,
50+ bool do_op_validation) const {
51+ ORT_UNUSED_PARAMETER (logger);
52+
53+ // Create a constant tensor for the divisor (1.0)
54+ std::string divisor_name = node_unit.Name () + " _divisor" ;
55+ std::vector<uint32_t > divisor_shape{1 };
56+ std::vector<uint8_t > divisor_data;
57+
58+ TensorInfo input_info{};
59+ ORT_RETURN_IF_ERROR (qnn_model_wrapper.GetTensorInfo (node_unit.Inputs ()[0 ], input_info));
60+
61+ QnnQuantParamsWrapper divisor_quant_param = input_info.quant_param .Copy ();
62+ Qnn_DataType_t divisor_qnn_data_type = input_info.qnn_data_type ;
63+
64+ if (input_info.quant_param .IsQuantized ()) {
65+ // Create a quantized divisor tensor
66+ double divisor_value = 1.0 ;
67+ int quantized_divisor_value;
68+ ORT_RETURN_IF_ERROR (utils::Quantize (divisor_value, divisor_quant_param.Get ().scaleOffsetEncoding .scale ,
69+ divisor_quant_param.Get ().scaleOffsetEncoding .offset ,
70+ divisor_qnn_data_type, quantized_divisor_value));
71+ size_t element_size = qnn::utils::GetElementSizeByType (divisor_qnn_data_type);
72+ divisor_data.resize (element_size);
73+ std::memcpy (divisor_data.data (), &quantized_divisor_value, element_size);
74+ } else {
75+ // Create a float divisor tensor
76+ divisor_data.resize (sizeof (float ));
77+ float one = 1 .0f ;
78+ std::memcpy (divisor_data.data (), &one, sizeof (float ));
79+ }
80+
81+ QnnTensorWrapper divisor_tensorwrapper (divisor_name, QNN_TENSOR_TYPE_STATIC, divisor_qnn_data_type,
82+ std::move (divisor_quant_param), std::move (divisor_shape), std::move (divisor_data));
83+ ORT_RETURN_IF_NOT (qnn_model_wrapper.AddTensorWrapper (std::move (divisor_tensorwrapper)), " Failed to add divisor tensor." );
84+
85+ // Create the Div node
86+ const auto & outputs = node_unit.Outputs ();
87+ const std::string& output_name = outputs[0 ].node_arg .Name ();
88+ bool is_graph_output = qnn_model_wrapper.IsGraphOutput (output_name);
89+ Qnn_TensorType_t tensor_type = is_graph_output ? QNN_TENSOR_TYPE_APP_READ : QNN_TENSOR_TYPE_NATIVE;
90+ TensorInfo output_info{};
91+ ORT_RETURN_IF_ERROR (qnn_model_wrapper.GetTensorInfo (outputs[0 ], output_info));
92+ QnnTensorWrapper output_tensorwrapper (output_name, tensor_type, output_info.qnn_data_type ,
93+ output_info.quant_param .Copy (), std::move (output_info.shape ));
94+ ORT_RETURN_IF_NOT (qnn_model_wrapper.AddTensorWrapper (std::move (output_tensorwrapper)), " Failed to add output tensor." );
95+
96+ ORT_RETURN_IF_NOT (qnn_model_wrapper.CreateQnnNode (
97+ utils::GetNodeName (node_unit),
98+ QNN_OP_PACKAGE_NAME_QTI_AISW,
99+ QNN_OP_ELEMENT_WISE_DIVIDE,
100+ {divisor_name, input_names[0 ]},
101+ {output_name},
102+ {},
103+ do_op_validation),
104+ " Failed to create Div node." );
105+
106+ return Status::OK ();
107+ }
108+
109+ void CreateReciprocalOpBuilder (const std::string& op_type, OpBuilderRegistrations& op_registrations) {
110+ op_registrations.AddOpBuilder (op_type, std::make_unique<ReciprocalOpBuilder>());
111+ }
112+
113+ } // namespace qnn
114+ } // namespace onnxruntime
0 commit comments