@@ -385,38 +385,112 @@ library Math {
385
385
* @dev Returns the square root of a number. If the number is not a perfect square, the value is rounded
386
386
* towards zero.
387
387
*
388
- * Inspired by Henry S. Warren, Jr.'s "Hacker's Delight" (Chapter 11).
388
+ * This method is based on Newton's method for computing square roots; the algorithm is restricted to only
389
+ * using integer operations.
389
390
*/
390
391
function sqrt (uint256 a ) internal pure returns (uint256 ) {
391
- if (a == 0 ) {
392
- return 0 ;
393
- }
394
-
395
- // For our first guess, we get the biggest power of 2 which is smaller than the square root of the target.
396
- //
397
- // We know that the "msb" (most significant bit) of our target number `a` is a power of 2 such that we have
398
- // `msb(a) <= a < 2*msb(a)`. This value can be written `msb(a)=2**k` with `k=log2(a)`.
399
- //
400
- // This can be rewritten `2**log2(a) <= a < 2**(log2(a) + 1)`
401
- // → `sqrt(2**k) <= sqrt(a) < sqrt(2**(k+1))`
402
- // → `2**(k/2) <= sqrt(a) < 2**((k+1)/2) <= 2**(k/2 + 1)`
403
- //
404
- // Consequently, `2**(log2(a) / 2)` is a good first approximation of `sqrt(a)` with at least 1 correct bit.
405
- uint256 result = 1 << (log2 (a) >> 1 );
406
-
407
- // At this point `result` is an estimation with one bit of precision. We know the true value is a uint128,
408
- // since it is the square root of a uint256. Newton's method converges quadratically (precision doubles at
409
- // every iteration). We thus need at most 7 iteration to turn our partial result with one bit of precision
410
- // into the expected uint128 result.
411
392
unchecked {
412
- result = (result + a / result) >> 1 ;
413
- result = (result + a / result) >> 1 ;
414
- result = (result + a / result) >> 1 ;
415
- result = (result + a / result) >> 1 ;
416
- result = (result + a / result) >> 1 ;
417
- result = (result + a / result) >> 1 ;
418
- result = (result + a / result) >> 1 ;
419
- return min (result, a / result);
393
+ // Take care of easy edge cases when a == 0 or a == 1
394
+ if (a <= 1 ) {
395
+ return a;
396
+ }
397
+
398
+ // In this function, we use Newton's method to get a root of `f(x) := x² - a`. It involves building a
399
+ // sequence x_n that converges toward sqrt(a). For each iteration x_n, we also define the error between
400
+ // the current value as `ε_n = | x_n - sqrt(a) |`.
401
+ //
402
+ // For our first estimation, we consider `e` the smallest power of 2 which is bigger than the square root
403
+ // of the target. (i.e. `2**(e-1) ≤ sqrt(a) < 2**e`). We know that `e ≤ 128` because `(2¹²⁸)² = 2²⁵⁶` is
404
+ // bigger than any uint256.
405
+ //
406
+ // By noticing that
407
+ // `2**(e-1) ≤ sqrt(a) < 2**e → (2**(e-1))² ≤ a < (2**e)² → 2**(2*e-2) ≤ a < 2**(2*e)`
408
+ // we can deduce that `e - 1` is `log2(a) / 2`. We can thus compute `x_n = 2**(e-1)` using a method similar
409
+ // to the msb function.
410
+ uint256 aa = a;
411
+ uint256 xn = 1 ;
412
+
413
+ if (aa >= (1 << 128 )) {
414
+ aa >>= 128 ;
415
+ xn <<= 64 ;
416
+ }
417
+ if (aa >= (1 << 64 )) {
418
+ aa >>= 64 ;
419
+ xn <<= 32 ;
420
+ }
421
+ if (aa >= (1 << 32 )) {
422
+ aa >>= 32 ;
423
+ xn <<= 16 ;
424
+ }
425
+ if (aa >= (1 << 16 )) {
426
+ aa >>= 16 ;
427
+ xn <<= 8 ;
428
+ }
429
+ if (aa >= (1 << 8 )) {
430
+ aa >>= 8 ;
431
+ xn <<= 4 ;
432
+ }
433
+ if (aa >= (1 << 4 )) {
434
+ aa >>= 4 ;
435
+ xn <<= 2 ;
436
+ }
437
+ if (aa >= (1 << 2 )) {
438
+ xn <<= 1 ;
439
+ }
440
+
441
+ // We now have x_n such that `x_n = 2**(e-1) ≤ sqrt(a) < 2**e = 2 * x_n`. This implies ε_n ≤ 2**(e-1).
442
+ //
443
+ // We can refine our estimation by noticing that the the middle of that interval minimizes the error.
444
+ // If we move x_n to equal 2**(e-1) + 2**(e-2), then we reduce the error to ε_n ≤ 2**(e-2).
445
+ // This is going to be our x_0 (and ε_0)
446
+ xn = (3 * xn) >> 1 ; // ε_0 := | x_0 - sqrt(a) | ≤ 2**(e-2)
447
+
448
+ // From here, Newton's method give us:
449
+ // x_{n+1} = (x_n + a / x_n) / 2
450
+ //
451
+ // One should note that:
452
+ // x_{n+1}² - a = ((x_n + a / x_n) / 2)² - a
453
+ // = ((x_n² + a) / (2 * x_n))² - a
454
+ // = (x_n⁴ + 2 * a * x_n² + a²) / (4 * x_n²) - a
455
+ // = (x_n⁴ + 2 * a * x_n² + a² - 4 * a * x_n²) / (4 * x_n²)
456
+ // = (x_n⁴ - 2 * a * x_n² + a²) / (4 * x_n²)
457
+ // = (x_n² - a)² / (2 * x_n)²
458
+ // = ((x_n² - a) / (2 * x_n))²
459
+ // ≥ 0
460
+ // Which proves that for all n ≥ 1, sqrt(a) ≤ x_n
461
+ //
462
+ // This gives us the proof of quadratic convergence of the sequence:
463
+ // ε_{n+1} = | x_{n+1} - sqrt(a) |
464
+ // = | (x_n + a / x_n) / 2 - sqrt(a) |
465
+ // = | (x_n² + a - 2*x_n*sqrt(a)) / (2 * x_n) |
466
+ // = | (x_n - sqrt(a))² / (2 * x_n) |
467
+ // = | ε_n² / (2 * x_n) |
468
+ // = ε_n² / | (2 * x_n) |
469
+ //
470
+ // For the first iteration, we have a special case where x_0 is known:
471
+ // ε_1 = ε_0² / | (2 * x_0) |
472
+ // ≤ (2**(e-2))² / (2 * (2**(e-1) + 2**(e-2)))
473
+ // ≤ 2**(2*e-4) / (3 * 2**(e-1))
474
+ // ≤ 2**(e-3) / 3
475
+ // ≤ 2**(e-3-log2(3))
476
+ // ≤ 2**(e-4.5)
477
+ //
478
+ // For the following iterations, we use the fact that, 2**(e-1) ≤ sqrt(a) ≤ x_n:
479
+ // ε_{n+1} = ε_n² / | (2 * x_n) |
480
+ // ≤ (2**(e-k))² / (2 * 2**(e-1))
481
+ // ≤ 2**(2*e-2*k) / 2**e
482
+ // ≤ 2**(e-2*k)
483
+ xn = (xn + a / xn) >> 1 ; // ε_1 := | x_1 - sqrt(a) | ≤ 2**(e-4.5) -- special case, see above
484
+ xn = (xn + a / xn) >> 1 ; // ε_2 := | x_2 - sqrt(a) | ≤ 2**(e-9) -- general case with k = 4.5
485
+ xn = (xn + a / xn) >> 1 ; // ε_3 := | x_3 - sqrt(a) | ≤ 2**(e-18) -- general case with k = 9
486
+ xn = (xn + a / xn) >> 1 ; // ε_4 := | x_4 - sqrt(a) | ≤ 2**(e-36) -- general case with k = 18
487
+ xn = (xn + a / xn) >> 1 ; // ε_5 := | x_5 - sqrt(a) | ≤ 2**(e-72) -- general case with k = 36
488
+ xn = (xn + a / xn) >> 1 ; // ε_6 := | x_6 - sqrt(a) | ≤ 2**(e-144) -- general case with k = 72
489
+
490
+ // Because e ≤ 128 (as discussed during the first estimation phase), we know have reached a precision
491
+ // ε_6 ≤ 2**(e-144) < 1. Given we're operating on integers, then we can ensure that xn is now either
492
+ // sqrt(a) or sqrt(a) + 1.
493
+ return xn - SafeCast.toUint (xn > a / xn);
420
494
}
421
495
}
422
496
0 commit comments