Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions dev/aarch64_clean/src/poly_decompose_32_asm.S
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,37 @@
.macro decompose32 a1, a, temp
// range: 0 <= a <= Q-1 = 32*GAMMA2

/* check-magic: 523776 == 2 * intdiv(MLDSA_Q - 1, 32) */
/* check-magic: 1074791425 == floor(2**49 / 523776) */
/* check-magic: 575897802350002176 == 1 / (1 / 523776 - 1074791425 / 2^49) */
// Compute a1 = round-(a / (2*GAMMA2)) = round-(a / 523776) ≈
// round(a * 1074791425 / 2^49), where round-() denotes "round half
// down". This is exact for 0 <= a < Q. Note that half is rounded down
// since 1074791425 / 2^49 ≲ 1 / 523776.
// down". This is exact for 0 <= a < Q. We'll prove this in the
// following paragraphs, in which we denote 2*GAMMA2 as B to avoid
// clutter.
//
// Consider the (signed) error a * (1 / B - 1074791425 / 2^49) between
// a / B and the (under-)approximation a * 1074791425 / 2^49. Because
// eps := 1 / B - 1074791425 / 2^49 is 1 / 575897802350002176 ≈
// 2^(-58.99) < 2^(-58), we have 0 <= a * eps < 2^23 * 2^(-58) =
// 1 / 2^35 < 1 / 2^19 < 1 / B (note that a is non-negative).
//
// On the other hand, 1 / B is the spacing between the integral
// multiples of 1 / B, which includes all rounding boundaries n + 0.5
// (since B is even). Hence, if a / B is not of the form n + 0.5, then
// it is at least 1 / B away from the nearest rounding boundary, so
// moving from a / B to a * 1074791425 / 2^49 does not affect the
// rounding result, no matter the type of rounding used in either side.
// In particular, we have round-(a / B) = round(a * 1074791425 / 2^49)
// as claimed.
//
// As for the remaining case where a / B _is_ of the form n + 0.5,
// because a * 1074791425 / 2^49 is slightly but strictly below a / B =
// n + 0.5 (note that a and thus the error a * eps cannot be 0 here), it
// is always rounded down to n. More precisely, we have round-(a / B) =
// round(a * 1074791425 / 2^49), where the round-down on the LHS is
// essential, and on the RHS the type of rounding again does not matter.
// This concludes the proof.
sqdmulh \a1\().4s, \a\().4s, barrett_const.4s
srshr \a1\().4s, \a1\().4s, #18
// range: 0 <= a1 <= 16
Expand Down
31 changes: 29 additions & 2 deletions dev/aarch64_clean/src/poly_decompose_88_asm.S
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,37 @@
.macro decompose88 a1, a, temp
// range: 0 <= a <= Q-1 = 88*GAMMA2

/* check-magic: 190464 == 2 * intdiv(MLDSA_Q - 1, 88) */
/* check-magic: 1477838209 == floor(2**48 / 190464) */
/* check-magic: 26177172834091008 == 35 / (1 / 190464 - 1477838209 / 2^48) */
// Compute a1 = round-(a / (2*GAMMA2)) = round-(a / 190464) ≈
// round(a * 1477838209 / 2^48), where round-() denotes "round half
// down". This is exact for 0 <= a < Q. Note that half is rounded down
// since 1477838209 / 2^48 ≲ 1 / 190464.
// down". This is exact for 0 <= a < Q. We'll prove this in the
// following paragraphs, in which we denote 2*GAMMA2 as B to avoid
// clutter.
//
// Consider the (signed) error a * (1 / B - 1477838209 / 2^48) between
// a / B and the (under-)approximation a * 1477838209 / 2^48. Because
// eps := 1 / B - 1477838209 / 2^48 is 35 / 26177172834091008 ≈
// 2^(-49.41) < 2^(-49), we have 0 <= a * eps < 2^23 * 2^(-49) =
// 1 / 2^26 < 1 / 2^18 < 1 / B (note that a is non-negative).
//
// On the other hand, 1 / B is the spacing between the integral
// multiples of 1 / B, which includes all rounding boundaries n + 0.5
// (since B is even). Hence, if a / B is not of the form n + 0.5, then
// it is at least 1 / B away from the nearest rounding boundary, so
// moving from a / B to a * 1477838209 / 2^48 does not affect the
// rounding result, no matter the type of rounding used in either side.
// In particular, we have round-(a / B) = round(a * 1477838209 / 2^48)
// as claimed.
//
// As for the remaining case where a / B _is_ of the form n + 0.5,
// because a * 1477838209 / 2^48 is slightly but strictly below a / B =
// n + 0.5 (note that a and thus the error a * eps cannot be 0 here), it
// is always rounded down to n. More precisely, we have round-(a / B) =
// round(a * 1477838209 / 2^48), where the round-down on the LHS is
// essential, and on the RHS the type of rounding again does not matter.
// This concludes the proof.
sqdmulh \a1\().4s, \a\().4s, barrett_const.4s
srshr \a1\().4s, \a1\().4s, #17
// range: 0 <= a1 <= 44
Expand Down
10 changes: 6 additions & 4 deletions dev/x86_64/src/poly_decompose_32_avx2.c
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ void mld_poly_decompose_32_avx2(int32_t *a1, int32_t *a0)
* range: 0 <= f <= Q-1 = 32*GAMMA2 = 16*128*B
*/

