diff --git a/dev/x86_64/src/intt.S b/dev/x86_64/src/intt.S index f45d0fd87..4e9979588 100644 --- a/dev/x86_64/src/intt.S +++ b/dev/x86_64/src/intt.S @@ -43,6 +43,13 @@ vpsrlq $32,%ymm\r0,%ymm\r0 vpblendd $0xAA,%ymm\r1,%ymm\r0,%ymm\r3 .endm +/* + * Compute l + h, montmul(h - l, zh) then store the results back to l, h + * respectively. + * + * The abs bound of "Montgomery multiplication with signed canonical constant" + * is ceil(3q/4) (see the end of this file). + */ .macro butterfly l,h,zl0=1,zl1=1,zh0=2,zh1=2 vpsubd %ymm\l,%ymm\h,%ymm12 vpaddd %ymm\h,%ymm\l,%ymm\l @@ -74,6 +81,8 @@ vmovdqa 256*\off+160(%rdi),%ymm9 vmovdqa 256*\off+192(%rdi),%ymm10 vmovdqa 256*\off+224(%rdi),%ymm11 +/* All: abs bound < q */ + /* level 0 */ vpermq $0x1B,(MLD_AVX2_BACKEND_DATA_OFFSET_ZETAS_QINV+296-8*\off-8)*4(%rsi),%ymm3 vpermq $0x1B,(MLD_AVX2_BACKEND_DATA_OFFSET_ZETAS+296-8*\off-8)*4(%rsi),%ymm15 @@ -99,6 +108,19 @@ vmovshdup %ymm3,%ymm1 vmovshdup %ymm15,%ymm2 butterfly 10,11,1,3,2,15 +/* 4, 6, 8, 10: abs bound < 2q; 5, 7, 9, 11: abs bound < ceil(3q/4) */ +/* + * Note that since 2^31 / q > 256, the sum of all 256 coefficients does not + * overflow. This allows us to greatly simplify the range analysis by relaxing + * and unifying the bounds of all coefficients on the same layer. As a concrete + * example, here we relax the bounds on 5, 7, 9, 11 and conclude that + * + * All: abs bound < 2q + * + * In all but last of the following layers, we do the same relaxation without + * explicit mention. + */ + /* level 1 */ vpermq $0x1B,(MLD_AVX2_BACKEND_DATA_OFFSET_ZETAS_QINV+168-8*\off-8)*4(%rsi),%ymm3 vpermq $0x1B,(MLD_AVX2_BACKEND_DATA_OFFSET_ZETAS+168-8*\off-8)*4(%rsi),%ymm15 @@ -114,6 +136,8 @@ vmovshdup %ymm15,%ymm2 butterfly 8,10,1,3,2,15 butterfly 9,11,1,3,2,15 +/* All: abs bound < 4q */ + /* level 2 */ vpermq $0x1B,(MLD_AVX2_BACKEND_DATA_OFFSET_ZETAS_QINV+104-8*\off-8)*4(%rsi),%ymm3 vpermq $0x1B,(MLD_AVX2_BACKEND_DATA_OFFSET_ZETAS+104-8*\off-8)*4(%rsi),%ymm15 @@ -124,6 +148,8 @@ butterfly 5,9,1,3,2,15 butterfly 6,10,1,3,2,15 butterfly 7,11,1,3,2,15 +/* All: abs bound < 8q */ + /* level 3 */ shuffle2 4,5,3,5 shuffle2 6,7,4,7 @@ -137,6 +163,8 @@ butterfly 4,7 butterfly 6,9 butterfly 8,11 +/* All: abs bound < 16q */ + /* level 4 */ shuffle4 3,4,10,4 shuffle4 6,8,3,8 @@ -150,6 +178,8 @@ butterfly 3,8 butterfly 6,7 butterfly 5,11 +/* All: abs bound < 32q */ + /* level 5 */ shuffle8 10,3,9,3 shuffle8 6,5,10,5 @@ -163,6 +193,8 @@ butterfly 10,5 butterfly 6,8 butterfly 4,11 +/* All: abs bound < 64q */ + vmovdqa %ymm9,256*\off+ 0(%rdi) vmovdqa %ymm10,256*\off+ 32(%rdi) vmovdqa %ymm6,256*\off+ 64(%rdi) @@ -194,6 +226,8 @@ vpbroadcastd (MLD_AVX2_BACKEND_DATA_OFFSET_ZETAS+2)*4(%rsi),%ymm2 butterfly 8,10 butterfly 9,11 +/* All: abs bound < 128q */ + /* level 7 */ vpbroadcastd (MLD_AVX2_BACKEND_DATA_OFFSET_ZETAS_QINV+0)*4(%rsi),%ymm1 vpbroadcastd (MLD_AVX2_BACKEND_DATA_OFFSET_ZETAS+0)*4(%rsi),%ymm2 @@ -203,11 +237,26 @@ butterfly 5,9 butterfly 6,10 butterfly 7,11 +/* 4, 5, 6, 7: abs bound < 256q; 8, 9, 10, 11: abs bound < ceil(3q/4) */ + vmovdqa %ymm8,512+32*\off(%rdi) vmovdqa %ymm9,640+32*\off(%rdi) vmovdqa %ymm10,768+32*\off(%rdi) vmovdqa %ymm11,896+32*\off(%rdi) +/* + * In order to (a) remove the factor of 256 arising from the 256-point INTT + * butterflies and (b) transform the output into Montgomery domain, we need to + * multiply all coefficients by 2^32/256. + * + * For ymm{8,9,10,11}, the scaling has been merged into the last butterfly, so + * only ymm{4,5,6,7} need to be scaled explicitly. + * + * The scaling is achieved by computing montmul(-, MLD_AVX2_DIV). + * + * 4, 5, 6, 7: abs bound < 256q + */ + vmovdqa (MLD_AVX2_BACKEND_DATA_OFFSET_8XDIV_QINV)*4(%rsi),%ymm1 vmovdqa (MLD_AVX2_BACKEND_DATA_OFFSET_8XDIV)*4(%rsi),%ymm2 vpmuldq %ymm1,%ymm4,%ymm12 @@ -256,6 +305,31 @@ vmovshdup %ymm7,%ymm7 vpblendd $0xAA,%ymm8,%ymm6,%ymm6 vpblendd $0xAA,%ymm9,%ymm7,%ymm7 +/* + * The bound ceil(3q/4) for this scaling, as well as any other "Montgomery + * multiplication with signed canonical constant", is justified as follows. + * + * In Section 2.2 of https://eprint.iacr.org/2023/1962, they showed a bound that + * works for any variable input a, as long as the constant b is signed + * canonical: + * + * |montmul(a, b)| <= (|a| (q/2) + (R/2) q) / R = (q/2) (1 + |a|/R). + * + * Therefore, even if we know nothing about a except that it fits inside + * int32_t (thus |a| <= R/2), we still have |montmul(a, b)| <= 3q/4. This can be + * strengthened to |montmul_pos(a, b)| <= floor(3q/4) < ceil(3q/4) since LHS is + * an integer and 3q/4 isn't. + * + * See test/test_bounds.py for more empirical evidence (and some minor technical + * details). + * + * TODO: Use proper citation. Currently, citations within asm can cause linter + * to complain about unused citation, because comments are not preserved + * after simpasm. + */ + +/* 4, 5, 6, 7: abs bound < ceil(3q/4) */ + vmovdqa %ymm4, 0+32*\off(%rdi) vmovdqa %ymm5,128+32*\off(%rdi) vmovdqa %ymm6,256+32*\off(%rdi) diff --git a/dev/x86_64/src/ntt.S b/dev/x86_64/src/ntt.S index 8fae4ccbc..0329a0dfc 100644 --- a/dev/x86_64/src/ntt.S +++ b/dev/x86_64/src/ntt.S @@ -44,6 +44,17 @@ vpsrlq $32,%ymm\r0,%ymm\r0 vpblendd $0xAA,%ymm\r1,%ymm\r0,%ymm\r3 .endm +/* + * Compute l + montmul(h, zh), l - montmul(h, zh) then store the results back to + * l, h respectively. + * + * Although the abs bound of "Montgomery multiplication with signed canonical + * constant" is ceil(3q/4) (see the end of dev/x86_64/src/intt.S), we use the + * more convenient bound q here. + * + * In conclusion, the magnitudes of all coefficients grow by at most q after + * each layer. + */ .macro butterfly l,h,zl0=1,zl1=1,zh0=2,zh1=2 vpmuldq %ymm\zl0,%ymm\h,%ymm13 vmovshdup %ymm\h,%ymm12 @@ -56,16 +67,30 @@ vpmuldq %ymm0,%ymm13,%ymm13 vpmuldq %ymm0,%ymm14,%ymm14 vmovshdup %ymm\h,%ymm\h -vpblendd $0xAA,%ymm12,%ymm\h,%ymm\h +vpblendd $0xAA,%ymm12,%ymm\h,%ymm\h /* mulhi(h, zh) */ -vpsubd %ymm\h,%ymm\l,%ymm12 -vpaddd %ymm\h,%ymm\l,%ymm\l +/* + * Originally, mulhi(h, zh) should be subtracted by mulhi(q, mullo(h, zl)) in + * order to complete computing + * + * montmul(h, zh) = mulhi(h, zh) - mulhi(q, mullo(h, zl)). + * + * Here, since mulhi(q, mullo(h, zl)) has not been computed yet, this task is + * delayed until after add/sub. + */ +vpsubd %ymm\h,%ymm\l,%ymm12 /* l - mulhi(h, zh) + * = l - montmul(h, zh) + * - mulhi(q, mullo(h, zl)) */ +vpaddd %ymm\h,%ymm\l,%ymm\l /* l + mulhi(h, zh) + * = l + montmul(h, zh) + * + mulhi(q, mullo(h, zl)) */ vmovshdup %ymm13,%ymm13 -vpblendd $0xAA,%ymm14,%ymm13,%ymm13 +vpblendd $0xAA,%ymm14,%ymm13,%ymm13 /* mulhi(q, mullo(h, zl)) */ -vpaddd %ymm13,%ymm12,%ymm\h -vpsubd %ymm13,%ymm\l,%ymm\l +/* Finish the delayed task mentioned above */ +vpaddd %ymm13,%ymm12,%ymm\h /* l - montmul(h, zh) */ +vpsubd %ymm13,%ymm\l,%ymm\l /* l + montmul(h, zh) */ .endm .macro levels0t1 off @@ -82,11 +107,15 @@ vmovdqa 640+32*\off(%rdi),%ymm9 vmovdqa 768+32*\off(%rdi),%ymm10 vmovdqa 896+32*\off(%rdi),%ymm11 +/* All: abs bound < q */ + butterfly 4,8 butterfly 5,9 butterfly 6,10 butterfly 7,11 +/* All: abs bound < 2q */ + /* level 1 */ vpbroadcastd (MLD_AVX2_BACKEND_DATA_OFFSET_ZETAS_QINV+2)*4(%rsi),%ymm1 vpbroadcastd (MLD_AVX2_BACKEND_DATA_OFFSET_ZETAS+2)*4(%rsi),%ymm2 @@ -98,6 +127,8 @@ vpbroadcastd (MLD_AVX2_BACKEND_DATA_OFFSET_ZETAS+3)*4(%rsi),%ymm2 butterfly 8,10 butterfly 9,11 +/* All: abs bound < 3q */ + vmovdqa %ymm4, 0+32*\off(%rdi) vmovdqa %ymm5,128+32*\off(%rdi) vmovdqa %ymm6,256+32*\off(%rdi) @@ -132,6 +163,8 @@ shuffle8 5,9,4,9 shuffle8 6,10,5,10 shuffle8 7,11,6,11 +/* All: abs bound < 4q */ + /* level 3 */ vmovdqa (MLD_AVX2_BACKEND_DATA_OFFSET_ZETAS_QINV+8+8*\off)*4(%rsi),%ymm1 vmovdqa (MLD_AVX2_BACKEND_DATA_OFFSET_ZETAS+8+8*\off)*4(%rsi),%ymm2 @@ -146,6 +179,8 @@ shuffle4 8,10,3,10 shuffle4 4,6,8,6 shuffle4 9,11,4,11 +/* All: abs bound < 5q */ + /* level 4 */ vmovdqa (MLD_AVX2_BACKEND_DATA_OFFSET_ZETAS_QINV+40+8*\off)*4(%rsi),%ymm1 vmovdqa (MLD_AVX2_BACKEND_DATA_OFFSET_ZETAS+40+8*\off)*4(%rsi),%ymm2 @@ -160,6 +195,8 @@ shuffle2 5,6,7,6 shuffle2 3,4,5,4 shuffle2 10,11,3,11 +/* All: abs bound < 6q */ + /* level 5 */ vmovdqa (MLD_AVX2_BACKEND_DATA_OFFSET_ZETAS_QINV+72+8*\off)*4(%rsi),%ymm1 vmovdqa (MLD_AVX2_BACKEND_DATA_OFFSET_ZETAS+72+8*\off)*4(%rsi),%ymm2 @@ -171,6 +208,8 @@ butterfly 8,4,1,10,2,15 butterfly 7,3,1,10,2,15 butterfly 6,11,1,10,2,15 +/* All: abs bound < 7q */ + /* level 6 */ vmovdqa (MLD_AVX2_BACKEND_DATA_OFFSET_ZETAS_QINV+104+8*\off)*4(%rsi),%ymm1 vmovdqa (MLD_AVX2_BACKEND_DATA_OFFSET_ZETAS+104+8*\off)*4(%rsi),%ymm2 @@ -186,6 +225,8 @@ vmovshdup %ymm2,%ymm15 butterfly 5,3,1,10,2,15 butterfly 4,11,1,10,2,15 +/* All: abs bound < 8q */ + /* level 7 */ vmovdqa (MLD_AVX2_BACKEND_DATA_OFFSET_ZETAS_QINV+168+8*\off)*4(%rsi),%ymm1 vmovdqa (MLD_AVX2_BACKEND_DATA_OFFSET_ZETAS+168+8*\off)*4(%rsi),%ymm2 @@ -211,6 +252,8 @@ vpsrlq $32,%ymm1,%ymm10 vmovshdup %ymm2,%ymm15 butterfly 3,11,1,10,2,15 +/* All: abs bound < 9q */ + vmovdqa %ymm9,256*\off+ 0(%rdi) vmovdqa %ymm8,256*\off+ 32(%rdi) vmovdqa %ymm7,256*\off+ 64(%rdi) diff --git a/dev/x86_64/src/pointwise.S b/dev/x86_64/src/pointwise.S index 8bd73616f..6445054d5 100644 --- a/dev/x86_64/src/pointwise.S +++ b/dev/x86_64/src/pointwise.S @@ -61,6 +61,7 @@ _looptop1: vpsrlq ymm11, ymm10, 32 vpsrlq ymm13, ymm12, 32 vmovshdup ymm15, ymm14 + /* All: abs bound < 9q */ // Multiply vpmuldq ymm2, ymm2, ymm10 @@ -69,6 +70,7 @@ _looptop1: vpmuldq ymm5, ymm5, ymm13 vpmuldq ymm6, ymm6, ymm14 vpmuldq ymm7, ymm7, ymm15 + /* All: abs bound < 81q^2 < 81*2^46 < 2^53 = 2^21R < qR/2 */ // Reduce vpmuldq ymm10, ymm0, ymm2 @@ -92,6 +94,11 @@ _looptop1: vpsrlq ymm2, ymm2, 32 vpsrlq ymm4, ymm4, 32 vmovshdup ymm6, ymm6 + /* + * All coefficients are Montgomery-reduced. This results in the bound + * + * All: abs bound <= "input abs bound"/R + q/2 < (qR/2)/R + q/2 = q + */ // Store vpblendd ymm2, ymm2, ymm3, 0xAA diff --git a/dev/x86_64/src/pointwise_acc_l4.S b/dev/x86_64/src/pointwise_acc_l4.S index e64881ccb..44c2b62b0 100644 --- a/dev/x86_64/src/pointwise_acc_l4.S +++ b/dev/x86_64/src/pointwise_acc_l4.S @@ -37,12 +37,17 @@ vpsrlq ymm9, ymm8, 32 vmovshdup ymm11, ymm10 vmovshdup ymm13, ymm12 + /* + * 6, 7, 8, 9: from the first input polynomial, abs bound < q + * 10, 11, 12, 13: from the second input polynomial, abs bound < 9q + */ // Multiply vpmuldq ymm6, ymm6, ymm10 vpmuldq ymm7, ymm7, ymm11 vpmuldq ymm8, ymm8, ymm12 vpmuldq ymm9, ymm9, ymm13 + /* All: abs bound < 9q^2 */ .endm .macro acc @@ -80,15 +85,19 @@ _looptop2: vmovdqa ymm3, ymm7 vmovdqa ymm4, ymm8 vmovdqa ymm5, ymm9 + /* All: abs bound < 9q^2 */ pointwise 1024 acc + /* All: abs bound < 18q^2 */ pointwise 2048 acc + /* All: abs bound < 27q^2 */ pointwise 3072 acc + /* All: abs bound < 36q^2 < 36*2^46 < 2^52 = 2^20R < qR/2 */ // Reduce vpmuldq ymm6, ymm0, ymm2 @@ -105,6 +114,11 @@ _looptop2: vpsubq ymm5, ymm5, ymm9 vpsrlq ymm2, ymm2, 32 vmovshdup ymm4, ymm4 + /* + * All coefficients are Montgomery-reduced. This results in the bound + * + * All: abs bound <= "input abs bound"/R + q/2 < (qR/2)/R + q/2 = q + */ // Store vpblendd ymm2, ymm2, ymm3, 0xAA diff --git a/dev/x86_64/src/pointwise_acc_l5.S b/dev/x86_64/src/pointwise_acc_l5.S index db7348f19..020f517eb 100644 --- a/dev/x86_64/src/pointwise_acc_l5.S +++ b/dev/x86_64/src/pointwise_acc_l5.S @@ -37,12 +37,17 @@ vpsrlq ymm9, ymm8, 32 vmovshdup ymm11, ymm10 vmovshdup ymm13, ymm12 + /* + * 6, 7, 8, 9: from the first input polynomial, abs bound < q + * 10, 11, 12, 13: from the second input polynomial, abs bound < 9q + */ // Multiply vpmuldq ymm6, ymm6, ymm10 vpmuldq ymm7, ymm7, ymm11 vpmuldq ymm8, ymm8, ymm12 vpmuldq ymm9, ymm9, ymm13 + /* All: abs bound < 9q^2 */ .endm .macro acc @@ -80,18 +85,23 @@ _looptop2: vmovdqa ymm3, ymm7 vmovdqa ymm4, ymm8 vmovdqa ymm5, ymm9 + /* All: abs bound < 9q^2 */ pointwise 1024 acc + /* All: abs bound < 18q^2 */ pointwise 2048 acc + /* All: abs bound < 27q^2 */ pointwise 3072 acc + /* All: abs bound < 36q^2 */ pointwise 4096 acc + /* All: abs bound < 45q^2 < 45*2^46 < 2^52 = 2^20R < qR/2 */ // Reduce vpmuldq ymm6, ymm0, ymm2 @@ -108,6 +118,11 @@ _looptop2: vpsubq ymm5, ymm5, ymm9 vpsrlq ymm2, ymm2, 32 vmovshdup ymm4, ymm4 + /* + * All coefficients are Montgomery-reduced. This results in the bound + * + * All: abs bound <= "input abs bound"/R + q/2 < (qR/2)/R + q/2 = q + */ // Store vpblendd ymm2, ymm2, ymm3, 0xAA diff --git a/dev/x86_64/src/pointwise_acc_l7.S b/dev/x86_64/src/pointwise_acc_l7.S index bae230d75..835e87b05 100644 --- a/dev/x86_64/src/pointwise_acc_l7.S +++ b/dev/x86_64/src/pointwise_acc_l7.S @@ -37,12 +37,17 @@ vpsrlq ymm9, ymm8, 32 vmovshdup ymm11, ymm10 vmovshdup ymm13, ymm12 + /* + * 6, 7, 8, 9: from the first input polynomial, abs bound < q + * 10, 11, 12, 13: from the second input polynomial, abs bound < 9q + */ // Multiply vpmuldq ymm6, ymm6, ymm10 vpmuldq ymm7, ymm7, ymm11 vpmuldq ymm8, ymm8, ymm12 vpmuldq ymm9, ymm9, ymm13 + /* All: abs bound < 9q^2 */ .endm .macro acc @@ -80,24 +85,31 @@ _looptop2: vmovdqa ymm3, ymm7 vmovdqa ymm4, ymm8 vmovdqa ymm5, ymm9 + /* All: abs bound < 9q^2 */ pointwise 1024 acc + /* All: abs bound < 18q^2 */ pointwise 2048 acc + /* All: abs bound < 27q^2 */ pointwise 3072 acc + /* All: abs bound < 36q^2 */ pointwise 4096 acc + /* All: abs bound < 45q^2 */ pointwise 5120 acc + /* All: abs bound < 54q^2 */ pointwise 6144 acc + /* All: abs bound < 63q^2 < 63*2^46 < 2^52 = 2^20R < qR/2 */ // Reduce vpmuldq ymm6, ymm0, ymm2 @@ -114,6 +126,11 @@ _looptop2: vpsubq ymm5, ymm5, ymm9 vpsrlq ymm2, ymm2, 32 vmovshdup ymm4, ymm4 + /* + * All coefficients are Montgomery-reduced. This results in the bound + * + * All: abs bound <= "input abs bound"/R + q/2 < (qR/2)/R + q/2 = q + */ // Store vpblendd ymm2, ymm2, ymm3, 0xAA diff --git a/mldsa/src/ntt.h b/mldsa/src/ntt.h index 9cbd0a19f..e64d57a12 100644 --- a/mldsa/src/ntt.h +++ b/mldsa/src/ntt.h @@ -21,8 +21,8 @@ /* Absolute exclusive upper bound for the output of the forward NTT */ #define MLD_NTT_BOUND (9 * MLDSA_Q) -/* Absolute exclusive upper bound for the output of the inverse NTT*/ -#define MLD_INTT_BOUND (MLDSA_Q * 3 / 4) +/* Absolute exclusive upper bound for the output of the inverse NTT */ +#define MLD_INTT_BOUND (MLDSA_Q * 3 / 4 + 1) /* ceil(3 * MLDSA_Q / 4) */ #define mld_ntt MLD_NAMESPACE(ntt) /************************************************* diff --git a/test/test_bounds.py b/test/test_bounds.py new file mode 100644 index 000000000..505d8b95a --- /dev/null +++ b/test/test_bounds.py @@ -0,0 +1,124 @@ +# Copyright (c) The mlkem-native project authors +# Copyright (c) The mldsa-native project authors +# SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + +# +# The purpose of this script is to provide either brute-force proof +# or empirical evidence to arithmetic bounds for the modular +# arithmetic primitives used in this repository. +# + +import random +from functools import lru_cache +from fractions import Fraction +from math import ceil + +# Global constants +R = 2**32 +Q = 8380417 +Qinv = pow(Q, -1, R) +NQinv = pow(-Q, -1, R) + + +# +# Montgomery multiplication +# + + +def lift_signed_i32(x): + """Returns signed canonical representative modulo R=2^32.""" + x = x % R + if x >= R // 2: + x -= R + return x + + +@lru_cache(maxsize=None) +def montmul_neg_twiddle(b): + return (b * NQinv) % R + + +@lru_cache(maxsize=None) +def montmul_pos_twiddle(b): + return (b * Qinv) % R + + +def montmul_neg(a, b): + b_twiddle = montmul_neg_twiddle(b) + return (a * b + Q * lift_signed_i32(a * b_twiddle)) // R + + +def montmul_pos(a, b): + b_twiddle = montmul_pos_twiddle(b) + return (a * b - Q * lift_signed_i32(a * b_twiddle)) // R + + +# +# Generic test functions +# + + +def test_random(f, test_name, num_tests=10000000, bound_a=R // 2, bound_b=Q // 2): + print(f"Randomly checking {test_name} ({num_tests} tests)...") + for i in range(num_tests): + if i % 100000 == 0: + print(f"... run {i} tests ({((i * 1000) // num_tests)/10}%)") + a = random.randrange(-bound_a, bound_a) + b = random.randrange(-bound_b, bound_b) + f(a, b) + + +# +# Test bound on "Montgomery multiplication with signed canonical constant", as +# used in AVX2 [I]NTT +# + +""" +In @[Survey_Hwang23, Section 2.2], the author noted the bound* + + |montmul(a, b)| <= (q/2) (1 + |a|/R). + +In particular, knowing that a fits inside int32_t (thus |a| <= R/2) already +implies |montmul(a, b)| <= 3q/4 < ceil(3q/4). + +(*) Strictly speaking, they considered the negative/additive variant + montmul_neg(a, b), but the exact same bound and proof also work for the + positive/subtractive variant montmul_pos(a, b). +""" + + +def montmul_pos_const_bound(a): + return Fraction(Q, 2) * (1 + Fraction(abs(a), R)) + + +def montmul_pos_const_bound_test(a, b): + ab = montmul_pos(a, b) + bound = montmul_pos_const_bound(a) + if abs(ab) > bound: + print(f"montmul_pos_const_bound_test failure for (a,b)={(a,b)}") + print(f"montmul_pos(a,b): {ab}") + print(f"bound: {bound}") + assert False + + +def montmul_pos_const_bound_test_random(): + test_random( + montmul_pos_const_bound_test, + "bound on Montgomery multiplication with constant, as used in AVX2 [I]NTT", + ) + + +def montmul_pos_const_bound_tight(): + """ + This example shows that, unless we know more about a or b, the bound + |montmul(a, b)| < ceil(3q/4) is the tightest exclusive bound. + """ + a_worst = -R // 2 + b_worst = -(Q - 3) // 2 + ab_worst = montmul_pos(a_worst, b_worst) + bound = ceil(Fraction(3 * Q, 4)) + assert ab_worst == bound - 1 + + +montmul_pos_const_bound_test_random() +montmul_pos_const_bound_tight()