diff --git a/src/utils/512Math.sol b/src/utils/512Math.sol index 8ff4db340..2959fead2 100644 --- a/src/utils/512Math.sol +++ b/src/utils/512Math.sol @@ -7,6 +7,7 @@ import {Clz} from "../vendor/Clz.sol"; import {Ternary} from "./Ternary.sol"; import {FastLogic} from "./FastLogic.sol"; import {Sqrt} from "../vendor/Sqrt.sol"; +import {Cbrt} from "../vendor/Cbrt.sol"; /* @@ -180,6 +181,11 @@ WARNING *** WARNING *** WARNING *** WARNING *** WARNING *** WARNING *** WARNING /// * osqrtUp(uint512,uint512) /// * isqrtUp(uint512) /// +/// ### Cube root +/// +/// * cbrt(uint512) returns (uint256) +/// * cbrtUp(uint512) returns (uint256) +/// /// ### Shifting /// /// * oshr(uint512,uint512,uint256) @@ -358,6 +364,7 @@ library Lib512MathArithmetic { using Ternary for bool; using FastLogic for bool; using Sqrt for uint256; + using Cbrt for uint256; function _add(uint256 x, uint256 y) private pure returns (uint256 r_hi, uint256 r_lo) { assembly ("memory-safe") { @@ -1690,13 +1697,12 @@ library Lib512MathArithmetic { } function _sqrt(uint256 x_hi, uint256 x_lo) private pure returns (uint256 r) { + /// Our general approach is to apply Zimmerman's "Karatsuba Square Root" algorithm + /// https://inria.hal.science/inria-00072854/document with the helpers from Solady and + /// 512Math. This approach is inspired by https://github.com/SimonSuckut/Solidity_Uint512/ unchecked { - /// Our general approach is to apply Zimmerman's "Karatsuba Square Root" algorithm - /// https://inria.hal.science/inria-00072854/document with the helpers from Solady and - /// 512Math. This approach is inspired by - /// https://github.com/SimonSuckut/Solidity_Uint512/ - - // Normalize `x` so the top word has its MSB in bit 255 or 254. + // Normalize `x` so the top word has its MSB in bit 255 or 254. This makes the "shift + // back" step exact. // x ≥ 2⁵¹⁰ uint256 shift = x_hi.clz(); (, x_hi, x_lo) = _shl256(x_hi, x_lo, shift & 0xfe); @@ -1755,9 +1761,8 @@ library Lib512MathArithmetic { // It's possible that `n` was 257 bits and overflowed (`res` was not just a single // limb). Explicitly handling the carry avoids 512-bit division. if c { - let neg_c := not(0x00) - r_lo := add(r_lo, div(neg_c, d)) - res := add(res, add(0x01, mod(neg_c, d))) + r_lo := add(r_lo, div(not(0x00), d)) + res := add(res, add(0x01, mod(not(0x00), d))) r_lo := add(r_lo, div(res, d)) res := mod(res, d) } @@ -1809,6 +1814,195 @@ library Lib512MathArithmetic { return osqrtUp(r, r); } + function _cbrt(uint256 x_hi, uint256 x_lo) private pure returns (uint256 r) { + /// This is the same general technique as we applied in `_sqrt`, patterned after Zimmerman's + /// "Karatsuba Square Root" algorithm, but adapted to compute cube roots instead. + unchecked { + // Normalize `x` so that its MSB is in bit 255, 254, or 253. This makes the left shift a + // multiple of 3 so that the "shift back" un-normalization step is exact. + // x ≥ 2⁵⁰⁹ + uint256 shift = x_hi.clz() / 3; + (, x_hi, x_lo) = _shl256(x_hi, x_lo, shift * 3); + + // Zimmerman's "Karatsuba Square Root" algorithm works with limbs of `r` that are half + // of a word. For cube root, we use limbs of `r` that are (roughly) one third of a + // word. The initial step to compute the first "limb" of `r` uses the "normal" cube root + // algorithm and consumes the first (almost) word of `x`. The second and final limb of + // `r` is computed using an analogue of the Karatsuba step from the original algorithm, + // followed by a pair of cleanup steps. + + // Now we run the "normal" cube root algorithm to obtain the first limb of `r`, which we + // store in `r_hi`. `res` is the residue after this first operation and `d` is the + // derivative/denominator for the subsequent Karatsuba step. + uint256 r_hi; + uint256 res; + uint256 d; + assembly ("memory-safe") { + let w := shr(0x02, x_hi) // w ≥ 2²⁵¹; w < 2²⁵⁴ from the above normalization + r_hi := 0x1000000000000000000000 // Given `w` in its range, this seed is suitable + + r_hi := div(add(add(div(w, mul(r_hi, r_hi)), r_hi), r_hi), 0x03) + r_hi := div(add(add(div(w, mul(r_hi, r_hi)), r_hi), r_hi), 0x03) + r_hi := div(add(add(div(w, mul(r_hi, r_hi)), r_hi), r_hi), 0x03) + r_hi := div(add(add(div(w, mul(r_hi, r_hi)), r_hi), r_hi), 0x03) + r_hi := div(add(add(div(w, mul(r_hi, r_hi)), r_hi), r_hi), 0x03) + r_hi := div(add(add(div(w, mul(r_hi, r_hi)), r_hi), r_hi), 0x03) + + let r_hi_sq := mul(r_hi, r_hi) + let r_hi_cube := mul(r_hi_sq, r_hi) + if gt(r_hi_cube, w) { + // We gate the 7th Newton-Raphson step and ceil/floor cleanup on whether it has + // overestimated. The second-order correction further below is sufficient to + // correct for small underestimation. This branch is net gas-optimizing. + r_hi := div(add(add(div(w, r_hi_sq), r_hi), r_hi), 0x03) + r_hi := sub(r_hi, lt(div(w, mul(r_hi, r_hi)), r_hi)) + r_hi_sq := mul(r_hi, r_hi) + r_hi_cube := mul(r_hi_sq, r_hi) + } + + res := sub(w, r_hi_cube) + d := mul(0x03, r_hi_sq) + } + + // `limb_hi` is the next 86-bit limb of `x` after the first whole-ish word. + uint256 limb_hi; + assembly ("memory-safe") { + limb_hi := or(shl(0x54, and(x_hi, 0x03)), shr(0xac, x_lo)) + } + // This is the Karatsuba step. The 86-bit lower limb of `r` is (almost): + // r_lo = ⌊(res ⋅ 2⁸⁶ + limb_hi) / (3 ⋅ r_hi²)⌋ + // Where `res` is the (nearly) 2-limb residue from the previous "normal" cube root + // step. We discard `res` after this step and perform a quadratic correction instead of + // the underflow check from Zimmerman + uint256 r_lo; + assembly ("memory-safe") { + let n := or(shl(0x56, res), limb_hi) + r_lo := div(n, d) + + // If `res` was 171 bits (one more than expected), then `n` overflowed to 257 + // bits. Explicitly handling the carry avoids 512-bit division. + if shr(0xaa, res) { + let rem := mod(n, d) + r_lo := add(r_lo, div(not(0x00), d)) + rem := add(rem, add(0x01, mod(not(0x00), d))) + r_lo := add(r_lo, div(rem, d)) + } + } + + // Unlike the square-root case, the error from the linear Karatsuba step can still be + // large because the expansion has more terms. We do a quadratic correction to get close + // enough that the single subtraction is sufficient for exactness. + // + // In the square-root version, the only ignored term in (s + q)² is q², which is small + // enough for a 1ulp correction. For cube root, the binomial expansion (r_hi·2⁸⁶ + + // r_lo)³ contains the cross term 3·(r_hi·2⁸⁶)·r_lo². The linear Karatsuba step + // overestimates r_lo by ≈r_lo²/(r_hi·2⁸⁶). After correction, this leaves only the r_lo³ + // term, on the order of 2²⁵⁸/(3·2³⁴²), much less than 1ulp. + r_lo -= (r_lo * r_lo).unsafeDiv(r_hi << 86); + r = (r_hi << 86) + r_lo; + // Our error is now down to 1ulp. + + // Un-normalize + return r >> shift; + } + } + + function cbrt(uint512 x) internal pure returns (uint256 r) { + (uint256 x_hi, uint256 x_lo) = x.into(); + + if (x_hi == 0) { + return x_lo.cbrt(); + } + + r = _cbrt(x_hi, x_lo); + + // The following cube-and-compare technique for obtaining the floor appears, at first, to + // have an overflow bug in it. Consider that `_cbrt` returns a value within 1ulp of the + // correct value. Define: + // r_max = 0x6597fa94f5b8f20ac16666ad0f7137bc6601d885628 + // this means that for values of x in [r_max³, 2⁵¹² - 1], `_cbrt` could return r_max + 1, + // which would result in overflow when cubing `r`. However, this does not happen. Given `x` + // in the specified range, the `_cbrt` follows the steps below: + // + // 1) shift = ⌊clz(x_hi) / 3⌋ = 0 + // 2) w = x_hi >> 2 lies in [0x3fff..fffb0959fdf442978718ddcb, 2²⁵⁴ - 1] + // 3) In that full interval, ⌊∛w⌋ is constant. For `r_hi`, we get: + // after 6 Newton-Raphson iterations: r_hi = 0x1965fea53d6e3c82b0c310 + // which forces a 7th iteration + // after the branch is taken: r_hi = 0x1965fea53d6e3c82b05999 + // 4) Therefore d = 3 ⋅ r_hi² is constant: + // d = 0x78f3d1d950af414cd731fe48f48fde1309821333853 + // 5) n = (res << 86) | limb_hi overflows and is truncated to 256 bits. The truncated ⌊n / d⌋ + // is constant: + // ⌊n / d⌋ = 0x8f3a38c7f3364c49d3405 + // The carry branch (res >> 170 != 0) fires. The carry adjustment modifies the truncated + // quotient by adding: + // ⌊(2²⁵⁶ - 1) / d⌋ = 0x21dd5386fc92fb58eb2224 + // and the final carry refinement term is zero, giving: + // r_lo = 0x2ad0f7137bc6601d885629 + // The quotient stays in one "bucket" because `res` varies by only ~0.62·2⁸³, and + // `limb_hi`'s full 86-bit range contributes <1/2⁸⁴ to n/d. Total swing in the continuous + // quotient is ~0.164. At the boundaries, frac(n/d) ≈ 0.118 (at x = r_max³) and ≈0.283 + // (at x = 2⁵¹² - 1), so the floor never crosses an integer boundary + // 6) After the carry adjustment branch, `r_lo` is constant: + // r_lo = 0x2ad0f7137bc6601d885629 + // 7) The quadratic correction subtracts exactly 1: + // ⌊r_lo² / (r_hi·2⁸⁶)⌋ = 1 + // so r_lo = 0x2ad0f7137bc6601d885628 and + // r = r_hi·2⁸⁶ + r_lo = r_max + // + // So, the cube-and-compare code below only cubes a value of at most `r_max`, which fits in + // 512 bits. `cbrtUp` reaches `r_max + 1` only via its final +1 correction + // + // The following assembly block is identical to + // (uint256 r2_hi, uint256 r2_lo) = _mul(r, r); + // (uint256 r3_hi, uint256 r3_lo) = _mul(r2_hi, r2_lo, r); + // r = r.unsafeDec(_gt(r3_hi, r3_lo, x_hi, x_lo)); + // but is substantially more gas efficient for inexplicable reasons + assembly ("memory-safe") { + let mm := mulmod(r, r, not(0x00)) + let r2_lo := mul(r, r) + let r2_hi := sub(sub(mm, r2_lo), lt(mm, r2_lo)) + + mm := mulmod(r2_lo, r, not(0x00)) + let r3_lo := mul(r2_lo, r) + let r3_hi := add(sub(sub(mm, r3_lo), lt(mm, r3_lo)), mul(r2_hi, r)) + + r := sub(r, or(gt(r3_hi, x_hi), and(eq(r3_hi, x_hi), gt(r3_lo, x_lo)))) + } + } + + function cbrtUp(uint512 x) internal pure returns (uint256 r) { + (uint256 x_hi, uint256 x_lo) = x.into(); + + if (x_hi == 0) { + return x_lo.cbrtUp(); + } + + r = _cbrt(x_hi, x_lo); + + // `_cbrt` gives a result within 1ulp. Check if `r` is too low and correct. + // + // The following assembly block is identical to + // (uint256 r2_hi, uint256 r2_lo) = _mul(r, r); + // (uint256 r3_hi, uint256 r3_lo) = _mul(r2_hi, r2_lo, r); + // r = r.unsafeInc(_gt(x_hi, x_lo, r3_hi, r3_lo)); + // but is substantially more gas efficient for inexplicable reasons + assembly ("memory-safe") { + // See the detailed overflow-regime note in `cbrt` above. In particular, near x = 2⁵¹², + // `_cbrt` is pinned at `r_max` and does not return `r_max + 1` directly. + let mm := mulmod(r, r, not(0x00)) + let r2_lo := mul(r, r) + let r2_hi := sub(sub(mm, r2_lo), lt(mm, r2_lo)) + + mm := mulmod(r2_lo, r, not(0x00)) + let r3_lo := mul(r2_lo, r) + let r3_hi := add(sub(sub(mm, r3_lo), lt(mm, r3_lo)), mul(r2_hi, r)) + + r := add(r, or(lt(r3_hi, x_hi), and(eq(r3_hi, x_hi), lt(r3_lo, x_lo)))) + } + } + function oshr(uint512 r, uint512 x, uint256 s) internal pure returns (uint512) { (uint256 x_hi, uint256 x_lo) = x.into(); (uint256 r_hi, uint256 r_lo) = _shr(x_hi, x_lo, s); diff --git a/test/0.8.25/512Math.t.sol b/test/0.8.25/512Math.t.sol index 727817109..9036f09e5 100644 --- a/test/0.8.25/512Math.t.sol +++ b/test/0.8.25/512Math.t.sol @@ -434,6 +434,82 @@ contract Lib512MathTest is Test { assertEq(r_lo, r); } + function test512Math_cbrt(uint256 x_hi, uint256 x_lo) external pure { + uint512 x = alloc().from(x_hi, x_lo); + uint256 r = x.cbrt(); + uint512 r3 = alloc().omul(r, r).imul(r); + + assertTrue(r3 <= x, "cbrt too high"); + if ( + x_hi > 0xffffffffffffffffffffffffffffffffffffffffffec2567f7d10a5e1c63772f + || (x_hi == 0xffffffffffffffffffffffffffffffffffffffffffec2567f7d10a5e1c63772f + && x_lo > 0xd70b34358c5c72dd2dbdc27132d143e3a7f08c1088df427db0884640df2d79ff) + ) { + assertEq(r, 0x6597fa94f5b8f20ac16666ad0f7137bc6601d885628, "cbrt overflow"); + } else { + r++; + r3.omul(r, r).imul(r); + assertTrue(r3 > x, "cbrt too low"); + } + } + + function test512Math_cbrt_perfectCube(uint256 r) external pure { + r = bound(r, 1, 0x6597fa94f5b8f20ac16666ad0f7137bc6601d885628); + uint512 x = alloc().omul(r, r).imul(r); + assertEq(x.cbrt(), r); + } + + function test512Math_cbrt_overflowCubeRegime(uint256 x_hi, uint256 x_lo) external pure { + uint256 r_max = 0x6597fa94f5b8f20ac16666ad0f7137bc6601d885628; + uint256 r_max_plus_one = 0x6597fa94f5b8f20ac16666ad0f7137bc6601d885629; + uint256 r_max_cube_hi = 0xffffffffffffffffffffffffffffffffffffffffffec2567f7d10a5e1c63772f; + uint256 r_max_cube_lo = 0xd70b34358c5c72dd2dbdc27132d143e3a7f08c1088df427db0884640df2d7a00; + + // Force x > r_max^3 so cbrtUp(x) must return r_max + 1, whose cube is 513 bits. + // + // Why this still passes with the current implementation: + // `_cbrt` is returning `r_max` (not `r_max + 1`) in this regime. If `_cbrt` returned + // `r_max + 1`, then `cbrt` would have to decrement based on an overflowed cube-and-compare + // and this assertion would fail for these near-2^512 inputs. Because `cbrt` stays equal to + // `r_max`, both cube-and-compare paths only cube `r_max` (which fits in 512 bits), and + // `cbrtUp` reaches `r_max + 1` only via its final `+1` correction. + x_hi = bound(x_hi, r_max_cube_hi, type(uint256).max); + if (x_hi == r_max_cube_hi) { + x_lo = bound(x_lo, r_max_cube_lo + 1, type(uint256).max); + } + + uint512 x = alloc().from(x_hi, x_lo); + assertEq(x.cbrt(), r_max, "cbrt in overflow-cube regime"); + assertEq(x.cbrtUp(), r_max_plus_one, "cbrtUp in overflow-cube regime"); + } + + function test512Math_cbrtUp(uint256 x_hi, uint256 x_lo) external pure { + uint512 x = alloc().from(x_hi, x_lo); + uint256 r = x.cbrtUp(); + uint512 r3 = alloc().omul(r, r).imul(r); + + if ( + x_hi > 0xffffffffffffffffffffffffffffffffffffffffffec2567f7d10a5e1c63772f + || (x_hi == 0xffffffffffffffffffffffffffffffffffffffffffec2567f7d10a5e1c63772f + && x_lo > 0xd70b34358c5c72dd2dbdc27132d143e3a7f08c1088df427db0884640df2d7a00) + ) { + assertEq(r, 0x6597fa94f5b8f20ac16666ad0f7137bc6601d885629, "cbrtUp overflow"); + } else { + assertTrue(r3 >= x, "cbrtUp too low"); + } + if (x_hi != 0 || x_lo != 0) { + r--; + r3.omul(r, r).imul(r); + assertTrue(r3 < x, "cbrtUp too high"); + } + } + + function test512Math_cbrtUp_perfectCube(uint256 r) external pure { + r = bound(r, 1, 0x6597fa94f5b8f20ac16666ad0f7137bc6601d885628); + uint512 x = alloc().omul(r, r).imul(r); + assertEq(x.cbrtUp(), r); + } + function test512Math_oshrUp(uint256 x_hi, uint256 x_lo, uint256 s) external pure { s = bound(s, 0, 512);