Skip to content

Commit d3c6070

Browse files
committed
Better explanation for Barrett division in decompose (C and AVX2)
Based on Hanno Becker's proposal, the new explanation explains how round-(f1' / B) can be replaced with rounding-mulhi, regardless of the type of rounding used in the mulhi. In addition, by bounding the approximation error to be strictly less than 1 / B, the exactness of the Barrett division is also justified. To avoid excessive repetition, we prove the GAMMA2 = (Q-1)/88 case in the C implementation, remark how the same proof can be adapted to the GAMMA2 = (Q-1)/32 case, and finally refer to them when explaining the AVX2 implementation. Signed-off-by: jammychiou1 <jammy.chiou1@gmail.com>
1 parent 1f24035 commit d3c6070

File tree

5 files changed

+39
-14
lines changed

5 files changed

+39
-14
lines changed

dev/x86_64/src/poly_decompose_32_avx2.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ void mld_poly_decompose_32_avx2(int32_t *a1, int32_t *a0)
7474

7575
/*
7676
* Compute f1 = round-(f1' / B) ≈ round(f1' * 1025 / 2^22). This is exact
77-
* for 0 <= f1' < 2^16. Note that half is rounded down since 1025 / 2^22 ≲
78-
* 1 / 4092.
77+
* for 0 <= f1' < 2^16. See mld_decompose() in mldsa/src/rounding.h for the
78+
* proof.
7979
*
8080
* round(f1' * 1025 / 2^22) is in turn computed in 2 steps as
8181
* round(floor(f1' * 1025 / 2^16) / 2^6). The mulhi computes f1'' =

dev/x86_64/src/poly_decompose_88_avx2.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ void mld_poly_decompose_88_avx2(int32_t *a1, int32_t *a0)
7575

7676
/*
7777
* Compute f1 = round-(f1' / B) ≈ round(f1' * 11275 / 2^24). This is exact
78-
* for 0 <= f1' < 2^16. Note that half is rounded down since 11275 / 2^24 ≲
79-
* 1 / 1488.
78+
* for 0 <= f1' < 2^16. See mld_decompose() in mldsa/src/rounding.h for the
79+
* proof.
8080
*
8181
* round(f1' * 11275 / 2^24) is in turn computed in 2 steps as
8282
* round(floor(f1' * 11275 / 2^16) / 2^8). The mulhi computes f1'' =

mldsa/src/native/x86_64/src/poly_decompose_32_avx2.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ void mld_poly_decompose_32_avx2(int32_t *a1, int32_t *a0)
7474

7575
/*
7676
* Compute f1 = round-(f1' / B) ≈ round(f1' * 1025 / 2^22). This is exact
77-
* for 0 <= f1' < 2^16. Note that half is rounded down since 1025 / 2^22 ≲
78-
* 1 / 4092.
77+
* for 0 <= f1' < 2^16. See mld_decompose() in mldsa/src/rounding.h for the
78+
* proof.
7979
*
8080
* round(f1' * 1025 / 2^22) is in turn computed in 2 steps as
8181
* round(floor(f1' * 1025 / 2^16) / 2^6). The mulhi computes f1'' =

mldsa/src/native/x86_64/src/poly_decompose_88_avx2.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ void mld_poly_decompose_88_avx2(int32_t *a1, int32_t *a0)
7575

7676
/*
7777
* Compute f1 = round-(f1' / B) ≈ round(f1' * 11275 / 2^24). This is exact
78-
* for 0 <= f1' < 2^16. Note that half is rounded down since 11275 / 2^24 ≲
79-
* 1 / 1488.
78+
* for 0 <= f1' < 2^16. See mld_decompose() in mldsa/src/rounding.h for the
79+
* proof.
8080
*
8181
* round(f1' * 11275 / 2^24) is in turn computed in 2 steps as
8282
* round(floor(f1' * 11275 / 2^16) / 2^8). The mulhi computes f1'' =

mldsa/src/rounding.h

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,32 @@ __contract__(
115115
#if MLD_CONFIG_PARAMETER_SET == 44
116116
/* check-magic: 1488 == 2 * intdiv(intdiv(MLDSA_Q - 1, 88), 128) */
117117
/* check-magic: 11275 == floor(2**24 / 1488) */
118+
/* check-magic: 1560281088 == 1 / (1 / 1488 - 11275 / 2**24) */
118119
/*
119-
* Compute f1 = round-(f1' / B) ≈ round(f1' * 11275 / 2^24). This is exact
120-
* for 0 <= f1' < 2^16. Note that half is rounded down since 11275 / 2^24 ≲
121-
* 1 / 1488.
120+
* Compute f1 = round-(f1' / B) ≈ round(f1' * 11275 / 2^24). This is exact for
121+
* 0 <= f1' < 2^16.
122+
*
123+
* To see this, consider the (signed) error f1' * (1 / B - 11275 / 2^24)
124+
* between f1' / B and the (under-)approximation f1' * 11275 / 2^24. Because
125+
* eps := 1 / B - 11275 / 2^24 is 1 / 1560281088 ≈ 2^(-30.54) < 2^(-30), we
126+
* have 0 <= f1' * eps < 2^16 * 2^(-30) = 1 / 2^14 < 1 / 2^11 < 1 / B (note
127+
* that f1' is non-negative).
128+
*
129+
* On the other hand, 1 / B is the spacing between the integral multiples
130+
* of 1 / B, which includes all rounding boundaries n + 0.5 (since B is even).
131+
* Hence, if f1' / B is not of the form n + 0.5, then it is at least 1 / B
132+
* away from the nearest rounding boundary, so moving from f1' / B to
133+
* f1' * 11275 / 2^24 does not affect the rounding result, no matter the type
134+
* of rounding used in either side. In particular, we have round-(f1' / B) =
135+
* round(f1' * 11275 / 2^24) as claimed.
136+
*
137+
* As for the remaining case where f1' / B _is_ of the form n + 0.5, because
138+
* f1' * 11275 / 2^24 is slightly but strictly below f1' / B = n + 0.5 (note
139+
* that f1' and thus the error f1' * eps cannot be 0 here), it is always
140+
* rounded down to n. More precisely, we have round-(f1' / B) =
141+
* round(f1' * 11275 / 2^24), where the round-down on the LHS is essential,
142+
* and on the RHS the type of rounding again does not matter. This concludes
143+
* the proof.
122144
*/
123145
*a1 = (*a1 * 11275 + (1 << 23)) >> 24;
124146
mld_assert(*a1 >= 0 && *a1 <= 44);
@@ -128,10 +150,13 @@ __contract__(
128150
#else /* MLD_CONFIG_PARAMETER_SET == 44 */
129151
/* check-magic: 4092 == 2 * intdiv(intdiv(MLDSA_Q - 1, 32), 128) */
130152
/* check-magic: 1025 == floor(2**22 / 4092) */
153+
/* check-magic: 4290772992 == 1 / (1 / 4092 - 1025 / 2**22) */
131154
/*
132-
* Compute f1 = round-(f1' / B) ≈ round(f1' * 1025 / 2^22). This is exact
133-
* for 0 <= f1' < 2^16. Note that half is rounded down since 1025 / 2^22 ≲
134-
* 1 / 4092.
155+
* Compute f1 = round-(f1' / B) ≈ round(f1' * 1025 / 2^22). This is exact for
156+
* 0 <= f1' < 2^16. Following the same argument above, it suffices to show
157+
* that f1' * eps < 1 / B, where eps := 1 / B - 1025 / 2^22. Indeed, we have
158+
* eps = 1 / 4290772992 ≈ 2^(-31.99) < 2^(-31), therefore f1' * eps <
159+
* 2^16 * 2^(-31) = 1 / 2^15 < 1 / 2^12 < 1 / B.
135160
*/
136161
*a1 = (*a1 * 1025 + (1 << 21)) >> 22;
137162
mld_assert(*a1 >= 0 && *a1 <= 16);

0 commit comments

Comments
 (0)