diff --git a/src/utils/512Math.sol b/src/utils/512Math.sol index 7c1753e99..8ff4db340 100644 --- a/src/utils/512Math.sol +++ b/src/utils/512Math.sol @@ -1689,160 +1689,94 @@ library Lib512MathArithmetic { return omodAlt(r, y, r); } - // hi ≈ x · y / 2²⁵⁶ (±1) - function _inaccurateMulHi(uint256 x, uint256 y) private pure returns (uint256 hi) { - assembly ("memory-safe") { - hi := sub(mulmod(x, y, not(0x00)), mul(x, y)) - } - } - - // gas benchmark 2025/09/20: ~1425 gas function _sqrt(uint256 x_hi, uint256 x_lo) private pure returns (uint256 r) { - /// Our general approach here is to compute the inverse of the square root of the argument - /// using Newton-Raphson iterations. Then we combine (multiply) this inverse square root - /// approximation with the argument to approximate the square root of the argument. After - /// that, a final fixup step is applied to get the exact result. We compute the inverse of - /// the square root rather than the square root directly because then our Newton-Raphson - /// iteration can avoid the extremely expensive 512-bit division subroutine. unchecked { - /// First, we normalize `x` by separating it into a mantissa and exponent. We use - /// even-exponent normalization. - - // `e` is half the exponent of `x` - // e = ⌊bitlength(x)/2⌋ - // invE = 256 - e - uint256 invE = (x_hi.clz() + 1) >> 1; // invE ∈ [0, 128] - - // Extract mantissa M by shifting x right by 2·e - 255 bits - // `M` is the mantissa of `x` as a Q1.255; M ∈ [½, 2) - (, uint256 M) = _shr(x_hi, x_lo, 257 - (invE << 1)); // scale: 2⁽²⁵⁵⁻²ᵉ⁾ - - /// Pick an initial estimate (seed) for Y using a lookup table. Even-exponent - /// normalization means our mantissa is geometrically symmetric around 1, leading to 16 - /// buckets on the low side and 32 buckets on the high side. - // `Y` _ultimately_ approximates the inverse square root of fixnum `M` as a - // Q3.253. However, as a gas optimization, the number of fractional bits in `Y` rises - // through the steps, giving an inhomogeneous fixed-point representation. Y ≈∈ [√½, √2] - uint256 Y; // scale: 2⁽²⁵³⁺ᵉ⁾ - uint256 Mbucket; + /// Our general approach is to apply Zimmerman's "Karatsuba Square Root" algorithm + /// https://inria.hal.science/inria-00072854/document with the helpers from Solady and + /// 512Math. This approach is inspired by + /// https://github.com/SimonSuckut/Solidity_Uint512/ + + // Normalize `x` so the top word has its MSB in bit 255 or 254. + // x ≥ 2⁵¹⁰ + uint256 shift = x_hi.clz(); + (, x_hi, x_lo) = _shl256(x_hi, x_lo, shift & 0xfe); + shift >>= 1; + + // We treat `r` as a ≤2-limb bigint where each limb is half a machine word (128 bits). + // Spliting √x in this way lets us apply "ordinary" 256-bit `sqrt` to the top word of + // `x`. Then we can recover the bottom limb of `r` without 512-bit division. + // + // Implementing this as: + // uint256 r_hi = x_hi.sqrt(); + // is correct, but duplicates the normalization that we just did above and performs a + // more-costly initialization step. solc is not smart enough to optimize this away, so + // we inline and do it ourselves. + uint256 r_hi; assembly ("memory-safe") { - // Extract the upper 6 bits of `M` to be used as a table index. `M >> 250 < 16` is - // invalid (that would imply M<½), so our lookup table only needs to handle only 16 - // through 63. - Mbucket := shr(0xfa, M) - // We can't fit 48 seeds into a single word, so we split the table in 2 and use `c` - // to select which table we index. - let c := lt(0x27, Mbucket) - - // Each entry is 10 bits and the entries are ordered from lowest `i` to highest. The - // seed is the value for `Y` for the midpoint of the bucket, rounded to 10 - // significant bits. That is, Y ≈ 1/√(2·M_mid), as a Q247.9. The 2 comes from the - // half-scale difference between Y and √M. The optimality of this choice was - // verified by fuzzing. - let table_hi := 0x71dc26f1b76c9ad6a5a46819c661946418c621856057e5ed775d1715b96b - let table_lo := 0xb26b4a8690a027198e559263e8ce2887e15832047f1f47b5e677dd974dcd - let table := xor(table_lo, mul(xor(table_hi, table_lo), c)) - - // Index the table to obtain the initial seed of `Y`. - let shift := add(0x186, mul(0x0a, sub(mul(0x18, c), Mbucket))) - // We begin the Newton-Raphson iterations with `Y` in Q247.9 format. - Y := and(0x3ff, shr(shift, table)) - - // The worst-case seed for `Y` occurs when `Mbucket = 16`. For monotone quadratic - // convergence, we desire that 1/√3 < Y·√M < √(5/3). At the boundaries (worst case) - // of the `Mbucket = 16` range, we are 0.407351 (41.3680%) from the lower bound and - // 0.275987 (27.1906%) from the higher bound. + // Initialization requires that the first bit of `r_hi` be in the correct + // position. This is correct from the normalization above. + r_hi := 0x80000000000000000000000000000000 + + // Seven Babylonian steps is sufficient for convergence. + r_hi := shr(0x01, add(r_hi, div(x_hi, r_hi))) + r_hi := shr(0x01, add(r_hi, div(x_hi, r_hi))) + r_hi := shr(0x01, add(r_hi, div(x_hi, r_hi))) + r_hi := shr(0x01, add(r_hi, div(x_hi, r_hi))) + r_hi := shr(0x01, add(r_hi, div(x_hi, r_hi))) + r_hi := shr(0x01, add(r_hi, div(x_hi, r_hi))) + r_hi := shr(0x01, add(r_hi, div(x_hi, r_hi))) + + // The Babylonian step can oscillate between ⌊√x_hi⌋ and ⌈√x_hi⌉. Clean that up. + r_hi := sub(r_hi, lt(div(x_hi, r_hi), r_hi)) } - /// Perform 5 Newton-Raphson iterations. 5 is enough iterations for sufficient - /// convergence that our final fixup step produces an exact result. - // The Newton-Raphson iteration for 1/√M is: - // Y ≈ Y · (3 - M · Y²) / 2 - // The implementation of this iteration is deliberately imprecise. No matter how many - // times you run it, you won't converge `Y` on the closest Q3.253 to √M. However, this - // is acceptable because the cleanup step applied after the final call is very tolerant - // of error in the low bits of `Y`. - - // `M` is Q1.255 - // `Y` is Q247.9 - { - uint256 Y2 = Y * Y; // scale: 2¹⁸ - // Because `M` is Q1.255, multiplying `Y2` by `M` and taking the high word - // implicitly divides `MY2` by 2. We move the division by 2 inside the subtraction - // from 3 by adjusting the minuend. - uint256 MY2 = _inaccurateMulHi(M, Y2); // scale: 2¹⁸ - uint256 T = 1.5 * 2 ** 18 - MY2; // scale: 2¹⁸ - Y *= T; // scale: 2²⁷ - } - // `Y` is Q229.27 - { - uint256 Y2 = Y * Y; // scale: 2⁵⁴ - uint256 MY2 = _inaccurateMulHi(M, Y2); // scale: 2⁵⁴ - uint256 T = 1.5 * 2 ** 54 - MY2; // scale: 2⁵⁴ - Y *= T; // scale: 2⁸¹ - } - // `Y` is Q175.81 - { - uint256 Y2 = Y * Y; // scale: 2¹⁶² - uint256 MY2 = _inaccurateMulHi(M, Y2); // scale: 2¹⁶² - uint256 T = 1.5 * 2 ** 162 - MY2; // scale: 2¹⁶² - Y = Y * T >> 116; // scale: 2¹²⁷ - } - // `Y` is Q129.127 - if (invE < 95 - Mbucket) { - // Generally speaking, for relatively smaller `e` (lower values of `x`) and for - // relatively larger `M`, we can skip the 5th N-R iteration. The constant `95` is - // derived by extensive fuzzing. Attempting a higher-order approximation of the - // relationship between `M` and `invE` consumes, on average, more gas. When this - // branch is not taken, the correct bits that this iteration would obtain are - // shifted away during the denormalization step. This branch is net gas-optimizing. - uint256 Y2 = Y * Y; // scale: 2²⁵⁴ - uint256 MY2 = _inaccurateMulHi(M, Y2); // scale: 2²⁵⁴ - uint256 T = 1.5 * 2 ** 254 - MY2; // scale: 2²⁵⁴ - Y = _inaccurateMulHi(Y << 2, T); // scale: 2¹²⁷ + // This is cheaper than + // uint256 res = x_hi - r_hi * r_hi; + // for no clear reason + uint256 res; + assembly ("memory-safe") { + res := sub(x_hi, mul(r_hi, r_hi)) } - // `Y` is Q129.127 - { - uint256 Y2 = Y * Y; // scale: 2²⁵⁴ - uint256 MY2 = _inaccurateMulHi(M, Y2); // scale: 2²⁵⁴ - uint256 T = 1.5 * 2 ** 254 - MY2; // scale: 2²⁵⁴ - Y = _inaccurateMulHi(Y << 128, T); // scale: 2²⁵³ + + uint256 r_lo; + // `res` is (almost) a single limb. Create a new (almost) machine word `n` with `res` as + // the upper limb and shifting in the next limb of `x` (namely `x_lo >> 128`) as the + // lower limb. The next step of Zimmerman's algorithm is: + // r_lo = n / (2 · r_hi) + // res = n % (2 · r_hi) + assembly ("memory-safe") { + let n := or(shl(0x80, res), shr(0x80, x_lo)) + let d := shl(0x01, r_hi) + r_lo := div(n, d) + + let c := shr(0x80, res) + res := mod(n, d) + + // It's possible that `n` was 257 bits and overflowed (`res` was not just a single + // limb). Explicitly handling the carry avoids 512-bit division. + if c { + let neg_c := not(0x00) + r_lo := add(r_lo, div(neg_c, d)) + res := add(res, add(0x01, mod(neg_c, d))) + r_lo := add(r_lo, div(res, d)) + res := mod(res, d) + } } - // `Y` is Q3.253 - - /// When we combine `Y` with `M` to form our approximation of the square root, we have - /// to un-normalize by the half-scale value. This is where even-exponent normalization - /// comes in because the half-scale is integral. - /// M = ⌊x · 2⁽²⁵⁵⁻²ᵉ⁾⌋ - /// Y ≈ 2²⁵³ / √(M / 2²⁵⁵) - /// Y ≈ 2³⁸¹ / √(2·M) - /// M·Y ≈ 2³⁸¹ · √(M/2) - /// M·Y ≈ 2⁽⁵⁰⁸⁻ᵉ⁾ · √x - /// r0 ≈ M·Y / 2⁽⁵⁰⁸⁻ᵉ⁾ ≈ ⌊√x⌋ - // We shift right by `508 - e` to account for both the Q3.253 scaling and - // denormalization. We don't care about accuracy in the low bits of `r0`, so we can cut - // some corners. - (, uint256 r0) = _shr(_inaccurateMulHi(M, Y), 0, 252 + invE); - - /// `r0` is only an approximation of √x, so we perform a single Babylonian step to fully - /// converge on ⌊√x⌋ or ⌈√x⌉. The Babylonian step is: - /// r = ⌊(r0 + ⌊x/r0⌋) / 2⌋ - // Rather than use the more-expensive division routine that returns a 512-bit result, - // because the value the upper word of the quotient can take is highly constrained, we - // can compute the quotient mod 2²⁵⁶ and recover the high word separately. Although - // `_div` does an expensive Newton-Raphson-Hensel modular inversion: - // ⌊x/r0⌋ ≡ ⌊x/2ⁿ⌋·⌊r0/2ⁿ⌋⁻¹ mod 2²⁵⁶ (for r % 2⁽ⁿ⁺¹⁾ = 2ⁿ) - // and we already have a pretty good estimate for r0⁻¹, namely `Y`, refining `Y` into - // the appropriate inverse requires a series of 768-bit multiplications that take more - // gas. - uint256 q_lo = _div(x_hi, x_lo, r0); - uint256 q_hi = (r0 <= x_hi).toUint(); - (uint256 s_hi, uint256 s_lo) = _add(q_hi, q_lo, r0); - // `oflo` here is either 0 or 1. When `oflo == 1`, `r == 0`, and the correct value for - // `r` is `type(uint256).max`. - uint256 oflo; - (oflo, r) = _shr256(s_hi, s_lo, 1); - r -= oflo; // underflow is desired + r = (r_hi << 128) + r_lo; + + // Then, if res · 2¹²⁸ + x_lo % 2¹²⁸ < r_lo², decrement `r`. We have to do this in a + // complicated manner because both `res` and `r_lo` can be _slightly_ longer than 1 limb + // (128 bits). This is more efficient than performing the full 257-bit comparison. + r = r.unsafeDec( + ((res >> 128) < (r_lo >> 128)) + .or( + ((res >> 128) == (r_lo >> 128)) + .and((res << 128) | (x_lo & 0xffffffffffffffffffffffffffffffff) < r_lo * r_lo) + ) + ); + + // Un-normalize + return r >> shift; } } @@ -1853,13 +1787,7 @@ library Lib512MathArithmetic { return x_lo.sqrt(); } - uint256 r = _sqrt(x_hi, x_lo); - - // Because the Babylonian step can give ⌈√x⌉ if x+1 is a perfect square, we have to - // check whether we've overstepped by 1 and clamp as appropriate. ref: - // https://en.wikipedia.org/wiki/Integer_square_root#Using_only_integer_division - (uint256 r2_hi, uint256 r2_lo) = _mul(r, r); - return r.unsafeDec(_gt(r2_hi, r2_lo, x_hi, x_lo)); + return _sqrt(x_hi, x_lo); } function osqrtUp(uint512 r, uint512 x) internal pure returns (uint512) { @@ -1871,8 +1799,6 @@ library Lib512MathArithmetic { uint256 r_lo = _sqrt(x_hi, x_lo); - // The Babylonian step can give ⌈√x⌉ if x+1 is a perfect square. This is - // fine. If the Babylonian step gave ⌊√x⌋ ≠ √x, we have to round up. (uint256 r2_hi, uint256 r2_lo) = _mul(r_lo, r_lo); uint256 r_hi; (r_hi, r_lo) = _add(0, r_lo, _gt(x_hi, x_lo, r2_hi, r2_lo).toUint()); diff --git a/test/0.8.25/512Math.t.sol b/test/0.8.25/512Math.t.sol index d202c4518..727817109 100644 --- a/test/0.8.25/512Math.t.sol +++ b/test/0.8.25/512Math.t.sol @@ -422,6 +422,18 @@ contract Lib512MathTest is Test { } } + function test512Math_sqrt_perfectSquare(uint256 r) external pure { + uint512 x = alloc().omul(r, r); + assertEq(x.sqrt(), r); + } + + function test512Math_osqrtUp_perfectSquare(uint256 r) external pure { + uint512 x = alloc().omul(r, r); + (uint256 r_hi, uint256 r_lo) = alloc().osqrtUp(x).into(); + assertEq(r_hi, 0); + assertEq(r_lo, r); + } + function test512Math_oshrUp(uint256 x_hi, uint256 x_lo, uint256 s) external pure { s = bound(s, 0, 512);