Skip to content

Commit 0d538ed

Browse files
committed
Simplification and refactoring to restore proof speed and stability.
1. Weaken post-condition and loop invariant in polyvecl_add(). The stonger post-condition was unncessary. 2. Simplify polyvec_matrix_expand(). Small performance loss here since batched_seeds[] is (re-) initialized every time. This is bit slower but removes a loop statement entirely. 3. Refactor polyvec_pointwise_acc_montgomery() by splitting core "sum of products" calculation into a distinct local function mld_pointwise_sum_of_products(). Add proof of the latter. Proof time for parameter set 87 now 4 minutes (real-time) and 40 minutes (user time) with 64 cores on an r7g instance. Signed-off-by: Rod Chapman <[email protected]>
1 parent 9291537 commit 0d538ed

File tree

5 files changed

+167
-54
lines changed

5 files changed

+167
-54
lines changed

mldsa/src/polyvec.c

Lines changed: 96 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -45,22 +45,14 @@ void mld_polyvec_matrix_expand(mld_polyvecl mat[MLDSA_K],
4545
* of the same parent object.
4646
*/
4747

48-
MLD_ALIGN uint8_t seed_ext[4][MLD_ALIGN_UP(MLDSA_SEEDBYTES + 2)];
49-
50-
for (j = 0; j < 4; j++)
51-
__loop__(
52-
assigns(j, object_whole(seed_ext))
53-
invariant(j <= 4)
54-
)
55-
{
56-
mld_memcpy(seed_ext[j], rho, MLDSA_SEEDBYTES);
57-
}
48+
MLD_ALIGN uint8_t single_seed[MLD_ALIGN_UP(MLDSA_SEEDBYTES + 2)];
5849

5950
#if !defined(MLD_CONFIG_SERIAL_FIPS202_ONLY)
51+
MLD_ALIGN uint8_t batched_seeds[4][MLD_ALIGN_UP(MLDSA_SEEDBYTES + 2)];
6052
/* Sample 4 matrix entries a time. */
6153
for (i = 0; i < (MLDSA_K * MLDSA_L / 4) * 4; i += 4)
6254
__loop__(
63-
assigns(i, j, object_whole(seed_ext), memory_slice(mat, MLDSA_K * sizeof(mld_polyvecl)))
55+
assigns(i, j, object_whole(batched_seeds), memory_slice(mat, MLDSA_K * sizeof(mld_polyvecl)))
6456
invariant(i <= (MLDSA_K * MLDSA_L / 4) * 4 && i % 4 == 0)
6557
/* vectors 0 .. i / MLDSA_L are completely sampled */
6658
invariant(forall(k1, 0, i / MLDSA_L, forall(l1, 0, MLDSA_L,
@@ -72,31 +64,38 @@ void mld_polyvec_matrix_expand(mld_polyvecl mat[MLDSA_K],
7264
{
7365
for (j = 0; j < 4; j++)
7466
__loop__(
75-
assigns(j, object_whole(seed_ext))
67+
assigns(j, object_whole(batched_seeds))
7668
invariant(j <= 4)
7769
)
7870
{
7971
uint8_t x = (uint8_t)((i + j) / MLDSA_L);
8072
uint8_t y = (uint8_t)((i + j) % MLDSA_L);
8173

82-
seed_ext[j][MLDSA_SEEDBYTES + 0] = y;
83-
seed_ext[j][MLDSA_SEEDBYTES + 1] = x;
74+
mld_memcpy(batched_seeds[j], rho, MLDSA_SEEDBYTES);
75+
batched_seeds[j][MLDSA_SEEDBYTES + 0] = y;
76+
batched_seeds[j][MLDSA_SEEDBYTES + 1] = x;
8477
}
8578

8679
mld_poly_uniform_4x(&mat[i / MLDSA_L].vec[i % MLDSA_L],
8780
&mat[(i + 1) / MLDSA_L].vec[(i + 1) % MLDSA_L],
8881
&mat[(i + 2) / MLDSA_L].vec[(i + 2) % MLDSA_L],
8982
&mat[(i + 3) / MLDSA_L].vec[(i + 3) % MLDSA_L],
90-
seed_ext);
83+
batched_seeds);
9184
}
85+
86+
/* @[FIPS204, Section 3.6.3] Destruction of intermediate values. */
87+
mld_zeroize(batched_seeds, sizeof(batched_seeds));
88+
9289
#else /* !MLD_CONFIG_SERIAL_FIPS202_ONLY */
9390
i = 0;
9491
#endif /* MLD_CONFIG_SERIAL_FIPS202_ONLY */
9592

93+
mld_memcpy(single_seed, rho, MLDSA_SEEDBYTES);
94+
9695
/* For MLDSA_K=6, MLDSA_L=5, process the last two entries individually */
9796
while (i < MLDSA_K * MLDSA_L)
9897
__loop__(
99-
assigns(i, object_whole(seed_ext), memory_slice(mat, MLDSA_K * sizeof(mld_polyvecl)))
98+
assigns(i, object_whole(single_seed), memory_slice(mat, MLDSA_K * sizeof(mld_polyvecl)))
10099
invariant(i <= MLDSA_K * MLDSA_L)
101100
/* vectors 0 .. i / MLDSA_L are completely sampled */
102101
invariant(forall(k1, 0, i / MLDSA_L, forall(l1, 0, MLDSA_L,
@@ -110,27 +109,31 @@ void mld_polyvec_matrix_expand(mld_polyvecl mat[MLDSA_K],
110109
uint8_t y = (uint8_t)(i % MLDSA_L);
111110
mld_poly *this_poly = &mat[i / MLDSA_L].vec[i % MLDSA_L];
112111

113-
seed_ext[0][MLDSA_SEEDBYTES + 0] = y;
114-
seed_ext[0][MLDSA_SEEDBYTES + 1] = x;
112+
single_seed[MLDSA_SEEDBYTES + 0] = y;
113+
single_seed[MLDSA_SEEDBYTES + 1] = x;
115114

116-
mld_poly_uniform(this_poly, seed_ext[0]);
115+
mld_poly_uniform(this_poly, single_seed);
117116
i++;
118117
}
119118

120119
/*
121120
* The public matrix is generated in NTT domain. If the native backend
122-
* uses a custom order in NTT domain, permute A accordingly.
121+
* uses a custom order in NTT domain, permute A accordingly. This does
122+
* not affect the bounds on the coefficients, so we ignore this for CBMC
123+
* to simplify proof.
123124
*/
125+
#ifndef CBMC
124126
for (i = 0; i < MLDSA_K; i++)
125127
{
126128
for (j = 0; j < MLDSA_L; j++)
127129
{
128130
mld_poly_permute_bitrev_to_custom(mat[i].vec[j].coeffs);
129131
}
130132
}
133+
#endif /* !CBMC */
131134

132135
/* @[FIPS204, Section 3.6.3] Destruction of intermediate values. */
133-
mld_zeroize(seed_ext, sizeof(seed_ext));
136+
mld_zeroize(single_seed, sizeof(single_seed));
134137
}
135138

136139
MLD_INTERNAL_API
@@ -235,7 +238,6 @@ void mld_polyvecl_add(mld_polyvecl *u, const mld_polyvecl *v)
235238
invariant(i <= MLDSA_L)
236239
invariant(forall(k0, i, MLDSA_L,
237240
forall(k1, 0, MLDSA_N, u->vec[k0].coeffs[k1] == loop_entry(*u).vec[k0].coeffs[k1])))
238-
invariant(forall(k4, 0, i, forall(k5, 0, MLDSA_N, u->vec[k4].coeffs[k5] == loop_entry(*u).vec[k4].coeffs[k5] + v->vec[k4].coeffs[k5])))
239241
invariant(forall(k6, 0, i, array_bound(u->vec[k6].coeffs, 0, MLDSA_N, INT32_MIN, REDUCE32_DOMAIN_MAX)))
240242
)
241243
{
@@ -303,87 +305,129 @@ void mld_polyvecl_pointwise_poly_montgomery(mld_polyvecl *r, const mld_poly *a,
303305
mld_assert_abs_bound_2d(r->vec, MLDSA_L, MLDSA_N, MLDSA_Q);
304306
}
305307

308+
#if defined(MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L4) && \
309+
MLD_CONFIG_PARAMETER_SET == 44
310+
306311
MLD_INTERNAL_API
307312
void mld_polyvecl_pointwise_acc_montgomery(mld_poly *w, const mld_polyvecl *u,
308313
const mld_polyvecl *v)
309314
{
310-
#if defined(MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L4) && \
311-
MLD_CONFIG_PARAMETER_SET == 44
312315
/* TODO: proof */
313316
mld_assert_bound_2d(u->vec, MLDSA_L, MLDSA_N, 0, MLDSA_Q);
314317
mld_assert_abs_bound_2d(v->vec, MLDSA_L, MLDSA_N, MLD_NTT_BOUND);
315318
mld_polyvecl_pointwise_acc_montgomery_l4_native(
316319
w->coeffs, (const int32_t(*)[MLDSA_N])u->vec,
317320
(const int32_t(*)[MLDSA_N])v->vec);
318321
mld_assert_abs_bound(w->coeffs, MLDSA_N, MLDSA_Q);
322+
}
323+
319324
#elif defined(MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L5) && \
320325
MLD_CONFIG_PARAMETER_SET == 65
326+
327+
void mld_polyvecl_pointwise_acc_montgomery(mld_poly *w, const mld_polyvecl *u,
328+
const mld_polyvecl *v)
329+
{
321330
/* TODO: proof */
322331
mld_assert_bound_2d(u->vec, MLDSA_L, MLDSA_N, 0, MLDSA_Q);
323332
mld_assert_abs_bound_2d(v->vec, MLDSA_L, MLDSA_N, MLD_NTT_BOUND);
324333
mld_polyvecl_pointwise_acc_montgomery_l5_native(
325334
w->coeffs, (const int32_t(*)[MLDSA_N])u->vec,
326335
(const int32_t(*)[MLDSA_N])v->vec);
327336
mld_assert_abs_bound(w->coeffs, MLDSA_N, MLDSA_Q);
337+
}
338+
328339
#elif defined(MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L7) && \
329340
MLD_CONFIG_PARAMETER_SET == 87
341+
void mld_polyvecl_pointwise_acc_montgomery(mld_poly *w, const mld_polyvecl *u,
342+
const mld_polyvecl *v)
343+
{
330344
/* TODO: proof */
331345
mld_assert_bound_2d(u->vec, MLDSA_L, MLDSA_N, 0, MLDSA_Q);
332346
mld_assert_abs_bound_2d(v->vec, MLDSA_L, MLDSA_N, MLD_NTT_BOUND);
333347
mld_polyvecl_pointwise_acc_montgomery_l7_native(
334348
w->coeffs, (const int32_t(*)[MLDSA_N])u->vec,
335349
(const int32_t(*)[MLDSA_N])v->vec);
336350
mld_assert_abs_bound(w->coeffs, MLDSA_N, MLDSA_Q);
337-
#else /* !(MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L4 && \
338-
MLD_CONFIG_PARAMETER_SET == 44) && \
339-
!(MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L5 && \
340-
MLD_CONFIG_PARAMETER_SET == 65) && \
341-
MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L7 && \
342-
MLD_CONFIG_PARAMETER_SET == 87 */
343-
unsigned int i, j;
344-
mld_assert_bound_2d(u->vec, MLDSA_L, MLDSA_N, 0, MLDSA_Q);
345-
mld_assert_abs_bound_2d(v->vec, MLDSA_L, MLDSA_N, MLD_NTT_BOUND);
346-
/* The first input is bounded by [0, Q-1] inclusive
347-
* The second input is bounded by [-9Q+1, 9Q-1] inclusive . Hence, we can
351+
}
352+
353+
#else /* !(MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L4 && \
354+
MLD_CONFIG_PARAMETER_SET == 44) && \
355+
!(MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L5 && \
356+
MLD_CONFIG_PARAMETER_SET == 65) && \
357+
MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L7 && \
358+
MLD_CONFIG_PARAMETER_SET == 87 */
359+
360+
static int64_t mld_pointwise_sum_of_products(const mld_polyvecl *u,
361+
const mld_polyvecl *v,
362+
unsigned int i)
363+
__contract__(
364+
requires(memory_no_alias(u, sizeof(mld_polyvecl)))
365+
requires(memory_no_alias(v, sizeof(mld_polyvecl)))
366+
requires(i < MLDSA_N)
367+
requires(forall(l0, 0, MLDSA_L,
368+
array_bound(u->vec[l0].coeffs, 0, MLDSA_N, 0, MLDSA_Q)))
369+
requires(forall(l1, 0, MLDSA_L,
370+
array_abs_bound(v->vec[l1].coeffs, 0, MLDSA_N, MLD_NTT_BOUND)))
371+
ensures(return_value >= -(int64_t) MLDSA_L*(MLDSA_Q - 1)*(MLD_NTT_BOUND - 1))
372+
ensures(return_value <= (int64_t) MLDSA_L*(MLDSA_Q - 1)*(MLD_NTT_BOUND - 1))
373+
)
374+
{
375+
/* Input vector u is bounded by [0, Q-1] inclusive
376+
* Input vector v is bounded by [-9Q+1, 9Q-1] inclusive . Hence, we can
348377
* safely accumulate in 64-bits without intermediate reductions as
349378
* MLDSA_L * (MLD_NTT_BOUND-1) * (Q-1) < INT64_MAX
350379
*
351380
* The worst case is ML-DSA-87: 7 * (9Q-1) * (Q-1) < 2**52
352381
* (and likewise for negative values)
353382
*/
354383

384+
int64_t t = 0;
385+
unsigned int j;
386+
for (j = 0; j < MLDSA_L; j++)
387+
__loop__(
388+
assigns(j, t)
389+
invariant(j <= MLDSA_L)
390+
invariant(t >= -(int64_t)j*(MLDSA_Q - 1)*(MLD_NTT_BOUND - 1))
391+
invariant(t <= (int64_t)j*(MLDSA_Q - 1)*(MLD_NTT_BOUND - 1))
392+
)
393+
{
394+
const int64_t u64 = (int64_t)u->vec[j].coeffs[i];
395+
const int64_t v64 = (int64_t)v->vec[j].coeffs[i];
396+
/* Helper assertions for proof efficiency. Do not remove */
397+
mld_assert(u64 >= 0 && u64 < MLDSA_Q);
398+
mld_assert(v64 > -MLD_NTT_BOUND && v64 < MLD_NTT_BOUND);
399+
t += (u64 * v64);
400+
}
401+
return t;
402+
}
403+
404+
void mld_polyvecl_pointwise_acc_montgomery(mld_poly *w, const mld_polyvecl *u,
405+
const mld_polyvecl *v)
406+
{
407+
unsigned int i;
408+
409+
mld_assert_bound_2d(u->vec, MLDSA_L, MLDSA_N, 0, MLDSA_Q);
410+
mld_assert_abs_bound_2d(v->vec, MLDSA_L, MLDSA_N, MLD_NTT_BOUND);
355411
for (i = 0; i < MLDSA_N; i++)
356412
__loop__(
357-
assigns(i, j, object_whole(w))
413+
assigns(i, object_whole(w))
358414
invariant(i <= MLDSA_N)
359415
invariant(array_abs_bound(w->coeffs, 0, i, MLDSA_Q))
360416
)
361417
{
362-
int64_t t = 0;
363-
int32_t r;
364-
for (j = 0; j < MLDSA_L; j++)
365-
__loop__(
366-
assigns(j, t)
367-
invariant(j <= MLDSA_L)
368-
invariant(t >= -(int64_t)j*(MLDSA_Q - 1)*(MLD_NTT_BOUND - 1))
369-
invariant(t <= (int64_t)j*(MLDSA_Q - 1)*(MLD_NTT_BOUND - 1))
370-
)
371-
{
372-
t += (int64_t)u->vec[j].coeffs[i] * v->vec[j].coeffs[i];
373-
}
374-
375-
r = mld_montgomery_reduce(t);
376-
w->coeffs[i] = r;
418+
w->coeffs[i] =
419+
mld_montgomery_reduce(mld_pointwise_sum_of_products(u, v, i));
377420
}
378421

379422
mld_assert_abs_bound(w->coeffs, MLDSA_N, MLDSA_Q);
423+
}
424+
380425
#endif /* !(MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L4 && \
381426
MLD_CONFIG_PARAMETER_SET == 44) && \
382427
!(MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L5 && \
383428
MLD_CONFIG_PARAMETER_SET == 65) && \
384429
!(MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L7 && \
385430
MLD_CONFIG_PARAMETER_SET == 87) */
386-
}
387431

388432
MLD_INTERNAL_API
389433
uint32_t mld_polyvecl_chknorm(const mld_polyvecl *v, int32_t bound)

mldsa/src/polyvec.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,6 @@ __contract__(
9393
requires(forall(k0, 0, MLDSA_L, forall(k1, 0, MLDSA_N, (int64_t) u->vec[k0].coeffs[k1] + v->vec[k0].coeffs[k1] < REDUCE32_DOMAIN_MAX)))
9494
requires(forall(k2, 0, MLDSA_L, forall(k3, 0, MLDSA_N, (int64_t) u->vec[k2].coeffs[k3] + v->vec[k2].coeffs[k3] >= INT32_MIN)))
9595
assigns(object_whole(u))
96-
ensures(forall(k4, 0, MLDSA_L, forall(k5, 0, MLDSA_N, u->vec[k4].coeffs[k5] == old(*u).vec[k4].coeffs[k5] + v->vec[k4].coeffs[k5])))
9796
ensures(forall(k6, 0, MLDSA_L,
9897
array_bound(u->vec[k6].coeffs, 0, MLDSA_N, INT32_MIN, REDUCE32_DOMAIN_MAX)))
9998
);
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright (c) The mldsa-native project authors
2+
# SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT
3+
4+
include ../Makefile_params.common
5+
6+
HARNESS_ENTRY = harness
7+
HARNESS_FILE = pointwise_sum_of_products_harness
8+
9+
# This should be a unique identifier for this proof, and will appear on the
10+
# Litani dashboard. It can be human-readable and contain spaces if you wish.
11+
PROOF_UID = pointwise_sum_of_products
12+
13+
DEFINES +=
14+
INCLUDES +=
15+
16+
REMOVE_FUNCTION_BODY +=
17+
UNWINDSET +=
18+
19+
PROOF_SOURCES += $(PROOFDIR)/$(HARNESS_FILE).c
20+
PROJECT_SOURCES += $(SRCDIR)/mldsa/polyvec.c
21+
22+
CHECK_FUNCTION_CONTRACTS=mld_pointwise_sum_of_products
23+
USE_FUNCTION_CONTRACTS=
24+
APPLY_LOOP_CONTRACTS=on
25+
USE_DYNAMIC_FRAMES=1
26+
27+
# Disable any setting of EXTERNAL_SAT_SOLVER, and choose SMT backend instead
28+
EXTERNAL_SAT_SOLVER=
29+
CBMCFLAGS=--smt2 --slice-formula
30+
31+
FUNCTION_NAME = pointwise_sum_of_products
32+
33+
# If this proof is found to consume huge amounts of RAM, you can set the
34+
# EXPENSIVE variable. With new enough versions of the proof tools, this will
35+
# restrict the number of EXPENSIVE CBMC jobs running at once. See the
36+
# documentation in Makefile.common under the "Job Pools" heading for details.
37+
# EXPENSIVE = true
38+
39+
# This function is large enough to need...
40+
CBMC_OBJECT_BITS = 12
41+
42+
# If you require access to a file-local ("static") function or object to conduct
43+
# your proof, set the following (and do not include the original source file
44+
# ("mldsa/poly.c") in PROJECT_SOURCES).
45+
# REWRITTEN_SOURCES = $(PROOFDIR)/<__SOURCE_FILE_BASENAME__>.i
46+
# include ../Makefile.common
47+
# $(PROOFDIR)/<__SOURCE_FILE_BASENAME__>.i_SOURCE = $(SRCDIR)/mldsa/poly.c
48+
# $(PROOFDIR)/<__SOURCE_FILE_BASENAME__>.i_FUNCTIONS = foo bar
49+
# $(PROOFDIR)/<__SOURCE_FILE_BASENAME__>.i_OBJECTS = baz
50+
# Care is required with variables on the left-hand side: REWRITTEN_SOURCES must
51+
# be set before including Makefile.common, but any use of variables on the
52+
# left-hand side requires those variables to be defined. Hence, _SOURCE,
53+
# _FUNCTIONS, _OBJECTS is set after including Makefile.common.
54+
55+
include ../Makefile.common
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// Copyright (c) The mldsa-native project authors
2+
// SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT
3+
4+
#include "polyvec.h"
5+
6+
int64_t mld_pointwise_sum_of_products(const mld_polyvecl *u,
7+
const mld_polyvecl *v, unsigned int i);
8+
9+
void harness(void)
10+
{
11+
mld_polyvecl *u, *v;
12+
unsigned int i;
13+
int64_t r;
14+
r = mld_pointwise_sum_of_products(u, v, i);
15+
}

proofs/cbmc/polyvecl_pointwise_acc_montgomery/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ PROOF_SOURCES += $(PROOFDIR)/$(HARNESS_FILE).c
2020
PROJECT_SOURCES += $(SRCDIR)/mldsa/src/polyvec.c
2121

2222
CHECK_FUNCTION_CONTRACTS=$(MLD_NAMESPACE)polyvecl_pointwise_acc_montgomery
23-
USE_FUNCTION_CONTRACTS=mld_montgomery_reduce
23+
USE_FUNCTION_CONTRACTS=mld_montgomery_reduce mld_pointwise_sum_of_products
2424
APPLY_LOOP_CONTRACTS=on
2525
USE_DYNAMIC_FRAMES=1
2626

0 commit comments

Comments
 (0)