diff --git a/include/mbedtls/bignum.h b/include/mbedtls/bignum.h
index 6187856713dc..86fa62574b7c 100644
--- a/include/mbedtls/bignum.h
+++ b/include/mbedtls/bignum.h
@@ -1045,6 +1045,7 @@ int mbedtls_mpi_is_prime_ext(const mbedtls_mpi *X, int rounds,
typedef enum {
MBEDTLS_MPI_GEN_PRIME_FLAG_DH = 0x0001, /**< (X-1)/2 is prime too */
MBEDTLS_MPI_GEN_PRIME_FLAG_LOW_ERR = 0x0002, /**< lower error rate from 2-80 to 2-128 */
+ MBEDTLS_MPI_GEN_PRIME_FLAG_3MOD4 = 0x0004, /**< generate a prime that's 3 mod 4 (ie, low bits are 11) */
} mbedtls_mpi_gen_prime_flag_t;
/**
diff --git a/library/bignum.c b/library/bignum.c
index f6b8f9998121..77b3ba767b95 100644
--- a/library/bignum.c
+++ b/library/bignum.c
@@ -2304,7 +2304,12 @@ int mbedtls_mpi_gen_prime(mbedtls_mpi *X, size_t nbits, int flags,
if (k > nbits) {
MBEDTLS_MPI_CHK(mbedtls_mpi_shift_r(X, k - nbits));
}
- X->p[0] |= 1;
+
+ if ((flags & MBEDTLS_MPI_GEN_PRIME_FLAG_3MOD4) != 0) {
+ X->p[0] |= 3;
+ } else {
+ X->p[0] |= 1;
+ }
if ((flags & MBEDTLS_MPI_GEN_PRIME_FLAG_DH) == 0) {
ret = mbedtls_mpi_is_prime_ext(X, rounds, f_rng, p_rng);
diff --git a/library/rsa.c b/library/rsa.c
index 08267dbfce17..45f6f405d098 100644
--- a/library/rsa.c
+++ b/library/rsa.c
@@ -1035,6 +1035,100 @@ size_t mbedtls_rsa_get_len(const mbedtls_rsa_context *ctx)
#if defined(MBEDTLS_GENPRIME)
+/* Part of the keypair generation routine, extracted for readability:
+ *
+ * check GCD( E, (P-1)*(Q-1) ) == 1 (FIPS 186-4 §B.3.1 criterion 2(a))
+ * compute D = E^-1 mod LCM(P-1, Q-1) (FIPS 186-4 §B.3.1 criterion 3(b))
+ *
+ * This is done in a single step as E^-1 mod LCM(P-1, Q-1) only exists
+ * if GCD( E, (P-1)*(Q-1) ) == 1 (which is equivalent to saying that
+ * E is coprime to LCM(P-1, Q-1), ie they have no prime factor in common).
+ *
+ * Input: a partial RSA context with only P, Q, E set.
+ * Output:
+ * - On success, D is set in the context, P, Q and E are unchanged.
+ * - On failure, P and Q may no longer hold their original values!
+ * - If GCD( E, (P-1) * (Q-1) ) != 1 return MBEDTLS_ERR_MPI_NOT_ACCEPTABLE.
+ * - On other errors (allocation etc) return a specific error code.
+ *
+ * Pre-conditions that must be ensured by the caller:
+ * - P > Q
+ * - P and Q have the same number of limbs
+ * - P and Q are both 3 mod 4
+ * - E is odd
+ */
+static int rsa_gen_key_check_e_compute_d(mbedtls_rsa_context *ctx)
+{
+ int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+ mbedtls_mpi G, L;
+
+ mbedtls_mpi_init(&G);
+ mbedtls_mpi_init(&L);
+
+ /* Compute that before we shift P, even though we know it won't shrink */
+ const size_t p_limbs = ctx->P.n;
+ const size_t d_limbs = 2 * p_limbs;
+
+ /* Since we can only compute modular inverse with odd modulus,
+ * and clearly P-1 and Q-1 hence their LCM is even,
+ * we'll first work with (P-1)/2 and (Q-1)/2 and their LCM,
+ * which we know are odd since P and Q are both 3 mod 4.
+ *
+ * More specifically, since E is odd, we have
+ * GCD( E, (P-1) * (Q-1) ) = GCD( E, (P-1)/2 * (Q-1)/2) )
+ * and that is 1 if and only if GCD( E, LCM((P-1)/2, (Q-1)/2) ) == 1.
+ *
+ * Also, setting L2 = LCM((P-1)/2, (Q-1)/2)
+ * and L = LCM(P-1, Q-1), we have L = 2 * L2 with L2 odd.
+ * So by the CRT, it's enough to compute E^-1 mod L2 and mod 2.
+ * But we know that the inverse mod 2 is 1 so we don't have to compute it.
+ *
+ * If D2 is the inverse mod L2, then the inverse mod L is either
+ * D2 or D2 + L2 (those are the only two numbers mod L that are equal to D2
+ * mod L2) and more specifically it's the one that's odd (ie 1 mod 2).
+ *
+ * We compute as much as possible in place.
+ */
+
+ /* Temporarily replace P, Q by (P-1)/2, (Q-1)/2 */
+ MBEDTLS_MPI_CHK(mbedtls_mpi_shift_r(&ctx->P, 1));
+ MBEDTLS_MPI_CHK(mbedtls_mpi_shift_r(&ctx->Q, 1));
+
+ /* Use LCM(a, b) = a * b / GCD(a, b) to compute L2.
+ * For the GCD computation we use the fact that Q < P */
+ MBEDTLS_MPI_CHK(mbedtls_mpi_gcd_modinv_odd(&G, NULL, &ctx->Q, &ctx->P));
+ MBEDTLS_MPI_CHK(mbedtls_mpi_div_mpi(&L, NULL, &ctx->P, &G));
+ MBEDTLS_MPI_CHK(mbedtls_mpi_mul_mpi(&L, &L, &ctx->Q));
+
+ /* Compute GCD(E, L2) and E^-1 mod L2 */
+ MBEDTLS_MPI_CHK(mbedtls_mpi_gcd_modinv_odd(&G, &ctx->D, &ctx->E, &L));
+
+ /* Reject if GCD(E, L2) != 1 */
+ if (mbedtls_mpi_cmp_int(&G, 1) != 0) {
+ ret = MBEDTLS_ERR_MPI_NOT_ACCEPTABLE;
+ goto cleanup;
+ }
+
+ /* Now ctx->D holds D2. Update that to D.
+ * Note that D2 + L2 < L = LCM(P-1, Q-1) <= (P-1) * (Q-1) < P * Q */
+ MBEDTLS_MPI_CHK(mbedtls_mpi_grow(&ctx->D, d_limbs));
+ MBEDTLS_MPI_CHK(mbedtls_mpi_grow(&L, d_limbs));
+ unsigned d2_is_even = (ctx->D.p[0] & 1) ^ 1;
+ (void) mbedtls_mpi_core_add_if(ctx->D.p, L.p, d_limbs, d2_is_even);
+
+ /* Restore P,Q */
+ MBEDTLS_MPI_CHK(mbedtls_mpi_shift_l(&ctx->P, 1));
+ ctx->P.p[0] |= 1;
+ MBEDTLS_MPI_CHK(mbedtls_mpi_shift_l(&ctx->Q, 1));
+ ctx->Q.p[0] |= 1;
+
+cleanup:
+ mbedtls_mpi_free(&G);
+ mbedtls_mpi_free(&L);
+
+ return ret;
+}
+
/*
* Generate an RSA keypair
*
@@ -1048,7 +1142,15 @@ int mbedtls_rsa_gen_key(mbedtls_rsa_context *ctx,
{
int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
mbedtls_mpi H;
- int prime_quality = 0;
+ /* We require our primes to be 3 mod 4; this ensures that (P-1)/2 and
+ * (Q-1)/2 are odd, enabling use of constant-time modinv, see
+ * rsa_gen_key_check_e_compute_d().
+ * Forcing this is allowed by FIPS 186-5, Appendix A.1.3:
+ * "a, b (Optional parameters) Numbers from the set {1, 3, 5, 7} that may be
+ * used to add the further requirements p ≡ a mod 8, q ≡ b mod 8."
+ * (We're only forcing the low 2 bits while FIPS allows forcing 3.)
+ */
+ int gen_prime_flags = MBEDTLS_MPI_GEN_PRIME_FLAG_3MOD4;
/*
* If the modulus is 1024 bit long or shorter, then the security strength of
@@ -1056,7 +1158,7 @@ int mbedtls_rsa_gen_key(mbedtls_rsa_context *ctx,
* rate of 2^-80 is sufficient.
*/
if (nbits > 1024) {
- prime_quality = MBEDTLS_MPI_GEN_PRIME_FLAG_LOW_ERR;
+ gen_prime_flags |= MBEDTLS_MPI_GEN_PRIME_FLAG_LOW_ERR;
}
mbedtls_mpi_init(&H);
@@ -1081,10 +1183,10 @@ int mbedtls_rsa_gen_key(mbedtls_rsa_context *ctx,
do {
MBEDTLS_MPI_CHK(mbedtls_mpi_gen_prime(&ctx->P, nbits >> 1,
- prime_quality, f_rng, p_rng));
+ gen_prime_flags, f_rng, p_rng));
MBEDTLS_MPI_CHK(mbedtls_mpi_gen_prime(&ctx->Q, nbits >> 1,
- prime_quality, f_rng, p_rng));
+ gen_prime_flags, f_rng, p_rng));
/* make sure the difference between p and q is not too small (FIPS 186-4 §B.3.3 step 5.4) */
MBEDTLS_MPI_CHK(mbedtls_mpi_sub_mpi(&H, &ctx->P, &ctx->Q));
@@ -1099,7 +1201,7 @@ int mbedtls_rsa_gen_key(mbedtls_rsa_context *ctx,
/* Compute D = E^-1 mod LCM(P-1, Q-1) (FIPS 186-4 §B.3.1 criterion 3(b))
* if it exists (FIPS 186-4 §B.3.1 criterion 2(a)) */
- ret = mbedtls_rsa_deduce_private_exponent(&ctx->P, &ctx->Q, &ctx->E, &ctx->D);
+ ret = rsa_gen_key_check_e_compute_d(ctx);
if (ret == MBEDTLS_ERR_MPI_NOT_ACCEPTABLE) {
mbedtls_mpi_lset(&ctx->D, 0); /* needed for the next call */
continue;
diff --git a/tests/suites/test_suite_bignum.function b/tests/suites/test_suite_bignum.function
index d348f05405e9..92297d8d2377 100644
--- a/tests/suites/test_suite_bignum.function
+++ b/tests/suites/test_suite_bignum.function
@@ -1449,20 +1449,36 @@ void mpi_gen_prime(int bits, int flags, int ref_ret)
mbedtls_mpi_init(&X);
- my_ret = mbedtls_mpi_gen_prime(&X, bits, flags,
- mbedtls_test_rnd_std_rand, NULL);
- TEST_ASSERT(my_ret == ref_ret);
+ /* Since this is not deterministic, repeat a few times
+ * when expecting success. */
+ for (size_t i = 0; i < 8; i++) {
+ my_ret = mbedtls_mpi_gen_prime(&X, bits, flags,
+ mbedtls_test_rnd_std_rand, NULL);
+ TEST_ASSERT(my_ret == ref_ret);
+
+ if (ref_ret != 0) {
+ /* no additional checks, no repeating */
+ break;
+ }
- if (ref_ret == 0) {
size_t actual_bits = mbedtls_mpi_bitlen(&X);
- TEST_ASSERT(actual_bits >= (size_t) bits);
- TEST_ASSERT(actual_bits <= (size_t) bits + 1);
- TEST_ASSERT(sign_is_valid(&X));
+ if (flags & MBEDTLS_MPI_GEN_PRIME_FLAG_DH) {
+ TEST_ASSERT(actual_bits >= (size_t) bits);
+ TEST_ASSERT(actual_bits <= (size_t) bits + 1);
+ } else {
+ TEST_EQUAL(actual_bits, (size_t) bits);
+ }
+ TEST_EQUAL(X.s, 1);
TEST_ASSERT(mbedtls_mpi_is_prime_ext(&X, 40,
mbedtls_test_rnd_std_rand,
NULL) == 0);
+
+ if (flags & MBEDTLS_MPI_GEN_PRIME_FLAG_3MOD4) {
+ TEST_EQUAL(X.p[0] & 3, 3);
+ }
+
if (flags & MBEDTLS_MPI_GEN_PRIME_FLAG_DH) {
/* X = ( X - 1 ) / 2 */
TEST_ASSERT(mbedtls_mpi_shift_r(&X, 1) == 0);
diff --git a/tests/suites/test_suite_bignum.misc.data b/tests/suites/test_suite_bignum.misc.data
index ab6088c5b15b..4e40846f2055 100644
--- a/tests/suites/test_suite_bignum.misc.data
+++ b/tests/suites/test_suite_bignum.misc.data
@@ -1873,6 +1873,10 @@ Test mbedtls_mpi_gen_prime (Larger)
depends_on:MBEDTLS_GENPRIME
mpi_gen_prime:128:0:0
+Test mbedtls_mpi_gen_prime (Lower error rate)
+depends_on:MBEDTLS_GENPRIME
+mpi_gen_prime:128:MBEDTLS_MPI_GEN_PRIME_FLAG_LOW_ERR:0
+
Test mbedtls_mpi_gen_prime (Safe)
depends_on:MBEDTLS_GENPRIME
mpi_gen_prime:128:MBEDTLS_MPI_GEN_PRIME_FLAG_DH:0
@@ -1881,13 +1885,29 @@ Test mbedtls_mpi_gen_prime (Safe with lower error rate)
depends_on:MBEDTLS_GENPRIME
mpi_gen_prime:128:MBEDTLS_MPI_GEN_PRIME_FLAG_DH | MBEDTLS_MPI_GEN_PRIME_FLAG_LOW_ERR:0
-Test mbedtls_mpi_gen_prime standard RSA #1 (lower error rate)
+Test mbedtls_mpi_gen_prime (Safe with lower error rate and 3 mod 4)
+depends_on:MBEDTLS_GENPRIME
+mpi_gen_prime:128:MBEDTLS_MPI_GEN_PRIME_FLAG_DH | MBEDTLS_MPI_GEN_PRIME_FLAG_LOW_ERR | MBEDTLS_MPI_GEN_PRIME_FLAG_3MOD4:0
+
+Test mbedtls_mpi_gen_prime (3 mod 4)
+depends_on:MBEDTLS_GENPRIME
+mpi_gen_prime:128:MBEDTLS_MPI_GEN_PRIME_FLAG_3MOD4:0
+
+Test mbedtls_mpi_gen_prime (3 mod 4 with lower error rate)
+depends_on:MBEDTLS_GENPRIME
+mpi_gen_prime:128:MBEDTLS_MPI_GEN_PRIME_FLAG_3MOD4 | MBEDTLS_MPI_GEN_PRIME_FLAG_LOW_ERR:0
+
+Test mbedtls_mpi_gen_prime (3 mod 4 and safe)
+depends_on:MBEDTLS_GENPRIME
+mpi_gen_prime:128:MBEDTLS_MPI_GEN_PRIME_FLAG_3MOD4 | MBEDTLS_MPI_GEN_PRIME_FLAG_DH:0
+
+Test mbedtls_mpi_gen_prime standard RSA #1 (lower error rate, 3 mod 4)
depends_on:MBEDTLS_GENPRIME
-mpi_gen_prime:1024:MBEDTLS_MPI_GEN_PRIME_FLAG_LOW_ERR:0
+mpi_gen_prime:1024:MBEDTLS_MPI_GEN_PRIME_FLAG_LOW_ERR | MBEDTLS_MPI_GEN_PRIME_FLAG_3MOD4:0
-Test mbedtls_mpi_gen_prime standard RSA #2 (lower error rate)
+Test mbedtls_mpi_gen_prime standard RSA #2 (lower error rate, 3 mod 4)
depends_on:MBEDTLS_GENPRIME
-mpi_gen_prime:1536:MBEDTLS_MPI_GEN_PRIME_FLAG_LOW_ERR:0
+mpi_gen_prime:1536:MBEDTLS_MPI_GEN_PRIME_FLAG_LOW_ERR | MBEDTLS_MPI_GEN_PRIME_FLAG_3MOD4:0
Test bit getting (Value bit 25)
mpi_get_bit:"2faa127":25:1