@@ -14,10 +14,6 @@ class SoftmaxOpBuilder : public BaseOpBuilder {
1414 SoftmaxOpBuilder () : BaseOpBuilder(" SoftmaxOpBuilder" ) {}
1515 ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE (SoftmaxOpBuilder);
1616
17- Status IsOpSupported (QnnModelWrapper& qnn_model_wrapper,
18- const NodeUnit& node_unit,
19- const logging::Logger& logger) const override final ORT_MUST_USE_RESULT;
20-
2117 protected:
2218 Status ProcessInputs (QnnModelWrapper& qnn_model_wrapper,
2319 const NodeUnit& node_unit,
@@ -37,34 +33,25 @@ constexpr int32_t GetDefaultAxisAttribute(int opset_version) {
3733 return opset_version < 13 ? 1 : -1 ;
3834}
3935
40- Status SoftmaxOpBuilder::IsOpSupported (QnnModelWrapper& qnn_model_wrapper,
41- const NodeUnit& node_unit,
42- const logging::Logger& logger) const {
43- ORT_UNUSED_PARAMETER (logger);
44- const int opset_version = node_unit.SinceVersion ();
36+ std::vector<uint32_t > FlattenShapeFromAxis (std::vector<uint32_t >& input_shape, int32_t axis) {
37+ /*
38+ Return the shape with all dimensions multiplied onward from the specified axis. If axis is 0, the returned shape
39+ will include an additional batch of size 1 as the first dimension.
40+ */
41+ assert (axis >= 0 && axis < input_shape.size ());
42+ std::vector<uint32_t > output_shape (input_shape.begin (), input_shape.begin () + axis);
4543
46- // The QNN HTP backend only supports an `axis` attribute that refers to the last input dimension.
47- // QNN EP is able to support arbitrary axis attributes by wrapping the QNN operator with transposes.
48- // However, the exception is Softmax/LogSoftmax with opset < 13. For these older ONNX operators, only
49- // axis == input_rank - 1 is supported.
50- if (opset_version < 13 ) {
51- const std::string& op_type = node_unit.OpType ();
52-
53- int32_t axis = GetDefaultAxisAttribute (opset_version);
54- Qnn_Scalar_t axis_qnn_scalar = QNN_SCALAR_INIT;
55- ORT_RETURN_IF_ERROR (ProcessAxisAttribute (qnn_model_wrapper, node_unit, axis_qnn_scalar, axis));
56- std::vector<uint32_t > input_shape;
57- ORT_RETURN_IF_NOT (qnn_model_wrapper.GetOnnxShape (node_unit.Inputs ()[0 ].node_arg , input_shape),
58- " QNN EP: Cannot get shape for Softmax input" );
59- ORT_RETURN_IF (axis != static_cast <int32_t >(input_shape.size () - 1 ),
60- " QNN " , op_type.c_str (),
61- " only supports an `axis` attribute equal to input_rank-1 (or -1) for ONNX opset < 13" );
44+ if (axis == 0 ) {
45+ output_shape.push_back (1 ); // Additional batch included
6246 }
47+ output_shape.push_back (
48+ std::accumulate (input_shape.begin () + axis, input_shape.end (), 1 , std::multiplies<uint32_t >())
49+ );
6350
64- return AddToModelBuilder (qnn_model_wrapper, node_unit, logger, true ) ;
51+ return output_shape ;
6552}
6653
67- static std::vector<uint32_t > GetTransposePermToUseLastAxis (uint32_t input_rank, uint32_t axis) {
54+ std::vector<uint32_t > GetTransposePermToUseLastAxis (uint32_t input_rank, uint32_t axis) {
6855 assert (axis < input_rank);
6956 std::vector<uint32_t > transpose_perm;
7057 transpose_perm.reserve (input_rank);
@@ -87,58 +74,86 @@ Status SoftmaxOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
8774 bool do_op_validation) const {
8875 const bool is_npu_backend = IsNpuBackend (qnn_model_wrapper.GetQnnBackendType ());
8976 const auto & inputs = node_unit.Inputs ();
77+ const std::string& input_name = inputs[0 ].node_arg .Name ();
9078 assert (inputs.size () == 1 );
9179
92- int32_t axis = GetDefaultAxisAttribute (node_unit.SinceVersion ());
80+ const int opset_version = node_unit.SinceVersion ();
81+ int32_t axis = GetDefaultAxisAttribute (opset_version);
9382 Qnn_Scalar_t axis_qnn_scalar = QNN_SCALAR_INIT;
9483 ORT_RETURN_IF_ERROR (ProcessAxisAttribute (qnn_model_wrapper, node_unit, axis_qnn_scalar, axis));
9584
9685 TensorInfo input_info = {};
9786 ORT_RETURN_IF_ERROR (qnn_model_wrapper.GetTensorInfo (inputs[0 ], input_info));
98- const size_t input_rank = input_info.shape .size ();
99-
100- // If the axis attribute refers to the last dimension, then process the input as normal.
101- if (!is_npu_backend || axis == static_cast <int32_t >(input_rank) - 1 ) {
102- return ProcessInput (qnn_model_wrapper, inputs[0 ], logger, input_names);
103- }
104-
105- //
106- // The axis does **not** refer to the last input dimension. Must wrap transposes around the operator to be able to use
107- // QNN's Softmax operator, which always uses an axis value that refers to the last dimension.
108- //
109-
110- std::vector<uint32_t > transpose_perm = GetTransposePermToUseLastAxis (static_cast <uint32_t >(input_rank),
111- static_cast <uint32_t >(axis));
87+ size_t input_rank = input_info.shape .size ();
88+ ORT_RETURN_IF (input_info.is_initializer , " QNN EP does not support (Log)Softmax with an initializer input, " ,
89+ " which should be optimized away by the ORT optimizer" );
11290
113- const std::string& input_name = inputs[0 ].node_arg .Name ();
114- std::string op_input_name = input_info.is_initializer ? input_name : input_name + " _ort_qnn_ep_transpose" ;
115- input_names.push_back (op_input_name);
91+ /*
92+ For Onnx Softmax with opset < 13, its behavior is to flatten the input starting from the axis, and perform
93+ softmax operation along the axis dimension, then reshape back to the original input shape.
94+ QNN EP is able to support arbitrary axis attribute by wrapping reshapes around the operator.
11695
117- std::vector< uint32_t > op_input_shape = input_info. shape ;
118- op_input_shape[input_rank - 1 ] = input_info. shape [ axis];
119- op_input_shape[axis] = input_info. shape [input_rank - 1 ];
96+ Here provides an example:
97+ Given an input with shape=(3, 4, 5) and axis=1. Its behavior is to reshape the input to (3, 20), perform softmax,
98+ and then reshape back to (3, 4, 5).
12099
121- ORT_RETURN_IF (input_info.is_initializer , " QNN EP does not support (Log)Softmax with an initializer input, " ,
122- " which should be optimized away by the ORT optimizer" );
100+ When axis equals 0, the reshape output shape includes an additional batch of size 1 as the first dimension.
101+ Here provides an example:
102+ Given an input with shape=(3, 4, 5) and axis=0. Its behavior is to reshape the input to (1, 60), perform softmax,
103+ and then reshape back to (3, 4, 5).
104+ */
105+ if (opset_version < 13 ) {
106+ std::string reshape_output_name = input_name + " _ort_qnn_ep_reshape" ;
107+ std::vector<uint32_t > reshape_output_shape = FlattenShapeFromAxis (input_info.shape , axis);
123108
124- // Input is dynamic, so add transpose node before input.
125- const bool is_graph_input = qnn_model_wrapper.IsGraphInput (input_name);
109+ // Input is dynamic, so add reshape node before input.
110+ const bool is_graph_input = qnn_model_wrapper.IsGraphInput (input_name);
126111
127- ORT_RETURN_IF_ERROR (qnn_model_wrapper.AddTransposeNode (node_unit.Index (),
128- input_name,
129- op_input_name,
112+ ORT_RETURN_IF_ERROR (qnn_model_wrapper.AddReshapeNode (input_name,
113+ reshape_output_name,
130114 input_info.shape ,
131- transpose_perm,
132- op_input_shape,
115+ reshape_output_shape,
133116 input_info.qnn_data_type ,
134117 input_info.quant_param ,
135118 do_op_validation,
136- is_graph_input));
137-
138- Qnn_TensorType_t tensor_type = qnn_model_wrapper.GetTensorType (op_input_name);
139- QnnTensorWrapper input_tensorwrapper (op_input_name, tensor_type, input_info.qnn_data_type ,
140- std::move (input_info.quant_param ), std::move (op_input_shape), {});
141- ORT_RETURN_IF_NOT (qnn_model_wrapper.AddTensorWrapper (std::move (input_tensorwrapper)), " Failed to add tensor." );
119+ is_graph_input,
120+ false ));
121+ input_names.push_back (reshape_output_name);
122+ }
123+ /*
124+ For Onnx Softmax with opset >= 13, the QNN HTP backend only supports the axis attribute that refers to the last
125+ input dimension.
126+ QNN EP is able to support arbitrary axis attribute by wrapping transposes around the operator.
127+ */
128+ else if (is_npu_backend && axis != static_cast <int32_t >(input_rank) - 1 ) {
129+ std::string transpose_output_name = input_name + " _ort_qnn_ep_transpose" ;
130+ std::vector<uint32_t > transpose_perm = GetTransposePermToUseLastAxis (static_cast <uint32_t >(input_rank),
131+ static_cast <uint32_t >(axis));
132+
133+ std::vector<uint32_t > transpose_output_shape = input_info.shape ;
134+ transpose_output_shape[input_rank - 1 ] = input_info.shape [axis];
135+ transpose_output_shape[axis] = input_info.shape [input_rank - 1 ];
136+
137+ // Input is dynamic, so add transpose node before input.
138+ const bool is_graph_input = qnn_model_wrapper.IsGraphInput (input_name);
139+
140+ ORT_RETURN_IF_ERROR (qnn_model_wrapper.AddTransposeNode (node_unit.Index (),
141+ input_name,
142+ transpose_output_name,
143+ input_info.shape ,
144+ transpose_perm,
145+ transpose_output_shape,
146+ input_info.qnn_data_type ,
147+ input_info.quant_param ,
148+ do_op_validation,
149+ is_graph_input,
150+ false ));
151+ input_names.push_back (transpose_output_name);
152+ }
153+ // Process the input as normal.
154+ else {
155+ return ProcessInput (qnn_model_wrapper, inputs[0 ], logger, input_names);
156+ }
142157
143158 return Status::OK ();
144159}
@@ -151,76 +166,107 @@ Status SoftmaxOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_
151166 const bool is_npu_backend = IsNpuBackend (qnn_model_wrapper.GetQnnBackendType ());
152167 const std::string& op_type = node_unit.OpType ();
153168 const auto & outputs = node_unit.Outputs ();
169+ const std::string& orig_output_name = outputs[0 ].node_arg .Name ();
154170 assert (outputs.size () == 1 );
155171
156- int32_t axis = GetDefaultAxisAttribute (node_unit.SinceVersion ());
172+ const int opset_version = node_unit.SinceVersion ();
173+ int32_t axis = GetDefaultAxisAttribute (opset_version);
157174 Qnn_Scalar_t axis_qnn_scalar = QNN_SCALAR_INIT;
158175 ORT_RETURN_IF_ERROR (ProcessAxisAttribute (qnn_model_wrapper, node_unit, axis_qnn_scalar, axis));
159176
160177 TensorInfo output_info = {};
161178 ORT_RETURN_IF_ERROR (qnn_model_wrapper.GetTensorInfo (outputs[0 ], output_info));
162- const size_t output_rank = output_info.shape .size ();
163- const bool axis_is_last_dim = static_cast <size_t >(axis) == output_rank - 1 ;
179+ size_t output_rank = output_info.shape .size ();
164180
165- // If axis refers to the last dimension, process outputs as usual.
166- if (!is_npu_backend || axis_is_last_dim) {
167- QnnParamWrapper axis_param (node_unit.Index (), node_unit.Name (), QNN_OP_SOFTMAX_PARAM_AXIS, axis_qnn_scalar);
181+ if (opset_version < 13 ) {
182+ std::string reshape_input_name = orig_output_name + " _ort_qnn_ep_reshape" ;
168183
184+ std::vector<uint32_t > reshape_input_shape = FlattenShapeFromAxis (output_info.shape , axis);
185+ if (axis == 0 ) {
186+ // Override axis due to the inserted batch=1 to the first dimension
187+ axis_qnn_scalar.uint32Value = 1 ;
188+ }
189+
190+ QnnParamWrapper axis_param (node_unit.Index (), node_unit.Name (), QNN_OP_SOFTMAX_PARAM_AXIS, axis_qnn_scalar);
169191 std::vector<std::string> param_tensor_names;
170192 param_tensor_names.push_back (axis_param.GetParamTensorName ());
171193 qnn_model_wrapper.AddParamWrapper (std::move (axis_param));
172194
173- return ProcessOutputs (qnn_model_wrapper, node_unit,
174- std::move (input_names),
175- std::move (param_tensor_names),
176- logger, do_op_validation, GetQnnOpType (op_type));
177- }
178-
179- //
180- // The axis **does** not refer to the last dimension. Must wrap the operator with Transposes to be able to use
181- // QNN's Softmax operator, which only supports an axis that refers to the last dimension.
182- //
183-
184- axis_qnn_scalar.uint32Value = static_cast <uint32_t >(output_rank - 1 ); // NOTE: override axis.
185- QnnParamWrapper axis_param (node_unit.Index (), node_unit.Name (), QNN_OP_SOFTMAX_PARAM_AXIS, axis_qnn_scalar);
186-
187- std::vector<std::string> param_tensor_names;
188- param_tensor_names.push_back (axis_param.GetParamTensorName ());
189- qnn_model_wrapper.AddParamWrapper (std::move (axis_param));
190-
191- const std::string& orig_output_name = outputs[0 ].node_arg .Name ();
192- std::string op_output_name = orig_output_name + " _ort_qnn_ep_transpose" ;
193-
194- std::vector<uint32_t > op_output_shape = output_info.shape ;
195- op_output_shape[output_rank - 1 ] = output_info.shape [axis];
196- op_output_shape[axis] = output_info.shape [output_rank - 1 ];
197-
198- QnnTensorWrapper output_tensorwrapper (op_output_name, QNN_TENSOR_TYPE_NATIVE, output_info.qnn_data_type ,
199- output_info.quant_param .Copy (), std::vector<uint32_t >(op_output_shape));
200- ORT_RETURN_IF_NOT (qnn_model_wrapper.AddTensorWrapper (std::move (output_tensorwrapper)), " Failed to add tensor." );
201- ORT_RETURN_IF_NOT (qnn_model_wrapper.CreateQnnNode (utils::GetNodeName (node_unit),
202- QNN_OP_PACKAGE_NAME_QTI_AISW,
203- GetQnnOpType (node_unit.OpType ()),
204- std::move (input_names),
205- {op_output_name},
206- std::move (param_tensor_names)),
207- " Failed to add node." );
208-
209- const bool is_graph_output = qnn_model_wrapper.IsGraphOutput (orig_output_name);
210- std::vector<uint32_t > transpose_perm = GetTransposePermToUseLastAxis (static_cast <uint32_t >(output_rank),
211- static_cast <uint32_t >(axis));
212-
213- ORT_RETURN_IF_ERROR (qnn_model_wrapper.AddTransposeNode (node_unit.Index (),
214- op_output_name,
195+ QnnTensorWrapper output_tensorwrapper (reshape_input_name, QNN_TENSOR_TYPE_NATIVE, output_info.qnn_data_type ,
196+ output_info.quant_param .Copy (), std::vector<uint32_t >(reshape_input_shape));
197+ ORT_RETURN_IF_NOT (qnn_model_wrapper.AddTensorWrapper (std::move (output_tensorwrapper)), " Failed to add tensor." );
198+ ORT_RETURN_IF_NOT (qnn_model_wrapper.CreateQnnNode (utils::GetNodeName (node_unit),
199+ QNN_OP_PACKAGE_NAME_QTI_AISW,
200+ GetQnnOpType (node_unit.OpType ()),
201+ std::move (input_names),
202+ {reshape_input_name},
203+ std::move (param_tensor_names)),
204+ " Failed to add node." );
205+
206+ const bool is_graph_output = qnn_model_wrapper.IsGraphOutput (orig_output_name);
207+ ORT_RETURN_IF_ERROR (qnn_model_wrapper.AddReshapeNode (reshape_input_name,
215208 orig_output_name,
216- op_output_shape,
217- transpose_perm,
209+ reshape_input_shape,
218210 output_info.shape ,
219211 output_info.qnn_data_type ,
220212 output_info.quant_param ,
221213 do_op_validation,
222214 false ,
223215 is_graph_output));
216+ }
217+ else if (is_npu_backend && axis != static_cast <int32_t >(output_rank) - 1 ) {
218+ std::string transpose_input_name = orig_output_name + " _ort_qnn_ep_transpose" ;
219+
220+ std::vector<uint32_t > transpose_input_shape = output_info.shape ;
221+ transpose_input_shape[output_rank - 1 ] = output_info.shape [axis];
222+ transpose_input_shape[axis] = output_info.shape [output_rank - 1 ];
223+
224+ // Override axis due to the actual shape after the inserted transpose node
225+ axis_qnn_scalar.uint32Value = static_cast <uint32_t >(output_rank) - 1 ;
226+
227+ QnnParamWrapper axis_param (node_unit.Index (), node_unit.Name (), QNN_OP_SOFTMAX_PARAM_AXIS, axis_qnn_scalar);
228+ std::vector<std::string> param_tensor_names;
229+ param_tensor_names.push_back (axis_param.GetParamTensorName ());
230+ qnn_model_wrapper.AddParamWrapper (std::move (axis_param));
231+
232+ QnnTensorWrapper output_tensorwrapper (transpose_input_name, QNN_TENSOR_TYPE_NATIVE, output_info.qnn_data_type ,
233+ output_info.quant_param .Copy (), std::vector<uint32_t >(transpose_input_shape));
234+ ORT_RETURN_IF_NOT (qnn_model_wrapper.AddTensorWrapper (std::move (output_tensorwrapper)), " Failed to add tensor." );
235+ ORT_RETURN_IF_NOT (qnn_model_wrapper.CreateQnnNode (utils::GetNodeName (node_unit),
236+ QNN_OP_PACKAGE_NAME_QTI_AISW,
237+ GetQnnOpType (node_unit.OpType ()),
238+ std::move (input_names),
239+ {transpose_input_name},
240+ std::move (param_tensor_names)),
241+ " Failed to add node." );
242+
243+ const bool is_graph_output = qnn_model_wrapper.IsGraphOutput (orig_output_name);
244+ std::vector<uint32_t > transpose_perm = GetTransposePermToUseLastAxis (static_cast <uint32_t >(output_rank),
245+ static_cast <uint32_t >(axis));
246+
247+ ORT_RETURN_IF_ERROR (qnn_model_wrapper.AddTransposeNode (node_unit.Index (),
248+ transpose_input_name,
249+ orig_output_name,
250+ transpose_input_shape,
251+ transpose_perm,
252+ output_info.shape ,
253+ output_info.qnn_data_type ,
254+ output_info.quant_param ,
255+ do_op_validation,
256+ false ,
257+ is_graph_output));
258+ }
259+ else {
260+ QnnParamWrapper axis_param (node_unit.Index (), node_unit.Name (), QNN_OP_SOFTMAX_PARAM_AXIS, axis_qnn_scalar);
261+ std::vector<std::string> param_tensor_names;
262+ param_tensor_names.push_back (axis_param.GetParamTensorName ());
263+ qnn_model_wrapper.AddParamWrapper (std::move (axis_param));
264+
265+ return ProcessOutputs (qnn_model_wrapper, node_unit,
266+ std::move (input_names),
267+ std::move (param_tensor_names),
268+ logger, do_op_validation, GetQnnOpType (op_type));
269+ }
224270
225271 return Status::OK ();
226272}
0 commit comments