@@ -22,50 +22,8 @@ bool CheckStatus(const OrtApi* g_ort, OrtStatus* status) {
2222 return true ;
2323}
2424
25- template <typename T_QuantType>
26- void QuantizedData (T_QuantType* out, const float * in, int32_t offset, float scale, size_t num_elements) {
27- static_assert (std::is_unsigned<T_QuantType>::value, " QuantizedData supports unsigned only!" );
28-
29- if (nullptr == out || nullptr == in) {
30- throw Ort::Exception (" Received a nullptr" , OrtErrorCode::ORT_EP_FAIL);
31- }
32-
33- size_t data_type_size_in_bytes = sizeof (T_QuantType);
34- size_t bit_width = data_type_size_in_bytes * 8 ;
35- double true_bit_width_max = pow (2 , bit_width) - 1 ;
36- double encoding_min = offset * scale;
37- double encoding_max = (true_bit_width_max + offset) * scale;
38- double encoding_range = encoding_max - encoding_min;
39-
40- for (size_t i = 0 ; i < num_elements; ++i) {
41- int quantized_value = static_cast <int >(round (true_bit_width_max * (in[i] - encoding_min) / encoding_range));
42- if (quantized_value < 0 ) {
43- quantized_value = 0 ;
44- } else if (quantized_value > (int )true_bit_width_max) {
45- quantized_value = (int )true_bit_width_max;
46- }
47- out[i] = static_cast <T_QuantType>(quantized_value);
48- }
49- }
50-
51-
52- template <typename T_QuantType>
53- void DequantizedData (float * out, const T_QuantType* in, int32_t offset, float scale, size_t num_elements) {
54- static_assert (std::is_unsigned<T_QuantType>::value, " DequantizedData supports unsigned only!" );
55-
56- if (nullptr == out || nullptr == in) {
57- throw Ort::Exception (" Received a nullptr" , OrtErrorCode::ORT_EP_FAIL);
58- }
59-
60- for (size_t i = 0 ; i < num_elements; i++) {
61- double quantized_value = static_cast <double >(in[i]);
62- double offset_double = static_cast <double >(offset);
63- out[i] = static_cast <float >((quantized_value + offset_double) * scale);
64- }
65- }
66-
6725void run_ort_qnn_ep (const std::string& backend, const std::string& model_path, const std::string& input_path,
68- bool generated_from_native_qnn, bool generate_ctx, bool float32_model) {
26+ bool generate_ctx, bool float32_model) {
6927 std::wstring model_path_wstr = std::wstring (model_path.begin (), model_path.end ());
7028
7129 const OrtApi* g_ort = OrtGetApiBase ()->GetApi (ORT_API_VERSION);
@@ -195,19 +153,13 @@ void run_ort_qnn_ep(const std::string& backend, const std::string& model_path, c
195153 input_raw_file.read (reinterpret_cast <char *>(&input_data[0 ]), num_elements * sizeof (float ));
196154
197155 CheckStatus (g_ort, g_ort->CreateCpuMemoryInfo (OrtArenaAllocator, OrtMemTypeDefault, &memory_info));
198- // QNN native tool chain generated quantized model use quantized data as inputs & outputs
199- if (generated_from_native_qnn) {
200- size_t input_data_length = input_data_size * sizeof (uint8_t );
201- QuantizedData (quantized_input_data.data (), input_data.data (), -116 , 0 .015875209f , input_data_size);
202- CheckStatus (g_ort, g_ort->CreateTensorWithDataAsOrtValue (
203- memory_info, reinterpret_cast <void *>(quantized_input_data.data ()), input_data_length,
204- input_node_dims[0 ].data (), input_node_dims[0 ].size (), input_types[0 ], &input_tensors[0 ]));
205- } else { // Ort generate QDQ model still use float32 data as inputs & outputs
206- size_t input_data_length = input_data_size * sizeof (float );
207- CheckStatus (g_ort, g_ort->CreateTensorWithDataAsOrtValue (
208- memory_info, reinterpret_cast <void *>(input_data.data ()), input_data_length,
209- input_node_dims[0 ].data (), input_node_dims[0 ].size (), input_types[0 ], &input_tensors[0 ]));
210- }
156+ // QNN native tool chain generated quantized model use quantized data as inputs & outputs by default,
157+ // We wrapped it with Q and DQ node in gen_qnn_ctx_onnx_model.py, so the inputs & outputs are still float
158+ // Ort generate QDQ model still use float32 data as inputs & outputs
159+ size_t input_data_length = input_data_size * sizeof (float );
160+ CheckStatus (g_ort, g_ort->CreateTensorWithDataAsOrtValue (
161+ memory_info, reinterpret_cast <void *>(input_data.data ()), input_data_length,
162+ input_node_dims[0 ].data (), input_node_dims[0 ].size (), input_types[0 ], &input_tensors[0 ]));
211163 g_ort->ReleaseMemoryInfo (memory_info);
212164
213165 CheckStatus (g_ort, g_ort->Run (session, nullptr , input_node_names.data (), (const OrtValue* const *)input_tensors.data (),
@@ -219,13 +171,7 @@ void run_ort_qnn_ep(const std::string& backend, const std::string& model_path, c
219171 void * output_buffer;
220172 CheckStatus (g_ort, g_ort->GetTensorMutableData (output_tensors[0 ], &output_buffer));
221173 float * float_buffer = nullptr ;
222- if (generated_from_native_qnn) {
223- uint8_t * buffer = reinterpret_cast <uint8_t *>(output_buffer);
224- DequantizedData (output_data.data (), buffer, -86 , 0 .08069417f , output_data_size);
225- float_buffer = output_data.data ();
226- } else {
227- float_buffer = reinterpret_cast <float *>(output_buffer);
228- }
174+ float_buffer = reinterpret_cast <float *>(output_buffer);
229175
230176 auto max = std::max_element (float_buffer, float_buffer + output_data_size);
231177 int max_index = static_cast <int >(std::distance (float_buffer, max));
@@ -278,7 +224,6 @@ int main(int argc, char* argv[]) {
278224 }
279225
280226 std::string backend = " " ;
281- bool generated_from_native_qnn = false ;
282227 bool float32_model = false ;
283228 if (strcmp (argv[1 ], CPUBACKEDN) == 0 ) {
284229 backend = " QnnCpu.dll" ;
@@ -290,7 +235,6 @@ int main(int argc, char* argv[]) {
290235 backend = " QnnHtp.dll" ;
291236 } else if (strcmp (argv[1 ], QNNCTXBINARY) == 0 ) {
292237 backend = " QnnHtp.dll" ;
293- generated_from_native_qnn = true ;
294238 if (generate_ctx) {
295239 std::cout << " --gen_ctx won't work with --qnn." << std::endl;
296240 return 1 ;
@@ -309,6 +253,6 @@ int main(int argc, char* argv[]) {
309253 std::string model_path (argv[2 ]);
310254 std::string input_path (argv[3 ]);
311255
312- run_ort_qnn_ep (backend, model_path, input_path, generated_from_native_qnn, generate_ctx, float32_model);
256+ run_ort_qnn_ep (backend, model_path, input_path, generate_ctx, float32_model);
313257 return 0 ;
314258}
0 commit comments