From 4c7f090105992fb288229265781ef06df1be5acd Mon Sep 17 00:00:00 2001 From: Daniel Chew Date: Mon, 20 Jan 2025 21:29:49 +0900 Subject: [PATCH 1/9] add provider --- .../contracts/contracts/pulse/IPulse.sol | 24 +- .../contracts/contracts/pulse/Pulse.sol | 104 +++++++- .../contracts/contracts/pulse/PulseEvents.sol | 12 +- .../contracts/contracts/pulse/PulseState.sol | 11 +- .../contracts/pulse/PulseUpgradeable.sol | 2 + .../ethereum/contracts/forge-test/Pulse.t.sol | 223 ++++++++++++------ 6 files changed, 284 insertions(+), 92 deletions(-) diff --git a/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol b/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol index f3d06a7704..54962e929a 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol @@ -9,7 +9,7 @@ import "./PulseState.sol"; interface IPulseConsumer { function pulseCallback( uint64 sequenceNumber, - address updater, + address provider, PythStructs.PriceFeed[] memory priceFeeds ) external; } @@ -33,7 +33,8 @@ interface IPulse is PulseEvents { function requestPriceUpdatesWithCallback( uint256 publishTime, bytes32[] calldata priceIds, - uint256 callbackGasLimit + uint256 callbackGasLimit, + address provider ) external payable returns (uint64 sequenceNumber); /** @@ -62,10 +63,12 @@ interface IPulse is PulseEvents { * @notice Calculates the total fee required for a price update request * @dev Total fee = base Pyth protocol fee + gas costs for callback * @param callbackGasLimit The amount of gas allocated for callback execution + * @param provider The provider to use for the fee calculation * @return feeAmount The total fee in wei that must be provided as msg.value */ function getFee( - uint256 callbackGasLimit + uint256 callbackGasLimit, + address provider ) external view returns (uint128 feeAmount); function getAccruedFees() external view returns (uint128 accruedFeesInWei); @@ -74,8 +77,19 @@ interface IPulse is PulseEvents { uint64 sequenceNumber ) external view returns (PulseState.Request memory req); - // Add these functions to the IPulse interface function setFeeManager(address manager) external; - function withdrawAsFeeManager(uint128 amount) external; + function withdrawAsFeeManager(address provider, uint128 amount) external; + + function registerProvider(uint128 feeInWei) external; + + function setProviderFee(uint128 newFeeInWei) external; + + function getProviderInfo( + address provider + ) external view returns (PulseState.ProviderInfo memory); + + function getDefaultProvider() external view returns (address); + + function setDefaultProvider(address provider) external; } diff --git a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol index dcd84211fc..0cc7d3a384 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol @@ -13,16 +13,22 @@ abstract contract Pulse is IPulse, PulseState { address admin, uint128 pythFeeInWei, address pythAddress, + address defaultProvider, bool prefillRequestStorage ) internal { require(admin != address(0), "admin is zero address"); require(pythAddress != address(0), "pyth is zero address"); + require( + defaultProvider != address(0), + "defaultProvider is zero address" + ); _state.admin = admin; _state.accruedFeesInWei = 0; _state.pythFeeInWei = pythFeeInWei; _state.pyth = pythAddress; _state.currentSequenceNumber = 1; + _state.defaultProvider = defaultProvider; if (prefillRequestStorage) { for (uint8 i = 0; i < NUM_REQUESTS; i++) { @@ -43,8 +49,17 @@ abstract contract Pulse is IPulse, PulseState { function requestPriceUpdatesWithCallback( uint256 publishTime, bytes32[] calldata priceIds, - uint256 callbackGasLimit + uint256 callbackGasLimit, + address provider ) external payable override returns (uint64 requestSequenceNumber) { + if (provider == address(0)) { + provider = _state.defaultProvider; + } + require( + _state.providers[provider].isRegistered, + "Provider not registered" + ); + // NOTE: The 60-second future limit on publishTime prevents a DoS vector where // attackers could submit many low-fee requests for far-future updates when gas prices // are low, forcing executors to fulfill them later when gas prices might be much higher. @@ -56,7 +71,7 @@ abstract contract Pulse is IPulse, PulseState { } requestSequenceNumber = _state.currentSequenceNumber++; - uint128 requiredFee = getFee(callbackGasLimit); + uint128 requiredFee = getFee(callbackGasLimit, provider); if (msg.value < requiredFee) revert InsufficientFee(); Request storage req = allocRequest(requestSequenceNumber); @@ -65,13 +80,17 @@ abstract contract Pulse is IPulse, PulseState { req.callbackGasLimit = callbackGasLimit; req.requester = msg.sender; req.numPriceIds = uint8(priceIds.length); + req.provider = provider; // Copy price IDs to storage for (uint8 i = 0; i < priceIds.length; i++) { req.priceIds[i] = priceIds[i]; } - _state.accruedFeesInWei += SafeCast.toUint128(msg.value); + _state.providers[provider].accruedFeesInWei += SafeCast.toUint128( + msg.value - _state.pythFeeInWei + ); + _state.accruedFeesInWei += _state.pythFeeInWei; emit PriceUpdateRequested(req, priceIds); } @@ -171,10 +190,15 @@ abstract contract Pulse is IPulse, PulseState { } function getFee( - uint256 callbackGasLimit + uint256 callbackGasLimit, + address provider ) public view override returns (uint128 feeAmount) { + if (provider == address(0)) { + provider = _state.defaultProvider; + } uint128 baseFee = _state.pythFeeInWei; - uint256 gasFee = callbackGasLimit * tx.gasprice; + uint128 providerFeeInWei = _state.providers[provider].feeInWei; + uint256 gasFee = callbackGasLimit * providerFeeInWei; feeAmount = baseFee + SafeCast.toUint128(gasFee); } @@ -271,21 +295,75 @@ abstract contract Pulse is IPulse, PulseState { } function setFeeManager(address manager) external override { - require(msg.sender == _state.admin, "Only admin can set fee manager"); - address oldFeeManager = _state.feeManager; - _state.feeManager = manager; - emit FeeManagerUpdated(_state.admin, oldFeeManager, manager); + require( + _state.providers[msg.sender].isRegistered, + "Provider not registered" + ); + address oldFeeManager = _state.providers[msg.sender].feeManager; + _state.providers[msg.sender].feeManager = manager; + emit FeeManagerUpdated(msg.sender, oldFeeManager, manager); } - function withdrawAsFeeManager(uint128 amount) external override { - require(msg.sender == _state.feeManager, "Only fee manager"); - require(_state.accruedFeesInWei >= amount, "Insufficient balance"); + function withdrawAsFeeManager( + address provider, + uint128 amount + ) external override { + require( + msg.sender == _state.providers[provider].feeManager, + "Only fee manager" + ); + require( + _state.providers[provider].accruedFeesInWei >= amount, + "Insufficient balance" + ); - _state.accruedFeesInWei -= amount; + _state.providers[provider].accruedFeesInWei -= amount; (bool sent, ) = msg.sender.call{value: amount}(""); require(sent, "Failed to send fees"); emit FeesWithdrawn(msg.sender, amount); } + + function registerProvider(uint128 feeInWei) external override { + ProviderInfo storage provider = _state.providers[msg.sender]; + require(!provider.isRegistered, "Provider already registered"); + provider.feeInWei = feeInWei; + provider.isRegistered = true; + emit ProviderRegistered(msg.sender, feeInWei); + } + + function setProviderFee(uint128 newFeeInWei) external override { + require( + _state.providers[msg.sender].isRegistered, + "Provider not registered" + ); + uint128 oldFee = _state.providers[msg.sender].feeInWei; + _state.providers[msg.sender].feeInWei = newFeeInWei; + emit ProviderFeeUpdated(msg.sender, oldFee, newFeeInWei); + } + + function getProviderInfo( + address provider + ) external view override returns (ProviderInfo memory) { + return _state.providers[provider]; + } + + function getDefaultProvider() external view override returns (address) { + return _state.defaultProvider; + } + + function setDefaultProvider(address provider) external override { + require( + msg.sender == _state.admin, + "Only admin can set default provider" + ); + require( + _state.providers[provider].isRegistered, + "Provider not registered" + ); + address oldProvider = _state.defaultProvider; + _state.defaultProvider = provider; + emit DefaultProviderUpdated(oldProvider, provider); + } } diff --git a/target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol b/target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol index c3d29f168d..b3a8e10fe0 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol @@ -8,7 +8,7 @@ interface PulseEvents { event PriceUpdateExecuted( uint64 indexed sequenceNumber, - address indexed updater, + address indexed provider, bytes32[] priceIds, int64[] prices, uint64[] conf, @@ -20,7 +20,7 @@ interface PulseEvents { event PriceUpdateCallbackFailed( uint64 indexed sequenceNumber, - address indexed updater, + address indexed provider, bytes32[] priceIds, address requester, string reason @@ -31,4 +31,12 @@ interface PulseEvents { address oldFeeManager, address newFeeManager ); + + event ProviderRegistered(address indexed provider, uint128 feeInWei); + event ProviderFeeUpdated( + address indexed provider, + uint128 oldFee, + uint128 newFee + ); + event DefaultProviderUpdated(address oldProvider, address newProvider); } diff --git a/target_chains/ethereum/contracts/contracts/pulse/PulseState.sol b/target_chains/ethereum/contracts/contracts/pulse/PulseState.sol index 50ef0147cd..02291ed77d 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/PulseState.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/PulseState.sol @@ -16,6 +16,14 @@ contract PulseState { uint8 numPriceIds; // Actual number of price IDs used uint256 callbackGasLimit; address requester; + address provider; + } + + struct ProviderInfo { + uint128 feeInWei; + uint128 accruedFeesInWei; + address feeManager; + bool isRegistered; } struct State { @@ -24,9 +32,10 @@ contract PulseState { uint128 accruedFeesInWei; address pyth; uint64 currentSequenceNumber; - address feeManager; + address defaultProvider; Request[NUM_REQUESTS] requests; mapping(bytes32 => Request) requestsOverflow; + mapping(address => ProviderInfo) providers; } State internal _state; diff --git a/target_chains/ethereum/contracts/contracts/pulse/PulseUpgradeable.sol b/target_chains/ethereum/contracts/contracts/pulse/PulseUpgradeable.sol index 48fc694e69..dc4f5e5a5b 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/PulseUpgradeable.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/PulseUpgradeable.sol @@ -23,6 +23,7 @@ contract PulseUpgradeable is address admin, uint128 pythFeeInWei, address pythAddress, + address defaultProvider, bool prefillRequestStorage ) public initializer { require(owner != address(0), "owner is zero address"); @@ -35,6 +36,7 @@ contract PulseUpgradeable is admin, pythFeeInWei, pythAddress, + defaultProvider, prefillRequestStorage ); diff --git a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol index 1b0af6aade..67754f970f 100644 --- a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol +++ b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol @@ -13,16 +13,16 @@ import "../contracts/pulse/PulseErrors.sol"; contract MockPulseConsumer is IPulseConsumer { uint64 public lastSequenceNumber; - address public lastUpdater; + address public lastProvider; PythStructs.PriceFeed[] private _lastPriceFeeds; function pulseCallback( uint64 sequenceNumber, - address updater, + address provider, PythStructs.PriceFeed[] memory priceFeeds ) external override { lastSequenceNumber = sequenceNumber; - lastUpdater = updater; + lastProvider = provider; for (uint i = 0; i < priceFeeds.length; i++) { _lastPriceFeeds.push(priceFeeds[i]); } @@ -65,11 +65,11 @@ contract PulseTest is Test, PulseEvents { MockPulseConsumer public consumer; address public owner; address public admin; - address public updater; address public pyth; - + address public defaultProvider; // Constants uint128 constant PYTH_FEE = 1 wei; + uint128 constant DEFAULT_PROVIDER_FEE = 1 wei; uint128 constant CALLBACK_GAS_LIMIT = 1_000_000; bytes32 constant BTC_PRICE_FEED_ID = 0xe62df6c8b4a85fe1a67db44dc12de5db330f7ac66b72dc658afedf0f4a415b43; @@ -86,14 +86,15 @@ contract PulseTest is Test, PulseEvents { function setUp() public { owner = address(1); admin = address(2); - updater = address(3); - pyth = address(4); - + pyth = address(3); + defaultProvider = address(4); PulseUpgradeable _pulse = new PulseUpgradeable(); proxy = new ERC1967Proxy(address(_pulse), ""); pulse = PulseUpgradeable(address(proxy)); - pulse.initialize(owner, admin, PYTH_FEE, pyth, false); + pulse.initialize(owner, admin, PYTH_FEE, pyth, defaultProvider, false); + vm.prank(defaultProvider); + pulse.registerProvider(DEFAULT_PROVIDER_FEE); consumer = new MockPulseConsumer(); } @@ -150,13 +151,16 @@ contract PulseTest is Test, PulseEvents { } // Helper function to calculate total fee - function calculateTotalFee() internal view returns (uint128) { - return pulse.getFee(CALLBACK_GAS_LIMIT); + function calculateTotalFee( + address provider + ) internal view returns (uint128) { + return pulse.getFee(CALLBACK_GAS_LIMIT, provider); } // Helper function to setup consumer request function setupConsumerRequest( - address consumerAddress + address consumerAddress, + address provider ) internal returns ( @@ -169,13 +173,14 @@ contract PulseTest is Test, PulseEvents { publishTime = block.timestamp; vm.deal(consumerAddress, 1 gwei); - uint128 totalFee = calculateTotalFee(); + uint128 totalFee = calculateTotalFee(provider); vm.prank(consumerAddress); sequenceNumber = pulse.requestPriceUpdatesWithCallback{value: totalFee}( publishTime, priceIds, - CALLBACK_GAS_LIMIT + CALLBACK_GAS_LIMIT, + provider ); return (sequenceNumber, priceIds, publishTime); @@ -190,7 +195,7 @@ contract PulseTest is Test, PulseEvents { // Fund the consumer contract with enough ETH for higher gas price vm.deal(address(consumer), 1 ether); - uint128 totalFee = calculateTotalFee(); + uint128 totalFee = calculateTotalFee(defaultProvider); // Create the event data we expect to see PulseState.Request memory expectedRequest = PulseState.Request({ @@ -210,7 +215,8 @@ contract PulseTest is Test, PulseEvents { ], numPriceIds: 2, callbackGasLimit: CALLBACK_GAS_LIMIT, - requester: address(consumer) + requester: address(consumer), + provider: defaultProvider }); vm.expectEmit(); @@ -220,7 +226,8 @@ contract PulseTest is Test, PulseEvents { pulse.requestPriceUpdatesWithCallback{value: totalFee}( publishTime, priceIds, - CALLBACK_GAS_LIMIT + CALLBACK_GAS_LIMIT, + defaultProvider ); // Additional assertions to verify event data was stored correctly @@ -253,7 +260,8 @@ contract PulseTest is Test, PulseEvents { pulse.requestPriceUpdatesWithCallback{value: PYTH_FEE}( // Intentionally low fee block.timestamp, priceIds, - CALLBACK_GAS_LIMIT + CALLBACK_GAS_LIMIT, + defaultProvider ); } @@ -263,13 +271,13 @@ contract PulseTest is Test, PulseEvents { // Fund the consumer contract vm.deal(address(consumer), 1 gwei); - uint128 totalFee = calculateTotalFee(); + uint128 totalFee = calculateTotalFee(defaultProvider); // Step 1: Make the request as consumer vm.prank(address(consumer)); uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{ value: totalFee - }(publishTime, priceIds, CALLBACK_GAS_LIMIT); + }(publishTime, priceIds, CALLBACK_GAS_LIMIT, defaultProvider); // Step 2: Create mock price feeds and setup Pyth response PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( @@ -298,7 +306,7 @@ contract PulseTest is Test, PulseEvents { vm.expectEmit(); emit PriceUpdateExecuted( sequenceNumber, - updater, + defaultProvider, priceIds, expectedPrices, expectedConf, @@ -309,11 +317,12 @@ contract PulseTest is Test, PulseEvents { // Create mock update data and execute callback bytes[] memory updateData = createMockUpdateData(priceFeeds); - vm.prank(updater); + vm.prank(defaultProvider); pulse.executeCallback(sequenceNumber, updateData, priceIds); // Verify callback was executed assertEq(consumer.lastSequenceNumber(), sequenceNumber); + assertEq(consumer.lastProvider(), defaultProvider); // Compare price feeds array length PythStructs.PriceFeed[] memory lastFeeds = consumer.lastPriceFeeds(); @@ -339,7 +348,7 @@ contract PulseTest is Test, PulseEvents { uint64 sequenceNumber, bytes32[] memory priceIds, uint256 publishTime - ) = setupConsumerRequest(address(failingConsumer)); + ) = setupConsumerRequest(address(failingConsumer), defaultProvider); PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( publishTime @@ -350,13 +359,13 @@ contract PulseTest is Test, PulseEvents { vm.expectEmit(); emit PriceUpdateCallbackFailed( sequenceNumber, - updater, + defaultProvider, priceIds, address(failingConsumer), "callback failed" ); - vm.prank(updater); + vm.prank(defaultProvider); pulse.executeCallback(sequenceNumber, updateData, priceIds); } @@ -367,7 +376,7 @@ contract PulseTest is Test, PulseEvents { uint64 sequenceNumber, bytes32[] memory priceIds, uint256 publishTime - ) = setupConsumerRequest(address(failingConsumer)); + ) = setupConsumerRequest(address(failingConsumer), defaultProvider); PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( publishTime @@ -378,13 +387,13 @@ contract PulseTest is Test, PulseEvents { vm.expectEmit(); emit PriceUpdateCallbackFailed( sequenceNumber, - updater, + defaultProvider, priceIds, address(failingConsumer), "low-level error (possibly out of gas)" ); - vm.prank(updater); + vm.prank(defaultProvider); pulse.executeCallback(sequenceNumber, updateData, priceIds); } @@ -394,7 +403,7 @@ contract PulseTest is Test, PulseEvents { uint64 sequenceNumber, bytes32[] memory priceIds, uint256 publishTime - ) = setupConsumerRequest(address(consumer)); + ) = setupConsumerRequest(address(consumer), defaultProvider); // Setup mock data PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( @@ -404,7 +413,7 @@ contract PulseTest is Test, PulseEvents { bytes[] memory updateData = createMockUpdateData(priceFeeds); // Try executing with only 100K gas when 1M is required - vm.prank(updater); + vm.prank(defaultProvider); vm.expectRevert(InsufficientGas.selector); pulse.executeCallback{gas: 100000}( sequenceNumber, @@ -419,11 +428,11 @@ contract PulseTest is Test, PulseEvents { uint256 futureTime = block.timestamp + 10; // 10 seconds in future vm.deal(address(consumer), 1 gwei); - uint128 totalFee = calculateTotalFee(); + uint128 totalFee = calculateTotalFee(defaultProvider); vm.prank(address(consumer)); uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{ value: totalFee - }(futureTime, priceIds, CALLBACK_GAS_LIMIT); + }(futureTime, priceIds, CALLBACK_GAS_LIMIT, defaultProvider); // Try to execute callback before the requested timestamp PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( @@ -432,7 +441,7 @@ contract PulseTest is Test, PulseEvents { mockParsePriceFeedUpdates(priceFeeds); // This will make parsePriceFeedUpdates return future-dated prices bytes[] memory updateData = createMockUpdateData(priceFeeds); - vm.prank(updater); + vm.prank(defaultProvider); // Should succeed because we're simulating receiving future-dated price updates pulse.executeCallback(sequenceNumber, updateData, priceIds); @@ -454,14 +463,15 @@ contract PulseTest is Test, PulseEvents { uint256 farFutureTime = block.timestamp + 61; // Just over 1 minute vm.deal(address(consumer), 1 gwei); - uint128 totalFee = calculateTotalFee(); + uint128 totalFee = calculateTotalFee(defaultProvider); vm.prank(address(consumer)); vm.expectRevert("Too far in future"); pulse.requestPriceUpdatesWithCallback{value: totalFee}( farFutureTime, priceIds, - CALLBACK_GAS_LIMIT + CALLBACK_GAS_LIMIT, + defaultProvider ); } @@ -470,7 +480,7 @@ contract PulseTest is Test, PulseEvents { uint64 sequenceNumber, bytes32[] memory priceIds, uint256 publishTime - ) = setupConsumerRequest(address(consumer)); + ) = setupConsumerRequest(address(consumer), defaultProvider); PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( publishTime @@ -479,11 +489,11 @@ contract PulseTest is Test, PulseEvents { bytes[] memory updateData = createMockUpdateData(priceFeeds); // First execution - vm.prank(updater); + vm.prank(defaultProvider); pulse.executeCallback(sequenceNumber, updateData, priceIds); // Second execution should fail - vm.prank(updater); + vm.prank(defaultProvider); vm.expectRevert(NoSuchRequest.selector); pulse.executeCallback(sequenceNumber, updateData, priceIds); } @@ -497,9 +507,10 @@ contract PulseTest is Test, PulseEvents { for (uint256 i = 0; i < gasLimits.length; i++) { uint256 gasLimit = gasLimits[i]; - uint128 expectedFee = SafeCast.toUint128(tx.gasprice * gasLimit) + - PYTH_FEE; - uint128 actualFee = pulse.getFee(gasLimit); + uint128 expectedFee = SafeCast.toUint128( + DEFAULT_PROVIDER_FEE * gasLimit + ) + PYTH_FEE; + uint128 actualFee = pulse.getFee(gasLimit, defaultProvider); assertEq( actualFee, expectedFee, @@ -509,7 +520,7 @@ contract PulseTest is Test, PulseEvents { // Test with zero gas limit uint128 expectedMinFee = PYTH_FEE; - uint128 actualMinFee = pulse.getFee(0); + uint128 actualMinFee = pulse.getFee(0, defaultProvider); assertEq( actualMinFee, expectedMinFee, @@ -523,11 +534,9 @@ contract PulseTest is Test, PulseEvents { vm.deal(address(consumer), 1 gwei); vm.prank(address(consumer)); - pulse.requestPriceUpdatesWithCallback{value: calculateTotalFee()}( - block.timestamp, - priceIds, - CALLBACK_GAS_LIMIT - ); + pulse.requestPriceUpdatesWithCallback{ + value: calculateTotalFee(defaultProvider) + }(block.timestamp, priceIds, CALLBACK_GAS_LIMIT, defaultProvider); // Get admin's balance before withdrawal uint256 adminBalanceBefore = admin.balance; @@ -565,8 +574,7 @@ contract PulseTest is Test, PulseEvents { function testSetAndWithdrawAsFeeManager() public { address feeManager = address(0x789); - // Set fee manager as admin - vm.prank(admin); + vm.prank(defaultProvider); pulse.setFeeManager(feeManager); // Setup: Request price update to accrue some fees @@ -574,53 +582,57 @@ contract PulseTest is Test, PulseEvents { vm.deal(address(consumer), 1 gwei); vm.prank(address(consumer)); - pulse.requestPriceUpdatesWithCallback{value: calculateTotalFee()}( - block.timestamp, - priceIds, - CALLBACK_GAS_LIMIT + pulse.requestPriceUpdatesWithCallback{ + value: calculateTotalFee(defaultProvider) + }(block.timestamp, priceIds, CALLBACK_GAS_LIMIT, defaultProvider); + + // Get provider's accrued fees instead of total fees + PulseState.ProviderInfo memory providerInfo = pulse.getProviderInfo( + defaultProvider ); + uint128 providerAccruedFees = providerInfo.accruedFeesInWei; - // Test withdrawal as fee manager uint256 managerBalanceBefore = feeManager.balance; - uint128 accruedFees = pulse.getAccruedFees(); vm.prank(feeManager); - pulse.withdrawAsFeeManager(accruedFees); + pulse.withdrawAsFeeManager(defaultProvider, providerAccruedFees); assertEq( feeManager.balance, - managerBalanceBefore + accruedFees, + managerBalanceBefore + providerAccruedFees, "Fee manager balance should increase by withdrawn amount" ); + + providerInfo = pulse.getProviderInfo(defaultProvider); assertEq( - pulse.getAccruedFees(), + providerInfo.accruedFeesInWei, 0, - "Contract should have no fees after withdrawal" + "Provider should have no fees after withdrawal" ); } function testSetFeeManagerUnauthorized() public { address feeManager = address(0x789); vm.prank(address(0xdead)); - vm.expectRevert("Only admin can set fee manager"); + vm.expectRevert("Provider not registered"); pulse.setFeeManager(feeManager); } function testWithdrawAsFeeManagerUnauthorized() public { vm.prank(address(0xdead)); vm.expectRevert("Only fee manager"); - pulse.withdrawAsFeeManager(1 ether); + pulse.withdrawAsFeeManager(defaultProvider, 1 ether); } function testWithdrawAsFeeManagerInsufficientBalance() public { // Set up fee manager first address feeManager = address(0x789); - vm.prank(admin); + vm.prank(defaultProvider); pulse.setFeeManager(feeManager); vm.prank(feeManager); vm.expectRevert("Insufficient balance"); - pulse.withdrawAsFeeManager(1 ether); + pulse.withdrawAsFeeManager(defaultProvider, 1 ether); } // Add new test for invalid priceIds @@ -629,7 +641,10 @@ contract PulseTest is Test, PulseEvents { uint256 publishTime = block.timestamp; // Setup request - (uint64 sequenceNumber, , ) = setupConsumerRequest(address(consumer)); + (uint64 sequenceNumber, , ) = setupConsumerRequest( + address(consumer), + defaultProvider + ); // Create different priceIds bytes32[] memory wrongPriceIds = new bytes32[](2); @@ -643,7 +658,7 @@ contract PulseTest is Test, PulseEvents { bytes[] memory updateData = createMockUpdateData(priceFeeds); // Should revert when trying to execute with wrong priceIds - vm.prank(updater); + vm.prank(defaultProvider); vm.expectRevert( abi.encodeWithSelector( InvalidPriceIds.selector, @@ -660,7 +675,7 @@ contract PulseTest is Test, PulseEvents { uint64 sequenceNumber, bytes32[] memory priceIds, uint256 publishTime - ) = setupConsumerRequest(address(consumer)); + ) = setupConsumerRequest(address(consumer), defaultProvider); // Setup mock data PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( @@ -670,7 +685,7 @@ contract PulseTest is Test, PulseEvents { bytes[] memory updateData = createMockUpdateData(priceFeeds); // Should fail with exactly 1.4x gas (less than required 1.5x) - vm.prank(updater); + vm.prank(defaultProvider); vm.expectRevert(InsufficientGas.selector); pulse.executeCallback{gas: (CALLBACK_GAS_LIMIT * 14) / 10}( sequenceNumber, @@ -679,7 +694,7 @@ contract PulseTest is Test, PulseEvents { ); // Should succeed with 1.6x gas - vm.prank(updater); + vm.prank(defaultProvider); pulse.executeCallback{gas: (CALLBACK_GAS_LIMIT * 16) / 10}( sequenceNumber, updateData, @@ -688,7 +703,7 @@ contract PulseTest is Test, PulseEvents { // Verify callback was executed successfully assertEq(consumer.lastSequenceNumber(), sequenceNumber); - assertEq(consumer.lastUpdater(), updater); + assertEq(consumer.lastProvider(), defaultProvider); } function testRevertOnTooManyPriceIds() public { @@ -700,7 +715,7 @@ contract PulseTest is Test, PulseEvents { } vm.deal(address(consumer), 1 gwei); - uint128 totalFee = calculateTotalFee(); + uint128 totalFee = calculateTotalFee(defaultProvider); vm.prank(address(consumer)); vm.expectRevert( @@ -713,7 +728,73 @@ contract PulseTest is Test, PulseEvents { pulse.requestPriceUpdatesWithCallback{value: totalFee}( block.timestamp, priceIds, - CALLBACK_GAS_LIMIT + CALLBACK_GAS_LIMIT, + defaultProvider ); } + + function testProviderRegistration() public { + address provider = address(0x123); + uint128 providerFee = 1000; + + vm.prank(provider); + pulse.registerProvider(providerFee); + + PulseState.ProviderInfo memory info = pulse.getProviderInfo(provider); + assertEq(info.feeInWei, providerFee); + assertTrue(info.isRegistered); + } + + function testSetProviderFee() public { + address provider = address(0x123); + uint128 initialFee = 1000; + uint128 newFee = 2000; + + vm.prank(provider); + pulse.registerProvider(initialFee); + + vm.prank(provider); + pulse.setProviderFee(newFee); + + PulseState.ProviderInfo memory info = pulse.getProviderInfo(provider); + assertEq(info.feeInWei, newFee); + } + + function testDefaultProvider() public { + address provider = address(0x123); + uint128 providerFee = 1000; + + vm.prank(provider); + pulse.registerProvider(providerFee); + + vm.prank(admin); + pulse.setDefaultProvider(provider); + + assertEq(pulse.getDefaultProvider(), provider); + } + + function testRequestWithProvider() public { + address provider = address(0x123); + uint128 providerFee = 1000; + + vm.prank(provider); + pulse.registerProvider(providerFee); + + vm.prank(admin); + pulse.setDefaultProvider(provider); + + bytes32[] memory priceIds = new bytes32[](1); + priceIds[0] = bytes32(uint256(1)); + + uint128 totalFee = pulse.getFee(CALLBACK_GAS_LIMIT, provider); + + vm.deal(address(consumer), totalFee); + vm.prank(address(consumer)); + uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{ + value: totalFee + }(block.timestamp, priceIds, CALLBACK_GAS_LIMIT, provider); + + PulseState.Request memory req = pulse.getRequest(sequenceNumber); + assertEq(req.provider, provider); + } } From 0b4d34a3014abce0c95935328ab75c8ea5b02b4e Mon Sep 17 00:00:00 2001 From: Daniel Chew Date: Mon, 20 Jan 2025 21:45:02 +0900 Subject: [PATCH 2/9] remove unnecessary code --- .../ethereum/contracts/contracts/pulse/Pulse.sol | 9 --------- 1 file changed, 9 deletions(-) diff --git a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol index 0cc7d3a384..43c63c880b 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol @@ -124,15 +124,6 @@ abstract contract Pulse is IPulse, PulseState { clearRequest(sequenceNumber); - // Check if enough gas remains for callback + events/cleanup - // We need extra gas beyond callbackGasLimit for: - // 1. Emitting success/failure events - // 2. Error handling in catch blocks - // 3. State cleanup operations - if (gasleft() < (req.callbackGasLimit * 3) / 2) { - revert InsufficientGas(); - } - try IPulseConsumer(req.requester).pulseCallback{ gas: req.callbackGasLimit From d68155e38efef7a3b4f0af5add398db5d52191a1 Mon Sep 17 00:00:00 2001 From: Daniel Chew Date: Mon, 20 Jan 2025 21:58:51 +0900 Subject: [PATCH 3/9] fix test --- .../contracts/contracts/pulse/PulseErrors.sol | 1 - .../ethereum/contracts/forge-test/Pulse.t.sol | 39 +------------------ 2 files changed, 1 insertion(+), 39 deletions(-) diff --git a/target_chains/ethereum/contracts/contracts/pulse/PulseErrors.sol b/target_chains/ethereum/contracts/contracts/pulse/PulseErrors.sol index aacb123ba5..c92f4e0858 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/PulseErrors.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/PulseErrors.sol @@ -11,5 +11,4 @@ error CallbackFailed(); error InvalidPriceIds(bytes32 providedPriceIdsHash, bytes32 storedPriceIdsHash); error InvalidCallbackGasLimit(uint256 requested, uint256 stored); error ExceedsMaxPrices(uint32 requested, uint32 maxAllowed); -error InsufficientGas(); error TooManyPriceIds(uint256 provided, uint256 maximum); diff --git a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol index 67754f970f..1e6d2ade1f 100644 --- a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol +++ b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol @@ -414,7 +414,7 @@ contract PulseTest is Test, PulseEvents { // Try executing with only 100K gas when 1M is required vm.prank(defaultProvider); - vm.expectRevert(InsufficientGas.selector); + vm.expectRevert(); // Just expect any revert since it will be an out-of-gas error pulse.executeCallback{gas: 100000}( sequenceNumber, updateData, @@ -669,43 +669,6 @@ contract PulseTest is Test, PulseEvents { pulse.executeCallback(sequenceNumber, updateData, wrongPriceIds); } - function testExecuteCallbackGasOverhead() public { - // Setup request with 1M gas limit - ( - uint64 sequenceNumber, - bytes32[] memory priceIds, - uint256 publishTime - ) = setupConsumerRequest(address(consumer), defaultProvider); - - // Setup mock data - PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( - publishTime - ); - mockParsePriceFeedUpdates(priceFeeds); - bytes[] memory updateData = createMockUpdateData(priceFeeds); - - // Should fail with exactly 1.4x gas (less than required 1.5x) - vm.prank(defaultProvider); - vm.expectRevert(InsufficientGas.selector); - pulse.executeCallback{gas: (CALLBACK_GAS_LIMIT * 14) / 10}( - sequenceNumber, - updateData, - priceIds - ); - - // Should succeed with 1.6x gas - vm.prank(defaultProvider); - pulse.executeCallback{gas: (CALLBACK_GAS_LIMIT * 16) / 10}( - sequenceNumber, - updateData, - priceIds - ); - - // Verify callback was executed successfully - assertEq(consumer.lastSequenceNumber(), sequenceNumber); - assertEq(consumer.lastProvider(), defaultProvider); - } - function testRevertOnTooManyPriceIds() public { uint256 maxPriceIds = uint256(pulse.MAX_PRICE_IDS()); // Create array with MAX_PRICE_IDS + 1 price IDs From f67734fa6670ab42c5957b6742a9b61b12237b26 Mon Sep 17 00:00:00 2001 From: Daniel Chew Date: Thu, 23 Jan 2025 14:24:16 +0900 Subject: [PATCH 4/9] add exclusivity period to provider (#2282) --- .../contracts/contracts/pulse/IPulse.sol | 7 +- .../contracts/contracts/pulse/Pulse.sol | 32 +++- .../contracts/contracts/pulse/PulseEvents.sol | 5 + .../contracts/contracts/pulse/PulseState.sol | 1 + .../ethereum/contracts/forge-test/Pulse.t.sol | 146 ++++++++++++++++-- 5 files changed, 169 insertions(+), 22 deletions(-) diff --git a/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol b/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol index 54962e929a..e913e5e3dd 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol @@ -33,8 +33,7 @@ interface IPulse is PulseEvents { function requestPriceUpdatesWithCallback( uint256 publishTime, bytes32[] calldata priceIds, - uint256 callbackGasLimit, - address provider + uint256 callbackGasLimit ) external payable returns (uint64 sequenceNumber); /** @@ -92,4 +91,8 @@ interface IPulse is PulseEvents { function getDefaultProvider() external view returns (address); function setDefaultProvider(address provider) external; + + function setExclusivityPeriod(uint256 periodSeconds) external; + + function getExclusivityPeriod() external view returns (uint256); } diff --git a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol index 43c63c880b..55162499fa 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol @@ -29,6 +29,7 @@ abstract contract Pulse is IPulse, PulseState { _state.pyth = pythAddress; _state.currentSequenceNumber = 1; _state.defaultProvider = defaultProvider; + _state.exclusivityPeriodSeconds = 15; // Default to 15 seconds if (prefillRequestStorage) { for (uint8 i = 0; i < NUM_REQUESTS; i++) { @@ -49,12 +50,9 @@ abstract contract Pulse is IPulse, PulseState { function requestPriceUpdatesWithCallback( uint256 publishTime, bytes32[] calldata priceIds, - uint256 callbackGasLimit, - address provider + uint256 callbackGasLimit ) external payable override returns (uint64 requestSequenceNumber) { - if (provider == address(0)) { - provider = _state.defaultProvider; - } + address provider = _state.defaultProvider; require( _state.providers[provider].isRegistered, "Provider not registered" @@ -102,6 +100,16 @@ abstract contract Pulse is IPulse, PulseState { ) external payable override { Request storage req = findActiveRequest(sequenceNumber); + // Check provider exclusivity using configurable period + if ( + block.timestamp < req.publishTime + _state.exclusivityPeriodSeconds + ) { + require( + msg.sender == req.provider, + "Only assigned provider during exclusivity period" + ); + } + // Verify priceIds match require( priceIds.length == req.numPriceIds, @@ -357,4 +365,18 @@ abstract contract Pulse is IPulse, PulseState { _state.defaultProvider = provider; emit DefaultProviderUpdated(oldProvider, provider); } + + function setExclusivityPeriod(uint256 periodSeconds) external override { + require( + msg.sender == _state.admin, + "Only admin can set exclusivity period" + ); + uint256 oldPeriod = _state.exclusivityPeriodSeconds; + _state.exclusivityPeriodSeconds = periodSeconds; + emit ExclusivityPeriodUpdated(oldPeriod, periodSeconds); + } + + function getExclusivityPeriod() external view override returns (uint256) { + return _state.exclusivityPeriodSeconds; + } } diff --git a/target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol b/target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol index b3a8e10fe0..b83a8c244d 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol @@ -39,4 +39,9 @@ interface PulseEvents { uint128 newFee ); event DefaultProviderUpdated(address oldProvider, address newProvider); + + event ExclusivityPeriodUpdated( + uint256 oldPeriodSeconds, + uint256 newPeriodSeconds + ); } diff --git a/target_chains/ethereum/contracts/contracts/pulse/PulseState.sol b/target_chains/ethereum/contracts/contracts/pulse/PulseState.sol index 02291ed77d..3d9fff9f76 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/PulseState.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/PulseState.sol @@ -33,6 +33,7 @@ contract PulseState { address pyth; uint64 currentSequenceNumber; address defaultProvider; + uint256 exclusivityPeriodSeconds; Request[NUM_REQUESTS] requests; mapping(bytes32 => Request) requestsOverflow; mapping(address => ProviderInfo) providers; diff --git a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol index 1e6d2ade1f..fb6ccfd5d0 100644 --- a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol +++ b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol @@ -179,8 +179,7 @@ contract PulseTest is Test, PulseEvents { sequenceNumber = pulse.requestPriceUpdatesWithCallback{value: totalFee}( publishTime, priceIds, - CALLBACK_GAS_LIMIT, - provider + CALLBACK_GAS_LIMIT ); return (sequenceNumber, priceIds, publishTime); @@ -226,8 +225,7 @@ contract PulseTest is Test, PulseEvents { pulse.requestPriceUpdatesWithCallback{value: totalFee}( publishTime, priceIds, - CALLBACK_GAS_LIMIT, - defaultProvider + CALLBACK_GAS_LIMIT ); // Additional assertions to verify event data was stored correctly @@ -260,8 +258,7 @@ contract PulseTest is Test, PulseEvents { pulse.requestPriceUpdatesWithCallback{value: PYTH_FEE}( // Intentionally low fee block.timestamp, priceIds, - CALLBACK_GAS_LIMIT, - defaultProvider + CALLBACK_GAS_LIMIT ); } @@ -277,7 +274,7 @@ contract PulseTest is Test, PulseEvents { vm.prank(address(consumer)); uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{ value: totalFee - }(publishTime, priceIds, CALLBACK_GAS_LIMIT, defaultProvider); + }(publishTime, priceIds, CALLBACK_GAS_LIMIT); // Step 2: Create mock price feeds and setup Pyth response PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( @@ -432,7 +429,7 @@ contract PulseTest is Test, PulseEvents { vm.prank(address(consumer)); uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{ value: totalFee - }(futureTime, priceIds, CALLBACK_GAS_LIMIT, defaultProvider); + }(futureTime, priceIds, CALLBACK_GAS_LIMIT); // Try to execute callback before the requested timestamp PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( @@ -470,8 +467,7 @@ contract PulseTest is Test, PulseEvents { pulse.requestPriceUpdatesWithCallback{value: totalFee}( farFutureTime, priceIds, - CALLBACK_GAS_LIMIT, - defaultProvider + CALLBACK_GAS_LIMIT ); } @@ -536,7 +532,7 @@ contract PulseTest is Test, PulseEvents { vm.prank(address(consumer)); pulse.requestPriceUpdatesWithCallback{ value: calculateTotalFee(defaultProvider) - }(block.timestamp, priceIds, CALLBACK_GAS_LIMIT, defaultProvider); + }(block.timestamp, priceIds, CALLBACK_GAS_LIMIT); // Get admin's balance before withdrawal uint256 adminBalanceBefore = admin.balance; @@ -584,7 +580,7 @@ contract PulseTest is Test, PulseEvents { vm.prank(address(consumer)); pulse.requestPriceUpdatesWithCallback{ value: calculateTotalFee(defaultProvider) - }(block.timestamp, priceIds, CALLBACK_GAS_LIMIT, defaultProvider); + }(block.timestamp, priceIds, CALLBACK_GAS_LIMIT); // Get provider's accrued fees instead of total fees PulseState.ProviderInfo memory providerInfo = pulse.getProviderInfo( @@ -691,8 +687,7 @@ contract PulseTest is Test, PulseEvents { pulse.requestPriceUpdatesWithCallback{value: totalFee}( block.timestamp, priceIds, - CALLBACK_GAS_LIMIT, - defaultProvider + CALLBACK_GAS_LIMIT ); } @@ -755,9 +750,130 @@ contract PulseTest is Test, PulseEvents { vm.prank(address(consumer)); uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{ value: totalFee - }(block.timestamp, priceIds, CALLBACK_GAS_LIMIT, provider); + }(block.timestamp, priceIds, CALLBACK_GAS_LIMIT); PulseState.Request memory req = pulse.getRequest(sequenceNumber); assertEq(req.provider, provider); } + + function testExclusivityPeriod() public { + // Test initial value + assertEq( + pulse.getExclusivityPeriod(), + 15, + "Initial exclusivity period should be 15 seconds" + ); + + // Test setting new value + vm.prank(admin); + vm.expectEmit(); + emit ExclusivityPeriodUpdated(15, 30); + pulse.setExclusivityPeriod(30); + + assertEq( + pulse.getExclusivityPeriod(), + 30, + "Exclusivity period should be updated" + ); + } + + function testSetExclusivityPeriodUnauthorized() public { + vm.prank(address(0xdead)); + vm.expectRevert("Only admin can set exclusivity period"); + pulse.setExclusivityPeriod(30); + } + + function testExecuteCallbackDuringExclusivity() public { + // Register a second provider + address secondProvider = address(0x456); + vm.prank(secondProvider); + pulse.registerProvider(DEFAULT_PROVIDER_FEE); + + // Setup request + ( + uint64 sequenceNumber, + bytes32[] memory priceIds, + uint256 publishTime + ) = setupConsumerRequest(address(consumer), defaultProvider); + + // Setup mock data + PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( + publishTime + ); + mockParsePriceFeedUpdates(priceFeeds); + bytes[] memory updateData = createMockUpdateData(priceFeeds); + + // Try to execute with second provider during exclusivity period + vm.prank(secondProvider); + vm.expectRevert("Only assigned provider during exclusivity period"); + pulse.executeCallback(sequenceNumber, updateData, priceIds); + + // Original provider should succeed + vm.prank(defaultProvider); + pulse.executeCallback(sequenceNumber, updateData, priceIds); + } + + function testExecuteCallbackAfterExclusivity() public { + // Register a second provider + address secondProvider = address(0x456); + vm.prank(secondProvider); + pulse.registerProvider(DEFAULT_PROVIDER_FEE); + + // Setup request + ( + uint64 sequenceNumber, + bytes32[] memory priceIds, + uint256 publishTime + ) = setupConsumerRequest(address(consumer), defaultProvider); + + // Setup mock data + PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( + publishTime + ); + mockParsePriceFeedUpdates(priceFeeds); + bytes[] memory updateData = createMockUpdateData(priceFeeds); + + // Wait for exclusivity period to end + vm.warp(block.timestamp + pulse.getExclusivityPeriod() + 1); + + // Second provider should now succeed + vm.prank(secondProvider); + pulse.executeCallback(sequenceNumber, updateData, priceIds); + } + + function testExecuteCallbackWithCustomExclusivityPeriod() public { + // Register a second provider + address secondProvider = address(0x456); + vm.prank(secondProvider); + pulse.registerProvider(DEFAULT_PROVIDER_FEE); + + // Set custom exclusivity period + vm.prank(admin); + pulse.setExclusivityPeriod(30); + + // Setup request + ( + uint64 sequenceNumber, + bytes32[] memory priceIds, + uint256 publishTime + ) = setupConsumerRequest(address(consumer), defaultProvider); + + // Setup mock data + PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( + publishTime + ); + mockParsePriceFeedUpdates(priceFeeds); + bytes[] memory updateData = createMockUpdateData(priceFeeds); + + // Try at 29 seconds (should fail for second provider) + vm.warp(block.timestamp + 29); + vm.prank(secondProvider); + vm.expectRevert("Only assigned provider during exclusivity period"); + pulse.executeCallback(sequenceNumber, updateData, priceIds); + + // Try at 31 seconds (should succeed for second provider) + vm.warp(block.timestamp + 2); + vm.prank(secondProvider); + pulse.executeCallback(sequenceNumber, updateData, priceIds); + } } From fb8a7cd4e6d278bb05bcf0d9dcfc339892171d76 Mon Sep 17 00:00:00 2001 From: Daniel Chew Date: Thu, 23 Jan 2025 14:36:15 +0900 Subject: [PATCH 5/9] make exclusivityPeriodSeconds configurable --- .../ethereum/contracts/contracts/pulse/Pulse.sol | 5 +++-- .../contracts/contracts/pulse/PulseUpgradeable.sol | 8 +++++--- .../ethereum/contracts/forge-test/Pulse.t.sol | 10 +++++++++- 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol index 55162499fa..6b5392c677 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol @@ -14,7 +14,8 @@ abstract contract Pulse is IPulse, PulseState { uint128 pythFeeInWei, address pythAddress, address defaultProvider, - bool prefillRequestStorage + bool prefillRequestStorage, + uint256 exclusivityPeriodSeconds ) internal { require(admin != address(0), "admin is zero address"); require(pythAddress != address(0), "pyth is zero address"); @@ -29,7 +30,7 @@ abstract contract Pulse is IPulse, PulseState { _state.pyth = pythAddress; _state.currentSequenceNumber = 1; _state.defaultProvider = defaultProvider; - _state.exclusivityPeriodSeconds = 15; // Default to 15 seconds + _state.exclusivityPeriodSeconds = exclusivityPeriodSeconds; if (prefillRequestStorage) { for (uint8 i = 0; i < NUM_REQUESTS; i++) { diff --git a/target_chains/ethereum/contracts/contracts/pulse/PulseUpgradeable.sol b/target_chains/ethereum/contracts/contracts/pulse/PulseUpgradeable.sol index dc4f5e5a5b..f3ceafc5ed 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/PulseUpgradeable.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/PulseUpgradeable.sol @@ -24,8 +24,9 @@ contract PulseUpgradeable is uint128 pythFeeInWei, address pythAddress, address defaultProvider, - bool prefillRequestStorage - ) public initializer { + bool prefillRequestStorage, + uint256 exclusivityPeriodSeconds + ) external initializer { require(owner != address(0), "owner is zero address"); require(admin != address(0), "admin is zero address"); @@ -37,7 +38,8 @@ contract PulseUpgradeable is pythFeeInWei, pythAddress, defaultProvider, - prefillRequestStorage + prefillRequestStorage, + exclusivityPeriodSeconds ); _transferOwnership(owner); diff --git a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol index fb6ccfd5d0..7c053dec8d 100644 --- a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol +++ b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol @@ -92,7 +92,15 @@ contract PulseTest is Test, PulseEvents { proxy = new ERC1967Proxy(address(_pulse), ""); pulse = PulseUpgradeable(address(proxy)); - pulse.initialize(owner, admin, PYTH_FEE, pyth, defaultProvider, false); + pulse.initialize( + owner, + admin, + PYTH_FEE, + pyth, + defaultProvider, + false, + 15 + ); vm.prank(defaultProvider); pulse.registerProvider(DEFAULT_PROVIDER_FEE); consumer = new MockPulseConsumer(); From 5812d786946deee0428d671acf4adbb1e425f5b5 Mon Sep 17 00:00:00 2001 From: Daniel Chew Date: Thu, 23 Jan 2025 14:41:16 +0900 Subject: [PATCH 6/9] remove provider arg from getFee --- .../contracts/contracts/pulse/IPulse.sol | 4 +- .../contracts/contracts/pulse/Pulse.sol | 12 ++-- .../ethereum/contracts/forge-test/Pulse.t.sol | 62 +++++++++---------- 3 files changed, 36 insertions(+), 42 deletions(-) diff --git a/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol b/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol index e913e5e3dd..25c06d5517 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol @@ -62,12 +62,10 @@ interface IPulse is PulseEvents { * @notice Calculates the total fee required for a price update request * @dev Total fee = base Pyth protocol fee + gas costs for callback * @param callbackGasLimit The amount of gas allocated for callback execution - * @param provider The provider to use for the fee calculation * @return feeAmount The total fee in wei that must be provided as msg.value */ function getFee( - uint256 callbackGasLimit, - address provider + uint256 callbackGasLimit ) external view returns (uint128 feeAmount); function getAccruedFees() external view returns (uint128 accruedFeesInWei); diff --git a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol index 6b5392c677..67cb0dd02a 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol @@ -70,7 +70,7 @@ abstract contract Pulse is IPulse, PulseState { } requestSequenceNumber = _state.currentSequenceNumber++; - uint128 requiredFee = getFee(callbackGasLimit, provider); + uint128 requiredFee = getFee(callbackGasLimit); if (msg.value < requiredFee) revert InsufficientFee(); Request storage req = allocRequest(requestSequenceNumber); @@ -190,14 +190,12 @@ abstract contract Pulse is IPulse, PulseState { } function getFee( - uint256 callbackGasLimit, - address provider + uint256 callbackGasLimit ) public view override returns (uint128 feeAmount) { - if (provider == address(0)) { - provider = _state.defaultProvider; - } uint128 baseFee = _state.pythFeeInWei; - uint128 providerFeeInWei = _state.providers[provider].feeInWei; + uint128 providerFeeInWei = _state + .providers[_state.defaultProvider] + .feeInWei; uint256 gasFee = callbackGasLimit * providerFeeInWei; feeAmount = baseFee + SafeCast.toUint128(gasFee); } diff --git a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol index 7c053dec8d..039de8423f 100644 --- a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol +++ b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol @@ -159,16 +159,13 @@ contract PulseTest is Test, PulseEvents { } // Helper function to calculate total fee - function calculateTotalFee( - address provider - ) internal view returns (uint128) { - return pulse.getFee(CALLBACK_GAS_LIMIT, provider); + function calculateTotalFee() internal view returns (uint128) { + return pulse.getFee(CALLBACK_GAS_LIMIT); } // Helper function to setup consumer request function setupConsumerRequest( - address consumerAddress, - address provider + address consumerAddress ) internal returns ( @@ -181,7 +178,7 @@ contract PulseTest is Test, PulseEvents { publishTime = block.timestamp; vm.deal(consumerAddress, 1 gwei); - uint128 totalFee = calculateTotalFee(provider); + uint128 totalFee = calculateTotalFee(); vm.prank(consumerAddress); sequenceNumber = pulse.requestPriceUpdatesWithCallback{value: totalFee}( @@ -202,7 +199,7 @@ contract PulseTest is Test, PulseEvents { // Fund the consumer contract with enough ETH for higher gas price vm.deal(address(consumer), 1 ether); - uint128 totalFee = calculateTotalFee(defaultProvider); + uint128 totalFee = calculateTotalFee(); // Create the event data we expect to see PulseState.Request memory expectedRequest = PulseState.Request({ @@ -276,7 +273,7 @@ contract PulseTest is Test, PulseEvents { // Fund the consumer contract vm.deal(address(consumer), 1 gwei); - uint128 totalFee = calculateTotalFee(defaultProvider); + uint128 totalFee = calculateTotalFee(); // Step 1: Make the request as consumer vm.prank(address(consumer)); @@ -353,7 +350,7 @@ contract PulseTest is Test, PulseEvents { uint64 sequenceNumber, bytes32[] memory priceIds, uint256 publishTime - ) = setupConsumerRequest(address(failingConsumer), defaultProvider); + ) = setupConsumerRequest(address(failingConsumer)); PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( publishTime @@ -381,7 +378,7 @@ contract PulseTest is Test, PulseEvents { uint64 sequenceNumber, bytes32[] memory priceIds, uint256 publishTime - ) = setupConsumerRequest(address(failingConsumer), defaultProvider); + ) = setupConsumerRequest(address(failingConsumer)); PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( publishTime @@ -408,7 +405,7 @@ contract PulseTest is Test, PulseEvents { uint64 sequenceNumber, bytes32[] memory priceIds, uint256 publishTime - ) = setupConsumerRequest(address(consumer), defaultProvider); + ) = setupConsumerRequest(address(consumer)); // Setup mock data PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( @@ -433,7 +430,7 @@ contract PulseTest is Test, PulseEvents { uint256 futureTime = block.timestamp + 10; // 10 seconds in future vm.deal(address(consumer), 1 gwei); - uint128 totalFee = calculateTotalFee(defaultProvider); + uint128 totalFee = calculateTotalFee(); vm.prank(address(consumer)); uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{ value: totalFee @@ -468,7 +465,7 @@ contract PulseTest is Test, PulseEvents { uint256 farFutureTime = block.timestamp + 61; // Just over 1 minute vm.deal(address(consumer), 1 gwei); - uint128 totalFee = calculateTotalFee(defaultProvider); + uint128 totalFee = calculateTotalFee(); vm.prank(address(consumer)); vm.expectRevert("Too far in future"); @@ -484,7 +481,7 @@ contract PulseTest is Test, PulseEvents { uint64 sequenceNumber, bytes32[] memory priceIds, uint256 publishTime - ) = setupConsumerRequest(address(consumer), defaultProvider); + ) = setupConsumerRequest(address(consumer)); PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( publishTime @@ -514,7 +511,7 @@ contract PulseTest is Test, PulseEvents { uint128 expectedFee = SafeCast.toUint128( DEFAULT_PROVIDER_FEE * gasLimit ) + PYTH_FEE; - uint128 actualFee = pulse.getFee(gasLimit, defaultProvider); + uint128 actualFee = pulse.getFee(gasLimit); assertEq( actualFee, expectedFee, @@ -524,7 +521,7 @@ contract PulseTest is Test, PulseEvents { // Test with zero gas limit uint128 expectedMinFee = PYTH_FEE; - uint128 actualMinFee = pulse.getFee(0, defaultProvider); + uint128 actualMinFee = pulse.getFee(0); assertEq( actualMinFee, expectedMinFee, @@ -538,9 +535,11 @@ contract PulseTest is Test, PulseEvents { vm.deal(address(consumer), 1 gwei); vm.prank(address(consumer)); - pulse.requestPriceUpdatesWithCallback{ - value: calculateTotalFee(defaultProvider) - }(block.timestamp, priceIds, CALLBACK_GAS_LIMIT); + pulse.requestPriceUpdatesWithCallback{value: calculateTotalFee()}( + block.timestamp, + priceIds, + CALLBACK_GAS_LIMIT + ); // Get admin's balance before withdrawal uint256 adminBalanceBefore = admin.balance; @@ -586,9 +585,11 @@ contract PulseTest is Test, PulseEvents { vm.deal(address(consumer), 1 gwei); vm.prank(address(consumer)); - pulse.requestPriceUpdatesWithCallback{ - value: calculateTotalFee(defaultProvider) - }(block.timestamp, priceIds, CALLBACK_GAS_LIMIT); + pulse.requestPriceUpdatesWithCallback{value: calculateTotalFee()}( + block.timestamp, + priceIds, + CALLBACK_GAS_LIMIT + ); // Get provider's accrued fees instead of total fees PulseState.ProviderInfo memory providerInfo = pulse.getProviderInfo( @@ -645,10 +646,7 @@ contract PulseTest is Test, PulseEvents { uint256 publishTime = block.timestamp; // Setup request - (uint64 sequenceNumber, , ) = setupConsumerRequest( - address(consumer), - defaultProvider - ); + (uint64 sequenceNumber, , ) = setupConsumerRequest(address(consumer)); // Create different priceIds bytes32[] memory wrongPriceIds = new bytes32[](2); @@ -682,7 +680,7 @@ contract PulseTest is Test, PulseEvents { } vm.deal(address(consumer), 1 gwei); - uint128 totalFee = calculateTotalFee(defaultProvider); + uint128 totalFee = calculateTotalFee(); vm.prank(address(consumer)); vm.expectRevert( @@ -752,7 +750,7 @@ contract PulseTest is Test, PulseEvents { bytes32[] memory priceIds = new bytes32[](1); priceIds[0] = bytes32(uint256(1)); - uint128 totalFee = pulse.getFee(CALLBACK_GAS_LIMIT, provider); + uint128 totalFee = pulse.getFee(CALLBACK_GAS_LIMIT); vm.deal(address(consumer), totalFee); vm.prank(address(consumer)); @@ -802,7 +800,7 @@ contract PulseTest is Test, PulseEvents { uint64 sequenceNumber, bytes32[] memory priceIds, uint256 publishTime - ) = setupConsumerRequest(address(consumer), defaultProvider); + ) = setupConsumerRequest(address(consumer)); // Setup mock data PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( @@ -832,7 +830,7 @@ contract PulseTest is Test, PulseEvents { uint64 sequenceNumber, bytes32[] memory priceIds, uint256 publishTime - ) = setupConsumerRequest(address(consumer), defaultProvider); + ) = setupConsumerRequest(address(consumer)); // Setup mock data PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( @@ -864,7 +862,7 @@ contract PulseTest is Test, PulseEvents { uint64 sequenceNumber, bytes32[] memory priceIds, uint256 publishTime - ) = setupConsumerRequest(address(consumer), defaultProvider); + ) = setupConsumerRequest(address(consumer)); // Setup mock data PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( From db135673ffcd1a47b778a29740e791156bbce715 Mon Sep 17 00:00:00 2001 From: Daniel Chew Date: Fri, 24 Jan 2025 13:36:54 +0900 Subject: [PATCH 7/9] remove provider from callback args --- target_chains/ethereum/contracts/contracts/pulse/IPulse.sol | 1 - target_chains/ethereum/contracts/contracts/pulse/Pulse.sol | 2 +- target_chains/ethereum/contracts/forge-test/Pulse.t.sol | 6 ------ 3 files changed, 1 insertion(+), 8 deletions(-) diff --git a/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol b/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol index 25c06d5517..2dd8239381 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol @@ -9,7 +9,6 @@ import "./PulseState.sol"; interface IPulseConsumer { function pulseCallback( uint64 sequenceNumber, - address provider, PythStructs.PriceFeed[] memory priceFeeds ) external; } diff --git a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol index 67cb0dd02a..7f941dd221 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol @@ -136,7 +136,7 @@ abstract contract Pulse is IPulse, PulseState { try IPulseConsumer(req.requester).pulseCallback{ gas: req.callbackGasLimit - }(sequenceNumber, msg.sender, priceFeeds) + }(sequenceNumber, priceFeeds) { // Callback succeeded emitPriceUpdate(sequenceNumber, priceIds, priceFeeds); diff --git a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol index 039de8423f..b6a0e4f51e 100644 --- a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol +++ b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol @@ -13,16 +13,13 @@ import "../contracts/pulse/PulseErrors.sol"; contract MockPulseConsumer is IPulseConsumer { uint64 public lastSequenceNumber; - address public lastProvider; PythStructs.PriceFeed[] private _lastPriceFeeds; function pulseCallback( uint64 sequenceNumber, - address provider, PythStructs.PriceFeed[] memory priceFeeds ) external override { lastSequenceNumber = sequenceNumber; - lastProvider = provider; for (uint i = 0; i < priceFeeds.length; i++) { _lastPriceFeeds.push(priceFeeds[i]); } @@ -40,7 +37,6 @@ contract MockPulseConsumer is IPulseConsumer { contract FailingPulseConsumer is IPulseConsumer { function pulseCallback( uint64, - address, PythStructs.PriceFeed[] memory ) external pure override { revert("callback failed"); @@ -52,7 +48,6 @@ contract CustomErrorPulseConsumer is IPulseConsumer { function pulseCallback( uint64, - address, PythStructs.PriceFeed[] memory ) external pure override { revert CustomError("callback failed"); @@ -324,7 +319,6 @@ contract PulseTest is Test, PulseEvents { // Verify callback was executed assertEq(consumer.lastSequenceNumber(), sequenceNumber); - assertEq(consumer.lastProvider(), defaultProvider); // Compare price feeds array length PythStructs.PriceFeed[] memory lastFeeds = consumer.lastPriceFeeds(); From 7365c1542f066975f7dd3279d8a65472476b6714 Mon Sep 17 00:00:00 2001 From: Daniel Chew Date: Fri, 24 Jan 2025 13:41:26 +0900 Subject: [PATCH 8/9] add comments --- .../ethereum/contracts/contracts/pulse/Pulse.sol | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol index 7f941dd221..a87b924c95 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol @@ -192,12 +192,12 @@ abstract contract Pulse is IPulse, PulseState { function getFee( uint256 callbackGasLimit ) public view override returns (uint128 feeAmount) { - uint128 baseFee = _state.pythFeeInWei; + uint128 baseFee = _state.pythFeeInWei; // Fixed fee to Pyth uint128 providerFeeInWei = _state .providers[_state.defaultProvider] - .feeInWei; - uint256 gasFee = callbackGasLimit * providerFeeInWei; - feeAmount = baseFee + SafeCast.toUint128(gasFee); + .feeInWei; // Provider's per-gas rate + uint256 gasFee = callbackGasLimit * providerFeeInWei; // Total provider fee based on gas + feeAmount = baseFee + SafeCast.toUint128(gasFee); // Total fee user needs to pay } function getPythFeeInWei() From 6e9c08d9a0d0d77109bb6c7bd1e3924b4d05a726 Mon Sep 17 00:00:00 2001 From: Daniel Chew Date: Fri, 24 Jan 2025 13:45:13 +0900 Subject: [PATCH 9/9] add comments --- target_chains/ethereum/contracts/contracts/pulse/Pulse.sol | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol index a87b924c95..4e768c30ae 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol @@ -29,6 +29,11 @@ abstract contract Pulse is IPulse, PulseState { _state.pythFeeInWei = pythFeeInWei; _state.pyth = pythAddress; _state.currentSequenceNumber = 1; + + // Two-step initialization process: + // 1. Set the default provider address here + // 2. Provider must call registerProvider() in a separate transaction to set their fee + // This ensures the provider maintains control over their own fee settings _state.defaultProvider = defaultProvider; _state.exclusivityPeriodSeconds = exclusivityPeriodSeconds;