Skip to content
Open
238 changes: 82 additions & 156 deletions src/utils/512Math.sol
Original file line number Diff line number Diff line change
Expand Up @@ -1689,160 +1689,94 @@ library Lib512MathArithmetic {
return omodAlt(r, y, r);
}

// hi ≈ x · y / 2²⁵⁶ (±1)
function _inaccurateMulHi(uint256 x, uint256 y) private pure returns (uint256 hi) {
assembly ("memory-safe") {
hi := sub(mulmod(x, y, not(0x00)), mul(x, y))
}
}

// gas benchmark 2025/09/20: ~1425 gas
function _sqrt(uint256 x_hi, uint256 x_lo) private pure returns (uint256 r) {
/// Our general approach here is to compute the inverse of the square root of the argument
/// using Newton-Raphson iterations. Then we combine (multiply) this inverse square root
/// approximation with the argument to approximate the square root of the argument. After
/// that, a final fixup step is applied to get the exact result. We compute the inverse of
/// the square root rather than the square root directly because then our Newton-Raphson
/// iteration can avoid the extremely expensive 512-bit division subroutine.
unchecked {
/// First, we normalize `x` by separating it into a mantissa and exponent. We use
/// even-exponent normalization.

// `e` is half the exponent of `x`
// e = ⌊bitlength(x)/2⌋
// invE = 256 - e
uint256 invE = (x_hi.clz() + 1) >> 1; // invE ∈ [0, 128]

// Extract mantissa M by shifting x right by 2·e - 255 bits
// `M` is the mantissa of `x` as a Q1.255; M ∈ [½, 2)
(, uint256 M) = _shr(x_hi, x_lo, 257 - (invE << 1)); // scale: 2⁽²⁵⁵⁻²ᵉ⁾

/// Pick an initial estimate (seed) for Y using a lookup table. Even-exponent
/// normalization means our mantissa is geometrically symmetric around 1, leading to 16
/// buckets on the low side and 32 buckets on the high side.
// `Y` _ultimately_ approximates the inverse square root of fixnum `M` as a
// Q3.253. However, as a gas optimization, the number of fractional bits in `Y` rises
// through the steps, giving an inhomogeneous fixed-point representation. Y ≈∈ [√½, √2]
uint256 Y; // scale: 2⁽²⁵³⁺ᵉ⁾
uint256 Mbucket;
/// Our general approach is to apply Zimmerman's "Karatsuba Square Root" algorithm
/// https://inria.hal.science/inria-00072854/document with the helpers from Solady and
/// 512Math. This approach is inspired by
/// https://github.com/SimonSuckut/Solidity_Uint512/

// Normalize `x` so the top word has its MSB in bit 255 or 254.
// x ≥ 2⁵¹⁰
uint256 shift = x_hi.clz();
(, x_hi, x_lo) = _shl256(x_hi, x_lo, shift & 0xfe);
shift >>= 1;

// We treat `r` as a ≤2-limb bigint where each limb is half a machine word (128 bits).
// Spliting √x in this way lets us apply "ordinary" 256-bit `sqrt` to the top word of
// `x`. Then we can recover the bottom limb of `r` without 512-bit division.
//
// Implementing this as:
// uint256 r_hi = x_hi.sqrt();
// is correct, but duplicates the normalization that we just did above and performs a
// more-costly initialization step. solc is not smart enough to optimize this away, so
// we inline and do it ourselves.
uint256 r_hi;
assembly ("memory-safe") {
// Extract the upper 6 bits of `M` to be used as a table index. `M >> 250 < 16` is
// invalid (that would imply M<½), so our lookup table only needs to handle only 16
// through 63.
Mbucket := shr(0xfa, M)
// We can't fit 48 seeds into a single word, so we split the table in 2 and use `c`
// to select which table we index.
let c := lt(0x27, Mbucket)

// Each entry is 10 bits and the entries are ordered from lowest `i` to highest. The
// seed is the value for `Y` for the midpoint of the bucket, rounded to 10
// significant bits. That is, Y ≈ 1/√(2·M_mid), as a Q247.9. The 2 comes from the
// half-scale difference between Y and √M. The optimality of this choice was
// verified by fuzzing.
let table_hi := 0x71dc26f1b76c9ad6a5a46819c661946418c621856057e5ed775d1715b96b
let table_lo := 0xb26b4a8690a027198e559263e8ce2887e15832047f1f47b5e677dd974dcd
let table := xor(table_lo, mul(xor(table_hi, table_lo), c))

// Index the table to obtain the initial seed of `Y`.
let shift := add(0x186, mul(0x0a, sub(mul(0x18, c), Mbucket)))
// We begin the Newton-Raphson iterations with `Y` in Q247.9 format.
Y := and(0x3ff, shr(shift, table))

// The worst-case seed for `Y` occurs when `Mbucket = 16`. For monotone quadratic
// convergence, we desire that 1/√3 < Y·√M < √(5/3). At the boundaries (worst case)
// of the `Mbucket = 16` range, we are 0.407351 (41.3680%) from the lower bound and
// 0.275987 (27.1906%) from the higher bound.
// Initialization requires that the first bit of `r_hi` be in the correct
// position. This is correct from the normalization above.
r_hi := 0x80000000000000000000000000000000

// Seven Babylonian steps is sufficient for convergence.
r_hi := shr(0x01, add(r_hi, div(x_hi, r_hi)))
r_hi := shr(0x01, add(r_hi, div(x_hi, r_hi)))
r_hi := shr(0x01, add(r_hi, div(x_hi, r_hi)))
r_hi := shr(0x01, add(r_hi, div(x_hi, r_hi)))
r_hi := shr(0x01, add(r_hi, div(x_hi, r_hi)))
r_hi := shr(0x01, add(r_hi, div(x_hi, r_hi)))
r_hi := shr(0x01, add(r_hi, div(x_hi, r_hi)))

// The Babylonian step can oscillate between ⌊√x_hi⌋ and ⌈√x_hi⌉. Clean that up.
r_hi := sub(r_hi, lt(div(x_hi, r_hi), r_hi))
}

/// Perform 5 Newton-Raphson iterations. 5 is enough iterations for sufficient
/// convergence that our final fixup step produces an exact result.
// The Newton-Raphson iteration for 1/√M is:
// Y ≈ Y · (3 - M · Y²) / 2
// The implementation of this iteration is deliberately imprecise. No matter how many
// times you run it, you won't converge `Y` on the closest Q3.253 to √M. However, this
// is acceptable because the cleanup step applied after the final call is very tolerant
// of error in the low bits of `Y`.

// `M` is Q1.255
// `Y` is Q247.9
{
uint256 Y2 = Y * Y; // scale: 2¹⁸
// Because `M` is Q1.255, multiplying `Y2` by `M` and taking the high word
// implicitly divides `MY2` by 2. We move the division by 2 inside the subtraction
// from 3 by adjusting the minuend.
uint256 MY2 = _inaccurateMulHi(M, Y2); // scale: 2¹⁸
uint256 T = 1.5 * 2 ** 18 - MY2; // scale: 2¹⁸
Y *= T; // scale: 2²⁷
}
// `Y` is Q229.27
{
uint256 Y2 = Y * Y; // scale: 2⁵⁴
uint256 MY2 = _inaccurateMulHi(M, Y2); // scale: 2⁵⁴
uint256 T = 1.5 * 2 ** 54 - MY2; // scale: 2⁵⁴
Y *= T; // scale: 2⁸¹
}
// `Y` is Q175.81
{
uint256 Y2 = Y * Y; // scale: 2¹⁶²
uint256 MY2 = _inaccurateMulHi(M, Y2); // scale: 2¹⁶²
uint256 T = 1.5 * 2 ** 162 - MY2; // scale: 2¹⁶²
Y = Y * T >> 116; // scale: 2¹²⁷
}
// `Y` is Q129.127
if (invE < 95 - Mbucket) {
// Generally speaking, for relatively smaller `e` (lower values of `x`) and for
// relatively larger `M`, we can skip the 5th N-R iteration. The constant `95` is
// derived by extensive fuzzing. Attempting a higher-order approximation of the
// relationship between `M` and `invE` consumes, on average, more gas. When this
// branch is not taken, the correct bits that this iteration would obtain are
// shifted away during the denormalization step. This branch is net gas-optimizing.
uint256 Y2 = Y * Y; // scale: 2²⁵⁴
uint256 MY2 = _inaccurateMulHi(M, Y2); // scale: 2²⁵⁴
uint256 T = 1.5 * 2 ** 254 - MY2; // scale: 2²⁵⁴
Y = _inaccurateMulHi(Y << 2, T); // scale: 2¹²⁷
// This is cheaper than
// uint256 res = x_hi - r_hi * r_hi;
// for no clear reason
uint256 res;
assembly ("memory-safe") {
res := sub(x_hi, mul(r_hi, r_hi))
}
// `Y` is Q129.127
{
uint256 Y2 = Y * Y; // scale: 2²⁵⁴
uint256 MY2 = _inaccurateMulHi(M, Y2); // scale: 2²⁵⁴
uint256 T = 1.5 * 2 ** 254 - MY2; // scale: 2²⁵⁴
Y = _inaccurateMulHi(Y << 128, T); // scale: 2²⁵³

uint256 r_lo;
// `res` is (almost) a single limb. Create a new (almost) machine word `n` with `res` as
// the upper limb and shifting in the next limb of `x` (namely `x_lo >> 128`) as the
// lower limb. The next step of Zimmerman's algorithm is:
// r_lo = n / (2 · r_hi)
// res = n % (2 · r_hi)
assembly ("memory-safe") {
let n := or(shl(0x80, res), shr(0x80, x_lo))
let d := shl(0x01, r_hi)
r_lo := div(n, d)

let c := shr(0x80, res)
res := mod(n, d)

// It's possible that `n` was 257 bits and overflowed (`res` was not just a single
// limb). Explicitly handling the carry avoids 512-bit division.
if c {
let neg_c := not(0x00)
r_lo := add(r_lo, div(neg_c, d))
res := add(res, add(0x01, mod(neg_c, d)))
r_lo := add(r_lo, div(res, d))
res := mod(res, d)
}
}
// `Y` is Q3.253

/// When we combine `Y` with `M` to form our approximation of the square root, we have
/// to un-normalize by the half-scale value. This is where even-exponent normalization
/// comes in because the half-scale is integral.
/// M = ⌊x · 2⁽²⁵⁵⁻²ᵉ⁾⌋
/// Y ≈ 2²⁵³ / √(M / 2²⁵⁵)
/// Y ≈ 2³⁸¹ / √(2·M)
/// M·Y ≈ 2³⁸¹ · √(M/2)
/// M·Y ≈ 2⁽⁵⁰⁸⁻ᵉ⁾ · √x
/// r0 ≈ M·Y / 2⁽⁵⁰⁸⁻ᵉ⁾ ≈ ⌊√x⌋
// We shift right by `508 - e` to account for both the Q3.253 scaling and
// denormalization. We don't care about accuracy in the low bits of `r0`, so we can cut
// some corners.
(, uint256 r0) = _shr(_inaccurateMulHi(M, Y), 0, 252 + invE);

/// `r0` is only an approximation of √x, so we perform a single Babylonian step to fully
/// converge on ⌊√x⌋ or ⌈√x⌉. The Babylonian step is:
/// r = ⌊(r0 + ⌊x/r0⌋) / 2⌋
// Rather than use the more-expensive division routine that returns a 512-bit result,
// because the value the upper word of the quotient can take is highly constrained, we
// can compute the quotient mod 2²⁵⁶ and recover the high word separately. Although
// `_div` does an expensive Newton-Raphson-Hensel modular inversion:
// ⌊x/r0⌋ ≡ ⌊x/2ⁿ⌋·⌊r0/2ⁿ⌋⁻¹ mod 2²⁵⁶ (for r % 2⁽ⁿ⁺¹⁾ = 2ⁿ)
// and we already have a pretty good estimate for r0⁻¹, namely `Y`, refining `Y` into
// the appropriate inverse requires a series of 768-bit multiplications that take more
// gas.
uint256 q_lo = _div(x_hi, x_lo, r0);
uint256 q_hi = (r0 <= x_hi).toUint();
(uint256 s_hi, uint256 s_lo) = _add(q_hi, q_lo, r0);
// `oflo` here is either 0 or 1. When `oflo == 1`, `r == 0`, and the correct value for
// `r` is `type(uint256).max`.
uint256 oflo;
(oflo, r) = _shr256(s_hi, s_lo, 1);
r -= oflo; // underflow is desired
r = (r_hi << 128) + r_lo;

// Then, if res · 2¹²⁸ + x_lo % 2¹²⁸ < r_lo², decrement `r`. We have to do this in a
// complicated manner because both `res` and `r_lo` can be _slightly_ longer than 1 limb
// (128 bits). This is more efficient than performing the full 257-bit comparison.
r = r.unsafeDec(
((res >> 128) < (r_lo >> 128))
.or(
((res >> 128) == (r_lo >> 128))
.and((res << 128) | (x_lo & 0xffffffffffffffffffffffffffffffff) < r_lo * r_lo)
)
);

// Un-normalize
return r >> shift;
}
}

Expand All @@ -1853,13 +1787,7 @@ library Lib512MathArithmetic {
return x_lo.sqrt();
}

uint256 r = _sqrt(x_hi, x_lo);

// Because the Babylonian step can give ⌈√x⌉ if x+1 is a perfect square, we have to
// check whether we've overstepped by 1 and clamp as appropriate. ref:
// https://en.wikipedia.org/wiki/Integer_square_root#Using_only_integer_division
(uint256 r2_hi, uint256 r2_lo) = _mul(r, r);
return r.unsafeDec(_gt(r2_hi, r2_lo, x_hi, x_lo));
return _sqrt(x_hi, x_lo);
}

function osqrtUp(uint512 r, uint512 x) internal pure returns (uint512) {
Expand All @@ -1871,8 +1799,6 @@ library Lib512MathArithmetic {

uint256 r_lo = _sqrt(x_hi, x_lo);

// The Babylonian step can give ⌈√x⌉ if x+1 is a perfect square. This is
// fine. If the Babylonian step gave ⌊√x⌋ ≠ √x, we have to round up.
(uint256 r2_hi, uint256 r2_lo) = _mul(r_lo, r_lo);
uint256 r_hi;
(r_hi, r_lo) = _add(0, r_lo, _gt(x_hi, x_lo, r2_hi, r2_lo).toUint());
Expand Down
12 changes: 12 additions & 0 deletions test/0.8.25/512Math.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,18 @@ contract Lib512MathTest is Test {
}
}

function test512Math_sqrt_perfectSquare(uint256 r) external pure {
uint512 x = alloc().omul(r, r);
assertEq(x.sqrt(), r);
}

function test512Math_osqrtUp_perfectSquare(uint256 r) external pure {
uint512 x = alloc().omul(r, r);
(uint256 r_hi, uint256 r_lo) = alloc().osqrtUp(x).into();
assertEq(r_hi, 0);
assertEq(r_lo, r);
}

function test512Math_oshrUp(uint256 x_hi, uint256 x_lo, uint256 s) external pure {
s = bound(s, 0, 512);

Expand Down