From 6d0f2e665182cd2a4f450fc29610f031d201718c Mon Sep 17 00:00:00 2001 From: Daniel Chew Date: Mon, 4 Nov 2024 23:09:59 +0900 Subject: [PATCH 01/29] add initial contracts --- .../contracts/contracts/pulse/IPulse.sol | 144 ++++++ .../contracts/contracts/pulse/Pulse.sol | 411 ++++++++++++++++++ .../contracts/contracts/pulse/PulseErrors.sol | 10 + .../contracts/contracts/pulse/PulseState.sol | 39 ++ .../contracts/pulse/PulseUpgradeable.sol | 70 +++ .../ethereum/contracts/forge-test/Pulse.t.sol | 401 +++++++++++++++++ 6 files changed, 1075 insertions(+) create mode 100644 target_chains/ethereum/contracts/contracts/pulse/IPulse.sol create mode 100644 target_chains/ethereum/contracts/contracts/pulse/Pulse.sol create mode 100644 target_chains/ethereum/contracts/contracts/pulse/PulseErrors.sol create mode 100644 target_chains/ethereum/contracts/contracts/pulse/PulseState.sol create mode 100644 target_chains/ethereum/contracts/contracts/pulse/PulseUpgradeable.sol create mode 100644 target_chains/ethereum/contracts/forge-test/Pulse.t.sol diff --git a/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol b/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol new file mode 100644 index 0000000000..8bcf6f263d --- /dev/null +++ b/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol @@ -0,0 +1,144 @@ +// SPDX-License-Identifier: Apache 2 + +pragma solidity ^0.8.0; + +import "./PulseState.sol"; + +interface IPulseConsumer { + function pulseCallback( + uint64 sequenceNumber, + address provider, + uint256 publishTime, + bytes32[] calldata priceIds + ) external; +} + +interface IPulse { + // Events + event PriceUpdateRequested( + uint64 indexed sequenceNumber, + address indexed provider, + uint256 publishTime, + bytes32[] priceIds, + address requester + ); + + event PriceUpdateExecuted( + uint64 indexed sequenceNumber, + address indexed provider, + uint256 publishTime, + bytes32[] priceIds + ); + + event ProviderRegistered( + address indexed provider, + uint128 feeInWei, + bytes uri + ); + + event ProviderFeeUpdated( + address indexed provider, + uint128 oldFeeInWei, + uint128 newFeeInWei + ); + + event ProviderWithdrawn( + address indexed provider, + address indexed recipient, + uint128 amount + ); + + event ProviderFeeManagerUpdated( + address indexed provider, + address oldFeeManager, + address newFeeManager + ); + + event ProviderUriUpdated( + address indexed provider, + bytes oldUri, + bytes newUri + ); + + event ProviderMaxNumPricesUpdated( + address indexed provider, + uint32 oldMaxNumPrices, + uint32 maxNumPrices + ); + + event PriceUpdateCallbackFailed( + uint64 indexed sequenceNumber, + address indexed provider, + uint256 publishTime, + bytes32[] priceIds, + address requester, + string reason + ); + + // Core functions + function requestPriceUpdatesWithCallback( + address provider, + uint256 publishTime, + bytes32[] calldata priceIds, + bytes[] calldata updateData, + uint256 callbackGasLimit + ) external payable returns (uint64 sequenceNumber); + + function executeCallback( + uint64 sequenceNumber, + uint256 publishTime, + bytes32[] calldata priceIds, + bytes[] calldata updateData, + uint256 callbackGasLimit + ) external; + + // Provider management + function register(uint128 feeInWei, bytes calldata uri) external; + + function setProviderFee(uint128 newFeeInWei) external; + + function withdraw(uint128 amount) external; + + // Add to interface + function withdrawAsFeeManager(address provider, uint128 amount) external; + + // Add to Provider management section + function setProviderUri(bytes calldata uri) external; + + // Getters + function getFee(address provider) external view returns (uint128 feeAmount); + + function getDefaultProvider() external view returns (address); + + // Add to interface + function setFeeManager(address manager) external; + + // Add to interface + function setProviderFeeAsFeeManager( + address provider, + uint128 newFeeInWei + ) external; + + // Add to Getters section + function getAccruedPythFees() + external + view + returns (uint128 accruedPythFeesInWei); + + // Add to Getters section + function getProviderInfo( + address provider + ) external view returns (PulseState.ProviderInfo memory info); + + function getAdmin() external view returns (address admin); + + function getPythFeeInWei() external view returns (uint128 pythFeeInWei); + + function setMaxNumPrices(uint32 maxNumPrices) external; + + // Add to Getters section + function getRequest( + address provider, + uint64 sequenceNumber + ) external view returns (PulseState.Request memory req); +} diff --git a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol new file mode 100644 index 0000000000..e17c4da49f --- /dev/null +++ b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol @@ -0,0 +1,411 @@ +// SPDX-License-Identifier: Apache 2 + +pragma solidity ^0.8.0; + +import "@openzeppelin/contracts/security/ReentrancyGuard.sol"; +import "@openzeppelin/contracts/utils/math/SafeCast.sol"; +import "./IPulse.sol"; +import "./PulseState.sol"; +import "./PulseErrors.sol"; + +contract Pulse is IPulse, ReentrancyGuard, PulseState { + using SafeCast for uint256; + + function _initialize( + address admin, + uint128 pythFeeInWei, + address defaultProvider, + bool prefillRequestStorage + ) internal { + require(admin != address(0), "admin is zero address"); + require( + defaultProvider != address(0), + "defaultProvider is zero address" + ); + + _state.admin = admin; + _state.pythFeeInWei = pythFeeInWei; + _state.accruedPythFeesInWei = 0; + _state.defaultProvider = defaultProvider; + + if (prefillRequestStorage) { + // Prefill storage slots to make future requests use less gas + for (uint8 i = 0; i < NUM_REQUESTS; i++) { + Request storage req = _state.requests[i]; + req.provider = address(1); + req.sequenceNumber = 0; // Keep it inactive + req.publishTime = 1; + // No need to prefill dynamic arrays (priceIds, updateData) + req.callbackGasLimit = 1; + req.requester = address(1); + } + } + } + + function requestPriceUpdatesWithCallback( + address provider, + uint256 publishTime, + bytes32[] calldata priceIds, + bytes[] calldata updateData, + uint256 callbackGasLimit + ) + external + payable + override + nonReentrant + returns (uint64 requestSequenceNumber) + { + ProviderInfo storage providerInfo = _state.providers[provider]; + if (providerInfo.sequenceNumber == 0) revert NoSuchProvider(); + + if ( + providerInfo.maxNumPrices > 0 && + priceIds.length > providerInfo.maxNumPrices + ) { + revert("Exceeds max number of prices"); + } + + // Assign sequence number and increment + requestSequenceNumber = providerInfo.sequenceNumber++; + + // Verify fee payment + uint128 requiredFee = getFee(provider); + if (msg.value < requiredFee) revert InsufficientFee(); + + // Store request for callback execution + Request storage req = allocRequest(provider, requestSequenceNumber); + req.provider = provider; + req.sequenceNumber = requestSequenceNumber; + req.publishTime = publishTime; + req.priceIds = priceIds; + req.updateData = updateData; + req.callbackGasLimit = callbackGasLimit; + req.requester = msg.sender; + + // Update fee balances + providerInfo.accruedFeesInWei += providerInfo.feeInWei; + _state.accruedPythFeesInWei += (msg.value.toUint128() - + providerInfo.feeInWei); + + emit PriceUpdateRequested( + requestSequenceNumber, + provider, + publishTime, + priceIds, + msg.sender + ); + } + + function executeCallback( + uint64 sequenceNumber, + uint256 publishTime, + bytes32[] calldata priceIds, + bytes[] calldata updateData, + uint256 callbackGasLimit + ) external override nonReentrant { + Request storage req = findActiveRequest(msg.sender, sequenceNumber); + + // Verify request parameters match + require(req.publishTime == publishTime, "Invalid publish time"); + require( + keccak256(abi.encode(req.priceIds)) == + keccak256(abi.encode(priceIds)), + "Invalid price IDs" + ); + require( + keccak256(abi.encode(req.updateData)) == + keccak256(abi.encode(updateData)), + "Invalid update data" + ); + require( + req.callbackGasLimit == callbackGasLimit, + "Invalid callback gas limit" + ); + + // Execute callback but don't revert if it fails + try + IPulseConsumer(req.requester).pulseCallback( + sequenceNumber, + msg.sender, + publishTime, + priceIds + ) + { + // Callback succeeded + emit PriceUpdateExecuted( + sequenceNumber, + msg.sender, + publishTime, + priceIds + ); + } catch Error(string memory reason) { + // Explicit revert/require + emit PriceUpdateCallbackFailed( + sequenceNumber, + msg.sender, + publishTime, + priceIds, + req.requester, + reason + ); + } catch { + // Out of gas or other low-level errors + emit PriceUpdateCallbackFailed( + sequenceNumber, + msg.sender, + publishTime, + priceIds, + req.requester, + "low-level error (possibly out of gas)" + ); + } + + // Clear request regardless of callback success + clearRequest(msg.sender, sequenceNumber); + } + + function register(uint128 feeInWei, bytes calldata uri) public override { + ProviderInfo storage provider = _state.providers[msg.sender]; + + provider.feeInWei = feeInWei; + provider.uri = uri; + + if (provider.sequenceNumber == 0) { + provider.sequenceNumber = 1; + } + + emit ProviderRegistered(msg.sender, feeInWei, uri); + } + + function setProviderFee(uint128 newFeeInWei) external override { + ProviderInfo storage provider = _state.providers[msg.sender]; + if (provider.sequenceNumber == 0) revert NoSuchProvider(); + + uint128 oldFeeInWei = provider.feeInWei; + provider.feeInWei = newFeeInWei; + + emit ProviderFeeUpdated(msg.sender, oldFeeInWei, newFeeInWei); + } + + function getFee( + address provider + ) public view override returns (uint128 feeAmount) { + feeAmount = _state.providers[provider].feeInWei + _state.pythFeeInWei; + } + + function getDefaultProvider() + external + view + override + returns (address defaultProvider) + { + defaultProvider = _state.defaultProvider; + } + + // Internal helper functions + function findActiveRequest( + address provider, + uint64 sequenceNumber + ) internal view returns (Request storage activeRequest) { + activeRequest = findRequest(provider, sequenceNumber); + if ( + !isActive(activeRequest) || + activeRequest.provider != provider || + activeRequest.sequenceNumber != sequenceNumber + ) { + revert NoSuchRequest(); + } + } + + function findRequest( + address provider, + uint64 sequenceNumber + ) internal view returns (Request storage foundRequest) { + (bytes32 key, uint8 shortKey) = requestKey(provider, sequenceNumber); + foundRequest = _state.requests[shortKey]; + + if ( + foundRequest.provider == provider && + foundRequest.sequenceNumber == sequenceNumber + ) { + return foundRequest; + } else { + foundRequest = _state.requestsOverflow[key]; + } + } + + function clearRequest(address provider, uint64 sequenceNumber) internal { + (bytes32 key, uint8 shortKey) = requestKey(provider, sequenceNumber); + Request storage req = _state.requests[shortKey]; + + if (req.provider == provider && req.sequenceNumber == sequenceNumber) { + req.sequenceNumber = 0; + } else { + delete _state.requestsOverflow[key]; + } + } + + function allocRequest( + address provider, + uint64 sequenceNumber + ) internal returns (Request storage newRequest) { + (, uint8 shortKey) = requestKey(provider, sequenceNumber); + newRequest = _state.requests[shortKey]; + + if (isActive(newRequest)) { + (bytes32 reqKey, ) = requestKey( + newRequest.provider, + newRequest.sequenceNumber + ); + _state.requestsOverflow[reqKey] = newRequest; + } + } + + function requestKey( + address provider, + uint64 sequenceNumber + ) internal pure returns (bytes32 hashKey, uint8 shortHashKey) { + hashKey = keccak256(abi.encodePacked(provider, sequenceNumber)); + shortHashKey = uint8(hashKey[0] & NUM_REQUESTS_MASK); + } + + function isActive( + Request storage req + ) internal view returns (bool isRequestActive) { + isRequestActive = req.sequenceNumber != 0; + } + + function withdraw(uint128 amount) public override { + ProviderInfo storage providerInfo = _state.providers[msg.sender]; + + // Use checks-effects-interactions pattern to prevent reentrancy attacks + require( + providerInfo.accruedFeesInWei >= amount, + "Insufficient balance" + ); + providerInfo.accruedFeesInWei -= amount; + + // Interaction with an external contract or token transfer + (bool sent, ) = msg.sender.call{value: amount}(""); + require(sent, "withdrawal to msg.sender failed"); + + emit ProviderWithdrawn(msg.sender, msg.sender, amount); + } + + function withdrawAsFeeManager( + address provider, + uint128 amount + ) external override { + ProviderInfo storage providerInfo = _state.providers[provider]; + + if (providerInfo.sequenceNumber == 0) { + revert NoSuchProvider(); + } + + if (providerInfo.feeManager != msg.sender) { + revert Unauthorized(); + } + + // Use checks-effects-interactions pattern to prevent reentrancy attacks + require( + providerInfo.accruedFeesInWei >= amount, + "Insufficient balance" + ); + providerInfo.accruedFeesInWei -= amount; + + // Interaction with an external contract or token transfer + (bool sent, ) = msg.sender.call{value: amount}(""); + require(sent, "withdrawal to msg.sender failed"); + + emit ProviderWithdrawn(provider, msg.sender, amount); + } + + function setFeeManager(address manager) external override { + ProviderInfo storage provider = _state.providers[msg.sender]; + if (provider.sequenceNumber == 0) revert NoSuchProvider(); + + address oldFeeManager = provider.feeManager; + provider.feeManager = manager; + + emit ProviderFeeManagerUpdated(msg.sender, oldFeeManager, manager); + } + + function setProviderFeeAsFeeManager( + address provider, + uint128 newFeeInWei + ) external override { + ProviderInfo storage providerInfo = _state.providers[provider]; + + if (providerInfo.sequenceNumber == 0) { + revert NoSuchProvider(); + } + + if (providerInfo.feeManager != msg.sender) { + revert Unauthorized(); + } + + uint128 oldFeeInWei = providerInfo.feeInWei; + providerInfo.feeInWei = newFeeInWei; + + emit ProviderFeeUpdated(provider, oldFeeInWei, newFeeInWei); + } + + function getAccruedPythFees() + public + view + override + returns (uint128 accruedPythFeesInWei) + { + accruedPythFeesInWei = _state.accruedPythFeesInWei; + } + + function getProviderInfo( + address provider + ) public view override returns (ProviderInfo memory info) { + info = _state.providers[provider]; + } + + function getAdmin() external view override returns (address adminAddress) { + adminAddress = _state.admin; + } + + function getPythFeeInWei() + external + view + override + returns (uint128 pythFee) + { + pythFee = _state.pythFeeInWei; + } + + function setProviderUri(bytes calldata uri) external override { + ProviderInfo storage provider = _state.providers[msg.sender]; + if (provider.sequenceNumber == 0) revert NoSuchProvider(); + + bytes memory oldUri = provider.uri; + provider.uri = uri; + + emit ProviderUriUpdated(msg.sender, oldUri, uri); + } + + function setMaxNumPrices(uint32 maxNumPrices) external override { + ProviderInfo storage provider = _state.providers[msg.sender]; + if (provider.sequenceNumber == 0) revert NoSuchProvider(); + + uint32 oldMaxNumPrices = provider.maxNumPrices; + provider.maxNumPrices = maxNumPrices; + + emit ProviderMaxNumPricesUpdated( + msg.sender, + oldMaxNumPrices, + maxNumPrices + ); + } + + function getRequest( + address provider, + uint64 sequenceNumber + ) public view override returns (Request memory req) { + req = findRequest(provider, sequenceNumber); + } +} diff --git a/target_chains/ethereum/contracts/contracts/pulse/PulseErrors.sol b/target_chains/ethereum/contracts/contracts/pulse/PulseErrors.sol new file mode 100644 index 0000000000..187ccf00a2 --- /dev/null +++ b/target_chains/ethereum/contracts/contracts/pulse/PulseErrors.sol @@ -0,0 +1,10 @@ +// SPDX-License-Identifier: Apache 2 + +pragma solidity ^0.8.0; + +error NoSuchProvider(); +error NoSuchRequest(); +error InsufficientFee(); +error Unauthorized(); +error InvalidCallbackGas(); +error CallbackFailed(); diff --git a/target_chains/ethereum/contracts/contracts/pulse/PulseState.sol b/target_chains/ethereum/contracts/contracts/pulse/PulseState.sol new file mode 100644 index 0000000000..839a6a0d21 --- /dev/null +++ b/target_chains/ethereum/contracts/contracts/pulse/PulseState.sol @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: Apache 2 + +pragma solidity ^0.8.0; + +contract PulseState { + uint8 public constant NUM_REQUESTS = 32; + bytes1 public constant NUM_REQUESTS_MASK = 0x1f; + + struct Request { + address provider; + uint64 sequenceNumber; + uint256 publishTime; + bytes32[] priceIds; + bytes[] updateData; + uint256 callbackGasLimit; + address requester; + } + + struct ProviderInfo { + uint64 sequenceNumber; + uint128 feeInWei; + uint128 accruedFeesInWei; + bytes uri; + address feeManager; + uint32 maxNumPrices; + } + + struct State { + address admin; + uint128 pythFeeInWei; + uint128 accruedPythFeesInWei; + address defaultProvider; + Request[32] requests; + mapping(bytes32 => Request) requestsOverflow; + mapping(address => ProviderInfo) providers; + } + + State internal _state; +} diff --git a/target_chains/ethereum/contracts/contracts/pulse/PulseUpgradeable.sol b/target_chains/ethereum/contracts/contracts/pulse/PulseUpgradeable.sol new file mode 100644 index 0000000000..191ef15889 --- /dev/null +++ b/target_chains/ethereum/contracts/contracts/pulse/PulseUpgradeable.sol @@ -0,0 +1,70 @@ +// SPDX-License-Identifier: Apache 2 + +pragma solidity ^0.8.0; + +import "@openzeppelin/contracts-upgradeable/proxy/utils/Initializable.sol"; +import "@openzeppelin/contracts-upgradeable/proxy/utils/UUPSUpgradeable.sol"; +import "@openzeppelin/contracts-upgradeable/access/Ownable2StepUpgradeable.sol"; +import "./Pulse.sol"; + +contract PulseUpgradeable is + Initializable, + Ownable2StepUpgradeable, + UUPSUpgradeable, + Pulse +{ + event ContractUpgraded( + address oldImplementation, + address newImplementation + ); + + function initialize( + address owner, + address admin, + uint128 pythFeeInWei, + address defaultProvider, + bool prefillRequestStorage + ) public initializer { + require(owner != address(0), "owner is zero address"); + + __Ownable_init(); + __UUPSUpgradeable_init(); + + Pulse._initialize( + admin, + pythFeeInWei, + defaultProvider, + prefillRequestStorage + ); + + _transferOwnership(owner); + } + + /// @custom:oz-upgrades-unsafe-allow constructor + constructor() initializer {} + + function _authorizeUpgrade(address) internal override onlyOwner {} + + function upgradeTo(address newImplementation) external override onlyProxy { + address oldImplementation = _getImplementation(); + _authorizeUpgrade(newImplementation); + _upgradeToAndCallUUPS(newImplementation, new bytes(0), false); + + emit ContractUpgraded(oldImplementation, _getImplementation()); + } + + function upgradeToAndCall( + address newImplementation, + bytes memory data + ) external payable override onlyProxy { + address oldImplementation = _getImplementation(); + _authorizeUpgrade(newImplementation); + _upgradeToAndCallUUPS(newImplementation, data, true); + + emit ContractUpgraded(oldImplementation, _getImplementation()); + } + + function version() public pure returns (string memory) { + return "1.0.0"; + } +} diff --git a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol new file mode 100644 index 0000000000..615d45eb7f --- /dev/null +++ b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol @@ -0,0 +1,401 @@ +// SPDX-License-Identifier: Apache 2 + +pragma solidity ^0.8.0; + +import "forge-std/Test.sol"; +import "../contracts/pulse/PulseUpgradeable.sol"; +import "../contracts/pulse/IPulse.sol"; +import "../contracts/pulse/PulseState.sol"; + +contract MockPulseConsumer is IPulseConsumer { + uint64 public lastSequenceNumber; + address public lastProvider; + uint256 public lastPublishTime; + bytes32[] public lastPriceIds; + + function pulseCallback( + uint64 sequenceNumber, + address provider, + uint256 publishTime, + bytes32[] calldata priceIds + ) external override { + lastSequenceNumber = sequenceNumber; + lastProvider = provider; + lastPublishTime = publishTime; + lastPriceIds = priceIds; + } +} + +contract PulseTest is Test { + PulseUpgradeable public pulse; + MockPulseConsumer public consumer; + address public owner; + address public admin; + address public provider; + uint128 constant PYTH_FEE = 0.001 ether; + uint128 constant PROVIDER_FEE = 0.002 ether; + + function setUp() public { + owner = address(1); + admin = address(2); + provider = address(3); + + // Deploy contracts + pulse = new PulseUpgradeable(); + pulse.initialize(owner, admin, PYTH_FEE, provider); + consumer = new MockPulseConsumer(); + + // Register provider + vm.prank(provider); + pulse.register(PROVIDER_FEE, "https://provider.com"); + } + + function testRequestPriceUpdate() public { + bytes32[] memory priceIds = new bytes32[](2); + priceIds[0] = bytes32("BTC/USD"); + priceIds[1] = bytes32("ETH/USD"); + + bytes[] memory updateData = new bytes[](2); + updateData[0] = bytes("data1"); + updateData[1] = bytes("data2"); + + uint256 publishTime = block.timestamp; + uint256 callbackGasLimit = 500000; + + vm.prank(address(consumer)); + uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{ + value: PYTH_FEE + PROVIDER_FEE + }(provider, publishTime, priceIds, updateData, callbackGasLimit); + + assertEq(sequenceNumber, 1); + } + + function testExecuteCallback() public { + bytes32[] memory priceIds = new bytes32[](2); + priceIds[0] = bytes32("BTC/USD"); + priceIds[1] = bytes32("ETH/USD"); + + bytes[] memory updateData = new bytes[](2); + updateData[0] = bytes("data1"); + updateData[1] = bytes("data2"); + + uint256 publishTime = block.timestamp; + uint256 callbackGasLimit = 500000; + + vm.prank(address(consumer)); + uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{ + value: PYTH_FEE + PROVIDER_FEE + }(provider, publishTime, priceIds, updateData, callbackGasLimit); + + vm.prank(provider); + pulse.executeCallback( + sequenceNumber, + publishTime, + priceIds, + updateData, + callbackGasLimit + ); + + assertEq(consumer.lastSequenceNumber(), sequenceNumber); + assertEq(consumer.lastProvider(), provider); + assertEq(consumer.lastPublishTime(), publishTime); + assertEq(consumer.lastPriceIds()[0], priceIds[0]); + assertEq(consumer.lastPriceIds()[1], priceIds[1]); + } + + function testProviderRegistration() public { + address newProvider = address(4); + vm.prank(newProvider); + pulse.register(PROVIDER_FEE, "https://newprovider.com"); + + uint128 fee = pulse.getFee(newProvider); + assertEq(fee, PYTH_FEE + PROVIDER_FEE); + } + + function testUpdateProviderFee() public { + uint128 newFee = 0.003 ether; + vm.prank(provider); + pulse.setProviderFee(newFee); + + uint128 fee = pulse.getFee(provider); + assertEq(fee, PYTH_FEE + newFee); + } + + function testFailInsufficientFee() public { + bytes32[] memory priceIds = new bytes32[](1); + bytes[] memory updateData = new bytes[](1); + + vm.prank(address(consumer)); + pulse.requestPriceUpdatesWithCallback{value: PYTH_FEE}( // Not paying provider fee + provider, + block.timestamp, + priceIds, + updateData, + 500000 + ); + } + + function testFailUnregisteredProvider() public { + bytes32[] memory priceIds = new bytes32[](1); + bytes[] memory updateData = new bytes[](1); + + vm.prank(address(consumer)); + pulse.requestPriceUpdatesWithCallback{value: PYTH_FEE + PROVIDER_FEE}( + address(99), // Unregistered provider + block.timestamp, + priceIds, + updateData, + 500000 + ); + } + + function testGasCostsWithPrefill() public { + // Deploy with prefill + PulseUpgradeable pulseWithPrefill = new PulseUpgradeable(); + pulseWithPrefill.initialize(owner, admin, PYTH_FEE, provider, true); + + // Measure gas for first request + uint256 gasBefore = gasleft(); + makeRequest(address(pulseWithPrefill)); + uint256 gasUsed = gasBefore - gasleft(); + + // Should be lower due to prefill + assertLt(gasUsed, 30000); + } + + function testGasCostsWithoutPrefill() public { + // Deploy without prefill + PulseUpgradeable pulseWithoutPrefill = new PulseUpgradeable(); + pulseWithoutPrefill.initialize(owner, admin, PYTH_FEE, provider, false); + + // Measure gas for first request + uint256 gasBefore = gasleft(); + makeRequest(address(pulseWithoutPrefill)); + uint256 gasUsed = gasBefore - gasleft(); + + // Should be higher without prefill + assertGt(gasUsed, 35000); + } + + function makeRequest(address pulseAddress) internal { + // Helper to make a standard request + bytes32[] memory priceIds = new bytes32[](1); + bytes[] memory updateData = new bytes[](1); + IPulse(pulseAddress).requestPriceUpdatesWithCallback{ + value: PYTH_FEE + PROVIDER_FEE + }(provider, block.timestamp, priceIds, updateData, 500000); + } + + function testWithdraw() public { + // Setup - make a request to accrue some fees + bytes32[] memory priceIds = new bytes32[](1); + bytes[] memory updateData = new bytes[](1); + + vm.prank(address(consumer)); + pulse.requestPriceUpdatesWithCallback{value: PYTH_FEE + PROVIDER_FEE}( + provider, + block.timestamp, + priceIds, + updateData, + 500000 + ); + + // Check provider balance before withdrawal + uint256 providerBalanceBefore = address(provider).balance; + + // Provider withdraws fees + vm.prank(provider); + pulse.withdraw(PROVIDER_FEE); + + // Verify balance increased + assertEq( + address(provider).balance, + providerBalanceBefore + PROVIDER_FEE + ); + } + + function testFailWithdrawTooMuch() public { + vm.prank(provider); + pulse.withdraw(1 ether); // Try to withdraw more than accrued + } + + function testFailWithdrawUnregistered() public { + vm.prank(address(99)); // Unregistered provider + pulse.withdraw(1 ether); + } + + function testWithdrawAsFeeManager() public { + // Setup fee manager + vm.prank(provider); + pulse.setFeeManager(address(99)); + + // Setup fees + bytes32[] memory priceIds = new bytes32[](1); + bytes[] memory updateData = new bytes[](1); + + vm.prank(address(consumer)); + pulse.requestPriceUpdatesWithCallback{value: PYTH_FEE + PROVIDER_FEE}( + provider, + block.timestamp, + priceIds, + updateData, + 500000 + ); + + // Check fee manager balance before withdrawal + uint256 managerBalanceBefore = address(99).balance; + + // Fee manager withdraws + vm.prank(address(99)); + pulse.withdrawAsFeeManager(provider, PROVIDER_FEE); + + // Verify balance increased + assertEq(address(99).balance, managerBalanceBefore + PROVIDER_FEE); + } + + function testFailWithdrawAsFeeManagerUnauthorized() public { + vm.prank(address(88)); // Not the fee manager + pulse.withdrawAsFeeManager(provider, PROVIDER_FEE); + } + + function testSetProviderFeeAsFeeManager() public { + // Setup fee manager + vm.prank(provider); + pulse.setFeeManager(address(99)); + + uint128 newFee = 0.005 ether; + + // Fee manager updates fee + vm.prank(address(99)); + pulse.setProviderFeeAsFeeManager(provider, newFee); + + // Verify fee was updated + uint128 fee = pulse.getFee(provider); + assertEq(fee, PYTH_FEE + newFee); + } + + function testFailSetProviderFeeAsFeeManagerUnauthorized() public { + vm.prank(address(88)); // Not the fee manager + pulse.setProviderFeeAsFeeManager(provider, 0.005 ether); + } + + function testGetAccruedPythFees() public { + // Setup - make a request to accrue some fees + bytes32[] memory priceIds = new bytes32[](1); + bytes[] memory updateData = new bytes[](1); + + vm.prank(address(consumer)); + pulse.requestPriceUpdatesWithCallback{value: PYTH_FEE + PROVIDER_FEE}( + provider, + block.timestamp, + priceIds, + updateData, + 500000 + ); + + // Verify accrued fees + assertEq(pulse.getAccruedPythFees(), PYTH_FEE); + } + + function testGetProviderInfo() public { + // Get provider info + PulseState.ProviderInfo memory info = pulse.getProviderInfo(provider); + + // Verify initial values + assertEq(info.sequenceNumber, 1); // Set during registration + assertEq(info.feeInWei, PROVIDER_FEE); + assertEq(info.accruedFeesInWei, 0); + assertEq(string(info.uri), "https://provider.com"); + assertEq(info.feeManager, address(0)); + assertEq(info.maxNumPrices, 0); + } + + function testGetAdmin() public { + assertEq(pulse.getAdmin(), admin); + } + + function testGetPythFeeInWei() public { + assertEq(pulse.getPythFeeInWei(), PYTH_FEE); + } + + function testSetProviderUri() public { + bytes memory newUri = bytes("https://new-provider-endpoint.com"); + + vm.prank(provider); + pulse.setProviderUri(newUri); + + // Get provider info and verify URI was updated + (, , , bytes memory uri, ) = pulse.getProviderInfo(provider); + assertEq(string(uri), string(newUri)); + } + + function testFailSetProviderUriUnregistered() public { + vm.prank(address(99)); // Unregistered provider + pulse.setProviderUri(bytes("https://new-uri.com")); + } + + function testSetMaxNumPrices() public { + uint32 maxPrices = 5; + + vm.prank(provider); + pulse.setMaxNumPrices(maxPrices); + + // Get provider info and verify maxNumPrices was updated + (, , , , address feeManager) = pulse.getProviderInfo(provider); + assertEq(uint256(maxPrices), uint256(maxPrices)); + } + + function testFailExceedMaxNumPrices() public { + // Set max prices to 2 + vm.prank(provider); + pulse.setMaxNumPrices(2); + + // Try to request 3 prices + bytes32[] memory priceIds = new bytes32[](3); + bytes[] memory updateData = new bytes[](3); + + vm.prank(address(consumer)); + pulse.requestPriceUpdatesWithCallback{value: PYTH_FEE + PROVIDER_FEE}( + provider, + block.timestamp, + priceIds, + updateData, + 500000 + ); + } + + function testGetRequest() public { + // Setup - make a request + bytes32[] memory priceIds = new bytes32[](2); + priceIds[0] = bytes32("BTC/USD"); + priceIds[1] = bytes32("ETH/USD"); + + bytes[] memory updateData = new bytes[](2); + updateData[0] = bytes("data1"); + updateData[1] = bytes("data2"); + + uint256 publishTime = block.timestamp; + uint256 callbackGasLimit = 500000; + + vm.prank(address(consumer)); + uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{ + value: PYTH_FEE + PROVIDER_FEE + }(provider, publishTime, priceIds, updateData, callbackGasLimit); + + // Get request and verify + PulseState.Request memory req = pulse.getRequest( + provider, + sequenceNumber + ); + + assertEq(req.provider, provider); + assertEq(req.sequenceNumber, sequenceNumber); + assertEq(req.publishTime, publishTime); + assertEq(req.priceIds[0], priceIds[0]); + assertEq(req.priceIds[1], priceIds[1]); + assertEq(string(req.updateData[0]), string(updateData[0])); + assertEq(string(req.updateData[1]), string(updateData[1])); + assertEq(req.callbackGasLimit, callbackGasLimit); + assertEq(req.requester, address(consumer)); + } +} From 0ac734761924ca18b8493804913060f3c6542fc2 Mon Sep 17 00:00:00 2001 From: Daniel Chew Date: Mon, 4 Nov 2024 23:13:44 +0900 Subject: [PATCH 02/29] refactor --- .../contracts/contracts/pulse/IPulse.sol | 26 +++++++------------ 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol b/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol index 8bcf6f263d..4fed40092c 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol @@ -99,10 +99,8 @@ interface IPulse { function withdraw(uint128 amount) external; - // Add to interface function withdrawAsFeeManager(address provider, uint128 amount) external; - // Add to Provider management section function setProviderUri(bytes calldata uri) external; // Getters @@ -110,22 +108,11 @@ interface IPulse { function getDefaultProvider() external view returns (address); - // Add to interface - function setFeeManager(address manager) external; - - // Add to interface - function setProviderFeeAsFeeManager( - address provider, - uint128 newFeeInWei - ) external; - - // Add to Getters section function getAccruedPythFees() external view returns (uint128 accruedPythFeesInWei); - // Add to Getters section function getProviderInfo( address provider ) external view returns (PulseState.ProviderInfo memory info); @@ -134,11 +121,18 @@ interface IPulse { function getPythFeeInWei() external view returns (uint128 pythFeeInWei); - function setMaxNumPrices(uint32 maxNumPrices) external; - - // Add to Getters section function getRequest( address provider, uint64 sequenceNumber ) external view returns (PulseState.Request memory req); + + // Setters + function setFeeManager(address manager) external; + + function setProviderFeeAsFeeManager( + address provider, + uint128 newFeeInWei + ) external; + + function setMaxNumPrices(uint32 maxNumPrices) external; } From 24e9f486368fc79457b6e9dc57c5fbea43bea4ab Mon Sep 17 00:00:00 2001 From: Daniel Chew Date: Wed, 6 Nov 2024 15:01:40 +0900 Subject: [PATCH 03/29] fix --- .../contracts/contracts/pulse/IPulse.sol | 51 +-- .../contracts/contracts/pulse/Pulse.sol | 390 +++++++++--------- .../ethereum/contracts/forge-test/Pulse.t.sol | 99 +++-- 3 files changed, 271 insertions(+), 269 deletions(-) diff --git a/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol b/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol index 4fed40092c..e7fac6da7c 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol @@ -15,13 +15,9 @@ interface IPulseConsumer { interface IPulse { // Events - event PriceUpdateRequested( - uint64 indexed sequenceNumber, - address indexed provider, - uint256 publishTime, - bytes32[] priceIds, - address requester - ); + event ProviderRegistered(PulseState.ProviderInfo providerInfo); + + event PriceUpdateRequested(PulseState.Request request); event PriceUpdateExecuted( uint64 indexed sequenceNumber, @@ -30,18 +26,18 @@ interface IPulse { bytes32[] priceIds ); - event ProviderRegistered( - address indexed provider, - uint128 feeInWei, - bytes uri - ); - event ProviderFeeUpdated( address indexed provider, uint128 oldFeeInWei, uint128 newFeeInWei ); + event ProviderUriUpdated( + address indexed provider, + bytes oldUri, + bytes newUri + ); + event ProviderWithdrawn( address indexed provider, address indexed recipient, @@ -54,12 +50,6 @@ interface IPulse { address newFeeManager ); - event ProviderUriUpdated( - address indexed provider, - bytes oldUri, - bytes newUri - ); - event ProviderMaxNumPricesUpdated( address indexed provider, uint32 oldMaxNumPrices, @@ -80,7 +70,6 @@ interface IPulse { address provider, uint256 publishTime, bytes32[] calldata priceIds, - bytes[] calldata updateData, uint256 callbackGasLimit ) external payable returns (uint64 sequenceNumber); @@ -97,30 +86,33 @@ interface IPulse { function setProviderFee(uint128 newFeeInWei) external; + function setProviderFeeAsFeeManager( + address provider, + uint128 newFeeInWei + ) external; + + function setProviderUri(bytes calldata uri) external; + function withdraw(uint128 amount) external; function withdrawAsFeeManager(address provider, uint128 amount) external; - function setProviderUri(bytes calldata uri) external; - // Getters function getFee(address provider) external view returns (uint128 feeAmount); - function getDefaultProvider() external view returns (address); + function getPythFeeInWei() external view returns (uint128 pythFeeInWei); function getAccruedPythFees() external view returns (uint128 accruedPythFeesInWei); + function getDefaultProvider() external view returns (address); + function getProviderInfo( address provider ) external view returns (PulseState.ProviderInfo memory info); - function getAdmin() external view returns (address admin); - - function getPythFeeInWei() external view returns (uint128 pythFeeInWei); - function getRequest( address provider, uint64 sequenceNumber @@ -129,10 +121,5 @@ interface IPulse { // Setters function setFeeManager(address manager) external; - function setProviderFeeAsFeeManager( - address provider, - uint128 newFeeInWei - ) external; - function setMaxNumPrices(uint32 maxNumPrices) external; } diff --git a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol index e17c4da49f..e6397a2f65 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol @@ -2,15 +2,12 @@ pragma solidity ^0.8.0; -import "@openzeppelin/contracts/security/ReentrancyGuard.sol"; import "@openzeppelin/contracts/utils/math/SafeCast.sol"; import "./IPulse.sol"; import "./PulseState.sol"; import "./PulseErrors.sol"; -contract Pulse is IPulse, ReentrancyGuard, PulseState { - using SafeCast for uint256; - +abstract contract Pulse is IPulse, PulseState { function _initialize( address admin, uint128 pythFeeInWei, @@ -24,15 +21,16 @@ contract Pulse is IPulse, ReentrancyGuard, PulseState { ); _state.admin = admin; - _state.pythFeeInWei = pythFeeInWei; _state.accruedPythFeesInWei = 0; + _state.pythFeeInWei = pythFeeInWei; _state.defaultProvider = defaultProvider; if (prefillRequestStorage) { - // Prefill storage slots to make future requests use less gas + // Write some data to every storage slot in the requests array such that new requests + // use a more consistent amount of gas. + // Note that these requests are not live because their sequenceNumber is 0. for (uint8 i = 0; i < NUM_REQUESTS; i++) { Request storage req = _state.requests[i]; - req.provider = address(1); req.sequenceNumber = 0; // Keep it inactive req.publishTime = 1; // No need to prefill dynamic arrays (priceIds, updateData) @@ -42,19 +40,67 @@ contract Pulse is IPulse, ReentrancyGuard, PulseState { } } + function register(uint128 feeInWei, bytes calldata uri) public override { + ProviderInfo storage providerInfo = _state.providers[msg.sender]; + + providerInfo.feeInWei = feeInWei; + providerInfo.uri = uri; + providerInfo.sequenceNumber += 1; + + emit ProviderRegistered(providerInfo); + } + + function withdraw(uint128 amount) public override { + ProviderInfo storage providerInfo = _state.providers[msg.sender]; + + // Use checks-effects-interactions pattern to prevent reentrancy attacks. + require( + providerInfo.accruedFeesInWei >= amount, + "Insufficient balance" + ); + providerInfo.accruedFeesInWei -= amount; + + // Interaction with an external contract or token transfer + (bool sent, ) = msg.sender.call{value: amount}(""); + require(sent, "withdrawal to msg.sender failed"); + + emit ProviderWithdrawn(msg.sender, msg.sender, amount); + } + + function withdrawAsFeeManager( + address provider, + uint128 amount + ) external override { + ProviderInfo storage providerInfo = _state.providers[provider]; + + if (providerInfo.sequenceNumber == 0) { + revert NoSuchProvider(); + } + + if (providerInfo.feeManager != msg.sender) { + revert Unauthorized(); + } + + // Use checks-effects-interactions pattern to prevent reentrancy attacks. + require( + providerInfo.accruedFeesInWei >= amount, + "Insufficient balance" + ); + providerInfo.accruedFeesInWei -= amount; + + // Interaction with an external contract or token transfer + (bool sent, ) = msg.sender.call{value: amount}(""); + require(sent, "withdrawal to msg.sender failed"); + + emit ProviderWithdrawn(provider, msg.sender, amount); + } + function requestPriceUpdatesWithCallback( address provider, uint256 publishTime, bytes32[] calldata priceIds, - bytes[] calldata updateData, uint256 callbackGasLimit - ) - external - payable - override - nonReentrant - returns (uint64 requestSequenceNumber) - { + ) external payable override returns (uint64 requestSequenceNumber) { ProviderInfo storage providerInfo = _state.providers[provider]; if (providerInfo.sequenceNumber == 0) revert NoSuchProvider(); @@ -78,22 +124,16 @@ contract Pulse is IPulse, ReentrancyGuard, PulseState { req.sequenceNumber = requestSequenceNumber; req.publishTime = publishTime; req.priceIds = priceIds; - req.updateData = updateData; req.callbackGasLimit = callbackGasLimit; req.requester = msg.sender; // Update fee balances providerInfo.accruedFeesInWei += providerInfo.feeInWei; - _state.accruedPythFeesInWei += (msg.value.toUint128() - - providerInfo.feeInWei); - - emit PriceUpdateRequested( - requestSequenceNumber, - provider, - publishTime, - priceIds, - msg.sender - ); + _state.accruedPythFeesInWei += + SafeCast.toUint128(msg.value) - + providerInfo.feeInWei; + + emit PriceUpdateRequested(req); } function executeCallback( @@ -102,7 +142,7 @@ contract Pulse is IPulse, ReentrancyGuard, PulseState { bytes32[] calldata priceIds, bytes[] calldata updateData, uint256 callbackGasLimit - ) external override nonReentrant { + ) external override { Request storage req = findActiveRequest(msg.sender, sequenceNumber); // Verify request parameters match @@ -164,137 +204,67 @@ contract Pulse is IPulse, ReentrancyGuard, PulseState { clearRequest(msg.sender, sequenceNumber); } - function register(uint128 feeInWei, bytes calldata uri) public override { - ProviderInfo storage provider = _state.providers[msg.sender]; - - provider.feeInWei = feeInWei; - provider.uri = uri; - - if (provider.sequenceNumber == 0) { - provider.sequenceNumber = 1; - } - - emit ProviderRegistered(msg.sender, feeInWei, uri); - } - - function setProviderFee(uint128 newFeeInWei) external override { - ProviderInfo storage provider = _state.providers[msg.sender]; - if (provider.sequenceNumber == 0) revert NoSuchProvider(); - - uint128 oldFeeInWei = provider.feeInWei; - provider.feeInWei = newFeeInWei; - - emit ProviderFeeUpdated(msg.sender, oldFeeInWei, newFeeInWei); - } - - function getFee( + function getProviderInfo( address provider - ) public view override returns (uint128 feeAmount) { - feeAmount = _state.providers[provider].feeInWei + _state.pythFeeInWei; + ) public view override returns (ProviderInfo memory info) { + info = _state.providers[provider]; } function getDefaultProvider() - external + public view override - returns (address defaultProvider) + returns (address provider) { - defaultProvider = _state.defaultProvider; - } - - // Internal helper functions - function findActiveRequest( - address provider, - uint64 sequenceNumber - ) internal view returns (Request storage activeRequest) { - activeRequest = findRequest(provider, sequenceNumber); - if ( - !isActive(activeRequest) || - activeRequest.provider != provider || - activeRequest.sequenceNumber != sequenceNumber - ) { - revert NoSuchRequest(); - } + provider = _state.defaultProvider; } - function findRequest( + function getRequest( address provider, uint64 sequenceNumber - ) internal view returns (Request storage foundRequest) { - (bytes32 key, uint8 shortKey) = requestKey(provider, sequenceNumber); - foundRequest = _state.requests[shortKey]; - - if ( - foundRequest.provider == provider && - foundRequest.sequenceNumber == sequenceNumber - ) { - return foundRequest; - } else { - foundRequest = _state.requestsOverflow[key]; - } - } - - function clearRequest(address provider, uint64 sequenceNumber) internal { - (bytes32 key, uint8 shortKey) = requestKey(provider, sequenceNumber); - Request storage req = _state.requests[shortKey]; - - if (req.provider == provider && req.sequenceNumber == sequenceNumber) { - req.sequenceNumber = 0; - } else { - delete _state.requestsOverflow[key]; - } + ) public view override returns (Request memory req) { + req = findRequest(provider, sequenceNumber); } - function allocRequest( - address provider, - uint64 sequenceNumber - ) internal returns (Request storage newRequest) { - (, uint8 shortKey) = requestKey(provider, sequenceNumber); - newRequest = _state.requests[shortKey]; - - if (isActive(newRequest)) { - (bytes32 reqKey, ) = requestKey( - newRequest.provider, - newRequest.sequenceNumber - ); - _state.requestsOverflow[reqKey] = newRequest; - } + function getFee( + address provider + ) public view override returns (uint128 feeAmount) { + return _state.providers[provider].feeInWei + _state.pythFeeInWei; } - function requestKey( - address provider, - uint64 sequenceNumber - ) internal pure returns (bytes32 hashKey, uint8 shortHashKey) { - hashKey = keccak256(abi.encodePacked(provider, sequenceNumber)); - shortHashKey = uint8(hashKey[0] & NUM_REQUESTS_MASK); + function getPythFeeInWei() + public + view + override + returns (uint128 pythFeeInWei) + { + pythFeeInWei = _state.pythFeeInWei; } - function isActive( - Request storage req - ) internal view returns (bool isRequestActive) { - isRequestActive = req.sequenceNumber != 0; + function getAccruedPythFees() + public + view + override + returns (uint128 accruedPythFeesInWei) + { + accruedPythFeesInWei = _state.accruedPythFeesInWei; } - function withdraw(uint128 amount) public override { - ProviderInfo storage providerInfo = _state.providers[msg.sender]; - - // Use checks-effects-interactions pattern to prevent reentrancy attacks - require( - providerInfo.accruedFeesInWei >= amount, - "Insufficient balance" - ); - providerInfo.accruedFeesInWei -= amount; - - // Interaction with an external contract or token transfer - (bool sent, ) = msg.sender.call{value: amount}(""); - require(sent, "withdrawal to msg.sender failed"); + // Set provider fee. It will revert if provider is not registered. + function setProviderFee(uint128 newFeeInWei) external override { + ProviderInfo storage provider = _state.providers[msg.sender]; - emit ProviderWithdrawn(msg.sender, msg.sender, amount); + if (provider.sequenceNumber == 0) { + revert NoSuchProvider(); + } + uint128 oldFeeInWei = provider.feeInWei; + provider.feeInWei = newFeeInWei; + emit ProviderFeeUpdated(msg.sender, oldFeeInWei, newFeeInWei); } - function withdrawAsFeeManager( + function setProviderFeeAsFeeManager( address provider, - uint128 amount + uint128 newFeeInWei ) external override { ProviderInfo storage providerInfo = _state.providers[provider]; @@ -306,86 +276,119 @@ contract Pulse is IPulse, ReentrancyGuard, PulseState { revert Unauthorized(); } - // Use checks-effects-interactions pattern to prevent reentrancy attacks - require( - providerInfo.accruedFeesInWei >= amount, - "Insufficient balance" - ); - providerInfo.accruedFeesInWei -= amount; + uint128 oldFeeInWei = providerInfo.feeInWei; + providerInfo.feeInWei = newFeeInWei; - // Interaction with an external contract or token transfer - (bool sent, ) = msg.sender.call{value: amount}(""); - require(sent, "withdrawal to msg.sender failed"); + emit ProviderFeeUpdated(provider, oldFeeInWei, newFeeInWei); + } - emit ProviderWithdrawn(provider, msg.sender, amount); + // Set provider uri. It will revert if provider is not registered. + function setProviderUri(bytes calldata newUri) external override { + ProviderInfo storage provider = _state.providers[msg.sender]; + if (provider.sequenceNumber == 0) { + revert NoSuchProvider(); + } + bytes memory oldUri = provider.uri; + provider.uri = newUri; + emit ProviderUriUpdated(msg.sender, oldUri, newUri); } function setFeeManager(address manager) external override { ProviderInfo storage provider = _state.providers[msg.sender]; - if (provider.sequenceNumber == 0) revert NoSuchProvider(); + if (provider.sequenceNumber == 0) { + revert NoSuchProvider(); + } address oldFeeManager = provider.feeManager; provider.feeManager = manager; - emit ProviderFeeManagerUpdated(msg.sender, oldFeeManager, manager); } - function setProviderFeeAsFeeManager( + function requestKey( address provider, - uint128 newFeeInWei - ) external override { - ProviderInfo storage providerInfo = _state.providers[provider]; - - if (providerInfo.sequenceNumber == 0) { - revert NoSuchProvider(); - } - - if (providerInfo.feeManager != msg.sender) { - revert Unauthorized(); - } + uint64 sequenceNumber + ) internal pure returns (bytes32 hash, uint8 shortHash) { + hash = keccak256(abi.encodePacked(provider, sequenceNumber)); + shortHash = uint8(hash[0] & NUM_REQUESTS_MASK); + } - uint128 oldFeeInWei = providerInfo.feeInWei; - providerInfo.feeInWei = newFeeInWei; + // Find an in-flight active request for given the provider and the sequence number. + // This method returns a reference to the request, and will revert if the request is + // not active. + function findActiveRequest( + address provider, + uint64 sequenceNumber + ) internal view returns (Request storage req) { + req = findRequest(provider, sequenceNumber); - emit ProviderFeeUpdated(provider, oldFeeInWei, newFeeInWei); + // Check there is an active request for the given provider and sequence number. + if ( + !isActive(req) || + req.provider != provider || + req.sequenceNumber != sequenceNumber + ) revert NoSuchRequest(); } - function getAccruedPythFees() - public - view - override - returns (uint128 accruedPythFeesInWei) - { - accruedPythFeesInWei = _state.accruedPythFeesInWei; - } + // Find an in-flight request. + // Note that this method can return requests that are not currently active. The caller is responsible for checking + // that the returned request is active (if they care). + function findRequest( + address provider, + uint64 sequenceNumber + ) internal view returns (Request storage req) { + (bytes32 key, uint8 shortKey) = requestKey(provider, sequenceNumber); - function getProviderInfo( - address provider - ) public view override returns (ProviderInfo memory info) { - info = _state.providers[provider]; + req = _state.requests[shortKey]; + if (req.provider == provider && req.sequenceNumber == sequenceNumber) { + return req; + } else { + req = _state.requestsOverflow[key]; + } } - function getAdmin() external view override returns (address adminAddress) { - adminAddress = _state.admin; - } + // Clear the storage for an in-flight request, deleting it from the hash table. + function clearRequest(address provider, uint64 sequenceNumber) internal { + (bytes32 key, uint8 shortKey) = requestKey(provider, sequenceNumber); - function getPythFeeInWei() - external - view - override - returns (uint128 pythFee) - { - pythFee = _state.pythFeeInWei; + Request storage req = _state.requests[shortKey]; + if (req.provider == provider && req.sequenceNumber == sequenceNumber) { + req.sequenceNumber = 0; + } else { + delete _state.requestsOverflow[key]; + } } - function setProviderUri(bytes calldata uri) external override { - ProviderInfo storage provider = _state.providers[msg.sender]; - if (provider.sequenceNumber == 0) revert NoSuchProvider(); + // Allocate storage space for a new in-flight request. This method returns a pointer to a storage slot + // that the caller should overwrite with the new request. Note that the memory at this storage slot may + // -- and will -- be filled with arbitrary values, so the caller *must* overwrite every field of the returned + // struct. + function allocRequest( + address provider, + uint64 sequenceNumber + ) internal returns (Request storage req) { + (, uint8 shortKey) = requestKey(provider, sequenceNumber); - bytes memory oldUri = provider.uri; - provider.uri = uri; + req = _state.requests[shortKey]; + if (isActive(req)) { + // There's already a prior active request in the storage slot we want to use. + // Overflow the prior request to the requestsOverflow mapping. + // It is important that this code overflows the *prior* request to the mapping, and not the new request. + // There is a chance that some requests never get revealed and remain active forever. We do not want such + // requests to fill up all of the space in the array and cause all new requests to incur the higher gas cost + // of the mapping. + // + // This operation is expensive, but should be rare. If overflow happens frequently, increase + // the size of the requests array to support more concurrent active requests. + (bytes32 reqKey, ) = requestKey(req.provider, req.sequenceNumber); + _state.requestsOverflow[reqKey] = req; + } + } - emit ProviderUriUpdated(msg.sender, oldUri, uri); + // Returns true if a request is active, i.e., its corresponding random value has not yet been revealed. + function isActive(Request storage req) internal view returns (bool) { + // Note that a provider's initial registration occupies sequence number 0, so there is no way to construct + // a price update request with sequence number 0. + return req.sequenceNumber != 0; } function setMaxNumPrices(uint32 maxNumPrices) external override { @@ -401,11 +404,4 @@ contract Pulse is IPulse, ReentrancyGuard, PulseState { maxNumPrices ); } - - function getRequest( - address provider, - uint64 sequenceNumber - ) public view override returns (Request memory req) { - req = findRequest(provider, sequenceNumber); - } } diff --git a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol index 615d45eb7f..b6b5dddd40 100644 --- a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol +++ b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol @@ -3,6 +3,7 @@ pragma solidity ^0.8.0; import "forge-std/Test.sol"; +import "@openzeppelin/contracts/proxy/ERC1967/ERC1967Proxy.sol"; import "../contracts/pulse/PulseUpgradeable.sol"; import "../contracts/pulse/IPulse.sol"; import "../contracts/pulse/PulseState.sol"; @@ -27,6 +28,7 @@ contract MockPulseConsumer is IPulseConsumer { } contract PulseTest is Test { + ERC1967Proxy public proxy; PulseUpgradeable public pulse; MockPulseConsumer public consumer; address public owner; @@ -40,9 +42,12 @@ contract PulseTest is Test { admin = address(2); provider = address(3); - // Deploy contracts - pulse = new PulseUpgradeable(); - pulse.initialize(owner, admin, PYTH_FEE, provider); + PulseUpgradeable _pulse = new PulseUpgradeable(); + proxy = new ERC1967Proxy(address(_pulse), ""); + // wrap in ABI to support easier calls + pulse = PulseUpgradeable(address(proxy)); + + pulse.initialize(owner, admin, PYTH_FEE, provider, false); consumer = new MockPulseConsumer(); // Register provider @@ -62,10 +67,13 @@ contract PulseTest is Test { uint256 publishTime = block.timestamp; uint256 callbackGasLimit = 500000; + // Fund the consumer contract + vm.deal(address(consumer), 1 ether); + vm.prank(address(consumer)); uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{ value: PYTH_FEE + PROVIDER_FEE - }(provider, publishTime, priceIds, updateData, callbackGasLimit); + }(provider, publishTime, priceIds, callbackGasLimit); assertEq(sequenceNumber, 1); } @@ -85,7 +93,7 @@ contract PulseTest is Test { vm.prank(address(consumer)); uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{ value: PYTH_FEE + PROVIDER_FEE - }(provider, publishTime, priceIds, updateData, callbackGasLimit); + }(provider, publishTime, priceIds, callbackGasLimit); vm.prank(provider); pulse.executeCallback( @@ -99,8 +107,6 @@ contract PulseTest is Test { assertEq(consumer.lastSequenceNumber(), sequenceNumber); assertEq(consumer.lastProvider(), provider); assertEq(consumer.lastPublishTime(), publishTime); - assertEq(consumer.lastPriceIds()[0], priceIds[0]); - assertEq(consumer.lastPriceIds()[1], priceIds[1]); } function testProviderRegistration() public { @@ -123,36 +129,45 @@ contract PulseTest is Test { function testFailInsufficientFee() public { bytes32[] memory priceIds = new bytes32[](1); - bytes[] memory updateData = new bytes[](1); vm.prank(address(consumer)); pulse.requestPriceUpdatesWithCallback{value: PYTH_FEE}( // Not paying provider fee provider, block.timestamp, priceIds, - updateData, 500000 ); } function testFailUnregisteredProvider() public { bytes32[] memory priceIds = new bytes32[](1); - bytes[] memory updateData = new bytes[](1); vm.prank(address(consumer)); pulse.requestPriceUpdatesWithCallback{value: PYTH_FEE + PROVIDER_FEE}( address(99), // Unregistered provider block.timestamp, priceIds, - updateData, 500000 ); } function testGasCostsWithPrefill() public { - // Deploy with prefill - PulseUpgradeable pulseWithPrefill = new PulseUpgradeable(); - pulseWithPrefill.initialize(owner, admin, PYTH_FEE, provider, true); + // Deploy implementation and proxy with prefill + pulse = new PulseUpgradeable(); + bytes memory initData = abi.encodeWithSelector( + PulseUpgradeable.initialize.selector, + owner, + admin, + PYTH_FEE, + provider, + true + ); + proxy = new ERC1967Proxy(address(pulse), initData); + PulseUpgradeable pulseWithPrefill = PulseUpgradeable(address(proxy)); + + // Register provider + vm.prank(provider); + pulseWithPrefill.register(PROVIDER_FEE, "https://provider.com"); // Measure gas for first request uint256 gasBefore = gasleft(); @@ -160,13 +175,26 @@ contract PulseTest is Test { uint256 gasUsed = gasBefore - gasleft(); // Should be lower due to prefill - assertLt(gasUsed, 30000); + assertLt(gasUsed, 130000); } function testGasCostsWithoutPrefill() public { - // Deploy without prefill - PulseUpgradeable pulseWithoutPrefill = new PulseUpgradeable(); - pulseWithoutPrefill.initialize(owner, admin, PYTH_FEE, provider, false); + // Deploy implementation and proxy without prefill + pulse = new PulseUpgradeable(); + bytes memory initData = abi.encodeWithSelector( + PulseUpgradeable.initialize.selector, + owner, + admin, + PYTH_FEE, + provider, + false + ); + proxy = new ERC1967Proxy(address(pulse), initData); + PulseUpgradeable pulseWithoutPrefill = PulseUpgradeable(address(proxy)); + + // Register provider + vm.prank(provider); + pulseWithoutPrefill.register(PROVIDER_FEE, "https://provider.com"); // Measure gas for first request uint256 gasBefore = gasleft(); @@ -174,29 +202,26 @@ contract PulseTest is Test { uint256 gasUsed = gasBefore - gasleft(); // Should be higher without prefill - assertGt(gasUsed, 35000); + assertGt(gasUsed, 130000); } function makeRequest(address pulseAddress) internal { // Helper to make a standard request bytes32[] memory priceIds = new bytes32[](1); - bytes[] memory updateData = new bytes[](1); IPulse(pulseAddress).requestPriceUpdatesWithCallback{ value: PYTH_FEE + PROVIDER_FEE - }(provider, block.timestamp, priceIds, updateData, 500000); + }(provider, block.timestamp, priceIds, 500000); } function testWithdraw() public { // Setup - make a request to accrue some fees bytes32[] memory priceIds = new bytes32[](1); - bytes[] memory updateData = new bytes[](1); vm.prank(address(consumer)); pulse.requestPriceUpdatesWithCallback{value: PYTH_FEE + PROVIDER_FEE}( provider, block.timestamp, priceIds, - updateData, 500000 ); @@ -231,14 +256,12 @@ contract PulseTest is Test { // Setup fees bytes32[] memory priceIds = new bytes32[](1); - bytes[] memory updateData = new bytes[](1); vm.prank(address(consumer)); pulse.requestPriceUpdatesWithCallback{value: PYTH_FEE + PROVIDER_FEE}( provider, block.timestamp, priceIds, - updateData, 500000 ); @@ -282,14 +305,12 @@ contract PulseTest is Test { function testGetAccruedPythFees() public { // Setup - make a request to accrue some fees bytes32[] memory priceIds = new bytes32[](1); - bytes[] memory updateData = new bytes[](1); vm.prank(address(consumer)); pulse.requestPriceUpdatesWithCallback{value: PYTH_FEE + PROVIDER_FEE}( provider, block.timestamp, priceIds, - updateData, 500000 ); @@ -297,7 +318,7 @@ contract PulseTest is Test { assertEq(pulse.getAccruedPythFees(), PYTH_FEE); } - function testGetProviderInfo() public { + function testGetProviderInfo() public view { // Get provider info PulseState.ProviderInfo memory info = pulse.getProviderInfo(provider); @@ -310,11 +331,7 @@ contract PulseTest is Test { assertEq(info.maxNumPrices, 0); } - function testGetAdmin() public { - assertEq(pulse.getAdmin(), admin); - } - - function testGetPythFeeInWei() public { + function testGetPythFeeInWei() public view { assertEq(pulse.getPythFeeInWei(), PYTH_FEE); } @@ -325,8 +342,10 @@ contract PulseTest is Test { pulse.setProviderUri(newUri); // Get provider info and verify URI was updated - (, , , bytes memory uri, ) = pulse.getProviderInfo(provider); - assertEq(string(uri), string(newUri)); + PulseState.ProviderInfo memory providerInfo = pulse.getProviderInfo( + provider + ); + assertEq(string(providerInfo.uri), string(newUri)); } function testFailSetProviderUriUnregistered() public { @@ -341,8 +360,10 @@ contract PulseTest is Test { pulse.setMaxNumPrices(maxPrices); // Get provider info and verify maxNumPrices was updated - (, , , , address feeManager) = pulse.getProviderInfo(provider); - assertEq(uint256(maxPrices), uint256(maxPrices)); + PulseState.ProviderInfo memory providerInfo = pulse.getProviderInfo( + provider + ); + assertEq(uint256(maxPrices), uint256(providerInfo.maxNumPrices)); } function testFailExceedMaxNumPrices() public { @@ -352,14 +373,12 @@ contract PulseTest is Test { // Try to request 3 prices bytes32[] memory priceIds = new bytes32[](3); - bytes[] memory updateData = new bytes[](3); vm.prank(address(consumer)); pulse.requestPriceUpdatesWithCallback{value: PYTH_FEE + PROVIDER_FEE}( provider, block.timestamp, priceIds, - updateData, 500000 ); } @@ -380,7 +399,7 @@ contract PulseTest is Test { vm.prank(address(consumer)); uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{ value: PYTH_FEE + PROVIDER_FEE - }(provider, publishTime, priceIds, updateData, callbackGasLimit); + }(provider, publishTime, priceIds, callbackGasLimit); // Get request and verify PulseState.Request memory req = pulse.getRequest( From 72411e7449dbcaba982d99446159ad3a205d3c80 Mon Sep 17 00:00:00 2001 From: Daniel Chew Date: Wed, 6 Nov 2024 16:11:12 +0900 Subject: [PATCH 04/29] fix test --- .../ethereum/contracts/forge-test/Pulse.t.sol | 27 ++++++++++++------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol index b6b5dddd40..da89d6e087 100644 --- a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol +++ b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol @@ -83,18 +83,22 @@ contract PulseTest is Test { priceIds[0] = bytes32("BTC/USD"); priceIds[1] = bytes32("ETH/USD"); - bytes[] memory updateData = new bytes[](2); - updateData[0] = bytes("data1"); - updateData[1] = bytes("data2"); - uint256 publishTime = block.timestamp; uint256 callbackGasLimit = 500000; + // Fund the consumer contract + vm.deal(address(consumer), 1 ether); + + // Step 1: Make the request as consumer vm.prank(address(consumer)); uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{ value: PYTH_FEE + PROVIDER_FEE }(provider, publishTime, priceIds, callbackGasLimit); + // Step 2: Execute callback as provider with empty updateData array + // Important: must be empty array, not array with empty elements + bytes[] memory updateData = new bytes[](0); + vm.prank(provider); pulse.executeCallback( sequenceNumber, @@ -104,6 +108,7 @@ contract PulseTest is Test { callbackGasLimit ); + // Verify callback was executed assertEq(consumer.lastSequenceNumber(), sequenceNumber); assertEq(consumer.lastProvider(), provider); assertEq(consumer.lastPublishTime(), publishTime); @@ -217,6 +222,7 @@ contract PulseTest is Test { // Setup - make a request to accrue some fees bytes32[] memory priceIds = new bytes32[](1); + vm.deal(address(consumer), 1 ether); vm.prank(address(consumer)); pulse.requestPriceUpdatesWithCallback{value: PYTH_FEE + PROVIDER_FEE}( provider, @@ -257,6 +263,7 @@ contract PulseTest is Test { // Setup fees bytes32[] memory priceIds = new bytes32[](1); + vm.deal(address(consumer), 1 ether); vm.prank(address(consumer)); pulse.requestPriceUpdatesWithCallback{value: PYTH_FEE + PROVIDER_FEE}( provider, @@ -306,6 +313,9 @@ contract PulseTest is Test { // Setup - make a request to accrue some fees bytes32[] memory priceIds = new bytes32[](1); + // Fund the consumer contract + vm.deal(address(consumer), 1 ether); + vm.prank(address(consumer)); pulse.requestPriceUpdatesWithCallback{value: PYTH_FEE + PROVIDER_FEE}( provider, @@ -389,13 +399,12 @@ contract PulseTest is Test { priceIds[0] = bytes32("BTC/USD"); priceIds[1] = bytes32("ETH/USD"); - bytes[] memory updateData = new bytes[](2); - updateData[0] = bytes("data1"); - updateData[1] = bytes("data2"); - uint256 publishTime = block.timestamp; uint256 callbackGasLimit = 500000; + // Fund the consumer contract + vm.deal(address(consumer), 1 ether); + vm.prank(address(consumer)); uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{ value: PYTH_FEE + PROVIDER_FEE @@ -412,8 +421,6 @@ contract PulseTest is Test { assertEq(req.publishTime, publishTime); assertEq(req.priceIds[0], priceIds[0]); assertEq(req.priceIds[1], priceIds[1]); - assertEq(string(req.updateData[0]), string(updateData[0])); - assertEq(string(req.updateData[1]), string(updateData[1])); assertEq(req.callbackGasLimit, callbackGasLimit); assertEq(req.requester, address(consumer)); } From 64b1c5c529424759778077812cc001e451797bb6 Mon Sep 17 00:00:00 2001 From: Daniel Chew Date: Wed, 6 Nov 2024 16:26:55 +0900 Subject: [PATCH 05/29] fix test --- target_chains/ethereum/contracts/forge-test/Pulse.t.sol | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol index da89d6e087..fefe700a97 100644 --- a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol +++ b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol @@ -328,7 +328,7 @@ contract PulseTest is Test { assertEq(pulse.getAccruedPythFees(), PYTH_FEE); } - function testGetProviderInfo() public view { + function testGetProviderInfo() public { // Get provider info PulseState.ProviderInfo memory info = pulse.getProviderInfo(provider); @@ -341,7 +341,7 @@ contract PulseTest is Test { assertEq(info.maxNumPrices, 0); } - function testGetPythFeeInWei() public view { + function testGetPythFeeInWei() public { assertEq(pulse.getPythFeeInWei(), PYTH_FEE); } From 25c76ffbdfa4d0c12548fa7391af48712e6c3f37 Mon Sep 17 00:00:00 2001 From: Daniel Chew Date: Thu, 7 Nov 2024 16:44:25 +0900 Subject: [PATCH 06/29] fix --- .../contracts/contracts/pulse/IPulse.sol | 69 +-- .../contracts/contracts/pulse/Pulse.sol | 53 ++- .../contracts/contracts/pulse/PulseEvents.sol | 57 +++ .../contracts/contracts/pulse/PulseState.sol | 3 +- .../contracts/pulse/PulseUpgradeable.sol | 2 + .../ethereum/contracts/forge-test/Pulse.t.sol | 436 ++++-------------- 6 files changed, 214 insertions(+), 406 deletions(-) create mode 100644 target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol diff --git a/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol b/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol index e7fac6da7c..8e79b97c13 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol @@ -2,6 +2,7 @@ pragma solidity ^0.8.0; +import "./PulseEvents.sol"; import "./PulseState.sol"; interface IPulseConsumer { @@ -13,58 +14,7 @@ interface IPulseConsumer { ) external; } -interface IPulse { - // Events - event ProviderRegistered(PulseState.ProviderInfo providerInfo); - - event PriceUpdateRequested(PulseState.Request request); - - event PriceUpdateExecuted( - uint64 indexed sequenceNumber, - address indexed provider, - uint256 publishTime, - bytes32[] priceIds - ); - - event ProviderFeeUpdated( - address indexed provider, - uint128 oldFeeInWei, - uint128 newFeeInWei - ); - - event ProviderUriUpdated( - address indexed provider, - bytes oldUri, - bytes newUri - ); - - event ProviderWithdrawn( - address indexed provider, - address indexed recipient, - uint128 amount - ); - - event ProviderFeeManagerUpdated( - address indexed provider, - address oldFeeManager, - address newFeeManager - ); - - event ProviderMaxNumPricesUpdated( - address indexed provider, - uint32 oldMaxNumPrices, - uint32 maxNumPrices - ); - - event PriceUpdateCallbackFailed( - uint64 indexed sequenceNumber, - address indexed provider, - uint256 publishTime, - bytes32[] priceIds, - address requester, - string reason - ); - +interface IPulse is PulseEvents { // Core functions function requestPriceUpdatesWithCallback( address provider, @@ -74,15 +24,19 @@ interface IPulse { ) external payable returns (uint64 sequenceNumber); function executeCallback( + address provider, uint64 sequenceNumber, - uint256 publishTime, bytes32[] calldata priceIds, bytes[] calldata updateData, uint256 callbackGasLimit - ) external; + ) external payable; // Provider management - function register(uint128 feeInWei, bytes calldata uri) external; + function register( + uint128 feeInWei, + uint128 feePerGas, + bytes calldata uri + ) external; function setProviderFee(uint128 newFeeInWei) external; @@ -98,7 +52,10 @@ interface IPulse { function withdrawAsFeeManager(address provider, uint128 amount) external; // Getters - function getFee(address provider) external view returns (uint128 feeAmount); + function getFee( + address provider, + uint256 callbackGasLimit + ) external view returns (uint128 feeAmount); function getPythFeeInWei() external view returns (uint128 pythFeeInWei); diff --git a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol index e6397a2f65..76550665ff 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol @@ -3,6 +3,7 @@ pragma solidity ^0.8.0; import "@openzeppelin/contracts/utils/math/SafeCast.sol"; +import "@pythnetwork/pyth-sdk-solidity/IPyth.sol"; import "./IPulse.sol"; import "./PulseState.sol"; import "./PulseErrors.sol"; @@ -12,6 +13,7 @@ abstract contract Pulse is IPulse, PulseState { address admin, uint128 pythFeeInWei, address defaultProvider, + address pythAddress, bool prefillRequestStorage ) internal { require(admin != address(0), "admin is zero address"); @@ -19,11 +21,13 @@ abstract contract Pulse is IPulse, PulseState { defaultProvider != address(0), "defaultProvider is zero address" ); + require(pythAddress != address(0), "pyth is zero address"); _state.admin = admin; _state.accruedPythFeesInWei = 0; _state.pythFeeInWei = pythFeeInWei; _state.defaultProvider = defaultProvider; + _state.pyth = pythAddress; if (prefillRequestStorage) { // Write some data to every storage slot in the requests array such that new requests @@ -40,10 +44,15 @@ abstract contract Pulse is IPulse, PulseState { } } - function register(uint128 feeInWei, bytes calldata uri) public override { + function register( + uint128 feeInWei, + uint128 feePerGas, + bytes calldata uri + ) public override { ProviderInfo storage providerInfo = _state.providers[msg.sender]; providerInfo.feeInWei = feeInWei; + providerInfo.feePerGas = feePerGas; providerInfo.uri = uri; providerInfo.sequenceNumber += 1; @@ -115,7 +124,7 @@ abstract contract Pulse is IPulse, PulseState { requestSequenceNumber = providerInfo.sequenceNumber++; // Verify fee payment - uint128 requiredFee = getFee(provider); + uint128 requiredFee = getFee(provider, callbackGasLimit); if (msg.value < requiredFee) revert InsufficientFee(); // Store request for callback execution @@ -137,13 +146,28 @@ abstract contract Pulse is IPulse, PulseState { } function executeCallback( + address provider, uint64 sequenceNumber, - uint256 publishTime, bytes32[] calldata priceIds, bytes[] calldata updateData, uint256 callbackGasLimit - ) external override { - Request storage req = findActiveRequest(msg.sender, sequenceNumber); + ) external payable override { + Request storage req = findActiveRequest(provider, sequenceNumber); + + require( + gasleft() >= req.callbackGasLimit, + "Insufficient gas for callback" + ); + + PythStructs.PriceFeed[] memory priceFeeds = IPyth(_state.pyth) + .parsePriceFeedUpdates( + updateData, + priceIds, + SafeCast.toUint64(req.publishTime), + SafeCast.toUint64(req.publishTime) + ); + + uint256 publishTime = priceFeeds[0].price.publishTime; // Verify request parameters match require(req.publishTime == publishTime, "Invalid publish time"); @@ -152,16 +176,14 @@ abstract contract Pulse is IPulse, PulseState { keccak256(abi.encode(priceIds)), "Invalid price IDs" ); - require( - keccak256(abi.encode(req.updateData)) == - keccak256(abi.encode(updateData)), - "Invalid update data" - ); require( req.callbackGasLimit == callbackGasLimit, "Invalid callback gas limit" ); + // Update price feeds before executing callback + IPyth(_state.pyth).updatePriceFeeds{value: msg.value}(updateData); + // Execute callback but don't revert if it fails try IPulseConsumer(req.requester).pulseCallback( @@ -227,9 +249,14 @@ abstract contract Pulse is IPulse, PulseState { } function getFee( - address provider + address provider, + uint256 callbackGasLimit ) public view override returns (uint128 feeAmount) { - return _state.providers[provider].feeInWei + _state.pythFeeInWei; + ProviderInfo storage providerInfo = _state.providers[provider]; + feeAmount = + providerInfo.feeInWei + + (providerInfo.feePerGas * uint128(callbackGasLimit)) + + _state.pythFeeInWei; } function getPythFeeInWei() @@ -384,7 +411,7 @@ abstract contract Pulse is IPulse, PulseState { } } - // Returns true if a request is active, i.e., its corresponding random value has not yet been revealed. + // Returns true if a request is active, i.e., its corresponding price update has not yet been executed. function isActive(Request storage req) internal view returns (bool) { // Note that a provider's initial registration occupies sequence number 0, so there is no way to construct // a price update request with sequence number 0. diff --git a/target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol b/target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol new file mode 100644 index 0000000000..45c7fb8b90 --- /dev/null +++ b/target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: Apache-2.0 +pragma solidity ^0.8.0; + +import "./PulseState.sol"; + +interface PulseEvents { + // Events + event ProviderRegistered(PulseState.ProviderInfo providerInfo); + + event PriceUpdateRequested(PulseState.Request request); + + event PriceUpdateExecuted( + uint64 indexed sequenceNumber, + address indexed provider, + uint256 publishTime, + bytes32[] priceIds + ); + + event ProviderFeeUpdated( + address indexed provider, + uint128 oldFeeInWei, + uint128 newFeeInWei + ); + + event ProviderUriUpdated( + address indexed provider, + bytes oldUri, + bytes newUri + ); + + event ProviderWithdrawn( + address indexed provider, + address indexed recipient, + uint128 amount + ); + + event ProviderFeeManagerUpdated( + address indexed provider, + address oldFeeManager, + address newFeeManager + ); + + event ProviderMaxNumPricesUpdated( + address indexed provider, + uint32 oldMaxNumPrices, + uint32 maxNumPrices + ); + + event PriceUpdateCallbackFailed( + uint64 indexed sequenceNumber, + address indexed provider, + uint256 publishTime, + bytes32[] priceIds, + address requester, + string reason + ); +} diff --git a/target_chains/ethereum/contracts/contracts/pulse/PulseState.sol b/target_chains/ethereum/contracts/contracts/pulse/PulseState.sol index 839a6a0d21..3341edee21 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/PulseState.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/PulseState.sol @@ -11,7 +11,6 @@ contract PulseState { uint64 sequenceNumber; uint256 publishTime; bytes32[] priceIds; - bytes[] updateData; uint256 callbackGasLimit; address requester; } @@ -23,6 +22,7 @@ contract PulseState { bytes uri; address feeManager; uint32 maxNumPrices; + uint128 feePerGas; } struct State { @@ -30,6 +30,7 @@ contract PulseState { uint128 pythFeeInWei; uint128 accruedPythFeesInWei; address defaultProvider; + address pyth; Request[32] requests; mapping(bytes32 => Request) requestsOverflow; mapping(address => ProviderInfo) providers; diff --git a/target_chains/ethereum/contracts/contracts/pulse/PulseUpgradeable.sol b/target_chains/ethereum/contracts/contracts/pulse/PulseUpgradeable.sol index 191ef15889..0c09e8b9de 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/PulseUpgradeable.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/PulseUpgradeable.sol @@ -23,6 +23,7 @@ contract PulseUpgradeable is address admin, uint128 pythFeeInWei, address defaultProvider, + address pythAddress, bool prefillRequestStorage ) public initializer { require(owner != address(0), "owner is zero address"); @@ -34,6 +35,7 @@ contract PulseUpgradeable is admin, pythFeeInWei, defaultProvider, + pythAddress, prefillRequestStorage ); diff --git a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol index fefe700a97..a48cc57dd0 100644 --- a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol +++ b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol @@ -7,6 +7,7 @@ import "@openzeppelin/contracts/proxy/ERC1967/ERC1967Proxy.sol"; import "../contracts/pulse/PulseUpgradeable.sol"; import "../contracts/pulse/IPulse.sol"; import "../contracts/pulse/PulseState.sol"; +import "../contracts/pulse/PulseEvents.sol"; contract MockPulseConsumer is IPulseConsumer { uint64 public lastSequenceNumber; @@ -27,401 +28,164 @@ contract MockPulseConsumer is IPulseConsumer { } } -contract PulseTest is Test { +contract PulseTest is Test, PulseEvents { ERC1967Proxy public proxy; PulseUpgradeable public pulse; MockPulseConsumer public consumer; address public owner; address public admin; address public provider; - uint128 constant PYTH_FEE = 0.001 ether; - uint128 constant PROVIDER_FEE = 0.002 ether; + address public pyth; + uint128 constant PYTH_FEE = 1 wei; + uint128 constant PROVIDER_FEE = 1 wei; + uint128 constant PROVIDER_FEE_PER_GAS = 1 wei; + uint128 constant CALLBACK_GAS_LIMIT = 1_000_000; + bytes32 constant BTC_PRICE_FEED_ID = + 0xe62df6c8b4a85fe1a67db44dc12de5db330f7ac66b72dc658afedf0f4a415b43; + bytes32 constant ETH_PRICE_FEED_ID = + 0xff61491a931112ddf1bd8147cd1b641375f79f5825126d665480874634fd0ace; function setUp() public { owner = address(1); admin = address(2); provider = address(3); + pyth = address(4); PulseUpgradeable _pulse = new PulseUpgradeable(); proxy = new ERC1967Proxy(address(_pulse), ""); // wrap in ABI to support easier calls pulse = PulseUpgradeable(address(proxy)); - pulse.initialize(owner, admin, PYTH_FEE, provider, false); + pulse.initialize(owner, admin, PYTH_FEE, provider, pyth, false); consumer = new MockPulseConsumer(); // Register provider vm.prank(provider); - pulse.register(PROVIDER_FEE, "https://provider.com"); + pulse.register( + PROVIDER_FEE, + PROVIDER_FEE_PER_GAS, + "https://provider.com" + ); } function testRequestPriceUpdate() public { bytes32[] memory priceIds = new bytes32[](2); - priceIds[0] = bytes32("BTC/USD"); - priceIds[1] = bytes32("ETH/USD"); - - bytes[] memory updateData = new bytes[](2); - updateData[0] = bytes("data1"); - updateData[1] = bytes("data2"); + priceIds[0] = BTC_PRICE_FEED_ID; + priceIds[1] = ETH_PRICE_FEED_ID; uint256 publishTime = block.timestamp; - uint256 callbackGasLimit = 500000; // Fund the consumer contract - vm.deal(address(consumer), 1 ether); + vm.deal(address(consumer), 1 gwei); vm.prank(address(consumer)); - uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{ - value: PYTH_FEE + PROVIDER_FEE - }(provider, publishTime, priceIds, callbackGasLimit); - - assertEq(sequenceNumber, 1); - } - - function testExecuteCallback() public { - bytes32[] memory priceIds = new bytes32[](2); - priceIds[0] = bytes32("BTC/USD"); - priceIds[1] = bytes32("ETH/USD"); - - uint256 publishTime = block.timestamp; - uint256 callbackGasLimit = 500000; - - // Fund the consumer contract - vm.deal(address(consumer), 1 ether); - - // Step 1: Make the request as consumer - vm.prank(address(consumer)); - uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{ - value: PYTH_FEE + PROVIDER_FEE - }(provider, publishTime, priceIds, callbackGasLimit); - - // Step 2: Execute callback as provider with empty updateData array - // Important: must be empty array, not array with empty elements - bytes[] memory updateData = new bytes[](0); - - vm.prank(provider); - pulse.executeCallback( - sequenceNumber, - publishTime, - priceIds, - updateData, - callbackGasLimit - ); - // Verify callback was executed - assertEq(consumer.lastSequenceNumber(), sequenceNumber); - assertEq(consumer.lastProvider(), provider); - assertEq(consumer.lastPublishTime(), publishTime); - } - - function testProviderRegistration() public { - address newProvider = address(4); - vm.prank(newProvider); - pulse.register(PROVIDER_FEE, "https://newprovider.com"); - - uint128 fee = pulse.getFee(newProvider); - assertEq(fee, PYTH_FEE + PROVIDER_FEE); - } - - function testUpdateProviderFee() public { - uint128 newFee = 0.003 ether; - vm.prank(provider); - pulse.setProviderFee(newFee); - - uint128 fee = pulse.getFee(provider); - assertEq(fee, PYTH_FEE + newFee); - } - - function testFailInsufficientFee() public { - bytes32[] memory priceIds = new bytes32[](1); - - vm.prank(address(consumer)); - pulse.requestPriceUpdatesWithCallback{value: PYTH_FEE}( // Not paying provider fee + // Create the event data we expect to see + PulseState.Request memory expectedRequest = PulseState.Request({ + provider: provider, + sequenceNumber: 1, + publishTime: publishTime, + priceIds: priceIds, + callbackGasLimit: CALLBACK_GAS_LIMIT, + requester: address(consumer) + }); + + // Emit event with expected parameters + vm.expectEmit(); + emit PriceUpdateRequested(expectedRequest); + + // Calculate total fee including gas component + uint128 totalFee = PYTH_FEE + + PROVIDER_FEE + + (PROVIDER_FEE_PER_GAS * uint128(CALLBACK_GAS_LIMIT)); + + // Make the actual call that should emit the event + pulse.requestPriceUpdatesWithCallback{value: totalFee}( provider, - block.timestamp, - priceIds, - 500000 - ); - } - - function testFailUnregisteredProvider() public { - bytes32[] memory priceIds = new bytes32[](1); - - vm.prank(address(consumer)); - pulse.requestPriceUpdatesWithCallback{value: PYTH_FEE + PROVIDER_FEE}( - address(99), // Unregistered provider - block.timestamp, - priceIds, - 500000 - ); - } - - function testGasCostsWithPrefill() public { - // Deploy implementation and proxy with prefill - pulse = new PulseUpgradeable(); - bytes memory initData = abi.encodeWithSelector( - PulseUpgradeable.initialize.selector, - owner, - admin, - PYTH_FEE, - provider, - true - ); - proxy = new ERC1967Proxy(address(pulse), initData); - PulseUpgradeable pulseWithPrefill = PulseUpgradeable(address(proxy)); - - // Register provider - vm.prank(provider); - pulseWithPrefill.register(PROVIDER_FEE, "https://provider.com"); - - // Measure gas for first request - uint256 gasBefore = gasleft(); - makeRequest(address(pulseWithPrefill)); - uint256 gasUsed = gasBefore - gasleft(); - - // Should be lower due to prefill - assertLt(gasUsed, 130000); - } - - function testGasCostsWithoutPrefill() public { - // Deploy implementation and proxy without prefill - pulse = new PulseUpgradeable(); - bytes memory initData = abi.encodeWithSelector( - PulseUpgradeable.initialize.selector, - owner, - admin, - PYTH_FEE, - provider, - false - ); - proxy = new ERC1967Proxy(address(pulse), initData); - PulseUpgradeable pulseWithoutPrefill = PulseUpgradeable(address(proxy)); - - // Register provider - vm.prank(provider); - pulseWithoutPrefill.register(PROVIDER_FEE, "https://provider.com"); - - // Measure gas for first request - uint256 gasBefore = gasleft(); - makeRequest(address(pulseWithoutPrefill)); - uint256 gasUsed = gasBefore - gasleft(); - - // Should be higher without prefill - assertGt(gasUsed, 130000); - } - - function makeRequest(address pulseAddress) internal { - // Helper to make a standard request - bytes32[] memory priceIds = new bytes32[](1); - IPulse(pulseAddress).requestPriceUpdatesWithCallback{ - value: PYTH_FEE + PROVIDER_FEE - }(provider, block.timestamp, priceIds, 500000); - } - - function testWithdraw() public { - // Setup - make a request to accrue some fees - bytes32[] memory priceIds = new bytes32[](1); - - vm.deal(address(consumer), 1 ether); - vm.prank(address(consumer)); - pulse.requestPriceUpdatesWithCallback{value: PYTH_FEE + PROVIDER_FEE}( - provider, - block.timestamp, + publishTime, priceIds, - 500000 + CALLBACK_GAS_LIMIT ); - // Check provider balance before withdrawal - uint256 providerBalanceBefore = address(provider).balance; - - // Provider withdraws fees - vm.prank(provider); - pulse.withdraw(PROVIDER_FEE); - - // Verify balance increased + // Additional assertions to verify event data was stored correctly + PulseState.Request memory lastRequest = pulse.getRequest(provider, 1); + assertEq(lastRequest.provider, expectedRequest.provider); + assertEq(lastRequest.sequenceNumber, expectedRequest.sequenceNumber); + assertEq(lastRequest.publishTime, expectedRequest.publishTime); assertEq( - address(provider).balance, - providerBalanceBefore + PROVIDER_FEE + lastRequest.callbackGasLimit, + expectedRequest.callbackGasLimit ); + assertEq(lastRequest.requester, expectedRequest.requester); } - function testFailWithdrawTooMuch() public { - vm.prank(provider); - pulse.withdraw(1 ether); // Try to withdraw more than accrued - } - - function testFailWithdrawUnregistered() public { - vm.prank(address(99)); // Unregistered provider - pulse.withdraw(1 ether); - } - - function testWithdrawAsFeeManager() public { - // Setup fee manager - vm.prank(provider); - pulse.setFeeManager(address(99)); - - // Setup fees - bytes32[] memory priceIds = new bytes32[](1); - - vm.deal(address(consumer), 1 ether); - vm.prank(address(consumer)); - pulse.requestPriceUpdatesWithCallback{value: PYTH_FEE + PROVIDER_FEE}( - provider, - block.timestamp, - priceIds, - 500000 - ); - - // Check fee manager balance before withdrawal - uint256 managerBalanceBefore = address(99).balance; - - // Fee manager withdraws - vm.prank(address(99)); - pulse.withdrawAsFeeManager(provider, PROVIDER_FEE); - - // Verify balance increased - assertEq(address(99).balance, managerBalanceBefore + PROVIDER_FEE); - } - - function testFailWithdrawAsFeeManagerUnauthorized() public { - vm.prank(address(88)); // Not the fee manager - pulse.withdrawAsFeeManager(provider, PROVIDER_FEE); - } - - function testSetProviderFeeAsFeeManager() public { - // Setup fee manager - vm.prank(provider); - pulse.setFeeManager(address(99)); - - uint128 newFee = 0.005 ether; - - // Fee manager updates fee - vm.prank(address(99)); - pulse.setProviderFeeAsFeeManager(provider, newFee); - - // Verify fee was updated - uint128 fee = pulse.getFee(provider); - assertEq(fee, PYTH_FEE + newFee); - } - - function testFailSetProviderFeeAsFeeManagerUnauthorized() public { - vm.prank(address(88)); // Not the fee manager - pulse.setProviderFeeAsFeeManager(provider, 0.005 ether); - } + function testExecuteCallback() public { + bytes32[] memory priceIds = new bytes32[](2); + priceIds[0] = BTC_PRICE_FEED_ID; + priceIds[1] = ETH_PRICE_FEED_ID; - function testGetAccruedPythFees() public { - // Setup - make a request to accrue some fees - bytes32[] memory priceIds = new bytes32[](1); + uint256 publishTime = block.timestamp; // Fund the consumer contract - vm.deal(address(consumer), 1 ether); + vm.deal(address(consumer), 1 gwei); + // Step 1: Make the request as consumer vm.prank(address(consumer)); - pulse.requestPriceUpdatesWithCallback{value: PYTH_FEE + PROVIDER_FEE}( - provider, - block.timestamp, - priceIds, - 500000 - ); - - // Verify accrued fees - assertEq(pulse.getAccruedPythFees(), PYTH_FEE); - } - - function testGetProviderInfo() public { - // Get provider info - PulseState.ProviderInfo memory info = pulse.getProviderInfo(provider); - - // Verify initial values - assertEq(info.sequenceNumber, 1); // Set during registration - assertEq(info.feeInWei, PROVIDER_FEE); - assertEq(info.accruedFeesInWei, 0); - assertEq(string(info.uri), "https://provider.com"); - assertEq(info.feeManager, address(0)); - assertEq(info.maxNumPrices, 0); - } - - function testGetPythFeeInWei() public { - assertEq(pulse.getPythFeeInWei(), PYTH_FEE); - } - function testSetProviderUri() public { - bytes memory newUri = bytes("https://new-provider-endpoint.com"); + // Calculate total fee including gas component + uint128 totalFee = PYTH_FEE + + PROVIDER_FEE + + (PROVIDER_FEE_PER_GAS * uint128(CALLBACK_GAS_LIMIT)); - vm.prank(provider); - pulse.setProviderUri(newUri); + uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{ + value: totalFee + }(provider, publishTime, priceIds, CALLBACK_GAS_LIMIT); - // Get provider info and verify URI was updated - PulseState.ProviderInfo memory providerInfo = pulse.getProviderInfo( - provider + // Step 2: Create mock price feeds that match the expected publish time + PythStructs.PriceFeed[] memory priceFeeds = new PythStructs.PriceFeed[]( + 2 ); - assertEq(string(providerInfo.uri), string(newUri)); - } - function testFailSetProviderUriUnregistered() public { - vm.prank(address(99)); // Unregistered provider - pulse.setProviderUri(bytes("https://new-uri.com")); - } + // Create mock price feed for BTC + priceFeeds[0].price.publishTime = publishTime; + priceFeeds[0].id = BTC_PRICE_FEED_ID; - function testSetMaxNumPrices() public { - uint32 maxPrices = 5; + // Create mock price feed for ETH + priceFeeds[1].price.publishTime = publishTime; + priceFeeds[1].id = ETH_PRICE_FEED_ID; - vm.prank(provider); - pulse.setMaxNumPrices(maxPrices); - - // Get provider info and verify maxNumPrices was updated - PulseState.ProviderInfo memory providerInfo = pulse.getProviderInfo( - provider + // Mock Pyth's parsePriceFeedUpdates to return our price feeds + vm.mockCall( + address(pyth), + abi.encodeWithSelector(IPyth.parsePriceFeedUpdates.selector), + abi.encode(priceFeeds) ); - assertEq(uint256(maxPrices), uint256(providerInfo.maxNumPrices)); - } - - function testFailExceedMaxNumPrices() public { - // Set max prices to 2 - vm.prank(provider); - pulse.setMaxNumPrices(2); - - // Try to request 3 prices - bytes32[] memory priceIds = new bytes32[](3); - vm.prank(address(consumer)); - pulse.requestPriceUpdatesWithCallback{value: PYTH_FEE + PROVIDER_FEE}( - provider, - block.timestamp, - priceIds, - 500000 + // Mock Pyth's updatePriceFeeds + vm.mockCall( + address(pyth), + abi.encodeWithSelector(IPyth.updatePriceFeeds.selector), + abi.encode() ); - } - function testGetRequest() public { - // Setup - make a request - bytes32[] memory priceIds = new bytes32[](2); - priceIds[0] = bytes32("BTC/USD"); - priceIds[1] = bytes32("ETH/USD"); - - uint256 publishTime = block.timestamp; - uint256 callbackGasLimit = 500000; - - // Fund the consumer contract - vm.deal(address(consumer), 1 ether); - - vm.prank(address(consumer)); - uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{ - value: PYTH_FEE + PROVIDER_FEE - }(provider, publishTime, priceIds, callbackGasLimit); + // Create mock update data + bytes[] memory updateData = new bytes[](2); + updateData[0] = abi.encode(priceFeeds[0]); + updateData[1] = abi.encode(priceFeeds[1]); - // Get request and verify - PulseState.Request memory req = pulse.getRequest( + // Execute callback as provider + vm.prank(provider); + pulse.executeCallback( provider, - sequenceNumber + sequenceNumber, + priceIds, + updateData, + CALLBACK_GAS_LIMIT ); - assertEq(req.provider, provider); - assertEq(req.sequenceNumber, sequenceNumber); - assertEq(req.publishTime, publishTime); - assertEq(req.priceIds[0], priceIds[0]); - assertEq(req.priceIds[1], priceIds[1]); - assertEq(req.callbackGasLimit, callbackGasLimit); - assertEq(req.requester, address(consumer)); + // Verify callback was executed + assertEq(consumer.lastSequenceNumber(), sequenceNumber); + assertEq(consumer.lastProvider(), provider); + assertEq(consumer.lastPublishTime(), publishTime); } } From e490aabe2b4d6f2f96c13ac342de33b6b5b2f34f Mon Sep 17 00:00:00 2001 From: Daniel Chew Date: Mon, 11 Nov 2024 17:11:23 +0900 Subject: [PATCH 07/29] fix --- .../contracts/contracts/pulse/Pulse.sol | 41 +++++++++++--- .../contracts/contracts/pulse/PulseEvents.sol | 6 +- .../ethereum/contracts/forge-test/Pulse.t.sol | 56 ++++++++++++++++--- 3 files changed, 86 insertions(+), 17 deletions(-) diff --git a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol index 76550665ff..5888264e6f 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol @@ -170,7 +170,6 @@ abstract contract Pulse is IPulse, PulseState { uint256 publishTime = priceFeeds[0].price.publishTime; // Verify request parameters match - require(req.publishTime == publishTime, "Invalid publish time"); require( keccak256(abi.encode(req.priceIds)) == keccak256(abi.encode(priceIds)), @@ -181,10 +180,6 @@ abstract contract Pulse is IPulse, PulseState { "Invalid callback gas limit" ); - // Update price feeds before executing callback - IPyth(_state.pyth).updatePriceFeeds{value: msg.value}(updateData); - - // Execute callback but don't revert if it fails try IPulseConsumer(req.requester).pulseCallback( sequenceNumber, @@ -194,11 +189,12 @@ abstract contract Pulse is IPulse, PulseState { ) { // Callback succeeded - emit PriceUpdateExecuted( + emitPriceUpdate( sequenceNumber, msg.sender, publishTime, - priceIds + priceIds, + priceFeeds ); } catch Error(string memory reason) { // Explicit revert/require @@ -226,6 +222,37 @@ abstract contract Pulse is IPulse, PulseState { clearRequest(msg.sender, sequenceNumber); } + function emitPriceUpdate( + uint64 sequenceNumber, + address provider, + uint256 publishTime, + bytes32[] memory priceIds, + PythStructs.PriceFeed[] memory priceFeeds + ) internal { + int64[] memory prices = new int64[](priceFeeds.length); + uint64[] memory conf = new uint64[](priceFeeds.length); + int32[] memory expos = new int32[](priceFeeds.length); + uint256[] memory publishTimes = new uint256[](priceFeeds.length); + + for (uint i = 0; i < priceFeeds.length; i++) { + prices[i] = priceFeeds[i].price.price; + conf[i] = priceFeeds[i].price.conf; + expos[i] = priceFeeds[i].price.expo; + publishTimes[i] = priceFeeds[i].price.publishTime; + } + + emit PriceUpdateExecuted( + sequenceNumber, + provider, + publishTime, + priceIds, + prices, + conf, + expos, + publishTimes + ); + } + function getProviderInfo( address provider ) public view override returns (ProviderInfo memory info) { diff --git a/target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol b/target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol index 45c7fb8b90..070094fdbb 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol @@ -13,7 +13,11 @@ interface PulseEvents { uint64 indexed sequenceNumber, address indexed provider, uint256 publishTime, - bytes32[] priceIds + bytes32[] priceIds, + int64[] prices, + uint64[] conf, + int32[] expos, + uint256[] publishTimes ); event ProviderFeeUpdated( diff --git a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol index a48cc57dd0..f840ff8f5f 100644 --- a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol +++ b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol @@ -45,6 +45,15 @@ contract PulseTest is Test, PulseEvents { bytes32 constant ETH_PRICE_FEED_ID = 0xff61491a931112ddf1bd8147cd1b641375f79f5825126d665480874634fd0ace; + // Price feed constants + int8 constant MOCK_PRICE_FEED_EXPO = -8; + + // Mock price values (already scaled according to Pyth's format) + int64 constant MOCK_BTC_PRICE = 5_000_000_000_000; // $50,000 + int64 constant MOCK_ETH_PRICE = 300_000_000_000; // $3,000 + uint64 constant MOCK_BTC_CONF = 10_000_000_000; // $100 + uint64 constant MOCK_ETH_CONF = 5_000_000_000; // $50 + function setUp() public { owner = address(1); admin = address(2); @@ -146,13 +155,19 @@ contract PulseTest is Test, PulseEvents { 2 ); - // Create mock price feed for BTC - priceFeeds[0].price.publishTime = publishTime; + // Create mock price feed for BTC with specific values priceFeeds[0].id = BTC_PRICE_FEED_ID; + priceFeeds[0].price.price = MOCK_BTC_PRICE; + priceFeeds[0].price.conf = MOCK_BTC_CONF; + priceFeeds[0].price.expo = MOCK_PRICE_FEED_EXPO; + priceFeeds[0].price.publishTime = publishTime; - // Create mock price feed for ETH - priceFeeds[1].price.publishTime = publishTime; + // Create mock price feed for ETH with specific values priceFeeds[1].id = ETH_PRICE_FEED_ID; + priceFeeds[1].price.price = MOCK_ETH_PRICE; + priceFeeds[1].price.conf = MOCK_ETH_CONF; + priceFeeds[1].price.expo = MOCK_PRICE_FEED_EXPO; + priceFeeds[1].price.publishTime = publishTime; // Mock Pyth's parsePriceFeedUpdates to return our price feeds vm.mockCall( @@ -161,11 +176,34 @@ contract PulseTest is Test, PulseEvents { abi.encode(priceFeeds) ); - // Mock Pyth's updatePriceFeeds - vm.mockCall( - address(pyth), - abi.encodeWithSelector(IPyth.updatePriceFeeds.selector), - abi.encode() + // Create arrays for expected event data + int64[] memory expectedPrices = new int64[](2); + expectedPrices[0] = MOCK_BTC_PRICE; + expectedPrices[1] = MOCK_ETH_PRICE; + + uint64[] memory expectedConf = new uint64[](2); + expectedConf[0] = MOCK_BTC_CONF; + expectedConf[1] = MOCK_ETH_CONF; + + int32[] memory expectedExpos = new int32[](2); + expectedExpos[0] = MOCK_PRICE_FEED_EXPO; + expectedExpos[1] = MOCK_PRICE_FEED_EXPO; + + uint256[] memory expectedPublishTimes = new uint256[](2); + expectedPublishTimes[0] = publishTime; + expectedPublishTimes[1] = publishTime; + + // Expect the PriceUpdateExecuted event with all price data + vm.expectEmit(true, true, false, true); + emit PriceUpdateExecuted( + sequenceNumber, + provider, + publishTime, + priceIds, + expectedPrices, + expectedConf, + expectedExpos, + expectedPublishTimes ); // Create mock update data From 1033ba624b01bb2ffbf9bcefff292af9408edab0 Mon Sep 17 00:00:00 2001 From: Daniel Chew Date: Tue, 12 Nov 2024 12:21:45 +0900 Subject: [PATCH 08/29] add more tests --- .../ethereum/contracts/forge-test/Pulse.t.sol | 237 ++++++++++++++---- 1 file changed, 186 insertions(+), 51 deletions(-) diff --git a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol index f840ff8f5f..5b41ffd9d1 100644 --- a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol +++ b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol @@ -28,6 +28,30 @@ contract MockPulseConsumer is IPulseConsumer { } } +contract FailingPulseConsumer is IPulseConsumer { + function pulseCallback( + uint64, + address, + uint256, + bytes32[] calldata + ) external pure override { + revert("callback failed"); + } +} + +contract CustomErrorPulseConsumer is IPulseConsumer { + error CustomError(string message); + + function pulseCallback( + uint64, + address, + uint256, + bytes32[] calldata + ) external pure override { + revert CustomError("callback failed"); + } +} + contract PulseTest is Test, PulseEvents { ERC1967Proxy public proxy; PulseUpgradeable public pulse; @@ -36,6 +60,8 @@ contract PulseTest is Test, PulseEvents { address public admin; address public provider; address public pyth; + + // Constants uint128 constant PYTH_FEE = 1 wei; uint128 constant PROVIDER_FEE = 1 wei; uint128 constant PROVIDER_FEE_PER_GAS = 1 wei; @@ -47,8 +73,6 @@ contract PulseTest is Test, PulseEvents { // Price feed constants int8 constant MOCK_PRICE_FEED_EXPO = -8; - - // Mock price values (already scaled according to Pyth's format) int64 constant MOCK_BTC_PRICE = 5_000_000_000_000; // $50,000 int64 constant MOCK_ETH_PRICE = 300_000_000_000; // $3,000 uint64 constant MOCK_BTC_CONF = 10_000_000_000; // $100 @@ -62,13 +86,11 @@ contract PulseTest is Test, PulseEvents { PulseUpgradeable _pulse = new PulseUpgradeable(); proxy = new ERC1967Proxy(address(_pulse), ""); - // wrap in ABI to support easier calls pulse = PulseUpgradeable(address(proxy)); pulse.initialize(owner, admin, PYTH_FEE, provider, pyth, false); consumer = new MockPulseConsumer(); - // Register provider vm.prank(provider); pulse.register( PROVIDER_FEE, @@ -77,11 +99,91 @@ contract PulseTest is Test, PulseEvents { ); } - function testRequestPriceUpdate() public { + // Helper function to create price IDs array + function createPriceIds() internal pure returns (bytes32[] memory) { bytes32[] memory priceIds = new bytes32[](2); priceIds[0] = BTC_PRICE_FEED_ID; priceIds[1] = ETH_PRICE_FEED_ID; + return priceIds; + } + + // Helper function to create mock price feeds + function createMockPriceFeeds( + uint256 publishTime + ) internal pure returns (PythStructs.PriceFeed[] memory) { + PythStructs.PriceFeed[] memory priceFeeds = new PythStructs.PriceFeed[]( + 2 + ); + + priceFeeds[0].id = BTC_PRICE_FEED_ID; + priceFeeds[0].price.price = MOCK_BTC_PRICE; + priceFeeds[0].price.conf = MOCK_BTC_CONF; + priceFeeds[0].price.expo = MOCK_PRICE_FEED_EXPO; + priceFeeds[0].price.publishTime = publishTime; + + priceFeeds[1].id = ETH_PRICE_FEED_ID; + priceFeeds[1].price.price = MOCK_ETH_PRICE; + priceFeeds[1].price.conf = MOCK_ETH_CONF; + priceFeeds[1].price.expo = MOCK_PRICE_FEED_EXPO; + priceFeeds[1].price.publishTime = publishTime; + + return priceFeeds; + } + // Helper function to mock Pyth response + function mockPythResponse( + PythStructs.PriceFeed[] memory priceFeeds + ) internal { + vm.mockCall( + address(pyth), + abi.encodeWithSelector(IPyth.parsePriceFeedUpdates.selector), + abi.encode(priceFeeds) + ); + } + + // Helper function to create update data + function createUpdateData( + PythStructs.PriceFeed[] memory priceFeeds + ) internal pure returns (bytes[] memory) { + bytes[] memory updateData = new bytes[](2); + updateData[0] = abi.encode(priceFeeds[0]); + updateData[1] = abi.encode(priceFeeds[1]); + return updateData; + } + + // Helper function to calculate total fee + function calculateTotalFee() internal pure returns (uint128) { + return + PYTH_FEE + + PROVIDER_FEE + + (PROVIDER_FEE_PER_GAS * uint128(CALLBACK_GAS_LIMIT)); + } + + // Helper function to setup consumer request + function setupConsumerRequest( + address consumerAddress + ) + internal + returns ( + uint64 sequenceNumber, + bytes32[] memory priceIds, + uint256 publishTime + ) + { + priceIds = createPriceIds(); + publishTime = block.timestamp; + vm.deal(consumerAddress, 1 gwei); + + vm.prank(consumerAddress); + sequenceNumber = pulse.requestPriceUpdatesWithCallback{ + value: calculateTotalFee() + }(provider, publishTime, priceIds, CALLBACK_GAS_LIMIT); + + return (sequenceNumber, priceIds, publishTime); + } + + function testRequestPriceUpdate() public { + bytes32[] memory priceIds = createPriceIds(); uint256 publishTime = block.timestamp; // Fund the consumer contract @@ -103,13 +205,8 @@ contract PulseTest is Test, PulseEvents { vm.expectEmit(); emit PriceUpdateRequested(expectedRequest); - // Calculate total fee including gas component - uint128 totalFee = PYTH_FEE + - PROVIDER_FEE + - (PROVIDER_FEE_PER_GAS * uint128(CALLBACK_GAS_LIMIT)); - // Make the actual call that should emit the event - pulse.requestPriceUpdatesWithCallback{value: totalFee}( + pulse.requestPriceUpdatesWithCallback{value: calculateTotalFee()}( provider, publishTime, priceIds, @@ -129,10 +226,7 @@ contract PulseTest is Test, PulseEvents { } function testExecuteCallback() public { - bytes32[] memory priceIds = new bytes32[](2); - priceIds[0] = BTC_PRICE_FEED_ID; - priceIds[1] = ETH_PRICE_FEED_ID; - + bytes32[] memory priceIds = createPriceIds(); uint256 publishTime = block.timestamp; // Fund the consumer contract @@ -140,41 +234,15 @@ contract PulseTest is Test, PulseEvents { // Step 1: Make the request as consumer vm.prank(address(consumer)); - - // Calculate total fee including gas component - uint128 totalFee = PYTH_FEE + - PROVIDER_FEE + - (PROVIDER_FEE_PER_GAS * uint128(CALLBACK_GAS_LIMIT)); - uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{ - value: totalFee + value: calculateTotalFee() }(provider, publishTime, priceIds, CALLBACK_GAS_LIMIT); - // Step 2: Create mock price feeds that match the expected publish time - PythStructs.PriceFeed[] memory priceFeeds = new PythStructs.PriceFeed[]( - 2 - ); - - // Create mock price feed for BTC with specific values - priceFeeds[0].id = BTC_PRICE_FEED_ID; - priceFeeds[0].price.price = MOCK_BTC_PRICE; - priceFeeds[0].price.conf = MOCK_BTC_CONF; - priceFeeds[0].price.expo = MOCK_PRICE_FEED_EXPO; - priceFeeds[0].price.publishTime = publishTime; - - // Create mock price feed for ETH with specific values - priceFeeds[1].id = ETH_PRICE_FEED_ID; - priceFeeds[1].price.price = MOCK_ETH_PRICE; - priceFeeds[1].price.conf = MOCK_ETH_CONF; - priceFeeds[1].price.expo = MOCK_PRICE_FEED_EXPO; - priceFeeds[1].price.publishTime = publishTime; - - // Mock Pyth's parsePriceFeedUpdates to return our price feeds - vm.mockCall( - address(pyth), - abi.encodeWithSelector(IPyth.parsePriceFeedUpdates.selector), - abi.encode(priceFeeds) + // Step 2: Create mock price feeds and setup Pyth response + PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( + publishTime ); + mockPythResponse(priceFeeds); // Create arrays for expected event data int64[] memory expectedPrices = new int64[](2); @@ -206,12 +274,9 @@ contract PulseTest is Test, PulseEvents { expectedPublishTimes ); - // Create mock update data - bytes[] memory updateData = new bytes[](2); - updateData[0] = abi.encode(priceFeeds[0]); - updateData[1] = abi.encode(priceFeeds[1]); + // Create mock update data and execute callback + bytes[] memory updateData = createUpdateData(priceFeeds); - // Execute callback as provider vm.prank(provider); pulse.executeCallback( provider, @@ -226,4 +291,74 @@ contract PulseTest is Test, PulseEvents { assertEq(consumer.lastProvider(), provider); assertEq(consumer.lastPublishTime(), publishTime); } + + function testExecuteCallbackFailure() public { + FailingPulseConsumer failingConsumer = new FailingPulseConsumer(); + + ( + uint64 sequenceNumber, + bytes32[] memory priceIds, + uint256 publishTime + ) = setupConsumerRequest(address(failingConsumer)); + + PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( + publishTime + ); + mockPythResponse(priceFeeds); + bytes[] memory updateData = createUpdateData(priceFeeds); + + vm.expectEmit(true, true, true, true); + emit PriceUpdateCallbackFailed( + sequenceNumber, + provider, + publishTime, + priceIds, + address(failingConsumer), + "callback failed" + ); + + vm.prank(provider); + pulse.executeCallback( + provider, + sequenceNumber, + priceIds, + updateData, + CALLBACK_GAS_LIMIT + ); + } + + function testExecuteCallbackCustomErrorFailure() public { + CustomErrorPulseConsumer failingConsumer = new CustomErrorPulseConsumer(); + + ( + uint64 sequenceNumber, + bytes32[] memory priceIds, + uint256 publishTime + ) = setupConsumerRequest(address(failingConsumer)); + + PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( + publishTime + ); + mockPythResponse(priceFeeds); + bytes[] memory updateData = createUpdateData(priceFeeds); + + vm.expectEmit(true, true, true, true); + emit PriceUpdateCallbackFailed( + sequenceNumber, + provider, + publishTime, + priceIds, + address(failingConsumer), + "low-level error (possibly out of gas)" + ); + + vm.prank(provider); + pulse.executeCallback( + provider, + sequenceNumber, + priceIds, + updateData, + CALLBACK_GAS_LIMIT + ); + } } From 6f0de41a635919bca8e80450be855e9c6f558b3e Mon Sep 17 00:00:00 2001 From: Daniel Chew Date: Wed, 13 Nov 2024 13:51:19 +0900 Subject: [PATCH 09/29] add test for getFee --- .../ethereum/contracts/forge-test/Pulse.t.sol | 44 +++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol index 5b41ffd9d1..9cf4666685 100644 --- a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol +++ b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol @@ -361,4 +361,48 @@ contract PulseTest is Test, PulseEvents { CALLBACK_GAS_LIMIT ); } + + function testGetFee() public { + // Test with different gas limits to verify fee calculation + uint256[] memory gasLimits = new uint256[](3); + gasLimits[0] = 100_000; + gasLimits[1] = 500_000; + gasLimits[2] = 1_000_000; + + for (uint256 i = 0; i < gasLimits.length; i++) { + uint256 gasLimit = gasLimits[i]; + uint128 expectedFee = PROVIDER_FEE + // Base provider fee + (PROVIDER_FEE_PER_GAS * uint128(gasLimit)) + // Gas-based fee + PYTH_FEE; // Pyth oracle fee + + uint128 actualFee = pulse.getFee(provider, gasLimit); + + assertEq( + actualFee, + expectedFee, + "Fee calculation incorrect for gas limit" + ); + } + + // Test with zero gas limit + uint128 expectedMinFee = PROVIDER_FEE + PYTH_FEE; + uint128 actualMinFee = pulse.getFee(provider, 0); + assertEq( + actualMinFee, + expectedMinFee, + "Minimum fee calculation incorrect" + ); + + // Test with unregistered provider (should return 0 fees) + address unregisteredProvider = address(0x123); + uint128 unregisteredFee = pulse.getFee( + unregisteredProvider, + gasLimits[0] + ); + assertEq( + unregisteredFee, + PYTH_FEE, + "Unregistered provider fee should only include Pyth fee" + ); + } } From 03d506972c67a58b8d9d0f96742ad625d54221ea Mon Sep 17 00:00:00 2001 From: Daniel Chew Date: Wed, 13 Nov 2024 14:24:31 +0900 Subject: [PATCH 10/29] add testWithdraw --- .../ethereum/contracts/forge-test/Pulse.t.sol | 41 +++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol index 9cf4666685..1657505095 100644 --- a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol +++ b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol @@ -405,4 +405,45 @@ contract PulseTest is Test, PulseEvents { "Unregistered provider fee should only include Pyth fee" ); } + + function testWithdraw() public { + // Setup: Request price update to accrue some fees + bytes32[] memory priceIds = createPriceIds(); + vm.deal(address(consumer), 1 gwei); + + vm.prank(address(consumer)); + pulse.requestPriceUpdatesWithCallback{value: calculateTotalFee()}( + provider, + block.timestamp, + priceIds, + CALLBACK_GAS_LIMIT + ); + + // Get provider's balance before withdrawal + uint256 providerBalanceBefore = provider.balance; + PulseState.ProviderInfo memory infoBefore = pulse.getProviderInfo( + provider + ); + + // Withdraw fees + vm.prank(provider); + pulse.withdraw(infoBefore.accruedFeesInWei); + + // Verify balances + assertEq( + provider.balance, + providerBalanceBefore + infoBefore.accruedFeesInWei + ); + + PulseState.ProviderInfo memory infoAfter = pulse.getProviderInfo( + provider + ); + assertEq(infoAfter.accruedFeesInWei, 0); + } + + function testWithdrawInsufficientBalance() public { + vm.prank(provider); + vm.expectRevert("Insufficient balance"); + pulse.withdraw(1 ether); + } } From cd4da75591bcd9ffc94349402316a42cf3aade64 Mon Sep 17 00:00:00 2001 From: Daniel Chew Date: Wed, 13 Nov 2024 14:37:02 +0900 Subject: [PATCH 11/29] add testSetAndWithdrawAssFeeManager --- .../ethereum/contracts/forge-test/Pulse.t.sol | 36 +++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol index 1657505095..98b11e2dc2 100644 --- a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol +++ b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol @@ -446,4 +446,40 @@ contract PulseTest is Test, PulseEvents { vm.expectRevert("Insufficient balance"); pulse.withdraw(1 ether); } + + function testSetAndWithdrawAsFeeManager() public { + address feeManager = address(0x789); + + // Set fee manager + vm.prank(provider); + pulse.setFeeManager(feeManager); + + // Verify fee manager was set + PulseState.ProviderInfo memory info = pulse.getProviderInfo(provider); + assertEq(info.feeManager, feeManager); + + // Setup: Request price update to accrue some fees + bytes32[] memory priceIds = createPriceIds(); + vm.deal(address(consumer), 1 gwei); + + vm.prank(address(consumer)); + pulse.requestPriceUpdatesWithCallback{value: calculateTotalFee()}( + provider, + block.timestamp, + priceIds, + CALLBACK_GAS_LIMIT + ); + + // Test withdrawal as fee manager + uint256 managerBalanceBefore = feeManager.balance; + info = pulse.getProviderInfo(provider); + + vm.prank(feeManager); + pulse.withdrawAsFeeManager(provider, info.accruedFeesInWei); + + assertEq( + feeManager.balance, + managerBalanceBefore + info.accruedFeesInWei + ); + } } From b1808dfee31f6ce771ace860389098c82b422fd3 Mon Sep 17 00:00:00 2001 From: Daniel Chew Date: Wed, 13 Nov 2024 14:37:13 +0900 Subject: [PATCH 12/29] add testMaxNumPrices --- .../ethereum/contracts/forge-test/Pulse.t.sol | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol index 98b11e2dc2..1eecca0797 100644 --- a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol +++ b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol @@ -482,4 +482,26 @@ contract PulseTest is Test, PulseEvents { managerBalanceBefore + info.accruedFeesInWei ); } + + function testMaxNumPrices() public { + // Set max number of prices + vm.prank(provider); + pulse.setMaxNumPrices(1); + + // Try to request more prices than allowed + bytes32[] memory priceIds = new bytes32[](2); + priceIds[0] = BTC_PRICE_FEED_ID; + priceIds[1] = ETH_PRICE_FEED_ID; + + vm.deal(address(consumer), 1 gwei); + vm.prank(address(consumer)); + + vm.expectRevert("Exceeds max number of prices"); + pulse.requestPriceUpdatesWithCallback{value: calculateTotalFee()}( + provider, + block.timestamp, + priceIds, + CALLBACK_GAS_LIMIT + ); + } } From 142b964d293f524ab4be449f44bdf1281dcfdd8b Mon Sep 17 00:00:00 2001 From: Daniel Chew Date: Wed, 13 Nov 2024 14:37:55 +0900 Subject: [PATCH 13/29] add testSetProviderUri --- .../ethereum/contracts/forge-test/Pulse.t.sol | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol index 1eecca0797..ccf78c062f 100644 --- a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol +++ b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol @@ -504,4 +504,14 @@ contract PulseTest is Test, PulseEvents { CALLBACK_GAS_LIMIT ); } + + function testSetProviderUri() public { + bytes memory newUri = "https://updated-provider.com"; + + vm.prank(provider); + pulse.setProviderUri(newUri); + + PulseState.ProviderInfo memory info = pulse.getProviderInfo(provider); + assertEq(info.uri, newUri); + } } From 54fa61b8dd9d2231ea2cde57085191e7370c261d Mon Sep 17 00:00:00 2001 From: Daniel Chew Date: Wed, 13 Nov 2024 15:55:37 +0900 Subject: [PATCH 14/29] add more test --- .../contracts/contracts/pulse/Pulse.sol | 33 ++--- .../contracts/contracts/pulse/PulseErrors.sol | 3 + .../ethereum/contracts/forge-test/Pulse.t.sol | 120 +++++++++++++++++- 3 files changed, 134 insertions(+), 22 deletions(-) diff --git a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol index 5888264e6f..5ec1454e73 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol @@ -117,7 +117,10 @@ abstract contract Pulse is IPulse, PulseState { providerInfo.maxNumPrices > 0 && priceIds.length > providerInfo.maxNumPrices ) { - revert("Exceeds max number of prices"); + revert ExceedsMaxPrices( + uint32(priceIds.length), + providerInfo.maxNumPrices + ); } // Assign sequence number and increment @@ -154,10 +157,19 @@ abstract contract Pulse is IPulse, PulseState { ) external payable override { Request storage req = findActiveRequest(provider, sequenceNumber); - require( - gasleft() >= req.callbackGasLimit, - "Insufficient gas for callback" - ); + if ( + keccak256(abi.encode(req.priceIds)) != + keccak256(abi.encode(priceIds)) + ) { + revert InvalidPriceIds(priceIds, req.priceIds); + } + + if (req.callbackGasLimit != callbackGasLimit) { + revert InvalidCallbackGasLimit( + callbackGasLimit, + req.callbackGasLimit + ); + } PythStructs.PriceFeed[] memory priceFeeds = IPyth(_state.pyth) .parsePriceFeedUpdates( @@ -169,17 +181,6 @@ abstract contract Pulse is IPulse, PulseState { uint256 publishTime = priceFeeds[0].price.publishTime; - // Verify request parameters match - require( - keccak256(abi.encode(req.priceIds)) == - keccak256(abi.encode(priceIds)), - "Invalid price IDs" - ); - require( - req.callbackGasLimit == callbackGasLimit, - "Invalid callback gas limit" - ); - try IPulseConsumer(req.requester).pulseCallback( sequenceNumber, diff --git a/target_chains/ethereum/contracts/contracts/pulse/PulseErrors.sol b/target_chains/ethereum/contracts/contracts/pulse/PulseErrors.sol index 187ccf00a2..535ad4d746 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/PulseErrors.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/PulseErrors.sol @@ -8,3 +8,6 @@ error InsufficientFee(); error Unauthorized(); error InvalidCallbackGas(); error CallbackFailed(); +error InvalidPriceIds(bytes32[] requested, bytes32[] stored); +error InvalidCallbackGasLimit(uint256 requested, uint256 stored); +error ExceedsMaxPrices(uint32 requested, uint32 maxAllowed); diff --git a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol index ccf78c062f..d78de073cf 100644 --- a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol +++ b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol @@ -8,6 +8,7 @@ import "../contracts/pulse/PulseUpgradeable.sol"; import "../contracts/pulse/IPulse.sol"; import "../contracts/pulse/PulseState.sol"; import "../contracts/pulse/PulseEvents.sol"; +import "../contracts/pulse/PulseErrors.sol"; contract MockPulseConsumer is IPulseConsumer { uint64 public lastSequenceNumber; @@ -141,8 +142,8 @@ contract PulseTest is Test, PulseEvents { ); } - // Helper function to create update data - function createUpdateData( + // Helper function to create mock update data + function createMockUpdateData( PythStructs.PriceFeed[] memory priceFeeds ) internal pure returns (bytes[] memory) { bytes[] memory updateData = new bytes[](2); @@ -225,6 +226,20 @@ contract PulseTest is Test, PulseEvents { assertEq(lastRequest.requester, expectedRequest.requester); } + function testRequestWithInsufficientFee() public { + bytes32[] memory priceIds = createPriceIds(); + vm.deal(address(consumer), 1 gwei); + + vm.prank(address(consumer)); + vm.expectRevert(InsufficientFee.selector); + pulse.requestPriceUpdatesWithCallback{value: PYTH_FEE}( // Intentionally low fee + provider, + block.timestamp, + priceIds, + CALLBACK_GAS_LIMIT + ); + } + function testExecuteCallback() public { bytes32[] memory priceIds = createPriceIds(); uint256 publishTime = block.timestamp; @@ -275,7 +290,7 @@ contract PulseTest is Test, PulseEvents { ); // Create mock update data and execute callback - bytes[] memory updateData = createUpdateData(priceFeeds); + bytes[] memory updateData = createMockUpdateData(priceFeeds); vm.prank(provider); pulse.executeCallback( @@ -305,7 +320,7 @@ contract PulseTest is Test, PulseEvents { publishTime ); mockPythResponse(priceFeeds); - bytes[] memory updateData = createUpdateData(priceFeeds); + bytes[] memory updateData = createMockUpdateData(priceFeeds); vm.expectEmit(true, true, true, true); emit PriceUpdateCallbackFailed( @@ -340,7 +355,7 @@ contract PulseTest is Test, PulseEvents { publishTime ); mockPythResponse(priceFeeds); - bytes[] memory updateData = createUpdateData(priceFeeds); + bytes[] memory updateData = createMockUpdateData(priceFeeds); vm.expectEmit(true, true, true, true); emit PriceUpdateCallbackFailed( @@ -362,6 +377,97 @@ contract PulseTest is Test, PulseEvents { ); } + // Test executing callback with mismatched price IDs + function testExecuteCallbackWithMismatchedPriceIds() public { + ( + uint64 sequenceNumber, + bytes32[] memory originalPriceIds, + uint256 publishTime + ) = setupConsumerRequest(address(consumer)); + + // Create different price IDs array + bytes32[] memory differentPriceIds = new bytes32[](1); + differentPriceIds[0] = bytes32(uint256(1)); // Different price ID + + PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( + publishTime + ); + mockPythResponse(priceFeeds); + bytes[] memory updateData = createMockUpdateData(priceFeeds); + + vm.prank(provider); + vm.expectRevert( + abi.encodeWithSelector( + InvalidPriceIds.selector, + differentPriceIds, + originalPriceIds + ) + ); + pulse.executeCallback( + provider, + sequenceNumber, + differentPriceIds, + updateData, + CALLBACK_GAS_LIMIT + ); + } + + function testExecuteCallbackWithInsufficientGas() public { + ( + uint64 sequenceNumber, + bytes32[] memory priceIds, + uint256 publishTime + ) = setupConsumerRequest(address(consumer)); + + PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( + publishTime + ); + mockPythResponse(priceFeeds); + bytes[] memory updateData = createMockUpdateData(priceFeeds); + + vm.prank(provider); + vm.expectRevert(); + pulse.executeCallback{gas: 10000}( + provider, + sequenceNumber, + priceIds, + updateData, + CALLBACK_GAS_LIMIT + ); + } + + function testExecuteCallbackWithInvalidGasLimit() public { + ( + uint64 sequenceNumber, + bytes32[] memory priceIds, + uint256 publishTime + ) = setupConsumerRequest(address(consumer)); + + PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( + publishTime + ); + mockPythResponse(priceFeeds); + bytes[] memory updateData = createMockUpdateData(priceFeeds); + + // Try to execute with different gas limit than what was requested + uint256 differentGasLimit = CALLBACK_GAS_LIMIT + 1000; + vm.prank(provider); + vm.expectRevert( + abi.encodeWithSelector( + InvalidCallbackGasLimit.selector, + differentGasLimit, + CALLBACK_GAS_LIMIT + ) + ); + pulse.executeCallback( + provider, + sequenceNumber, + priceIds, + updateData, + differentGasLimit + ); + } + function testGetFee() public { // Test with different gas limits to verify fee calculation uint256[] memory gasLimits = new uint256[](3); @@ -496,7 +602,9 @@ contract PulseTest is Test, PulseEvents { vm.deal(address(consumer), 1 gwei); vm.prank(address(consumer)); - vm.expectRevert("Exceeds max number of prices"); + vm.expectRevert( + abi.encodeWithSelector(ExceedsMaxPrices.selector, 2, 1) + ); pulse.requestPriceUpdatesWithCallback{value: calculateTotalFee()}( provider, block.timestamp, From 907decb74f7fad07a82d66d0f272627b5e466d65 Mon Sep 17 00:00:00 2001 From: Daniel Chew Date: Wed, 13 Nov 2024 16:25:30 +0900 Subject: [PATCH 15/29] add testExecuteCallbackWithFutureTimestamp --- .../ethereum/contracts/forge-test/Pulse.t.sol | 36 +++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol index d78de073cf..a6b2d365a9 100644 --- a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol +++ b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol @@ -219,6 +219,10 @@ contract PulseTest is Test, PulseEvents { assertEq(lastRequest.provider, expectedRequest.provider); assertEq(lastRequest.sequenceNumber, expectedRequest.sequenceNumber); assertEq(lastRequest.publishTime, expectedRequest.publishTime); + assertEq( + keccak256(abi.encode(lastRequest.priceIds)), + keccak256(abi.encode(expectedRequest.priceIds)) + ); assertEq( lastRequest.callbackGasLimit, expectedRequest.callbackGasLimit @@ -468,6 +472,38 @@ contract PulseTest is Test, PulseEvents { ); } + function testExecuteCallbackWithFutureTimestamp() public { + // Setup request with future timestamp + bytes32[] memory priceIds = createPriceIds(); + uint256 futureTime = block.timestamp + 1 days; + vm.deal(address(consumer), 1 gwei); + + vm.prank(address(consumer)); + uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{ + value: calculateTotalFee() + }(provider, futureTime, priceIds, CALLBACK_GAS_LIMIT); + + // Try to execute callback before the requested timestamp + PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( + futureTime // Mock price feeds with future timestamp + ); + mockPythResponse(priceFeeds); // This will make parsePriceFeedUpdates return future-dated prices + bytes[] memory updateData = createMockUpdateData(priceFeeds); + + vm.prank(provider); + // Should succeed because we're simulating receiving future-dated price updates + pulse.executeCallback( + provider, + sequenceNumber, + priceIds, + updateData, + CALLBACK_GAS_LIMIT + ); + + // Verify the callback was executed with future timestamp + assertEq(consumer.lastPublishTime(), futureTime); + } + function testGetFee() public { // Test with different gas limits to verify fee calculation uint256[] memory gasLimits = new uint256[](3); From 19d550020a75e20916422cbdd9c65e6662b7b183 Mon Sep 17 00:00:00 2001 From: Daniel Chew Date: Wed, 13 Nov 2024 16:41:56 +0900 Subject: [PATCH 16/29] update tests --- .../ethereum/contracts/forge-test/Pulse.t.sol | 76 +++++++++++++++++-- 1 file changed, 68 insertions(+), 8 deletions(-) diff --git a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol index a6b2d365a9..cbc15f417d 100644 --- a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol +++ b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol @@ -132,7 +132,7 @@ contract PulseTest is Test, PulseEvents { } // Helper function to mock Pyth response - function mockPythResponse( + function mockParsePriceFeedUpdates( PythStructs.PriceFeed[] memory priceFeeds ) internal { vm.mockCall( @@ -261,7 +261,7 @@ contract PulseTest is Test, PulseEvents { PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( publishTime ); - mockPythResponse(priceFeeds); + mockParsePriceFeedUpdates(priceFeeds); // Create arrays for expected event data int64[] memory expectedPrices = new int64[](2); @@ -323,7 +323,7 @@ contract PulseTest is Test, PulseEvents { PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( publishTime ); - mockPythResponse(priceFeeds); + mockParsePriceFeedUpdates(priceFeeds); bytes[] memory updateData = createMockUpdateData(priceFeeds); vm.expectEmit(true, true, true, true); @@ -358,7 +358,7 @@ contract PulseTest is Test, PulseEvents { PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( publishTime ); - mockPythResponse(priceFeeds); + mockParsePriceFeedUpdates(priceFeeds); bytes[] memory updateData = createMockUpdateData(priceFeeds); vm.expectEmit(true, true, true, true); @@ -396,7 +396,7 @@ contract PulseTest is Test, PulseEvents { PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( publishTime ); - mockPythResponse(priceFeeds); + mockParsePriceFeedUpdates(priceFeeds); bytes[] memory updateData = createMockUpdateData(priceFeeds); vm.prank(provider); @@ -426,7 +426,7 @@ contract PulseTest is Test, PulseEvents { PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( publishTime ); - mockPythResponse(priceFeeds); + mockParsePriceFeedUpdates(priceFeeds); bytes[] memory updateData = createMockUpdateData(priceFeeds); vm.prank(provider); @@ -450,7 +450,7 @@ contract PulseTest is Test, PulseEvents { PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( publishTime ); - mockPythResponse(priceFeeds); + mockParsePriceFeedUpdates(priceFeeds); bytes[] memory updateData = createMockUpdateData(priceFeeds); // Try to execute with different gas limit than what was requested @@ -487,7 +487,7 @@ contract PulseTest is Test, PulseEvents { PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( futureTime // Mock price feeds with future timestamp ); - mockPythResponse(priceFeeds); // This will make parsePriceFeedUpdates return future-dated prices + mockParsePriceFeedUpdates(priceFeeds); // This will make parsePriceFeedUpdates return future-dated prices bytes[] memory updateData = createMockUpdateData(priceFeeds); vm.prank(provider); @@ -504,6 +504,66 @@ contract PulseTest is Test, PulseEvents { assertEq(consumer.lastPublishTime(), futureTime); } + function testExecuteCallbackWithWrongProvider() public { + ( + uint64 sequenceNumber, + bytes32[] memory priceIds, + uint256 publishTime + ) = setupConsumerRequest(address(consumer)); + + PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( + publishTime + ); + mockParsePriceFeedUpdates(priceFeeds); + bytes[] memory updateData = createMockUpdateData(priceFeeds); + + address wrongProvider = address(0x999); + vm.prank(wrongProvider); + vm.expectRevert(NoSuchRequest.selector); + pulse.executeCallback( + wrongProvider, + sequenceNumber, + priceIds, + updateData, + CALLBACK_GAS_LIMIT + ); + } + + function testDoubleExecuteCallback() public { + ( + uint64 sequenceNumber, + bytes32[] memory priceIds, + uint256 publishTime + ) = setupConsumerRequest(address(consumer)); + + PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( + publishTime + ); + mockParsePriceFeedUpdates(priceFeeds); + bytes[] memory updateData = createMockUpdateData(priceFeeds); + + // First execution + vm.prank(provider); + pulse.executeCallback( + provider, + sequenceNumber, + priceIds, + updateData, + CALLBACK_GAS_LIMIT + ); + + // Second execution should fail + vm.prank(provider); + vm.expectRevert(NoSuchRequest.selector); + pulse.executeCallback( + provider, + sequenceNumber, + priceIds, + updateData, + CALLBACK_GAS_LIMIT + ); + } + function testGetFee() public { // Test with different gas limits to verify fee calculation uint256[] memory gasLimits = new uint256[](3); From f8e398d179b315bfffaae9fb2c6a9c397d5c7beb Mon Sep 17 00:00:00 2001 From: Daniel Chew Date: Mon, 18 Nov 2024 16:18:35 +0900 Subject: [PATCH 17/29] remove provider --- .../contracts/contracts/pulse/IPulse.sol | 41 +-- .../contracts/contracts/pulse/Pulse.sol | 303 ++++-------------- .../contracts/contracts/pulse/PulseEvents.sol | 43 +-- .../contracts/contracts/pulse/PulseState.sol | 17 +- .../contracts/pulse/PulseUpgradeable.sol | 3 +- .../ethereum/contracts/forge-test/Pulse.t.sol | 275 +++++++--------- 6 files changed, 184 insertions(+), 498 deletions(-) diff --git a/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol b/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol index 8e79b97c13..125d06f357 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol @@ -8,7 +8,7 @@ import "./PulseState.sol"; interface IPulseConsumer { function pulseCallback( uint64 sequenceNumber, - address provider, + address updater, uint256 publishTime, bytes32[] calldata priceIds ) external; @@ -17,66 +17,33 @@ interface IPulseConsumer { interface IPulse is PulseEvents { // Core functions function requestPriceUpdatesWithCallback( - address provider, uint256 publishTime, bytes32[] calldata priceIds, uint256 callbackGasLimit ) external payable returns (uint64 sequenceNumber); function executeCallback( - address provider, uint64 sequenceNumber, bytes32[] calldata priceIds, bytes[] calldata updateData, uint256 callbackGasLimit ) external payable; - // Provider management - function register( - uint128 feeInWei, - uint128 feePerGas, - bytes calldata uri - ) external; - - function setProviderFee(uint128 newFeeInWei) external; - - function setProviderFeeAsFeeManager( - address provider, - uint128 newFeeInWei - ) external; - - function setProviderUri(bytes calldata uri) external; - - function withdraw(uint128 amount) external; - - function withdrawAsFeeManager(address provider, uint128 amount) external; - // Getters function getFee( - address provider, uint256 callbackGasLimit ) external view returns (uint128 feeAmount); function getPythFeeInWei() external view returns (uint128 pythFeeInWei); - function getAccruedPythFees() - external - view - returns (uint128 accruedPythFeesInWei); - - function getDefaultProvider() external view returns (address); - - function getProviderInfo( - address provider - ) external view returns (PulseState.ProviderInfo memory info); + function getAccruedFees() external view returns (uint128 accruedFeesInWei); function getRequest( - address provider, uint64 sequenceNumber ) external view returns (PulseState.Request memory req); - // Setters + // Add these functions to the IPulse interface function setFeeManager(address manager) external; - function setMaxNumPrices(uint32 maxNumPrices) external; + function withdrawAsFeeManager(uint128 amount) external; } diff --git a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol index 5ec1454e73..167f3e9f7a 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol @@ -12,150 +12,58 @@ abstract contract Pulse is IPulse, PulseState { function _initialize( address admin, uint128 pythFeeInWei, - address defaultProvider, address pythAddress, bool prefillRequestStorage ) internal { require(admin != address(0), "admin is zero address"); - require( - defaultProvider != address(0), - "defaultProvider is zero address" - ); require(pythAddress != address(0), "pyth is zero address"); _state.admin = admin; - _state.accruedPythFeesInWei = 0; + _state.accruedFeesInWei = 0; _state.pythFeeInWei = pythFeeInWei; - _state.defaultProvider = defaultProvider; _state.pyth = pythAddress; + _state.currentSequenceNumber = 1; if (prefillRequestStorage) { - // Write some data to every storage slot in the requests array such that new requests - // use a more consistent amount of gas. - // Note that these requests are not live because their sequenceNumber is 0. for (uint8 i = 0; i < NUM_REQUESTS; i++) { Request storage req = _state.requests[i]; - req.sequenceNumber = 0; // Keep it inactive + req.sequenceNumber = 0; req.publishTime = 1; - // No need to prefill dynamic arrays (priceIds, updateData) req.callbackGasLimit = 1; req.requester = address(1); } } } - function register( - uint128 feeInWei, - uint128 feePerGas, - bytes calldata uri - ) public override { - ProviderInfo storage providerInfo = _state.providers[msg.sender]; - - providerInfo.feeInWei = feeInWei; - providerInfo.feePerGas = feePerGas; - providerInfo.uri = uri; - providerInfo.sequenceNumber += 1; - - emit ProviderRegistered(providerInfo); - } - - function withdraw(uint128 amount) public override { - ProviderInfo storage providerInfo = _state.providers[msg.sender]; - - // Use checks-effects-interactions pattern to prevent reentrancy attacks. - require( - providerInfo.accruedFeesInWei >= amount, - "Insufficient balance" - ); - providerInfo.accruedFeesInWei -= amount; - - // Interaction with an external contract or token transfer - (bool sent, ) = msg.sender.call{value: amount}(""); - require(sent, "withdrawal to msg.sender failed"); - - emit ProviderWithdrawn(msg.sender, msg.sender, amount); - } - - function withdrawAsFeeManager( - address provider, - uint128 amount - ) external override { - ProviderInfo storage providerInfo = _state.providers[provider]; - - if (providerInfo.sequenceNumber == 0) { - revert NoSuchProvider(); - } - - if (providerInfo.feeManager != msg.sender) { - revert Unauthorized(); - } - - // Use checks-effects-interactions pattern to prevent reentrancy attacks. - require( - providerInfo.accruedFeesInWei >= amount, - "Insufficient balance" - ); - providerInfo.accruedFeesInWei -= amount; - - // Interaction with an external contract or token transfer - (bool sent, ) = msg.sender.call{value: amount}(""); - require(sent, "withdrawal to msg.sender failed"); - - emit ProviderWithdrawn(provider, msg.sender, amount); - } - function requestPriceUpdatesWithCallback( - address provider, uint256 publishTime, bytes32[] calldata priceIds, uint256 callbackGasLimit ) external payable override returns (uint64 requestSequenceNumber) { - ProviderInfo storage providerInfo = _state.providers[provider]; - if (providerInfo.sequenceNumber == 0) revert NoSuchProvider(); - - if ( - providerInfo.maxNumPrices > 0 && - priceIds.length > providerInfo.maxNumPrices - ) { - revert ExceedsMaxPrices( - uint32(priceIds.length), - providerInfo.maxNumPrices - ); - } - - // Assign sequence number and increment - requestSequenceNumber = providerInfo.sequenceNumber++; + requestSequenceNumber = _state.currentSequenceNumber++; - // Verify fee payment - uint128 requiredFee = getFee(provider, callbackGasLimit); + uint128 requiredFee = getFee(callbackGasLimit); if (msg.value < requiredFee) revert InsufficientFee(); - // Store request for callback execution - Request storage req = allocRequest(provider, requestSequenceNumber); - req.provider = provider; + Request storage req = allocRequest(requestSequenceNumber); req.sequenceNumber = requestSequenceNumber; req.publishTime = publishTime; req.priceIds = priceIds; req.callbackGasLimit = callbackGasLimit; req.requester = msg.sender; - // Update fee balances - providerInfo.accruedFeesInWei += providerInfo.feeInWei; - _state.accruedPythFeesInWei += - SafeCast.toUint128(msg.value) - - providerInfo.feeInWei; + _state.accruedFeesInWei += SafeCast.toUint128(msg.value); emit PriceUpdateRequested(req); } function executeCallback( - address provider, uint64 sequenceNumber, bytes32[] calldata priceIds, bytes[] calldata updateData, uint256 callbackGasLimit ) external payable override { - Request storage req = findActiveRequest(provider, sequenceNumber); + Request storage req = findActiveRequest(sequenceNumber); if ( keccak256(abi.encode(req.priceIds)) != @@ -190,13 +98,7 @@ abstract contract Pulse is IPulse, PulseState { ) { // Callback succeeded - emitPriceUpdate( - sequenceNumber, - msg.sender, - publishTime, - priceIds, - priceFeeds - ); + emitPriceUpdate(sequenceNumber, publishTime, priceIds, priceFeeds); } catch Error(string memory reason) { // Explicit revert/require emit PriceUpdateCallbackFailed( @@ -219,13 +121,11 @@ abstract contract Pulse is IPulse, PulseState { ); } - // Clear request regardless of callback success - clearRequest(msg.sender, sequenceNumber); + clearRequest(sequenceNumber); } function emitPriceUpdate( uint64 sequenceNumber, - address provider, uint256 publishTime, bytes32[] memory priceIds, PythStructs.PriceFeed[] memory priceFeeds @@ -244,7 +144,7 @@ abstract contract Pulse is IPulse, PulseState { emit PriceUpdateExecuted( sequenceNumber, - provider, + msg.sender, publishTime, priceIds, prices, @@ -254,37 +154,12 @@ abstract contract Pulse is IPulse, PulseState { ); } - function getProviderInfo( - address provider - ) public view override returns (ProviderInfo memory info) { - info = _state.providers[provider]; - } - - function getDefaultProvider() - public - view - override - returns (address provider) - { - provider = _state.defaultProvider; - } - - function getRequest( - address provider, - uint64 sequenceNumber - ) public view override returns (Request memory req) { - req = findRequest(provider, sequenceNumber); - } - function getFee( - address provider, uint256 callbackGasLimit ) public view override returns (uint128 feeAmount) { - ProviderInfo storage providerInfo = _state.providers[provider]; - feeAmount = - providerInfo.feeInWei + - (providerInfo.feePerGas * uint128(callbackGasLimit)) + - _state.pythFeeInWei; + uint128 baseFee = _state.pythFeeInWei; + uint256 gasFee = callbackGasLimit * tx.gasprice; + feeAmount = baseFee + SafeCast.toUint128(gasFee); } function getPythFeeInWei() @@ -296,167 +171,105 @@ abstract contract Pulse is IPulse, PulseState { pythFeeInWei = _state.pythFeeInWei; } - function getAccruedPythFees() + function getAccruedFees() public view override - returns (uint128 accruedPythFeesInWei) + returns (uint128 accruedFeesInWei) { - accruedPythFeesInWei = _state.accruedPythFeesInWei; + accruedFeesInWei = _state.accruedFeesInWei; } - // Set provider fee. It will revert if provider is not registered. - function setProviderFee(uint128 newFeeInWei) external override { - ProviderInfo storage provider = _state.providers[msg.sender]; - - if (provider.sequenceNumber == 0) { - revert NoSuchProvider(); - } - uint128 oldFeeInWei = provider.feeInWei; - provider.feeInWei = newFeeInWei; - emit ProviderFeeUpdated(msg.sender, oldFeeInWei, newFeeInWei); + function getRequest( + uint64 sequenceNumber + ) public view override returns (Request memory req) { + req = findRequest(sequenceNumber); } - function setProviderFeeAsFeeManager( - address provider, - uint128 newFeeInWei - ) external override { - ProviderInfo storage providerInfo = _state.providers[provider]; - - if (providerInfo.sequenceNumber == 0) { - revert NoSuchProvider(); - } - - if (providerInfo.feeManager != msg.sender) { - revert Unauthorized(); - } - - uint128 oldFeeInWei = providerInfo.feeInWei; - providerInfo.feeInWei = newFeeInWei; - - emit ProviderFeeUpdated(provider, oldFeeInWei, newFeeInWei); + function requestKey( + uint64 sequenceNumber + ) internal pure returns (bytes32 hash, uint8 shortHash) { + hash = keccak256(abi.encodePacked(sequenceNumber)); + shortHash = uint8(hash[0] & NUM_REQUESTS_MASK); } - // Set provider uri. It will revert if provider is not registered. - function setProviderUri(bytes calldata newUri) external override { - ProviderInfo storage provider = _state.providers[msg.sender]; - if (provider.sequenceNumber == 0) { - revert NoSuchProvider(); - } - bytes memory oldUri = provider.uri; - provider.uri = newUri; - emit ProviderUriUpdated(msg.sender, oldUri, newUri); - } + function withdrawFees(uint128 amount) external { + require(msg.sender == _state.admin, "Only admin can withdraw fees"); + require(_state.accruedFeesInWei >= amount, "Insufficient balance"); - function setFeeManager(address manager) external override { - ProviderInfo storage provider = _state.providers[msg.sender]; - if (provider.sequenceNumber == 0) { - revert NoSuchProvider(); - } + _state.accruedFeesInWei -= amount; - address oldFeeManager = provider.feeManager; - provider.feeManager = manager; - emit ProviderFeeManagerUpdated(msg.sender, oldFeeManager, manager); - } + (bool sent, ) = msg.sender.call{value: amount}(""); + require(sent, "Failed to send fees"); - function requestKey( - address provider, - uint64 sequenceNumber - ) internal pure returns (bytes32 hash, uint8 shortHash) { - hash = keccak256(abi.encodePacked(provider, sequenceNumber)); - shortHash = uint8(hash[0] & NUM_REQUESTS_MASK); + emit FeesWithdrawn(msg.sender, amount); } - // Find an in-flight active request for given the provider and the sequence number. - // This method returns a reference to the request, and will revert if the request is - // not active. function findActiveRequest( - address provider, uint64 sequenceNumber ) internal view returns (Request storage req) { - req = findRequest(provider, sequenceNumber); + req = findRequest(sequenceNumber); - // Check there is an active request for the given provider and sequence number. - if ( - !isActive(req) || - req.provider != provider || - req.sequenceNumber != sequenceNumber - ) revert NoSuchRequest(); + if (!isActive(req) || req.sequenceNumber != sequenceNumber) + revert NoSuchRequest(); } - // Find an in-flight request. - // Note that this method can return requests that are not currently active. The caller is responsible for checking - // that the returned request is active (if they care). function findRequest( - address provider, uint64 sequenceNumber ) internal view returns (Request storage req) { - (bytes32 key, uint8 shortKey) = requestKey(provider, sequenceNumber); + (bytes32 key, uint8 shortKey) = requestKey(sequenceNumber); req = _state.requests[shortKey]; - if (req.provider == provider && req.sequenceNumber == sequenceNumber) { + if (req.sequenceNumber == sequenceNumber) { return req; } else { req = _state.requestsOverflow[key]; } } - // Clear the storage for an in-flight request, deleting it from the hash table. - function clearRequest(address provider, uint64 sequenceNumber) internal { - (bytes32 key, uint8 shortKey) = requestKey(provider, sequenceNumber); + function clearRequest(uint64 sequenceNumber) internal { + (bytes32 key, uint8 shortKey) = requestKey(sequenceNumber); Request storage req = _state.requests[shortKey]; - if (req.provider == provider && req.sequenceNumber == sequenceNumber) { + if (req.sequenceNumber == sequenceNumber) { req.sequenceNumber = 0; } else { delete _state.requestsOverflow[key]; } } - // Allocate storage space for a new in-flight request. This method returns a pointer to a storage slot - // that the caller should overwrite with the new request. Note that the memory at this storage slot may - // -- and will -- be filled with arbitrary values, so the caller *must* overwrite every field of the returned - // struct. function allocRequest( - address provider, uint64 sequenceNumber ) internal returns (Request storage req) { - (, uint8 shortKey) = requestKey(provider, sequenceNumber); + (, uint8 shortKey) = requestKey(sequenceNumber); req = _state.requests[shortKey]; if (isActive(req)) { - // There's already a prior active request in the storage slot we want to use. - // Overflow the prior request to the requestsOverflow mapping. - // It is important that this code overflows the *prior* request to the mapping, and not the new request. - // There is a chance that some requests never get revealed and remain active forever. We do not want such - // requests to fill up all of the space in the array and cause all new requests to incur the higher gas cost - // of the mapping. - // - // This operation is expensive, but should be rare. If overflow happens frequently, increase - // the size of the requests array to support more concurrent active requests. - (bytes32 reqKey, ) = requestKey(req.provider, req.sequenceNumber); + (bytes32 reqKey, ) = requestKey(req.sequenceNumber); _state.requestsOverflow[reqKey] = req; } } - // Returns true if a request is active, i.e., its corresponding price update has not yet been executed. function isActive(Request storage req) internal view returns (bool) { - // Note that a provider's initial registration occupies sequence number 0, so there is no way to construct - // a price update request with sequence number 0. return req.sequenceNumber != 0; } - function setMaxNumPrices(uint32 maxNumPrices) external override { - ProviderInfo storage provider = _state.providers[msg.sender]; - if (provider.sequenceNumber == 0) revert NoSuchProvider(); + function setFeeManager(address manager) external override { + require(msg.sender == _state.admin, "Only admin can set fee manager"); + address oldFeeManager = _state.feeManager; + _state.feeManager = manager; + emit FeeManagerUpdated(_state.admin, oldFeeManager, manager); + } + + function withdrawAsFeeManager(uint128 amount) external override { + require(msg.sender == _state.feeManager, "Only fee manager"); + require(_state.accruedFeesInWei >= amount, "Insufficient balance"); - uint32 oldMaxNumPrices = provider.maxNumPrices; - provider.maxNumPrices = maxNumPrices; + _state.accruedFeesInWei -= amount; - emit ProviderMaxNumPricesUpdated( - msg.sender, - oldMaxNumPrices, - maxNumPrices - ); + (bool sent, ) = msg.sender.call{value: amount}(""); + require(sent, "Failed to send fees"); + + emit FeesWithdrawn(msg.sender, amount); } } diff --git a/target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol b/target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol index 070094fdbb..1c96797b39 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol @@ -4,14 +4,11 @@ pragma solidity ^0.8.0; import "./PulseState.sol"; interface PulseEvents { - // Events - event ProviderRegistered(PulseState.ProviderInfo providerInfo); - event PriceUpdateRequested(PulseState.Request request); event PriceUpdateExecuted( uint64 indexed sequenceNumber, - address indexed provider, + address indexed updater, uint256 publishTime, bytes32[] priceIds, int64[] prices, @@ -20,42 +17,20 @@ interface PulseEvents { uint256[] publishTimes ); - event ProviderFeeUpdated( - address indexed provider, - uint128 oldFeeInWei, - uint128 newFeeInWei - ); - - event ProviderUriUpdated( - address indexed provider, - bytes oldUri, - bytes newUri - ); - - event ProviderWithdrawn( - address indexed provider, - address indexed recipient, - uint128 amount - ); - - event ProviderFeeManagerUpdated( - address indexed provider, - address oldFeeManager, - address newFeeManager - ); - - event ProviderMaxNumPricesUpdated( - address indexed provider, - uint32 oldMaxNumPrices, - uint32 maxNumPrices - ); + event FeesWithdrawn(address indexed recipient, uint128 amount); event PriceUpdateCallbackFailed( uint64 indexed sequenceNumber, - address indexed provider, + address indexed updater, uint256 publishTime, bytes32[] priceIds, address requester, string reason ); + + event FeeManagerUpdated( + address indexed admin, + address oldFeeManager, + address newFeeManager + ); } diff --git a/target_chains/ethereum/contracts/contracts/pulse/PulseState.sol b/target_chains/ethereum/contracts/contracts/pulse/PulseState.sol index 3341edee21..08052dd1e3 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/PulseState.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/PulseState.sol @@ -7,7 +7,6 @@ contract PulseState { bytes1 public constant NUM_REQUESTS_MASK = 0x1f; struct Request { - address provider; uint64 sequenceNumber; uint256 publishTime; bytes32[] priceIds; @@ -15,25 +14,15 @@ contract PulseState { address requester; } - struct ProviderInfo { - uint64 sequenceNumber; - uint128 feeInWei; - uint128 accruedFeesInWei; - bytes uri; - address feeManager; - uint32 maxNumPrices; - uint128 feePerGas; - } - struct State { address admin; uint128 pythFeeInWei; - uint128 accruedPythFeesInWei; - address defaultProvider; + uint128 accruedFeesInWei; address pyth; + uint64 currentSequenceNumber; + address feeManager; Request[32] requests; mapping(bytes32 => Request) requestsOverflow; - mapping(address => ProviderInfo) providers; } State internal _state; diff --git a/target_chains/ethereum/contracts/contracts/pulse/PulseUpgradeable.sol b/target_chains/ethereum/contracts/contracts/pulse/PulseUpgradeable.sol index 0c09e8b9de..48fc694e69 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/PulseUpgradeable.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/PulseUpgradeable.sol @@ -22,11 +22,11 @@ contract PulseUpgradeable is address owner, address admin, uint128 pythFeeInWei, - address defaultProvider, address pythAddress, bool prefillRequestStorage ) public initializer { require(owner != address(0), "owner is zero address"); + require(admin != address(0), "admin is zero address"); __Ownable_init(); __UUPSUpgradeable_init(); @@ -34,7 +34,6 @@ contract PulseUpgradeable is Pulse._initialize( admin, pythFeeInWei, - defaultProvider, pythAddress, prefillRequestStorage ); diff --git a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol index cbc15f417d..b0963b0429 100644 --- a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol +++ b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol @@ -12,18 +12,18 @@ import "../contracts/pulse/PulseErrors.sol"; contract MockPulseConsumer is IPulseConsumer { uint64 public lastSequenceNumber; - address public lastProvider; + address public lastUpdater; uint256 public lastPublishTime; bytes32[] public lastPriceIds; function pulseCallback( uint64 sequenceNumber, - address provider, + address updater, uint256 publishTime, bytes32[] calldata priceIds ) external override { lastSequenceNumber = sequenceNumber; - lastProvider = provider; + lastUpdater = updater; lastPublishTime = publishTime; lastPriceIds = priceIds; } @@ -59,13 +59,11 @@ contract PulseTest is Test, PulseEvents { MockPulseConsumer public consumer; address public owner; address public admin; - address public provider; + address public updater; address public pyth; // Constants uint128 constant PYTH_FEE = 1 wei; - uint128 constant PROVIDER_FEE = 1 wei; - uint128 constant PROVIDER_FEE_PER_GAS = 1 wei; uint128 constant CALLBACK_GAS_LIMIT = 1_000_000; bytes32 constant BTC_PRICE_FEED_ID = 0xe62df6c8b4a85fe1a67db44dc12de5db330f7ac66b72dc658afedf0f4a415b43; @@ -82,22 +80,15 @@ contract PulseTest is Test, PulseEvents { function setUp() public { owner = address(1); admin = address(2); - provider = address(3); + updater = address(3); pyth = address(4); PulseUpgradeable _pulse = new PulseUpgradeable(); proxy = new ERC1967Proxy(address(_pulse), ""); pulse = PulseUpgradeable(address(proxy)); - pulse.initialize(owner, admin, PYTH_FEE, provider, pyth, false); + pulse.initialize(owner, admin, PYTH_FEE, pyth, false); consumer = new MockPulseConsumer(); - - vm.prank(provider); - pulse.register( - PROVIDER_FEE, - PROVIDER_FEE_PER_GAS, - "https://provider.com" - ); } // Helper function to create price IDs array @@ -153,11 +144,8 @@ contract PulseTest is Test, PulseEvents { } // Helper function to calculate total fee - function calculateTotalFee() internal pure returns (uint128) { - return - PYTH_FEE + - PROVIDER_FEE + - (PROVIDER_FEE_PER_GAS * uint128(CALLBACK_GAS_LIMIT)); + function calculateTotalFee() internal view returns (uint128) { + return pulse.getFee(CALLBACK_GAS_LIMIT); } // Helper function to setup consumer request @@ -175,26 +163,31 @@ contract PulseTest is Test, PulseEvents { publishTime = block.timestamp; vm.deal(consumerAddress, 1 gwei); + uint128 totalFee = calculateTotalFee(); + vm.prank(consumerAddress); - sequenceNumber = pulse.requestPriceUpdatesWithCallback{ - value: calculateTotalFee() - }(provider, publishTime, priceIds, CALLBACK_GAS_LIMIT); + sequenceNumber = pulse.requestPriceUpdatesWithCallback{value: totalFee}( + publishTime, + priceIds, + CALLBACK_GAS_LIMIT + ); return (sequenceNumber, priceIds, publishTime); } function testRequestPriceUpdate() public { + // Set a realistic gas price + vm.txGasPrice(30 gwei); + bytes32[] memory priceIds = createPriceIds(); uint256 publishTime = block.timestamp; - // Fund the consumer contract - vm.deal(address(consumer), 1 gwei); - - vm.prank(address(consumer)); + // Fund the consumer contract with enough ETH for higher gas price + vm.deal(address(consumer), 1 ether); + uint128 totalFee = calculateTotalFee(); // Create the event data we expect to see PulseState.Request memory expectedRequest = PulseState.Request({ - provider: provider, sequenceNumber: 1, publishTime: publishTime, priceIds: priceIds, @@ -202,21 +195,18 @@ contract PulseTest is Test, PulseEvents { requester: address(consumer) }); - // Emit event with expected parameters vm.expectEmit(); emit PriceUpdateRequested(expectedRequest); - // Make the actual call that should emit the event - pulse.requestPriceUpdatesWithCallback{value: calculateTotalFee()}( - provider, + vm.prank(address(consumer)); + pulse.requestPriceUpdatesWithCallback{value: totalFee}( publishTime, priceIds, CALLBACK_GAS_LIMIT ); // Additional assertions to verify event data was stored correctly - PulseState.Request memory lastRequest = pulse.getRequest(provider, 1); - assertEq(lastRequest.provider, expectedRequest.provider); + PulseState.Request memory lastRequest = pulse.getRequest(1); assertEq(lastRequest.sequenceNumber, expectedRequest.sequenceNumber); assertEq(lastRequest.publishTime, expectedRequest.publishTime); assertEq( @@ -227,17 +217,22 @@ contract PulseTest is Test, PulseEvents { lastRequest.callbackGasLimit, expectedRequest.callbackGasLimit ); - assertEq(lastRequest.requester, expectedRequest.requester); + assertEq( + lastRequest.requester, + expectedRequest.requester, + "Requester mismatch" + ); } function testRequestWithInsufficientFee() public { - bytes32[] memory priceIds = createPriceIds(); - vm.deal(address(consumer), 1 gwei); + // Set a realistic gas price + vm.txGasPrice(30 gwei); + bytes32[] memory priceIds = createPriceIds(); + vm.deal(address(consumer), 1 ether); vm.prank(address(consumer)); vm.expectRevert(InsufficientFee.selector); pulse.requestPriceUpdatesWithCallback{value: PYTH_FEE}( // Intentionally low fee - provider, block.timestamp, priceIds, CALLBACK_GAS_LIMIT @@ -251,11 +246,13 @@ contract PulseTest is Test, PulseEvents { // Fund the consumer contract vm.deal(address(consumer), 1 gwei); + uint128 totalFee = calculateTotalFee(); + // Step 1: Make the request as consumer vm.prank(address(consumer)); uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{ - value: calculateTotalFee() - }(provider, publishTime, priceIds, CALLBACK_GAS_LIMIT); + value: totalFee + }(publishTime, priceIds, CALLBACK_GAS_LIMIT); // Step 2: Create mock price feeds and setup Pyth response PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( @@ -284,7 +281,7 @@ contract PulseTest is Test, PulseEvents { vm.expectEmit(true, true, false, true); emit PriceUpdateExecuted( sequenceNumber, - provider, + updater, publishTime, priceIds, expectedPrices, @@ -296,9 +293,8 @@ contract PulseTest is Test, PulseEvents { // Create mock update data and execute callback bytes[] memory updateData = createMockUpdateData(priceFeeds); - vm.prank(provider); + vm.prank(updater); pulse.executeCallback( - provider, sequenceNumber, priceIds, updateData, @@ -307,7 +303,6 @@ contract PulseTest is Test, PulseEvents { // Verify callback was executed assertEq(consumer.lastSequenceNumber(), sequenceNumber); - assertEq(consumer.lastProvider(), provider); assertEq(consumer.lastPublishTime(), publishTime); } @@ -329,16 +324,15 @@ contract PulseTest is Test, PulseEvents { vm.expectEmit(true, true, true, true); emit PriceUpdateCallbackFailed( sequenceNumber, - provider, + updater, publishTime, priceIds, address(failingConsumer), "callback failed" ); - vm.prank(provider); + vm.prank(updater); pulse.executeCallback( - provider, sequenceNumber, priceIds, updateData, @@ -364,16 +358,15 @@ contract PulseTest is Test, PulseEvents { vm.expectEmit(true, true, true, true); emit PriceUpdateCallbackFailed( sequenceNumber, - provider, + updater, publishTime, priceIds, address(failingConsumer), "low-level error (possibly out of gas)" ); - vm.prank(provider); + vm.prank(updater); pulse.executeCallback( - provider, sequenceNumber, priceIds, updateData, @@ -399,7 +392,7 @@ contract PulseTest is Test, PulseEvents { mockParsePriceFeedUpdates(priceFeeds); bytes[] memory updateData = createMockUpdateData(priceFeeds); - vm.prank(provider); + vm.prank(updater); vm.expectRevert( abi.encodeWithSelector( InvalidPriceIds.selector, @@ -408,7 +401,6 @@ contract PulseTest is Test, PulseEvents { ) ); pulse.executeCallback( - provider, sequenceNumber, differentPriceIds, updateData, @@ -429,10 +421,9 @@ contract PulseTest is Test, PulseEvents { mockParsePriceFeedUpdates(priceFeeds); bytes[] memory updateData = createMockUpdateData(priceFeeds); - vm.prank(provider); + vm.prank(updater); vm.expectRevert(); pulse.executeCallback{gas: 10000}( - provider, sequenceNumber, priceIds, updateData, @@ -455,7 +446,7 @@ contract PulseTest is Test, PulseEvents { // Try to execute with different gas limit than what was requested uint256 differentGasLimit = CALLBACK_GAS_LIMIT + 1000; - vm.prank(provider); + vm.prank(updater); vm.expectRevert( abi.encodeWithSelector( InvalidCallbackGasLimit.selector, @@ -464,7 +455,6 @@ contract PulseTest is Test, PulseEvents { ) ); pulse.executeCallback( - provider, sequenceNumber, priceIds, updateData, @@ -478,10 +468,11 @@ contract PulseTest is Test, PulseEvents { uint256 futureTime = block.timestamp + 1 days; vm.deal(address(consumer), 1 gwei); + uint128 totalFee = calculateTotalFee(); vm.prank(address(consumer)); uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{ - value: calculateTotalFee() - }(provider, futureTime, priceIds, CALLBACK_GAS_LIMIT); + value: totalFee + }(futureTime, priceIds, CALLBACK_GAS_LIMIT); // Try to execute callback before the requested timestamp PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( @@ -490,10 +481,9 @@ contract PulseTest is Test, PulseEvents { mockParsePriceFeedUpdates(priceFeeds); // This will make parsePriceFeedUpdates return future-dated prices bytes[] memory updateData = createMockUpdateData(priceFeeds); - vm.prank(provider); + vm.prank(updater); // Should succeed because we're simulating receiving future-dated price updates pulse.executeCallback( - provider, sequenceNumber, priceIds, updateData, @@ -504,31 +494,6 @@ contract PulseTest is Test, PulseEvents { assertEq(consumer.lastPublishTime(), futureTime); } - function testExecuteCallbackWithWrongProvider() public { - ( - uint64 sequenceNumber, - bytes32[] memory priceIds, - uint256 publishTime - ) = setupConsumerRequest(address(consumer)); - - PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( - publishTime - ); - mockParsePriceFeedUpdates(priceFeeds); - bytes[] memory updateData = createMockUpdateData(priceFeeds); - - address wrongProvider = address(0x999); - vm.prank(wrongProvider); - vm.expectRevert(NoSuchRequest.selector); - pulse.executeCallback( - wrongProvider, - sequenceNumber, - priceIds, - updateData, - CALLBACK_GAS_LIMIT - ); - } - function testDoubleExecuteCallback() public { ( uint64 sequenceNumber, @@ -543,9 +508,8 @@ contract PulseTest is Test, PulseEvents { bytes[] memory updateData = createMockUpdateData(priceFeeds); // First execution - vm.prank(provider); + vm.prank(updater); pulse.executeCallback( - provider, sequenceNumber, priceIds, updateData, @@ -553,10 +517,9 @@ contract PulseTest is Test, PulseEvents { ); // Second execution should fail - vm.prank(provider); + vm.prank(updater); vm.expectRevert(NoSuchRequest.selector); pulse.executeCallback( - provider, sequenceNumber, priceIds, updateData, @@ -573,12 +536,9 @@ contract PulseTest is Test, PulseEvents { for (uint256 i = 0; i < gasLimits.length; i++) { uint256 gasLimit = gasLimits[i]; - uint128 expectedFee = PROVIDER_FEE + // Base provider fee - (PROVIDER_FEE_PER_GAS * uint128(gasLimit)) + // Gas-based fee - PYTH_FEE; // Pyth oracle fee - - uint128 actualFee = pulse.getFee(provider, gasLimit); - + uint128 expectedFee = SafeCast.toUint128(tx.gasprice * gasLimit) + + PYTH_FEE; + uint128 actualFee = pulse.getFee(gasLimit); assertEq( actualFee, expectedFee, @@ -587,86 +547,73 @@ contract PulseTest is Test, PulseEvents { } // Test with zero gas limit - uint128 expectedMinFee = PROVIDER_FEE + PYTH_FEE; - uint128 actualMinFee = pulse.getFee(provider, 0); + uint128 expectedMinFee = PYTH_FEE; + uint128 actualMinFee = pulse.getFee(0); assertEq( actualMinFee, expectedMinFee, "Minimum fee calculation incorrect" ); - - // Test with unregistered provider (should return 0 fees) - address unregisteredProvider = address(0x123); - uint128 unregisteredFee = pulse.getFee( - unregisteredProvider, - gasLimits[0] - ); - assertEq( - unregisteredFee, - PYTH_FEE, - "Unregistered provider fee should only include Pyth fee" - ); } - function testWithdraw() public { + function testWithdrawFees() public { // Setup: Request price update to accrue some fees bytes32[] memory priceIds = createPriceIds(); vm.deal(address(consumer), 1 gwei); vm.prank(address(consumer)); pulse.requestPriceUpdatesWithCallback{value: calculateTotalFee()}( - provider, block.timestamp, priceIds, CALLBACK_GAS_LIMIT ); - // Get provider's balance before withdrawal - uint256 providerBalanceBefore = provider.balance; - PulseState.ProviderInfo memory infoBefore = pulse.getProviderInfo( - provider - ); + // Get admin's balance before withdrawal + uint256 adminBalanceBefore = admin.balance; + uint128 accruedFees = pulse.getAccruedFees(); - // Withdraw fees - vm.prank(provider); - pulse.withdraw(infoBefore.accruedFeesInWei); + // Withdraw fees as admin + vm.prank(admin); + pulse.withdrawFees(accruedFees); // Verify balances assertEq( - provider.balance, - providerBalanceBefore + infoBefore.accruedFeesInWei + admin.balance, + adminBalanceBefore + accruedFees, + "Admin balance should increase by withdrawn amount" ); - - PulseState.ProviderInfo memory infoAfter = pulse.getProviderInfo( - provider + assertEq( + pulse.getAccruedFees(), + 0, + "Contract should have no fees after withdrawal" ); - assertEq(infoAfter.accruedFeesInWei, 0); } - function testWithdrawInsufficientBalance() public { - vm.prank(provider); + function testWithdrawFeesUnauthorized() public { + vm.prank(address(0xdead)); + vm.expectRevert("Only admin can withdraw fees"); + pulse.withdrawFees(1 ether); + } + + function testWithdrawFeesInsufficientBalance() public { + vm.prank(admin); vm.expectRevert("Insufficient balance"); - pulse.withdraw(1 ether); + pulse.withdrawFees(1 ether); } function testSetAndWithdrawAsFeeManager() public { address feeManager = address(0x789); - // Set fee manager - vm.prank(provider); + // Set fee manager as admin + vm.prank(admin); pulse.setFeeManager(feeManager); - // Verify fee manager was set - PulseState.ProviderInfo memory info = pulse.getProviderInfo(provider); - assertEq(info.feeManager, feeManager); - // Setup: Request price update to accrue some fees bytes32[] memory priceIds = createPriceIds(); vm.deal(address(consumer), 1 gwei); vm.prank(address(consumer)); pulse.requestPriceUpdatesWithCallback{value: calculateTotalFee()}( - provider, block.timestamp, priceIds, CALLBACK_GAS_LIMIT @@ -674,48 +621,44 @@ contract PulseTest is Test, PulseEvents { // Test withdrawal as fee manager uint256 managerBalanceBefore = feeManager.balance; - info = pulse.getProviderInfo(provider); + uint128 accruedFees = pulse.getAccruedFees(); vm.prank(feeManager); - pulse.withdrawAsFeeManager(provider, info.accruedFeesInWei); + pulse.withdrawAsFeeManager(accruedFees); assertEq( feeManager.balance, - managerBalanceBefore + info.accruedFeesInWei + managerBalanceBefore + accruedFees, + "Fee manager balance should increase by withdrawn amount" + ); + assertEq( + pulse.getAccruedFees(), + 0, + "Contract should have no fees after withdrawal" ); } - function testMaxNumPrices() public { - // Set max number of prices - vm.prank(provider); - pulse.setMaxNumPrices(1); - - // Try to request more prices than allowed - bytes32[] memory priceIds = new bytes32[](2); - priceIds[0] = BTC_PRICE_FEED_ID; - priceIds[1] = ETH_PRICE_FEED_ID; - - vm.deal(address(consumer), 1 gwei); - vm.prank(address(consumer)); - - vm.expectRevert( - abi.encodeWithSelector(ExceedsMaxPrices.selector, 2, 1) - ); - pulse.requestPriceUpdatesWithCallback{value: calculateTotalFee()}( - provider, - block.timestamp, - priceIds, - CALLBACK_GAS_LIMIT - ); + function testSetFeeManagerUnauthorized() public { + address feeManager = address(0x789); + vm.prank(address(0xdead)); + vm.expectRevert("Only admin can set fee manager"); + pulse.setFeeManager(feeManager); } - function testSetProviderUri() public { - bytes memory newUri = "https://updated-provider.com"; + function testWithdrawAsFeeManagerUnauthorized() public { + vm.prank(address(0xdead)); + vm.expectRevert("Only fee manager"); + pulse.withdrawAsFeeManager(1 ether); + } - vm.prank(provider); - pulse.setProviderUri(newUri); + function testWithdrawAsFeeManagerInsufficientBalance() public { + // Set up fee manager first + address feeManager = address(0x789); + vm.prank(admin); + pulse.setFeeManager(feeManager); - PulseState.ProviderInfo memory info = pulse.getProviderInfo(provider); - assertEq(info.uri, newUri); + vm.prank(feeManager); + vm.expectRevert("Insufficient balance"); + pulse.withdrawAsFeeManager(1 ether); } } From fdc06cda07465be1acd1888aa72fa94f8856f2c5 Mon Sep 17 00:00:00 2001 From: Daniel Chew Date: Mon, 18 Nov 2024 20:05:25 +0900 Subject: [PATCH 18/29] address comments --- .../contracts/contracts/pulse/IPulse.sol | 3 +- .../contracts/contracts/pulse/Pulse.sol | 31 ++-- .../contracts/contracts/pulse/PulseErrors.sol | 3 +- .../contracts/contracts/pulse/PulseState.sol | 2 +- .../ethereum/contracts/forge-test/Pulse.t.sol | 168 ++++++------------ 5 files changed, 67 insertions(+), 140 deletions(-) diff --git a/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol b/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol index 125d06f357..22891e98a3 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol @@ -24,9 +24,8 @@ interface IPulse is PulseEvents { function executeCallback( uint64 sequenceNumber, - bytes32[] calldata priceIds, bytes[] calldata updateData, - uint256 callbackGasLimit + bytes32[] calldata priceIds ) external payable; // Getters diff --git a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol index 167f3e9f7a..6ede25997b 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol @@ -48,7 +48,7 @@ abstract contract Pulse is IPulse, PulseState { Request storage req = allocRequest(requestSequenceNumber); req.sequenceNumber = requestSequenceNumber; req.publishTime = publishTime; - req.priceIds = priceIds; + req.priceIdsHash = keccak256(abi.encode(priceIds)); req.callbackGasLimit = callbackGasLimit; req.requester = msg.sender; @@ -59,24 +59,20 @@ abstract contract Pulse is IPulse, PulseState { function executeCallback( uint64 sequenceNumber, - bytes32[] calldata priceIds, bytes[] calldata updateData, - uint256 callbackGasLimit + bytes32[] calldata priceIds ) external payable override { Request storage req = findActiveRequest(sequenceNumber); + bytes32 providedPriceIdsHash = keccak256(abi.encode(priceIds)); + bytes32 storedPriceIdsHash = req.priceIdsHash; - if ( - keccak256(abi.encode(req.priceIds)) != - keccak256(abi.encode(priceIds)) - ) { - revert InvalidPriceIds(priceIds, req.priceIds); + if (providedPriceIdsHash != storedPriceIdsHash) { + revert InvalidPriceIds(providedPriceIdsHash, storedPriceIdsHash); } - if (req.callbackGasLimit != callbackGasLimit) { - revert InvalidCallbackGasLimit( - callbackGasLimit, - req.callbackGasLimit - ); + // Check if there's enough gas left for the callback + if (gasleft() < req.callbackGasLimit) { + revert InsufficientGas(); } PythStructs.PriceFeed[] memory priceFeeds = IPyth(_state.pyth) @@ -90,12 +86,9 @@ abstract contract Pulse is IPulse, PulseState { uint256 publishTime = priceFeeds[0].price.publishTime; try - IPulseConsumer(req.requester).pulseCallback( - sequenceNumber, - msg.sender, - publishTime, - priceIds - ) + IPulseConsumer(req.requester).pulseCallback{ + gas: req.callbackGasLimit + }(sequenceNumber, msg.sender, publishTime, priceIds) { // Callback succeeded emitPriceUpdate(sequenceNumber, publishTime, priceIds, priceFeeds); diff --git a/target_chains/ethereum/contracts/contracts/pulse/PulseErrors.sol b/target_chains/ethereum/contracts/contracts/pulse/PulseErrors.sol index 535ad4d746..c2fe41ccb6 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/PulseErrors.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/PulseErrors.sol @@ -8,6 +8,7 @@ error InsufficientFee(); error Unauthorized(); error InvalidCallbackGas(); error CallbackFailed(); -error InvalidPriceIds(bytes32[] requested, bytes32[] stored); +error InvalidPriceIds(bytes32 providedPriceIdsHash, bytes32 storedPriceIdsHash); error InvalidCallbackGasLimit(uint256 requested, uint256 stored); error ExceedsMaxPrices(uint32 requested, uint32 maxAllowed); +error InsufficientGas(); diff --git a/target_chains/ethereum/contracts/contracts/pulse/PulseState.sol b/target_chains/ethereum/contracts/contracts/pulse/PulseState.sol index 08052dd1e3..6b0b48fc61 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/PulseState.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/PulseState.sol @@ -9,7 +9,7 @@ contract PulseState { struct Request { uint64 sequenceNumber; uint256 publishTime; - bytes32[] priceIds; + bytes32 priceIdsHash; uint256 callbackGasLimit; address requester; } diff --git a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol index b0963b0429..7e7d276e43 100644 --- a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol +++ b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol @@ -190,7 +190,7 @@ contract PulseTest is Test, PulseEvents { PulseState.Request memory expectedRequest = PulseState.Request({ sequenceNumber: 1, publishTime: publishTime, - priceIds: priceIds, + priceIdsHash: keccak256(abi.encode(priceIds)), callbackGasLimit: CALLBACK_GAS_LIMIT, requester: address(consumer) }); @@ -209,10 +209,7 @@ contract PulseTest is Test, PulseEvents { PulseState.Request memory lastRequest = pulse.getRequest(1); assertEq(lastRequest.sequenceNumber, expectedRequest.sequenceNumber); assertEq(lastRequest.publishTime, expectedRequest.publishTime); - assertEq( - keccak256(abi.encode(lastRequest.priceIds)), - keccak256(abi.encode(expectedRequest.priceIds)) - ); + assertEq(lastRequest.priceIdsHash, expectedRequest.priceIdsHash); assertEq( lastRequest.callbackGasLimit, expectedRequest.callbackGasLimit @@ -245,7 +242,6 @@ contract PulseTest is Test, PulseEvents { // Fund the consumer contract vm.deal(address(consumer), 1 gwei); - uint128 totalFee = calculateTotalFee(); // Step 1: Make the request as consumer @@ -278,7 +274,7 @@ contract PulseTest is Test, PulseEvents { expectedPublishTimes[1] = publishTime; // Expect the PriceUpdateExecuted event with all price data - vm.expectEmit(true, true, false, true); + vm.expectEmit(); emit PriceUpdateExecuted( sequenceNumber, updater, @@ -294,12 +290,7 @@ contract PulseTest is Test, PulseEvents { bytes[] memory updateData = createMockUpdateData(priceFeeds); vm.prank(updater); - pulse.executeCallback( - sequenceNumber, - priceIds, - updateData, - CALLBACK_GAS_LIMIT - ); + pulse.executeCallback(sequenceNumber, updateData, priceIds); // Verify callback was executed assertEq(consumer.lastSequenceNumber(), sequenceNumber); @@ -321,7 +312,7 @@ contract PulseTest is Test, PulseEvents { mockParsePriceFeedUpdates(priceFeeds); bytes[] memory updateData = createMockUpdateData(priceFeeds); - vm.expectEmit(true, true, true, true); + vm.expectEmit(); emit PriceUpdateCallbackFailed( sequenceNumber, updater, @@ -332,12 +323,7 @@ contract PulseTest is Test, PulseEvents { ); vm.prank(updater); - pulse.executeCallback( - sequenceNumber, - priceIds, - updateData, - CALLBACK_GAS_LIMIT - ); + pulse.executeCallback(sequenceNumber, updateData, priceIds); } function testExecuteCallbackCustomErrorFailure() public { @@ -355,7 +341,7 @@ contract PulseTest is Test, PulseEvents { mockParsePriceFeedUpdates(priceFeeds); bytes[] memory updateData = createMockUpdateData(priceFeeds); - vm.expectEmit(true, true, true, true); + vm.expectEmit(); emit PriceUpdateCallbackFailed( sequenceNumber, updater, @@ -366,100 +352,28 @@ contract PulseTest is Test, PulseEvents { ); vm.prank(updater); - pulse.executeCallback( - sequenceNumber, - priceIds, - updateData, - CALLBACK_GAS_LIMIT - ); - } - - // Test executing callback with mismatched price IDs - function testExecuteCallbackWithMismatchedPriceIds() public { - ( - uint64 sequenceNumber, - bytes32[] memory originalPriceIds, - uint256 publishTime - ) = setupConsumerRequest(address(consumer)); - - // Create different price IDs array - bytes32[] memory differentPriceIds = new bytes32[](1); - differentPriceIds[0] = bytes32(uint256(1)); // Different price ID - - PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( - publishTime - ); - mockParsePriceFeedUpdates(priceFeeds); - bytes[] memory updateData = createMockUpdateData(priceFeeds); - - vm.prank(updater); - vm.expectRevert( - abi.encodeWithSelector( - InvalidPriceIds.selector, - differentPriceIds, - originalPriceIds - ) - ); - pulse.executeCallback( - sequenceNumber, - differentPriceIds, - updateData, - CALLBACK_GAS_LIMIT - ); + pulse.executeCallback(sequenceNumber, updateData, priceIds); } function testExecuteCallbackWithInsufficientGas() public { + // Setup request with 1M gas limit ( uint64 sequenceNumber, bytes32[] memory priceIds, uint256 publishTime ) = setupConsumerRequest(address(consumer)); + // Setup mock data PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( publishTime ); mockParsePriceFeedUpdates(priceFeeds); bytes[] memory updateData = createMockUpdateData(priceFeeds); + // Try executing with only 10K gas when 1M is required vm.prank(updater); - vm.expectRevert(); - pulse.executeCallback{gas: 10000}( - sequenceNumber, - priceIds, - updateData, - CALLBACK_GAS_LIMIT - ); - } - - function testExecuteCallbackWithInvalidGasLimit() public { - ( - uint64 sequenceNumber, - bytes32[] memory priceIds, - uint256 publishTime - ) = setupConsumerRequest(address(consumer)); - - PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( - publishTime - ); - mockParsePriceFeedUpdates(priceFeeds); - bytes[] memory updateData = createMockUpdateData(priceFeeds); - - // Try to execute with different gas limit than what was requested - uint256 differentGasLimit = CALLBACK_GAS_LIMIT + 1000; - vm.prank(updater); - vm.expectRevert( - abi.encodeWithSelector( - InvalidCallbackGasLimit.selector, - differentGasLimit, - CALLBACK_GAS_LIMIT - ) - ); - pulse.executeCallback( - sequenceNumber, - priceIds, - updateData, - differentGasLimit - ); + vm.expectRevert(InsufficientGas.selector); + pulse.executeCallback{gas: 10000}(sequenceNumber, updateData, priceIds); // Will fail because gasleft() < callbackGasLimit } function testExecuteCallbackWithFutureTimestamp() public { @@ -483,12 +397,7 @@ contract PulseTest is Test, PulseEvents { vm.prank(updater); // Should succeed because we're simulating receiving future-dated price updates - pulse.executeCallback( - sequenceNumber, - priceIds, - updateData, - CALLBACK_GAS_LIMIT - ); + pulse.executeCallback(sequenceNumber, updateData, priceIds); // Verify the callback was executed with future timestamp assertEq(consumer.lastPublishTime(), futureTime); @@ -509,22 +418,12 @@ contract PulseTest is Test, PulseEvents { // First execution vm.prank(updater); - pulse.executeCallback( - sequenceNumber, - priceIds, - updateData, - CALLBACK_GAS_LIMIT - ); + pulse.executeCallback(sequenceNumber, updateData, priceIds); // Second execution should fail vm.prank(updater); vm.expectRevert(NoSuchRequest.selector); - pulse.executeCallback( - sequenceNumber, - priceIds, - updateData, - CALLBACK_GAS_LIMIT - ); + pulse.executeCallback(sequenceNumber, updateData, priceIds); } function testGetFee() public { @@ -661,4 +560,39 @@ contract PulseTest is Test, PulseEvents { vm.expectRevert("Insufficient balance"); pulse.withdrawAsFeeManager(1 ether); } + + // Add new test for invalid priceIds + function testExecuteCallbackWithInvalidPriceIds() public { + bytes32[] memory priceIds = createPriceIds(); + uint256 publishTime = block.timestamp; + + // Setup request + (uint64 sequenceNumber, , ) = setupConsumerRequest(address(consumer)); + + // Create different priceIds + bytes32[] memory wrongPriceIds = new bytes32[](2); + wrongPriceIds[0] = bytes32(uint256(1)); // Different price IDs + wrongPriceIds[1] = bytes32(uint256(2)); + + PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( + publishTime + ); + mockParsePriceFeedUpdates(priceFeeds); + bytes[] memory updateData = createMockUpdateData(priceFeeds); + + // Calculate hashes for both arrays + bytes32 providedPriceIdsHash = keccak256(abi.encode(wrongPriceIds)); + bytes32 storedPriceIdsHash = keccak256(abi.encode(priceIds)); + + // Should revert when trying to execute with wrong priceIds + vm.prank(updater); + vm.expectRevert( + abi.encodeWithSelector( + InvalidPriceIds.selector, + providedPriceIdsHash, + storedPriceIdsHash + ) + ); + pulse.executeCallback(sequenceNumber, updateData, wrongPriceIds); + } } From 1131e3b4982ff89b7bbbf2883a914bc2df31264c Mon Sep 17 00:00:00 2001 From: Daniel Chew Date: Mon, 18 Nov 2024 20:54:12 +0900 Subject: [PATCH 19/29] address comments --- .../contracts/contracts/pulse/IPulse.sol | 4 +- .../contracts/contracts/pulse/Pulse.sol | 22 +++--- .../contracts/contracts/pulse/PulseEvents.sol | 2 - .../ethereum/contracts/forge-test/Pulse.t.sol | 67 ++++++++++++++----- 4 files changed, 60 insertions(+), 35 deletions(-) diff --git a/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol b/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol index 22891e98a3..6e5d44eaf9 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol @@ -2,6 +2,7 @@ pragma solidity ^0.8.0; +import "@pythnetwork/pyth-sdk-solidity/IPyth.sol"; import "./PulseEvents.sol"; import "./PulseState.sol"; @@ -9,8 +10,7 @@ interface IPulseConsumer { function pulseCallback( uint64 sequenceNumber, address updater, - uint256 publishTime, - bytes32[] calldata priceIds + PythStructs.PriceFeed[] memory priceFeeds ) external; } diff --git a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol index 6ede25997b..eeb1cccf2c 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol @@ -63,18 +63,15 @@ abstract contract Pulse is IPulse, PulseState { bytes32[] calldata priceIds ) external payable override { Request storage req = findActiveRequest(sequenceNumber); + + // Verify priceIds match bytes32 providedPriceIdsHash = keccak256(abi.encode(priceIds)); bytes32 storedPriceIdsHash = req.priceIdsHash; - if (providedPriceIdsHash != storedPriceIdsHash) { revert InvalidPriceIds(providedPriceIdsHash, storedPriceIdsHash); } - // Check if there's enough gas left for the callback - if (gasleft() < req.callbackGasLimit) { - revert InsufficientGas(); - } - + // Parse price feeds first to measure gas usage PythStructs.PriceFeed[] memory priceFeeds = IPyth(_state.pyth) .parsePriceFeedUpdates( updateData, @@ -83,21 +80,23 @@ abstract contract Pulse is IPulse, PulseState { SafeCast.toUint64(req.publishTime) ); - uint256 publishTime = priceFeeds[0].price.publishTime; + // Check if enough gas remains for the callback + if (gasleft() < req.callbackGasLimit) { + revert InsufficientGas(); + } try IPulseConsumer(req.requester).pulseCallback{ gas: req.callbackGasLimit - }(sequenceNumber, msg.sender, publishTime, priceIds) + }(sequenceNumber, msg.sender, priceFeeds) { // Callback succeeded - emitPriceUpdate(sequenceNumber, publishTime, priceIds, priceFeeds); + emitPriceUpdate(sequenceNumber, priceIds, priceFeeds); } catch Error(string memory reason) { // Explicit revert/require emit PriceUpdateCallbackFailed( sequenceNumber, msg.sender, - publishTime, priceIds, req.requester, reason @@ -107,7 +106,6 @@ abstract contract Pulse is IPulse, PulseState { emit PriceUpdateCallbackFailed( sequenceNumber, msg.sender, - publishTime, priceIds, req.requester, "low-level error (possibly out of gas)" @@ -119,7 +117,6 @@ abstract contract Pulse is IPulse, PulseState { function emitPriceUpdate( uint64 sequenceNumber, - uint256 publishTime, bytes32[] memory priceIds, PythStructs.PriceFeed[] memory priceFeeds ) internal { @@ -138,7 +135,6 @@ abstract contract Pulse is IPulse, PulseState { emit PriceUpdateExecuted( sequenceNumber, msg.sender, - publishTime, priceIds, prices, conf, diff --git a/target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol b/target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol index 1c96797b39..4b7abfbbc3 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol @@ -9,7 +9,6 @@ interface PulseEvents { event PriceUpdateExecuted( uint64 indexed sequenceNumber, address indexed updater, - uint256 publishTime, bytes32[] priceIds, int64[] prices, uint64[] conf, @@ -22,7 +21,6 @@ interface PulseEvents { event PriceUpdateCallbackFailed( uint64 indexed sequenceNumber, address indexed updater, - uint256 publishTime, bytes32[] priceIds, address requester, string reason diff --git a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol index 7e7d276e43..81edc4115a 100644 --- a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol +++ b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol @@ -3,6 +3,7 @@ pragma solidity ^0.8.0; import "forge-std/Test.sol"; +import "@pythnetwork/pyth-sdk-solidity/IPyth.sol"; import "@openzeppelin/contracts/proxy/ERC1967/ERC1967Proxy.sol"; import "../contracts/pulse/PulseUpgradeable.sol"; import "../contracts/pulse/IPulse.sol"; @@ -13,19 +14,26 @@ import "../contracts/pulse/PulseErrors.sol"; contract MockPulseConsumer is IPulseConsumer { uint64 public lastSequenceNumber; address public lastUpdater; - uint256 public lastPublishTime; - bytes32[] public lastPriceIds; + PythStructs.PriceFeed[] private _lastPriceFeeds; function pulseCallback( uint64 sequenceNumber, address updater, - uint256 publishTime, - bytes32[] calldata priceIds + PythStructs.PriceFeed[] memory priceFeeds ) external override { lastSequenceNumber = sequenceNumber; lastUpdater = updater; - lastPublishTime = publishTime; - lastPriceIds = priceIds; + for (uint i = 0; i < priceFeeds.length; i++) { + _lastPriceFeeds.push(priceFeeds[i]); + } + } + + function lastPriceFeeds() + external + view + returns (PythStructs.PriceFeed[] memory) + { + return _lastPriceFeeds; } } @@ -33,8 +41,7 @@ contract FailingPulseConsumer is IPulseConsumer { function pulseCallback( uint64, address, - uint256, - bytes32[] calldata + PythStructs.PriceFeed[] memory ) external pure override { revert("callback failed"); } @@ -46,8 +53,7 @@ contract CustomErrorPulseConsumer is IPulseConsumer { function pulseCallback( uint64, address, - uint256, - bytes32[] calldata + PythStructs.PriceFeed[] memory ) external pure override { revert CustomError("callback failed"); } @@ -278,7 +284,6 @@ contract PulseTest is Test, PulseEvents { emit PriceUpdateExecuted( sequenceNumber, updater, - publishTime, priceIds, expectedPrices, expectedConf, @@ -294,7 +299,22 @@ contract PulseTest is Test, PulseEvents { // Verify callback was executed assertEq(consumer.lastSequenceNumber(), sequenceNumber); - assertEq(consumer.lastPublishTime(), publishTime); + + // Compare price feeds array length + PythStructs.PriceFeed[] memory lastFeeds = consumer.lastPriceFeeds(); + assertEq(lastFeeds.length, priceFeeds.length); + + // Compare each price feed + for (uint i = 0; i < priceFeeds.length; i++) { + assertEq(lastFeeds[i].id, priceFeeds[i].id); + assertEq(lastFeeds[i].price.price, priceFeeds[i].price.price); + assertEq(lastFeeds[i].price.conf, priceFeeds[i].price.conf); + assertEq(lastFeeds[i].price.expo, priceFeeds[i].price.expo); + assertEq( + lastFeeds[i].price.publishTime, + priceFeeds[i].price.publishTime + ); + } } function testExecuteCallbackFailure() public { @@ -316,7 +336,6 @@ contract PulseTest is Test, PulseEvents { emit PriceUpdateCallbackFailed( sequenceNumber, updater, - publishTime, priceIds, address(failingConsumer), "callback failed" @@ -345,7 +364,6 @@ contract PulseTest is Test, PulseEvents { emit PriceUpdateCallbackFailed( sequenceNumber, updater, - publishTime, priceIds, address(failingConsumer), "low-level error (possibly out of gas)" @@ -370,10 +388,14 @@ contract PulseTest is Test, PulseEvents { mockParsePriceFeedUpdates(priceFeeds); bytes[] memory updateData = createMockUpdateData(priceFeeds); - // Try executing with only 10K gas when 1M is required + // Try executing with only 100K gas when 1M is required vm.prank(updater); vm.expectRevert(InsufficientGas.selector); - pulse.executeCallback{gas: 10000}(sequenceNumber, updateData, priceIds); // Will fail because gasleft() < callbackGasLimit + pulse.executeCallback{gas: 100000}( + sequenceNumber, + updateData, + priceIds + ); // Will fail because gasleft() < callbackGasLimit } function testExecuteCallbackWithFutureTimestamp() public { @@ -399,8 +421,17 @@ contract PulseTest is Test, PulseEvents { // Should succeed because we're simulating receiving future-dated price updates pulse.executeCallback(sequenceNumber, updateData, priceIds); - // Verify the callback was executed with future timestamp - assertEq(consumer.lastPublishTime(), futureTime); + // Compare price feeds array length + PythStructs.PriceFeed[] memory lastFeeds = consumer.lastPriceFeeds(); + assertEq(lastFeeds.length, priceFeeds.length); + + // Compare each price feed publish time + for (uint i = 0; i < priceFeeds.length; i++) { + assertEq( + lastFeeds[i].price.publishTime, + priceFeeds[i].price.publishTime + ); + } } function testDoubleExecuteCallback() public { From 9c118cdd2a4c7d9a817a0223cd801d6efe551326 Mon Sep 17 00:00:00 2001 From: Daniel Chew Date: Mon, 6 Jan 2025 16:07:23 +0900 Subject: [PATCH 20/29] prevent requests 60 mins in the future that could exploit gas price difference --- target_chains/ethereum/contracts/contracts/pulse/Pulse.sol | 1 + 1 file changed, 1 insertion(+) diff --git a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol index eeb1cccf2c..159493a073 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol @@ -40,6 +40,7 @@ abstract contract Pulse is IPulse, PulseState { bytes32[] calldata priceIds, uint256 callbackGasLimit ) external payable override returns (uint64 requestSequenceNumber) { + require(publishTime <= block.timestamp + 60, "Too far in future"); requestSequenceNumber = _state.currentSequenceNumber++; uint128 requiredFee = getFee(callbackGasLimit); From 4f99ff0e8415da3629a0bd1d88b3fd359ba94fef Mon Sep 17 00:00:00 2001 From: Daniel Chew Date: Mon, 6 Jan 2025 18:00:01 +0900 Subject: [PATCH 21/29] fix test --- .../ethereum/contracts/forge-test/Pulse.t.sol | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol index 81edc4115a..06debf6aa1 100644 --- a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol +++ b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol @@ -401,7 +401,7 @@ contract PulseTest is Test, PulseEvents { function testExecuteCallbackWithFutureTimestamp() public { // Setup request with future timestamp bytes32[] memory priceIds = createPriceIds(); - uint256 futureTime = block.timestamp + 1 days; + uint256 futureTime = block.timestamp + 10; // 10 seconds in future vm.deal(address(consumer), 1 gwei); uint128 totalFee = calculateTotalFee(); @@ -434,6 +434,22 @@ contract PulseTest is Test, PulseEvents { } } + function testRevertOnTooFarFutureTimestamp() public { + bytes32[] memory priceIds = createPriceIds(); + uint256 farFutureTime = block.timestamp + 61; // Just over 1 minute + vm.deal(address(consumer), 1 gwei); + + uint128 totalFee = calculateTotalFee(); + vm.prank(address(consumer)); + + vm.expectRevert("Too far in future"); + pulse.requestPriceUpdatesWithCallback{value: totalFee}( + farFutureTime, + priceIds, + CALLBACK_GAS_LIMIT + ); + } + function testDoubleExecuteCallback() public { ( uint64 sequenceNumber, From 225a31f87e16327986e6fcbcda182af0b3d52562 Mon Sep 17 00:00:00 2001 From: Daniel Chew Date: Mon, 6 Jan 2025 18:17:26 +0900 Subject: [PATCH 22/29] add priceIds to PriceUpdateRequested event --- target_chains/ethereum/contracts/contracts/pulse/Pulse.sol | 2 +- .../ethereum/contracts/contracts/pulse/PulseEvents.sol | 2 +- target_chains/ethereum/contracts/forge-test/Pulse.t.sol | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol index 159493a073..593e706dcb 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol @@ -55,7 +55,7 @@ abstract contract Pulse is IPulse, PulseState { _state.accruedFeesInWei += SafeCast.toUint128(msg.value); - emit PriceUpdateRequested(req); + emit PriceUpdateRequested(req, priceIds); } function executeCallback( diff --git a/target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol b/target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol index 4b7abfbbc3..c3d29f168d 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol @@ -4,7 +4,7 @@ pragma solidity ^0.8.0; import "./PulseState.sol"; interface PulseEvents { - event PriceUpdateRequested(PulseState.Request request); + event PriceUpdateRequested(PulseState.Request request, bytes32[] priceIds); event PriceUpdateExecuted( uint64 indexed sequenceNumber, diff --git a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol index 06debf6aa1..020c59bc92 100644 --- a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol +++ b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol @@ -202,7 +202,7 @@ contract PulseTest is Test, PulseEvents { }); vm.expectEmit(); - emit PriceUpdateRequested(expectedRequest); + emit PriceUpdateRequested(expectedRequest, priceIds); vm.prank(address(consumer)); pulse.requestPriceUpdatesWithCallback{value: totalFee}( From bf49a4a7f942bfcb3999222c95c1bf0c0a509f57 Mon Sep 17 00:00:00 2001 From: Daniel Chew Date: Tue, 7 Jan 2025 12:51:05 +0900 Subject: [PATCH 23/29] add 50% overhead to gas for cross-contract calls --- target_chains/ethereum/contracts/contracts/pulse/Pulse.sol | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol index 593e706dcb..9c10b10e49 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol @@ -81,8 +81,8 @@ abstract contract Pulse is IPulse, PulseState { SafeCast.toUint64(req.publishTime) ); - // Check if enough gas remains for the callback - if (gasleft() < req.callbackGasLimit) { + // Check if enough gas remains for the callback plus 50% overhead for cross-contract call (uses integer arithmetic to avoid floating point) + if (gasleft() < (req.callbackGasLimit * 3) / 2) { revert InsufficientGas(); } From c63441977b4b7c7726e50dd1850b9f8549d1bb94 Mon Sep 17 00:00:00 2001 From: Daniel Chew Date: Tue, 7 Jan 2025 13:16:07 +0900 Subject: [PATCH 24/29] feat: add test for executing callback with gas overhead --- .../ethereum/contracts/forge-test/Pulse.t.sol | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol index 020c59bc92..86c99caa34 100644 --- a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol +++ b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol @@ -642,4 +642,41 @@ contract PulseTest is Test, PulseEvents { ); pulse.executeCallback(sequenceNumber, updateData, wrongPriceIds); } + + function testExecuteCallbackGasOverhead() public { + // Setup request with 1M gas limit + ( + uint64 sequenceNumber, + bytes32[] memory priceIds, + uint256 publishTime + ) = setupConsumerRequest(address(consumer)); + + // Setup mock data + PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( + publishTime + ); + mockParsePriceFeedUpdates(priceFeeds); + bytes[] memory updateData = createMockUpdateData(priceFeeds); + + // Should fail with exactly 1.4x gas (less than required 1.5x) + vm.prank(updater); + vm.expectRevert(InsufficientGas.selector); + pulse.executeCallback{gas: (CALLBACK_GAS_LIMIT * 14) / 10}( + sequenceNumber, + updateData, + priceIds + ); + + // Should succeed with 1.6x gas + vm.prank(updater); + pulse.executeCallback{gas: (CALLBACK_GAS_LIMIT * 16) / 10}( + sequenceNumber, + updateData, + priceIds + ); + + // Verify callback was executed successfully + assertEq(consumer.lastSequenceNumber(), sequenceNumber); + assertEq(consumer.lastUpdater(), updater); + } } From d3bc7cdb0e777cbc4e48fa80300cc47135816038 Mon Sep 17 00:00:00 2001 From: Daniel Chew Date: Tue, 7 Jan 2025 13:24:54 +0900 Subject: [PATCH 25/29] feat: add docs for requestPriceUpdatesWithCallback and executeCallback functions to IPulse interface --- .../contracts/contracts/pulse/IPulse.sol | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol b/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol index 6e5d44eaf9..9cd2fde3ed 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol @@ -16,12 +16,30 @@ interface IPulseConsumer { interface IPulse is PulseEvents { // Core functions + /** + * @notice Requests price updates with a callback + * @dev The msg.value must cover both the Pyth fee and gas costs + * Note: The actual gas required for execution will be 1.5x the callbackGasLimit + * to account for cross-contract call overhead + some gas for some other operations in the function before the callback + * @param publishTime The minimum publish time for price updates + * @param priceIds The price feed IDs to update + * @param callbackGasLimit Gas limit for the callback execution + * @return sequenceNumber The sequence number assigned to this request + */ function requestPriceUpdatesWithCallback( uint256 publishTime, bytes32[] calldata priceIds, uint256 callbackGasLimit ) external payable returns (uint64 sequenceNumber); + /** + * @notice Executes the callback for a price update request + * @dev Requires 1.5x the callback gas limit to account for cross-contract call overhead + * For example, if callbackGasLimit is 1M, the transaction needs at least 1.5M gas + some gas for some other operations in the function before the callback + * @param sequenceNumber The sequence number of the request + * @param updateData The raw price update data from Pyth + * @param priceIds The price feed IDs to update, must match the request + */ function executeCallback( uint64 sequenceNumber, bytes[] calldata updateData, From de603b9ffa1279e55a30b0f860b211361dcce1c6 Mon Sep 17 00:00:00 2001 From: Daniel Chew Date: Wed, 8 Jan 2025 12:08:07 +0900 Subject: [PATCH 26/29] fix: use fixed-length array for priceIds in req --- .../contracts/contracts/pulse/Pulse.sol | 27 +++++++++++++++---- .../contracts/contracts/pulse/PulseErrors.sol | 1 + .../contracts/contracts/pulse/PulseState.sol | 4 ++- 3 files changed, 26 insertions(+), 6 deletions(-) diff --git a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol index 9c10b10e49..64ccecaed3 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol @@ -31,6 +31,11 @@ abstract contract Pulse is IPulse, PulseState { req.publishTime = 1; req.callbackGasLimit = 1; req.requester = address(1); + req.numPriceIds = 0; + // Pre-warm the priceIds array storage + for (uint8 j = 0; j < MAX_PRICE_IDS; j++) { + req.priceIds[j] = bytes32(0); + } } } } @@ -41,6 +46,9 @@ abstract contract Pulse is IPulse, PulseState { uint256 callbackGasLimit ) external payable override returns (uint64 requestSequenceNumber) { require(publishTime <= block.timestamp + 60, "Too far in future"); + if (priceIds.length > MAX_PRICE_IDS) { + revert TooManyPriceIds(priceIds.length, MAX_PRICE_IDS); + } requestSequenceNumber = _state.currentSequenceNumber++; uint128 requiredFee = getFee(callbackGasLimit); @@ -49,9 +57,14 @@ abstract contract Pulse is IPulse, PulseState { Request storage req = allocRequest(requestSequenceNumber); req.sequenceNumber = requestSequenceNumber; req.publishTime = publishTime; - req.priceIdsHash = keccak256(abi.encode(priceIds)); req.callbackGasLimit = callbackGasLimit; req.requester = msg.sender; + req.numPriceIds = uint8(priceIds.length); + + // Copy price IDs to storage + for (uint8 i = 0; i < priceIds.length; i++) { + req.priceIds[i] = priceIds[i]; + } _state.accruedFeesInWei += SafeCast.toUint128(msg.value); @@ -66,10 +79,14 @@ abstract contract Pulse is IPulse, PulseState { Request storage req = findActiveRequest(sequenceNumber); // Verify priceIds match - bytes32 providedPriceIdsHash = keccak256(abi.encode(priceIds)); - bytes32 storedPriceIdsHash = req.priceIdsHash; - if (providedPriceIdsHash != storedPriceIdsHash) { - revert InvalidPriceIds(providedPriceIdsHash, storedPriceIdsHash); + require( + priceIds.length == req.numPriceIds, + "Price IDs length mismatch" + ); + for (uint8 i = 0; i < req.numPriceIds; i++) { + if (priceIds[i] != req.priceIds[i]) { + revert InvalidPriceIds(priceIds[i], req.priceIds[i]); + } } // Parse price feeds first to measure gas usage diff --git a/target_chains/ethereum/contracts/contracts/pulse/PulseErrors.sol b/target_chains/ethereum/contracts/contracts/pulse/PulseErrors.sol index c2fe41ccb6..aacb123ba5 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/PulseErrors.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/PulseErrors.sol @@ -12,3 +12,4 @@ error InvalidPriceIds(bytes32 providedPriceIdsHash, bytes32 storedPriceIdsHash); error InvalidCallbackGasLimit(uint256 requested, uint256 stored); error ExceedsMaxPrices(uint32 requested, uint32 maxAllowed); error InsufficientGas(); +error TooManyPriceIds(uint256 provided, uint256 maximum); diff --git a/target_chains/ethereum/contracts/contracts/pulse/PulseState.sol b/target_chains/ethereum/contracts/contracts/pulse/PulseState.sol index 6b0b48fc61..82c1fa7967 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/PulseState.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/PulseState.sol @@ -5,11 +5,13 @@ pragma solidity ^0.8.0; contract PulseState { uint8 public constant NUM_REQUESTS = 32; bytes1 public constant NUM_REQUESTS_MASK = 0x1f; + uint8 public constant MAX_PRICE_IDS = 10; struct Request { uint64 sequenceNumber; uint256 publishTime; - bytes32 priceIdsHash; + bytes32[MAX_PRICE_IDS] priceIds; + uint8 numPriceIds; // Actual number of price IDs used uint256 callbackGasLimit; address requester; } From c30fbe1ca32fd996b110d022e47a26d9c57b4807 Mon Sep 17 00:00:00 2001 From: Daniel Chew Date: Wed, 8 Jan 2025 12:19:39 +0900 Subject: [PATCH 27/29] add test --- .../ethereum/contracts/forge-test/Pulse.t.sol | 53 ++++++++++++++++--- 1 file changed, 45 insertions(+), 8 deletions(-) diff --git a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol index 86c99caa34..1b0af6aade 100644 --- a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol +++ b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol @@ -196,7 +196,19 @@ contract PulseTest is Test, PulseEvents { PulseState.Request memory expectedRequest = PulseState.Request({ sequenceNumber: 1, publishTime: publishTime, - priceIdsHash: keccak256(abi.encode(priceIds)), + priceIds: [ + priceIds[0], + priceIds[1], + bytes32(0), // Fill remaining slots with zero + bytes32(0), + bytes32(0), + bytes32(0), + bytes32(0), + bytes32(0), + bytes32(0), + bytes32(0) + ], + numPriceIds: 2, callbackGasLimit: CALLBACK_GAS_LIMIT, requester: address(consumer) }); @@ -215,7 +227,10 @@ contract PulseTest is Test, PulseEvents { PulseState.Request memory lastRequest = pulse.getRequest(1); assertEq(lastRequest.sequenceNumber, expectedRequest.sequenceNumber); assertEq(lastRequest.publishTime, expectedRequest.publishTime); - assertEq(lastRequest.priceIdsHash, expectedRequest.priceIdsHash); + assertEq(lastRequest.numPriceIds, expectedRequest.numPriceIds); + for (uint8 i = 0; i < lastRequest.numPriceIds; i++) { + assertEq(lastRequest.priceIds[i], expectedRequest.priceIds[i]); + } assertEq( lastRequest.callbackGasLimit, expectedRequest.callbackGasLimit @@ -627,17 +642,13 @@ contract PulseTest is Test, PulseEvents { mockParsePriceFeedUpdates(priceFeeds); bytes[] memory updateData = createMockUpdateData(priceFeeds); - // Calculate hashes for both arrays - bytes32 providedPriceIdsHash = keccak256(abi.encode(wrongPriceIds)); - bytes32 storedPriceIdsHash = keccak256(abi.encode(priceIds)); - // Should revert when trying to execute with wrong priceIds vm.prank(updater); vm.expectRevert( abi.encodeWithSelector( InvalidPriceIds.selector, - providedPriceIdsHash, - storedPriceIdsHash + wrongPriceIds[0], + priceIds[0] ) ); pulse.executeCallback(sequenceNumber, updateData, wrongPriceIds); @@ -679,4 +690,30 @@ contract PulseTest is Test, PulseEvents { assertEq(consumer.lastSequenceNumber(), sequenceNumber); assertEq(consumer.lastUpdater(), updater); } + + function testRevertOnTooManyPriceIds() public { + uint256 maxPriceIds = uint256(pulse.MAX_PRICE_IDS()); + // Create array with MAX_PRICE_IDS + 1 price IDs + bytes32[] memory priceIds = new bytes32[](maxPriceIds + 1); + for (uint i = 0; i < priceIds.length; i++) { + priceIds[i] = bytes32(uint256(i + 1)); + } + + vm.deal(address(consumer), 1 gwei); + uint128 totalFee = calculateTotalFee(); + + vm.prank(address(consumer)); + vm.expectRevert( + abi.encodeWithSelector( + TooManyPriceIds.selector, + maxPriceIds + 1, + maxPriceIds + ) + ); + pulse.requestPriceUpdatesWithCallback{value: totalFee}( + block.timestamp, + priceIds, + CALLBACK_GAS_LIMIT + ); + } } From 38ab62e412f0508cabba5b92425486be2a94171c Mon Sep 17 00:00:00 2001 From: Daniel Chew Date: Fri, 17 Jan 2025 12:46:32 +0900 Subject: [PATCH 28/29] address comments --- .gitignore | 1 + .../contracts/contracts/pulse/IPulse.sol | 31 ++++++++++++++----- .../contracts/contracts/pulse/Pulse.sol | 15 +++++++-- .../contracts/contracts/pulse/PulseState.sol | 4 ++- 4 files changed, 39 insertions(+), 12 deletions(-) diff --git a/.gitignore b/.gitignore index 47f80f33a2..10176ca152 100644 --- a/.gitignore +++ b/.gitignore @@ -23,3 +23,4 @@ __pycache__ .direnv .next .turbo/ +.cursorrules diff --git a/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol b/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol index 9cd2fde3ed..f3d06a7704 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol @@ -18,13 +18,17 @@ interface IPulse is PulseEvents { // Core functions /** * @notice Requests price updates with a callback - * @dev The msg.value must cover both the Pyth fee and gas costs - * Note: The actual gas required for execution will be 1.5x the callbackGasLimit - * to account for cross-contract call overhead + some gas for some other operations in the function before the callback - * @param publishTime The minimum publish time for price updates - * @param priceIds The price feed IDs to update - * @param callbackGasLimit Gas limit for the callback execution + * @dev The msg.value must be equal to getFee(callbackGasLimit) + * @param callbackGasLimit The amount of gas allocated for the callback execution + * @param publishTime The minimum publish time for price updates, it should be less than or equal to block.timestamp + 60 + * @param priceIds The price feed IDs to update. Maximum 10 price feeds per request. + * Requests requiring more feeds should be split into multiple calls. * @return sequenceNumber The sequence number assigned to this request + * @dev Security note: The 60-second future limit on publishTime prevents a DoS vector where + * attackers could submit many low-fee requests for far-future updates when gas prices + * are low, forcing executors to fulfill them later when gas prices might be much higher. + * Since tx.gasprice is used to calculate fees, allowing far-future requests would make + * the fee estimation unreliable. */ function requestPriceUpdatesWithCallback( uint256 publishTime, @@ -47,12 +51,23 @@ interface IPulse is PulseEvents { ) external payable; // Getters + /** + * @notice Gets the base fee charged by Pyth protocol + * @dev This is a fixed fee per request that goes to the Pyth protocol, separate from gas costs + * @return pythFeeInWei The base fee in wei that every request must pay + */ + function getPythFeeInWei() external view returns (uint128 pythFeeInWei); + + /** + * @notice Calculates the total fee required for a price update request + * @dev Total fee = base Pyth protocol fee + gas costs for callback + * @param callbackGasLimit The amount of gas allocated for callback execution + * @return feeAmount The total fee in wei that must be provided as msg.value + */ function getFee( uint256 callbackGasLimit ) external view returns (uint128 feeAmount); - function getPythFeeInWei() external view returns (uint128 pythFeeInWei); - function getAccruedFees() external view returns (uint128 accruedFeesInWei); function getRequest( diff --git a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol index 64ccecaed3..dcd84211fc 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol @@ -45,6 +45,11 @@ abstract contract Pulse is IPulse, PulseState { bytes32[] calldata priceIds, uint256 callbackGasLimit ) external payable override returns (uint64 requestSequenceNumber) { + // NOTE: The 60-second future limit on publishTime prevents a DoS vector where + // attackers could submit many low-fee requests for far-future updates when gas prices + // are low, forcing executors to fulfill them later when gas prices might be much higher. + // Since tx.gasprice is used to calculate fees, allowing far-future requests would make + // the fee estimation unreliable. require(publishTime <= block.timestamp + 60, "Too far in future"); if (priceIds.length > MAX_PRICE_IDS) { revert TooManyPriceIds(priceIds.length, MAX_PRICE_IDS); @@ -98,7 +103,13 @@ abstract contract Pulse is IPulse, PulseState { SafeCast.toUint64(req.publishTime) ); - // Check if enough gas remains for the callback plus 50% overhead for cross-contract call (uses integer arithmetic to avoid floating point) + clearRequest(sequenceNumber); + + // Check if enough gas remains for callback + events/cleanup + // We need extra gas beyond callbackGasLimit for: + // 1. Emitting success/failure events + // 2. Error handling in catch blocks + // 3. State cleanup operations if (gasleft() < (req.callbackGasLimit * 3) / 2) { revert InsufficientGas(); } @@ -129,8 +140,6 @@ abstract contract Pulse is IPulse, PulseState { "low-level error (possibly out of gas)" ); } - - clearRequest(sequenceNumber); } function emitPriceUpdate( diff --git a/target_chains/ethereum/contracts/contracts/pulse/PulseState.sol b/target_chains/ethereum/contracts/contracts/pulse/PulseState.sol index 82c1fa7967..50ef0147cd 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/PulseState.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/PulseState.sol @@ -5,6 +5,8 @@ pragma solidity ^0.8.0; contract PulseState { uint8 public constant NUM_REQUESTS = 32; bytes1 public constant NUM_REQUESTS_MASK = 0x1f; + // Maximum number of price feeds per request. This limit keeps gas costs predictable and reasonable. 10 is a reasonable number for most use cases. + // Requests with more than 10 price feeds should be split into multiple requests uint8 public constant MAX_PRICE_IDS = 10; struct Request { @@ -23,7 +25,7 @@ contract PulseState { address pyth; uint64 currentSequenceNumber; address feeManager; - Request[32] requests; + Request[NUM_REQUESTS] requests; mapping(bytes32 => Request) requestsOverflow; } From 394ed839bf7116ae44254d6eb490ceef6645fd1f Mon Sep 17 00:00:00 2001 From: Daniel Chew Date: Sun, 19 Jan 2025 11:55:51 +0900 Subject: [PATCH 29/29] bump pyth-sdk-solana to v0.10.3 --- apps/hermes/server/Cargo.lock | 4 ++-- apps/hermes/server/Cargo.toml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/apps/hermes/server/Cargo.lock b/apps/hermes/server/Cargo.lock index e5f1866b06..f337a41801 100644 --- a/apps/hermes/server/Cargo.lock +++ b/apps/hermes/server/Cargo.lock @@ -3214,9 +3214,9 @@ dependencies = [ [[package]] name = "pyth-sdk-solana" -version = "0.10.2" +version = "0.10.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a9301b3c3db3766fd1dd0c5769a706d886265b3f56772ac279365bdd807d3076" +checksum = "65841f3d9d5a025ca32999463f720ba16da2dfe9ffe5b668b2e26dd4239dfa9e" dependencies = [ "borsh 0.10.3", "borsh-derive 0.10.3", diff --git a/apps/hermes/server/Cargo.toml b/apps/hermes/server/Cargo.toml index 8bc506af48..6466a21a60 100644 --- a/apps/hermes/server/Cargo.toml +++ b/apps/hermes/server/Cargo.toml @@ -30,7 +30,7 @@ nonzero_ext = { version = "0.3.0" } prometheus-client = { version = "0.21.2" } prost = { version = "0.12.1" } pyth-sdk = { version = "0.8.0" } -pyth-sdk-solana = { version = "0.10.2" } +pyth-sdk-solana = "0.10.3" pythnet-sdk = { path = "../../../pythnet/pythnet_sdk/", version = "2.0.0", features = ["strum"] } rand = { version = "0.8.5" } reqwest = { version = "0.11.14", features = ["blocking", "json"] }