Skip to content

Commit 2784001

Browse files
committed
update PythUtils
1 parent dfe99e2 commit 2784001

File tree

4 files changed

+147
-62
lines changed

4 files changed

+147
-62
lines changed

target_chains/ethereum/contracts/forge-test/utils/PythTestUtils.t.sol

Lines changed: 77 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -370,36 +370,53 @@ abstract contract PythTestUtils is Test, WormholeTestUtils, RandTestUtils {
370370
}
371371

372372
contract PythUtilsTest is Test, WormholeTestUtils, PythTestUtils, IPythEvents {
373-
function successTest(int64 price1, int32 expo1, int64 price2, int32 expo2, int64 expectedPrice, int32 expectedExpo) internal {
374-
(int64 price, int32 expo) = PythUtils.deriveCrossRate(price1, expo1, price2, expo2);
373+
function assertCrossRateEquals(
374+
int64 price1,
375+
int32 expo1,
376+
int64 price2,
377+
int32 expo2,
378+
int32 targetExpo,
379+
380+
int64 expectedPrice,
381+
int32 expectedExpo
382+
) internal {
383+
(int64 price, int32 expo) = PythUtils.deriveCrossRate(price1, expo1, price2, expo2, targetExpo);
375384
assertEq(price, expectedPrice);
376385
assertEq(expo, expectedExpo);
377386
}
378387

379-
function revertTest(int64 price1, int32 expo1, int64 price2, int32 expo2) internal {
380-
vm.expectRevert();
381-
PythUtils.deriveCrossRate(price1, expo1, price2, expo2);
388+
function assertCrossRateReverts(
389+
int64 price1,
390+
int32 expo1,
391+
int64 price2,
392+
int32 expo2,
393+
int32 targetExpo,
394+
bytes4 expectedError
395+
) internal {
396+
vm.expectRevert(expectedError);
397+
PythUtils.deriveCrossRate(price1, expo1, price2, expo2, targetExpo);
382398
}
383399

384400
function testConvertToUnit() public {
385401
// Price can't be negative
386-
vm.expectRevert();
402+
vm.expectRevert(PythErrors.NegativeInputPrice.selector);
387403
PythUtils.convertToUint(-100, -5, 18);
388404

389-
// Exponent can't be positive
390-
vm.expectRevert();
391-
PythUtils.convertToUint(100, 5, 18);
405+
// Exponent can't be less than -255
406+
vm.expectRevert(PythErrors.InvalidInputExpo.selector);
407+
PythUtils.convertToUint(100, -256, 18);
392408

409+
// Negative Exponent Tests
393410
// Price with 18 decimals and exponent -5
394411
assertEq(
395412
PythUtils.convertToUint(100, -5, 18),
396-
1000000000000000 // 100 * 10^13
413+
100_0_000_000_000_000 // 100 * 10^13
397414
);
398415

399416
// Price with 9 decimals and exponent -2
400417
assertEq(
401418
PythUtils.convertToUint(100, -2, 9),
402-
1000000000 // 100 * 10^7
419+
100_0_000_000 // 100 * 10^7
403420
);
404421

405422
// Price with 4 decimals and exponent -5
@@ -409,45 +426,66 @@ contract PythUtilsTest is Test, WormholeTestUtils, PythTestUtils, IPythEvents {
409426
// @note: We will lose precision here as price is
410427
// 0.00001 and we are targetDecimals is 2.
411428
assertEq(PythUtils.convertToUint(100, -5, 2), 0);
429+
430+
assertEq(PythUtils.convertToUint(123, -8, 5), 0);
431+
432+
// Positive Exponent Tests
433+
// Price with 18 decimals and exponent 5
434+
assertEq(PythUtils.convertToUint(100, 5, 18), 100_00_000_000_000_000_000_000_000); // 100 with23 zeros
435+
436+
// Price with 9 decimals and exponent 2
437+
assertEq(PythUtils.convertToUint(100, 2, 9), 100_00_000_000_000); // 100 with 11 zeros
438+
439+
// Price with 4 decimals and exponent 5
440+
assertEq(PythUtils.convertToUint(100, 1, 2), 100_000); // 100 with 3 zeros
412441
}
413442

414443
function testCombinePrices() public {
415444

416445
// Basic Tests
417-
successTest(100, -2, 100, -2, 100, -2);
418-
successTest(10000, -2, 100, -2, 10000, -2);
419-
successTest(1_000_000, -2, 10_000, -2, 10_000, -2);
446+
assertCrossRateEquals(500, -8, 500, -8, -5, 100000, -5);
447+
assertCrossRateEquals(10_000, -8, 100, -2, -5, 10, -5);
448+
assertCrossRateEquals(10_000, -2, 100, -8, -4, 1_000_000_000_000, -4);
420449

421450
// Negative Price Tests
422-
revertTest(-100, -2, 100, -2);
423-
revertTest(100, -2, -100, -2);
424-
revertTest(-100, -2, -100, -2);
451+
assertCrossRateReverts(-100, -2, 100, -2, -5, PythErrors.NegativeInputPrice.selector);
452+
assertCrossRateReverts(100, -2, -100, -2, -5, PythErrors.NegativeInputPrice.selector);
453+
assertCrossRateReverts(-100, -2, -100, -2, -5, PythErrors.NegativeInputPrice.selector);
425454

426455
// Positive Exponent Tests
427-
revertTest(100, 2, 100, -2);
428-
revertTest(100, -2, 100, 2);
429-
revertTest(100, 2, 100, 2);
456+
assertCrossRateReverts(100, 2, 100, -2, -5, PythErrors.InvalidInputExpo.selector);
457+
assertCrossRateReverts(100, -2, 100, 2, -5, PythErrors.InvalidInputExpo.selector);
458+
assertCrossRateReverts(100, 2, 100, 2, -5, PythErrors.InvalidInputExpo.selector);
459+
460+
// Invalid Target Exponent Tests
461+
assertCrossRateReverts(100, -2, 100, -2, 1, PythErrors.InvalidTargetExpo.selector);
430462

431463
// Different Exponent Tests
432-
successTest(10_000, -2, 100, -4, 100_000_000, -4);
433-
successTest(10_000, -2, 10_000, 0, 1, -2);
434-
successTest(10_000, 0, 10_000, 0, 1, 0);
435-
436-
// End Range Tests
437-
successTest(int64(type(int64).max), 0, int64(type(int64).max), 0, 1, 0);
438-
successTest(int64(type(int64).max), 0, 1, 0, int64(type(int64).max), 0);
439-
successTest(1, 0, int64(type(int64).max), 0, 1 / int64(type(int64).max), 0);
440-
revertTest(10_000, -2, 10_000, -256);
464+
assertCrossRateEquals(10_000, -2, 100, -4, -4, 100_000_000, -4);
465+
assertCrossRateEquals(10_000, -2, 10_000, -1, -2, 10, -2);
466+
assertCrossRateEquals(10_000, -10, 10_000, -2, 0, 0, 0); // It will truncate to 0
467+
468+
// Exponent Edge Tests
469+
assertCrossRateEquals(10_000, 0, 100, 0, 0, 100, 0);
470+
assertCrossRateEquals(10_000, 0, 100, 0, -255, 100, -255);
471+
// assertCrossRateEquals(10_000, 0, 100, -255, -255, 100, -255);
472+
// assertCrossRateEquals(10_000, -255, 100, 0, 0, 100, 0);
473+
474+
// // End Range Tests
475+
// successTest(int64(type(int64).max), 0, int64(type(int64).max), 0, 1, 0);
476+
// successTest(int64(type(int64).max), 0, 1, 0, int64(type(int64).max), 0);
477+
// successTest(1, 0, int64(type(int64).max), 0, 1 / int64(type(int64).max), 0);
478+
// revertTest(10_000, -2, 10_000, -256);
441479

442-
// More Realistic Tests
443-
// Test case 1: (StEth/Eth / Eth/USD = ETH/BTC)
444-
(int64 price, int32 expo) = PythUtils.deriveCrossRate(206487956502, -8, 206741615681, -8);
445-
assertApproxEqRel(price, 100000000, 9e17); // $1
446-
assertEq(expo, -8);
447-
448-
// Test case 2:
449-
(price, expo) = PythUtils.deriveCrossRate(520010, -8, 38591, -8);
450-
assertApproxEqRel(price, 1347490347, 9e17); // $1
451-
assertEq(expo, -8);
480+
// // More Realistic Tests
481+
// // Test case 1: (StEth/Eth / Eth/USD = ETH/BTC)
482+
// (int64 price, int32 expo) = PythUtils.deriveCrossRate(206487956502, -8, 206741615681, -8);
483+
// assertApproxEqRel(price, 100000000, 9e17); // $1
484+
// assertEq(expo, -8);
485+
486+
// // Test case 2:
487+
// (price, expo) = PythUtils.deriveCrossRate(520010, -8, 38591, -8);
488+
// assertApproxEqRel(price, 1347490347, 9e17); // $1
489+
// assertEq(expo, -8);
452490
}
453491
}

target_chains/ethereum/sdk/solidity/Math.sol

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,4 +157,21 @@ library Math {
157157
// return mulDiv(x, y, denominator) + toUint(unsignedRoundsUp(rounding) && mulmod(x, y, denominator) > 0);
158158
// }
159159

160+
/**
161+
* @dev Returns the absolute unsigned value of a signed value.
162+
*/
163+
function abs(int256 n) internal pure returns (uint256) {
164+
unchecked {
165+
// Formula from the "Bit Twiddling Hacks" by Sean Eron Anderson.
166+
// Since `n` is a signed integer, the generated bytecode will use the SAR opcode to perform the right shift,
167+
// taking advantage of the most significant (or "sign" bit) in two's complement representation.
168+
// This opcode adds new most significant bits set to the value of the previous most significant bit. As a result,
169+
// the mask will either be `bytes32(0)` (if n is positive) or `~bytes32(0)` (if n is negative).
170+
int256 mask = n >> 255;
171+
172+
// A `bytes32(0)` mask leaves the input unchanged, while a `~bytes32(0)` mask complements it.
173+
return uint256((n + mask) ^ mask);
174+
}
175+
}
176+
160177
}

target_chains/ethereum/sdk/solidity/PythErrors.sol

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ library PythErrors {
5151
error InvalidTwapUpdateDataSet();
5252
// The Input Price is negative.
5353
error NegativeInputPrice();
54-
// The Input Exponent is greater than 0.
55-
error PositiveInputExpo();
54+
// The Input Exponent is invalid.
55+
error InvalidInputExpo();
5656
// The target exponent is invalid.
5757
error InvalidTargetExpo();
5858
// The combined price is greater than int64.max.

target_chains/ethereum/sdk/solidity/PythUtils.sol

Lines changed: 51 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,13 @@
22
pragma solidity ^0.8.0;
33

44
import "./PythStructs.sol";
5+
import "./PythErrors.sol";
6+
import "./Math.sol";
7+
import "forge-std/console.sol";
58

69
library PythUtils {
10+
11+
uint32 internal constant PRECISION = 36;
712
/// @notice Converts a Pyth price to a uint256 with a target number of decimals
813
/// @param price The Pyth price
914
/// @param expo The Pyth price exponent
@@ -17,26 +22,28 @@ library PythUtils {
1722
int32 expo,
1823
uint8 targetDecimals
1924
) public pure returns (uint256) {
20-
if (price < 0 || expo > 0 || expo < -255) {
21-
revert();
25+
if (price < 0) {
26+
revert PythErrors.NegativeInputPrice();
27+
}
28+
if (expo < -255) {
29+
revert PythErrors.InvalidInputExpo();
2230
}
2331

24-
uint8 priceDecimals = uint8(uint32(-1 * expo));
32+
int256 combinedExpo = int32(int8(targetDecimals)) + int32(int8(uint8(uint32(expo))));
2533

26-
if (targetDecimals >= priceDecimals) {
27-
return
28-
uint(uint64(price)) *
29-
10 ** uint32(targetDecimals - priceDecimals);
34+
if (combinedExpo > 0) {
35+
return uint(uint64(price)) * 10 ** uint32(int32(combinedExpo));
3036
} else {
31-
return
32-
uint(uint64(price)) /
33-
10 ** uint32(priceDecimals - targetDecimals);
37+
return uint(uint64(price)) / 10 ** uint32(Math.abs(combinedExpo));
3438
}
3539
}
3640

3741
/// @notice Combines two prices to get a cross-rate
3842
/// @param price1 The first price (a/b)
43+
/// @param expo1 The exponent of the first price
3944
/// @param price2 The second price (c/b)
45+
/// @param expo2 The exponent of the second price
46+
/// @param targetExpo The target exponent of the cross-rate
4047
/// @return crossRate The cross-rate (a/c)
4148
/// @return expo The exponent of the cross-rate
4249
/// @dev This function will revert if either price is negative or if the exponents are invalid.
@@ -46,25 +53,48 @@ library PythUtils {
4653
int64 price1,
4754
int32 expo1,
4855
int64 price2,
49-
int32 expo2
56+
int32 expo2,
57+
int32 targetExpo
5058
) public pure returns (int64 crossRate, int32 expo) {
51-
if (price1 < 0 || price2 < 0 || expo1 > 0 || expo2 > 0 || expo1 < -255 || expo2 < -255) {
52-
revert();
59+
// Check if the input prices are negative
60+
if (price1 < 0 || price2 < 0) {
61+
revert PythErrors.NegativeInputPrice();
62+
}
63+
// Check if the input exponents are valid and not less than -255
64+
if (expo1 > 0 || expo2 > 0 || expo1 < -255 || expo2 < -255) {
65+
revert PythErrors.InvalidInputExpo();
66+
}
67+
// Check if the target exponent is valid and not less than -255
68+
if (targetExpo > 0 || targetExpo < -255) {
69+
revert PythErrors.InvalidTargetExpo();
70+
}
71+
72+
// Calculate the combined price with precision of 36
73+
uint256 fixedPointPrice = Math.mulDiv(uint64(price1), 10 ** PRECISION, uint64(price2));
74+
// TODO: Check for underflow
75+
int32 combinedExpo = expo1 - expo2 - int32(PRECISION);
76+
console.log("fixedPointPrice", fixedPointPrice);
77+
console.log("combinedExpo", combinedExpo);
78+
console.log("targetExpo", targetExpo);
79+
// Convert the price to the target exponent
80+
uint256 combined;
81+
if (combinedExpo >= targetExpo) {
82+
console.log("combinedExpo >= targetExpo");
83+
// If combinedExpo is greater than or equal to targetExpo, we need to multiply
84+
combined = fixedPointPrice * 10 ** uint32(combinedExpo + targetExpo);
85+
} else {
86+
console.log("combinedExpo - targetExpo", combinedExpo - targetExpo);
87+
// If combinedExpo is less than targetExpo, we need to divide
88+
combined = fixedPointPrice / 10 ** uint32(targetExpo - combinedExpo);
5389
}
5490

55-
// Convert both prices to the same decimal places (using the larger of the two)
56-
uint8 maxDecimals = uint8(uint32(-1 * (expo1 < expo2 ? expo1 : expo2)));
57-
uint256 p1 = convertToUint(price1, expo1, maxDecimals);
58-
uint256 p2 = convertToUint(price2, expo2, maxDecimals);
91+
console.log("combined", combined);
5992

60-
// Calculate the combined price with precision
61-
uint256 combined = (p1 * 10**18) / p2;
62-
combined = combined / 10 ** (18 - maxDecimals);
6393
// Check if the combined price fits in int64
6494
if (combined > uint256(uint64(type(int64).max))) {
6595
revert();
6696
}
6797

68-
return (int64(uint64(combined)), expo1 < expo2 ? expo1 : expo2);
98+
return (int64(uint64(combined)), targetExpo);
6999
}
70100
}

0 commit comments

Comments
 (0)