1
+ // Copyright (c) Microsoft Corporation. All rights reserved.
2
+ // Licensed under the MIT License.
3
+
4
+ #include " core/providers/qnn/builder/opbuilder/base_op_builder.h"
5
+ #include " core/providers/qnn/builder/qnn_model_wrapper.h"
6
+ #include " core/providers/qnn/builder/op_builder_factory.h"
7
+ #include " core/providers/qnn/builder/qnn_utils.h"
8
+
9
+ namespace onnxruntime {
10
+ namespace qnn {
11
+
12
+ class STFTOpBuilder : public BaseOpBuilder {
13
+ public:
14
+ STFTOpBuilder () : BaseOpBuilder(" STFTOpBuilder" ) {}
15
+ ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE (STFTOpBuilder);
16
+
17
+ protected:
18
+ Status ProcessInputs (QnnModelWrapper& qnn_model_wrapper,
19
+ const NodeUnit& node_unit,
20
+ const logging::Logger& logger,
21
+ std::vector<std::string>& input_names,
22
+ bool do_op_validation) const override ;
23
+
24
+ Status ProcessAttributesAndOutputs (QnnModelWrapper& qnn_model_wrapper,
25
+ const NodeUnit& node_unit,
26
+ std::vector<std::string>&& input_names,
27
+ const logging::Logger& logger,
28
+ bool do_op_validation) const override ;
29
+
30
+ Status IsOpSupported (QnnModelWrapper& qnn_model_wrapper,
31
+ const NodeUnit& node_unit,
32
+ const logging::Logger& logger) const override ;
33
+ };
34
+
35
+ // Checks if the given input is a window input (float type).
36
+ static bool IsWindowInput (const NodeUnitIODef& input) {
37
+ return input.node_arg .TypeAsProto ()->tensor_type ().elem_type () == ONNX_NAMESPACE::TensorProto_DataType_FLOAT;
38
+ }
39
+
40
+ Status STFTOpBuilder::IsOpSupported (QnnModelWrapper& qnn_model_wrapper,
41
+ const NodeUnit& node_unit,
42
+ const logging::Logger& logger) const {
43
+ ORT_UNUSED_PARAMETER (logger);
44
+ // TODO: STFT seg faults on QNN CPU
45
+ bool is_cpu_backend = IsCpuBackend (qnn_model_wrapper.GetQnnBackendType ());
46
+ ORT_RETURN_IF (is_cpu_backend, " QNN EP: STFT Op disabled in CPU backend." );
47
+ // General Datatype checks on various QNN backend (HTP, CPU, GPU)
48
+ ORT_RETURN_IF_ERROR (ProcessDataTypes (qnn_model_wrapper, node_unit));
49
+ return AddToModelBuilder (qnn_model_wrapper, node_unit, logger, true );
50
+ }
51
+
52
+ Status STFTOpBuilder::ProcessInputs (QnnModelWrapper& qnn_model_wrapper,
53
+ const NodeUnit& node_unit,
54
+ const logging::Logger& logger,
55
+ std::vector<std::string>& input_names,
56
+ bool do_op_validation) const {
57
+ const auto & inputs = node_unit.Inputs ();
58
+
59
+ // Process signal input (first input)
60
+ const auto & signal_input = inputs[0 ];
61
+ const std::string& signal_input_name = signal_input.node_arg .Name ();
62
+
63
+ // Get the shape of the signal input
64
+ TensorInfo signal_info{};
65
+ ORT_RETURN_IF_ERROR (qnn_model_wrapper.GetTensorInfo (signal_input, signal_info));
66
+
67
+ // Check if the signal input is rank 2 (needs expansion to rank 3)
68
+ if (signal_info.shape .size () == 2 ) {
69
+ LOGS (logger, VERBOSE) << " Signal input is rank 2, adding ExpandDims op to convert to rank 3" ;
70
+
71
+ // Add the original signal tensor
72
+ if (!qnn_model_wrapper.IsQnnTensorWrapperExist (signal_input_name)) {
73
+ QnnTensorWrapper signal_tensorwrapper;
74
+ ORT_RETURN_IF_ERROR (qnn_model_wrapper.MakeTensorWrapper (signal_input, signal_tensorwrapper));
75
+ ORT_RETURN_IF_NOT (qnn_model_wrapper.AddTensorWrapper (std::move (signal_tensorwrapper)),
76
+ " Failed to add signal tensor." );
77
+ }
78
+
79
+ // Create a name for the expanded tensor
80
+ std::string expanded_tensor_name = signal_input_name + " _expanded" ;
81
+
82
+ // Create the expanded tensor with an extra dimension
83
+ std::vector<uint32_t > expanded_shape = signal_info.shape ;
84
+ expanded_shape.push_back (1 );
85
+
86
+ // Create a tensor info for the expanded tensor based on the signal tensor info
87
+ TensorInfo expanded_tensor_info = signal_info;
88
+ expanded_tensor_info.shape = expanded_shape;
89
+
90
+ // Create the tensor wrapper using MakeTensorWrapper
91
+ QnnTensorWrapper expanded_tensorwrapper;
92
+ ORT_RETURN_IF_ERROR (qnn_model_wrapper.MakeTensorWrapper (expanded_tensor_info,
93
+ expanded_tensor_name,
94
+ expanded_tensorwrapper));
95
+
96
+ ORT_RETURN_IF_NOT (qnn_model_wrapper.AddTensorWrapper (std::move (expanded_tensorwrapper)),
97
+ " Failed to add expanded signal tensor." );
98
+
99
+ // Create axis parameter for ExpandDims (add dimension at the end)
100
+ Qnn_Scalar_t axis_param = QNN_SCALAR_INIT;
101
+ axis_param.dataType = QNN_DATATYPE_UINT_32;
102
+ axis_param.uint32Value = static_cast <uint32_t >(2 ); // Add at the end
103
+
104
+ QnnParamWrapper axis_param_wrapper (node_unit.Index (),
105
+ node_unit.Name (),
106
+ " axis" ,
107
+ axis_param);
108
+
109
+ std::vector<std::string> expand_dims_params;
110
+ expand_dims_params.push_back (axis_param_wrapper.GetParamTensorName ());
111
+ qnn_model_wrapper.AddParamWrapper (std::move (axis_param_wrapper));
112
+
113
+ // Create the ExpandDims node
114
+ ORT_RETURN_IF_NOT (qnn_model_wrapper.CreateQnnNode (
115
+ utils::GetUniqueName (signal_input_name, " _expand_dims" ),
116
+ QNN_OP_PACKAGE_NAME_QTI_AISW,
117
+ QNN_OP_EXPAND_DIMS,
118
+ {signal_input_name},
119
+ {expanded_tensor_name},
120
+ std::move (expand_dims_params),
121
+ do_op_validation),
122
+ " Failed to create ExpandDims node." );
123
+
124
+ // Use the expanded tensor for STFT
125
+ input_names.push_back (expanded_tensor_name);
126
+ } else {
127
+ // Process as normal for rank 3 or higher inputs
128
+ if (qnn_model_wrapper.IsQnnTensorWrapperExist (signal_input_name)) {
129
+ LOGS (logger, VERBOSE) << " Tensor already added, skip it: " << signal_input_name;
130
+ } else {
131
+ QnnTensorWrapper signal_tensorwrapper;
132
+ ORT_RETURN_IF_ERROR (qnn_model_wrapper.MakeTensorWrapper (signal_input, signal_tensorwrapper));
133
+ ORT_RETURN_IF_NOT (qnn_model_wrapper.AddTensorWrapper (std::move (signal_tensorwrapper)),
134
+ " Failed to add signal tensor." );
135
+ }
136
+ input_names.push_back (signal_input_name);
137
+ }
138
+
139
+ // Process frame_step input (second input)
140
+ if (inputs.size () > 1 ) {
141
+ const auto & frame_step_input = inputs[1 ];
142
+ const std::string& frame_step_input_name = frame_step_input.node_arg .Name ();
143
+
144
+ if (qnn_model_wrapper.IsQnnTensorWrapperExist (frame_step_input_name)) {
145
+ LOGS (logger, VERBOSE) << " Tensor already added, skip it: " << frame_step_input_name;
146
+ } else {
147
+ QnnTensorWrapper frame_step_tensorwrapper;
148
+ ORT_RETURN_IF_ERROR (qnn_model_wrapper.MakeTensorWrapper (frame_step_input, frame_step_tensorwrapper));
149
+ ORT_RETURN_IF_NOT (qnn_model_wrapper.AddTensorWrapper (std::move (frame_step_tensorwrapper)),
150
+ " Failed to add frame_step tensor." );
151
+ }
152
+ // We don't add frame_step to input_names because it will be processed as a parameter
153
+ }
154
+
155
+ // Process the 'window' input if it exists and is of type float
156
+ if (inputs.size () > 2 ) {
157
+ const auto & window_input = inputs[2 ];
158
+ if (IsWindowInput (window_input)) {
159
+ const std::string& window_input_name = window_input.node_arg .Name ();
160
+
161
+ if (qnn_model_wrapper.IsQnnTensorWrapperExist (window_input_name)) {
162
+ LOGS (logger, VERBOSE) << " Tensor already added, skip it: " << window_input_name;
163
+ } else {
164
+ QnnTensorWrapper window_tensorwrapper;
165
+ ORT_RETURN_IF_ERROR (qnn_model_wrapper.MakeTensorWrapper (window_input, window_tensorwrapper));
166
+ ORT_RETURN_IF_NOT (qnn_model_wrapper.AddTensorWrapper (std::move (window_tensorwrapper)),
167
+ " Failed to add window tensor." );
168
+ }
169
+ input_names.push_back (window_input_name);
170
+ }
171
+ }
172
+
173
+ // Process frame_length input if it exists
174
+ if (inputs.size () > 3 || (inputs.size () > 2 && !IsWindowInput (inputs[2 ]))) {
175
+ int frame_length_index = inputs.size () > 3 ? 3 : 2 ;
176
+ const auto & frame_length_input = inputs[frame_length_index];
177
+ const std::string& frame_length_input_name = frame_length_input.node_arg .Name ();
178
+
179
+ if (qnn_model_wrapper.IsQnnTensorWrapperExist (frame_length_input_name)) {
180
+ LOGS (logger, VERBOSE) << " Tensor already added, skip it: " << frame_length_input_name;
181
+ } else {
182
+ QnnTensorWrapper frame_length_tensorwrapper;
183
+ ORT_RETURN_IF_ERROR (qnn_model_wrapper.MakeTensorWrapper (frame_length_input, frame_length_tensorwrapper));
184
+ ORT_RETURN_IF_NOT (qnn_model_wrapper.AddTensorWrapper (std::move (frame_length_tensorwrapper)),
185
+ " Failed to add frame_length tensor." );
186
+ }
187
+ // We don't add frame_length to input_names because it will be processed as a parameter
188
+ }
189
+
190
+ return Status::OK ();
191
+ }
192
+
193
+ Status STFTOpBuilder::ProcessAttributesAndOutputs (QnnModelWrapper& qnn_model_wrapper,
194
+ const NodeUnit& node_unit,
195
+ std::vector<std::string>&& input_names,
196
+ const logging::Logger& logger,
197
+ bool do_op_validation) const {
198
+ ORT_UNUSED_PARAMETER (logger);
199
+ NodeAttrHelper node_helper (node_unit);
200
+ const auto & inputs = node_unit.Inputs ();
201
+ bool onesided = node_helper.Get (" onesided" , static_cast <bool >(1 ));
202
+
203
+ std::vector<std::string> param_tensor_names;
204
+ // Extract frame_step from the inputs if it exists
205
+ uint32_t frame_step_info = 0 ;
206
+ TensorInfo frame_step = {};
207
+ TensorInfo frame_length = {};
208
+ uint32_t frame_length_info = 0 ;
209
+
210
+ int frame_length_index = -1 ;
211
+ int frame_step_index = -1 ;
212
+
213
+ // Determine indices for frame_step and frame_length
214
+ if (inputs.size () >= 2 ) {
215
+ frame_step_index = 1 ; // frame_step is always the second input
216
+ }
217
+
218
+ if (inputs.size () == 3 ) {
219
+ // Check if the third input is window or frame_length
220
+ const auto & third_input = inputs[2 ];
221
+ if (!IsWindowInput (third_input)) {
222
+ frame_length_index = 2 ; // It's frame_length
223
+ }
224
+ } else if (inputs.size () > 3 ) {
225
+ frame_length_index = 3 ; // frame_length is the fourth input
226
+ }
227
+
228
+ // Process frame_step
229
+ if (frame_step_index != -1 ) {
230
+ const auto & frame_step_input = inputs[frame_step_index];
231
+ ORT_RETURN_IF_ERROR (qnn_model_wrapper.GetTensorInfo (frame_step_input, frame_step));
232
+ std::vector<uint8_t > frame_step_data;
233
+ ORT_RETURN_IF_ERROR (qnn_model_wrapper.UnpackInitializerData (*frame_step.initializer_tensor , frame_step_data));
234
+ frame_step_info = *reinterpret_cast <uint32_t *>(frame_step_data.data ());
235
+ Qnn_Scalar_t frame_step_param = QNN_SCALAR_INIT;
236
+ frame_step_param.dataType = QNN_DATATYPE_UINT_32;
237
+ frame_step_param.uint32Value = frame_step_info;
238
+ QnnParamWrapper frame_step_param_wrapper (node_unit.Index (),
239
+ node_unit.Name (),
240
+ QNN_OP_STFT_PARAM_FRAME_STEP,
241
+ frame_step_param);
242
+ param_tensor_names.push_back (frame_step_param_wrapper.GetParamTensorName ());
243
+ qnn_model_wrapper.AddParamWrapper (std::move (frame_step_param_wrapper));
244
+ }
245
+
246
+ // Process frame_length if it exists
247
+ if (frame_length_index != -1 ) {
248
+ const auto & frame_length_input = inputs[frame_length_index];
249
+ ORT_RETURN_IF_ERROR (qnn_model_wrapper.GetTensorInfo (frame_length_input, frame_length));
250
+
251
+ std::vector<uint8_t > frame_length_data;
252
+ ORT_RETURN_IF_ERROR (qnn_model_wrapper.UnpackInitializerData (*frame_length.initializer_tensor , frame_length_data));
253
+ frame_length_info = *reinterpret_cast <uint32_t *>(frame_length_data.data ());
254
+
255
+ // Create frame_length parameter
256
+ Qnn_Scalar_t frame_length_param = QNN_SCALAR_INIT;
257
+ frame_length_param.dataType = QNN_DATATYPE_UINT_32;
258
+ frame_length_param.uint32Value = frame_length_info;
259
+ QnnParamWrapper frame_length_param_wrapper (node_unit.Index (),
260
+ node_unit.Name (),
261
+ QNN_OP_STFT_PARAM_FRAME_LENGTH,
262
+ frame_length_param);
263
+ param_tensor_names.push_back (frame_length_param_wrapper.GetParamTensorName ());
264
+ qnn_model_wrapper.AddParamWrapper (std::move (frame_length_param_wrapper));
265
+ }
266
+
267
+ const auto & outputs = node_unit.Outputs ();
268
+ const std::string& output_name = outputs[0 ].node_arg .Name ();
269
+
270
+ bool is_graph_output = qnn_model_wrapper.IsGraphOutput (output_name);
271
+ Qnn_TensorType_t tensor_type = is_graph_output ? QNN_TENSOR_TYPE_APP_READ : QNN_TENSOR_TYPE_NATIVE;
272
+
273
+ TensorInfo output_info{};
274
+ ORT_RETURN_IF_ERROR (qnn_model_wrapper.GetTensorInfo (outputs[0 ], output_info));
275
+
276
+ QnnTensorWrapper output_tensorwrapper (output_name, tensor_type, output_info.qnn_data_type ,
277
+ output_info.quant_param .Copy (), std::move (output_info.shape ));
278
+ ORT_RETURN_IF_NOT (qnn_model_wrapper.AddTensorWrapper (std::move (output_tensorwrapper)), " Failed to add output tensor." );
279
+
280
+ Qnn_Scalar_t onesided_param = QNN_SCALAR_INIT;
281
+ onesided_param.dataType = QNN_DATATYPE_BOOL_8;
282
+ onesided_param.bool8Value = static_cast <bool >(onesided);
283
+
284
+ QnnParamWrapper onesided_param_wrapper (node_unit.Index (),
285
+ node_unit.Name (),
286
+ QNN_OP_STFT_PARAM_ONESIDED,
287
+ onesided_param);
288
+
289
+ param_tensor_names.push_back (onesided_param_wrapper.GetParamTensorName ());
290
+ qnn_model_wrapper.AddParamWrapper (std::move (onesided_param_wrapper));
291
+
292
+ ORT_RETURN_IF_NOT (qnn_model_wrapper.CreateQnnNode (
293
+ utils::GetUniqueName (node_unit),
294
+ QNN_OP_PACKAGE_NAME_QTI_AISW,
295
+ QNN_OP_STFT,
296
+ std::move (input_names),
297
+ {output_name},
298
+ std::move (param_tensor_names),
299
+ do_op_validation),
300
+ " Failed to create STFT node." );
301
+
302
+ return Status::OK ();
303
+ }
304
+
305
+ void CreateSTFTOpBuilder (const std::string& op_type, OpBuilderRegistrations& op_registrations) {
306
+ op_registrations.AddOpBuilder (op_type, std::make_unique<STFTOpBuilder>());
307
+ }
308
+
309
+ } // namespace qnn
310
+ } // namespace onnxruntime
0 commit comments