@@ -32,7 +32,9 @@ internal unsafe ref struct BigInteger
32
32
private const int MaxBits = BitsForLongestBinaryMantissa + BitsForLongestDigitSequence + BitsPerBlock ;
33
33
34
34
private const int BitsPerBlock = sizeof ( int ) * 8 ;
35
- private const int MaxBlockCount = ( MaxBits + ( BitsPerBlock - 1 ) ) / BitsPerBlock ;
35
+
36
+ // We need one extra block to make our shift left algorithm significantly simpler
37
+ private const int MaxBlockCount = ( ( MaxBits + ( BitsPerBlock - 1 ) ) / BitsPerBlock ) + 1 ;
36
38
37
39
private static readonly uint [ ] s_Pow10UInt32Table = new uint [ ]
38
40
{
@@ -300,7 +302,8 @@ internal unsafe ref struct BigInteger
300
302
0xD9D61A05 ,
301
303
0x00000325 ,
302
304
303
- // 9 Trailing blocks to ensure MaxBlockCount
305
+ // 10 Trailing blocks to ensure MaxBlockCount
306
+ 0x00000000 ,
304
307
0x00000000 ,
305
308
0x00000000 ,
306
309
0x00000000 ,
@@ -353,18 +356,33 @@ public static void Add(ref BigInteger lhs, uint value, ref BigInteger result)
353
356
while ( index < lhsLength )
354
357
{
355
358
ulong sum = ( ulong ) ( lhs . _blocks [ index ] ) + carry ;
356
- lhs . _blocks [ index ] = ( uint ) ( sum ) ;
359
+ result . _blocks [ index ] = ( uint ) ( sum ) ;
357
360
carry = ( uint ) ( sum >> 32 ) ;
358
361
359
362
index ++ ;
360
363
}
361
364
365
+ int resultLength = lhsLength ;
366
+
362
367
if ( carry != 0 )
363
368
{
364
- Debug . Assert ( unchecked ( ( uint ) ( lhsLength ) ) + 1 <= MaxBlockCount ) ;
369
+ Debug . Assert ( unchecked ( ( uint ) ( resultLength ) ) < MaxBlockCount ) ;
370
+
371
+ if ( unchecked ( ( uint ) ( resultLength ) ) >= MaxBlockCount )
372
+ {
373
+ // We shouldn't reach here, and the above assert will help flag this
374
+ // during testing, but we'll ensure that we return a safe value of
375
+ // zero in the case we end up overflowing in any way.
376
+
377
+ result . SetZero ( ) ;
378
+ return ;
379
+ }
380
+
365
381
result . _blocks [ index ] = carry ;
366
- result . _length = ( lhsLength + 1 ) ;
382
+ resultLength += 1 ;
367
383
}
384
+
385
+ result . _length = resultLength ;
368
386
}
369
387
370
388
public static void Add ( ref BigInteger lhs , ref BigInteger rhs , out BigInteger result )
@@ -409,15 +427,30 @@ public static void Add(ref BigInteger lhs, ref BigInteger rhs, out BigInteger re
409
427
resultIndex ++ ;
410
428
}
411
429
430
+ int resultLength = largeLength ;
431
+
412
432
// If there's still a carry, append a new block
413
433
if ( carry != 0 )
414
434
{
415
435
Debug . Assert ( carry == 1 ) ;
416
- Debug . Assert ( ( resultIndex == largeLength ) && ( largeLength < MaxBlockCount ) ) ;
436
+ Debug . Assert ( resultIndex == resultLength ) ;
437
+ Debug . Assert ( unchecked ( ( uint ) ( resultLength ) ) < MaxBlockCount ) ;
438
+
439
+ if ( unchecked ( ( uint ) ( resultLength ) ) >= MaxBlockCount )
440
+ {
441
+ // We shouldn't reach here, and the above assert will help flag this
442
+ // during testing, but we'll ensure that we return a safe value of
443
+ // zero in the case we end up overflowing in any way.
444
+
445
+ result . SetZero ( ) ;
446
+ return ;
447
+ }
417
448
418
449
result . _blocks [ resultIndex ] = 1 ;
419
- result . _length += 1 ;
450
+ resultLength += 1 ;
420
451
}
452
+
453
+ result . _length = resultLength ;
421
454
}
422
455
423
456
public static int Compare ( ref BigInteger lhs , ref BigInteger rhs )
@@ -765,12 +798,27 @@ public static void Multiply(ref BigInteger lhs, uint value, ref BigInteger resul
765
798
index ++ ;
766
799
}
767
800
801
+ int resultLength = lhsLength ;
802
+
768
803
if ( carry != 0 )
769
804
{
770
- Debug . Assert ( unchecked ( ( uint ) ( lhsLength ) ) + 1 <= MaxBlockCount ) ;
805
+ Debug . Assert ( unchecked ( ( uint ) ( resultLength ) ) < MaxBlockCount ) ;
806
+
807
+ if ( unchecked ( ( uint ) ( resultLength ) ) >= MaxBlockCount )
808
+ {
809
+ // We shouldn't reach here, and the above assert will help flag this
810
+ // during testing, but we'll ensure that we return a safe value of
811
+ // zero in the case we end up overflowing in any way.
812
+
813
+ result . SetZero ( ) ;
814
+ return ;
815
+ }
816
+
771
817
result . _blocks [ index ] = carry ;
772
- result . _length = ( lhsLength + 1 ) ;
818
+ resultLength += 1 ;
773
819
}
820
+
821
+ result . _length = resultLength ;
774
822
}
775
823
776
824
public static void Multiply ( ref BigInteger lhs , ref BigInteger rhs , ref BigInteger result )
@@ -804,6 +852,16 @@ public static void Multiply(ref BigInteger lhs, ref BigInteger rhs, ref BigInteg
804
852
805
853
int maxResultLength = smallLength + largeLength ;
806
854
Debug . Assert ( unchecked ( ( uint ) ( maxResultLength ) ) <= MaxBlockCount ) ;
855
+
856
+ if ( unchecked ( ( uint ) ( maxResultLength ) ) > MaxBlockCount )
857
+ {
858
+ // We shouldn't reach here, and the above assert will help flag this
859
+ // during testing, but we'll ensure that we return a safe value of
860
+ // zero in the case we end up overflowing in any way.
861
+
862
+ result . SetZero ( ) ;
863
+ return ;
864
+ }
807
865
808
866
// Zero out result internal blocks.
809
867
Buffer . ZeroMemory ( ( byte * ) ( result . GetBlocksPointer ( ) ) , ( maxResultLength * sizeof ( uint ) ) ) ;
@@ -1072,6 +1130,18 @@ public void ExtendBlock(uint blockValue)
1072
1130
{
1073
1131
_blocks [ _length ] = blockValue ;
1074
1132
_length ++ ;
1133
+
1134
+ Debug . Assert ( unchecked ( ( uint ) ( _length ) ) <= MaxBlockCount ) ;
1135
+
1136
+ if ( unchecked ( ( uint ) ( _length ) ) > MaxBlockCount )
1137
+ {
1138
+ // We shouldn't reach here, and the above assert will help flag this
1139
+ // during testing, but we'll ensure that we return a safe value of
1140
+ // zero in the case we end up overflowing in any way.
1141
+
1142
+ SetZero ( ) ;
1143
+ return ;
1144
+ }
1075
1145
}
1076
1146
1077
1147
public void ExtendBlocks ( uint blockValue , uint blockCount )
@@ -1081,7 +1151,19 @@ public void ExtendBlocks(uint blockValue, uint blockCount)
1081
1151
if ( blockCount == 1 )
1082
1152
{
1083
1153
ExtendBlock ( blockValue ) ;
1154
+ return ;
1155
+ }
1156
+
1157
+ int resultLength = _length + ( int ) ( blockCount ) ;
1158
+ Debug . Assert ( unchecked ( ( uint ) ( resultLength ) ) <= MaxBlockCount ) ;
1084
1159
1160
+ if ( unchecked ( ( uint ) ( resultLength ) ) > MaxBlockCount )
1161
+ {
1162
+ // We shouldn't reach here, and the above assert will help flag this
1163
+ // during testing, but we'll ensure that we return a safe value of
1164
+ // zero in the case we end up overflowing in any way.
1165
+
1166
+ SetZero ( ) ;
1085
1167
return ;
1086
1168
}
1087
1169
@@ -1149,9 +1231,20 @@ public void Multiply10()
1149
1231
1150
1232
if ( carry != 0 )
1151
1233
{
1152
- Debug . Assert ( unchecked ( ( uint ) ( _length ) ) + 1 <= MaxBlockCount ) ;
1234
+ Debug . Assert ( unchecked ( ( uint ) ( length ) ) < MaxBlockCount ) ;
1235
+
1236
+ if ( unchecked ( ( uint ) ( length ) ) >= MaxBlockCount )
1237
+ {
1238
+ // We shouldn't reach here, and the above assert will help flag this
1239
+ // during testing, but we'll ensure that we return a safe value of
1240
+ // zero in the case we end up overflowing in any way.
1241
+
1242
+ SetZero ( ) ;
1243
+ return ;
1244
+ }
1245
+
1153
1246
_blocks [ index ] = ( uint ) ( carry ) ;
1154
- _length += 1 ;
1247
+ _length = length + 1 ;
1155
1248
}
1156
1249
}
1157
1250
@@ -1214,10 +1307,30 @@ public void ShiftLeft(uint shift)
1214
1307
int readIndex = ( length - 1 ) ;
1215
1308
int writeIndex = readIndex + ( int ) ( blocksToShift ) ;
1216
1309
1310
+ if ( unchecked ( ( uint ) ( writeIndex ) ) >= MaxBlockCount )
1311
+ {
1312
+ // We shouldn't reach here, and the above assert will help flag this
1313
+ // during testing, but we'll ensure that we return a safe value of
1314
+ // zero in the case we end up overflowing in any way.
1315
+
1316
+ SetZero ( ) ;
1317
+ return ;
1318
+ }
1319
+
1217
1320
// Check if the shift is block aligned
1218
1321
if ( remainingBitsToShift == 0 )
1219
1322
{
1220
- Debug . Assert ( writeIndex < MaxBlockCount ) ;
1323
+ Debug . Assert ( unchecked ( ( uint ) ( length ) ) < MaxBlockCount ) ;
1324
+
1325
+ if ( unchecked ( ( uint ) ( length ) ) >= MaxBlockCount )
1326
+ {
1327
+ // We shouldn't reach here, and the above assert will help flag this
1328
+ // during testing, but we'll ensure that we return a safe value of
1329
+ // zero in the case we end up overflowing in any way.
1330
+
1331
+ SetZero ( ) ;
1332
+ return ;
1333
+ }
1221
1334
1222
1335
while ( readIndex >= 0 )
1223
1336
{
@@ -1234,8 +1347,19 @@ public void ShiftLeft(uint shift)
1234
1347
else
1235
1348
{
1236
1349
// We need an extra block for the partial shift
1350
+
1237
1351
writeIndex ++ ;
1238
- Debug . Assert ( writeIndex < MaxBlockCount ) ;
1352
+ Debug . Assert ( unchecked ( ( uint ) ( length ) ) < MaxBlockCount ) ;
1353
+
1354
+ if ( unchecked ( ( uint ) ( length ) ) >= MaxBlockCount )
1355
+ {
1356
+ // We shouldn't reach here, and the above assert will help flag this
1357
+ // during testing, but we'll ensure that we return a safe value of
1358
+ // zero in the case we end up overflowing in any way.
1359
+
1360
+ SetZero ( ) ;
1361
+ return ;
1362
+ }
1239
1363
1240
1364
// Set the length to hold the shifted blocks
1241
1365
_length = writeIndex + 1 ;
0 commit comments