Skip to content

Commit 8ef6439

Browse files
committed
[QNN-EP] Add support for Softmax operator with opset < 13
1 parent 9dcb99c commit 8ef6439

File tree

2 files changed

+167
-121
lines changed

2 files changed

+167
-121
lines changed

onnxruntime/core/providers/qnn/builder/opbuilder/softmax_op_builder.cc

Lines changed: 159 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,6 @@ class SoftmaxOpBuilder : public BaseOpBuilder {
1414
SoftmaxOpBuilder() : BaseOpBuilder("SoftmaxOpBuilder") {}
1515
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(SoftmaxOpBuilder);
1616

17-
Status IsOpSupported(QnnModelWrapper& qnn_model_wrapper,
18-
const NodeUnit& node_unit,
19-
const logging::Logger& logger) const override final ORT_MUST_USE_RESULT;
20-
2117
protected:
2218
Status ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
2319
const NodeUnit& node_unit,
@@ -37,34 +33,25 @@ constexpr int32_t GetDefaultAxisAttribute(int opset_version) {
3733
return opset_version < 13 ? 1 : -1;
3834
}
3935

40-
Status SoftmaxOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper,
41-
const NodeUnit& node_unit,
42-
const logging::Logger& logger) const {
43-
ORT_UNUSED_PARAMETER(logger);
44-
const int opset_version = node_unit.SinceVersion();
36+
std::vector<uint32_t> FlattenShapeFromAxis(std::vector<uint32_t>& input_shape, int32_t axis) {
37+
/*
38+
Return the shape with all dimensions multiplied onward from the specified axis. If axis is 0, the returned shape
39+
will include an additional batch of size 1 as the first dimension.
40+
*/
41+
assert(axis >= 0 && axis < input_shape.size());
42+
std::vector<uint32_t> output_shape(input_shape.begin(), input_shape.begin() + axis);
4543

46-
// The QNN HTP backend only supports an `axis` attribute that refers to the last input dimension.
47-
// QNN EP is able to support arbitrary axis attributes by wrapping the QNN operator with transposes.
48-
// However, the exception is Softmax/LogSoftmax with opset < 13. For these older ONNX operators, only
49-
// axis == input_rank - 1 is supported.
50-
if (opset_version < 13) {
51-
const std::string& op_type = node_unit.OpType();
52-
53-
int32_t axis = GetDefaultAxisAttribute(opset_version);
54-
Qnn_Scalar_t axis_qnn_scalar = QNN_SCALAR_INIT;
55-
ORT_RETURN_IF_ERROR(ProcessAxisAttribute(qnn_model_wrapper, node_unit, axis_qnn_scalar, axis));
56-
std::vector<uint32_t> input_shape;
57-
ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(node_unit.Inputs()[0].node_arg, input_shape),
58-
"QNN EP: Cannot get shape for Softmax input");
59-
ORT_RETURN_IF(axis != static_cast<int32_t>(input_shape.size() - 1),
60-
"QNN ", op_type.c_str(),
61-
" only supports an `axis` attribute equal to input_rank-1 (or -1) for ONNX opset < 13");
44+
if (axis == 0) {
45+
output_shape.push_back(1); // Additional batch included
6246
}
47+
output_shape.push_back(
48+
std::accumulate(input_shape.begin() + axis, input_shape.end(), 1, std::multiplies<uint32_t>())
49+
);
6350

64-
return AddToModelBuilder(qnn_model_wrapper, node_unit, logger, true);
51+
return output_shape;
6552
}
6653

67-
static std::vector<uint32_t> GetTransposePermToUseLastAxis(uint32_t input_rank, uint32_t axis) {
54+
std::vector<uint32_t> GetTransposePermToUseLastAxis(uint32_t input_rank, uint32_t axis) {
6855
assert(axis < input_rank);
6956
std::vector<uint32_t> transpose_perm;
7057
transpose_perm.reserve(input_rank);
@@ -87,58 +74,86 @@ Status SoftmaxOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
8774
bool do_op_validation) const {
8875
const bool is_npu_backend = IsNpuBackend(qnn_model_wrapper.GetQnnBackendType());
8976
const auto& inputs = node_unit.Inputs();
77+
const std::string& input_name = inputs[0].node_arg.Name();
9078
assert(inputs.size() == 1);
9179

92-
int32_t axis = GetDefaultAxisAttribute(node_unit.SinceVersion());
80+
const int opset_version = node_unit.SinceVersion();
81+
int32_t axis = GetDefaultAxisAttribute(opset_version);
9382
Qnn_Scalar_t axis_qnn_scalar = QNN_SCALAR_INIT;
9483
ORT_RETURN_IF_ERROR(ProcessAxisAttribute(qnn_model_wrapper, node_unit, axis_qnn_scalar, axis));
9584

9685
TensorInfo input_info = {};
9786
ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(inputs[0], input_info));
98-
const size_t input_rank = input_info.shape.size();
99-
100-
// If the axis attribute refers to the last dimension, then process the input as normal.
101-
if (!is_npu_backend || axis == static_cast<int32_t>(input_rank) - 1) {
102-
return ProcessInput(qnn_model_wrapper, inputs[0], logger, input_names);
103-
}
104-
105-
//
106-
// The axis does **not** refer to the last input dimension. Must wrap transposes around the operator to be able to use
107-
// QNN's Softmax operator, which always uses an axis value that refers to the last dimension.
108-
//
109-
110-
std::vector<uint32_t> transpose_perm = GetTransposePermToUseLastAxis(static_cast<uint32_t>(input_rank),
111-
static_cast<uint32_t>(axis));
87+
size_t input_rank = input_info.shape.size();
88+
ORT_RETURN_IF(input_info.is_initializer, "QNN EP does not support (Log)Softmax with an initializer input, ",
89+
"which should be optimized away by the ORT optimizer");
11290

113-
const std::string& input_name = inputs[0].node_arg.Name();
114-
std::string op_input_name = input_info.is_initializer ? input_name : input_name + "_ort_qnn_ep_transpose";
115-
input_names.push_back(op_input_name);
91+
/*
92+
For Onnx Softmax with opset < 13, its behavior is to flatten the input starting from the axis, and perform
93+
softmax operation along the axis dimension, then reshape back to the original input shape.
94+
QNN EP is able to support arbitrary axis attribute by wrapping reshapes around the operator.
11695
117-
std::vector<uint32_t> op_input_shape = input_info.shape;
118-
op_input_shape[input_rank - 1] = input_info.shape[axis];
119-
op_input_shape[axis] = input_info.shape[input_rank - 1];
96+
Here provides an example:
97+
Given an input with shape=(3, 4, 5) and axis=1. Its behavior is to reshape the input to (3, 20), perform softmax,
98+
and then reshape back to (3, 4, 5).
12099
121-
ORT_RETURN_IF(input_info.is_initializer, "QNN EP does not support (Log)Softmax with an initializer input, ",
122-
"which should be optimized away by the ORT optimizer");
100+
When axis equals 0, the reshape output shape includes an additional batch of size 1 as the first dimension.
101+
Here provides an example:
102+
Given an input with shape=(3, 4, 5) and axis=0. Its behavior is to reshape the input to (1, 60), perform softmax,
103+
and then reshape back to (3, 4, 5).
104+
*/
105+
if (opset_version < 13) {
106+
std::string reshape_output_name = input_name + "_ort_qnn_ep_reshape";
107+
std::vector<uint32_t> reshape_output_shape = FlattenShapeFromAxis(input_info.shape, axis);
123108

124-
// Input is dynamic, so add transpose node before input.
125-
const bool is_graph_input = qnn_model_wrapper.IsGraphInput(input_name);
109+
// Input is dynamic, so add reshape node before input.
110+
const bool is_graph_input = qnn_model_wrapper.IsGraphInput(input_name);
126111

127-
ORT_RETURN_IF_ERROR(qnn_model_wrapper.AddTransposeNode(node_unit.Index(),
128-
input_name,
129-
op_input_name,
112+
ORT_RETURN_IF_ERROR(qnn_model_wrapper.AddReshapeNode(input_name,
113+
reshape_output_name,
130114
input_info.shape,
131-
transpose_perm,
132-
op_input_shape,
115+
reshape_output_shape,
133116
input_info.qnn_data_type,
134117
input_info.quant_param,
135118
do_op_validation,
136-
is_graph_input));
137-
138-
Qnn_TensorType_t tensor_type = qnn_model_wrapper.GetTensorType(op_input_name);
139-
QnnTensorWrapper input_tensorwrapper(op_input_name, tensor_type, input_info.qnn_data_type,
140-
std::move(input_info.quant_param), std::move(op_input_shape), {});
141-
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor.");
119+
is_graph_input,
120+
false));
121+
input_names.push_back(reshape_output_name);
122+
}
123+
/*
124+
For Onnx Softmax with opset >= 13, the QNN HTP backend only supports the axis attribute that refers to the last
125+
input dimension.
126+
QNN EP is able to support arbitrary axis attribute by wrapping transposes around the operator.
127+
*/
128+
else if (is_npu_backend && axis != static_cast<int32_t>(input_rank) - 1) {
129+
std::string transpose_output_name = input_name + "_ort_qnn_ep_transpose";
130+
std::vector<uint32_t> transpose_perm = GetTransposePermToUseLastAxis(static_cast<uint32_t>(input_rank),
131+
static_cast<uint32_t>(axis));
132+
133+
std::vector<uint32_t> transpose_output_shape = input_info.shape;
134+
transpose_output_shape[input_rank - 1] = input_info.shape[axis];
135+
transpose_output_shape[axis] = input_info.shape[input_rank - 1];
136+
137+
// Input is dynamic, so add transpose node before input.
138+
const bool is_graph_input = qnn_model_wrapper.IsGraphInput(input_name);
139+
140+
ORT_RETURN_IF_ERROR(qnn_model_wrapper.AddTransposeNode(node_unit.Index(),
141+
input_name,
142+
transpose_output_name,
143+
input_info.shape,
144+
transpose_perm,
145+
transpose_output_shape,
146+
input_info.qnn_data_type,
147+
input_info.quant_param,
148+
do_op_validation,
149+
is_graph_input,
150+
false));
151+
input_names.push_back(transpose_output_name);
152+
}
153+
// Process the input as normal.
154+
else {
155+
return ProcessInput(qnn_model_wrapper, inputs[0], logger, input_names);
156+
}
142157

143158
return Status::OK();
144159
}
@@ -151,76 +166,107 @@ Status SoftmaxOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_
151166
const bool is_npu_backend = IsNpuBackend(qnn_model_wrapper.GetQnnBackendType());
152167
const std::string& op_type = node_unit.OpType();
153168
const auto& outputs = node_unit.Outputs();
169+
const std::string& orig_output_name = outputs[0].node_arg.Name();
154170
assert(outputs.size() == 1);
155171

156-
int32_t axis = GetDefaultAxisAttribute(node_unit.SinceVersion());
172+
const int opset_version = node_unit.SinceVersion();
173+
int32_t axis = GetDefaultAxisAttribute(opset_version);
157174
Qnn_Scalar_t axis_qnn_scalar = QNN_SCALAR_INIT;
158175
ORT_RETURN_IF_ERROR(ProcessAxisAttribute(qnn_model_wrapper, node_unit, axis_qnn_scalar, axis));
159176

160177
TensorInfo output_info = {};
161178
ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(outputs[0], output_info));
162-
const size_t output_rank = output_info.shape.size();
163-
const bool axis_is_last_dim = static_cast<size_t>(axis) == output_rank - 1;
179+
size_t output_rank = output_info.shape.size();
164180

