Skip to content

Commit d3a916d

Browse files
[QNN EP] Add Support for STFT Op in QNN EP (#26063)
- Added new STFT Op Builder - Added unit tests for running op on HTP - Disabled QNN CPU for STFT Op due to bug in QNN CPU ### Description Added support for STFT with a new op builder ### Motivation and Context To enable audio and signal processing on models within the QNN framework --------- Co-authored-by: Copilot <[email protected]>
1 parent cced33b commit d3a916d

File tree

4 files changed

+448
-0
lines changed

4 files changed

+448
-0
lines changed

onnxruntime/core/providers/qnn/builder/op_builder_factory.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,10 @@ OpBuilderRegistrations::OpBuilderRegistrations() {
212212
CreateThresholdedReluOpBuilder("ThresholdedRelu", *this);
213213
}
214214

215+
{
216+
CreateSTFTOpBuilder("STFT", *this);
217+
}
218+
215219
{
216220
CreateInverseOpBuilder("Inverse", *this);
217221
}

onnxruntime/core/providers/qnn/builder/op_builder_factory.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@ void CreateModOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_r
119119

120120
void CreateThresholdedReluOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
121121

122+
void CreateSTFTOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
123+
122124
void CreateInverseOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
123125

124126
} // namespace qnn
Lines changed: 310 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,310 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "core/providers/qnn/builder/opbuilder/base_op_builder.h"
5+
#include "core/providers/qnn/builder/qnn_model_wrapper.h"
6+
#include "core/providers/qnn/builder/op_builder_factory.h"
7+
#include "core/providers/qnn/builder/qnn_utils.h"
8+
9+
namespace onnxruntime {
10+
namespace qnn {
11+
12+
class STFTOpBuilder : public BaseOpBuilder {
13+
public:
14+
STFTOpBuilder() : BaseOpBuilder("STFTOpBuilder") {}
15+
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(STFTOpBuilder);
16+
17+
protected:
18+
Status ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
19+
const NodeUnit& node_unit,
20+
const logging::Logger& logger,
21+
std::vector<std::string>& input_names,
22+
bool do_op_validation) const override;
23+
24+
Status ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper,
25+
const NodeUnit& node_unit,
26+
std::vector<std::string>&& input_names,
27+
const logging::Logger& logger,
28+
bool do_op_validation) const override;
29+
30+
Status IsOpSupported(QnnModelWrapper& qnn_model_wrapper,
31+
const NodeUnit& node_unit,
32+
const logging::Logger& logger) const override;
33+
};
34+
35+
// Checks if the given input is a window input (float type).
36+
static bool IsWindowInput(const NodeUnitIODef& input) {
37+
return input.node_arg.TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT;
38+
}
39+
40+
Status STFTOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper,
41+
const NodeUnit& node_unit,
42+
const logging::Logger& logger) const {
43+
ORT_UNUSED_PARAMETER(logger);
44+
// TODO: STFT seg faults on QNN CPU
45+
bool is_cpu_backend = IsCpuBackend(qnn_model_wrapper.GetQnnBackendType());
46+
ORT_RETURN_IF(is_cpu_backend, "QNN EP: STFT Op disabled in CPU backend.");
47+
// General Datatype checks on various QNN backend (HTP, CPU, GPU)
48+
ORT_RETURN_IF_ERROR(ProcessDataTypes(qnn_model_wrapper, node_unit));
49+
return AddToModelBuilder(qnn_model_wrapper, node_unit, logger, true);
50+
}
51+
52+
Status STFTOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
53+
const NodeUnit& node_unit,
54+
const logging::Logger& logger,
55+
std::vector<std::string>& input_names,
56+
bool do_op_validation) const {
57+
const auto& inputs = node_unit.Inputs();
58+
59+
// Process signal input (first input)
60+
const auto& signal_input = inputs[0];
61+
const std::string& signal_input_name = signal_input.node_arg.Name();
62+
63+
// Get the shape of the signal input
64+
TensorInfo signal_info{};
65+
ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(signal_input, signal_info));
66+
67+
// Check if the signal input is rank 2 (needs expansion to rank 3)
68+
if (signal_info.shape.size() == 2) {
69+
LOGS(logger, VERBOSE) << "Signal input is rank 2, adding ExpandDims op to convert to rank 3";
70+
71+
// Add the original signal tensor
72+
if (!qnn_model_wrapper.IsQnnTensorWrapperExist(signal_input_name)) {
73+
QnnTensorWrapper signal_tensorwrapper;
74+
ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(signal_input, signal_tensorwrapper));
75+
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(signal_tensorwrapper)),
76+
"Failed to add signal tensor.");
77+
}
78+
79+
// Create a name for the expanded tensor
80+
std::string expanded_tensor_name = signal_input_name + "_expanded";
81+
82+
// Create the expanded tensor with an extra dimension
83+
std::vector<uint32_t> expanded_shape = signal_info.shape;
84+
expanded_shape.push_back(1);
85+
86+
// Create a tensor info for the expanded tensor based on the signal tensor info
87+
TensorInfo expanded_tensor_info = signal_info;
88+
expanded_tensor_info.shape = expanded_shape;
89+
90+
// Create the tensor wrapper using MakeTensorWrapper
91+
QnnTensorWrapper expanded_tensorwrapper;
92+
ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(expanded_tensor_info,
93+
expanded_tensor_name,
94+
expanded_tensorwrapper));
95+
96+
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(expanded_tensorwrapper)),
97+
"Failed to add expanded signal tensor.");
98+
99+
// Create axis parameter for ExpandDims (add dimension at the end)
100+
Qnn_Scalar_t axis_param = QNN_SCALAR_INIT;
101+
axis_param.dataType = QNN_DATATYPE_UINT_32;
102+
axis_param.uint32Value = static_cast<uint32_t>(2); // Add at the end
103+
104+
QnnParamWrapper axis_param_wrapper(node_unit.Index(),
105+
node_unit.Name(),
106+
"axis",
107+
axis_param);
108+
109+
std::vector<std::string> expand_dims_params;
110+
expand_dims_params.push_back(axis_param_wrapper.GetParamTensorName());
111+
qnn_model_wrapper.AddParamWrapper(std::move(axis_param_wrapper));
112+
113+
// Create the ExpandDims node
114+
ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(
115+
utils::GetUniqueName(signal_input_name, "_expand_dims"),
116+
QNN_OP_PACKAGE_NAME_QTI_AISW,
117+
QNN_OP_EXPAND_DIMS,
118+
{signal_input_name},
119+
{expanded_tensor_name},
120+
std::move(expand_dims_params),
121+
do_op_validation),
122+
"Failed to create ExpandDims node.");
123+
124+
// Use the expanded tensor for STFT
125+
input_names.push_back(expanded_tensor_name);
126+
} else {
127+
// Process as normal for rank 3 or higher inputs
128+
if (qnn_model_wrapper.IsQnnTensorWrapperExist(signal_input_name)) {
129+
LOGS(logger, VERBOSE) << "Tensor already added, skip it: " << signal_input_name;
130+
} else {
131+
QnnTensorWrapper signal_tensorwrapper;
132+
ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(signal_input, signal_tensorwrapper));
133+
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(signal_tensorwrapper)),
134+
"Failed to add signal tensor.");
135+
}
136+
input_names.push_back(signal_input_name);
137+
}
138+
139+
// Process frame_step input (second input)
140+
if (inputs.size() > 1) {
141+
const auto& frame_step_input = inputs[1];
142+
const std::string& frame_step_input_name = frame_step_input.node_arg.Name();
143+
144+
if (qnn_model_wrapper.IsQnnTensorWrapperExist(frame_step_input_name)) {
145+
LOGS(logger, VERBOSE) << "Tensor already added, skip it: " << frame_step_input_name;
146+
} else {
147+
QnnTensorWrapper frame_step_tensorwrapper;
148+
ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(frame_step_input, frame_step_tensorwrapper));
149+
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(frame_step_tensorwrapper)),
150+
"Failed to add frame_step tensor.");
151+
}
152+
// We don't add frame_step to input_names because it will be processed as a parameter
153+
}
154+
155+
// Process the 'window' input if it exists and is of type float
156+
if (inputs.size() > 2) {
157+
const auto& window_input = inputs[2];
158+
if (IsWindowInput(window_input)) {
159+
const std::string& window_input_name = window_input.node_arg.Name();
160+
161+
if (qnn_model_wrapper.IsQnnTensorWrapperExist(window_input_name)) {
162+
LOGS(logger, VERBOSE) << "Tensor already added, skip it: " << window_input_name;
163+
} else {
164+
QnnTensorWrapper window_tensorwrapper;
165+
ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(window_input, window_tensorwrapper));
166+
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(window_tensorwrapper)),
167+
"Failed to add window tensor.");
168+
}
169+
input_names.push_back(window_input_name);
170+
}
171+
}
172+
173+
// Process frame_length input if it exists
174+
if (inputs.size() > 3 || (inputs.size() > 2 && !IsWindowInput(inputs[2]))) {
175+
int frame_length_index = inputs.size() > 3 ? 3 : 2;
176+
const auto& frame_length_input = inputs[frame_length_index];
177+
const std::string& frame_length_input_name = frame_length_input.node_arg.Name();
178+
179+
if (qnn_model_wrapper.IsQnnTensorWrapperExist(frame_length_input_name)) {
180+
LOGS(logger, VERBOSE) << "Tensor already added, skip it: " << frame_length_input_name;
181+
} else {
182+
QnnTensorWrapper frame_length_tensorwrapper;
183+
ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(frame_length_input, frame_length_tensorwrapper));
184+
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(frame_length_tensorwrapper)),
185+
"Failed to add frame_length tensor.");
186+
}
187+
// We don't add frame_length to input_names because it will be processed as a parameter
188+
}
189+
190+
return Status::OK();
191+
}
192+
193+
Status STFTOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper,
194+
const NodeUnit& node_unit,
195+
std::vector<std::string>&& input_names,
196+
const logging::Logger& logger,
197+
bool do_op_validation) const {
198+
ORT_UNUSED_PARAMETER(logger);
199+
NodeAttrHelper node_helper(node_unit);
200+
const auto& inputs = node_unit.Inputs();
201+
bool onesided = node_helper.Get("onesided", static_cast<bool>(1));
202+
203+
std::vector<std::string> param_tensor_names;
204+
// Extract frame_step from the inputs if it exists
205+
uint32_t frame_step_info = 0;
206+
TensorInfo frame_step = {};
207+
TensorInfo frame_length = {};
208+
uint32_t frame_length_info = 0;
209+
210+
int frame_length_index = -1;
211+
int frame_step_index = -1;
212+
213+
// Determine indices for frame_step and frame_length
214+
if (inputs.size() >= 2) {
215+
frame_step_index = 1; // frame_step is always the second input
216+
}
217+
218+
if (inputs.size() == 3) {
219+
// Check if the third input is window or frame_length
220+
const auto& third_input = inputs[2];
221+
if (!IsWindowInput(third_input)) {
222+
frame_length_index = 2; // It's frame_length
223+
}
224+
} else if (inputs.size() > 3) {
225+
frame_length_index = 3; // frame_length is the fourth input
226+
}
227+
228+
// Process frame_step
229+
if (frame_step_index != -1) {
230+
const auto& frame_step_input = inputs[frame_step_index];
231+
ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(frame_step_input, frame_step));
232+
std::vector<uint8_t> frame_step_data;
233+
ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*frame_step.initializer_tensor, frame_step_data));
234+
frame_step_info = *reinterpret_cast<uint32_t*>(frame_step_data.data());
235+
Qnn_Scalar_t frame_step_param = QNN_SCALAR_INIT;
236+
frame_step_param.dataType = QNN_DATATYPE_UINT_32;
237+
frame_step_param.uint32Value = frame_step_info;
238+
QnnParamWrapper frame_step_param_wrapper(node_unit.Index(),
239+
node_unit.Name(),
240+
QNN_OP_STFT_PARAM_FRAME_STEP,
241+
frame_step_param);
242+
param_tensor_names.push_back(frame_step_param_wrapper.GetParamTensorName());
243+
qnn_model_wrapper.AddParamWrapper(std::move(frame_step_param_wrapper));
244+
}
245+
246+
// Process frame_length if it exists
247+
if (frame_length_index != -1) {
248+
const auto& frame_length_input = inputs[frame_length_index];
249+
ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(frame_length_input, frame_length));
250+
251+
std::vector<uint8_t> frame_length_data;
252+
ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*frame_length.initializer_tensor, frame_length_data));
253+
frame_length_info = *reinterpret_cast<uint32_t*>(frame_length_data.data());
254+
255+
// Create frame_length parameter
256+
Qnn_Scalar_t frame_length_param = QNN_SCALAR_INIT;
257+
frame_length_param.dataType = QNN_DATATYPE_UINT_32;
258+
frame_length_param.uint32Value = frame_length_info;
259+
QnnParamWrapper frame_length_param_wrapper(node_unit.Index(),
260+
node_unit.Name(),
261+
QNN_OP_STFT_PARAM_FRAME_LENGTH,
262+
frame_length_param);
263+
param_tensor_names.push_back(frame_length_param_wrapper.GetParamTensorName());
264+
qnn_model_wrapper.AddParamWrapper(std::move(frame_length_param_wrapper));
265+
}
266+
267+
const auto& outputs = node_unit.Outputs();
268+
const std::string& output_name = outputs[0].node_arg.Name();
269+
270+
bool is_graph_output = qnn_model_wrapper.IsGraphOutput(output_name);
271+
Qnn_TensorType_t tensor_type = is_graph_output ? QNN_TENSOR_TYPE_APP_READ : QNN_TENSOR_TYPE_NATIVE;
272+
273+
TensorInfo output_info{};
274+
ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(outputs[0], output_info));
275+
276+
QnnTensorWrapper output_tensorwrapper(output_name, tensor_type, output_info.qnn_data_type,
277+
output_info.quant_param.Copy(), std::move(output_info.shape));
278+
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensorwrapper)), "Failed to add output tensor.");
279+
280+
Qnn_Scalar_t onesided_param = QNN_SCALAR_INIT;
281+
onesided_param.dataType = QNN_DATATYPE_BOOL_8;
282+
onesided_param.bool8Value = static_cast<bool>(onesided);
283+
284+
QnnParamWrapper onesided_param_wrapper(node_unit.Index(),
285+
node_unit.Name(),
286+
QNN_OP_STFT_PARAM_ONESIDED,
287+
onesided_param);
288+
289+
param_tensor_names.push_back(onesided_param_wrapper.GetParamTensorName());
290+
qnn_model_wrapper.AddParamWrapper(std::move(onesided_param_wrapper));
291+
292+
ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(
293+
utils::GetUniqueName(node_unit),
294+
QNN_OP_PACKAGE_NAME_QTI_AISW,
295+
QNN_OP_STFT,
296+
std::move(input_names),
297+
{output_name},
298+
std::move(param_tensor_names),
299+
do_op_validation),
300+
"Failed to create STFT node.");
301+
302+
return Status::OK();
303+
}
304+
305+
void CreateSTFTOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
306+
op_registrations.AddOpBuilder(op_type, std::make_unique<STFTOpBuilder>());
307+
}
308+
309+
} // namespace qnn
310+
} // namespace onnxruntime

0 commit comments

Comments
 (0)