From 4521179a0feee25b1562d7a52521187ae41eb1a5 Mon Sep 17 00:00:00 2001 From: jammychiou1 Date: Sun, 26 Oct 2025 22:57:05 +0800 Subject: [PATCH 1/4] Add bounds reasoning comments to AVX2 ntt/intt Signed-off-by: jammychiou1 --- dev/x86_64/src/intt.S | 51 ++++++++++++++++++++++++++++++++++++++++ dev/x86_64/src/ntt.S | 54 ++++++++++++++++++++++++++++++++++++++----- 2 files changed, 99 insertions(+), 6 deletions(-) diff --git a/dev/x86_64/src/intt.S b/dev/x86_64/src/intt.S index f45d0fd87..85acfb1b9 100644 --- a/dev/x86_64/src/intt.S +++ b/dev/x86_64/src/intt.S @@ -43,6 +43,12 @@ vpsrlq $32,%ymm\r0,%ymm\r0 vpblendd $0xAA,%ymm\r1,%ymm\r0,%ymm\r3 .endm +/* + * Compute l + h, montmul(h - l, zh) then store the results back to l, h + * respectively. + * + * The general abs bound of Montgomery multiplication is 3q/4. + */ .macro butterfly l,h,zl0=1,zl1=1,zh0=2,zh1=2 vpsubd %ymm\l,%ymm\h,%ymm12 vpaddd %ymm\h,%ymm\l,%ymm\l @@ -74,6 +80,8 @@ vmovdqa 256*\off+160(%rdi),%ymm9 vmovdqa 256*\off+192(%rdi),%ymm10 vmovdqa 256*\off+224(%rdi),%ymm11 +/* All: abs bound < q */ + /* level 0 */ vpermq $0x1B,(MLD_AVX2_BACKEND_DATA_OFFSET_ZETAS_QINV+296-8*\off-8)*4(%rsi),%ymm3 vpermq $0x1B,(MLD_AVX2_BACKEND_DATA_OFFSET_ZETAS+296-8*\off-8)*4(%rsi),%ymm15 @@ -99,6 +107,19 @@ vmovshdup %ymm3,%ymm1 vmovshdup %ymm15,%ymm2 butterfly 10,11,1,3,2,15 +/* 4, 6, 8, 10: abs bound < 2q; 5, 7, 9, 11: abs bound < 3q/4 */ +/* + * Note that since 2^31 / q > 256, the sum of all 256 coefficients does not + * overflow. This allows us to greatly simplify the range analysis by relaxing + * and unifying the bounds of all coefficients on the same layer. As a concrete + * example, here we relax the bounds on 5, 7, 9, 11 and conclude that + * + * All: abs bound < 2q + * + * In all but last of the following layers, we do the same relaxation without + * explicit mention. + */ + /* level 1 */ vpermq $0x1B,(MLD_AVX2_BACKEND_DATA_OFFSET_ZETAS_QINV+168-8*\off-8)*4(%rsi),%ymm3 vpermq $0x1B,(MLD_AVX2_BACKEND_DATA_OFFSET_ZETAS+168-8*\off-8)*4(%rsi),%ymm15 @@ -114,6 +135,8 @@ vmovshdup %ymm15,%ymm2 butterfly 8,10,1,3,2,15 butterfly 9,11,1,3,2,15 +/* All: abs bound < 4q */ + /* level 2 */ vpermq $0x1B,(MLD_AVX2_BACKEND_DATA_OFFSET_ZETAS_QINV+104-8*\off-8)*4(%rsi),%ymm3 vpermq $0x1B,(MLD_AVX2_BACKEND_DATA_OFFSET_ZETAS+104-8*\off-8)*4(%rsi),%ymm15 @@ -124,6 +147,8 @@ butterfly 5,9,1,3,2,15 butterfly 6,10,1,3,2,15 butterfly 7,11,1,3,2,15 +/* All: abs bound < 8q */ + /* level 3 */ shuffle2 4,5,3,5 shuffle2 6,7,4,7 @@ -137,6 +162,8 @@ butterfly 4,7 butterfly 6,9 butterfly 8,11 +/* All: abs bound < 16q */ + /* level 4 */ shuffle4 3,4,10,4 shuffle4 6,8,3,8 @@ -150,6 +177,8 @@ butterfly 3,8 butterfly 6,7 butterfly 5,11 +/* All: abs bound < 32q */ + /* level 5 */ shuffle8 10,3,9,3 shuffle8 6,5,10,5 @@ -163,6 +192,8 @@ butterfly 10,5 butterfly 6,8 butterfly 4,11 +/* All: abs bound < 64q */ + vmovdqa %ymm9,256*\off+ 0(%rdi) vmovdqa %ymm10,256*\off+ 32(%rdi) vmovdqa %ymm6,256*\off+ 64(%rdi) @@ -194,6 +225,8 @@ vpbroadcastd (MLD_AVX2_BACKEND_DATA_OFFSET_ZETAS+2)*4(%rsi),%ymm2 butterfly 8,10 butterfly 9,11 +/* All: abs bound < 128q */ + /* level 7 */ vpbroadcastd (MLD_AVX2_BACKEND_DATA_OFFSET_ZETAS_QINV+0)*4(%rsi),%ymm1 vpbroadcastd (MLD_AVX2_BACKEND_DATA_OFFSET_ZETAS+0)*4(%rsi),%ymm2 @@ -203,11 +236,27 @@ butterfly 5,9 butterfly 6,10 butterfly 7,11 +/* 4, 5, 6, 7: abs bound < 256q; 8, 9, 10, 11: abs bound < 3q/4 */ + vmovdqa %ymm8,512+32*\off(%rdi) vmovdqa %ymm9,640+32*\off(%rdi) vmovdqa %ymm10,768+32*\off(%rdi) vmovdqa %ymm11,896+32*\off(%rdi) +/* + * In order to (a) remove the factor of 256 arising from the 256-point intt + * butterflies and (b) transform the output into Montgomery domain, we need to + * multiply all coefficients by 2^32/256. + * + * For ymm{8,9,10,11}, the scaling has been merged into the last butterfly, so + * only ymm{4,5,6,7} need to be scaled explicitly. + * + * The scaling is achieved by computing montmul(-, MLD_AVX2_DIV), so the output + * will have an abs bound of 3q/4. + * + * 4, 5, 6, 7: abs bound < 256q + */ + vmovdqa (MLD_AVX2_BACKEND_DATA_OFFSET_8XDIV_QINV)*4(%rsi),%ymm1 vmovdqa (MLD_AVX2_BACKEND_DATA_OFFSET_8XDIV)*4(%rsi),%ymm2 vpmuldq %ymm1,%ymm4,%ymm12 @@ -256,6 +305,8 @@ vmovshdup %ymm7,%ymm7 vpblendd $0xAA,%ymm8,%ymm6,%ymm6 vpblendd $0xAA,%ymm9,%ymm7,%ymm7 +/* 4, 5, 6, 7: abs bound < 3q/4 */ + vmovdqa %ymm4, 0+32*\off(%rdi) vmovdqa %ymm5,128+32*\off(%rdi) vmovdqa %ymm6,256+32*\off(%rdi) diff --git a/dev/x86_64/src/ntt.S b/dev/x86_64/src/ntt.S index 8fae4ccbc..0f1d1dd58 100644 --- a/dev/x86_64/src/ntt.S +++ b/dev/x86_64/src/ntt.S @@ -44,6 +44,16 @@ vpsrlq $32,%ymm\r0,%ymm\r0 vpblendd $0xAA,%ymm\r1,%ymm\r0,%ymm\r3 .endm +/* + * Compute l + montmul(h, zh), l - montmul(h, zh) then store the results back to + * l, h respectively. + * + * Although the general abs bound of Montgomery multiplication is 3q/4, we use + * the more convenient bound q here. + * + * In conclusion, the magnitudes of all coefficients grow by at most q after + * each layer. + */ .macro butterfly l,h,zl0=1,zl1=1,zh0=2,zh1=2 vpmuldq %ymm\zl0,%ymm\h,%ymm13 vmovshdup %ymm\h,%ymm12 @@ -56,16 +66,30 @@ vpmuldq %ymm0,%ymm13,%ymm13 vpmuldq %ymm0,%ymm14,%ymm14 vmovshdup %ymm\h,%ymm\h -vpblendd $0xAA,%ymm12,%ymm\h,%ymm\h +vpblendd $0xAA,%ymm12,%ymm\h,%ymm\h /* mulhi(h * zh) */ -vpsubd %ymm\h,%ymm\l,%ymm12 -vpaddd %ymm\h,%ymm\l,%ymm\l +/* + * Originally, mulhi(h * zh) should be subtracted by mulhi(q * mullo(h * zl)) + * in order to complete computing + * + * montmul(h, zh) = mulhi(h * zh) - mulhi(q * mullo(h * zl)). + * + * Here, since mulhi(q * mullo(h * zl)) has not been computed yet, this task is + * delayed until after add/sub. + */ +vpsubd %ymm\h,%ymm\l,%ymm12 /* l - mulhi(h * zh) + * = l - montmul(h, zh) + * - mulhi(q * mullo(h * zl)) */ +vpaddd %ymm\h,%ymm\l,%ymm\l /* l + mulhi(h * zh) + * = l + montmul(h, zh) + * + mulhi(q * mullo(h * zl)) */ vmovshdup %ymm13,%ymm13 -vpblendd $0xAA,%ymm14,%ymm13,%ymm13 +vpblendd $0xAA,%ymm14,%ymm13,%ymm13 /* mulhi(q * mullo(h * zl)) */ -vpaddd %ymm13,%ymm12,%ymm\h -vpsubd %ymm13,%ymm\l,%ymm\l +/* Finish the delayed task mentioned above */ +vpaddd %ymm13,%ymm12,%ymm\h /* l - montmul(h, zh) */ +vpsubd %ymm13,%ymm\l,%ymm\l /* l + montmul(h, zh) */ .endm .macro levels0t1 off @@ -82,11 +106,15 @@ vmovdqa 640+32*\off(%rdi),%ymm9 vmovdqa 768+32*\off(%rdi),%ymm10 vmovdqa 896+32*\off(%rdi),%ymm11 +/* All: abs bound < q */ + butterfly 4,8 butterfly 5,9 butterfly 6,10 butterfly 7,11 +/* All: abs bound < 2q */ + /* level 1 */ vpbroadcastd (MLD_AVX2_BACKEND_DATA_OFFSET_ZETAS_QINV+2)*4(%rsi),%ymm1 vpbroadcastd (MLD_AVX2_BACKEND_DATA_OFFSET_ZETAS+2)*4(%rsi),%ymm2 @@ -98,6 +126,8 @@ vpbroadcastd (MLD_AVX2_BACKEND_DATA_OFFSET_ZETAS+3)*4(%rsi),%ymm2 butterfly 8,10 butterfly 9,11 +/* All: abs bound < 3q */ + vmovdqa %ymm4, 0+32*\off(%rdi) vmovdqa %ymm5,128+32*\off(%rdi) vmovdqa %ymm6,256+32*\off(%rdi) @@ -132,6 +162,8 @@ shuffle8 5,9,4,9 shuffle8 6,10,5,10 shuffle8 7,11,6,11 +/* All: abs bound < 4q */ + /* level 3 */ vmovdqa (MLD_AVX2_BACKEND_DATA_OFFSET_ZETAS_QINV+8+8*\off)*4(%rsi),%ymm1 vmovdqa (MLD_AVX2_BACKEND_DATA_OFFSET_ZETAS+8+8*\off)*4(%rsi),%ymm2 @@ -146,6 +178,8 @@ shuffle4 8,10,3,10 shuffle4 4,6,8,6 shuffle4 9,11,4,11 +/* All: abs bound < 5q */ + /* level 4 */ vmovdqa (MLD_AVX2_BACKEND_DATA_OFFSET_ZETAS_QINV+40+8*\off)*4(%rsi),%ymm1 vmovdqa (MLD_AVX2_BACKEND_DATA_OFFSET_ZETAS+40+8*\off)*4(%rsi),%ymm2 @@ -160,6 +194,8 @@ shuffle2 5,6,7,6 shuffle2 3,4,5,4 shuffle2 10,11,3,11 +/* All: abs bound < 6q */ + /* level 5 */ vmovdqa (MLD_AVX2_BACKEND_DATA_OFFSET_ZETAS_QINV+72+8*\off)*4(%rsi),%ymm1 vmovdqa (MLD_AVX2_BACKEND_DATA_OFFSET_ZETAS+72+8*\off)*4(%rsi),%ymm2 @@ -171,6 +207,8 @@ butterfly 8,4,1,10,2,15 butterfly 7,3,1,10,2,15 butterfly 6,11,1,10,2,15 +/* All: abs bound < 7q */ + /* level 6 */ vmovdqa (MLD_AVX2_BACKEND_DATA_OFFSET_ZETAS_QINV+104+8*\off)*4(%rsi),%ymm1 vmovdqa (MLD_AVX2_BACKEND_DATA_OFFSET_ZETAS+104+8*\off)*4(%rsi),%ymm2 @@ -186,6 +224,8 @@ vmovshdup %ymm2,%ymm15 butterfly 5,3,1,10,2,15 butterfly 4,11,1,10,2,15 +/* All: abs bound < 8q */ + /* level 7 */ vmovdqa (MLD_AVX2_BACKEND_DATA_OFFSET_ZETAS_QINV+168+8*\off)*4(%rsi),%ymm1 vmovdqa (MLD_AVX2_BACKEND_DATA_OFFSET_ZETAS+168+8*\off)*4(%rsi),%ymm2 @@ -211,6 +251,8 @@ vpsrlq $32,%ymm1,%ymm10 vmovshdup %ymm2,%ymm15 butterfly 3,11,1,10,2,15 +/* All: abs bound < 9q */ + vmovdqa %ymm9,256*\off+ 0(%rdi) vmovdqa %ymm8,256*\off+ 32(%rdi) vmovdqa %ymm7,256*\off+ 64(%rdi) From a12ee244ab93dba8f1ce6409934994685b4feac4 Mon Sep 17 00:00:00 2001 From: jammychiou1 Date: Mon, 27 Oct 2025 11:26:57 +0800 Subject: [PATCH 2/4] Add bounds reasoning comments to AVX2 basemul Signed-off-by: jammychiou1 --- dev/x86_64/src/pointwise.S | 7 +++++++ dev/x86_64/src/pointwise_acc_l4.S | 14 ++++++++++++++ dev/x86_64/src/pointwise_acc_l5.S | 15 +++++++++++++++ dev/x86_64/src/pointwise_acc_l7.S | 17 +++++++++++++++++ 4 files changed, 53 insertions(+) diff --git a/dev/x86_64/src/pointwise.S b/dev/x86_64/src/pointwise.S index 8bd73616f..6445054d5 100644 --- a/dev/x86_64/src/pointwise.S +++ b/dev/x86_64/src/pointwise.S @@ -61,6 +61,7 @@ _looptop1: vpsrlq ymm11, ymm10, 32 vpsrlq ymm13, ymm12, 32 vmovshdup ymm15, ymm14 + /* All: abs bound < 9q */ // Multiply vpmuldq ymm2, ymm2, ymm10 @@ -69,6 +70,7 @@ _looptop1: vpmuldq ymm5, ymm5, ymm13 vpmuldq ymm6, ymm6, ymm14 vpmuldq ymm7, ymm7, ymm15 + /* All: abs bound < 81q^2 < 81*2^46 < 2^53 = 2^21R < qR/2 */ // Reduce vpmuldq ymm10, ymm0, ymm2 @@ -92,6 +94,11 @@ _looptop1: vpsrlq ymm2, ymm2, 32 vpsrlq ymm4, ymm4, 32 vmovshdup ymm6, ymm6 + /* + * All coefficients are Montgomery-reduced. This results in the bound + * + * All: abs bound <= "input abs bound"/R + q/2 < (qR/2)/R + q/2 = q + */ // Store vpblendd ymm2, ymm2, ymm3, 0xAA diff --git a/dev/x86_64/src/pointwise_acc_l4.S b/dev/x86_64/src/pointwise_acc_l4.S index e64881ccb..44c2b62b0 100644 --- a/dev/x86_64/src/pointwise_acc_l4.S +++ b/dev/x86_64/src/pointwise_acc_l4.S @@ -37,12 +37,17 @@ vpsrlq ymm9, ymm8, 32 vmovshdup ymm11, ymm10 vmovshdup ymm13, ymm12 + /* + * 6, 7, 8, 9: from the first input polynomial, abs bound < q + * 10, 11, 12, 13: from the second input polynomial, abs bound < 9q + */ // Multiply vpmuldq ymm6, ymm6, ymm10 vpmuldq ymm7, ymm7, ymm11 vpmuldq ymm8, ymm8, ymm12 vpmuldq ymm9, ymm9, ymm13 + /* All: abs bound < 9q^2 */ .endm .macro acc @@ -80,15 +85,19 @@ _looptop2: vmovdqa ymm3, ymm7 vmovdqa ymm4, ymm8 vmovdqa ymm5, ymm9 + /* All: abs bound < 9q^2 */ pointwise 1024 acc + /* All: abs bound < 18q^2 */ pointwise 2048 acc + /* All: abs bound < 27q^2 */ pointwise 3072 acc + /* All: abs bound < 36q^2 < 36*2^46 < 2^52 = 2^20R < qR/2 */ // Reduce vpmuldq ymm6, ymm0, ymm2 @@ -105,6 +114,11 @@ _looptop2: vpsubq ymm5, ymm5, ymm9 vpsrlq ymm2, ymm2, 32 vmovshdup ymm4, ymm4 + /* + * All coefficients are Montgomery-reduced. This results in the bound + * + * All: abs bound <= "input abs bound"/R + q/2 < (qR/2)/R + q/2 = q + */ // Store vpblendd ymm2, ymm2, ymm3, 0xAA diff --git a/dev/x86_64/src/pointwise_acc_l5.S b/dev/x86_64/src/pointwise_acc_l5.S index db7348f19..020f517eb 100644 --- a/dev/x86_64/src/pointwise_acc_l5.S +++ b/dev/x86_64/src/pointwise_acc_l5.S @@ -37,12 +37,17 @@ vpsrlq ymm9, ymm8, 32 vmovshdup ymm11, ymm10 vmovshdup ymm13, ymm12 + /* + * 6, 7, 8, 9: from the first input polynomial, abs bound < q + * 10, 11, 12, 13: from the second input polynomial, abs bound < 9q + */ // Multiply vpmuldq ymm6, ymm6, ymm10 vpmuldq ymm7, ymm7, ymm11 vpmuldq ymm8, ymm8, ymm12 vpmuldq ymm9, ymm9, ymm13 + /* All: abs bound < 9q^2 */ .endm .macro acc @@ -80,18 +85,23 @@ _looptop2: vmovdqa ymm3, ymm7 vmovdqa ymm4, ymm8 vmovdqa ymm5, ymm9 + /* All: abs bound < 9q^2 */ pointwise 1024 acc + /* All: abs bound < 18q^2 */ pointwise 2048 acc + /* All: abs bound < 27q^2 */ pointwise 3072 acc + /* All: abs bound < 36q^2 */ pointwise 4096 acc + /* All: abs bound < 45q^2 < 45*2^46 < 2^52 = 2^20R < qR/2 */ // Reduce vpmuldq ymm6, ymm0, ymm2 @@ -108,6 +118,11 @@ _looptop2: vpsubq ymm5, ymm5, ymm9 vpsrlq ymm2, ymm2, 32 vmovshdup ymm4, ymm4 + /* + * All coefficients are Montgomery-reduced. This results in the bound + * + * All: abs bound <= "input abs bound"/R + q/2 < (qR/2)/R + q/2 = q + */ // Store vpblendd ymm2, ymm2, ymm3, 0xAA diff --git a/dev/x86_64/src/pointwise_acc_l7.S b/dev/x86_64/src/pointwise_acc_l7.S index bae230d75..835e87b05 100644 --- a/dev/x86_64/src/pointwise_acc_l7.S +++ b/dev/x86_64/src/pointwise_acc_l7.S @@ -37,12 +37,17 @@ vpsrlq ymm9, ymm8, 32 vmovshdup ymm11, ymm10 vmovshdup ymm13, ymm12 + /* + * 6, 7, 8, 9: from the first input polynomial, abs bound < q + * 10, 11, 12, 13: from the second input polynomial, abs bound < 9q + */ // Multiply vpmuldq ymm6, ymm6, ymm10 vpmuldq ymm7, ymm7, ymm11 vpmuldq ymm8, ymm8, ymm12 vpmuldq ymm9, ymm9, ymm13 + /* All: abs bound < 9q^2 */ .endm .macro acc @@ -80,24 +85,31 @@ _looptop2: vmovdqa ymm3, ymm7 vmovdqa ymm4, ymm8 vmovdqa ymm5, ymm9 + /* All: abs bound < 9q^2 */ pointwise 1024 acc + /* All: abs bound < 18q^2 */ pointwise 2048 acc + /* All: abs bound < 27q^2 */ pointwise 3072 acc + /* All: abs bound < 36q^2 */ pointwise 4096 acc + /* All: abs bound < 45q^2 */ pointwise 5120 acc + /* All: abs bound < 54q^2 */ pointwise 6144 acc + /* All: abs bound < 63q^2 < 63*2^46 < 2^52 = 2^20R < qR/2 */ // Reduce vpmuldq ymm6, ymm0, ymm2 @@ -114,6 +126,11 @@ _looptop2: vpsubq ymm5, ymm5, ymm9 vpsrlq ymm2, ymm2, 32 vmovshdup ymm4, ymm4 + /* + * All coefficients are Montgomery-reduced. This results in the bound + * + * All: abs bound <= "input abs bound"/R + q/2 < (qR/2)/R + q/2 = q + */ // Store vpblendd ymm2, ymm2, ymm3, 0xAA From 618d4551a9f93a875af13026631310453858106a Mon Sep 17 00:00:00 2001 From: jammychiou1 Date: Sun, 2 Nov 2025 18:15:58 +0800 Subject: [PATCH 3/4] AVX2: Redo decompose_{32,88} with an approach that's easier to explain The new approach is adapted from our Neon implementation. See for more information on the idea. Bounds reasoning comments are also added. Signed-off-by: jammychiou1 --- dev/x86_64/src/poly_decompose_32_avx2.c | 65 ++++++++++++++++--- dev/x86_64/src/poly_decompose_88_avx2.c | 65 +++++++++++++++---- .../x86_64/src/poly_decompose_32_avx2.c | 65 ++++++++++++++++--- .../x86_64/src/poly_decompose_88_avx2.c | 65 +++++++++++++++---- 4 files changed, 218 insertions(+), 42 deletions(-) diff --git a/dev/x86_64/src/poly_decompose_32_avx2.c b/dev/x86_64/src/poly_decompose_32_avx2.c index 89bf93e2b..2c2e20a72 100644 --- a/dev/x86_64/src/poly_decompose_32_avx2.c +++ b/dev/x86_64/src/poly_decompose_32_avx2.c @@ -36,30 +36,75 @@ void mld_poly_decompose_32_avx2(__m256i *a1, __m256i *a0, const __m256i *a) { unsigned int i; - __m256i f, f0, f1; - const __m256i q = - _mm256_load_si256(&mld_qdata.vec[MLD_AVX2_BACKEND_DATA_OFFSET_8XQ / 8]); - const __m256i hq = _mm256_srli_epi32(q, 1); - /* check-magic: 1025 == round((2**22*128) / ((MLDSA_Q - 1) / 16)) */ + __m256i f, f0, f1, t; + const __m256i q_bound = _mm256_set1_epi32(31 * MLDSA_GAMMA2); + /* check-magic: 1025 == floor(2**22 / 4092) */ const __m256i v = _mm256_set1_epi32(1025); const __m256i alpha = _mm256_set1_epi32(2 * MLDSA_GAMMA2); const __m256i off = _mm256_set1_epi32(127); const __m256i shift = _mm256_set1_epi32(512); - const __m256i mask = _mm256_set1_epi32(15); for (i = 0; i < MLDSA_N / 8; i++) { f = _mm256_load_si256(&a[i]); + + /* check-magic: 4092 == 2 * ((MLDSA_Q-1) // 32) // 128 */ + /* + * The goal is to compute f1 = round-(f / (2*GAMMA2)), which can be computed + * alternatively as round-(f / (128B)) = round-(ceil(f / 128) / B) where + * B = 2*GAMMA2 / 128 = 4092. Here round-() denotes "round half down". + * + * range: 0 <= f <= Q-1 = 32*GAMMA2 = 16*128*B + */ + + /* Compute f1' = ceil(f / 128) as floor((f + 127) >> 7) */ f1 = _mm256_add_epi32(f, off); f1 = _mm256_srli_epi32(f1, 7); + /* + * range: 0 <= f1' <= (Q-1)/128 = 16B + * + * Also, f1' <= (Q-1)/128 = 2^16 - 2^6 < 2^16 ensures that the odd-index + * 16-bit lanes are all 0, so no bits will be dropped in the input of the + * _mm256_mulhi_epu16() below. + */ + + /* check-magic: off */ + /* + * Compute f1 = round-(f1' / B) ≈ round(f1' * 1025 / 2^22). This is exact + * for 0 <= f1' < 2^16. Note that half is rounded down since 1025 / 2^22 ≲ + * 1 / 4092. + * + * The odd-index 16-bit lanes are still all 0 after this. As such, despite + * that the following steps use 32-bit lanes, the value of f1 is unaffected. + */ + /* check-magic: on */ f1 = _mm256_mulhi_epu16(f1, v); f1 = _mm256_mulhrs_epi16(f1, shift); - f1 = _mm256_and_si256(f1, mask); + /* range: 0 <= f1 <= 16 */ + + /* + * If f1 = 16, i.e. f > 31*GAMMA2, proceed as if f' = f - Q was given + * instead. (For f = 31*GAMMA2 + 1 thus f' = -GAMMA2, we still round it to 0 + * like other "wrapped around" cases.) + */ + + /* Check for wrap-around */ + t = _mm256_cmpgt_epi32(f, q_bound); + + /* Compute remainder f0 */ f0 = _mm256_mullo_epi32(f1, alpha); f0 = _mm256_sub_epi32(f, f0); - f = _mm256_cmpgt_epi32(f0, hq); - f = _mm256_and_si256(f, q); - f0 = _mm256_sub_epi32(f0, f); + /* + * range: -GAMMA2 < f0 <= GAMMA2 + * + * This holds since f1 = round-(f / (2*GAMMA2)) was computed exactly. + */ + + /* If wrap-around is required, set f1 = 0 and f0 -= 1 */ + f1 = _mm256_andnot_si256(t, f1); + f0 = _mm256_add_epi32(f0, t); + /* range: 0 <= f1 <= 15, -GAMMA2 <= f0 <= GAMMA2 */ + _mm256_store_si256(&a1[i], f1); _mm256_store_si256(&a0[i], f0); } diff --git a/dev/x86_64/src/poly_decompose_88_avx2.c b/dev/x86_64/src/poly_decompose_88_avx2.c index f17d663c9..fdf65f596 100644 --- a/dev/x86_64/src/poly_decompose_88_avx2.c +++ b/dev/x86_64/src/poly_decompose_88_avx2.c @@ -38,31 +38,74 @@ void mld_poly_decompose_88_avx2(__m256i *a1, __m256i *a0, const __m256i *a) { unsigned int i; __m256i f, f0, f1, t; - const __m256i q = - _mm256_load_si256(&mld_qdata.vec[MLD_AVX2_BACKEND_DATA_OFFSET_8XQ / 8]); - const __m256i hq = _mm256_srli_epi32(q, 1); - /* check-magic: 11275 == round((2**24*128) / ((MLDSA_Q - 1) / 44)) */ + const __m256i q_bound = _mm256_set1_epi32(87 * MLDSA_GAMMA2); + /* check-magic: 11275 == floor(2**24 / 1488) */ const __m256i v = _mm256_set1_epi32(11275); const __m256i alpha = _mm256_set1_epi32(2 * MLDSA_GAMMA2); const __m256i off = _mm256_set1_epi32(127); const __m256i shift = _mm256_set1_epi32(128); - const __m256i max = _mm256_set1_epi32(43); - const __m256i zero = _mm256_setzero_si256(); for (i = 0; i < MLDSA_N / 8; i++) { f = _mm256_load_si256(&a[i]); + + /* check-magic: 1488 == 2 * ((MLDSA_Q-1) // 88) // 128 */ + /* + * The goal is to compute f1 = round-(f / (2*GAMMA2)), which can be computed + * alternatively as round-(f / (128B)) = round-(ceil(f / 128) / B) where + * B = 2*GAMMA2 / 128 = 1488. Here round-() denotes "round half down". + * + * range: 0 <= f <= Q-1 = 88*GAMMA2 = 44*128*B + */ + + /* Compute f1' = ceil(f / 128) as floor((f + 127) >> 7) */ f1 = _mm256_add_epi32(f, off); f1 = _mm256_srli_epi32(f1, 7); + /* + * range: 0 <= f1' <= (Q-1)/128 = 44B + * + * Also, f1' <= (Q-1)/128 = 2^16 - 2^6 < 2^16 ensures that the odd-index + * 16-bit lanes are all 0, so no bits will be dropped in the input of the + * _mm256_mulhi_epu16() below. + */ + + /* check-magic: off */ + /* + * Compute f1 = round-(f1' / B) ≈ round(f1' * 11275 / 2^24). This is exact + * for 0 <= f1' < 2^16. Note that half is rounded down since 11275 / 2^24 ≲ + * 1 / 1488. + * + * The odd-index 16-bit lanes are still all 0 after this. As such, despite + * that the following steps use 32-bit lanes, the value of f1 is unaffected. + */ + /* check-magic: on */ f1 = _mm256_mulhi_epu16(f1, v); f1 = _mm256_mulhrs_epi16(f1, shift); - t = _mm256_sub_epi32(max, f1); - f1 = _mm256_blendv_epi32(f1, zero, t); + /* range: 0 <= f1 <= 44 */ + + /* + * If f1 = 44, i.e. f > 87*GAMMA2, proceed as if f' = f - Q was given + * instead. (For f = 87*GAMMA2 + 1 thus f' = -GAMMA2, we still round it to 0 + * like other "wrapped around" cases.) + */ + + /* Check for wrap-around */ + t = _mm256_cmpgt_epi32(f, q_bound); + + /* Compute remainder f0 */ f0 = _mm256_mullo_epi32(f1, alpha); f0 = _mm256_sub_epi32(f, f0); - f = _mm256_cmpgt_epi32(f0, hq); - f = _mm256_and_si256(f, q); - f0 = _mm256_sub_epi32(f0, f); + /* + * range: -GAMMA2 < f0 <= GAMMA2 + * + * This holds since f1 = round-(f / (2*GAMMA2)) was computed exactly. + */ + + /* If wrap-around is required, set f1 = 0 and f0 -= 1 */ + f1 = _mm256_andnot_si256(t, f1); + f0 = _mm256_add_epi32(f0, t); + /* range: 0 <= f1 <= 43, -GAMMA2 <= f0 <= GAMMA2 */ + _mm256_store_si256(&a1[i], f1); _mm256_store_si256(&a0[i], f0); } diff --git a/mldsa/src/native/x86_64/src/poly_decompose_32_avx2.c b/mldsa/src/native/x86_64/src/poly_decompose_32_avx2.c index 89bf93e2b..2c2e20a72 100644 --- a/mldsa/src/native/x86_64/src/poly_decompose_32_avx2.c +++ b/mldsa/src/native/x86_64/src/poly_decompose_32_avx2.c @@ -36,30 +36,75 @@ void mld_poly_decompose_32_avx2(__m256i *a1, __m256i *a0, const __m256i *a) { unsigned int i; - __m256i f, f0, f1; - const __m256i q = - _mm256_load_si256(&mld_qdata.vec[MLD_AVX2_BACKEND_DATA_OFFSET_8XQ / 8]); - const __m256i hq = _mm256_srli_epi32(q, 1); - /* check-magic: 1025 == round((2**22*128) / ((MLDSA_Q - 1) / 16)) */ + __m256i f, f0, f1, t; + const __m256i q_bound = _mm256_set1_epi32(31 * MLDSA_GAMMA2); + /* check-magic: 1025 == floor(2**22 / 4092) */ const __m256i v = _mm256_set1_epi32(1025); const __m256i alpha = _mm256_set1_epi32(2 * MLDSA_GAMMA2); const __m256i off = _mm256_set1_epi32(127); const __m256i shift = _mm256_set1_epi32(512); - const __m256i mask = _mm256_set1_epi32(15); for (i = 0; i < MLDSA_N / 8; i++) { f = _mm256_load_si256(&a[i]); + + /* check-magic: 4092 == 2 * ((MLDSA_Q-1) // 32) // 128 */ + /* + * The goal is to compute f1 = round-(f / (2*GAMMA2)), which can be computed + * alternatively as round-(f / (128B)) = round-(ceil(f / 128) / B) where + * B = 2*GAMMA2 / 128 = 4092. Here round-() denotes "round half down". + * + * range: 0 <= f <= Q-1 = 32*GAMMA2 = 16*128*B + */ + + /* Compute f1' = ceil(f / 128) as floor((f + 127) >> 7) */ f1 = _mm256_add_epi32(f, off); f1 = _mm256_srli_epi32(f1, 7); + /* + * range: 0 <= f1' <= (Q-1)/128 = 16B + * + * Also, f1' <= (Q-1)/128 = 2^16 - 2^6 < 2^16 ensures that the odd-index + * 16-bit lanes are all 0, so no bits will be dropped in the input of the + * _mm256_mulhi_epu16() below. + */ + + /* check-magic: off */ + /* + * Compute f1 = round-(f1' / B) ≈ round(f1' * 1025 / 2^22). This is exact + * for 0 <= f1' < 2^16. Note that half is rounded down since 1025 / 2^22 ≲ + * 1 / 4092. + * + * The odd-index 16-bit lanes are still all 0 after this. As such, despite + * that the following steps use 32-bit lanes, the value of f1 is unaffected. + */ + /* check-magic: on */ f1 = _mm256_mulhi_epu16(f1, v); f1 = _mm256_mulhrs_epi16(f1, shift); - f1 = _mm256_and_si256(f1, mask); + /* range: 0 <= f1 <= 16 */ + + /* + * If f1 = 16, i.e. f > 31*GAMMA2, proceed as if f' = f - Q was given + * instead. (For f = 31*GAMMA2 + 1 thus f' = -GAMMA2, we still round it to 0 + * like other "wrapped around" cases.) + */ + + /* Check for wrap-around */ + t = _mm256_cmpgt_epi32(f, q_bound); + + /* Compute remainder f0 */ f0 = _mm256_mullo_epi32(f1, alpha); f0 = _mm256_sub_epi32(f, f0); - f = _mm256_cmpgt_epi32(f0, hq); - f = _mm256_and_si256(f, q); - f0 = _mm256_sub_epi32(f0, f); + /* + * range: -GAMMA2 < f0 <= GAMMA2 + * + * This holds since f1 = round-(f / (2*GAMMA2)) was computed exactly. + */ + + /* If wrap-around is required, set f1 = 0 and f0 -= 1 */ + f1 = _mm256_andnot_si256(t, f1); + f0 = _mm256_add_epi32(f0, t); + /* range: 0 <= f1 <= 15, -GAMMA2 <= f0 <= GAMMA2 */ + _mm256_store_si256(&a1[i], f1); _mm256_store_si256(&a0[i], f0); } diff --git a/mldsa/src/native/x86_64/src/poly_decompose_88_avx2.c b/mldsa/src/native/x86_64/src/poly_decompose_88_avx2.c index f17d663c9..fdf65f596 100644 --- a/mldsa/src/native/x86_64/src/poly_decompose_88_avx2.c +++ b/mldsa/src/native/x86_64/src/poly_decompose_88_avx2.c @@ -38,31 +38,74 @@ void mld_poly_decompose_88_avx2(__m256i *a1, __m256i *a0, const __m256i *a) { unsigned int i; __m256i f, f0, f1, t; - const __m256i q = - _mm256_load_si256(&mld_qdata.vec[MLD_AVX2_BACKEND_DATA_OFFSET_8XQ / 8]); - const __m256i hq = _mm256_srli_epi32(q, 1); - /* check-magic: 11275 == round((2**24*128) / ((MLDSA_Q - 1) / 44)) */ + const __m256i q_bound = _mm256_set1_epi32(87 * MLDSA_GAMMA2); + /* check-magic: 11275 == floor(2**24 / 1488) */ const __m256i v = _mm256_set1_epi32(11275); const __m256i alpha = _mm256_set1_epi32(2 * MLDSA_GAMMA2); const __m256i off = _mm256_set1_epi32(127); const __m256i shift = _mm256_set1_epi32(128); - const __m256i max = _mm256_set1_epi32(43); - const __m256i zero = _mm256_setzero_si256(); for (i = 0; i < MLDSA_N / 8; i++) { f = _mm256_load_si256(&a[i]); + + /* check-magic: 1488 == 2 * ((MLDSA_Q-1) // 88) // 128 */ + /* + * The goal is to compute f1 = round-(f / (2*GAMMA2)), which can be computed + * alternatively as round-(f / (128B)) = round-(ceil(f / 128) / B) where + * B = 2*GAMMA2 / 128 = 1488. Here round-() denotes "round half down". + * + * range: 0 <= f <= Q-1 = 88*GAMMA2 = 44*128*B + */ + + /* Compute f1' = ceil(f / 128) as floor((f + 127) >> 7) */ f1 = _mm256_add_epi32(f, off); f1 = _mm256_srli_epi32(f1, 7); + /* + * range: 0 <= f1' <= (Q-1)/128 = 44B + * + * Also, f1' <= (Q-1)/128 = 2^16 - 2^6 < 2^16 ensures that the odd-index + * 16-bit lanes are all 0, so no bits will be dropped in the input of the + * _mm256_mulhi_epu16() below. + */ + + /* check-magic: off */ + /* + * Compute f1 = round-(f1' / B) ≈ round(f1' * 11275 / 2^24). This is exact + * for 0 <= f1' < 2^16. Note that half is rounded down since 11275 / 2^24 ≲ + * 1 / 1488. + * + * The odd-index 16-bit lanes are still all 0 after this. As such, despite + * that the following steps use 32-bit lanes, the value of f1 is unaffected. + */ + /* check-magic: on */ f1 = _mm256_mulhi_epu16(f1, v); f1 = _mm256_mulhrs_epi16(f1, shift); - t = _mm256_sub_epi32(max, f1); - f1 = _mm256_blendv_epi32(f1, zero, t); + /* range: 0 <= f1 <= 44 */ + + /* + * If f1 = 44, i.e. f > 87*GAMMA2, proceed as if f' = f - Q was given + * instead. (For f = 87*GAMMA2 + 1 thus f' = -GAMMA2, we still round it to 0 + * like other "wrapped around" cases.) + */ + + /* Check for wrap-around */ + t = _mm256_cmpgt_epi32(f, q_bound); + + /* Compute remainder f0 */ f0 = _mm256_mullo_epi32(f1, alpha); f0 = _mm256_sub_epi32(f, f0); - f = _mm256_cmpgt_epi32(f0, hq); - f = _mm256_and_si256(f, q); - f0 = _mm256_sub_epi32(f0, f); + /* + * range: -GAMMA2 < f0 <= GAMMA2 + * + * This holds since f1 = round-(f / (2*GAMMA2)) was computed exactly. + */ + + /* If wrap-around is required, set f1 = 0 and f0 -= 1 */ + f1 = _mm256_andnot_si256(t, f1); + f0 = _mm256_add_epi32(f0, t); + /* range: 0 <= f1 <= 43, -GAMMA2 <= f0 <= GAMMA2 */ + _mm256_store_si256(&a1[i], f1); _mm256_store_si256(&a0[i], f0); } From 37040a988e3ba23835fd362fd5d95f61dafa2ea3 Mon Sep 17 00:00:00 2001 From: jammychiou1 Date: Sun, 2 Nov 2025 20:45:38 +0800 Subject: [PATCH 4/4] AVX2: Update decompose approach used in use_hint Edit some comments while we're at it. Signed-off-by: jammychiou1 --- dev/x86_64/src/poly_use_hint_32_avx2.c | 23 +++++++++--------- dev/x86_64/src/poly_use_hint_88_avx2.c | 24 +++++++++---------- .../native/x86_64/src/poly_use_hint_32_avx2.c | 23 +++++++++--------- .../native/x86_64/src/poly_use_hint_88_avx2.c | 24 +++++++++---------- 4 files changed, 44 insertions(+), 50 deletions(-) diff --git a/dev/x86_64/src/poly_use_hint_32_avx2.c b/dev/x86_64/src/poly_use_hint_32_avx2.c index ad5d71de2..ebc3ccf04 100644 --- a/dev/x86_64/src/poly_use_hint_32_avx2.c +++ b/dev/x86_64/src/poly_use_hint_32_avx2.c @@ -38,10 +38,8 @@ void mld_poly_use_hint_32_avx2(__m256i *b, const __m256i *a, { unsigned int i; __m256i f, f0, f1, h, t; - const __m256i q = - _mm256_load_si256(&mld_qdata.vec[MLD_AVX2_BACKEND_DATA_OFFSET_8XQ / 8]); - const __m256i hq = _mm256_srli_epi32(q, 1); - /* check-magic: 1025 == round((2**22*128) / ((MLDSA_Q - 1) / 16)) */ + const __m256i q_bound = _mm256_set1_epi32(87 * MLDSA_GAMMA2); + /* check-magic: 1025 == floor(2**22 / 4092) */ const __m256i v = _mm256_set1_epi32(1025); const __m256i alpha = _mm256_set1_epi32(2 * MLDSA_GAMMA2); const __m256i off = _mm256_set1_epi32(127); @@ -54,26 +52,27 @@ void mld_poly_use_hint_32_avx2(__m256i *b, const __m256i *a, f = _mm256_load_si256(&a[i]); h = _mm256_load_si256(&hint[i]); - /* Reference: The reference avx2 implementation calls poly_decompose to - * compute all a1, a0 before the loop. + /* Reference: + * - @[REF_AVX2] calls poly_decompose to compute all a1, a0 before the loop. + * - Our implementation of decompose() is slightly different from that in + * @[REF_AVX2]. See poly_decompose_32_avx2.c for more information. */ - /* decompose */ + /* f1, f2 = decompose(f) */ f1 = _mm256_add_epi32(f, off); f1 = _mm256_srli_epi32(f1, 7); f1 = _mm256_mulhi_epu16(f1, v); f1 = _mm256_mulhrs_epi16(f1, shift); - f1 = _mm256_and_si256(f1, mask); + t = _mm256_cmpgt_epi32(f, q_bound); f0 = _mm256_mullo_epi32(f1, alpha); f0 = _mm256_sub_epi32(f, f0); - f = _mm256_cmpgt_epi32(f0, hq); - f = _mm256_and_si256(f, q); - f0 = _mm256_sub_epi32(f0, f); + f1 = _mm256_andnot_si256(t, f1); + f0 = _mm256_add_epi32(f0, t); /* Reference: The reference avx2 implementation checks a0 >= 0, which is * different from the specification and the reference C implementation. We * follow the specification and check a0 > 0. */ - /* t = (a0 > 0) ? h : -h */ + /* t = (f0 > 0) ? h : -h */ f0 = _mm256_cmpgt_epi32(f0, zero); t = _mm256_blendv_epi32(h, zero, f0); t = _mm256_slli_epi32(t, 1); diff --git a/dev/x86_64/src/poly_use_hint_88_avx2.c b/dev/x86_64/src/poly_use_hint_88_avx2.c index a91fa80b9..1e902f28b 100644 --- a/dev/x86_64/src/poly_use_hint_88_avx2.c +++ b/dev/x86_64/src/poly_use_hint_88_avx2.c @@ -38,10 +38,8 @@ void mld_poly_use_hint_88_avx2(__m256i *b, const __m256i *a, { unsigned int i; __m256i f, f0, f1, h, t; - const __m256i q = - _mm256_load_si256(&mld_qdata.vec[MLD_AVX2_BACKEND_DATA_OFFSET_8XQ / 8]); - const __m256i hq = _mm256_srli_epi32(q, 1); - /* check-magic: 11275 == round((2**24*128) / ((MLDSA_Q - 1) / 44)) */ + const __m256i q_bound = _mm256_set1_epi32(87 * MLDSA_GAMMA2); + /* check-magic: 11275 == floor(2**24 / 1488) */ const __m256i v = _mm256_set1_epi32(11275); const __m256i alpha = _mm256_set1_epi32(2 * MLDSA_GAMMA2); const __m256i off = _mm256_set1_epi32(127); @@ -54,27 +52,27 @@ void mld_poly_use_hint_88_avx2(__m256i *b, const __m256i *a, f = _mm256_load_si256(&a[i]); h = _mm256_load_si256(&hint[i]); - /* Reference: The reference avx2 implementation calls poly_decompose to - * compute all a1, a0 before the loop. + /* Reference: + * - @[REF_AVX2] calls poly_decompose to compute all a1, a0 before the loop. + * - Our implementation of decompose() is slightly different from that in + * @[REF_AVX2]. See poly_decompose_88_avx2.c for more information. */ - /* decompose */ + /* f1, f2 = decompose(f) */ f1 = _mm256_add_epi32(f, off); f1 = _mm256_srli_epi32(f1, 7); f1 = _mm256_mulhi_epu16(f1, v); f1 = _mm256_mulhrs_epi16(f1, shift); - t = _mm256_sub_epi32(max, f1); - f1 = _mm256_blendv_epi32(f1, zero, t); + t = _mm256_cmpgt_epi32(f, q_bound); f0 = _mm256_mullo_epi32(f1, alpha); f0 = _mm256_sub_epi32(f, f0); - f = _mm256_cmpgt_epi32(f0, hq); - f = _mm256_and_si256(f, q); - f0 = _mm256_sub_epi32(f0, f); + f1 = _mm256_andnot_si256(t, f1); + f0 = _mm256_add_epi32(f0, t); /* Reference: The reference avx2 implementation checks a0 >= 0, which is * different from the specification and the reference C implementation. We * follow the specification and check a0 > 0. */ - /* t = (a0 > 0) ? h : -h */ + /* t = (f0 > 0) ? h : -h */ f0 = _mm256_cmpgt_epi32(f0, zero); t = _mm256_blendv_epi32(h, zero, f0); t = _mm256_slli_epi32(t, 1); diff --git a/mldsa/src/native/x86_64/src/poly_use_hint_32_avx2.c b/mldsa/src/native/x86_64/src/poly_use_hint_32_avx2.c index ad5d71de2..ebc3ccf04 100644 --- a/mldsa/src/native/x86_64/src/poly_use_hint_32_avx2.c +++ b/mldsa/src/native/x86_64/src/poly_use_hint_32_avx2.c @@ -38,10 +38,8 @@ void mld_poly_use_hint_32_avx2(__m256i *b, const __m256i *a, { unsigned int i; __m256i f, f0, f1, h, t; - const __m256i q = - _mm256_load_si256(&mld_qdata.vec[MLD_AVX2_BACKEND_DATA_OFFSET_8XQ / 8]); - const __m256i hq = _mm256_srli_epi32(q, 1); - /* check-magic: 1025 == round((2**22*128) / ((MLDSA_Q - 1) / 16)) */ + const __m256i q_bound = _mm256_set1_epi32(87 * MLDSA_GAMMA2); + /* check-magic: 1025 == floor(2**22 / 4092) */ const __m256i v = _mm256_set1_epi32(1025); const __m256i alpha = _mm256_set1_epi32(2 * MLDSA_GAMMA2); const __m256i off = _mm256_set1_epi32(127); @@ -54,26 +52,27 @@ void mld_poly_use_hint_32_avx2(__m256i *b, const __m256i *a, f = _mm256_load_si256(&a[i]); h = _mm256_load_si256(&hint[i]); - /* Reference: The reference avx2 implementation calls poly_decompose to - * compute all a1, a0 before the loop. + /* Reference: + * - @[REF_AVX2] calls poly_decompose to compute all a1, a0 before the loop. + * - Our implementation of decompose() is slightly different from that in + * @[REF_AVX2]. See poly_decompose_32_avx2.c for more information. */ - /* decompose */ + /* f1, f2 = decompose(f) */ f1 = _mm256_add_epi32(f, off); f1 = _mm256_srli_epi32(f1, 7); f1 = _mm256_mulhi_epu16(f1, v); f1 = _mm256_mulhrs_epi16(f1, shift); - f1 = _mm256_and_si256(f1, mask); + t = _mm256_cmpgt_epi32(f, q_bound); f0 = _mm256_mullo_epi32(f1, alpha); f0 = _mm256_sub_epi32(f, f0); - f = _mm256_cmpgt_epi32(f0, hq); - f = _mm256_and_si256(f, q); - f0 = _mm256_sub_epi32(f0, f); + f1 = _mm256_andnot_si256(t, f1); + f0 = _mm256_add_epi32(f0, t); /* Reference: The reference avx2 implementation checks a0 >= 0, which is * different from the specification and the reference C implementation. We * follow the specification and check a0 > 0. */ - /* t = (a0 > 0) ? h : -h */ + /* t = (f0 > 0) ? h : -h */ f0 = _mm256_cmpgt_epi32(f0, zero); t = _mm256_blendv_epi32(h, zero, f0); t = _mm256_slli_epi32(t, 1); diff --git a/mldsa/src/native/x86_64/src/poly_use_hint_88_avx2.c b/mldsa/src/native/x86_64/src/poly_use_hint_88_avx2.c index a91fa80b9..1e902f28b 100644 --- a/mldsa/src/native/x86_64/src/poly_use_hint_88_avx2.c +++ b/mldsa/src/native/x86_64/src/poly_use_hint_88_avx2.c @@ -38,10 +38,8 @@ void mld_poly_use_hint_88_avx2(__m256i *b, const __m256i *a, { unsigned int i; __m256i f, f0, f1, h, t; - const __m256i q = - _mm256_load_si256(&mld_qdata.vec[MLD_AVX2_BACKEND_DATA_OFFSET_8XQ / 8]); - const __m256i hq = _mm256_srli_epi32(q, 1); - /* check-magic: 11275 == round((2**24*128) / ((MLDSA_Q - 1) / 44)) */ + const __m256i q_bound = _mm256_set1_epi32(87 * MLDSA_GAMMA2); + /* check-magic: 11275 == floor(2**24 / 1488) */ const __m256i v = _mm256_set1_epi32(11275); const __m256i alpha = _mm256_set1_epi32(2 * MLDSA_GAMMA2); const __m256i off = _mm256_set1_epi32(127); @@ -54,27 +52,27 @@ void mld_poly_use_hint_88_avx2(__m256i *b, const __m256i *a, f = _mm256_load_si256(&a[i]); h = _mm256_load_si256(&hint[i]); - /* Reference: The reference avx2 implementation calls poly_decompose to - * compute all a1, a0 before the loop. + /* Reference: + * - @[REF_AVX2] calls poly_decompose to compute all a1, a0 before the loop. + * - Our implementation of decompose() is slightly different from that in + * @[REF_AVX2]. See poly_decompose_88_avx2.c for more information. */ - /* decompose */ + /* f1, f2 = decompose(f) */ f1 = _mm256_add_epi32(f, off); f1 = _mm256_srli_epi32(f1, 7); f1 = _mm256_mulhi_epu16(f1, v); f1 = _mm256_mulhrs_epi16(f1, shift); - t = _mm256_sub_epi32(max, f1); - f1 = _mm256_blendv_epi32(f1, zero, t); + t = _mm256_cmpgt_epi32(f, q_bound); f0 = _mm256_mullo_epi32(f1, alpha); f0 = _mm256_sub_epi32(f, f0); - f = _mm256_cmpgt_epi32(f0, hq); - f = _mm256_and_si256(f, q); - f0 = _mm256_sub_epi32(f0, f); + f1 = _mm256_andnot_si256(t, f1); + f0 = _mm256_add_epi32(f0, t); /* Reference: The reference avx2 implementation checks a0 >= 0, which is * different from the specification and the reference C implementation. We * follow the specification and check a0 > 0. */ - /* t = (a0 > 0) ? h : -h */ + /* t = (f0 > 0) ? h : -h */ f0 = _mm256_cmpgt_epi32(f0, zero); t = _mm256_blendv_epi32(h, zero, f0); t = _mm256_slli_epi32(t, 1);