/* Compute f1' = ceil(f / 128) as floor((f + 127) >> 7) */
/* Compute f1' = ceil(f / 128) as floor((f + 127) / 2^7) */
f1 = _mm256_add_epi32(f, off);
f1 = _mm256_srli_epi32(f1, 7);
/*
Expand All @@ -74,8 +74,8 @@ void mld_poly_decompose_32_avx2(int32_t *a1, int32_t *a0)

/*
* Compute f1 = round-(f1' / B) ≈ round(f1' * 1025 / 2^22). This is exact
* for 0 <= f1' < 2^16. Note that half is rounded down since 1025 / 2^22 ≲
* 1 / 4092.
* for 0 <= f1' < 2^16. See mld_decompose() in mldsa/src/rounding.h for the
* proof.
*
* round(f1' * 1025 / 2^22) is in turn computed in 2 steps as
* round(floor(f1' * 1025 / 2^16) / 2^6). The mulhi computes f1'' =
Expand All @@ -87,7 +87,9 @@ void mld_poly_decompose_32_avx2(int32_t *a1, int32_t *a0)
*/
f1 = _mm256_mulhi_epu16(f1, v);
/*
* range: 0 <= f1'' < floor(2^16 * 1025 / 2^16) = 1025
* range: 0 <= f1'' = floor(f1' * 1025 / 2^16)
* <= f1' * 1025 / 2^16
* < 2^16 * 1025 / 2^16 = 1025
*
* Because 0 <= f1'' < 2^15, the multiplication in mulhrs is unsigned, that
* is, no erroneous sign-extension occurs.
Expand Down
10 changes: 6 additions & 4 deletions dev/x86_64/src/poly_decompose_88_avx2.c
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ void mld_poly_decompose_88_avx2(int32_t *a1, int32_t *a0)
* range: 0 <= f <= Q-1 = 88*GAMMA2 = 44*128*B
*/

