Skip to content

Commit 19b77e2

Browse files
authored
[ethereum] - charge updateFee per number of updates (#878)
* feat(ethereum): charge update fee per numUpdates for accumulator updates * refactor(ethereum): refactor, add benchmarks for getUpdateFee * refactor(ethereum): add back parseWormholeMerkleHeaderNumUpdates * refactor: increment totalNumUdpates by 1 for batch prices * test(ethereum): add test for checking getUpdateFee for accumulator, clean up unused code
1 parent 75abeb1 commit 19b77e2

File tree

6 files changed

+175
-89
lines changed

6 files changed

+175
-89
lines changed

target_chains/ethereum/contracts/contracts/pyth/Pyth.sol

Lines changed: 53 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -71,24 +71,26 @@ abstract contract Pyth is
7171
function updatePriceFeeds(
7272
bytes[] calldata updateData
7373
) public payable override {
74-
// TODO: Is this fee model still good for accumulator?
75-
uint requiredFee = getUpdateFee(updateData);
76-
if (msg.value < requiredFee) revert PythErrors.InsufficientFee();
77-
74+
uint totalNumUpdates = 0;
7875
for (uint i = 0; i < updateData.length; ) {
7976
if (
8077
updateData[i].length > 4 &&
8178
UnsafeBytesLib.toUint32(updateData[i], 0) == ACCUMULATOR_MAGIC
8279
) {
83-
updatePriceInfosFromAccumulatorUpdate(updateData[i]);
80+
totalNumUpdates += updatePriceInfosFromAccumulatorUpdate(
81+
updateData[i]
82+
);
8483
} else {
8584
updatePriceBatchFromVm(updateData[i]);
85+
totalNumUpdates += 1;
8686
}
8787

8888
unchecked {
8989
i++;
9090
}
9191
}
92+
uint requiredFee = getTotalFee(totalNumUpdates);
93+
if (msg.value < requiredFee) revert PythErrors.InsufficientFee();
9294
}
9395

9496
/// This method is deprecated, please use the `getUpdateFee(bytes[])` instead.
@@ -101,7 +103,28 @@ abstract contract Pyth is
101103
function getUpdateFee(
102104
bytes[] calldata updateData
103105
) public view override returns (uint feeAmount) {
104-
return singleUpdateFeeInWei() * updateData.length;
106+
uint totalNumUpdates = 0;
107+
for (uint i = 0; i < updateData.length; i++) {
108+
if (
109+
updateData[i].length > 4 &&
110+
UnsafeBytesLib.toUint32(updateData[i], 0) == ACCUMULATOR_MAGIC
111+
) {
112+
(
113+
uint offset,
114+
UpdateType updateType
115+
) = extractUpdateTypeFromAccumulatorHeader(updateData[i]);
116+
if (updateType != UpdateType.WormholeMerkle) {
117+
revert PythErrors.InvalidUpdateData();
118+
}
119+
totalNumUpdates += parseWormholeMerkleHeaderNumUpdates(
120+
updateData[i],
121+
offset
122+
);
123+
} else {
124+
totalNumUpdates += 1;
125+
}
126+
}
127+
return getTotalFee(totalNumUpdates);
105128
}
106129

107130
function verifyPythVM(
@@ -425,52 +448,45 @@ abstract contract Pyth is
425448
returns (PythStructs.PriceFeed[] memory priceFeeds)
426449
{
427450
unchecked {
428-
{
429-
uint requiredFee = getUpdateFee(updateData);
430-
if (msg.value < requiredFee)
431-
revert PythErrors.InsufficientFee();
432-
}
433-
451+
uint totalNumUpdates = 0;
434452
priceFeeds = new PythStructs.PriceFeed[](priceIds.length);
435453
for (uint i = 0; i < updateData.length; i++) {
436454
if (
437455
updateData[i].length > 4 &&
438456
UnsafeBytesLib.toUint32(updateData[i], 0) ==
439457
ACCUMULATOR_MAGIC
440458
) {
441-
bytes memory accumulatorUpdate = updateData[i];
442459
uint offset;
443460
{
444461
UpdateType updateType;
445462
(
446463
offset,
447464
updateType
448465
) = extractUpdateTypeFromAccumulatorHeader(
449-
accumulatorUpdate
466+
updateData[i]
450467
);
451468

452469
if (updateType != UpdateType.WormholeMerkle) {
453470
revert PythErrors.InvalidUpdateData();
454471
}
455472
}
473+
456474
bytes20 digest;
457475
uint8 numUpdates;
458-
bytes memory encoded = UnsafeBytesLib.slice(
459-
accumulatorUpdate,
460-
offset,
461-
accumulatorUpdate.length - offset
462-
);
463-
476+
bytes memory encoded;
464477
(
465478
offset,
466479
digest,
467-
numUpdates
468-
) = extractWormholeMerkleHeaderDigestAndNumUpdates(encoded);
480+
numUpdates,
481+
encoded
482+
) = extractWormholeMerkleHeaderDigestAndNumUpdatesAndEncodedFromAccumulatorUpdate(
483+
updateData[i],
484+
offset
485+
);
469486

470487
for (uint j = 0; j < numUpdates; j++) {
471488
PythInternalStructs.PriceInfo memory info;
472489
bytes32 priceId;
473-
474490
(
475491
offset,
476492
info,
@@ -509,6 +525,7 @@ abstract contract Pyth is
509525
}
510526
}
511527
}
528+
totalNumUpdates += numUpdates;
512529
if (offset != encoded.length)
513530
revert PythErrors.InvalidUpdateData();
514531
} else {
@@ -583,6 +600,7 @@ abstract contract Pyth is
583600

584601
index += attestationSize;
585602
}
603+
totalNumUpdates += 1;
586604
}
587605
}
588606

@@ -591,9 +609,21 @@ abstract contract Pyth is
591609
revert PythErrors.PriceFeedNotFoundWithinRange();
592610
}
593611
}
612+
613+
{
614+
uint requiredFee = getTotalFee(totalNumUpdates);
615+
if (msg.value < requiredFee)
616+
revert PythErrors.InsufficientFee();
617+
}
594618
}
595619
}
596620

