@@ -58,6 +58,36 @@ void dequantize(
5858 }
5959}
6060
61+ // Requantize the int8_t/uint8_t in value to a uint8_t/int8_t out value.
62+ // The scale and zero_point for requantization are in the args.
63+ template <typename IT, typename OT>
64+ OT requantize (
65+ const IT in,
66+ float in_scale,
67+ int32_t in_zero_point,
68+ float inv_out_scale,
69+ int32_t out_zero_point) {
70+ float dequant = dequantize<IT>(in, in_scale, in_zero_point);
71+ return quantize<OT>(dequant, inv_out_scale, out_zero_point);
72+ }
73+
74+ // Requantize the int8_t/uint8_t in array to a uint8_t/int8_t out array.
75+ // The scale and zero_point for requantization are in the args.
76+ template <typename IT, typename OT>
77+ void requantize (
78+ OT* __restrict__ out,
79+ const IT* __restrict__ in,
80+ float in_scale,
81+ int32_t in_zero_point,
82+ float inv_out_scale,
83+ int32_t out_zero_point,
84+ size_t size) {
85+ for (size_t i = 0 ; i < size; ++i) {
86+ out[i] = requantize<IT, OT>(
87+ in[i], in_scale, in_zero_point, inv_out_scale, out_zero_point);
88+ }
89+ }
90+
6191// explicit template instantiation
6292
6393#define typed_quantize_val (dtype ) \
@@ -106,6 +136,58 @@ typed_dequantize_vec(uint16_t);
106136typed_dequantize_vec (int32_t );
107137#undef typed_dequantize_vec
108138
139+ #define typed_requantize_val (itype, otype ) \
140+ template otype requantize ( \
141+ const itype in, \
142+ float in_scale, \
143+ int32_t in_zero_point, \
144+ float inv_out_scale, \
145+ int32_t out_zero_point);
146+ typed_requantize_val (int8_t , int8_t );
147+ typed_requantize_val (int8_t , uint8_t );
148+ typed_requantize_val (int8_t , int16_t );
149+ typed_requantize_val (int8_t , uint16_t );
150+ typed_requantize_val (uint8_t , int8_t );
151+ typed_requantize_val (uint8_t , uint8_t );
152+ typed_requantize_val (uint8_t , int16_t );
153+ typed_requantize_val (uint8_t , uint16_t );
154+ typed_requantize_val (int16_t , int8_t );
155+ typed_requantize_val (int16_t , uint8_t );
156+ typed_requantize_val (int16_t , int16_t );
157+ typed_requantize_val (int16_t , uint16_t );
158+ typed_requantize_val (uint16_t , int8_t );
159+ typed_requantize_val (uint16_t , uint8_t );
160+ typed_requantize_val (uint16_t , int16_t );
161+ typed_requantize_val (uint16_t , uint16_t );
162+ #undef typed_requantize_val
163+
164+ #define typed_requantize_vec (itype, otype ) \
165+ template void requantize ( \
166+ otype* __restrict__ out, \
167+ const itype* __restrict__ in, \
168+ float in_scale, \
169+ int32_t in_zero_point, \
170+ float inv_out_scale, \
171+ int32_t out_zero_point, \
172+ size_t size);
173+ typed_requantize_vec (int8_t , int8_t );
174+ typed_requantize_vec (int8_t , uint8_t );
175+ typed_requantize_vec (int8_t , int16_t );
176+ typed_requantize_vec (int8_t , uint16_t );
177+ typed_requantize_vec (uint8_t , int8_t );
178+ typed_requantize_vec (uint8_t , uint8_t );
179+ typed_requantize_vec (uint8_t , int16_t );
180+ typed_requantize_vec (uint8_t , uint16_t );
181+ typed_requantize_vec (int16_t , int8_t );
182+ typed_requantize_vec (int16_t , uint8_t );
183+ typed_requantize_vec (int16_t , int16_t );
184+ typed_requantize_vec (int16_t , uint16_t );
185+ typed_requantize_vec (uint16_t , int8_t );
186+ typed_requantize_vec (uint16_t , uint8_t );
187+ typed_requantize_vec (uint16_t , int16_t );
188+ typed_requantize_vec (uint16_t , uint16_t );
189+ #undef typed_requantize_vec
190+
109191}; // namespace kernels
110192}; // namespace reference
111193}; // namespace impl
0 commit comments