Skip to content
Merged
170 changes: 109 additions & 61 deletions target_chains/ethereum/contracts/contracts/pyth/Pyth.sol
Original file line number Diff line number Diff line change
Expand Up @@ -379,11 +379,11 @@ abstract contract Pyth is
view
returns (
/// @return newOffset The next position in the update data after processing this TWAP update
/// @return twapPriceInfo The extracted time-weighted average price information
/// @return priceId The unique identifier for this price feed
/// @return priceInfos Array of extracted TWAP price information
/// @return priceIds Array of corresponding price feed IDs
uint newOffset,
PythStructs.TwapPriceInfo memory twapPriceInfo,
bytes32 priceId
PythStructs.TwapPriceInfo[] memory twapPriceInfos,
bytes32[] memory priceIds
)
{
UpdateType updateType;
Expand Down Expand Up @@ -417,12 +417,22 @@ abstract contract Pyth is
revert PythErrors.InvalidUpdateData();
}

// Extract start TWAP data with robust error checking
(offset, twapPriceInfo, priceId) = extractTwapPriceInfoFromMerkleProof(
digest,
encoded,
offset
);
// Initialize arrays to store all price infos and ids from this update
twapPriceInfos = new PythStructs.TwapPriceInfo[](numUpdates);
priceIds = new bytes32[](numUpdates);

// Extract each TWAP price info from the merkle proof
for (uint i = 0; i < numUpdates; i++) {
PythStructs.TwapPriceInfo memory twapPriceInfo;
bytes32 priceId;
(
offset,
twapPriceInfo,
priceId
) = extractTwapPriceInfoFromMerkleProof(digest, encoded, offset);
twapPriceInfos[i] = twapPriceInfo;
priceIds[i] = priceId;
}

if (offset != encoded.length) {
revert PythErrors.InvalidTwapUpdateData();
Expand All @@ -439,71 +449,109 @@ abstract contract Pyth is
override
returns (PythStructs.TwapPriceFeed[] memory twapPriceFeeds)
{
// TWAP requires exactly 2 updates - one for the start point and one for the end point
// to calculate the time-weighted average price between those two points
// TWAP requires exactly 2 updates: one for the start point and one for the end point
if (updateData.length != 2) {
revert PythErrors.InvalidUpdateData();
}

uint requiredFee = getUpdateFee(updateData);
if (msg.value < requiredFee) revert PythErrors.InsufficientFee();

unchecked {
twapPriceFeeds = new PythStructs.TwapPriceFeed[](priceIds.length);
for (uint i = 0; i < updateData.length - 1; i++) {
// Process start update data
PythStructs.TwapPriceInfo[] memory startTwapPriceInfos;
bytes32[] memory startPriceIds;
{
uint offsetStart;
(
offsetStart,
startTwapPriceInfos,
startPriceIds
) = processSingleTwapUpdate(updateData[0]);
}

// Process end update data
PythStructs.TwapPriceInfo[] memory endTwapPriceInfos;
bytes32[] memory endPriceIds;
{
uint offsetEnd;
(
offsetEnd,
endTwapPriceInfos,
endPriceIds
) = processSingleTwapUpdate(updateData[1]);
}

// Verify that we have the same number of price feeds in start and end updates
if (startPriceIds.length != endPriceIds.length) {
revert PythErrors.InvalidTwapUpdateDataSet();
}

// Create a mapping to check that every startPriceId has a matching endPriceId
// This ensures price feed continuity between start and end points
bool[] memory endPriceIdMatched = new bool[](endPriceIds.length);
for (uint i = 0; i < startPriceIds.length; i++) {
bool foundMatch = false;
for (uint j = 0; j < endPriceIds.length; j++) {
if (
(updateData[i].length > 4 &&
UnsafeCalldataBytesLib.toUint32(updateData[i], 0) ==
ACCUMULATOR_MAGIC) &&
(updateData[i + 1].length > 4 &&
UnsafeCalldataBytesLib.toUint32(updateData[i + 1], 0) ==
ACCUMULATOR_MAGIC)
startPriceIds[i] == endPriceIds[j] && !endPriceIdMatched[j]
) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we dont need this check anymore because we call processSingleTwapUpdate and in the function we call extractUpdateTypeFromAccumulatorHeader which checks for the ACCUMULATOR_MAGIC

uint offsetStart;
uint offsetEnd;
bytes32 priceIdStart;
bytes32 priceIdEnd;
PythStructs.TwapPriceInfo memory twapPriceInfoStart;
PythStructs.TwapPriceInfo memory twapPriceInfoEnd;
(
offsetStart,
twapPriceInfoStart,
priceIdStart
) = processSingleTwapUpdate(updateData[i]);
(
offsetEnd,
twapPriceInfoEnd,
priceIdEnd
) = processSingleTwapUpdate(updateData[i + 1]);

if (priceIdStart != priceIdEnd)
revert PythErrors.InvalidTwapUpdateDataSet();

validateTwapPriceInfo(twapPriceInfoStart, twapPriceInfoEnd);

uint k = findIndexOfPriceId(priceIds, priceIdStart);

// If priceFeed[k].id != 0 then it means that there was a valid
// update for priceIds[k] and we don't need to process this one.
if (k == priceIds.length || twapPriceFeeds[k].id != 0) {
continue;
}

twapPriceFeeds[k] = calculateTwap(
priceIdStart,
twapPriceInfoStart,
twapPriceInfoEnd
);
} else {
revert PythErrors.InvalidUpdateData();
endPriceIdMatched[j] = true;
foundMatch = true;
break;
}
}
// If a price ID in start doesn't have a match in end, it's invalid
if (!foundMatch) {
revert PythErrors.InvalidTwapUpdateDataSet();
}
}

// Initialize the output array
twapPriceFeeds = new PythStructs.TwapPriceFeed[](priceIds.length);

for (uint k = 0; k < priceIds.length; k++) {
if (twapPriceFeeds[k].id == 0) {
revert PythErrors.PriceFeedNotFoundWithinRange();
// For each requested price ID, find matching start and end data points
for (uint i = 0; i < priceIds.length; i++) {
bytes32 requestedPriceId = priceIds[i];
int startIdx = -1;
int endIdx = -1;

// Find the index of this price ID in start and end arrays
for (uint j = 0; j < startPriceIds.length; j++) {
if (startPriceIds[j] == requestedPriceId) {
startIdx = int(j);
break;
}
}

for (uint j = 0; j < endPriceIds.length; j++) {
if (endPriceIds[j] == requestedPriceId) {
endIdx = int(j);
break;
}
}

// If we found both start and end data for this price ID
if (startIdx >= 0 && endIdx >= 0) {
// Validate the pair of price infos
validateTwapPriceInfo(
startTwapPriceInfos[uint(startIdx)],
endTwapPriceInfos[uint(endIdx)]
);

// Calculate TWAP from these data points
twapPriceFeeds[i] = calculateTwap(
requestedPriceId,
startTwapPriceInfos[uint(startIdx)],
endTwapPriceInfos[uint(endIdx)]
);
}
}

// Ensure all requested price IDs were found
for (uint k = 0; k < priceIds.length; k++) {
if (twapPriceFeeds[k].id == 0) {
revert PythErrors.PriceFeedNotFoundWithinRange();
}
}
}

Expand Down
Loading
Loading