165-
// If axis refers to the last dimension, process outputs as usual.
166-
if (!is_npu_backend || axis_is_last_dim) {
167-
QnnParamWrapper axis_param(node_unit.Index(), node_unit.Name(), QNN_OP_SOFTMAX_PARAM_AXIS, axis_qnn_scalar);
181+
if (opset_version < 13) {
182+
std::string reshape_input_name = orig_output_name + "_ort_qnn_ep_reshape";
168183

184+
std::vector<uint32_t> reshape_input_shape = FlattenShapeFromAxis(output_info.shape, axis);
185+
if (axis == 0) {
186+
// Override axis due to the inserted batch=1 to the first dimension
187+
axis_qnn_scalar.uint32Value = 1;
188+
}
189+
190+
QnnParamWrapper axis_param(node_unit.Index(), node_unit.Name(), QNN_OP_SOFTMAX_PARAM_AXIS, axis_qnn_scalar);
169191
std::vector<std::string> param_tensor_names;
170192
param_tensor_names.push_back(axis_param.GetParamTensorName());
171193
qnn_model_wrapper.AddParamWrapper(std::move(axis_param));
172194

173-
return ProcessOutputs(qnn_model_wrapper, node_unit,
174-
std::move(input_names),
175-
std::move(param_tensor_names),
176-
logger, do_op_validation, GetQnnOpType(op_type));
177-
}
178-
179-
//
180-
// The axis **does** not refer to the last dimension. Must wrap the operator with Transposes to be able to use
181-
// QNN's Softmax operator, which only supports an axis that refers to the last dimension.
182-
//
183-
184-
axis_qnn_scalar.uint32Value = static_cast<uint32_t>(output_rank - 1); // NOTE: override axis.
185-
QnnParamWrapper axis_param(node_unit.Index(), node_unit.Name(), QNN_OP_SOFTMAX_PARAM_AXIS, axis_qnn_scalar);
186-
187-
std::vector<std::string> param_tensor_names;
188-
param_tensor_names.push_back(axis_param.GetParamTensorName());
189-
qnn_model_wrapper.AddParamWrapper(std::move(axis_param));
190-
191-
const std::string& orig_output_name = outputs[0].node_arg.Name();
192-
std::string op_output_name = orig_output_name + "_ort_qnn_ep_transpose";
193-
194-
std::vector<uint32_t> op_output_shape = output_info.shape;
195-
op_output_shape[output_rank - 1] = output_info.shape[axis];
196-
op_output_shape[axis] = output_info.shape[output_rank - 1];
197-
198-
QnnTensorWrapper output_tensorwrapper(op_output_name, QNN_TENSOR_TYPE_NATIVE, output_info.qnn_data_type,
199-
output_info.quant_param.Copy(), std::vector<uint32_t>(op_output_shape));
200-
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensorwrapper)), "Failed to add tensor.");
201-
ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetNodeName(node_unit),
202-
QNN_OP_PACKAGE_NAME_QTI_AISW,
203-
GetQnnOpType(node_unit.OpType()),
204-
std::move(input_names),
205-
{op_output_name},
206-
std::move(param_tensor_names)),
207-
"Failed to add node.");
208-
209-
const bool is_graph_output = qnn_model_wrapper.IsGraphOutput(orig_output_name);
210-
std::vector<uint32_t> transpose_perm = GetTransposePermToUseLastAxis(static_cast<uint32_t>(output_rank),
211-
static_cast<uint32_t>(axis));
212-
213-
ORT_RETURN_IF_ERROR(qnn_model_wrapper.AddTransposeNode(node_unit.Index(),
214-
op_output_name,
195+
QnnTensorWrapper output_tensorwrapper(reshape_input_name, QNN_TENSOR_TYPE_NATIVE, output_info.qnn_data_type,
196+
output_info.quant_param.Copy(), std::vector<uint32_t>(reshape_input_shape));
197+
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensorwrapper)), "Failed to add tensor.");
198+
ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetNodeName(node_unit),
199+
QNN_OP_PACKAGE_NAME_QTI_AISW,
200+
GetQnnOpType(node_unit.OpType()),
201+
std::move(input_names),
202+
{reshape_input_name},
203+
std::move(param_tensor_names)),
204+
"Failed to add node.");
205+
206+
const bool is_graph_output = qnn_model_wrapper.IsGraphOutput(orig_output_name);
207+
ORT_RETURN_IF_ERROR(qnn_model_wrapper.AddReshapeNode(reshape_input_name,
215208
orig_output_name,
216-
op_output_shape,
217-
transpose_perm,
209+
reshape_input_shape,
218210
output_info.shape,
219211
output_info.qnn_data_type,
220212
output_info.quant_param,
221213
do_op_validation,
222214
false,
223215
is_graph_output));
216+
}
217+
else if (is_npu_backend && axis != static_cast<int32_t>(output_rank) - 1) {
218+
std::string transpose_input_name = orig_output_name + "_ort_qnn_ep_transpose";
219+
220+
std::vector<uint32_t> transpose_input_shape = output_info.shape;
221+
transpose_input_shape[output_rank - 1] = output_info.shape[axis];
222+
transpose_input_shape[axis] = output_info.shape[output_rank - 1];
223+
224+
// Override axis due to the actual shape after the inserted transpose node
225+
axis_qnn_scalar.uint32Value = static_cast<uint32_t>(output_rank) - 1;
226+
227+
QnnParamWrapper axis_param(node_unit.Index(), node_unit.Name(), QNN_OP_SOFTMAX_PARAM_AXIS, axis_qnn_scalar);
228+
std::vector<std::string> param_tensor_names;
229+
param_tensor_names.push_back(axis_param.GetParamTensorName());
230+
qnn_model_wrapper.AddParamWrapper(std::move(axis_param));
231+
232+
QnnTensorWrapper output_tensorwrapper(transpose_input_name, QNN_TENSOR_TYPE_NATIVE, output_info.qnn_data_type,
233+
output_info.quant_param.Copy(), std::vector<uint32_t>(transpose_input_shape));
234+
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensorwrapper)), "Failed to add tensor.");
235+
ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetNodeName(node_unit),
236+
QNN_OP_PACKAGE_NAME_QTI_AISW,
237+
GetQnnOpType(node_unit.OpType()),
238+
std::move(input_names),
239+
{transpose_input_name},
240+
std::move(param_tensor_names)),
241+
"Failed to add node.");
242+
243+
const bool is_graph_output = qnn_model_wrapper.IsGraphOutput(orig_output_name);
244+
std::vector<uint32_t> transpose_perm = GetTransposePermToUseLastAxis(static_cast<uint32_t>(output_rank),
245+
static_cast<uint32_t>(axis));
246+
247+
ORT_RETURN_IF_ERROR(qnn_model_wrapper.AddTransposeNode(node_unit.Index(),
248+
transpose_input_name,
249+
orig_output_name,
250+
transpose_input_shape,
251+
transpose_perm,
252+
output_info.shape,
253+
output_info.qnn_data_type,
254+
output_info.quant_param,
255+
do_op_validation,
256+
false,
257+
is_graph_output));
258+
}
259+
else {
260+
QnnParamWrapper axis_param(node_unit.Index(), node_unit.Name(), QNN_OP_SOFTMAX_PARAM_AXIS, axis_qnn_scalar);
261+
std::vector<std::string> param_tensor_names;
262+
param_tensor_names.push_back(axis_param.GetParamTensorName());
263+
qnn_model_wrapper.AddParamWrapper(std::move(axis_param));
264+
265+
return ProcessOutputs(qnn_model_wrapper, node_unit,
266+
std::move(input_names),
267+
std::move(param_tensor_names),
268+
logger, do_op_validation, GetQnnOpType(op_type));
269+
}
224270

225271
return Status::OK();
226272
}

0 commit comments

Comments
 (0)