Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 125 additions & 66 deletions target_chains/ethereum/contracts/contracts/pyth/Pyth.sol
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,37 @@ abstract contract Pyth is
return getTotalFee(totalNumUpdates);
}

function getTwapUpdateFee(
bytes[] calldata updateData
) public view override returns (uint feeAmount) {
uint totalNumUpdates = 0;
// For TWAP updates, updateData is always length 2 (start and end points),
// but each VAA can contain multiple price feeds. We only need to count
// the number of updates in the first VAA since both VAAs will have the
// same number of price feeds.
if (
updateData[0].length > 4 &&
UnsafeCalldataBytesLib.toUint32(updateData[0], 0) ==
ACCUMULATOR_MAGIC
) {
(
uint offset,
UpdateType updateType
) = extractUpdateTypeFromAccumulatorHeader(updateData[0]);
if (updateType != UpdateType.WormholeMerkle) {
revert PythErrors.InvalidUpdateData();
}
totalNumUpdates += parseWormholeMerkleHeaderNumUpdates(
updateData[0],
offset
);
} else {
revert PythErrors.InvalidUpdateData();
}

return getTotalFee(totalNumUpdates);
}

// This is an overwrite of the same method in AbstractPyth.sol
// to be more gas efficient.
function updatePriceFeedsIfNecessary(
Expand Down Expand Up @@ -372,18 +403,18 @@ abstract contract Pyth is
);
}

