diff --git a/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol b/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol index cf67313fd4..6dff9b1283 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol @@ -6,11 +6,32 @@ import "@pythnetwork/pyth-sdk-solidity/IPyth.sol"; import "./PulseEvents.sol"; import "./PulseState.sol"; -interface IPulseConsumer { +abstract contract IPulseConsumer { + // This method is called by Pulse to provide the price updates to the consumer. + // It asserts that the msg.sender is the Pulse contract. It is not meant to be + // overridden by the consumer. + function _pulseCallback( + uint64 sequenceNumber, + PythStructs.PriceFeed[] memory priceFeeds + ) external { + address pulse = getPulse(); + require(pulse != address(0), "Pulse address not set"); + require(msg.sender == pulse, "Only Pulse can call this function"); + + pulseCallback(sequenceNumber, priceFeeds); + } + + // getPulse returns the Pulse contract address. The method is being used to check that the + // callback is indeed from the Pulse contract. The consumer is expected to implement this method. + function getPulse() internal view virtual returns (address); + + // This method is expected to be implemented by the consumer to handle the price updates. + // It will be called by _pulseCallback after _pulseCallback ensures that the call is + // indeed from Pulse contract. function pulseCallback( uint64 sequenceNumber, PythStructs.PriceFeed[] memory priceFeeds - ) external; + ) internal virtual; } interface IPulse is PulseEvents { @@ -18,10 +39,11 @@ interface IPulse is PulseEvents { /** * @notice Requests price updates with a callback * @dev The msg.value must be equal to getFee(callbackGasLimit) - * @param callbackGasLimit The amount of gas allocated for the callback execution + * @param provider The provider to fulfill the request * @param publishTime The minimum publish time for price updates, it should be less than or equal to block.timestamp + 60 * @param priceIds The price feed IDs to update. Maximum 10 price feeds per request. * Requests requiring more feeds should be split into multiple calls. + * @param callbackGasLimit The amount of gas allocated for the callback execution * @return sequenceNumber The sequence number assigned to this request * @dev Security note: The 60-second future limit on publishTime prevents a DoS vector where * attackers could submit many low-fee requests for far-future updates when gas prices @@ -30,7 +52,8 @@ interface IPulse is PulseEvents { * the fee estimation unreliable. */ function requestPriceUpdatesWithCallback( - uint256 publishTime, + address provider, + uint64 publishTime, bytes32[] calldata priceIds, uint256 callbackGasLimit ) external payable returns (uint64 sequenceNumber); @@ -39,11 +62,13 @@ interface IPulse is PulseEvents { * @notice Executes the callback for a price update request * @dev Requires 1.5x the callback gas limit to account for cross-contract call overhead * For example, if callbackGasLimit is 1M, the transaction needs at least 1.5M gas + some gas for some other operations in the function before the callback + * @param providerToCredit The provider to credit for fulfilling the request. This may not be the provider that submitted the request (if the exclusivity period has elapsed). * @param sequenceNumber The sequence number of the request * @param updateData The raw price update data from Pyth * @param priceIds The price feed IDs to update, must match the request */ function executeCallback( + address providerToCredit, uint64 sequenceNumber, bytes[] calldata updateData, bytes32[] calldata priceIds @@ -59,15 +84,22 @@ interface IPulse is PulseEvents { /** * @notice Calculates the total fee required for a price update request - * @dev Total fee = base Pyth protocol fee + gas costs for callback + * @dev Total fee = base Pyth protocol fee + base provider fee + provider fee per feed + gas costs for callback + * @param provider The provider to fulfill the request * @param callbackGasLimit The amount of gas allocated for callback execution + * @param priceIds The price feed IDs to update. * @return feeAmount The total fee in wei that must be provided as msg.value */ function getFee( - uint256 callbackGasLimit + address provider, + uint256 callbackGasLimit, + bytes32[] calldata priceIds ) external view returns (uint128 feeAmount); - function getAccruedFees() external view returns (uint128 accruedFeesInWei); + function getAccruedPythFees() + external + view + returns (uint128 accruedFeesInWei); function getRequest( uint64 sequenceNumber @@ -83,9 +115,18 @@ interface IPulse is PulseEvents { function withdrawAsFeeManager(address provider, uint128 amount) external; - function registerProvider(uint128 feeInWei) external; + function registerProvider( + uint128 baseFeeInWei, + uint128 feePerFeedInWei, + uint128 feePerGasInWei + ) external; - function setProviderFee(uint128 newFeeInWei) external; + function setProviderFee( + address provider, + uint128 newBaseFeeInWei, + uint128 newFeePerFeedInWei, + uint128 newFeePerGasInWei + ) external; function getProviderInfo( address provider diff --git a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol index 60e097aef7..9483bc015c 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol @@ -53,17 +53,19 @@ abstract contract Pulse is IPulse, PulseState { } } + // TODO: there can be a separate wrapper function that defaults the provider (or uses the cheapest or something). function requestPriceUpdatesWithCallback( - uint256 publishTime, + address provider, + uint64 publishTime, bytes32[] calldata priceIds, uint256 callbackGasLimit ) external payable override returns (uint64 requestSequenceNumber) { - address provider = _state.defaultProvider; require( _state.providers[provider].isRegistered, "Provider not registered" ); + // FIXME: this comment is wrong. (we're not using tx.gasprice) // NOTE: The 60-second future limit on publishTime prevents a DoS vector where // attackers could submit many low-fee requests for far-future updates when gas prices // are low, forcing executors to fulfill them later when gas prices might be much higher. @@ -75,7 +77,7 @@ abstract contract Pulse is IPulse, PulseState { } requestSequenceNumber = _state.currentSequenceNumber++; - uint128 requiredFee = getFee(callbackGasLimit); + uint128 requiredFee = getFee(provider, callbackGasLimit, priceIds); if (msg.value < requiredFee) revert InsufficientFee(); Request storage req = allocRequest(requestSequenceNumber); @@ -85,21 +87,20 @@ abstract contract Pulse is IPulse, PulseState { req.requester = msg.sender; req.numPriceIds = uint8(priceIds.length); req.provider = provider; + req.fee = SafeCast.toUint128(msg.value - _state.pythFeeInWei); // Copy price IDs to storage for (uint8 i = 0; i < priceIds.length; i++) { req.priceIds[i] = priceIds[i]; } - - _state.providers[provider].accruedFeesInWei += SafeCast.toUint128( - msg.value - _state.pythFeeInWei - ); _state.accruedFeesInWei += _state.pythFeeInWei; emit PriceUpdateRequested(req, priceIds); } + // TODO: does this need to be payable? Any cost paid to Pyth could be taken out of the provider's accrued fees. function executeCallback( + address providerToCredit, uint64 sequenceNumber, bytes[] calldata updateData, bytes32[] calldata priceIds @@ -111,7 +112,7 @@ abstract contract Pulse is IPulse, PulseState { block.timestamp < req.publishTime + _state.exclusivityPeriodSeconds ) { require( - msg.sender == req.provider, + providerToCredit == req.provider, "Only assigned provider during exclusivity period" ); } @@ -127,19 +128,41 @@ abstract contract Pulse is IPulse, PulseState { } } - // Parse price feeds first to measure gas usage - PythStructs.PriceFeed[] memory priceFeeds = IPyth(_state.pyth) - .parsePriceFeedUpdates( - updateData, - priceIds, - SafeCast.toUint64(req.publishTime), - SafeCast.toUint64(req.publishTime) - ); + // TODO: should this use parsePriceFeedUpdatesUnique? also, do we need to add 1 to maxPublishTime? + IPyth pyth = IPyth(_state.pyth); + uint256 pythFee = pyth.getUpdateFee(updateData); + PythStructs.PriceFeed[] memory priceFeeds = pyth.parsePriceFeedUpdates{ + value: pythFee + }( + updateData, + priceIds, + SafeCast.toUint64(req.publishTime), + SafeCast.toUint64(req.publishTime) + ); + + // TODO: if this effect occurs here, we need to guarantee that executeCallback can never revert. + // If executeCallback can revert, then funds can be permanently locked in the contract. + // TODO: there also needs to be some penalty mechanism in case the expected provider doesn't execute the callback. + // This should take funds from the expected provider and give to providerToCredit. The penalty should probably scale + // with time in order to ensure that the callback eventually gets executed. + // (There may be exploits with ^ though if the consumer contract is malicious ?) + _state.providers[providerToCredit].accruedFeesInWei += SafeCast + .toUint128((req.fee + msg.value) - pythFee); clearRequest(sequenceNumber); + // TODO: I'm pretty sure this is going to use a lot of gas because it's doing a storage lookup for each sequence number. + // a better solution would be a doubly-linked list of active requests. + // After successful callback, update firstUnfulfilledSeq if needed + while ( + _state.firstUnfulfilledSeq < _state.currentSequenceNumber && + !isActive(findRequest(_state.firstUnfulfilledSeq)) + ) { + _state.firstUnfulfilledSeq++; + } + try - IPulseConsumer(req.requester).pulseCallback{ + IPulseConsumer(req.requester)._pulseCallback{ gas: req.callbackGasLimit }(sequenceNumber, priceFeeds) { @@ -149,7 +172,7 @@ abstract contract Pulse is IPulse, PulseState { // Explicit revert/require emit PriceUpdateCallbackFailed( sequenceNumber, - msg.sender, + providerToCredit, priceIds, req.requester, reason @@ -158,20 +181,12 @@ abstract contract Pulse is IPulse, PulseState { // Out of gas or other low-level errors emit PriceUpdateCallbackFailed( sequenceNumber, - msg.sender, + providerToCredit, priceIds, req.requester, "low-level error (possibly out of gas)" ); } - - // After successful callback, update firstUnfulfilledSeq if needed - while ( - _state.firstUnfulfilledSeq < _state.currentSequenceNumber && - !isActive(findRequest(_state.firstUnfulfilledSeq)) - ) { - _state.firstUnfulfilledSeq++; - } } function emitPriceUpdate( @@ -182,13 +197,16 @@ abstract contract Pulse is IPulse, PulseState { int64[] memory prices = new int64[](priceFeeds.length); uint64[] memory conf = new uint64[](priceFeeds.length); int32[] memory expos = new int32[](priceFeeds.length); - uint256[] memory publishTimes = new uint256[](priceFeeds.length); + uint64[] memory publishTimes = new uint64[](priceFeeds.length); for (uint i = 0; i < priceFeeds.length; i++) { prices[i] = priceFeeds[i].price.price; conf[i] = priceFeeds[i].price.conf; expos[i] = priceFeeds[i].price.expo; - publishTimes[i] = priceFeeds[i].price.publishTime; + // Safe cast because this is a unix timestamp in seconds. + publishTimes[i] = SafeCast.toUint64( + priceFeeds[i].price.publishTime + ); } emit PriceUpdateExecuted( @@ -203,14 +221,25 @@ abstract contract Pulse is IPulse, PulseState { } function getFee( - uint256 callbackGasLimit + address provider, + uint256 callbackGasLimit, + bytes32[] calldata priceIds ) public view override returns (uint128 feeAmount) { uint128 baseFee = _state.pythFeeInWei; // Fixed fee to Pyth - uint128 providerFeeInWei = _state - .providers[_state.defaultProvider] - .feeInWei; // Provider's per-gas rate + // Note: The provider needs to set its fees to include the fee charged by the Pyth contract. + // Ideally, we would be able to automatically compute the pyth fees from the priceIds, but the + // fee computation on IPyth assumes it has the full updated data. + uint128 providerBaseFee = _state.providers[provider].baseFeeInWei; + uint128 providerFeedFee = SafeCast.toUint128( + priceIds.length * _state.providers[provider].feePerFeedInWei + ); + uint128 providerFeeInWei = _state.providers[provider].feePerGasInWei; // Provider's per-gas rate uint256 gasFee = callbackGasLimit * providerFeeInWei; // Total provider fee based on gas - feeAmount = baseFee + SafeCast.toUint128(gasFee); // Total fee user needs to pay + feeAmount = + baseFee + + providerBaseFee + + providerFeedFee + + SafeCast.toUint128(gasFee); // Total fee user needs to pay } function getPythFeeInWei() @@ -222,7 +251,7 @@ abstract contract Pulse is IPulse, PulseState { pythFeeInWei = _state.pythFeeInWei; } - function getAccruedFees() + function getAccruedPythFees() public view override @@ -244,6 +273,7 @@ abstract contract Pulse is IPulse, PulseState { shortHash = uint8(hash[0] & NUM_REQUESTS_MASK); } + // TODO: move out governance functions into a separate PulseGovernance contract function withdrawFees(uint128 amount) external override { require(msg.sender == _state.admin, "Only admin can withdraw fees"); require(_state.accruedFeesInWei >= amount, "Insufficient balance"); @@ -336,22 +366,51 @@ abstract contract Pulse is IPulse, PulseState { emit FeesWithdrawn(msg.sender, amount); } - function registerProvider(uint128 feeInWei) external override { + function registerProvider( + uint128 baseFeeInWei, + uint128 feePerFeedInWei, + uint128 feePerGasInWei + ) external override { ProviderInfo storage provider = _state.providers[msg.sender]; require(!provider.isRegistered, "Provider already registered"); - provider.feeInWei = feeInWei; + provider.baseFeeInWei = baseFeeInWei; + provider.feePerFeedInWei = feePerFeedInWei; + provider.feePerGasInWei = feePerGasInWei; provider.isRegistered = true; - emit ProviderRegistered(msg.sender, feeInWei); + emit ProviderRegistered(msg.sender, feePerGasInWei); } - function setProviderFee(uint128 newFeeInWei) external override { + function setProviderFee( + address provider, + uint128 newBaseFeeInWei, + uint128 newFeePerFeedInWei, + uint128 newFeePerGasInWei + ) external override { require( - _state.providers[msg.sender].isRegistered, + _state.providers[provider].isRegistered, "Provider not registered" ); - uint128 oldFee = _state.providers[msg.sender].feeInWei; - _state.providers[msg.sender].feeInWei = newFeeInWei; - emit ProviderFeeUpdated(msg.sender, oldFee, newFeeInWei); + require( + msg.sender == provider || + msg.sender == _state.providers[provider].feeManager, + "Only provider or fee manager can invoke this method" + ); + + uint128 oldBaseFee = _state.providers[provider].baseFeeInWei; + uint128 oldFeePerFeed = _state.providers[provider].feePerFeedInWei; + uint128 oldFeePerGas = _state.providers[provider].feePerGasInWei; + _state.providers[provider].baseFeeInWei = newBaseFeeInWei; + _state.providers[provider].feePerFeedInWei = newFeePerFeedInWei; + _state.providers[provider].feePerGasInWei = newFeePerGasInWei; + emit ProviderFeeUpdated( + provider, + oldBaseFee, + oldFeePerFeed, + oldFeePerGas, + newBaseFeeInWei, + newFeePerFeedInWei, + newFeePerGasInWei + ); } function getProviderInfo( diff --git a/target_chains/ethereum/contracts/contracts/pulse/PulseErrors.sol b/target_chains/ethereum/contracts/contracts/pulse/PulseErrors.sol index c92f4e0858..e57719d4da 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/PulseErrors.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/PulseErrors.sol @@ -4,6 +4,7 @@ pragma solidity ^0.8.0; error NoSuchProvider(); error NoSuchRequest(); +// TODO: add expected / provided values error InsufficientFee(); error Unauthorized(); error InvalidCallbackGas(); diff --git a/target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol b/target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol index b83a8c244d..f01069e60d 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol @@ -13,7 +13,7 @@ interface PulseEvents { int64[] prices, uint64[] conf, int32[] expos, - uint256[] publishTimes + uint64[] publishTimes ); event FeesWithdrawn(address indexed recipient, uint128 amount); @@ -35,8 +35,12 @@ interface PulseEvents { event ProviderRegistered(address indexed provider, uint128 feeInWei); event ProviderFeeUpdated( address indexed provider, - uint128 oldFee, - uint128 newFee + uint128 oldBaseFee, + uint128 oldFeePerFeed, + uint128 oldFeePerGas, + uint128 newBaseFee, + uint128 newFeePerFeed, + uint128 newFeePerGas ); event DefaultProviderUpdated(address oldProvider, address newProvider); diff --git a/target_chains/ethereum/contracts/contracts/pulse/PulseState.sol b/target_chains/ethereum/contracts/contracts/pulse/PulseState.sol index a561b1544a..57560d276a 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/PulseState.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/PulseState.sol @@ -11,16 +11,22 @@ contract PulseState { struct Request { uint64 sequenceNumber; - uint256 publishTime; + uint64 publishTime; + // TODO: this is going to absolutely explode gas costs. Need to do something smarter here. + // possible solution is to hash the price ids and store the hash instead. + // The ids themselves can be retrieved from the event. bytes32[MAX_PRICE_IDS] priceIds; uint8 numPriceIds; // Actual number of price IDs used uint256 callbackGasLimit; address requester; address provider; + uint128 fee; } struct ProviderInfo { - uint128 feeInWei; + uint128 baseFeeInWei; + uint128 feePerFeedInWei; + uint128 feePerGasInWei; uint128 accruedFeesInWei; address feeManager; bool isRegistered; diff --git a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol index 147ac4e0d7..9187987cb9 100644 --- a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol +++ b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol @@ -12,13 +12,22 @@ import "../contracts/pulse/PulseEvents.sol"; import "../contracts/pulse/PulseErrors.sol"; contract MockPulseConsumer is IPulseConsumer { + address private _pulse; uint64 public lastSequenceNumber; PythStructs.PriceFeed[] private _lastPriceFeeds; + constructor(address pulse) { + _pulse = pulse; + } + + function getPulse() internal view override returns (address) { + return _pulse; + } + function pulseCallback( uint64 sequenceNumber, PythStructs.PriceFeed[] memory priceFeeds - ) external override { + ) internal override { lastSequenceNumber = sequenceNumber; for (uint i = 0; i < priceFeeds.length; i++) { _lastPriceFeeds.push(priceFeeds[i]); @@ -35,10 +44,20 @@ contract MockPulseConsumer is IPulseConsumer { } contract FailingPulseConsumer is IPulseConsumer { + address private _pulse; + + constructor(address pulse) { + _pulse = pulse; + } + + function getPulse() internal view override returns (address) { + return _pulse; + } + function pulseCallback( uint64, PythStructs.PriceFeed[] memory - ) external pure override { + ) internal pure override { revert("callback failed"); } } @@ -46,14 +65,25 @@ contract FailingPulseConsumer is IPulseConsumer { contract CustomErrorPulseConsumer is IPulseConsumer { error CustomError(string message); + address private _pulse; + + constructor(address pulse) { + _pulse = pulse; + } + + function getPulse() internal view override returns (address) { + return _pulse; + } + function pulseCallback( uint64, PythStructs.PriceFeed[] memory - ) external pure override { + ) internal pure override { revert CustomError("callback failed"); } } +// FIXME: this shouldn't be IPulseConsumer. contract PulseTest is Test, PulseEvents, IPulseConsumer { ERC1967Proxy public proxy; PulseUpgradeable public pulse; @@ -64,7 +94,11 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer { address public defaultProvider; // Constants uint128 constant PYTH_FEE = 1 wei; - uint128 constant DEFAULT_PROVIDER_FEE = 1 wei; + uint128 constant DEFAULT_PROVIDER_FEE_PER_GAS = 1 wei; + uint128 constant DEFAULT_PROVIDER_BASE_FEE = 1 wei; + uint128 constant DEFAULT_PROVIDER_FEE_PER_FEED = 10 wei; + uint constant MOCK_PYTH_FEE_PER_FEED = 10 wei; + uint128 constant CALLBACK_GAS_LIMIT = 1_000_000; bytes32 constant BTC_PRICE_FEED_ID = 0xe62df6c8b4a85fe1a67db44dc12de5db330f7ac66b72dc658afedf0f4a415b43; @@ -97,8 +131,12 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer { 15 ); vm.prank(defaultProvider); - pulse.registerProvider(DEFAULT_PROVIDER_FEE); - consumer = new MockPulseConsumer(); + pulse.registerProvider( + DEFAULT_PROVIDER_BASE_FEE, + DEFAULT_PROVIDER_FEE_PER_FEED, + DEFAULT_PROVIDER_FEE_PER_GAS + ); + consumer = new MockPulseConsumer(address(proxy)); } // Helper function to create price IDs array @@ -136,8 +174,17 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer { function mockParsePriceFeedUpdates( PythStructs.PriceFeed[] memory priceFeeds ) internal { + uint expectedFee = MOCK_PYTH_FEE_PER_FEED * priceFeeds.length; + vm.mockCall( address(pyth), + abi.encodeWithSelector(IPyth.getUpdateFee.selector), + abi.encode(expectedFee) + ); + + vm.mockCall( + address(pyth), + expectedFee, abi.encodeWithSelector(IPyth.parsePriceFeedUpdates.selector), abi.encode(priceFeeds) ); @@ -154,8 +201,10 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer { } // Helper function to calculate total fee + // FIXME: I think this helper probably needs to take some arguments. function calculateTotalFee() internal view returns (uint128) { - return pulse.getFee(CALLBACK_GAS_LIMIT); + return + pulse.getFee(defaultProvider, CALLBACK_GAS_LIMIT, createPriceIds()); } // Helper function to setup consumer request @@ -166,17 +215,18 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer { returns ( uint64 sequenceNumber, bytes32[] memory priceIds, - uint256 publishTime + uint64 publishTime ) { priceIds = createPriceIds(); - publishTime = block.timestamp; + publishTime = SafeCast.toUint64(block.timestamp); vm.deal(consumerAddress, 1 gwei); uint128 totalFee = calculateTotalFee(); vm.prank(consumerAddress); sequenceNumber = pulse.requestPriceUpdatesWithCallback{value: totalFee}( + defaultProvider, publishTime, priceIds, CALLBACK_GAS_LIMIT @@ -190,7 +240,7 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer { vm.txGasPrice(30 gwei); bytes32[] memory priceIds = createPriceIds(); - uint256 publishTime = block.timestamp; + uint64 publishTime = SafeCast.toUint64(block.timestamp); // Fund the consumer contract with enough ETH for higher gas price vm.deal(address(consumer), 1 ether); @@ -215,7 +265,8 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer { numPriceIds: 2, callbackGasLimit: CALLBACK_GAS_LIMIT, requester: address(consumer), - provider: defaultProvider + provider: defaultProvider, + fee: totalFee - PYTH_FEE }); vm.expectEmit(); @@ -223,6 +274,7 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer { vm.prank(address(consumer)); pulse.requestPriceUpdatesWithCallback{value: totalFee}( + defaultProvider, publishTime, priceIds, CALLBACK_GAS_LIMIT @@ -256,7 +308,8 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer { vm.prank(address(consumer)); vm.expectRevert(InsufficientFee.selector); pulse.requestPriceUpdatesWithCallback{value: PYTH_FEE}( // Intentionally low fee - block.timestamp, + defaultProvider, + SafeCast.toUint64(block.timestamp), priceIds, CALLBACK_GAS_LIMIT ); @@ -264,7 +317,7 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer { function testExecuteCallback() public { bytes32[] memory priceIds = createPriceIds(); - uint256 publishTime = block.timestamp; + uint64 publishTime = SafeCast.toUint64(block.timestamp); // Fund the consumer contract vm.deal(address(consumer), 1 gwei); @@ -274,12 +327,13 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer { vm.prank(address(consumer)); uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{ value: totalFee - }(publishTime, priceIds, CALLBACK_GAS_LIMIT); + }(defaultProvider, publishTime, priceIds, CALLBACK_GAS_LIMIT); // Step 2: Create mock price feeds and setup Pyth response PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( publishTime ); + // FIXME: this test doesn't ensure the Pyth fee is paid. mockParsePriceFeedUpdates(priceFeeds); // Create arrays for expected event data @@ -295,7 +349,7 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer { expectedExpos[0] = MOCK_PRICE_FEED_EXPO; expectedExpos[1] = MOCK_PRICE_FEED_EXPO; - uint256[] memory expectedPublishTimes = new uint256[](2); + uint64[] memory expectedPublishTimes = new uint64[](2); expectedPublishTimes[0] = publishTime; expectedPublishTimes[1] = publishTime; @@ -315,7 +369,12 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer { bytes[] memory updateData = createMockUpdateData(priceFeeds); vm.prank(defaultProvider); - pulse.executeCallback(sequenceNumber, updateData, priceIds); + pulse.executeCallback( + defaultProvider, + sequenceNumber, + updateData, + priceIds + ); // Verify callback was executed assertEq(consumer.lastSequenceNumber(), sequenceNumber); @@ -338,7 +397,9 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer { } function testExecuteCallbackFailure() public { - FailingPulseConsumer failingConsumer = new FailingPulseConsumer(); + FailingPulseConsumer failingConsumer = new FailingPulseConsumer( + address(proxy) + ); ( uint64 sequenceNumber, @@ -362,11 +423,18 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer { ); vm.prank(defaultProvider); - pulse.executeCallback(sequenceNumber, updateData, priceIds); + pulse.executeCallback( + defaultProvider, + sequenceNumber, + updateData, + priceIds + ); } function testExecuteCallbackCustomErrorFailure() public { - CustomErrorPulseConsumer failingConsumer = new CustomErrorPulseConsumer(); + CustomErrorPulseConsumer failingConsumer = new CustomErrorPulseConsumer( + address(proxy) + ); ( uint64 sequenceNumber, @@ -390,7 +458,12 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer { ); vm.prank(defaultProvider); - pulse.executeCallback(sequenceNumber, updateData, priceIds); + pulse.executeCallback( + defaultProvider, + sequenceNumber, + updateData, + priceIds + ); } function testExecuteCallbackWithInsufficientGas() public { @@ -412,6 +485,7 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer { vm.prank(defaultProvider); vm.expectRevert(); // Just expect any revert since it will be an out-of-gas error pulse.executeCallback{gas: 100000}( + defaultProvider, sequenceNumber, updateData, priceIds @@ -421,14 +495,14 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer { function testExecuteCallbackWithFutureTimestamp() public { // Setup request with future timestamp bytes32[] memory priceIds = createPriceIds(); - uint256 futureTime = block.timestamp + 10; // 10 seconds in future + uint64 futureTime = SafeCast.toUint64(block.timestamp + 10); // 10 seconds in future vm.deal(address(consumer), 1 gwei); uint128 totalFee = calculateTotalFee(); vm.prank(address(consumer)); uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{ value: totalFee - }(futureTime, priceIds, CALLBACK_GAS_LIMIT); + }(defaultProvider, futureTime, priceIds, CALLBACK_GAS_LIMIT); // Try to execute callback before the requested timestamp PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( @@ -439,7 +513,12 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer { vm.prank(defaultProvider); // Should succeed because we're simulating receiving future-dated price updates - pulse.executeCallback(sequenceNumber, updateData, priceIds); + pulse.executeCallback( + defaultProvider, + sequenceNumber, + updateData, + priceIds + ); // Compare price feeds array length PythStructs.PriceFeed[] memory lastFeeds = consumer.lastPriceFeeds(); @@ -456,7 +535,7 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer { function testRevertOnTooFarFutureTimestamp() public { bytes32[] memory priceIds = createPriceIds(); - uint256 farFutureTime = block.timestamp + 61; // Just over 1 minute + uint64 farFutureTime = SafeCast.toUint64(block.timestamp + 61); // Just over 1 minute vm.deal(address(consumer), 1 gwei); uint128 totalFee = calculateTotalFee(); @@ -464,6 +543,7 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer { vm.expectRevert("Too far in future"); pulse.requestPriceUpdatesWithCallback{value: totalFee}( + defaultProvider, farFutureTime, priceIds, CALLBACK_GAS_LIMIT @@ -485,12 +565,22 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer { // First execution vm.prank(defaultProvider); - pulse.executeCallback(sequenceNumber, updateData, priceIds); + pulse.executeCallback( + defaultProvider, + sequenceNumber, + updateData, + priceIds + ); // Second execution should fail vm.prank(defaultProvider); vm.expectRevert(NoSuchRequest.selector); - pulse.executeCallback(sequenceNumber, updateData, priceIds); + pulse.executeCallback( + defaultProvider, + sequenceNumber, + updateData, + priceIds + ); } function testGetFee() public { @@ -500,12 +590,22 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer { gasLimits[1] = 500_000; gasLimits[2] = 1_000_000; + bytes32[] memory priceIds = createPriceIds(); + for (uint256 i = 0; i < gasLimits.length; i++) { uint256 gasLimit = gasLimits[i]; uint128 expectedFee = SafeCast.toUint128( - DEFAULT_PROVIDER_FEE * gasLimit + DEFAULT_PROVIDER_BASE_FEE + + DEFAULT_PROVIDER_FEE_PER_FEED * + priceIds.length + + DEFAULT_PROVIDER_FEE_PER_GAS * + gasLimit ) + PYTH_FEE; - uint128 actualFee = pulse.getFee(gasLimit); + uint128 actualFee = pulse.getFee( + defaultProvider, + gasLimit, + priceIds + ); assertEq( actualFee, expectedFee, @@ -514,8 +614,13 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer { } // Test with zero gas limit - uint128 expectedMinFee = PYTH_FEE; - uint128 actualMinFee = pulse.getFee(0); + uint128 expectedMinFee = SafeCast.toUint128( + PYTH_FEE + + DEFAULT_PROVIDER_BASE_FEE + + DEFAULT_PROVIDER_FEE_PER_FEED * + priceIds.length + ); + uint128 actualMinFee = pulse.getFee(defaultProvider, 0, priceIds); assertEq( actualMinFee, expectedMinFee, @@ -530,14 +635,15 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer { vm.prank(address(consumer)); pulse.requestPriceUpdatesWithCallback{value: calculateTotalFee()}( - block.timestamp, + defaultProvider, + SafeCast.toUint64(block.timestamp), priceIds, CALLBACK_GAS_LIMIT ); // Get admin's balance before withdrawal uint256 adminBalanceBefore = admin.balance; - uint128 accruedFees = pulse.getAccruedFees(); + uint128 accruedFees = pulse.getAccruedPythFees(); // Withdraw fees as admin vm.prank(admin); @@ -550,7 +656,7 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer { "Admin balance should increase by withdrawn amount" ); assertEq( - pulse.getAccruedFees(), + pulse.getAccruedPythFees(), 0, "Contract should have no fees after withdrawal" ); @@ -580,7 +686,8 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer { vm.prank(address(consumer)); pulse.requestPriceUpdatesWithCallback{value: calculateTotalFee()}( - block.timestamp, + defaultProvider, + SafeCast.toUint64(block.timestamp), priceIds, CALLBACK_GAS_LIMIT ); @@ -662,7 +769,12 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer { priceIds[0] ) ); - pulse.executeCallback(sequenceNumber, updateData, wrongPriceIds); + pulse.executeCallback( + defaultProvider, + sequenceNumber, + updateData, + wrongPriceIds + ); } function testRevertOnTooManyPriceIds() public { @@ -685,7 +797,8 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer { ) ); pulse.requestPriceUpdatesWithCallback{value: totalFee}( - block.timestamp, + defaultProvider, + SafeCast.toUint64(block.timestamp), priceIds, CALLBACK_GAS_LIMIT ); @@ -696,26 +809,36 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer { uint128 providerFee = 1000; vm.prank(provider); - pulse.registerProvider(providerFee); + pulse.registerProvider(providerFee, providerFee, providerFee); PulseState.ProviderInfo memory info = pulse.getProviderInfo(provider); - assertEq(info.feeInWei, providerFee); + assertEq(info.feePerGasInWei, providerFee); assertTrue(info.isRegistered); } function testSetProviderFee() public { address provider = address(0x123); - uint128 initialFee = 1000; - uint128 newFee = 2000; + uint128 initialBaseFee = 1000; + uint128 initialFeePerFeed = 2000; + uint128 initialFeePerGas = 3000; + uint128 newFeePerFeed = 4000; + uint128 newBaseFee = 5000; + uint128 newFeePerGas = 6000; vm.prank(provider); - pulse.registerProvider(initialFee); + pulse.registerProvider( + initialBaseFee, + initialFeePerFeed, + initialFeePerGas + ); vm.prank(provider); - pulse.setProviderFee(newFee); + pulse.setProviderFee(provider, newBaseFee, newFeePerFeed, newFeePerGas); PulseState.ProviderInfo memory info = pulse.getProviderInfo(provider); - assertEq(info.feeInWei, newFee); + assertEq(info.baseFeeInWei, newBaseFee); + assertEq(info.feePerFeedInWei, newFeePerFeed); + assertEq(info.feePerGasInWei, newFeePerGas); } function testDefaultProvider() public { @@ -723,7 +846,7 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer { uint128 providerFee = 1000; vm.prank(provider); - pulse.registerProvider(providerFee); + pulse.registerProvider(providerFee, providerFee, providerFee); vm.prank(admin); pulse.setDefaultProvider(provider); @@ -736,21 +859,23 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer { uint128 providerFee = 1000; vm.prank(provider); - pulse.registerProvider(providerFee); - - vm.prank(admin); - pulse.setDefaultProvider(provider); + pulse.registerProvider(providerFee, providerFee, providerFee); bytes32[] memory priceIds = new bytes32[](1); priceIds[0] = bytes32(uint256(1)); - uint128 totalFee = pulse.getFee(CALLBACK_GAS_LIMIT); + uint128 totalFee = pulse.getFee(provider, CALLBACK_GAS_LIMIT, priceIds); vm.deal(address(consumer), totalFee); vm.prank(address(consumer)); uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{ value: totalFee - }(block.timestamp, priceIds, CALLBACK_GAS_LIMIT); + }( + provider, + SafeCast.toUint64(block.timestamp), + priceIds, + CALLBACK_GAS_LIMIT + ); PulseState.Request memory req = pulse.getRequest(sequenceNumber); assertEq(req.provider, provider); @@ -787,7 +912,11 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer { // Register a second provider address secondProvider = address(0x456); vm.prank(secondProvider); - pulse.registerProvider(DEFAULT_PROVIDER_FEE); + pulse.registerProvider( + DEFAULT_PROVIDER_BASE_FEE, + DEFAULT_PROVIDER_FEE_PER_FEED, + DEFAULT_PROVIDER_FEE_PER_GAS + ); // Setup request ( @@ -804,20 +933,32 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer { bytes[] memory updateData = createMockUpdateData(priceFeeds); // Try to execute with second provider during exclusivity period - vm.prank(secondProvider); vm.expectRevert("Only assigned provider during exclusivity period"); - pulse.executeCallback(sequenceNumber, updateData, priceIds); + pulse.executeCallback( + secondProvider, + sequenceNumber, + updateData, + priceIds + ); // Original provider should succeed - vm.prank(defaultProvider); - pulse.executeCallback(sequenceNumber, updateData, priceIds); + pulse.executeCallback( + defaultProvider, + sequenceNumber, + updateData, + priceIds + ); } function testExecuteCallbackAfterExclusivity() public { // Register a second provider address secondProvider = address(0x456); vm.prank(secondProvider); - pulse.registerProvider(DEFAULT_PROVIDER_FEE); + pulse.registerProvider( + DEFAULT_PROVIDER_BASE_FEE, + DEFAULT_PROVIDER_FEE_PER_FEED, + DEFAULT_PROVIDER_FEE_PER_GAS + ); // Setup request ( @@ -838,14 +979,23 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer { // Second provider should now succeed vm.prank(secondProvider); - pulse.executeCallback(sequenceNumber, updateData, priceIds); + pulse.executeCallback( + defaultProvider, + sequenceNumber, + updateData, + priceIds + ); } function testExecuteCallbackWithCustomExclusivityPeriod() public { // Register a second provider address secondProvider = address(0x456); vm.prank(secondProvider); - pulse.registerProvider(DEFAULT_PROVIDER_FEE); + pulse.registerProvider( + DEFAULT_PROVIDER_BASE_FEE, + DEFAULT_PROVIDER_FEE_PER_FEED, + DEFAULT_PROVIDER_FEE_PER_GAS + ); // Set custom exclusivity period vm.prank(admin); @@ -867,14 +1017,22 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer { // Try at 29 seconds (should fail for second provider) vm.warp(block.timestamp + 29); - vm.prank(secondProvider); vm.expectRevert("Only assigned provider during exclusivity period"); - pulse.executeCallback(sequenceNumber, updateData, priceIds); + pulse.executeCallback( + secondProvider, + sequenceNumber, + updateData, + priceIds + ); // Try at 31 seconds (should succeed for second provider) vm.warp(block.timestamp + 2); - vm.prank(secondProvider); - pulse.executeCallback(sequenceNumber, updateData, priceIds); + pulse.executeCallback( + secondProvider, + sequenceNumber, + updateData, + priceIds + ); } function testGetFirstActiveRequests() public { @@ -902,10 +1060,11 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer { } function createTestRequests(bytes32[] memory priceIds) private { - uint256 publishTime = block.timestamp; + uint64 publishTime = SafeCast.toUint64(block.timestamp); for (uint i = 0; i < 5; i++) { vm.deal(address(this), 1 ether); pulse.requestPriceUpdatesWithCallback{value: 1 ether}( + defaultProvider, publishTime, priceIds, 1000000 @@ -919,15 +1078,25 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer { ) private { // Create mock price feeds and setup Pyth response PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( - block.timestamp + SafeCast.toUint64(block.timestamp) ); mockParsePriceFeedUpdates(priceFeeds); updateData = createMockUpdateData(priceFeeds); vm.deal(defaultProvider, 2 ether); // Increase ETH allocation to prevent OutOfFunds vm.startPrank(defaultProvider); - pulse.executeCallback{value: 1 ether}(2, updateData, priceIds); - pulse.executeCallback{value: 1 ether}(4, updateData, priceIds); + pulse.executeCallback{value: 1 ether}( + defaultProvider, + 2, + updateData, + priceIds + ); + pulse.executeCallback{value: 1 ether}( + defaultProvider, + 4, + updateData, + priceIds + ); vm.stopPrank(); } @@ -1000,9 +1169,24 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer { ) private { vm.deal(defaultProvider, 3 ether); // Increase ETH allocation vm.startPrank(defaultProvider); - pulse.executeCallback{value: 1 ether}(1, updateData, priceIds); - pulse.executeCallback{value: 1 ether}(3, updateData, priceIds); - pulse.executeCallback{value: 1 ether}(5, updateData, priceIds); + pulse.executeCallback{value: 1 ether}( + defaultProvider, + 1, + updateData, + priceIds + ); + pulse.executeCallback{value: 1 ether}( + defaultProvider, + 3, + updateData, + priceIds + ); + pulse.executeCallback{value: 1 ether}( + defaultProvider, + 5, + updateData, + priceIds + ); vm.stopPrank(); } @@ -1017,7 +1201,7 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer { // Setup test data bytes32[] memory priceIds = new bytes32[](1); priceIds[0] = bytes32(uint256(1)); - uint256 publishTime = block.timestamp; + uint64 publishTime = SafeCast.toUint64(block.timestamp); uint256 callbackGasLimit = 1000000; // Create mock price feeds and setup Pyth response @@ -1031,6 +1215,7 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer { for (uint i = 0; i < 20; i++) { vm.deal(address(this), 1 ether); pulse.requestPriceUpdatesWithCallback{value: 1 ether}( + defaultProvider, publishTime, priceIds, callbackGasLimit @@ -1041,6 +1226,7 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer { vm.deal(defaultProvider, 1 ether); vm.prank(defaultProvider); pulse.executeCallback{value: 1 ether}( + defaultProvider, uint64(i + 1), updateData, priceIds @@ -1071,11 +1257,15 @@ contract PulseTest is Test, PulseEvents, IPulseConsumer { ); } + function getPulse() internal view override returns (address) { + return address(pulse); + } + // Mock implementation of pulseCallback function pulseCallback( uint64 sequenceNumber, PythStructs.PriceFeed[] memory priceFeeds - ) external override { + ) internal override { // Just accept the callback, no need to do anything with the data // This prevents the revert we're seeing }