Skip to content

Commit 5812d78

Browse files
committed
remove provider arg from getFee
1 parent fb8a7cd commit 5812d78

File tree

3 files changed

+36
-42
lines changed

3 files changed

+36
-42
lines changed

target_chains/ethereum/contracts/contracts/pulse/IPulse.sol

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,10 @@ interface IPulse is PulseEvents {
6262
* @notice Calculates the total fee required for a price update request
6363
* @dev Total fee = base Pyth protocol fee + gas costs for callback
6464
* @param callbackGasLimit The amount of gas allocated for callback execution
65-
* @param provider The provider to use for the fee calculation
6665
* @return feeAmount The total fee in wei that must be provided as msg.value
6766
*/
6867
function getFee(
69-
uint256 callbackGasLimit,
70-
address provider
68+
uint256 callbackGasLimit
7169
) external view returns (uint128 feeAmount);
7270

7371
function getAccruedFees() external view returns (uint128 accruedFeesInWei);

target_chains/ethereum/contracts/contracts/pulse/Pulse.sol

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ abstract contract Pulse is IPulse, PulseState {
7070
}
7171
requestSequenceNumber = _state.currentSequenceNumber++;
7272

73-
uint128 requiredFee = getFee(callbackGasLimit, provider);
73+
uint128 requiredFee = getFee(callbackGasLimit);
7474
if (msg.value < requiredFee) revert InsufficientFee();
7575

7676
Request storage req = allocRequest(requestSequenceNumber);
@@ -190,14 +190,12 @@ abstract contract Pulse is IPulse, PulseState {
190190
}
191191

192192
function getFee(
193-
uint256 callbackGasLimit,
194-
address provider
193+
uint256 callbackGasLimit
195194
) public view override returns (uint128 feeAmount) {
196-
if (provider == address(0)) {
197-
provider = _state.defaultProvider;
198-
}
199195
uint128 baseFee = _state.pythFeeInWei;
200-
uint128 providerFeeInWei = _state.providers[provider].feeInWei;
196+
uint128 providerFeeInWei = _state
197+
.providers[_state.defaultProvider]
198+
.feeInWei;
201199
uint256 gasFee = callbackGasLimit * providerFeeInWei;
202200
feeAmount = baseFee + SafeCast.toUint128(gasFee);
203201
}

target_chains/ethereum/contracts/forge-test/Pulse.t.sol

Lines changed: 30 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -159,16 +159,13 @@ contract PulseTest is Test, PulseEvents {
159159
}
160160

161161
// Helper function to calculate total fee
162-
function calculateTotalFee(
163-
address provider
164-
) internal view returns (uint128) {
165-
return pulse.getFee(CALLBACK_GAS_LIMIT, provider);
162+
function calculateTotalFee() internal view returns (uint128) {
163+
return pulse.getFee(CALLBACK_GAS_LIMIT);
166164
}
167165

168166
// Helper function to setup consumer request
169167
function setupConsumerRequest(
170-
address consumerAddress,
171-
address provider
168+
address consumerAddress
172169
)
173170
internal
174171
returns (
@@ -181,7 +178,7 @@ contract PulseTest is Test, PulseEvents {
181178
publishTime = block.timestamp;
182179
vm.deal(consumerAddress, 1 gwei);
183180

184-
uint128 totalFee = calculateTotalFee(provider);
181+
uint128 totalFee = calculateTotalFee();
185182

186183
vm.prank(consumerAddress);
187184
sequenceNumber = pulse.requestPriceUpdatesWithCallback{value: totalFee}(
@@ -202,7 +199,7 @@ contract PulseTest is Test, PulseEvents {
202199

203200
// Fund the consumer contract with enough ETH for higher gas price
204201
vm.deal(address(consumer), 1 ether);
205-
uint128 totalFee = calculateTotalFee(defaultProvider);
202+
uint128 totalFee = calculateTotalFee();
206203

207204
// Create the event data we expect to see
208205
PulseState.Request memory expectedRequest = PulseState.Request({
@@ -276,7 +273,7 @@ contract PulseTest is Test, PulseEvents {
276273

277274
// Fund the consumer contract
278275
vm.deal(address(consumer), 1 gwei);
279-
uint128 totalFee = calculateTotalFee(defaultProvider);
276+
uint128 totalFee = calculateTotalFee();
280277

281278
// Step 1: Make the request as consumer
282279
vm.prank(address(consumer));
@@ -353,7 +350,7 @@ contract PulseTest is Test, PulseEvents {
353350
uint64 sequenceNumber,
354351
bytes32[] memory priceIds,
355352
uint256 publishTime
356-
) = setupConsumerRequest(address(failingConsumer), defaultProvider);
353+
) = setupConsumerRequest(address(failingConsumer));
357354

358355
PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
359356
publishTime
@@ -381,7 +378,7 @@ contract PulseTest is Test, PulseEvents {
381378
uint64 sequenceNumber,
382379
bytes32[] memory priceIds,
383380
uint256 publishTime
384-
) = setupConsumerRequest(address(failingConsumer), defaultProvider);
381+
) = setupConsumerRequest(address(failingConsumer));
385382

386383
PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
387384
publishTime
@@ -408,7 +405,7 @@ contract PulseTest is Test, PulseEvents {
408405
uint64 sequenceNumber,
409406
bytes32[] memory priceIds,
410407
uint256 publishTime
411-
) = setupConsumerRequest(address(consumer), defaultProvider);
408+
) = setupConsumerRequest(address(consumer));
412409

413410
// Setup mock data
414411
PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
@@ -433,7 +430,7 @@ contract PulseTest is Test, PulseEvents {
433430
uint256 futureTime = block.timestamp + 10; // 10 seconds in future
434431
vm.deal(address(consumer), 1 gwei);
435432

436-
uint128 totalFee = calculateTotalFee(defaultProvider);
433+
uint128 totalFee = calculateTotalFee();
437434
vm.prank(address(consumer));
438435
uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{
439436
value: totalFee
@@ -468,7 +465,7 @@ contract PulseTest is Test, PulseEvents {
468465
uint256 farFutureTime = block.timestamp + 61; // Just over 1 minute
469466
vm.deal(address(consumer), 1 gwei);
470467

471-
uint128 totalFee = calculateTotalFee(defaultProvider);
468+
uint128 totalFee = calculateTotalFee();
472469
vm.prank(address(consumer));
473470

474471
vm.expectRevert("Too far in future");
@@ -484,7 +481,7 @@ contract PulseTest is Test, PulseEvents {
484481
uint64 sequenceNumber,
485482
bytes32[] memory priceIds,
486483
uint256 publishTime
487-
) = setupConsumerRequest(address(consumer), defaultProvider);
484+
) = setupConsumerRequest(address(consumer));
488485

489486
PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
490487
publishTime
@@ -514,7 +511,7 @@ contract PulseTest is Test, PulseEvents {
514511
uint128 expectedFee = SafeCast.toUint128(
515512
DEFAULT_PROVIDER_FEE * gasLimit
516513
) + PYTH_FEE;
517-
uint128 actualFee = pulse.getFee(gasLimit, defaultProvider);
514+
uint128 actualFee = pulse.getFee(gasLimit);
518515
assertEq(
519516
actualFee,
520517
expectedFee,
@@ -524,7 +521,7 @@ contract PulseTest is Test, PulseEvents {
524521

525522
// Test with zero gas limit
526523
uint128 expectedMinFee = PYTH_FEE;
527-
uint128 actualMinFee = pulse.getFee(0, defaultProvider);
524+
uint128 actualMinFee = pulse.getFee(0);
528525
assertEq(
529526
actualMinFee,
530527
expectedMinFee,
@@ -538,9 +535,11 @@ contract PulseTest is Test, PulseEvents {
538535
vm.deal(address(consumer), 1 gwei);
539536

540537
vm.prank(address(consumer));
541-
pulse.requestPriceUpdatesWithCallback{
542-
value: calculateTotalFee(defaultProvider)
543-
}(block.timestamp, priceIds, CALLBACK_GAS_LIMIT);
538+
pulse.requestPriceUpdatesWithCallback{value: calculateTotalFee()}(
539+
block.timestamp,
540+
priceIds,
541+
CALLBACK_GAS_LIMIT
542+
);
544543

545544
// Get admin's balance before withdrawal
546545
uint256 adminBalanceBefore = admin.balance;
@@ -586,9 +585,11 @@ contract PulseTest is Test, PulseEvents {
586585
vm.deal(address(consumer), 1 gwei);
587586

588587
vm.prank(address(consumer));
589-
pulse.requestPriceUpdatesWithCallback{
590-
value: calculateTotalFee(defaultProvider)
591-
}(block.timestamp, priceIds, CALLBACK_GAS_LIMIT);
588+
pulse.requestPriceUpdatesWithCallback{value: calculateTotalFee()}(
589+
block.timestamp,
590+
priceIds,
591+
CALLBACK_GAS_LIMIT
592+
);
592593

593594
// Get provider's accrued fees instead of total fees
594595
PulseState.ProviderInfo memory providerInfo = pulse.getProviderInfo(
@@ -645,10 +646,7 @@ contract PulseTest is Test, PulseEvents {
645646
uint256 publishTime = block.timestamp;
646647

647648
// Setup request
648-
(uint64 sequenceNumber, , ) = setupConsumerRequest(
649-
address(consumer),
650-
defaultProvider
651-
);
649+
(uint64 sequenceNumber, , ) = setupConsumerRequest(address(consumer));
652650

653651
// Create different priceIds
654652
bytes32[] memory wrongPriceIds = new bytes32[](2);
@@ -682,7 +680,7 @@ contract PulseTest is Test, PulseEvents {
682680
}
683681

684682
vm.deal(address(consumer), 1 gwei);
685-
uint128 totalFee = calculateTotalFee(defaultProvider);
683+
uint128 totalFee = calculateTotalFee();
686684

687685
vm.prank(address(consumer));
688686
vm.expectRevert(
@@ -752,7 +750,7 @@ contract PulseTest is Test, PulseEvents {
752750
bytes32[] memory priceIds = new bytes32[](1);
753751
priceIds[0] = bytes32(uint256(1));
754752

755-
uint128 totalFee = pulse.getFee(CALLBACK_GAS_LIMIT, provider);
753+
uint128 totalFee = pulse.getFee(CALLBACK_GAS_LIMIT);
756754

757755
vm.deal(address(consumer), totalFee);
758756
vm.prank(address(consumer));
@@ -802,7 +800,7 @@ contract PulseTest is Test, PulseEvents {
802800
uint64 sequenceNumber,
803801
bytes32[] memory priceIds,
804802
uint256 publishTime
805-
) = setupConsumerRequest(address(consumer), defaultProvider);
803+
) = setupConsumerRequest(address(consumer));
806804

807805
// Setup mock data
808806
PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
@@ -832,7 +830,7 @@ contract PulseTest is Test, PulseEvents {
832830
uint64 sequenceNumber,
833831
bytes32[] memory priceIds,
834832
uint256 publishTime
835-
) = setupConsumerRequest(address(consumer), defaultProvider);
833+
) = setupConsumerRequest(address(consumer));
836834

837835
// Setup mock data
838836
PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(
@@ -864,7 +862,7 @@ contract PulseTest is Test, PulseEvents {
864862
uint64 sequenceNumber,
865863
bytes32[] memory priceIds,
866864
uint256 publishTime
867-
) = setupConsumerRequest(address(consumer), defaultProvider);
865+
) = setupConsumerRequest(address(consumer));
868866

869867
// Setup mock data
870868
PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds(

0 commit comments

Comments
 (0)