33// ==============================================================================
44
55#include " HTP/core/constraints.h"
6+ #include " HTP/core/intrinsics.h"
67#include " HTP/core/op_package_feature_support.h"
78#include " HTP/core/op_register_ext.h"
89#include " HTP/core/optimize.h"
910#include " HTP/core/simple_reg.h"
1011#include " QnnOpPackage.h"
12+ #include " hexagon_protos.h"
1113#include " hexagon_types.h"
1214#include " hvx_hexagon_protos.h"
1315
@@ -163,7 +165,51 @@ GraphStatus examplecustomopImpl(TensorType& out_0, const TensorType& in_0)
163165 if (input_intfc.dtype == DType::Float32) {
164166 const float * p_input = static_cast <const float *>(in_0.raw_data_const ());
165167 float * p_output = static_cast <float *>(out_0.raw_data ());
166- const int multiplier = 3 ;
168+ const size_t N = in_0.total_storage_elements ();
169+
170+ // Allocate temporary FP16 buffers on stack or heap
171+ std::vector<Float16> tmp_in (N);
172+ std::vector<Float16> tmp_out (N);
173+
174+ // 1. Convert FP32 -> FP16
175+ for (size_t i = 0 ; i < N; ++i) {
176+ tmp_in[i] = static_cast <Float16>(p_input[i]);
177+ }
178+
179+ #ifdef __hexagon__
180+ // 2. Run HVX multiply (FP16 domain)
181+ union {
182+ Float16 f16 ;
183+ uint16_t bits;
184+ } f3 = {static_cast <Float16>(3 .0f )};
185+ HVX_Vector v_mul = Q6_Vh_vsplat_R (f3.bits );
186+
187+ const int vector_bytes = 128 ;
188+ const int elems_per_vec = vector_bytes / sizeof (Float16);
189+
190+ for (size_t i = 0 ; i < N; i += elems_per_vec) {
191+ HVX_Vector vin = q6op_V_vldu_A (&tmp_in[i]);
192+ HVX_Vector vout = Q6_Vhf_vmpy_VhfVhf (vin, v_mul);
193+ q6op_vstu_AV (&tmp_out[i], vout);
194+ }
195+ #else
196+ // 2. Fallback scalar multiply
197+ for (size_t i = 0 ; i < N; ++i) {
198+ tmp_out[i] = static_cast <Float16>(tmp_in[i] * static_cast <Float16>(3 .0f ));
199+ }
200+ #endif
201+
202+ // 3. Convert FP16 -> FP32
203+ for (size_t i = 0 ; i < N; ++i) {
204+ p_output[i] = static_cast <float >(tmp_out[i]);
205+ }
206+
207+ return GraphStatus::Success;
208+ } else if (input_intfc.dtype == DType::QUInt8) {
209+ // printf("[QNN ExecuTorch Op Package test] input is QUInt8\n");
210+ const uint8_t * p_input = static_cast <const uint8_t *>(in_0.raw_data_const ());
211+ uint8_t * p_output = static_cast <uint8_t *>(out_0.raw_data ());
212+ const int multiplier = 3 * input_intfc.scale / out_intfc.scale ;
167213 for (size_t i = 0 ; i < input_num_elements; ++i) {
168214 p_output[i] = multiplier * p_input[i];
169215
@@ -177,59 +223,6 @@ GraphStatus examplecustomopImpl(TensorType& out_0, const TensorType& in_0)
177223 i,
178224 p_output[i]);
179225 }
180- } else if (input_intfc.dtype == DType::QUInt8) {
181- // const uint8_t* p_input = static_cast<const
182- // uint8_t*>(in_0.raw_data_const()); uint8_t* p_output =
183- // static_cast<uint8_t*>(out_0.raw_data()); const int multiplier = 3 *
184- // input_intfc.scale / out_intfc.scale; for (size_t i = 0; i <
185- // input_num_elements; ++i) {
186- // p_output[i] = multiplier * p_input[i];
187-
188- // FARF(
189- // ALWAYS,
190- // "[QNN ExecuTorch Op Package test]"
191- // "input0[%zu]=%f, multiplier=%d, output[%zu]=%f",
192- // i,
193- // p_input[i],
194- // multiplier,
195- // i,
196- // p_output[i]);
197- // }
198-
199- const uint8_t * p_input = static_cast <const uint8_t *>(in_0.raw_data_const ());
200- uint8_t * p_output = static_cast <uint8_t *>(out_0.raw_data ());
201- const float multiplier_f = 3 .0f * input_intfc.scale / out_intfc.scale ;
202- const int multiplier =
203- static_cast <int >(multiplier_f * 128 .0f ); // fixed-point scale
204-
205- const HVX_Vector* in_vec = reinterpret_cast <const HVX_Vector*>(p_input);
206- HVX_Vector* out_vec = reinterpret_cast <HVX_Vector*>(p_output);
207-
208- HVX_Vector v_mult = Q6_V_vsplat_R (multiplier & 0xFF );
209- HVX_Vector vzero = Q6_V_vzero ();
210-
211- const size_t vec_elems = 128 ; // 128 bytes per HVX vector
212- const size_t nvecs = input_num_elements / vec_elems;
213-
214- for (size_t i = 0 ; i < nvecs; ++i) {
215- HVX_Vector vin = Q6_V_vldu_A (in_vec + i);
216- HVX_Vector vout;
217-
218- #if defined(__HEXAGON_ARCH__)
219- // use available multiply intrinsic
220- vout = Q6_Vub_vmpy_VubRb_s1_rnd_sat (vin, v_mult);
221- #else
222- // fallback scalar multiply for x86 simulation
223- alignas (128 ) uint8_t tmp_in[128 ], tmp_out[128 ];
224- memcpy (tmp_in, p_input + i * 128 , 128 );
225- for (int j = 0 ; j < 128 ; ++j)
226- tmp_out[j] = std::min (255 , (tmp_in[j] * multiplier) >> 7 );
227- memcpy (p_output + i * 128 , tmp_out, 128 );
228- continue ;
229- #endif
230-
231- Q6_V_vstu_A (out_vec + i, vout);
232- }
233226 }
234227
235228 return GraphStatus::Success;
0 commit comments