Skip to content

Commit 8b66dc7

Browse files
committed
Merge remote-tracking branch 'origin/mega-refactor-curve-gas'
2 parents 22fd8de + f16493d commit 8b66dc7

File tree

2 files changed

+141
-25
lines changed

2 files changed

+141
-25
lines changed

src/libraries/CurveLib.sol

Lines changed: 52 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -39,40 +39,50 @@ library CurveLib {
3939
}
4040
}
4141

42+
/// @dev EulerSwap inverse function definition
43+
/// Pre-conditions: 0 < x <= x0, 1 <= {px,py} <= 1e36, {x0,y0} <= type(uint112).max, c <= 1e18
4244
function fInverse(uint256 y, uint256 px, uint256 py, uint256 x0, uint256 y0, uint256 c)
4345
internal
4446
pure
4547
returns (uint256)
4648
{
4749
// components of quadratic equation
48-
int256 B = int256((py * (y - y0) + (px - 1)) / px) - (2 * int256(c) - int256(1e18)) * int256(x0) / 1e18;
50+
int256 B;
4951
uint256 C;
5052
uint256 fourAC;
51-
if (x0 < 1e18) {
52-
C = ((1e18 - c) * x0 * x0 + (1e18 - 1)) / 1e18; // upper bound of 1e28 for x0 means this is safe
53-
fourAC = Math.mulDiv(4 * c, C, 1e18, Math.Rounding.Ceil);
54-
} else {
55-
C = Math.mulDiv((1e18 - c), x0 * x0, 1e36, Math.Rounding.Ceil); // upper bound of 1e28 for x0 means this is safe
56-
fourAC = Math.mulDiv(4 * c, C, 1, Math.Rounding.Ceil);
53+
unchecked {
54+
B = int256((py * (y - y0) + (px - 1)) / px) - (2 * int256(c) - int256(1e18)) * int256(x0) / 1e18;
55+
if (x0 >= 1e18) {
56+
// if x0 >= 1, scale as normal
57+
C = Math.mulDiv((1e18 - c), x0 * x0, 1e36, Math.Rounding.Ceil);
58+
fourAC = 4 * c * C;
59+
} else {
60+
// if x0 < 1, then numbers get very small, so decrease scale to 1e18 to increase precision later
61+
C = ((1e18 - c) * x0 * x0 + (1e18 - 1)) / 1e18;
62+
fourAC = Math.mulDiv(4 * c, C, 1e18, Math.Rounding.Ceil);
63+
}
5764
}
58-
59-
// solve for the square root
60-
uint256 absB = abs(B);
65+
66+
uint256 absB = uint256(B >= 0 ? B : -B);
6167
uint256 squaredB;
6268
uint256 discriminant;
6369
uint256 sqrt;
64-
if (absB > 1e33) {
70+
if (absB < 1e36) {
71+
// safe to use naive squaring
72+
unchecked {
73+
squaredB = absB * absB;
74+
discriminant = squaredB + fourAC; // keep in 1e36 scale for increased precision ahead of sqrt
75+
sqrt = Math.sqrt(discriminant); // drop back to 1e18 scale
76+
sqrt = (sqrt * sqrt < discriminant) ? sqrt + 1 : sqrt;
77+
}
78+
} else {
79+
// use scaled, overflow-safe path
6580
uint256 scale = computeScale(absB);
6681
squaredB = Math.mulDiv(absB / scale, absB, scale, Math.Rounding.Ceil);
6782
discriminant = squaredB + fourAC / (scale * scale);
6883
sqrt = Math.sqrt(discriminant);
6984
sqrt = (sqrt * sqrt < discriminant) ? sqrt + 1 : sqrt;
7085
sqrt = sqrt * scale;
71-
} else {
72-
squaredB = Math.mulDiv(absB, absB, 1, Math.Rounding.Ceil);
73-
discriminant = squaredB + fourAC; // keep in 1e36 scale for increased precision ahead of sqrt
74-
sqrt = Math.sqrt(discriminant); // drop back to 1e18 scale
75-
sqrt = (sqrt * sqrt < discriminant) ? sqrt + 1 : sqrt;
7686
}
7787

7888
uint256 x;
@@ -90,25 +100,42 @@ library CurveLib {
90100
}
91101

92102
function computeScale(uint256 x) internal pure returns (uint256 scale) {
103+
// calculate number of bits in x
93104
uint256 bits = 0;
94-
uint256 tmp = x;
95-
96-
while (tmp > 0) {
97-
tmp >>= 1;
105+
while (x > 0) {
106+
x >>= 1;
98107
bits++;
99108
}
100109

101-
// absB * absB must be <= 2^256 ⇒ bits(B) ≤ 128
110+
// 2^excessBits is how much we need to scale down to prevent overflow when squaring x
102111
if (bits > 128) {
103-
uint256 excessBits = bits - 128;
104-
// 2^excessBits is how much we need to scale down to prevent overflow
112+
uint256 excessBits = bits - 128;
105113
scale = 1 << excessBits;
106114
} else {
107115
scale = 1;
108116
}
109117
}
110118

111-
function abs(int256 x) internal pure returns (uint256) {
112-
return uint256(x >= 0 ? x : -x);
119+
function binarySearch(IEulerSwap.Params memory p, uint256 newReserve1, uint256 xMin, uint256 xMax)
120+
internal
121+
pure
122+
returns (uint256)
123+
{
124+
if (xMin < 1) {
125+
xMin = 1;
126+
}
127+
while (xMin < xMax) {
128+
uint256 xMid = (xMin + xMax) / 2;
129+
uint256 fxMid = f(xMid, p.priceX, p.priceY, p.equilibriumReserve0, p.equilibriumReserve1, p.concentrationX);
130+
if (newReserve1 >= fxMid) {
131+
xMax = xMid;
132+
} else {
133+
xMin = xMid + 1;
134+
}
135+
}
136+
if (newReserve1 < f(xMin, p.priceX, p.priceY, p.equilibriumReserve0, p.equilibriumReserve1, p.concentrationX)) {
137+
xMin += 1;
138+
}
139+
return xMin;
113140
}
114141
}

test/CurveLib.t.sol

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
// SPDX-License-Identifier: GPL-2.0-or-later
2+
pragma solidity ^0.8.24;
3+
4+
import "forge-std/Test.sol";
5+
import "forge-std/console.sol";
6+
import {EulerSwapTestBase, EulerSwap, TestERC20} from "./EulerSwapTestBase.t.sol";
7+
import {IEulerSwap} from "../src/interfaces/IEulerSwap.sol";
8+
import {CurveLib} from "../src/CurveLib.sol";
9+
10+
contract CurveLibTest is EulerSwapTestBase {
11+
EulerSwap public eulerSwap;
12+
13+
function setUp() public virtual override {
14+
super.setUp();
15+
}
16+
17+
function testGas_fInverse() public pure {
18+
// Set representative values within valid bounds
19+
uint256 px = 1e18;
20+
uint256 py = 1e18;
21+
uint256 x0 = 1e14;
22+
uint256 y0 = 1e14;
23+
uint256 c = 1e18;
24+
25+
// Use CurveLib.f to get a valid y
26+
uint256 x = 1e12;
27+
uint256 y = CurveLib.f(x, px, py, x0, y0, c);
28+
29+
// Measure gas of fInverse
30+
CurveLib.fInverse(y, px, py, x0, y0, c);
31+
}
32+
33+
function test_fuzzfInverse(uint256 x, uint256 px, uint256 py, uint256 x0, uint256 y0, uint256 cx, uint256 cy)
34+
public
35+
view
36+
{
37+
// Params
38+
px = 1e18;
39+
py = bound(py, 1, 1e36);
40+
x0 = bound(x0, 1e2, 1e28);
41+
y0 = bound(y0, 0, 1e28);
42+
cx = bound(cx, 1, 1e18);
43+
cy = bound(cy, 1, 1e18);
44+
console.log("px", px);
45+
console.log("py", py);
46+
console.log("x0", x0);
47+
console.log("y0", y0);
48+
console.log("cx", cx);
49+
console.log("cy", cy);
50+
51+
IEulerSwap.Params memory p = IEulerSwap.Params({
52+
vault0: address(0),
53+
vault1: address(0),
54+
eulerAccount: address(0),
55+
equilibriumReserve0: uint112(x0),
56+
equilibriumReserve1: uint112(y0),
57+
priceX: px,
58+
priceY: py,
59+
concentrationX: cx,
60+
concentrationY: cy,
61+
fee: 0,
62+
protocolFee: 0,
63+
protocolFeeRecipient: address(0)
64+
});
65+
66+
// Note without -2 in the max bound, f() sometimes fails when x gets too close to centre.
67+
// Note small x values lead to large y-values, which causes problems for both f() and fInverse(), so we cap it here
68+
x = bound(x, 1e2 - 3, x0 - 3);
69+
70+
uint256 y = CurveLib.f(x, px, py, x0, y0, cx);
71+
console.log("y ", y);
72+
uint256 xCalc = CurveLib.fInverse(y, px, py, x0, y0, cx);
73+
console.log("xCalc", xCalc);
74+
uint256 yCalc = CurveLib.f(xCalc, px, py, x0, y0, cx);
75+
uint256 xBin = CurveLib.binarySearch(p, y, 1, x0);
76+
uint256 yBin = CurveLib.f(xBin, px, py, x0, y0, cx);
77+
console.log("x ", x);
78+
console.log("xCalc", xCalc);
79+
console.log("xBin ", xBin);
80+
console.log("y ", y);
81+
console.log("yCalc", yCalc);
82+
console.log("yBin ", yBin);
83+
84+
if (x < type(uint112).max && y < type(uint112).max) {
85+
assert(CurveLib.verify(p, xCalc, y));
86+
assert(int256(xCalc) - int256(xBin) <= 3 || int256(yCalc) - int256(yBin) <= 3); // suspect this is 2 wei error in fInverse() + 1 wei error in f()
87+
}
88+
}
89+
}

0 commit comments

Comments
 (0)