@@ -39,40 +39,50 @@ library CurveLib {
39
39
}
40
40
}
41
41
42
+ /// @dev EulerSwap inverse function definition
43
+ /// Pre-conditions: 0 < x <= x0, 1 <= {px,py} <= 1e36, {x0,y0} <= type(uint112).max, c <= 1e18
42
44
function fInverse (uint256 y , uint256 px , uint256 py , uint256 x0 , uint256 y0 , uint256 c )
43
45
internal
44
46
pure
45
47
returns (uint256 )
46
48
{
47
49
// components of quadratic equation
48
- int256 B = int256 ((py * (y - y0) + (px - 1 )) / px) - ( 2 * int256 (c) - int256 ( 1e18 )) * int256 (x0) / 1e18 ;
50
+ int256 B;
49
51
uint256 C;
50
52
uint256 fourAC;
51
- if (x0 < 1e18 ) {
52
- C = ((1e18 - c) * x0 * x0 + (1e18 - 1 )) / 1e18 ; // upper bound of 1e28 for x0 means this is safe
53
- fourAC = Math.mulDiv (4 * c, C, 1e18 , Math.Rounding.Ceil);
54
- } else {
55
- C = Math.mulDiv ((1e18 - c), x0 * x0, 1e36 , Math.Rounding.Ceil); // upper bound of 1e28 for x0 means this is safe
56
- fourAC = Math.mulDiv (4 * c, C, 1 , Math.Rounding.Ceil);
53
+ unchecked {
54
+ B = int256 ((py * (y - y0) + (px - 1 )) / px) - (2 * int256 (c) - int256 (1e18 )) * int256 (x0) / 1e18 ;
55
+ if (x0 >= 1e18 ) {
56
+ // if x0 >= 1, scale as normal
57
+ C = Math.mulDiv ((1e18 - c), x0 * x0, 1e36 , Math.Rounding.Ceil);
58
+ fourAC = 4 * c * C;
59
+ } else {
60
+ // if x0 < 1, then numbers get very small, so decrease scale to 1e18 to increase precision later
61
+ C = ((1e18 - c) * x0 * x0 + (1e18 - 1 )) / 1e18 ;
62
+ fourAC = Math.mulDiv (4 * c, C, 1e18 , Math.Rounding.Ceil);
63
+ }
57
64
}
58
-
59
- // solve for the square root
60
- uint256 absB = abs (B);
65
+
66
+ uint256 absB = uint256 (B >= 0 ? B : - B);
61
67
uint256 squaredB;
62
68
uint256 discriminant;
63
69
uint256 sqrt;
64
- if (absB > 1e33 ) {
70
+ if (absB < 1e36 ) {
71
+ // safe to use naive squaring
72
+ unchecked {
73
+ squaredB = absB * absB;
74
+ discriminant = squaredB + fourAC; // keep in 1e36 scale for increased precision ahead of sqrt
75
+ sqrt = Math.sqrt (discriminant); // drop back to 1e18 scale
76
+ sqrt = (sqrt * sqrt < discriminant) ? sqrt + 1 : sqrt;
77
+ }
78
+ } else {
79
+ // use scaled, overflow-safe path
65
80
uint256 scale = computeScale (absB);
66
81
squaredB = Math.mulDiv (absB / scale, absB, scale, Math.Rounding.Ceil);
67
82
discriminant = squaredB + fourAC / (scale * scale);
68
83
sqrt = Math.sqrt (discriminant);
69
84
sqrt = (sqrt * sqrt < discriminant) ? sqrt + 1 : sqrt;
70
85
sqrt = sqrt * scale;
71
- } else {
72
- squaredB = Math.mulDiv (absB, absB, 1 , Math.Rounding.Ceil);
73
- discriminant = squaredB + fourAC; // keep in 1e36 scale for increased precision ahead of sqrt
74
- sqrt = Math.sqrt (discriminant); // drop back to 1e18 scale
75
- sqrt = (sqrt * sqrt < discriminant) ? sqrt + 1 : sqrt;
76
86
}
77
87
78
88
uint256 x;
@@ -90,40 +100,27 @@ library CurveLib {
90
100
}
91
101
92
102
function computeScale (uint256 x ) internal pure returns (uint256 scale ) {
103
+ // calculate number of bits in x
93
104
uint256 bits = 0 ;
94
- uint256 tmp = x;
95
-
96
- while (tmp > 0 ) {
97
- tmp >>= 1 ;
105
+ while (x > 0 ) {
106
+ x >>= 1 ;
98
107
bits++ ;
99
108
}
100
109
101
- // absB * absB must be <= 2^256 ⇒ bits(B) ≤ 128
110
+ // 2^excessBits is how much we need to scale down to prevent overflow when squaring x
102
111
if (bits > 128 ) {
103
- uint256 excessBits = bits - 128 ;
104
- // 2^excessBits is how much we need to scale down to prevent overflow
112
+ uint256 excessBits = bits - 128 ;
105
113
scale = 1 << excessBits;
106
114
} else {
107
115
scale = 1 ;
108
116
}
109
117
}
110
118
111
- function abs (int256 x ) internal pure returns (uint256 ) {
112
- return uint256 (x >= 0 ? x : - x);
113
- }
114
-
115
- function binarySearch (
116
- IEulerSwap.Params memory p ,
117
- uint256 newReserve1 ,
118
- // uint256 y,
119
- // uint256 px,
120
- // uint256 py,
121
- // uint256 x0,
122
- // uint256 y0,
123
- // uint256 c,
124
- uint256 xMin ,
125
- uint256 xMax
126
- ) internal pure returns (uint256 ) {
119
+ function binarySearch (IEulerSwap.Params memory p , uint256 newReserve1 , uint256 xMin , uint256 xMax )
120
+ internal
121
+ pure
122
+ returns (uint256 )
123
+ {
127
124
if (xMin < 1 ) {
128
125
xMin = 1 ;
129
126
}
0 commit comments