Skip to content

Commit 76071d1

Browse files
committed
constant time mods (Sai)
1 parent 58940b9 commit 76071d1

File tree

1 file changed

+94
-48
lines changed
  • core/src/main/java/org/bouncycastle/pqc/crypto/ntruprime

1 file changed

+94
-48
lines changed

core/src/main/java/org/bouncycastle/pqc/crypto/ntruprime/Utils.java

Lines changed: 94 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
package org.bouncycastle.pqc.crypto.ntruprime;
22

3-
import java.math.BigInteger;
43
import java.security.SecureRandom;
54

65
import org.bouncycastle.crypto.StreamCipher;
@@ -16,10 +15,10 @@ protected static int getRandomUnsignedInteger(SecureRandom random)
1615
{
1716
byte[] c = new byte[4];
1817
random.nextBytes(c);
19-
return (Byte.toUnsignedInt(c[0])
20-
+ (Byte.toUnsignedInt(c[1]) << 8)
21-
+ (Byte.toUnsignedInt(c[2]) << 16)
22-
+ (Byte.toUnsignedInt(c[3]) << 24));
18+
return (bToUnsignedInt(c[0])
19+
+ (bToUnsignedInt(c[1]) << 8)
20+
+ (bToUnsignedInt(c[2]) << 16)
21+
+ (bToUnsignedInt(c[3]) << 24));
2322
}
2423

2524
protected static void getRandomSmallPolynomial(SecureRandom random, byte[] g)
@@ -28,10 +27,9 @@ protected static void getRandomSmallPolynomial(SecureRandom random, byte[] g)
2827
g[i] = (byte)((((getRandomUnsignedInteger(random) & 0x3fffffff) * 3) >>> 30) - 1);
2928
}
3029

31-
// TODO - Check for constant time
3230
protected static int getModFreeze(int x, int n)
3331
{
34-
return Math.floorMod((x + ((n - 1) / 2)), n) - ((n - 1) / 2);
32+
return getSignedDivMod((x + ((n - 1) / 2)), n)[1] - ((n - 1) / 2);
3533
}
3634

