Skip to content

Commit 9e8583a

Browse files
committed
Port test/test_bounds.py to confirm bound for Montgomery
Signed-off-by: jammychiou1 <[email protected]>
1 parent 502d1de commit 9e8583a

File tree

2 files changed

+127
-0
lines changed

2 files changed

+127
-0
lines changed

dev/x86_64/src/intt.S

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,9 @@ vpblendd $0xAA,%ymm9,%ymm7,%ymm7
323323
* int32_t (thus |a| <= R/2), we still have |montmul(a, b)| <= 3q/4. This can be
324324
* strengthened to |montmul_pos(a, b)| <= floor(3q/4) < ceil(3q/4) since LHS is
325325
* an integer and 3q/4 isn't.
326+
*
327+
* See test/test_bounds.py for more empirical evidence (and some minor technical
328+
* details).
326329
*/
327330

328331
/* 4, 5, 6, 7: abs bound < ceil(3q/4) */

test/test_bounds.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# Copyright (c) The mlkem-native project authors
2+
# Copyright (c) The mldsa-native project authors
3+
# SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT
4+
5+
#
6+
# The purpose of this script is to provide either brute-force proof
7+
# or empirical evidence to arithmetic bounds for the modular
8+
# arithmetic primitives used in this repository.
9+
#
10+
11+
import random
12+
from functools import lru_cache
13+
from fractions import Fraction
14+
from math import ceil
15+
16+
# Global constants
17+
R = 2**32
18+
Q = 8380417
19+
Qinv = pow(Q, -1, R)
20+
NQinv = pow(-Q, -1, R)
21+
22+
23+
#
24+
# Montgomery multiplication
25+
#
26+
27+
28+
def lift_signed_i32(x):
29+
"""Returns signed canonical representative modulo R=2^32."""
30+
x = x % R
31+
if x >= R // 2:
32+
x -= R
33+
return x
34+
35+
36+
@lru_cache(maxsize=None)
37+
def montmul_neg_twiddle(b):
38+
return (b * NQinv) % R
39+
40+
41+
@lru_cache(maxsize=None)
42+
def montmul_pos_twiddle(b):
43+
return (b * Qinv) % R
44+
45+
46+
def montmul_neg(a, b):
47+
b_twiddle = montmul_neg_twiddle(b)
48+
return (a * b + Q * lift_signed_i32(a * b_twiddle)) // R
49+
50+
51+
def montmul_pos(a, b):
52+
b_twiddle = montmul_pos_twiddle(b)
53+
return (a * b - Q * lift_signed_i32(a * b_twiddle)) // R
54+
55+
56+
#
57+
# Generic test functions
58+
#
59+
60+
61+
def test_random(f, test_name, num_tests=10000000, bound_a=R // 2, bound_b=Q // 2):
62+
print(f"Randomly checking {test_name} ({num_tests} tests)...")
63+
for i in range(num_tests):
64+
if i % 100000 == 0:
65+
print(f"... run {i} tests ({((i * 1000) // num_tests)/10}%)")
66+
a = random.randrange(-bound_a, bound_a)
67+
b = random.randrange(-bound_b, bound_b)
68+
f(a, b)
69+
70+
71+
#
72+
# Test bound on "Montgomery multiplication with signed canonical constant", as
73+
# used in AVX2 [I]NTT
74+
#
75+
76+
"""
77+
In @[Survey_Hwang23, Section 2.2], the author noted the bound*
78+
79+
|montmul(a, b)| <= (q/2) (1 + |a|/R).
80+
81+
In particular, knowing that a fits inside int32_t (thus |a| <= R/2) already
82+
implies |montmul(a, b)| <= 3q/4 < ceil(3q/4).
83+
84+
(*) Strictly speaking, they considered the negative/additive variant
85+
montmul_neg(a, b), but the exact same bound and proof also work for the
86+
positive/subtractive variant montmul_pos(a, b).
87+
"""
88+
89+
90+
def montmul_pos_const_bound(a):
91+
return Fraction(Q, 2) * (1 + Fraction(abs(a), R))
92+
93+
94+
def montmul_pos_const_bound_test(a, b):
95+
ab = montmul_pos(a, b)
96+
bound = montmul_pos_const_bound(a)
97+
if abs(ab) > bound:
98+
print(f"montmul_pos_const_bound_test failure for (a,b)={(a,b)}")
99+
print(f"montmul_pos(a,b): {ab}")
100+
print(f"bound: {bound}")
101+
assert False
102+
103+
104+
def montmul_pos_const_bound_test_random():
105+
test_random(
106+
montmul_pos_const_bound_test,
107+
"bound on Montgomery multiplication with constant, as used in AVX2 [I]NTT",
108+
)
109+
110+
111+
def montmul_pos_const_bound_tight():
112+
"""
113+
This example shows that, unless we know more about a or b, the bound
114+
|montmul(a, b)| < ceil(3q/4) is the tightest exclusive bound.
115+
"""
116+
a_worst = -R // 2
117+
b_worst = -(Q - 3) // 2
118+
ab_worst = montmul_pos(a_worst, b_worst)
119+
bound = ceil(Fraction(3 * Q, 4))
120+
assert ab_worst == bound - 1
121+
122+
123+
montmul_pos_const_bound_test_random()
124+
montmul_pos_const_bound_tight()

0 commit comments

Comments
 (0)