Skip to content

Commit dfab00e

Browse files
AhmedHussein535JaccovG
authored andcommitted
fix convert bit-exactness
1 parent 09322a6 commit dfab00e

File tree

3 files changed

+72
-77
lines changed

3 files changed

+72
-77
lines changed

lib/src/bricks/impl/mli_prv_quant_vdsp.h

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -580,11 +580,7 @@ MLI_FORCE_INLINE vNx4int_t ir_rnn_result_requantize(
580580
vNx4int_t shift_left = mli_math_max_fx(-total_shift, 0);
581581
vNx4int_t shift_right = mli_math_min_fx(mli_math_max_fx(total_shift, 0), max_int_shift);
582582

583-
vNx4int_t preshift = mli_math_max_fx(shift_right - max_int_shift, 0);
584-
shift_right = shift_right - preshift;
585-
586-
vNx4int_t acc_shifted = mli_math_asr_fx(acc_scaled, preshift);
587-
acc_shifted = mli_math_asr_rnd_fx(acc_shifted, shift_right);
583+
vNx4int_t acc_shifted = mli_math_asr_rnd_fx(acc_scaled, shift_right);
588584
acc_shifted = mli_math_asl_fx(acc_shifted, shift_left);
589585
return acc_shifted;
590586
}

lib/src/helpers/src/impl/mli_hlp_convert_tensor_ref.h

Lines changed: 45 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,37 @@
1515
#include "mli_prv_quant.h"
1616
#include "mli_prv_tensor.h"
1717
#include "mli_types.h"
18+
#include <math.h>
19+
20+
template <typename in_T, typename out_T, typename acc_T>
21+
MLI_FORCE_INLINE void calc_convert(const MLI_PTR(in_T) src_tensor_arr,
22+
MLI_OUT_PTR(out_T) dst_tensor_arr,
23+
const int16_t in_zp, const int16_t scale,
24+
const int16_t scale_shift, const int16_t out_zp) {
25+
if (std::is_same<acc_T, int64_t>::value) {
26+
const int mul_hi_shift = 32;
27+
int32_t src_in_zp = mli_math_sub_fx<int32_t>(*src_tensor_arr, in_zp);
28+
int32_t src_norm = mli_math_norm_fx<int32_t, int32_t>(src_in_zp);
29+
src_in_zp = mli_math_asl_fx<int32_t>(src_in_zp, src_norm);
30+
31+
int32_t scale_norm = mli_math_norm_fx<int32_t, int32_t>((int32_t) scale);
32+
int32_t scale32 = mli_math_asl_fx<int32_t>((int32_t) scale, scale_norm);
33+
34+
int64_t dst_acc = mli_math_mul_fx<int32_t, int64_t>(src_in_zp, scale32);
35+
int32_t acc_hi = dst_acc >> mul_hi_shift;
36+
37+
int32_t dst_acc_shf_casted = mli_math_asr_rnd_fx<int32_t>(acc_hi, scale_shift + scale_norm + src_norm - mul_hi_shift);
38+
int32_t dst_val = mli_math_add_fx<int32_t>(dst_acc_shf_casted, out_zp);
39+
*dst_tensor_arr = mli_math_cast_fx<int32_t, out_T>(dst_val, 0);
40+
} else {
41+
int16_t src_in_zp = mli_math_sub_fx<int16_t>(*src_tensor_arr, in_zp);
42+
acc_T dst_acc = mli_math_mul_fx<int16_t, acc_T>(src_in_zp, scale);
43+
acc_T dst_acc_shf_casted = mli_math_asr_rnd_fx<acc_T>(dst_acc, scale_shift);
44+
acc_T dst_val = mli_math_add_fx<acc_T>(dst_acc_shf_casted, out_zp);
45+
*dst_tensor_arr = mli_math_cast_fx<acc_T, out_T>(dst_val, 0);
46+
}
47+
}
48+
1849