621+
function getTotalFee(
622+
uint totalNumUpdates
623+
) private view returns (uint requiredFee) {
624+
return totalNumUpdates * singleUpdateFeeInWei();
625+
}
626+
597627
function findIndexOfPriceId(
598628
bytes32[] calldata priceIds,
599629
bytes32 targetPriceId

target_chains/ethereum/contracts/contracts/pyth/PythAccumulator.sol

Lines changed: 53 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,24 @@ abstract contract PythAccumulator is PythGetters, PythSetters, AbstractPyth {
107107
}
108108
}
109109

110-
function extractWormholeMerkleHeaderDigestAndNumUpdates(
111-
bytes memory encoded
112-
) internal view returns (uint offset, bytes20 digest, uint8 numUpdates) {
110+
function extractWormholeMerkleHeaderDigestAndNumUpdatesAndEncodedFromAccumulatorUpdate(
111+
bytes calldata accumulatorUpdate,
112+
uint encodedOffset
113+
)
114+
internal
115+
view
116+
returns (
117+
uint offset,
118+
bytes20 digest,
119+
uint8 numUpdates,
120+
bytes memory encoded
121+
)
122+
{
123+
encoded = UnsafeBytesLib.slice(
124+
accumulatorUpdate,
125+
encodedOffset,
126+
accumulatorUpdate.length - encodedOffset
127+
);
113128
unchecked {
114129
offset = 0;
115130

@@ -170,6 +185,19 @@ abstract contract PythAccumulator is PythGetters, PythSetters, AbstractPyth {
170185
}
171186
}
172187

188+
function parseWormholeMerkleHeaderNumUpdates(
189+
bytes calldata wormholeMerkleUpdate,
190+
uint offset
191+
) internal view returns (uint8 numUpdates) {
192+
uint16 whProofSize = UnsafeBytesLib.toUint16(
193+
wormholeMerkleUpdate,
194+
offset
195+
);
196+
offset += 2;
197+
offset += whProofSize;
198+
numUpdates = UnsafeBytesLib.toUint8(wormholeMerkleUpdate, offset);
199+
}
200+
173201
function extractPriceInfoFromMerkleProof(
174202
bytes20 digest,
175203
bytes memory encoded,
@@ -185,62 +213,28 @@ abstract contract PythAccumulator is PythGetters, PythSetters, AbstractPyth {
185213
{
186214
unchecked {
187215
bytes memory encodedMessage;
188-
(endOffset, encodedMessage) = extractMessageFromProof(
189-
encoded,
190-
offset,
191-
digest
192-
);
193-
194-
(priceInfo, priceId) = extractPriceInfoAndIdFromPriceFeedMessage(
195-
encodedMessage
196-
);
197-
198-
return (endOffset, priceInfo, priceId);
199-
}
200-
}
201-
202-
function extractMessageFromProof(
203-
bytes memory encodedProof,
204-
uint offset,
205-
bytes20 merkleRoot
206-
) private pure returns (uint endOffset, bytes memory encodedMessage) {
207-
unchecked {
208-
uint16 messageSize = UnsafeBytesLib.toUint16(encodedProof, offset);
216+
uint16 messageSize = UnsafeBytesLib.toUint16(encoded, offset);
209217
offset += 2;
210218

211-
encodedMessage = UnsafeBytesLib.slice(
212-
encodedProof,
213-
offset,
214-
messageSize
215-
);
219+
encodedMessage = UnsafeBytesLib.slice(encoded, offset, messageSize);
216220
offset += messageSize;
217221

218222
bool valid;
219223
(valid, endOffset) = MerkleTree.isProofValid(
220-
encodedProof,
224+
encoded,
221225
offset,
222-
merkleRoot,
226+
digest,
223227
encodedMessage
224228
);
225229
if (!valid) {
226230
revert PythErrors.InvalidUpdateData();
227231
}
228-
}
229-
}
230232

231-
function extractPriceInfoAndIdFromPriceFeedMessage(
232-
bytes memory encodedMessage
233-
)
234-
private
235-
pure
236-
returns (PythInternalStructs.PriceInfo memory info, bytes32 priceId)
237-
{
238-
unchecked {
239233
MessageType messageType = MessageType(
240234
UnsafeBytesLib.toUint8(encodedMessage, 0)
241235
);
242236
if (messageType == MessageType.PriceFeed) {
243-
(info, priceId) = parsePriceFeedMessage(
237+
(priceInfo, priceId) = parsePriceFeedMessage(
244238
UnsafeBytesLib.slice(
245239
encodedMessage,
246240
1,
@@ -250,6 +244,8 @@ abstract contract PythAccumulator is PythGetters, PythSetters, AbstractPyth {
250244
} else {
251245
revert PythErrors.InvalidUpdateData();
252246
}
247+
248+
return (endOffset, priceInfo, priceId);
253249
}
254250
}
255251

@@ -315,32 +311,30 @@ abstract contract PythAccumulator is PythGetters, PythSetters, AbstractPyth {
315311

316312
function updatePriceInfosFromAccumulatorUpdate(
317313
bytes calldata accumulatorUpdate
318-
) internal {
314+
) internal returns (uint8 numUpdates) {
319315
(
320-
uint offset,
316+
uint encodedOffset,
321317
UpdateType updateType
322318
) = extractUpdateTypeFromAccumulatorHeader(accumulatorUpdate);
323319

324320
if (updateType != UpdateType.WormholeMerkle) {
325321
revert PythErrors.InvalidUpdateData();
326322
}
327-
updatePriceInfosFromWormholeMerkle(
328-
UnsafeBytesLib.slice(
329-
accumulatorUpdate,
330-
offset,
331-
accumulatorUpdate.length - offset
332-
)
323+
324+
uint offset;
325+
bytes20 digest;
326+
bytes memory encoded;
327+
(
328+
offset,
329+
digest,
330+
numUpdates,
331+
encoded
332+
) = extractWormholeMerkleHeaderDigestAndNumUpdatesAndEncodedFromAccumulatorUpdate(
333+
accumulatorUpdate,
334+
encodedOffset
333335
);
334-
}
335336

336-
function updatePriceInfosFromWormholeMerkle(bytes memory encoded) private {
337337
unchecked {
338-
(
339-
uint offset,
340-
bytes20 digest,
341-
uint8 numUpdates
342-
) = extractWormholeMerkleHeaderDigestAndNumUpdates(encoded);
343-
344338
for (uint i = 0; i < numUpdates; i++) {
345339
PythInternalStructs.PriceInfo memory priceInfo;
346340
bytes32 priceId;
@@ -360,7 +354,7 @@ abstract contract PythAccumulator is PythGetters, PythSetters, AbstractPyth {
360354
);
361355
}
362356
}
363-
if (offset != encoded.length) revert PythErrors.InvalidUpdateData();
364357
}
358+
if (offset != encoded.length) revert PythErrors.InvalidUpdateData();
365359
}
366360
}

target_chains/ethereum/contracts/forge-test/GasBenchmark.t.sol

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -340,9 +340,12 @@ contract GasBenchmark is Test, WormholeTestUtils, PythTestUtils {
340340
ids[0] = priceIds[4];
341341
ids[1] = priceIds[2];
342342
ids[2] = priceIds[0];
343-
pyth.parsePriceFeedUpdates{
344-
value: freshPricesWhMerkleUpdateFee[numIds - 1]
345-
}(freshPricesWhMerkleUpdateData[4], ids, 0, 50);
343+
pyth.parsePriceFeedUpdates{value: freshPricesWhMerkleUpdateFee[4]}( // updateFee based on number of priceFeeds in updateData
344+
freshPricesWhMerkleUpdateData[4],
345+
ids,
346+
0,
347+
50
348+
);
346349
}
347350

348351
function testBenchmarkParsePriceFeedUpdatesWhMerkleForOnePriceFeedNotWithinRange()
@@ -391,7 +394,27 @@ contract GasBenchmark is Test, WormholeTestUtils, PythTestUtils {
391394
pyth.getEmaPrice(priceIds[0]);
392395
}
393396

394-
function testBenchmarkGetUpdateFee() public view {
397+
function testBenchmarkGetUpdateFeeWhBatch() public view {
395398
pyth.getUpdateFee(freshPricesWhBatchUpdateData);
396399
}
400+
401+
function testBenchmarkGetUpdateFeeWhMerkle1() public view {
402+
pyth.getUpdateFee(freshPricesWhMerkleUpdateData[0]);
403+
}
404+
405+
function testBenchmarkGetUpdateFeeWhMerkle2() public view {
406+
pyth.getUpdateFee(freshPricesWhMerkleUpdateData[1]);
407+
}
408+
409+
function testBenchmarkGetUpdateFeeWhMerkle3() public view {
410+
pyth.getUpdateFee(freshPricesWhMerkleUpdateData[2]);
411+
}
412+
413+
function testBenchmarkGetUpdateFeeWhMerkle4() public view {
414+
pyth.getUpdateFee(freshPricesWhMerkleUpdateData[3]);
415+
}
416+
417+
function testBenchmarkGetUpdateFeeWhMerkle5() public view {
418+
pyth.getUpdateFee(freshPricesWhMerkleUpdateData[4]);
419+
}
397420
}

0 commit comments

Comments
 (0)