3735
protected static boolean isInvertiblePolynomialInR3(byte[] g, byte[] ginv, int p)
@@ -58,7 +56,7 @@ protected static boolean isInvertiblePolynomialInR3(byte[] g, byte[] ginv, int p
5856
v[0] = 0;
5957

6058
sign = -h[0] * f[0];
61-
swap = ((-delta < 0) ? -1 : 0) & (((int)h[0] != 0) ? -1 : 0);
59+
swap = checkLessThanZero(-delta) & checkNotEqualToZero(h[0]);
6260
delta ^= swap & (delta ^ -delta);
6361
delta += 1;
6462

@@ -176,7 +174,7 @@ protected static void getOneThirdInverseInRQ(short[] finv3, byte[] f, int p, int
176174
System.arraycopy(v, 0, v, 1, p);
177175
v[0] = 0;
178176

179-
swap = ((-delta < 0) ? -1 : 0) & ((g[0] != 0) ? -1 : 0);
177+
swap = checkLessThanZero(-delta) & checkNotEqualToZero(g[0]);
180178
delta ^= swap & (delta ^ -delta);
181179
delta += 1;
182180

@@ -327,22 +325,15 @@ protected static void expand(int[] L, byte[] k)
327325
byte[] nonce = new byte[16];
328326
generateAES256CTRStream(aesInput, aesOutput, nonce, k);
329327
for (int i = 0; i < L.length; i++)
330-
L[i] = (Byte.toUnsignedInt(aesOutput[i * 4])
331-
+ (Byte.toUnsignedInt(aesOutput[(i * 4) + 1]) << 8)
332-
+ (Byte.toUnsignedInt(aesOutput[(i * 4) + 2]) << 16)
333-
+ (Byte.toUnsignedInt(aesOutput[(i * 4) + 3]) << 24));
328+
L[i] = (bToUnsignedInt(aesOutput[i * 4])
329+
+ (bToUnsignedInt(aesOutput[(i * 4) + 1]) << 8)
330+
+ (bToUnsignedInt(aesOutput[(i * 4) + 2]) << 16)
331+
+ (bToUnsignedInt(aesOutput[(i * 4) + 3]) << 24));
334332
}
335333

336-
// TODO - Check for constant time
337-
private static int getUnsignedDiv(int x, int n)
338-
{
339-
return BigInteger.valueOf((x < 0) ? x + 4294967296L : x).divide(BigInteger.valueOf(n)).intValueExact();
340-
}
341-
342-
// TODO - Check for constant time
343334
private static int getUnsignedMod(int x, int n)
344335
{
345-
return BigInteger.valueOf((x < 0) ? x + 4294967296L : x).mod(BigInteger.valueOf(n)).intValueExact();
336+
return getUnsignedDivMod(x, n)[1];
346337
}
347338

348339
protected static void generatePolynomialInRQFromSeed(short[] G, byte[] seed, int p, int q)
@@ -393,9 +384,9 @@ private static void decode(short[] out, byte[] S, short[] M, int len, int start,
393384
if (M[0] == 1)
394385
out[start] = 0;
395386
else if (M[0] <= 256)
396-
out[start] = (short)getUnsignedMod(Byte.toUnsignedInt(S[sIndex]), M[0]);
387+
out[start] = (short)getUnsignedMod(bToUnsignedInt(S[sIndex]), M[0]);
397388
else
398-
out[start] = (short)getUnsignedMod(Byte.toUnsignedInt(S[sIndex]) + (S[sIndex + 1] << 8), M[0]);
389+
out[start] = (short)getUnsignedMod(bToUnsignedInt(S[sIndex]) + (S[sIndex + 1] << 8), M[0]);
399390
}
400391

401392
if (len > 1)
@@ -412,14 +403,14 @@ else if (M[0] <= 256)
412403
if (m > (256 * 16383))
413404
{
414405
bottomt[i / 2] = 256 * 256;
415-
bottomr[i / 2] = (short)(Byte.toUnsignedInt(S[sIndex]) + (256 * Byte.toUnsignedInt(S[sIndex + 1])));
406+
bottomr[i / 2] = (short)(bToUnsignedInt(S[sIndex]) + (256 * bToUnsignedInt(S[sIndex + 1])));
416407
sIndex += 2;
417408
M2[i / 2] = (short)((((m + 255) >>> 8) + 255) >>> 8);
418409
}
419410
else if (m >= 16384)
420411
{
421412
bottomt[i / 2] = 256;
422-
bottomr[i / 2] = (short)Byte.toUnsignedInt(S[sIndex]);
413+
bottomr[i / 2] = (short)bToUnsignedInt(S[sIndex]);
423414
sIndex += 1;
424415
M2[i / 2] = (short)((m + 255) >>> 8);
425416
}
@@ -437,14 +428,11 @@ else if (m >= 16384)
437428

438429
for (i = 0; i < len - 1; i += 2)
439430
{
440-
int r = Short.toUnsignedInt(bottomr[i / 2]);
441-
int r0, r1;
442-
r += bottomt[i / 2] * Short.toUnsignedInt(R2[i / 2]);
443-
r0 = getUnsignedMod(r, M[i]);
444-
r1 = getUnsignedDiv(r, M[i]);
445-
r1 = getUnsignedMod(r1, M[i + 1]);
446-
out[start++] = (short)r0;
447-
out[start++] = (short)r1;
431+
int r = sToUnsignedInt(bottomr[i / 2]);
432+
r += bottomt[i / 2] * sToUnsignedInt(R2[i / 2]);
433+
int[] r01 = getUnsignedDivMod(r, M[i]);
434+
out[start++] = (short)r01[1];
435+
out[start++] = (short)getUnsignedMod(r01[0], M[i + 1]);
448436
}
449437
if (i < len)
450438
out[start] = R2[i / 2];
@@ -515,14 +503,14 @@ protected static void getDecodedSmallPolynomial(byte[] sp, byte[] encSP, int p)
515503
for (int i = 0; i < p / 4; i++)
516504
{
517505
x = encSP[encSPIndex++];
518-
sp[spIndex++] = (byte)((Byte.toUnsignedInt(x) & 3) - 1); x >>>= 2;
519-
sp[spIndex++] = (byte)((Byte.toUnsignedInt(x) & 3) - 1); x >>>= 2;
520-
sp[spIndex++] = (byte)((Byte.toUnsignedInt(x) & 3) - 1); x >>>= 2;
521-
sp[spIndex++] = (byte)((Byte.toUnsignedInt(x) & 3) - 1);
506+
sp[spIndex++] = (byte)((bToUnsignedInt(x) & 3) - 1); x >>>= 2;
507+
sp[spIndex++] = (byte)((bToUnsignedInt(x) & 3) - 1); x >>>= 2;
508+
sp[spIndex++] = (byte)((bToUnsignedInt(x) & 3) - 1); x >>>= 2;
509+
sp[spIndex++] = (byte)((bToUnsignedInt(x) & 3) - 1);
522510
}
523511

524512
x = encSP[encSPIndex];
525-
sp[spIndex] = (byte)((Byte.toUnsignedInt(x) & 3) - 1);
513+
sp[spIndex] = (byte)((bToUnsignedInt(x) & 3) - 1);
526514
}
527515