/* Compute f1' = ceil(f / 128) as floor((f + 127) >> 7) */
/* Compute f1' = ceil(f / 128) as floor((f + 127) / 2^7) */
f1 = _mm256_add_epi32(f, off);
f1 = _mm256_srli_epi32(f1, 7);
/*
Expand All @@ -75,8 +75,8 @@ void mld_poly_decompose_88_avx2(int32_t *a1, int32_t *a0)

/*
* Compute f1 = round-(f1' / B) ≈ round(f1' * 11275 / 2^24). This is exact
* for 0 <= f1' < 2^16. Note that half is rounded down since 11275 / 2^24 ≲
* 1 / 1488.
* for 0 <= f1' < 2^16. See mld_decompose() in mldsa/src/rounding.h for the
* proof.
*
* round(f1' * 11275 / 2^24) is in turn computed in 2 steps as
* round(floor(f1' * 11275 / 2^16) / 2^8). The mulhi computes f1'' =
Expand All @@ -88,7 +88,9 @@ void mld_poly_decompose_88_avx2(int32_t *a1, int32_t *a0)
*/
f1 = _mm256_mulhi_epu16(f1, v);
/*
* range: 0 <= f1'' < floor(2^16 * 11275 / 2^16) = 11275
* range: 0 <= f1'' = floor(f1' * 11275 / 2^16)
* <= f1' * 11275 / 2^16
* < 2^16 * 11275 / 2^16 = 11275
*
* Because 0 <= f1'' < 2^15, the multiplication in mulhrs is unsigned, that
* is, no erroneous sign-extension occurs.
Expand Down
10 changes: 6 additions & 4 deletions mldsa/src/native/x86_64/src/poly_decompose_32_avx2.c
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ void mld_poly_decompose_32_avx2(int32_t *a1, int32_t *a0)
* range: 0 <= f <= Q-1 = 32*GAMMA2 = 16*128*B
*/

/* Compute f1' = ceil(f / 128) as floor((f + 127) >> 7) */
/* Compute f1' = ceil(f / 128) as floor((f + 127) / 2^7) */
f1 = _mm256_add_epi32(f, off);
f1 = _mm256_srli_epi32(f1, 7);
/*
Expand All @@ -74,8 +74,8 @@ void mld_poly_decompose_32_avx2(int32_t *a1, int32_t *a0)

/*
* Compute f1 = round-(f1' / B) ≈ round(f1' * 1025 / 2^22). This is exact
* for 0 <= f1' < 2^16. Note that half is rounded down since 1025 / 2^22 ≲
* 1 / 4092.
* for 0 <= f1' < 2^16. See mld_decompose() in mldsa/src/rounding.h for the
* proof.
*
* round(f1' * 1025 / 2^22) is in turn computed in 2 steps as
* round(floor(f1' * 1025 / 2^16) / 2^6). The mulhi computes f1'' =
Expand All @@ -87,7 +87,9 @@ void mld_poly_decompose_32_avx2(int32_t *a1, int32_t *a0)
*/
f1 = _mm256_mulhi_epu16(f1, v);
/*
* range: 0 <= f1'' < floor(2^16 * 1025 / 2^16) = 1025
* range: 0 <= f1'' = floor(f1' * 1025 / 2^16)
* <= f1' * 1025 / 2^16
* < 2^16 * 1025 / 2^16 = 1025
*
* Because 0 <= f1'' < 2^15, the multiplication in mulhrs is unsigned, that
* is, no erroneous sign-extension occurs.
Expand Down
10 changes: 6 additions & 4 deletions mldsa/src/native/x86_64/src/poly_decompose_88_avx2.c
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ void mld_poly_decompose_88_avx2(int32_t *a1, int32_t *a0)
* range: 0 <= f <= Q-1 = 88*GAMMA2 = 44*128*B
*/

