Skip to content

Commit 50b2bb5

Browse files
hanno-beckermkannwischer
authored andcommitted
check-magic: Remember all explained magic constants
Previously, scripts/check-magic would remember only the last explained magic constant, preventing, for example, the explanation of multiple magic constants ahead of a comment block referring to all of them. Moreover, check-magic would only lazily evaluate a provided explanation when actually finding a magic value magic the LHS of the proposed explanation. In particular, a _wrong_ explanation would only be caught if, in the rest of the file under consideration, some matching magic constant would be found. This led to an ambiguous and unused explanation being overlooked in params.h. This commit makes check-magic more general so that - it always checks magic value explanations when they are provided, regardless of whether they are needed or not; and, - it remembers all magic values explained so far. Moreover, the `round` function is instrumented to fail if it is called on an odd multiple of 1/2 -- in this case, the rounding is ambiguous (do we want round-half-down or round-half-up?); this was the source of the mismatch in params.h previously. We also add support for `intdiv(a,b)` to an integer division which we want to assert to be without residue. This can be used instead of `//` to additionally check that the division is indeed integral. Signed-off-by: Hanno Becker <[email protected]>
1 parent fba1a38 commit 50b2bb5

File tree

2 files changed

+39
-13
lines changed

2 files changed

+39
-13
lines changed

mldsa/src/params.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
#define MLDSA_RNDBYTES 32
1414
#define MLDSA_N 256
1515
#define MLDSA_Q 8380417
16-
/* check-magic: 4190209 == round(MLDSA_Q/2) */
1716
#define MLDSA_Q_HALF ((MLDSA_Q + 1) / 2)
1817
#define MLDSA_D 13
1918

scripts/check-magic

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88
#
99

1010
import re
11+
import math
1112
import pathlib
1213

13-
from sympy import simplify, sympify, Function
14+
from sympy import simplify, sympify, Function, Rational
1415

1516
def get_c_source_files():
1617
return get_files("mldsa/**/*.c")
@@ -21,6 +22,17 @@ def get_header_files():
2122
def get_files(pattern):
2223
return list(map(str, pathlib.Path().glob(pattern)))
2324

25+
# Standard color definitions
26+
GREEN="\033[32m"
27+
RED="\033[31m"
28+
BLUE="\033[94m"
29+
BOLD="\033[1m"
30+
NORMAL="\033[0m"
31+
32+
CHECKED = f"{GREEN}{NORMAL}"
33+
FAIL = f"{RED}{NORMAL}"
34+
REMEMBERED = f"{BLUE}{NORMAL}"
35+
2436
def check_magic_numbers():
2537
mldsa_q = 8380417
2638
exceptions = [mldsa_q]
@@ -62,9 +74,21 @@ def check_magic_numbers():
6274
y = int(y)
6375
m = int(m)
6476
return signed_mod(pow(x,y,m),m)
77+
def safe_round(x):
78+
if x - math.floor(x) == Rational(1, 2):
79+
raise ValueError(f"Ambiguous rounding: {x} is an odd multiple of 0.5 and it is unclear if round-up or round-down is desired")
80+
return round(x)
81+
def safe_floordiv(x, y):
82+
x = int(x)
83+
y = int(y)
84+
if x % y != 0:
85+
raise ValueError(f"Non-integral division: {x} // {y} has remainder {x % y}")
86+
return x // y
6587
locals_dict = {'signed_mod': signed_mod,
6688
'unsigned_mod': unsigned_mod,
67-
'pow': pow_mod }
89+
'pow': pow_mod,
90+
'round': safe_round,
91+
'intdiv': safe_floordiv }
6892
locals_dict.update(known_magics)
6993
return sympify(m, locals=locals_dict)
7094

@@ -80,6 +104,7 @@ def check_magic_numbers():
80104
enabled = True
81105
magic_dict = {'MLDSA_Q': mldsa_q, 'REDUCE32_DOMAIN_MAX': 2143289343}
82106
magic_expr = None
107+
verified_magics = {}
83108
for i, l in enumerate(content):
84109
if enabled is True and disable_marker in l:
85110
enabled = False
@@ -92,6 +117,12 @@ def check_magic_numbers():
92117
l, g = get_magic(l)
93118
if g is not None:
94119
magic_val, magic_expr = g
120+
magic_val_check = evaluate_magic(magic_expr, magic_dict)
121+
if magic_val != magic_val_check:
122+
print(f"{FAIL}:{filename}:{i+1}: Mismatching magic annotation: {magic_val} != {magic_expr} (= {magic_val_check})")
123+
exit(1)
124+
print(f"{REMEMBERED}:{filename}:{i+1}: Verified explanation {magic_val} == {magic_expr}")
125+
verified_magics[magic_val] = magic_expr
95126

96127
found = next(re.finditer(pattern, l), None)
97128
if found is None:
@@ -101,16 +132,12 @@ def check_magic_numbers():
101132
if is_exception(filename, l, magic):
102133
continue
103134

104-
if magic_expr is not None:
105-
val = evaluate_magic(magic_expr, magic_dict)
106-
if magic_val != val:
107-
raise Exception(f"{filename}:{i}: Mismatching magic annotation: {magic_val} != {val}")
108-
if val == magic:
109-
print(f"[OK] {filename}:{i}: Verified magic constant {magic} == {magic_expr}")
110-
else:
111-
raise Exception(f"{filename}:{i}: Magic constant mismatch {magic} != {magic_expr}")
112-
else:
113-
raise Exception(f"{filename}:{i}: No explanation for magic value {magic}")
135+
explanation = verified_magics.get(magic, None)
136+
if explanation is None:
137+
print(f"{FAIL}:{filename}:{i+1}: No explanation for magic value {magic}")
138+
exit(1)
139+
140+
print(f"{CHECKED}:{filename}:{i+1}: {magic} previously explained as {explanation}")
114141

115142
# If this is a #define's clause, remember it
116143
define = get_define(l)

0 commit comments

Comments
 (0)