Skip to content

Commit 140d66f

Browse files
chgormanAmxxernestognw
authored
Improved integer square root (#4403)
Co-authored-by: Hadrien Croubois <[email protected]> Co-authored-by: Ernesto García <[email protected]>
1 parent 96e5c08 commit 140d66f

File tree

1 file changed

+103
-29
lines changed

1 file changed

+103
-29
lines changed

contracts/utils/math/Math.sol

Lines changed: 103 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -385,38 +385,112 @@ library Math {
385385
* @dev Returns the square root of a number. If the number is not a perfect square, the value is rounded
386386
* towards zero.
387387
*
388-
* Inspired by Henry S. Warren, Jr.'s "Hacker's Delight" (Chapter 11).
388+
* This method is based on Newton's method for computing square roots; the algorithm is restricted to only
389+
* using integer operations.
389390
*/
390391
function sqrt(uint256 a) internal pure returns (uint256) {
391-
if (a == 0) {
392-
return 0;
393-
}
394-
395-
// For our first guess, we get the biggest power of 2 which is smaller than the square root of the target.
396-
//
397-
// We know that the "msb" (most significant bit) of our target number `a` is a power of 2 such that we have
398-
// `msb(a) <= a < 2*msb(a)`. This value can be written `msb(a)=2**k` with `k=log2(a)`.
399-
//
400-
// This can be rewritten `2**log2(a) <= a < 2**(log2(a) + 1)`
401-
// → `sqrt(2**k) <= sqrt(a) < sqrt(2**(k+1))`
402-
// → `2**(k/2) <= sqrt(a) < 2**((k+1)/2) <= 2**(k/2 + 1)`
403-
//
404-
// Consequently, `2**(log2(a) / 2)` is a good first approximation of `sqrt(a)` with at least 1 correct bit.
405-
uint256 result = 1 << (log2(a) >> 1);
406-
407-
// At this point `result` is an estimation with one bit of precision. We know the true value is a uint128,
408-
// since it is the square root of a uint256. Newton's method converges quadratically (precision doubles at
409-
// every iteration). We thus need at most 7 iteration to turn our partial result with one bit of precision
410-
// into the expected uint128 result.
411392
unchecked {
412-
result = (result + a / result) >> 1;
413-
result = (result + a / result) >> 1;
414-
result = (result + a / result) >> 1;
415-
result = (result + a / result) >> 1;
416-
result = (result + a / result) >> 1;
417-
result = (result + a / result) >> 1;
418-
result = (result + a / result) >> 1;
419-
return min(result, a / result);
393+
// Take care of easy edge cases when a == 0 or a == 1
394+
if (a <= 1) {
395+
return a;
396+
}
397+
398+
// In this function, we use Newton's method to get a root of `f(x) := x² - a`. It involves building a
399+
// sequence x_n that converges toward sqrt(a). For each iteration x_n, we also define the error between
400+
// the current value as `ε_n = | x_n - sqrt(a) |`.
401+
//
402+
// For our first estimation, we consider `e` the smallest power of 2 which is bigger than the square root
403+
// of the target. (i.e. `2**(e-1) ≤ sqrt(a) < 2**e`). We know that `e ≤ 128` because `(2¹²⁸)² = 2²⁵⁶` is
404+
// bigger than any uint256.
405+
//
406+
// By noticing that
407+
// `2**(e-1) ≤ sqrt(a) < 2**e → (2**(e-1))² ≤ a < (2**e)² → 2**(2*e-2) ≤ a < 2**(2*e)`
408+
// we can deduce that `e - 1` is `log2(a) / 2`. We can thus compute `x_n = 2**(e-1)` using a method similar
409+
// to the msb function.
410+
uint256 aa = a;
411+
uint256 xn = 1;
412+
413+
if (aa >= (1 << 128)) {
414+
aa >>= 128;
415+
xn <<= 64;
416+
}
417+
if (aa >= (1 << 64)) {
418+
aa >>= 64;
419+
xn <<= 32;
420+
}
421+
if (aa >= (1 << 32)) {
422+
aa >>= 32;
423+
xn <<= 16;
424+
}
425+
if (aa >= (1 << 16)) {
426+
aa >>= 16;
427+
xn <<= 8;
428+
}
429+
if (aa >= (1 << 8)) {
430+
aa >>= 8;
431+
xn <<= 4;
432+
}
433+
if (aa >= (1 << 4)) {
434+
aa >>= 4;
435+
xn <<= 2;
436+
}
437+
if (aa >= (1 << 2)) {
438+
xn <<= 1;
439+
}
440+
441+
// We now have x_n such that `x_n = 2**(e-1) ≤ sqrt(a) < 2**e = 2 * x_n`. This implies ε_n ≤ 2**(e-1).
442+
//
443+
// We can refine our estimation by noticing that the the middle of that interval minimizes the error.
444+
// If we move x_n to equal 2**(e-1) + 2**(e-2), then we reduce the error to ε_n ≤ 2**(e-2).
445+
// This is going to be our x_0 (and ε_0)
446+
xn = (3 * xn) >> 1; // ε_0 := | x_0 - sqrt(a) | ≤ 2**(e-2)
447+
448+
// From here, Newton's method give us:
449+
// x_{n+1} = (x_n + a / x_n) / 2
450+
//
451+
// One should note that:
452+
// x_{n+1}² - a = ((x_n + a / x_n) / 2)² - a
453+
// = ((x_n² + a) / (2 * x_n))² - a
454+
// = (x_n⁴ + 2 * a * x_n² + a²) / (4 * x_n²) - a
455+
// = (x_n⁴ + 2 * a * x_n² + a² - 4 * a * x_n²) / (4 * x_n²)
456+
// = (x_n⁴ - 2 * a * x_n² + a²) / (4 * x_n²)
457+
// = (x_n² - a)² / (2 * x_n)²
458+
// = ((x_n² - a) / (2 * x_n))²
459+
// ≥ 0
460+
// Which proves that for all n ≥ 1, sqrt(a) ≤ x_n
461+
//
462+
// This gives us the proof of quadratic convergence of the sequence:
463+
// ε_{n+1} = | x_{n+1} - sqrt(a) |
464+
// = | (x_n + a / x_n) / 2 - sqrt(a) |
465+
// = | (x_n² + a - 2*x_n*sqrt(a)) / (2 * x_n) |
466+
// = | (x_n - sqrt(a))² / (2 * x_n) |
467+
// = | ε_n² / (2 * x_n) |
468+
// = ε_n² / | (2 * x_n) |
469+
//
470+
// For the first iteration, we have a special case where x_0 is known:
471+
// ε_1 = ε_0² / | (2 * x_0) |
472+
// ≤ (2**(e-2))² / (2 * (2**(e-1) + 2**(e-2)))
473+
// ≤ 2**(2*e-4) / (3 * 2**(e-1))
474+
// ≤ 2**(e-3) / 3
475+
// ≤ 2**(e-3-log2(3))
476+
// ≤ 2**(e-4.5)
477+
//
478+
// For the following iterations, we use the fact that, 2**(e-1) ≤ sqrt(a) ≤ x_n:
479+
// ε_{n+1} = ε_n² / | (2 * x_n) |
480+
// ≤ (2**(e-k))² / (2 * 2**(e-1))
481+
// ≤ 2**(2*e-2*k) / 2**e
482+
// ≤ 2**(e-2*k)
483+
xn = (xn + a / xn) >> 1; // ε_1 := | x_1 - sqrt(a) | ≤ 2**(e-4.5) -- special case, see above
484+
xn = (xn + a / xn) >> 1; // ε_2 := | x_2 - sqrt(a) | ≤ 2**(e-9) -- general case with k = 4.5
485+
xn = (xn + a / xn) >> 1; // ε_3 := | x_3 - sqrt(a) | ≤ 2**(e-18) -- general case with k = 9
486+
xn = (xn + a / xn) >> 1; // ε_4 := | x_4 - sqrt(a) | ≤ 2**(e-36) -- general case with k = 18
487+
xn = (xn + a / xn) >> 1; // ε_5 := | x_5 - sqrt(a) | ≤ 2**(e-72) -- general case with k = 36
488+
xn = (xn + a / xn) >> 1; // ε_6 := | x_6 - sqrt(a) | ≤ 2**(e-144) -- general case with k = 72
489+
490+
// Because e ≤ 128 (as discussed during the first estimation phase), we know have reached a precision
491+
// ε_6 ≤ 2**(e-144) < 1. Given we're operating on integers, then we can ensure that xn is now either
492+
// sqrt(a) or sqrt(a) + 1.
493+
return xn - SafeCast.toUint(xn > a / xn);
420494
}
421495
}
422496

0 commit comments

Comments
 (0)