Skip to content

Commit 618d455

Browse files
committed
AVX2: Redo decompose_{32,88} with an approach that's easier to explain
The new approach is adapted from our Neon implementation. See <#411 (comment)> for more information on the idea. Bounds reasoning comments are also added. Signed-off-by: jammychiou1 <[email protected]>
1 parent a12ee24 commit 618d455

File tree

4 files changed

+218
-42
lines changed

4 files changed

+218
-42
lines changed

dev/x86_64/src/poly_decompose_32_avx2.c

Lines changed: 55 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,30 +36,75 @@
3636
void mld_poly_decompose_32_avx2(__m256i *a1, __m256i *a0, const __m256i *a)
3737
{
3838
unsigned int i;
39-
__m256i f, f0, f1;
40-
const __m256i q =
41-
_mm256_load_si256(&mld_qdata.vec[MLD_AVX2_BACKEND_DATA_OFFSET_8XQ / 8]);
42-
const __m256i hq = _mm256_srli_epi32(q, 1);
43-
/* check-magic: 1025 == round((2**22*128) / ((MLDSA_Q - 1) / 16)) */
39+
__m256i f, f0, f1, t;
40+
const __m256i q_bound = _mm256_set1_epi32(31 * MLDSA_GAMMA2);
41+
/* check-magic: 1025 == floor(2**22 / 4092) */
4442
const __m256i v = _mm256_set1_epi32(1025);
4543
const __m256i alpha = _mm256_set1_epi32(2 * MLDSA_GAMMA2);
4644
const __m256i off = _mm256_set1_epi32(127);
4745
const __m256i shift = _mm256_set1_epi32(512);
48-
const __m256i mask = _mm256_set1_epi32(15);
4946

5047
for (i = 0; i < MLDSA_N / 8; i++)
5148
{
5249
f = _mm256_load_si256(&a[i]);
50+
51+
/* check-magic: 4092 == 2 * ((MLDSA_Q-1) // 32) // 128 */
52+
/*
53+
* The goal is to compute f1 = round-(f / (2*GAMMA2)), which can be computed
54+
* alternatively as round-(f / (128B)) = round-(ceil(f / 128) / B) where
55+
* B = 2*GAMMA2 / 128 = 4092. Here round-() denotes "round half down".
56+
*
57+
* range: 0 <= f <= Q-1 = 32*GAMMA2 = 16*128*B
58+
*/
59+
60+
/* Compute f1' = ceil(f / 128) as floor((f + 127) >> 7) */
5361
f1 = _mm256_add_epi32(f, off);
5462
f1 = _mm256_srli_epi32(f1, 7);
63+
/*
64+
* range: 0 <= f1' <= (Q-1)/128 = 16B
65+
*
66+
* Also, f1' <= (Q-1)/128 = 2^16 - 2^6 < 2^16 ensures that the odd-index
67+
* 16-bit lanes are all 0, so no bits will be dropped in the input of the
68+
* _mm256_mulhi_epu16() below.
69+
*/
70+
71+
/* check-magic: off */
72+
/*
73+
* Compute f1 = round-(f1' / B) ≈ round(f1' * 1025 / 2^22). This is exact
74+
* for 0 <= f1' < 2^16. Note that half is rounded down since 1025 / 2^22 ≲
75+
* 1 / 4092.
76+
*
77+
* The odd-index 16-bit lanes are still all 0 after this. As such, despite
78+
* that the following steps use 32-bit lanes, the value of f1 is unaffected.
79+
*/
80+
/* check-magic: on */
5581
f1 = _mm256_mulhi_epu16(f1, v);
5682
f1 = _mm256_mulhrs_epi16(f1, shift);
57-
f1 = _mm256_and_si256(f1, mask);
83+
/* range: 0 <= f1 <= 16 */
84+
85+
/*
86+
* If f1 = 16, i.e. f > 31*GAMMA2, proceed as if f' = f - Q was given
87+
* instead. (For f = 31*GAMMA2 + 1 thus f' = -GAMMA2, we still round it to 0
88+
* like other "wrapped around" cases.)
89+
*/
90+
91+
/* Check for wrap-around */
92+
t = _mm256_cmpgt_epi32(f, q_bound);
93+
94+
/* Compute remainder f0 */
5895
f0 = _mm256_mullo_epi32(f1, alpha);
5996
f0 = _mm256_sub_epi32(f, f0);
60-
f = _mm256_cmpgt_epi32(f0, hq);
61-
f = _mm256_and_si256(f, q);
62-
f0 = _mm256_sub_epi32(f0, f);
97+
/*
98+
* range: -GAMMA2 < f0 <= GAMMA2
99+
*
100+
* This holds since f1 = round-(f / (2*GAMMA2)) was computed exactly.
101+
*/
102+
103+
/* If wrap-around is required, set f1 = 0 and f0 -= 1 */
104+
f1 = _mm256_andnot_si256(t, f1);
105+
f0 = _mm256_add_epi32(f0, t);
106+
/* range: 0 <= f1 <= 15, -GAMMA2 <= f0 <= GAMMA2 */
107+
63108
_mm256_store_si256(&a1[i], f1);
64109
_mm256_store_si256(&a0[i], f0);
65110
}

dev/x86_64/src/poly_decompose_88_avx2.c

Lines changed: 54 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,31 +38,74 @@ void mld_poly_decompose_88_avx2(__m256i *a1, __m256i *a0, const __m256i *a)
3838
{
3939
unsigned int i;
4040
__m256i f, f0, f1, t;
41-
const __m256i q =
42-
_mm256_load_si256(&mld_qdata.vec[MLD_AVX2_BACKEND_DATA_OFFSET_8XQ / 8]);
43-
const __m256i hq = _mm256_srli_epi32(q, 1);
44-
/* check-magic: 11275 == round((2**24*128) / ((MLDSA_Q - 1) / 44)) */
41+
const __m256i q_bound = _mm256_set1_epi32(87 * MLDSA_GAMMA2);
42+
/* check-magic: 11275 == floor(2**24 / 1488) */
4543
const __m256i v = _mm256_set1_epi32(11275);
4644
const __m256i alpha = _mm256_set1_epi32(2 * MLDSA_GAMMA2);
4745
const __m256i off = _mm256_set1_epi32(127);
4846
const __m256i shift = _mm256_set1_epi32(128);
49-
const __m256i max = _mm256_set1_epi32(43);
50-
const __m256i zero = _mm256_setzero_si256();
5147

5248
for (i = 0; i < MLDSA_N / 8; i++)
5349
{
5450
f = _mm256_load_si256(&a[i]);
51+
52+
/* check-magic: 1488 == 2 * ((MLDSA_Q-1) // 88) // 128 */
53+
/*
54+
* The goal is to compute f1 = round-(f / (2*GAMMA2)), which can be computed
55+
* alternatively as round-(f / (128B)) = round-(ceil(f / 128) / B) where
56+
* B = 2*GAMMA2 / 128 = 1488. Here round-() denotes "round half down".
57+
*
58+
* range: 0 <= f <= Q-1 = 88*GAMMA2 = 44*128*B
59+
*/
60+
61+
/* Compute f1' = ceil(f / 128) as floor((f + 127) >> 7) */
5562
f1 = _mm256_add_epi32(f, off);
5663
f1 = _mm256_srli_epi32(f1, 7);
64+
/*
65+
* range: 0 <= f1' <= (Q-1)/128 = 44B
66+
*
67+
* Also, f1' <= (Q-1)/128 = 2^16 - 2^6 < 2^16 ensures that the odd-index
68+
* 16-bit lanes are all 0, so no bits will be dropped in the input of the
69+
* _mm256_mulhi_epu16() below.
70+
*/
71+
72+
/* check-magic: off */
73+
/*
74+
* Compute f1 = round-(f1' / B) ≈ round(f1' * 11275 / 2^24). This is exact
75+
* for 0 <= f1' < 2^16. Note that half is rounded down since 11275 / 2^24 ≲
76+
* 1 / 1488.
77+
*
78+
* The odd-index 16-bit lanes are still all 0 after this. As such, despite
79+
* that the following steps use 32-bit lanes, the value of f1 is unaffected.
80+
*/
81+
/* check-magic: on */
5782
f1 = _mm256_mulhi_epu16(f1, v);
5883
f1 = _mm256_mulhrs_epi16(f1, shift);
59-
t = _mm256_sub_epi32(max, f1);
60-
f1 = _mm256_blendv_epi32(f1, zero, t);
84+
/* range: 0 <= f1 <= 44 */
85+
86+
/*
87+
* If f1 = 44, i.e. f > 87*GAMMA2, proceed as if f' = f - Q was given
88+
* instead. (For f = 87*GAMMA2 + 1 thus f' = -GAMMA2, we still round it to 0
89+
* like other "wrapped around" cases.)
90+
*/
91+
92+
/* Check for wrap-around */
93+
t = _mm256_cmpgt_epi32(f, q_bound);
94+
95+
/* Compute remainder f0 */
6196
f0 = _mm256_mullo_epi32(f1, alpha);
6297
f0 = _mm256_sub_epi32(f, f0);
63-
f = _mm256_cmpgt_epi32(f0, hq);
64-
f = _mm256_and_si256(f, q);
65-
f0 = _mm256_sub_epi32(f0, f);
98+
/*
99+
* range: -GAMMA2 < f0 <= GAMMA2
100+
*
101+
* This holds since f1 = round-(f / (2*GAMMA2)) was computed exactly.
102+
*/
103+
104+
/* If wrap-around is required, set f1 = 0 and f0 -= 1 */
105+
f1 = _mm256_andnot_si256(t, f1);
106+
f0 = _mm256_add_epi32(f0, t);
107+
/* range: 0 <= f1 <= 43, -GAMMA2 <= f0 <= GAMMA2 */
108+
66109
_mm256_store_si256(&a1[i], f1);
67110
_mm256_store_si256(&a0[i], f0);
68111
}

mldsa/src/native/x86_64/src/poly_decompose_32_avx2.c

Lines changed: 55 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,30 +36,75 @@
3636
void mld_poly_decompose_32_avx2(__m256i *a1, __m256i *a0, const __m256i *a)
3737
{
3838
unsigned int i;
39-
__m256i f, f0, f1;
40-
const __m256i q =
41-
_mm256_load_si256(&mld_qdata.vec[MLD_AVX2_BACKEND_DATA_OFFSET_8XQ / 8]);
42-
const __m256i hq = _mm256_srli_epi32(q, 1);
43-
/* check-magic: 1025 == round((2**22*128) / ((MLDSA_Q - 1) / 16)) */
39+
__m256i f, f0, f1, t;
40+
const __m256i q_bound = _mm256_set1_epi32(31 * MLDSA_GAMMA2);
41+
/* check-magic: 1025 == floor(2**22 / 4092) */
4442
const __m256i v = _mm256_set1_epi32(1025);
4543
const __m256i alpha = _mm256_set1_epi32(2 * MLDSA_GAMMA2);
4644
const __m256i off = _mm256_set1_epi32(127);
4745
const __m256i shift = _mm256_set1_epi32(512);
48-
const __m256i mask = _mm256_set1_epi32(15);
4946

5047
for (i = 0; i < MLDSA_N / 8; i++)
5148
{
5249
f = _mm256_load_si256(&a[i]);
50+
51+
/* check-magic: 4092 == 2 * ((MLDSA_Q-1) // 32) // 128 */
52+
/*
53+
* The goal is to compute f1 = round-(f / (2*GAMMA2)), which can be computed
54+
* alternatively as round-(f / (128B)) = round-(ceil(f / 128) / B) where
55+
* B = 2*GAMMA2 / 128 = 4092. Here round-() denotes "round half down".
56+
*
57+
* range: 0 <= f <= Q-1 = 32*GAMMA2 = 16*128*B
58+
*/
59+
60+
/* Compute f1' = ceil(f / 128) as floor((f + 127) >> 7) */
5361
f1 = _mm256_add_epi32(f, off);
5462
f1 = _mm256_srli_epi32(f1, 7);
63+
/*
64+
* range: 0 <= f1' <= (Q-1)/128 = 16B
65+
*
66+
* Also, f1' <= (Q-1)/128 = 2^16 - 2^6 < 2^16 ensures that the odd-index
67+
* 16-bit lanes are all 0, so no bits will be dropped in the input of the
68+
* _mm256_mulhi_epu16() below.
69+
*/
70+
71+
/* check-magic: off */
72+
/*
73+
* Compute f1 = round-(f1' / B) ≈ round(f1' * 1025 / 2^22). This is exact
74+
* for 0 <= f1' < 2^16. Note that half is rounded down since 1025 / 2^22 ≲
75+
* 1 / 4092.
76+
*
77+
* The odd-index 16-bit lanes are still all 0 after this. As such, despite
78+
* that the following steps use 32-bit lanes, the value of f1 is unaffected.
79+
*/
80+
/* check-magic: on */
5581
f1 = _mm256_mulhi_epu16(f1, v);
5682
f1 = _mm256_mulhrs_epi16(f1, shift);
57-
f1 = _mm256_and_si256(f1, mask);
83+
/* range: 0 <= f1 <= 16 */
84+
85+
/*
86+
* If f1 = 16, i.e. f > 31*GAMMA2, proceed as if f' = f - Q was given
87+
* instead. (For f = 31*GAMMA2 + 1 thus f' = -GAMMA2, we still round it to 0
88+
* like other "wrapped around" cases.)
89+
*/
90+
91+
/* Check for wrap-around */
92+
t = _mm256_cmpgt_epi32(f, q_bound);
93+
94+
/* Compute remainder f0 */
5895
f0 = _mm256_mullo_epi32(f1, alpha);
5996
f0 = _mm256_sub_epi32(f, f0);
60-
f = _mm256_cmpgt_epi32(f0, hq);
61-
f = _mm256_and_si256(f, q);
62-
f0 = _mm256_sub_epi32(f0, f);
97+
/*
98+
* range: -GAMMA2 < f0 <= GAMMA2
99+
*
100+
* This holds since f1 = round-(f / (2*GAMMA2)) was computed exactly.
101+
*/
102+
103+
/* If wrap-around is required, set f1 = 0 and f0 -= 1 */
104+
f1 = _mm256_andnot_si256(t, f1);
105+
f0 = _mm256_add_epi32(f0, t);
106+
/* range: 0 <= f1 <= 15, -GAMMA2 <= f0 <= GAMMA2 */
107+
63108
_mm256_store_si256(&a1[i], f1);
64109
_mm256_store_si256(&a0[i], f0);
65110
}

mldsa/src/native/x86_64/src/poly_decompose_88_avx2.c

Lines changed: 54 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,31 +38,74 @@ void mld_poly_decompose_88_avx2(__m256i *a1, __m256i *a0, const __m256i *a)
3838
{
3939
unsigned int i;
4040
__m256i f, f0, f1, t;
41-
const __m256i q =
42-
_mm256_load_si256(&mld_qdata.vec[MLD_AVX2_BACKEND_DATA_OFFSET_8XQ / 8]);
43-
const __m256i hq = _mm256_srli_epi32(q, 1);
44-
/* check-magic: 11275 == round((2**24*128) / ((MLDSA_Q - 1) / 44)) */
41+
const __m256i q_bound = _mm256_set1_epi32(87 * MLDSA_GAMMA2);
42+
/* check-magic: 11275 == floor(2**24 / 1488) */
4543
const __m256i v = _mm256_set1_epi32(11275);
4644
const __m256i alpha = _mm256_set1_epi32(2 * MLDSA_GAMMA2);
4745
const __m256i off = _mm256_set1_epi32(127);
4846
const __m256i shift = _mm256_set1_epi32(128);
49-
const __m256i max = _mm256_set1_epi32(43);
50-
const __m256i zero = _mm256_setzero_si256();
5147

5248
for (i = 0; i < MLDSA_N / 8; i++)
5349
{
5450
f = _mm256_load_si256(&a[i]);
51+
52+
/* check-magic: 1488 == 2 * ((MLDSA_Q-1) // 88) // 128 */
53+
/*
54+
* The goal is to compute f1 = round-(f / (2*GAMMA2)), which can be computed
55+
* alternatively as round-(f / (128B)) = round-(ceil(f / 128) / B) where
56+
* B = 2*GAMMA2 / 128 = 1488. Here round-() denotes "round half down".
57+
*
58+
* range: 0 <= f <= Q-1 = 88*GAMMA2 = 44*128*B
59+
*/
60+
61+
/* Compute f1' = ceil(f / 128) as floor((f + 127) >> 7) */
5562
f1 = _mm256_add_epi32(f, off);
5663
f1 = _mm256_srli_epi32(f1, 7);
64+
/*
65+
* range: 0 <= f1' <= (Q-1)/128 = 44B
66+
*
67+
* Also, f1' <= (Q-1)/128 = 2^16 - 2^6 < 2^16 ensures that the odd-index
68+
* 16-bit lanes are all 0, so no bits will be dropped in the input of the
69+
* _mm256_mulhi_epu16() below.
70+
*/
71+
72+
/* check-magic: off */
73+
/*
74+
* Compute f1 = round-(f1' / B) ≈ round(f1' * 11275 / 2^24). This is exact
75+
* for 0 <= f1' < 2^16. Note that half is rounded down since 11275 / 2^24 ≲
76+
* 1 / 1488.
77+
*
78+
* The odd-index 16-bit lanes are still all 0 after this. As such, despite
79+
* that the following steps use 32-bit lanes, the value of f1 is unaffected.
80+
*/
81+
/* check-magic: on */
5782
f1 = _mm256_mulhi_epu16(f1, v);
5883
f1 = _mm256_mulhrs_epi16(f1, shift);
59-
t = _mm256_sub_epi32(max, f1);
60-
f1 = _mm256_blendv_epi32(f1, zero, t);
84+
/* range: 0 <= f1 <= 44 */
85+
86+
/*
87+
* If f1 = 44, i.e. f > 87*GAMMA2, proceed as if f' = f - Q was given
88+
* instead. (For f = 87*GAMMA2 + 1 thus f' = -GAMMA2, we still round it to 0
89+
* like other "wrapped around" cases.)
90+
*/
91+
92+
/* Check for wrap-around */
93+
t = _mm256_cmpgt_epi32(f, q_bound);
94+
95+
/* Compute remainder f0 */
6196
f0 = _mm256_mullo_epi32(f1, alpha);
6297
f0 = _mm256_sub_epi32(f, f0);
63-
f = _mm256_cmpgt_epi32(f0, hq);
64-
f = _mm256_and_si256(f, q);
65-
f0 = _mm256_sub_epi32(f0, f);
98+
/*
99+
* range: -GAMMA2 < f0 <= GAMMA2
100+
*
101+
* This holds since f1 = round-(f / (2*GAMMA2)) was computed exactly.
102+
*/
103+
104+
/* If wrap-around is required, set f1 = 0 and f0 -= 1 */
105+
f1 = _mm256_andnot_si256(t, f1);
106+
f0 = _mm256_add_epi32(f0, t);
107+
/* range: 0 <= f1 <= 43, -GAMMA2 <= f0 <= GAMMA2 */
108+
66109
_mm256_store_si256(&a1[i], f1);
67110
_mm256_store_si256(&a0[i], f0);
68111
}

0 commit comments

Comments
 (0)