@@ -29,6 +29,7 @@ namespace {
2929 */  
3030void  check_dequantize_args (
3131    const  Tensor& input,
32+     int64_t  zero_point,
3233    int64_t  quant_min,
3334    int64_t  quant_max,
3435    ScalarType dtype,
@@ -39,6 +40,18 @@ void check_dequantize_args(
3940      " input.scalar_type() %" "  is not char type" 
4041      static_cast <int8_t >(input.scalar_type ()));
4142
43+   //  Check zp range
44+   ET_CHECK_MSG (
45+       zero_point >= quant_min,
46+       " zero_point must be %" "  <= quant_min %" 
47+       zero_point,
48+       quant_min);
49+   ET_CHECK_MSG (
50+       zero_point <= quant_max,
51+       " zero_point must be %" "  >= quant_max %" 
52+       zero_point,
53+       quant_max);
54+ 
4255  //  Check output dtype is float
4356  ET_CHECK_MSG (
4457      out.scalar_type () == ScalarType::Float,
@@ -73,18 +86,10 @@ void check_dequantize_args(
7386/* *
7487 * Scalar implementation of quantization for a single value. 
7588 */  
76- template  <typename  K, typename  T>
77- T dequantize_val (
78-     float  scale,
79-     int32_t  zero_point,
80-     K value,
81-     int64_t  quant_min,
82-     int64_t  quant_max) {
83-   (void )quant_min;
84-   (void )quant_max;
85-   return  static_cast <T>((static_cast <int32_t >(value) - zero_point) * scale);
89+ template  <typename  Q, typename  F>
90+ F dequantize_val (float  scale, int32_t  zero_point, Q qvalue) {
91+   return  static_cast <F>((static_cast <int32_t >(qvalue) - zero_point) * scale);
8692}
87- 
8893} //  namespace
8994
9095Tensor& dequantize_per_tensor_out (
@@ -106,29 +111,71 @@ Tensor& dequantize_per_tensor_out(
106111      " Failed to resize out Tensor in dequantize_per_tensor_out" 
107112
108113  //  Validate input parameters
109-   check_dequantize_args (input, quant_min, quant_max, dtype, out);
114+   check_dequantize_args (input, zero_point,  quant_min, quant_max, dtype, out);
110115
111-   //  Pre-compute inverse scale for better performance
112116  int32_t  zp = static_cast <int32_t >(zero_point);
113-   int32_t  qmin = static_cast <int32_t >(quant_min);
114-   int32_t  qmax = static_cast <int32_t >(quant_max);
115117
116118  //  Get pointers to input and output data
117119  const  int8_t * input_data = input.const_data_ptr <int8_t >();
118120  float * out_data = out.mutable_data_ptr <float >();
119121  const  size_t  numel = input.numel ();
120122
123+   size_t  i = 0 ;
121124#if  defined(HAS_HELIUM_SIMD)
122- //  Helium MVE implementation for float32 to int8 quantization
123- #Error " Implement MVE version!" 
124- #else 
125-   //  Scalar implementation for float32 to int8 quantization
126-   for  (size_t  i = 0 ; i < numel; i++) {
127-     out_data[i] =
128-         dequantize_val<int8_t , float >(scale, zp, input_data[i], qmin, qmax);
125+   //  Helium MVE implementation for int8 to float quantization
126+   static  uint8x16_t  voffset{
127+       0x0 ,
128+       0x8 ,
129+       0x4 ,
130+       0xC ,
131+       0x1 ,
132+       0x9 ,
133+       0x5 ,
134+       0xD ,
135+       0x2 ,
136+       0xA ,
137+       0x6 ,
138+       0xE ,
139+       0x3 ,
140+       0xB ,
141+       0x7 ,
142+       0xF };
143+ 
144+   int16x8_t  vzp = vdupq_n_s16 (static_cast <int16_t >(zp));
145+   float32x4_t  vscale = vdupq_n_f32 (static_cast <float >(scale));
146+ 
147+   for  (; i + 15  < numel; i += 16 ) {
148+     int8x16_t  in_084C195D2A6E3B7F =
149+         vldrbq_gather_offset_s8 (input_data, voffset);
150+ 
151+     int16x8_t  in_04152637 = vsubq_s16 (vmovlbq_s8 (in_084C195D2A6E3B7F), vzp);
152+     int16x8_t  in_8C9DAEBF = vsubq_s16 (vmovltq_s8 (in_084C195D2A6E3B7F), vzp);
153+ 
154+     float32x4_t  inf_0123 = vcvtq_f32_s32 (vmovlbq_s16 (in_04152637));
155+     float32x4_t  inf_4567 = vcvtq_f32_s32 (vmovltq_s16 (in_04152637));
156+     float32x4_t  inf_89AB = vcvtq_f32_s32 (vmovlbq_s16 (in_8C9DAEBF));
157+     float32x4_t  inf_CDEF = vcvtq_f32_s32 (vmovltq_s16 (in_8C9DAEBF));
158+ 
159+     float32x4_t  out_0123 = vmulq_f32 (inf_0123, vscale);
160+     float32x4_t  out_4567 = vmulq_f32 (inf_4567, vscale);
161+     float32x4_t  out_89AB = vmulq_f32 (inf_89AB, vscale);
162+     float32x4_t  out_CDEF = vmulq_f32 (inf_CDEF, vscale);
163+ 
164+     vstrwq_f32 (out_data + 0 , out_0123);
165+     vstrwq_f32 (out_data + 4 , out_4567);
166+     vstrwq_f32 (out_data + 8 , out_89AB);
167+     vstrwq_f32 (out_data + 12 , out_CDEF);
168+ 
169+     input_data += 16 ;
170+     out_data += 16 ;
129171  }
130- #endif 
172+ #endif   //  defined(HAS_HELIUM_SIMD) 
131173
174+   for  (; i < numel; i++) {
175+     *out_data = dequantize_val<int8_t , float >(scale, zp, *input_data);
176+     input_data++;
177+     out_data++;
178+   }
132179  return  out;
133180}
134181
0 commit comments