11package org .bouncycastle .pqc .crypto .ntruprime ;
22
3- import java .math .BigInteger ;
43import java .security .SecureRandom ;
54
65import org .bouncycastle .crypto .StreamCipher ;
@@ -16,10 +15,10 @@ protected static int getRandomUnsignedInteger(SecureRandom random)
1615 {
1716 byte [] c = new byte [4 ];
1817 random .nextBytes (c );
19- return (Byte . toUnsignedInt (c [0 ])
20- + (Byte . toUnsignedInt (c [1 ]) << 8 )
21- + (Byte . toUnsignedInt (c [2 ]) << 16 )
22- + (Byte . toUnsignedInt (c [3 ]) << 24 ));
18+ return (bToUnsignedInt (c [0 ])
19+ + (bToUnsignedInt (c [1 ]) << 8 )
20+ + (bToUnsignedInt (c [2 ]) << 16 )
21+ + (bToUnsignedInt (c [3 ]) << 24 ));
2322 }
2423
2524 protected static void getRandomSmallPolynomial (SecureRandom random , byte [] g )
@@ -28,10 +27,9 @@ protected static void getRandomSmallPolynomial(SecureRandom random, byte[] g)
2827 g [i ] = (byte )((((getRandomUnsignedInteger (random ) & 0x3fffffff ) * 3 ) >>> 30 ) - 1 );
2928 }
3029
31- // TODO - Check for constant time
3230 protected static int getModFreeze (int x , int n )
3331 {
34- return Math . floorMod ((x + ((n - 1 ) / 2 )), n ) - ((n - 1 ) / 2 );
32+ return getSignedDivMod ((x + ((n - 1 ) / 2 )), n )[ 1 ] - ((n - 1 ) / 2 );
3533 }
3634
3735 protected static boolean isInvertiblePolynomialInR3 (byte [] g , byte [] ginv , int p )
@@ -58,7 +56,7 @@ protected static boolean isInvertiblePolynomialInR3(byte[] g, byte[] ginv, int p
5856 v [0 ] = 0 ;
5957
6058 sign = -h [0 ] * f [0 ];
61- swap = (( -delta < 0 ) ? - 1 : 0 ) & ((( int ) h [0 ] != 0 ) ? - 1 : 0 );
59+ swap = checkLessThanZero ( -delta ) & checkNotEqualToZero ( h [0 ]);
6260 delta ^= swap & (delta ^ -delta );
6361 delta += 1 ;
6462
@@ -176,7 +174,7 @@ protected static void getOneThirdInverseInRQ(short[] finv3, byte[] f, int p, int
176174 System .arraycopy (v , 0 , v , 1 , p );
177175 v [0 ] = 0 ;
178176
179- swap = (( -delta < 0 ) ? - 1 : 0 ) & (( g [0 ] != 0 ) ? - 1 : 0 );
177+ swap = checkLessThanZero ( -delta ) & checkNotEqualToZero ( g [0 ]);
180178 delta ^= swap & (delta ^ -delta );
181179 delta += 1 ;
182180
@@ -327,22 +325,15 @@ protected static void expand(int[] L, byte[] k)
327325 byte [] nonce = new byte [16 ];
328326 generateAES256CTRStream (aesInput , aesOutput , nonce , k );
329327 for (int i = 0 ; i < L .length ; i ++)
330- L [i ] = (Byte . toUnsignedInt (aesOutput [i * 4 ])
331- + (Byte . toUnsignedInt (aesOutput [(i * 4 ) + 1 ]) << 8 )
332- + (Byte . toUnsignedInt (aesOutput [(i * 4 ) + 2 ]) << 16 )
333- + (Byte . toUnsignedInt (aesOutput [(i * 4 ) + 3 ]) << 24 ));
328+ L [i ] = (bToUnsignedInt (aesOutput [i * 4 ])
329+ + (bToUnsignedInt (aesOutput [(i * 4 ) + 1 ]) << 8 )
330+ + (bToUnsignedInt (aesOutput [(i * 4 ) + 2 ]) << 16 )
331+ + (bToUnsignedInt (aesOutput [(i * 4 ) + 3 ]) << 24 ));
334332 }
335333
336- // TODO - Check for constant time
337- private static int getUnsignedDiv (int x , int n )
338- {
339- return BigInteger .valueOf ((x < 0 ) ? x + 4294967296L : x ).divide (BigInteger .valueOf (n )).intValueExact ();
340- }
341-
342- // TODO - Check for constant time
343334 private static int getUnsignedMod (int x , int n )
344335 {
345- return BigInteger . valueOf (( x < 0 ) ? x + 4294967296L : x ). mod ( BigInteger . valueOf ( n )). intValueExact () ;
336+ return getUnsignedDivMod ( x , n )[ 1 ] ;
346337 }
347338
348339 protected static void generatePolynomialInRQFromSeed (short [] G , byte [] seed , int p , int q )
@@ -393,9 +384,9 @@ private static void decode(short[] out, byte[] S, short[] M, int len, int start,
393384 if (M [0 ] == 1 )
394385 out [start ] = 0 ;
395386 else if (M [0 ] <= 256 )
396- out [start ] = (short )getUnsignedMod (Byte . toUnsignedInt (S [sIndex ]), M [0 ]);
387+ out [start ] = (short )getUnsignedMod (bToUnsignedInt (S [sIndex ]), M [0 ]);
397388 else
398- out [start ] = (short )getUnsignedMod (Byte . toUnsignedInt (S [sIndex ]) + (S [sIndex + 1 ] << 8 ), M [0 ]);
389+ out [start ] = (short )getUnsignedMod (bToUnsignedInt (S [sIndex ]) + (S [sIndex + 1 ] << 8 ), M [0 ]);
399390 }
400391
401392 if (len > 1 )
@@ -412,14 +403,14 @@ else if (M[0] <= 256)
412403 if (m > (256 * 16383 ))
413404 {
414405 bottomt [i / 2 ] = 256 * 256 ;
415- bottomr [i / 2 ] = (short )(Byte . toUnsignedInt (S [sIndex ]) + (256 * Byte . toUnsignedInt (S [sIndex + 1 ])));
406+ bottomr [i / 2 ] = (short )(bToUnsignedInt (S [sIndex ]) + (256 * bToUnsignedInt (S [sIndex + 1 ])));
416407 sIndex += 2 ;
417408 M2 [i / 2 ] = (short )((((m + 255 ) >>> 8 ) + 255 ) >>> 8 );
418409 }
419410 else if (m >= 16384 )
420411 {
421412 bottomt [i / 2 ] = 256 ;
422- bottomr [i / 2 ] = (short )Byte . toUnsignedInt (S [sIndex ]);
413+ bottomr [i / 2 ] = (short )bToUnsignedInt (S [sIndex ]);
423414 sIndex += 1 ;
424415 M2 [i / 2 ] = (short )((m + 255 ) >>> 8 );
425416 }
@@ -437,14 +428,11 @@ else if (m >= 16384)
437428
438429 for (i = 0 ; i < len - 1 ; i += 2 )
439430 {
440- int r = Short .toUnsignedInt (bottomr [i / 2 ]);
441- int r0 , r1 ;
442- r += bottomt [i / 2 ] * Short .toUnsignedInt (R2 [i / 2 ]);
443- r0 = getUnsignedMod (r , M [i ]);
444- r1 = getUnsignedDiv (r , M [i ]);
445- r1 = getUnsignedMod (r1 , M [i + 1 ]);
446- out [start ++] = (short )r0 ;
447- out [start ++] = (short )r1 ;
431+ int r = sToUnsignedInt (bottomr [i / 2 ]);
432+ r += bottomt [i / 2 ] * sToUnsignedInt (R2 [i / 2 ]);
433+ int [] r01 = getUnsignedDivMod (r , M [i ]);
434+ out [start ++] = (short )r01 [1 ];
435+ out [start ++] = (short )getUnsignedMod (r01 [0 ], M [i + 1 ]);
448436 }
449437 if (i < len )
450438 out [start ] = R2 [i / 2 ];
@@ -515,14 +503,14 @@ protected static void getDecodedSmallPolynomial(byte[] sp, byte[] encSP, int p)
515503 for (int i = 0 ; i < p / 4 ; i ++)
516504 {
517505 x = encSP [encSPIndex ++];
518- sp [spIndex ++] = (byte )((Byte . toUnsignedInt (x ) & 3 ) - 1 ); x >>>= 2 ;
519- sp [spIndex ++] = (byte )((Byte . toUnsignedInt (x ) & 3 ) - 1 ); x >>>= 2 ;
520- sp [spIndex ++] = (byte )((Byte . toUnsignedInt (x ) & 3 ) - 1 ); x >>>= 2 ;
521- sp [spIndex ++] = (byte )((Byte . toUnsignedInt (x ) & 3 ) - 1 );
506+ sp [spIndex ++] = (byte )((bToUnsignedInt (x ) & 3 ) - 1 ); x >>>= 2 ;
507+ sp [spIndex ++] = (byte )((bToUnsignedInt (x ) & 3 ) - 1 ); x >>>= 2 ;
508+ sp [spIndex ++] = (byte )((bToUnsignedInt (x ) & 3 ) - 1 ); x >>>= 2 ;
509+ sp [spIndex ++] = (byte )((bToUnsignedInt (x ) & 3 ) - 1 );
522510 }
523511
524512 x = encSP [encSPIndex ];
525- sp [spIndex ] = (byte )((Byte . toUnsignedInt (x ) & 3 ) - 1 );
513+ sp [spIndex ] = (byte )((bToUnsignedInt (x ) & 3 ) - 1 );
526514 }
527515
528516 protected static void scalarMultiplicationInRQ (short [] out , short [] in , int scalar , int q )
@@ -572,20 +560,14 @@ protected static void multiplicationInR3(byte[] h, byte[] finv3, byte[] g, int p
572560 protected static void checkForSmallPolynomial (byte [] r , byte [] ev , int p , int w )
573561 {
574562 int weight = 0 ;
575- for (int i = 0 ; i != ev .length ; i ++)
576- {
577- weight += ev [i ] & 1 ;
578- }
563+ for (byte b : ev )
564+ weight += b & 1 ;
579565
580- int mask = (weight == w ) ? 0 : - 1 ;
566+ int mask = checkNotEqualToZero (weight - w );
581567 for (int i = 0 ; i < w ; i ++)
582- {
583568 r [i ] = (byte )(((ev [i ] ^ 1 ) & ~mask ) ^ 1 );
584- }
585569 for (int i = w ; i < p ; i ++)
586- {
587570 r [i ] = (byte )(ev [i ] & ~mask );
588- }
589571 }
590572
591573 protected static void updateDiffMask (byte [] encR , byte [] rho , int mask )
@@ -606,6 +588,70 @@ protected static void getTopDecodedPolynomial(byte[] out, byte[] in)
606588 protected static void right (byte [] out , short [] aB , byte [] T , int q , int w , int tau2 , int tau3 )
607589 {
608590 for (int i = 0 ; i < out .length ; i ++)
609- out [i ] = (byte )((getModFreeze (getModFreeze ((tau3 * T [i ]) - tau2 , q ) - aB [i ] + (4 * w ) + 1 , q ) < 0 ) ? 1 : 0 );
591+ out [i ] = (byte )(-checkLessThanZero (getModFreeze (getModFreeze ((tau3 * T [i ]) - tau2 , q ) - aB [i ] + (4 * w ) + 1 , q )));
592+ }
593+
594+ private static int [] getUnsignedDivMod (int dividend , int n )
595+ {
596+ long x = Integer .toUnsignedLong (dividend );
597+ long v = Integer .toUnsignedLong (0x80000000 );
598+ long q , qpart , mask ;
599+
600+ v /= n ;
601+ q = 0 ;
602+
603+ qpart = (x * v ) >>> 31 ;
604+ x -= qpart * n ;
605+ q += qpart ;
606+
607+ qpart = (x * v ) >>> 31 ;
608+ x -= qpart * n ;
609+ q += qpart ;
610+
611+ x -= n ;
612+ q += 1 ;
613+ mask = -(x >>> 63 );
614+ x += mask & n ;
615+ q += mask ;
616+
617+ return new int []{Math .toIntExact (q ), Math .toIntExact (x )};
618+ }
619+
620+ private static int [] getSignedDivMod (int x , int n )
621+ {
622+ int q , r , mask ;
623+
624+ int [] div1 = getUnsignedDivMod (Math .toIntExact (0x80000000 + Integer .toUnsignedLong (x )), n );
625+ int [] div2 = getUnsignedDivMod (0x80000000 , n );
626+
627+ q = Math .toIntExact (Integer .toUnsignedLong (div1 [0 ]) - Integer .toUnsignedLong (div2 [0 ]));
628+ r = Math .toIntExact (Integer .toUnsignedLong (div1 [1 ]) - Integer .toUnsignedLong (div2 [1 ]));
629+ mask = -(r >>> 31 );
630+ r += mask & n ;
631+ q += mask ;
632+
633+ return new int []{q , r };
634+ }
635+
636+ private static int checkLessThanZero (int x )
637+ {
638+ return -(int )(x >>> 31 );
639+ }
640+
641+ private static int checkNotEqualToZero (int x )
642+ {
643+ long l = Integer .toUnsignedLong (x );
644+ l = -l ;
645+ return -(int )(l >>> 63 );
646+ }
647+
648+ static int bToUnsignedInt (byte b )
649+ {
650+ return b & 0xff ;
651+ }
652+
653+ static int sToUnsignedInt (short s )
654+ {
655+ return s & 0xffff ;
610656 }
611657}
0 commit comments