function processSingleTwapUpdate(
function extractTwapPriceInfos(
bytes calldata updateData
)
private
view
returns (
/// @return newOffset The next position in the update data after processing this TWAP update
/// @return twapPriceInfo The extracted time-weighted average price information
/// @return priceId The unique identifier for this price feed
/// @return priceInfos Array of extracted TWAP price information
/// @return priceIds Array of corresponding price feed IDs
uint newOffset,
PythStructs.TwapPriceInfo memory twapPriceInfo,
bytes32 priceId
PythStructs.TwapPriceInfo[] memory twapPriceInfos,
bytes32[] memory priceIds
)
{
UpdateType updateType;
Expand Down Expand Up @@ -417,12 +448,22 @@ abstract contract Pyth is
revert PythErrors.InvalidUpdateData();
}

// Extract start TWAP data with robust error checking
(offset, twapPriceInfo, priceId) = extractTwapPriceInfoFromMerkleProof(
digest,
encoded,
offset
);
// Initialize arrays to store all price infos and ids from this update
twapPriceInfos = new PythStructs.TwapPriceInfo[](numUpdates);
priceIds = new bytes32[](numUpdates);

// Extract each TWAP price info from the merkle proof
for (uint i = 0; i < numUpdates; i++) {
PythStructs.TwapPriceInfo memory twapPriceInfo;
bytes32 priceId;
(
offset,
twapPriceInfo,
priceId
) = extractTwapPriceInfoFromMerkleProof(digest, encoded, offset);
twapPriceInfos[i] = twapPriceInfo;
priceIds[i] = priceId;
}

if (offset != encoded.length) {
revert PythErrors.InvalidTwapUpdateData();
Expand All @@ -439,71 +480,89 @@ abstract contract Pyth is
override
returns (PythStructs.TwapPriceFeed[] memory twapPriceFeeds)
{
// TWAP requires exactly 2 updates - one for the start point and one for the end point
// to calculate the time-weighted average price between those two points
// TWAP requires exactly 2 updates: one for the start point and one for the end point
if (updateData.length != 2) {
revert PythErrors.InvalidUpdateData();
}

uint requiredFee = getUpdateFee(updateData);
uint requiredFee = getTwapUpdateFee(updateData);
if (msg.value < requiredFee) revert PythErrors.InsufficientFee();

unchecked {
twapPriceFeeds = new PythStructs.TwapPriceFeed[](priceIds.length);
for (uint i = 0; i < updateData.length - 1; i++) {
if (
(updateData[i].length > 4 &&
UnsafeCalldataBytesLib.toUint32(updateData[i], 0) ==
ACCUMULATOR_MAGIC) &&
(updateData[i + 1].length > 4 &&
UnsafeCalldataBytesLib.toUint32(updateData[i + 1], 0) ==
ACCUMULATOR_MAGIC)
) {
uint offsetStart;
uint offsetEnd;
bytes32 priceIdStart;
bytes32 priceIdEnd;
PythStructs.TwapPriceInfo memory twapPriceInfoStart;
PythStructs.TwapPriceInfo memory twapPriceInfoEnd;
(
offsetStart,
twapPriceInfoStart,
priceIdStart
) = processSingleTwapUpdate(updateData[i]);
(
offsetEnd,
twapPriceInfoEnd,
priceIdEnd
) = processSingleTwapUpdate(updateData[i + 1]);

if (priceIdStart != priceIdEnd)
revert PythErrors.InvalidTwapUpdateDataSet();

validateTwapPriceInfo(twapPriceInfoStart, twapPriceInfoEnd);

uint k = findIndexOfPriceId(priceIds, priceIdStart);

// If priceFeed[k].id != 0 then it means that there was a valid
// update for priceIds[k] and we don't need to process this one.
if (k == priceIds.length || twapPriceFeeds[k].id != 0) {
continue;
}

twapPriceFeeds[k] = calculateTwap(
priceIdStart,
twapPriceInfoStart,
twapPriceInfoEnd
);
} else {
revert PythErrors.InvalidUpdateData();
}
// Process start update data
PythStructs.TwapPriceInfo[] memory startTwapPriceInfos;
bytes32[] memory startPriceIds;
{
uint offsetStart;
(
offsetStart,
startTwapPriceInfos,
startPriceIds
) = extractTwapPriceInfos(updateData[0]);
}

// Process end update data
PythStructs.TwapPriceInfo[] memory endTwapPriceInfos;
bytes32[] memory endPriceIds;
{
uint offsetEnd;
(offsetEnd, endTwapPriceInfos, endPriceIds) = extractTwapPriceInfos(
updateData[1]
);
}

// Verify that we have the same number of price feeds in start and end updates
if (startPriceIds.length != endPriceIds.length) {
revert PythErrors.InvalidTwapUpdateDataSet();
}

// Hermes always returns price feeds in the same order for start and end updates
// This allows us to assume startPriceIds[i] == endPriceIds[i] for efficiency
for (uint i = 0; i < startPriceIds.length; i++) {
if (startPriceIds[i] != endPriceIds[i]) {
revert PythErrors.InvalidTwapUpdateDataSet();
}
}

for (uint k = 0; k < priceIds.length; k++) {
if (twapPriceFeeds[k].id == 0) {
revert PythErrors.PriceFeedNotFoundWithinRange();
// Initialize the output array
twapPriceFeeds = new PythStructs.TwapPriceFeed[](priceIds.length);

// For each requested price ID, find matching start and end data points
for (uint i = 0; i < priceIds.length; i++) {
bytes32 requestedPriceId = priceIds[i];
int startIdx = -1;

// Find the index of this price ID in the startPriceIds array
// (which is the same as the endPriceIds array based on our validation above)
for (uint j = 0; j < startPriceIds.length; j++) {
if (startPriceIds[j] == requestedPriceId) {
startIdx = int(j);
break;
}
}

// If we found the price ID
if (startIdx >= 0) {
uint idx = uint(startIdx);
// Validate the pair of price infos
validateTwapPriceInfo(
startTwapPriceInfos[idx],
endTwapPriceInfos[idx]
);

// Calculate TWAP from these data points
twapPriceFeeds[i] = calculateTwap(
requestedPriceId,
startTwapPriceInfos[idx],
endTwapPriceInfos[idx]
);
}
}

// Ensure all requested price IDs were found
for (uint k = 0; k < priceIds.length; k++) {
if (twapPriceFeeds[k].id == 0) {
revert PythErrors.PriceFeedNotFoundWithinRange();
}
}
}

Expand Down
Loading
Loading