Skip to content

Commit 6ca67e1

Browse files
committed
AVX2: Update decompose approach used in use_hint
Edit some comments while we're at it. Signed-off-by: jammychiou1 <[email protected]>
1 parent c4e5647 commit 6ca67e1

File tree

4 files changed

+44
-50
lines changed

4 files changed

+44
-50
lines changed

dev/x86_64/src/poly_use_hint_32_avx2.c

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,8 @@ void mld_poly_use_hint_32_avx2(__m256i *b, const __m256i *a,
3838
{
3939
unsigned int i;
4040
__m256i f, f0, f1, h, t;
41-
const __m256i q =
42-
_mm256_load_si256(&mld_qdata.vec[MLD_AVX2_BACKEND_DATA_OFFSET_8XQ / 8]);
43-
const __m256i hq = _mm256_srli_epi32(q, 1);
44-
/* check-magic: 1025 == round((2**22*128) / ((MLDSA_Q - 1) / 16)) */
41+
const __m256i q_bound = _mm256_set1_epi32(87 * MLDSA_GAMMA2);
42+
/* check-magic: 1025 == floor(2**22 / 4092) */
4543
const __m256i v = _mm256_set1_epi32(1025);
4644
const __m256i alpha = _mm256_set1_epi32(2 * MLDSA_GAMMA2);
4745
const __m256i off = _mm256_set1_epi32(127);
@@ -54,26 +52,27 @@ void mld_poly_use_hint_32_avx2(__m256i *b, const __m256i *a,
5452
f = _mm256_load_si256(&a[i]);
5553
h = _mm256_load_si256(&hint[i]);
5654

57-
/* Reference: The reference avx2 implementation calls poly_decompose to
58-
* compute all a1, a0 before the loop.
55+
/* Reference:
56+
* - @[REF_AVX2] calls poly_decompose to compute all a1, a0 before the loop.
57+
* - Our implementation of decompose() is slightly different from that in
58+
* @[REF_AVX2]. See poly_decompose_32_avx2.c for more information.
5959
*/
60-
/* decompose */
60+
/* f1, f2 = decompose(f) */
6161
f1 = _mm256_add_epi32(f, off);
6262
f1 = _mm256_srli_epi32(f1, 7);
6363
f1 = _mm256_mulhi_epu16(f1, v);
6464
f1 = _mm256_mulhrs_epi16(f1, shift);
65-
f1 = _mm256_and_si256(f1, mask);
65+
t = _mm256_cmpgt_epi32(f, q_bound);
6666
f0 = _mm256_mullo_epi32(f1, alpha);
6767
f0 = _mm256_sub_epi32(f, f0);
68-
f = _mm256_cmpgt_epi32(f0, hq);
69-
f = _mm256_and_si256(f, q);
70-
f0 = _mm256_sub_epi32(f0, f);
68+
f1 = _mm256_andnot_si256(t, f1);
69+
f0 = _mm256_add_epi32(f0, t);
7170

7271
/* Reference: The reference avx2 implementation checks a0 >= 0, which is
7372
* different from the specification and the reference C implementation. We
7473
* follow the specification and check a0 > 0.
7574
*/
76-
/* t = (a0 > 0) ? h : -h */
75+
/* t = (f0 > 0) ? h : -h */
7776
f0 = _mm256_cmpgt_epi32(f0, zero);
7877
t = _mm256_blendv_epi32(h, zero, f0);
7978
t = _mm256_slli_epi32(t, 1);

dev/x86_64/src/poly_use_hint_88_avx2.c

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,8 @@ void mld_poly_use_hint_88_avx2(__m256i *b, const __m256i *a,
3838
{
3939
unsigned int i;
4040
__m256i f, f0, f1, h, t;
41-
const __m256i q =
42-
_mm256_load_si256(&mld_qdata.vec[MLD_AVX2_BACKEND_DATA_OFFSET_8XQ / 8]);
43-
const __m256i hq = _mm256_srli_epi32(q, 1);
44-
/* check-magic: 11275 == round((2**24*128) / ((MLDSA_Q - 1) / 44)) */
41+
const __m256i q_bound = _mm256_set1_epi32(87 * MLDSA_GAMMA2);
42+
/* check-magic: 11275 == floor(2**24 / 1488) */
4543
const __m256i v = _mm256_set1_epi32(11275);
4644
const __m256i alpha = _mm256_set1_epi32(2 * MLDSA_GAMMA2);
4745
const __m256i off = _mm256_set1_epi32(127);
@@ -54,27 +52,27 @@ void mld_poly_use_hint_88_avx2(__m256i *b, const __m256i *a,
5452
f = _mm256_load_si256(&a[i]);
5553
h = _mm256_load_si256(&hint[i]);
5654

57-
/* Reference: The reference avx2 implementation calls poly_decompose to
58-
* compute all a1, a0 before the loop.
55+
/* Reference:
56+
* - @[REF_AVX2] calls poly_decompose to compute all a1, a0 before the loop.
57+
* - Our implementation of decompose() is slightly different from that in
58+
* @[REF_AVX2]. See poly_decompose_88_avx2.c for more information.
5959
*/
60-
/* decompose */
60+
/* f1, f2 = decompose(f) */
6161
f1 = _mm256_add_epi32(f, off);
6262
f1 = _mm256_srli_epi32(f1, 7);
6363
f1 = _mm256_mulhi_epu16(f1, v);
6464
f1 = _mm256_mulhrs_epi16(f1, shift);
65-
t = _mm256_sub_epi32(max, f1);
66-
f1 = _mm256_blendv_epi32(f1, zero, t);
65+
t = _mm256_cmpgt_epi32(f, q_bound);
6766
f0 = _mm256_mullo_epi32(f1, alpha);
6867
f0 = _mm256_sub_epi32(f, f0);
69-
f = _mm256_cmpgt_epi32(f0, hq);
70-
f = _mm256_and_si256(f, q);
71-
f0 = _mm256_sub_epi32(f0, f);
68+
f1 = _mm256_andnot_si256(t, f1);
69+
f0 = _mm256_add_epi32(f0, t);
7270

7371
/* Reference: The reference avx2 implementation checks a0 >= 0, which is
7472
* different from the specification and the reference C implementation. We
7573
* follow the specification and check a0 > 0.
7674
*/
77-
/* t = (a0 > 0) ? h : -h */
75+
/* t = (f0 > 0) ? h : -h */
7876
f0 = _mm256_cmpgt_epi32(f0, zero);
7977
t = _mm256_blendv_epi32(h, zero, f0);
8078
t = _mm256_slli_epi32(t, 1);

mldsa/native/x86_64/src/poly_use_hint_32_avx2.c

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,8 @@ void mld_poly_use_hint_32_avx2(__m256i *b, const __m256i *a,
3838
{
3939
unsigned int i;
4040
__m256i f, f0, f1, h, t;
41-
const __m256i q =
42-
_mm256_load_si256(&mld_qdata.vec[MLD_AVX2_BACKEND_DATA_OFFSET_8XQ / 8]);
43-
const __m256i hq = _mm256_srli_epi32(q, 1);
44-
/* check-magic: 1025 == round((2**22*128) / ((MLDSA_Q - 1) / 16)) */
41+
const __m256i q_bound = _mm256_set1_epi32(87 * MLDSA_GAMMA2);
42+
/* check-magic: 1025 == floor(2**22 / 4092) */
4543
const __m256i v = _mm256_set1_epi32(1025);
4644
const __m256i alpha = _mm256_set1_epi32(2 * MLDSA_GAMMA2);
4745
const __m256i off = _mm256_set1_epi32(127);
@@ -54,26 +52,27 @@ void mld_poly_use_hint_32_avx2(__m256i *b, const __m256i *a,
5452
f = _mm256_load_si256(&a[i]);
5553
h = _mm256_load_si256(&hint[i]);
5654

57-
/* Reference: The reference avx2 implementation calls poly_decompose to
58-
* compute all a1, a0 before the loop.
55+
/* Reference:
56+
* - @[REF_AVX2] calls poly_decompose to compute all a1, a0 before the loop.
57+
* - Our implementation of decompose() is slightly different from that in
58+
* @[REF_AVX2]. See poly_decompose_32_avx2.c for more information.
5959
*/
60-
/* decompose */
60+
/* f1, f2 = decompose(f) */
6161
f1 = _mm256_add_epi32(f, off);
6262
f1 = _mm256_srli_epi32(f1, 7);
6363
f1 = _mm256_mulhi_epu16(f1, v);
6464
f1 = _mm256_mulhrs_epi16(f1, shift);
65-
f1 = _mm256_and_si256(f1, mask);
65+
t = _mm256_cmpgt_epi32(f, q_bound);
6666
f0 = _mm256_mullo_epi32(f1, alpha);
6767
f0 = _mm256_sub_epi32(f, f0);
68-
f = _mm256_cmpgt_epi32(f0, hq);
69-
f = _mm256_and_si256(f, q);
70-
f0 = _mm256_sub_epi32(f0, f);
68+
f1 = _mm256_andnot_si256(t, f1);
69+
f0 = _mm256_add_epi32(f0, t);
7170

7271
/* Reference: The reference avx2 implementation checks a0 >= 0, which is
7372
* different from the specification and the reference C implementation. We
7473
* follow the specification and check a0 > 0.
7574
*/
76-
/* t = (a0 > 0) ? h : -h */
75+
/* t = (f0 > 0) ? h : -h */
7776
f0 = _mm256_cmpgt_epi32(f0, zero);
7877
t = _mm256_blendv_epi32(h, zero, f0);
7978
t = _mm256_slli_epi32(t, 1);

mldsa/native/x86_64/src/poly_use_hint_88_avx2.c

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,8 @@ void mld_poly_use_hint_88_avx2(__m256i *b, const __m256i *a,
3838
{
3939
unsigned int i;
4040
__m256i f, f0, f1, h, t;
41-
const __m256i q =
42-
_mm256_load_si256(&mld_qdata.vec[MLD_AVX2_BACKEND_DATA_OFFSET_8XQ / 8]);
43-
const __m256i hq = _mm256_srli_epi32(q, 1);
44-
/* check-magic: 11275 == round((2**24*128) / ((MLDSA_Q - 1) / 44)) */
41+
const __m256i q_bound = _mm256_set1_epi32(87 * MLDSA_GAMMA2);
42+
/* check-magic: 11275 == floor(2**24 / 1488) */
4543
const __m256i v = _mm256_set1_epi32(11275);
4644
const __m256i alpha = _mm256_set1_epi32(2 * MLDSA_GAMMA2);
4745
const __m256i off = _mm256_set1_epi32(127);
@@ -54,27 +52,27 @@ void mld_poly_use_hint_88_avx2(__m256i *b, const __m256i *a,
5452
f = _mm256_load_si256(&a[i]);
5553
h = _mm256_load_si256(&hint[i]);
5654

57-
/* Reference: The reference avx2 implementation calls poly_decompose to
58-
* compute all a1, a0 before the loop.
55+
/* Reference:
56+
* - @[REF_AVX2] calls poly_decompose to compute all a1, a0 before the loop.
57+
* - Our implementation of decompose() is slightly different from that in
58+
* @[REF_AVX2]. See poly_decompose_88_avx2.c for more information.
5959
*/
60-
/* decompose */
60+
/* f1, f2 = decompose(f) */
6161
f1 = _mm256_add_epi32(f, off);
6262
f1 = _mm256_srli_epi32(f1, 7);
6363
f1 = _mm256_mulhi_epu16(f1, v);
6464
f1 = _mm256_mulhrs_epi16(f1, shift);
65-
t = _mm256_sub_epi32(max, f1);
66-
f1 = _mm256_blendv_epi32(f1, zero, t);
65+
t = _mm256_cmpgt_epi32(f, q_bound);
6766
f0 = _mm256_mullo_epi32(f1, alpha);
6867
f0 = _mm256_sub_epi32(f, f0);
69-
f = _mm256_cmpgt_epi32(f0, hq);
70-
f = _mm256_and_si256(f, q);
71-
f0 = _mm256_sub_epi32(f0, f);
68+
f1 = _mm256_andnot_si256(t, f1);
69+
f0 = _mm256_add_epi32(f0, t);
7270

7371
/* Reference: The reference avx2 implementation checks a0 >= 0, which is
7472
* different from the specification and the reference C implementation. We
7573
* follow the specification and check a0 > 0.
7674
*/
77-
/* t = (a0 > 0) ? h : -h */
75+
/* t = (f0 > 0) ? h : -h */
7876
f0 = _mm256_cmpgt_epi32(f0, zero);
7977
t = _mm256_blendv_epi32(h, zero, f0);
8078
t = _mm256_slli_epi32(t, 1);

0 commit comments

Comments
 (0)