1950
namespace mli {
2051
namespace hlp {
@@ -26,9 +57,6 @@ template <typename in_T, typename out_T, typename acc_T>
2657
mli_status compute_convert_quantized_data(const mli_tensor * src, mli_tensor * dst) {
2758
mli_prv_fx_init_dsp_ctrl();
2859

29-
/* If the accumulator is int64_t, so int32_t should be used for multiplying. */
30-
typedef typename std::conditional<std::is_same<acc_T, int64_t>::value, int32_t, int16_t>::type mul_T;
31-
3260
/* Get Generic Private Tensors */
3361
auto src_prv = mli_prv_get_generic_tensor<MLI_PTR(in_T)>(src);
3462
auto dst_prv = mli_prv_get_generic_tensor<MLI_OUT_PTR(out_T)>(dst);
@@ -63,10 +91,10 @@ mli_status compute_convert_quantized_data(const mli_tensor * src, mli_tensor * d
6391
/* Calculate scale and scaled zero point. */
6492
mli::krn::s8asym_quant_params params;
6593
mli::krn::define_requant_params(src, dst, &params, scale_idx);
66-
const int16_t scale_shift = params.shift;
94+
const int16_t scale_shift = mli_math_min_fx(params.shift, (int16_t) ((sizeof(acc_T) * 8) - 1));
6795
const int16_t scale = params.scale;
68-
int16_t in_zp = mli_hlp_tensor_zero_offset(src, scale_idx);
69-
int16_t out_zp = mli_hlp_tensor_zero_offset(dst, scale_idx);
96+
const int16_t in_zp = mli_hlp_tensor_zero_offset(src, scale_idx);
97+
const int16_t out_zp = mli_hlp_tensor_zero_offset(dst, scale_idx);
7098
/* Calculate borders across all dimensions for slice where this scale is applicable */
7199
int dim_start[MLI_MAX_RANK] = { 0 };
72100
int dim_end[MLI_MAX_RANK] = { 0 };
@@ -84,11 +112,8 @@ mli_status compute_convert_quantized_data(const mli_tensor * src, mli_tensor * d
84112
const int dst_pos = POS(&dst_prv, dim0_idx, dim1_idx, dim2_idx, dim3_idx);
85113
MLI_ASSERT(src_pos < src_tensor_size);
86114
MLI_ASSERT(dst_pos < dst_tensor_size);
87-
mul_T src_in_zp = mli_math_sub_fx<mul_T>(src_tensor_arr[src_pos], in_zp);
88-
acc_T dst_acc = mli_math_mul_fx<mul_T, acc_T>(src_in_zp, scale);
89-
acc_T dst_acc_shf_casted = mli_math_asr_rnd_fx<acc_T>(dst_acc, scale_shift);
90-
acc_T dst_val = mli_math_add_fx<acc_T>(dst_acc_shf_casted, out_zp);
91-
dst_tensor_arr[dst_pos] = mli_math_cast_fx<acc_T, out_T>(dst_val, 0);
115+
calc_convert<in_T, out_T, acc_T>(&src_tensor_arr[src_pos], &dst_tensor_arr[dst_pos],
116+
in_zp, scale, scale_shift, out_zp);
92117
}
93118
}
94119
}
@@ -137,7 +162,7 @@ mli_status convert_float_data(const mli_tensor * src, mli_tensor * dst, convert_
137162

138163
const mli_tensor* tensor = nullptr;
139164
const mli_tensor* float_tensor = nullptr;
140-
165+
141166
/* Defining float_tensor and tensor depending on current conversion direction */
142167
if (mode == mli::hlp::QUANTIZE) {
143168
float_tensor = src;
@@ -171,14 +196,16 @@ mli_status convert_float_data(const mli_tensor * src, mli_tensor * dst, convert_
171196
/* Transformation will be applied on slices across scales dimension (or all tensor) */
172197
for (int scale_idx = 0; scale_idx < scales_num; ++scale_idx) {
173198
/* Calculate current scale and zero offset */
174-
float scale_val;
199+
float scale_val = 1.0;
200+
int8_t frac_bits = mli_hlp_tensor_scale_shift(tensor, scale_idx);
201+
float scale = (float) mli_hlp_tensor_scale(tensor, scale_idx);
175202
if (mode == mli::hlp::QUANTIZE) {
176-
scale_val = (float)((int64_t)1l << mli_hlp_tensor_scale_shift(tensor, scale_idx));
177-
scale_val = scale_val / (float)mli_hlp_tensor_scale(tensor, scale_idx);
178-
} else if (mode == mli::hlp::DEQUANTIZE) {
179-
scale_val = (float)mli_hlp_tensor_scale(tensor, scale_idx);
180-
scale_val = scale_val / (float)((int64_t)1l << mli_hlp_tensor_scale_shift(tensor, scale_idx));
203+
scale = 1.0 / scale;
204+
scale_val = ldexp(scale, ((int32_t) frac_bits));
205+
} else {
206+
scale_val = ldexp(scale, -((int32_t) frac_bits));
181207
}
208+
182209
int16_t zero_offset = mli_hlp_tensor_zero_offset(tensor, scale_idx);
183210

184211
/* Calculate borders across all dimensions for slice where this scale is applicable */

lib/src/helpers/src/impl/mli_hlp_convert_tensor_vdsp.h

Lines changed: 26 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mli_prv_tensor.h"
1717
#include "mli_types.h"
1818

19+
1920
namespace mli {
2021
namespace hlp {
2122
namespace vdsp {
@@ -28,13 +29,12 @@ static MLI_FORCE_INLINE vNx4int_t calc_convert(
2829
const int16_t shift,
2930
const int16_t in_zp,
3031
const int16_t out_zp) {
31-
32-
int shift_right = mli_math_max_fx(shift, 0);
32+
constexpr int max_shift = 31;
33+
int shift_right = mli_math_min_fx(mli_math_max_fx(shift, 0), max_shift);
3334
int shift_left = mli_math_max_fx(-shift, 0);
3435
#ifdef ROUND_UP
3536
uint32_t one = 1u;
3637
int32_t offset = (one << shift_right) >> 1;
37-
offset += (int32_t)out_zp << shift_right;
3838
#else
3939
#error Rounding mode not supported
4040
#endif
@@ -45,6 +45,7 @@ static MLI_FORCE_INLINE vNx4int_t calc_convert(
4545
dst_val = mli_math_add_fx<vNx4int_t>(dst_val, offset);
4646
dst_val = mli_math_asr_fx(dst_val, shift_right);
4747
dst_val = mli_math_asl_fx(dst_val, shift_left);
48+
dst_val = mli_math_add_fx<vNx4int_t>(dst_val, (int32_t) out_zp);
4849

4950
return dst_val;
5051
}
@@ -55,13 +56,12 @@ static MLI_FORCE_INLINE vNx4int_t calc_convert(
5556
const int16_t shift,
5657
const int16_t in_zp,
5758
const int16_t out_zp) {
58-
59-
int shift_right = mli_math_max_fx(shift, 0);
59+
constexpr int max_shift = 31;
60+
int shift_right = mli_math_min_fx(mli_math_max_fx(shift, 0), max_shift);
6061
int shift_left = mli_math_max_fx(-shift, 0);
6162
#ifdef ROUND_UP
6263
uint32_t one = 1u;
6364
int32_t offset = (one << shift_right) >> 1;
64-
offset += (int32_t)out_zp << shift_right;
6565
#else
6666
#error Rounding mode not supported
6767
#endif
@@ -70,6 +70,7 @@ static MLI_FORCE_INLINE vNx4int_t calc_convert(
7070
dst_val = mli_math_add_fx<vNx4int_t>(dst_val, offset);
7171
dst_val = mli_math_asr_fx(dst_val, shift_right);
7272
dst_val = mli_math_asl_fx(dst_val, shift_left);
73+
dst_val = mli_math_add_fx<vNx4int_t>(dst_val, (int32_t) out_zp);
7374

7475
return dst_val;
7576
}
@@ -80,60 +81,31 @@ static MLI_FORCE_INLINE vNx4int_t calc_convert(
8081
const int16_t shift,
8182
const int16_t in_zp,
8283
const int16_t out_zp) {
83-
84-
constexpr int mul_pre_shift = 16;
85-
86-
if( shift > mul_pre_shift ) {
87-
constexpr int mul_hi_shift = 32;
88-
int total_shift = shift - (mul_hi_shift - mul_pre_shift);
89-
int shift_right = mli_math_max_fx(total_shift, 1);
90-
int shift_left = mli_math_max_fx(1 - total_shift, 0);
91-
92-
vNx4int_t src_in_zp = mli_math_sub(input, (int32_t)in_zp);
93-
src_in_zp = mli_math_asl_fx(src_in_zp, shift_left);
94-
auto res = mli_math_mul_fx_high(src_in_zp, ((int32_t)scale << mul_pre_shift));
95-
res = mli_math_asr_rnd_fx(res, shift_right);
96-
res = mli_math_add_fx<vNx4int_t>(res, out_zp);
97-
98-
return res;
99-
} else {
100-
/* input = 2^16 * (input_hi) + input_lo
101-
* input * scale = (2^16 * (input_hi) + input_lo) * scale
102-
* = 2^16 * (input_hi * scale) + (input_lo * scale)
103-
* input * scale * 2^(-shift) = (2^16 * (input_hi * scale) + (input_lo * scale)) * (2^(-shift))
104-
* = (input_hi * scale) * 2^(-(shift - 16)) + (input_lo * scale)) * (2^(-shift)
105-
* = res_hi + res_lo
106-
* where res_hi = (input_hi * scale) * 2^(-(shift - 16))
107-
* and res_lo = (input_lo * scale)) * (2^(-shift)
108-
*/
109-
int shift_hi = shift - mul_pre_shift;
110-
int shift_hi_right = mli_math_max_fx( shift_hi, 0);
111-
int shift_hi_left = mli_math_max_fx(-shift_hi, 0);
112-
int shift_lo_right = mli_math_max_fx( shift, 0);
113-
int shift_lo_left = mli_math_max_fx(-shift, 0);
114-
vNx4int_t src_in_zp = mli_math_sub(input, (int32_t)in_zp);
115-
auto input_lo = to_vNx4ushort_t(src_in_zp & 0xFFFF);
116-
auto input_hi = to_vNx4short_t(src_in_zp >> mul_pre_shift);
117-
auto res_lo = mli_math_mul_su_fx<vNx4short_t, vNx4ushort_t, vNx4accint_t>(scale, input_lo);
118-
res_lo = mli_math_asl_fx(res_lo, shift_lo_left);
119-
res_lo = mli_math_asr_rnd_fx(res_lo, shift_lo_right);
120-
auto res_hi = mli_math_mul_fx<vNx4short_t, vNx4accint_t>(input_hi, scale);
121-
res_hi = mli_math_asl_fx(res_hi, shift_hi_left);
122-
res_hi = mli_math_asr_fx(res_hi, shift_hi_right);
123-
124-
auto res = mli_math_add(res_lo, res_hi);
125-
res = mli_math_add(res, (vNx4int_t)out_zp);
126-
127-
return mli_math_acc_cast_fx<vNx4int_t, vNx4accint_t>(res);
128-
}
84+
constexpr int mul_hi_shift = 32;
85+
constexpr int max_int_shift = 31;
86+
87+
vNx4int_t src_in_zp = mli_math_sub(input, (int32_t)in_zp);
88+
vNx4int_t src_norm = mli_math_norm_fx<vNx4int_t, vNx4int_t>(src_in_zp);
89+
src_in_zp = mli_math_asl_fx<vNx4int_t, vNx4int_t>(src_in_zp, src_norm);
90+
91+
int32_t scale_norm = mli_math_norm_fx<int32_t, int32_t>((int32_t) scale);
92+
int32_t scale_shifted = ((int32_t) scale) << scale_norm;
93+
vNx4int_t res = mli_math_mul_fx_high(src_in_zp, scale_shifted);
94+
vNx4int_t total_shift = mli_math_add_fx<vNx4int_t>(src_norm, (scale_norm - mul_hi_shift + shift));
95+
vNx4int_t shift_left = mli_math_max_fx(-total_shift, 0);
96+
vNx4int_t shift_right = mli_math_min_fx(mli_math_max_fx(total_shift, 0), max_int_shift);
97+
vNx4int_t res_shifted = mli_math_asr_rnd_fx(res, shift_right);
98+
res_shifted = mli_math_asl_fx(res_shifted, shift_left);
99+
res_shifted = mli_math_add_fx<vNx4int_t>(res_shifted, (int32_t) out_zp);
100+
return res_shifted;
129101
}
130102

131103
template <typename out_T>
132104
static MLI_FORCE_INLINE void store_convert(
133105
MLI_OUT_PTR(out_T) out_ptr,
134106
vNx4int_t output,
135107
int remaining_part = 0) {
136-
108+
137109
typedef decltype(mli_prv_load_nx4_samples(out_ptr)) cast_type;
138110

139111
if (remaining_part) {
@@ -165,7 +137,7 @@ static MLI_FORCE_INLINE void store_convert(
165137
const int out_stride,
166138
vNx4int_t output,
167139
int remaining_part = 0) {
168-
140+
169141
typedef decltype(mli_prv_load_nx4_samples(out_ptr)) cast_type;
170142

171143
if (remaining_part) {

0 commit comments

Comments
 (0)