/* Compute f1' = ceil(f / 128) as floor((f + 127) >> 7) */
/* Compute f1' = ceil(f / 128) as floor((f + 127) / 2^7) */
f1 = _mm256_add_epi32(f, off);
f1 = _mm256_srli_epi32(f1, 7);
/*
Expand All @@ -75,8 +75,8 @@ void mld_poly_decompose_88_avx2(int32_t *a1, int32_t *a0)

/*
* Compute f1 = round-(f1' / B) ≈ round(f1' * 11275 / 2^24). This is exact
* for 0 <= f1' < 2^16. Note that half is rounded down since 11275 / 2^24 ≲
* 1 / 1488.
* for 0 <= f1' < 2^16. See mld_decompose() in mldsa/src/rounding.h for the
* proof.
*
* round(f1' * 11275 / 2^24) is in turn computed in 2 steps as
* round(floor(f1' * 11275 / 2^16) / 2^8). The mulhi computes f1'' =
Expand All @@ -88,7 +88,9 @@ void mld_poly_decompose_88_avx2(int32_t *a1, int32_t *a0)
*/
f1 = _mm256_mulhi_epu16(f1, v);
/*
* range: 0 <= f1'' < floor(2^16 * 11275 / 2^16) = 11275
* range: 0 <= f1'' = floor(f1' * 11275 / 2^16)
* <= f1' * 11275 / 2^16
* < 2^16 * 11275 / 2^16 = 11275
*
* Because 0 <= f1'' < 2^15, the multiplication in mulhrs is unsigned, that
* is, no erroneous sign-extension occurs.
Expand Down
37 changes: 31 additions & 6 deletions mldsa/src/rounding.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,32 @@ __contract__(
#if MLD_CONFIG_PARAMETER_SET == 44
/* check-magic: 1488 == 2 * intdiv(intdiv(MLDSA_Q - 1, 88), 128) */
/* check-magic: 11275 == floor(2**24 / 1488) */
/* check-magic: 1560281088 == 1 / (1 / 1488 - 11275 / 2**24) */
/*
* Compute f1 = round-(f1' / B) ≈ round(f1' * 11275 / 2^24). This is exact
* for 0 <= f1' < 2^16. Note that half is rounded down since 11275 / 2^24 ≲
* 1 / 1488.
* Compute f1 = round-(f1' / B) ≈ round(f1' * 11275 / 2^24). This is exact for
* 0 <= f1' < 2^16.
*
* To see this, consider the (signed) error f1' * (1 / B - 11275 / 2^24)
* between f1' / B and the (under-)approximation f1' * 11275 / 2^24. Because
* eps := 1 / B - 11275 / 2^24 is 1 / 1560281088 ≈ 2^(-30.54) < 2^(-30), we
* have 0 <= f1' * eps < 2^16 * 2^(-30) = 1 / 2^14 < 1 / 2^11 < 1 / B (note
* that f1' is non-negative).
*
* On the other hand, 1 / B is the spacing between the integral multiples
* of 1 / B, which includes all rounding boundaries n + 0.5 (since B is even).
* Hence, if f1' / B is not of the form n + 0.5, then it is at least 1 / B
* away from the nearest rounding boundary, so moving from f1' / B to
* f1' * 11275 / 2^24 does not affect the rounding result, no matter the type
* of rounding used in either side. In particular, we have round-(f1' / B) =
* round(f1' * 11275 / 2^24) as claimed.
*
* As for the remaining case where f1' / B _is_ of the form n + 0.5, because
* f1' * 11275 / 2^24 is slightly but strictly below f1' / B = n + 0.5 (note
* that f1' and thus the error f1' * eps cannot be 0 here), it is always
* rounded down to n. More precisely, we have round-(f1' / B) =
* round(f1' * 11275 / 2^24), where the round-down on the LHS is essential,
* and on the RHS the type of rounding again does not matter. This concludes
* the proof.
*/
*a1 = (*a1 * 11275 + (1 << 23)) >> 24;
mld_assert(*a1 >= 0 && *a1 <= 44);
Expand All @@ -128,10 +150,13 @@ __contract__(
#else /* MLD_CONFIG_PARAMETER_SET == 44 */
/* check-magic: 4092 == 2 * intdiv(intdiv(MLDSA_Q - 1, 32), 128) */
/* check-magic: 1025 == floor(2**22 / 4092) */
/* check-magic: 4290772992 == 1 / (1 / 4092 - 1025 / 2**22) */
/*
* Compute f1 = round-(f1' / B) ≈ round(f1' * 1025 / 2^22). This is exact
* for 0 <= f1' < 2^16. Note that half is rounded down since 1025 / 2^22 ≲
* 1 / 4092.
* Compute f1 = round-(f1' / B) ≈ round(f1' * 1025 / 2^22). This is exact for
* 0 <= f1' < 2^16. Following the same argument above, it suffices to show
* that f1' * eps < 1 / B, where eps := 1 / B - 1025 / 2^22. Indeed, we have
* eps = 1 / 4290772992 ≈ 2^(-31.99) < 2^(-31), therefore f1' * eps <
* 2^16 * 2^(-31) = 1 / 2^15 < 1 / 2^12 < 1 / B.
*/
*a1 = (*a1 * 1025 + (1 << 21)) >> 22;
mld_assert(*a1 >= 0 && *a1 <= 16);
Expand Down