@@ -29,6 +29,7 @@ namespace {
2929 */
3030void check_dequantize_args (
3131 const Tensor& input,
32+ int64_t zero_point,
3233 int64_t quant_min,
3334 int64_t quant_max,
3435 ScalarType dtype,
@@ -39,6 +40,18 @@ void check_dequantize_args(
3940 " input.scalar_type() %" PRId8 " is not char type" ,
4041 static_cast <int8_t >(input.scalar_type ()));
4142
43+ // Check zp range
44+ ET_CHECK_MSG (
45+ zero_point >= quant_min,
46+ " zero_point must be %" PRId64 " <= quant_min %" PRId64,
47+ zero_point,
48+ quant_min);
49+ ET_CHECK_MSG (
50+ zero_point <= quant_max,
51+ " zero_point must be %" PRId64 " >= quant_max %" PRId64,
52+ zero_point,
53+ quant_max);
54+
4255 // Check output dtype is float
4356 ET_CHECK_MSG (
4457 out.scalar_type () == ScalarType::Float,
@@ -73,18 +86,10 @@ void check_dequantize_args(
7386/* *
7487 * Scalar implementation of quantization for a single value.
7588 */
76- template <typename K, typename T>
77- T dequantize_val (
78- float scale,
79- int32_t zero_point,
80- K value,
81- int64_t quant_min,
82- int64_t quant_max) {
83- (void )quant_min;
84- (void )quant_max;
85- return static_cast <T>((static_cast <int32_t >(value) - zero_point) * scale);
89+ template <typename Q, typename F>
90+ F dequantize_val (float scale, int32_t zero_point, Q qvalue) {
91+ return static_cast <F>((static_cast <int32_t >(qvalue) - zero_point) * scale);
8692}
87-
8893} // namespace
8994
9095Tensor& dequantize_per_tensor_out (
@@ -106,29 +111,71 @@ Tensor& dequantize_per_tensor_out(
106111 " Failed to resize out Tensor in dequantize_per_tensor_out" );
107112
108113 // Validate input parameters
109- check_dequantize_args (input, quant_min, quant_max, dtype, out);
114+ check_dequantize_args (input, zero_point, quant_min, quant_max, dtype, out);
110115
111- // Pre-compute inverse scale for better performance
112116 int32_t zp = static_cast <int32_t >(zero_point);
113- int32_t qmin = static_cast <int32_t >(quant_min);
114- int32_t qmax = static_cast <int32_t >(quant_max);
115117
116118 // Get pointers to input and output data
117119 const int8_t * input_data = input.const_data_ptr <int8_t >();
118120 float * out_data = out.mutable_data_ptr <float >();
119121 const size_t numel = input.numel ();
120122
123+ size_t i = 0 ;
121124#if defined(HAS_HELIUM_SIMD)
122- // Helium MVE implementation for float32 to int8 quantization
123- #Error " Implement MVE version!"
124- #else
125- // Scalar implementation for float32 to int8 quantization
126- for (size_t i = 0 ; i < numel; i++) {
127- out_data[i] =
128- dequantize_val<int8_t , float >(scale, zp, input_data[i], qmin, qmax);
125+ // Helium MVE implementation for int8 to float quantization
126+ static uint8x16_t voffset{
127+ 0x0 ,
128+ 0x8 ,
129+ 0x4 ,
130+ 0xC ,
131+ 0x1 ,
132+ 0x9 ,
133+ 0x5 ,
134+ 0xD ,
135+ 0x2 ,
136+ 0xA ,
137+ 0x6 ,
138+ 0xE ,
139+ 0x3 ,
140+ 0xB ,
141+ 0x7 ,
142+ 0xF };
143+
144+ int16x8_t vzp = vdupq_n_s16 (static_cast <int16_t >(zp));
145+ float32x4_t vscale = vdupq_n_f32 (static_cast <float >(scale));
146+
147+ for (; i + 15 < numel; i += 16 ) {
148+ int8x16_t in_084C195D2A6E3B7F =
149+ vldrbq_gather_offset_s8 (input_data, voffset);
150+
151+ int16x8_t in_04152637 = vsubq_s16 (vmovlbq_s8 (in_084C195D2A6E3B7F), vzp);
152+ int16x8_t in_8C9DAEBF = vsubq_s16 (vmovltq_s8 (in_084C195D2A6E3B7F), vzp);
153+
154+ float32x4_t inf_0123 = vcvtq_f32_s32 (vmovlbq_s16 (in_04152637));
155+ float32x4_t inf_4567 = vcvtq_f32_s32 (vmovltq_s16 (in_04152637));
156+ float32x4_t inf_89AB = vcvtq_f32_s32 (vmovlbq_s16 (in_8C9DAEBF));
157+ float32x4_t inf_CDEF = vcvtq_f32_s32 (vmovltq_s16 (in_8C9DAEBF));
158+
159+ float32x4_t out_0123 = vmulq_f32 (inf_0123, vscale);
160+ float32x4_t out_4567 = vmulq_f32 (inf_4567, vscale);
161+ float32x4_t out_89AB = vmulq_f32 (inf_89AB, vscale);
162+ float32x4_t out_CDEF = vmulq_f32 (inf_CDEF, vscale);
163+
164+ vstrwq_f32 (out_data + 0 , out_0123);
165+ vstrwq_f32 (out_data + 4 , out_4567);
166+ vstrwq_f32 (out_data + 8 , out_89AB);
167+ vstrwq_f32 (out_data + 12 , out_CDEF);
168+
169+ input_data += 16 ;
170+ out_data += 16 ;
129171 }
130- #endif
172+ #endif // defined(HAS_HELIUM_SIMD)
131173
174+ for (; i < numel; i++) {
175+ *out_data = dequantize_val<int8_t , float >(scale, zp, *input_data);
176+ input_data++;
177+ out_data++;
178+ }
132179 return out;
133180}
134181
0 commit comments