528516
protected static void scalarMultiplicationInRQ(short[] out, short[] in, int scalar, int q)
@@ -572,20 +560,14 @@ protected static void multiplicationInR3(byte[] h, byte[] finv3, byte[] g, int p
572560
protected static void checkForSmallPolynomial(byte[] r, byte[] ev, int p, int w)
573561
{
574562
int weight = 0;
575-
for (int i = 0; i != ev.length; i++)
576-
{
577-
weight += ev[i] & 1;
578-
}
563+
for (byte b : ev)
564+
weight += b & 1;
579565

580-
int mask = (weight == w) ? 0 : -1;
566+
int mask = checkNotEqualToZero(weight - w);
581567
for (int i = 0; i < w; i++)
582-
{
583568
r[i] = (byte)(((ev[i] ^ 1) & ~mask) ^ 1);
584-
}
585569
for (int i = w; i < p; i++)
586-
{
587570
r[i] = (byte)(ev[i] & ~mask);
588-
}
589571
}
590572

591573
protected static void updateDiffMask(byte[] encR, byte[] rho, int mask)
@@ -606,6 +588,70 @@ protected static void getTopDecodedPolynomial(byte[] out, byte[] in)
606588
protected static void right(byte[] out, short[] aB, byte[] T, int q, int w, int tau2, int tau3)
607589
{
608590
for (int i = 0; i < out.length; i++)
609-
out[i] = (byte)((getModFreeze(getModFreeze((tau3 * T[i]) - tau2, q) - aB[i] + (4 * w) + 1, q) < 0) ? 1 : 0);
591+
out[i] = (byte)(-checkLessThanZero(getModFreeze(getModFreeze((tau3 * T[i]) - tau2, q) - aB[i] + (4 * w) + 1, q)));
592+
}
593+
594+
private static int[] getUnsignedDivMod(int dividend, int n)
595+
{
596+
long x = Integer.toUnsignedLong(dividend);
597+
long v = Integer.toUnsignedLong(0x80000000);
598+
long q, qpart, mask;
599+
600+
v /= n;
601+
q = 0;
602+
603+
qpart = (x * v) >>> 31;
604+
x -= qpart * n;
605+
q += qpart;
606+
607+
qpart = (x * v) >>> 31;
608+
x -= qpart * n;
609+
q += qpart;
610+
611+
x -= n;
612+
q += 1;
613+
mask = -(x >>> 63);
614+
x += mask & n;
615+
q += mask;
616+
617+
return new int[]{Math.toIntExact(q), Math.toIntExact(x)};
618+
}
619+
620+
private static int[] getSignedDivMod(int x, int n)
621+
{
622+
int q, r, mask;
623+
624+
int[] div1 = getUnsignedDivMod(Math.toIntExact(0x80000000 + Integer.toUnsignedLong(x)), n);
625+
int[] div2 = getUnsignedDivMod(0x80000000, n);
626+
627+
q = Math.toIntExact(Integer.toUnsignedLong(div1[0]) - Integer.toUnsignedLong(div2[0]));
628+
r = Math.toIntExact(Integer.toUnsignedLong(div1[1]) - Integer.toUnsignedLong(div2[1]));
629+
mask = -(r >>> 31);
630+
r += mask & n;
631+
q += mask;
632+
633+
return new int[]{q, r};
634+
}
635+
636+
private static int checkLessThanZero(int x)
637+
{
638+
return -(int)(x >>> 31);
639+
}
640+
641+
private static int checkNotEqualToZero(int x)
642+
{
643+
long l = Integer.toUnsignedLong(x);
644+
l = -l;
645+
return -(int)(l >>> 63);
646+
}
647+
648+
static int bToUnsignedInt(byte b)
649+
{
650+
return b & 0xff;
651+
}
652+
653+
static int sToUnsignedInt(short s)
654+
{
655+
return s & 0xffff;
610656
}
611657
}

0 commit comments

Comments
 (0)