Skip to content

Commit 919f71e

Browse files
authored
[eth] - optimize ReceiverMessages parseAndVerifyVM (#901)
* feat(eth): optimize ReceiverMessages parseAndVerifyVM * test(eth): update test setups to use wormholeReceiver * chore(eth): remove console logging * feat(eth): optimize & revert return type for parseAndVerifyVM * fix(eth): add index boundary checks * perf(eth): optimize verifySignature by passing in primitives instead of structs * test(eth): add wormhole tests related to guardian set validity * test(eth): add more parseAndVerify failure test cases * test(eth): add more failure tests for parseAndVerify * test(eth): add empty forge test, refactor/deduplicate
1 parent d07cc9d commit 919f71e

File tree

10 files changed

+657
-59
lines changed

10 files changed

+657
-59
lines changed

target_chains/ethereum/contracts/contracts/libraries/external/UnsafeCalldataBytesLib.sol

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,13 @@ library UnsafeCalldataBytesLib {
2020
return _bytes[_start:_start + _length];
2121
}
2222

23+
function sliceFrom(
24+
bytes calldata _bytes,
25+
uint256 _start
26+
) internal pure returns (bytes calldata) {
27+
return _bytes[_start:_bytes.length];
28+
}
29+
2330
function toAddress(
2431
bytes calldata _bytes,
2532
uint256 _start

target_chains/ethereum/contracts/contracts/pyth/Pyth.sol

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,8 @@ abstract contract Pyth is
7575
for (uint i = 0; i < updateData.length; ) {
7676
if (
7777
updateData[i].length > 4 &&
78-
UnsafeBytesLib.toUint32(updateData[i], 0) == ACCUMULATOR_MAGIC
78+
UnsafeCalldataBytesLib.toUint32(updateData[i], 0) ==
79+
ACCUMULATOR_MAGIC
7980
) {
8081
totalNumUpdates += updatePriceInfosFromAccumulatorUpdate(
8182
updateData[i]
@@ -143,7 +144,6 @@ abstract contract Pyth is
143144
// operations have proper require.
144145
unchecked {
145146
bytes memory encoded = vm.payload;
146-
147147
(
148148
uint index,
149149
uint nAttestations,

target_chains/ethereum/contracts/contracts/pyth/PythAccumulator.sol

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ abstract contract PythAccumulator is PythGetters, PythSetters, AbstractPyth {
3131
// This method is also used by batch attestation but moved here
3232
// as the batch attestation will deprecate soon.
3333
function parseAndVerifyPythVM(
34-
bytes memory encodedVm
34+
bytes calldata encodedVm
3535
) internal view returns (IWormhole.VM memory vm) {
3636
{
3737
bool valid;
@@ -152,7 +152,6 @@ abstract contract PythAccumulator is PythGetters, PythSetters, AbstractPyth {
152152

153153
// TODO: Do we need to emit an update for accumulator update? If so what should we emit?
154154
// emit AccumulatorUpdate(vm.chainId, vm.sequence);
155-
156155
encodedPayload = vm.payload;
157156
}
158157

@@ -200,16 +199,19 @@ abstract contract PythAccumulator is PythGetters, PythSetters, AbstractPyth {
200199
}
201200

202201
function parseWormholeMerkleHeaderNumUpdates(
203-
bytes memory wormholeMerkleUpdate,
202+
bytes calldata wormholeMerkleUpdate,
204203
uint offset
205204
) internal pure returns (uint8 numUpdates) {
206-
uint16 whProofSize = UnsafeBytesLib.toUint16(
205+
uint16 whProofSize = UnsafeCalldataBytesLib.toUint16(
207206
wormholeMerkleUpdate,
208207
offset
209208
);
210209
offset += 2;
211210
offset += whProofSize;
212-
numUpdates = UnsafeBytesLib.toUint8(wormholeMerkleUpdate, offset);
211+
numUpdates = UnsafeCalldataBytesLib.toUint8(
212+
wormholeMerkleUpdate,
213+
offset
214+
);
213215
}
214216

215217
function extractPriceInfoFromMerkleProof(

target_chains/ethereum/contracts/contracts/wormhole-receiver/ReceiverMessages.sol

Lines changed: 197 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,179 @@ pragma experimental ABIEncoderV2;
77
import "./ReceiverGetters.sol";
88
import "./ReceiverStructs.sol";
99
import "../libraries/external/BytesLib.sol";
10+
import "../libraries/external/UnsafeCalldataBytesLib.sol";
11+
12+
error VmVersionIncompatible();
13+
error SignatureIndexesNotAscending();
1014

1115
contract ReceiverMessages is ReceiverGetters {
1216
using BytesLib for bytes;
1317

1418
/// @dev parseAndVerifyVM serves to parse an encodedVM and wholy validate it for consumption
19+
/// WARNING: it intentionally sets vm.signatures to an empty array since it is not needed after it is validated in this function
20+
/// since it not used anywhere. If you need to use vm.signatures, use parseVM and verifyVM separately.
1521
function parseAndVerifyVM(
1622
bytes calldata encodedVM
1723
)
1824
public
1925
view
2026
returns (ReceiverStructs.VM memory vm, bool valid, string memory reason)
2127
{
22-
vm = parseVM(encodedVM);
23-
(valid, reason) = verifyVM(vm);
28+
uint index = 0;
29+
unchecked {
30+
{
31+
vm.version = UnsafeCalldataBytesLib.toUint8(encodedVM, index);
32+
index += 1;
33+
if (vm.version != 1) {
34+
revert VmVersionIncompatible();
35+
}
36+
}
37+
38+
ReceiverStructs.GuardianSet memory guardianSet;
39+
{
40+
vm.guardianSetIndex = UnsafeCalldataBytesLib.toUint32(
41+
encodedVM,
42+
index
43+
);
44+
index += 4;
45+
guardianSet = getGuardianSet(vm.guardianSetIndex);
46+
47+
/**
48+
* @dev Checks whether the guardianSet has zero keys
49+
* WARNING: This keys check is critical to ensure the guardianSet has keys present AND to ensure
50+
* that guardianSet key size doesn't fall to zero and negatively impact quorum assessment. If guardianSet
51+
* key length is 0 and vm.signatures length is 0, this could compromise the integrity of both vm and
52+
* signature verification.
53+
*/
54+
if (guardianSet.keys.length == 0) {
55+
return (vm, false, "invalid guardian set");
56+
}
57+
58+
/// @dev Checks if VM guardian set index matches the current index (unless the current set is expired).
59+
if (
60+
vm.guardianSetIndex != getCurrentGuardianSetIndex() &&
61+
guardianSet.expirationTime < block.timestamp
62+
) {
63+
return (vm, false, "guardian set has expired");
64+
}
65+
}
66+
67+
// Parse Signatures
68+
uint256 signersLen = UnsafeCalldataBytesLib.toUint8(
69+
encodedVM,
70+
index
71+
);
72+
index += 1;
73+
{
74+
// 66 is the length of each signature
75+
// 1 (guardianIndex) + 32 (r) + 32 (s) + 1 (v)
76+
uint hashIndex = index + (signersLen * 66);
77+
if (hashIndex > encodedVM.length) {
78+
return (vm, false, "invalid signature length");
79+
}
80+
// Hash the body
81+
vm.hash = keccak256(
82+
abi.encodePacked(
83+
keccak256(
84+
UnsafeCalldataBytesLib.sliceFrom(
85+
encodedVM,
86+
hashIndex
87+
)
88+
)
89+
)
90+
);
91+
}
92+
93+
{
94+
uint8 lastIndex = 0;
95+
for (uint i = 0; i < signersLen; i++) {
96+
ReceiverStructs.Signature memory sig;
97+
sig.guardianIndex = UnsafeCalldataBytesLib.toUint8(
98+
encodedVM,
99+
index
100+
);
101+
index += 1;
102+
103+
sig.r = UnsafeCalldataBytesLib.toBytes32(encodedVM, index);
104+
index += 32;
105+
sig.s = UnsafeCalldataBytesLib.toBytes32(encodedVM, index);
106+
index += 32;
107+
sig.v =
108+
UnsafeCalldataBytesLib.toUint8(encodedVM, index) +
109+
27;
110+
index += 1;
111+
bool signatureValid;
112+
string memory invalidReason;
113+
(signatureValid, invalidReason) = verifySignature(
114+
i,
115+
lastIndex,
116+
vm.hash,
117+
sig.guardianIndex,
118+
sig.r,
119+
sig.s,
120+
sig.v,
121+
guardianSet.keys[sig.guardianIndex]
122+
);
123+
if (!signatureValid) {
124+
return (vm, false, invalidReason);
125+
}
126+
lastIndex = sig.guardianIndex;
127+
}
128+
}
129+
130+
/**
131+
* @dev We're using a fixed point number transformation with 1 decimal to deal with rounding.
132+
* WARNING: This quorum check is critical to assessing whether we have enough Guardian signatures to validate a VM
133+
* if making any changes to this, obtain additional peer review. If guardianSet key length is 0 and
134+
* vm.signatures length is 0, this could compromise the integrity of both vm and signature verification.
135+
*/
136+
137+
if (
138+
(((guardianSet.keys.length * 10) / 3) * 2) / 10 + 1 > signersLen
139+
) {
140+
return (vm, false, "no quorum");
141+
}
142+
143+
// purposely setting vm.signatures to empty array since we don't need it anymore
144+
// and we've already verified it above
145+
vm.signatures = new ReceiverStructs.Signature[](0);
146+
147+
// Parse the body
148+
vm.timestamp = UnsafeCalldataBytesLib.toUint32(encodedVM, index);
149+
index += 4;
150+
151+
vm.nonce = UnsafeCalldataBytesLib.toUint32(encodedVM, index);
152+
index += 4;
153+
154+
vm.emitterChainId = UnsafeCalldataBytesLib.toUint16(
155+
encodedVM,
156+
index
157+
);
158+
index += 2;
159+
160+
vm.emitterAddress = UnsafeCalldataBytesLib.toBytes32(
161+
encodedVM,
162+
index
163+
);
164+
index += 32;
165+
166+
vm.sequence = UnsafeCalldataBytesLib.toUint64(encodedVM, index);
167+
index += 8;
168+
169+
vm.consistencyLevel = UnsafeCalldataBytesLib.toUint8(
170+
encodedVM,
171+
index
172+
);
173+
index += 1;
174+
175+
if (index > encodedVM.length) {
176+
return (vm, false, "invalid payload length");
177+
}
178+
179+
vm.payload = UnsafeCalldataBytesLib.sliceFrom(encodedVM, index);
180+
181+
return (vm, true, "");
182+
}
24183
}
25184

26185
/**
@@ -84,6 +243,27 @@ contract ReceiverMessages is ReceiverGetters {
84243
return (true, "");
85244
}
86245

246+
function verifySignature(
247+
uint i,
248+
uint8 lastIndex,
249+
bytes32 hash,
250+
uint8 guardianIndex,
251+
bytes32 r,
252+
bytes32 s,
253+
uint8 v,
254+
address guardianSetKey
255+
) private pure returns (bool valid, string memory reason) {
256+
/// Ensure that provided signature indices are ascending only
257+
if (i != 0 && guardianIndex <= lastIndex) {
258+
revert SignatureIndexesNotAscending();
259+
}
260+
/// Check to see if the signer of the signature does not match a specific Guardian key at the provided index
261+
if (ecrecover(hash, v, r, s) != guardianSetKey) {
262+
return (false, "VM signature invalid");
263+
}
264+
return (true, "");
265+
}
266+
87267
/**
88268
* @dev verifySignatures serves to validate arbitrary sigatures against an arbitrary guardianSet
89269
* - it intentionally does not solve for expectations within guardianSet (you should use verifyVM if you need these protections)
@@ -98,21 +278,20 @@ contract ReceiverMessages is ReceiverGetters {
98278
uint8 lastIndex = 0;
99279
for (uint i = 0; i < signatures.length; i++) {
100280
ReceiverStructs.Signature memory sig = signatures[i];
101-
102-
/// Ensure that provided signature indices are ascending only
103-
require(
104-
i == 0 || sig.guardianIndex > lastIndex,
105-
"signature indices must be ascending"
106-
);
107-
lastIndex = sig.guardianIndex;
108-
109-
/// Check to see if the signer of the signature does not match a specific Guardian key at the provided index
110-
if (
111-
ecrecover(hash, sig.v, sig.r, sig.s) !=
281+
(valid, reason) = verifySignature(
282+
i,
283+
lastIndex,
284+
hash,
285+
sig.guardianIndex,
286+
sig.r,
287+
sig.s,
288+
sig.v,
112289
guardianSet.keys[sig.guardianIndex]
113-
) {
114-
return (false, "VM signature invalid");
290+
);
291+
if (!valid) {
292+
return (false, reason);
115293
}
294+
lastIndex = sig.guardianIndex;
116295
}
117296

118297
/// If we are here, we've validated that the provided signatures are valid for the provided guardianSet
@@ -130,7 +309,9 @@ contract ReceiverMessages is ReceiverGetters {
130309

131310
vm.version = encodedVM.toUint8(index);
132311
index += 1;
133-
require(vm.version == 1, "VM version incompatible");
312+
if (vm.version != 1) {
313+
revert VmVersionIncompatible();
314+
}
134315

135316
vm.guardianSetIndex = encodedVM.toUint32(index);
136317
index += 4;

target_chains/ethereum/contracts/forge-test/GasBenchmark.t.sol

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ contract GasBenchmark is Test, WormholeTestUtils, PythTestUtils {
2525
// We will have less than 512 price for a foreseeable future.
2626
uint8 constant MERKLE_TREE_DEPTH = 9;
2727

28+
IWormhole public wormhole;
2829
IPyth public pyth;
2930

3031
bytes32[] priceIds;
@@ -51,7 +52,9 @@ contract GasBenchmark is Test, WormholeTestUtils, PythTestUtils {
5152
uint randSeed;
5253

5354
function setUp() public {
54-
pyth = IPyth(setUpPyth(setUpWormhole(NUM_GUARDIANS)));
55+
address wormholeAddr = setUpWormholeReceiver(NUM_GUARDIANS);
56+
wormhole = IWormhole(wormholeAddr);
57+
pyth = IPyth(setUpPyth(wormholeAddr));
5558

5659
priceIds = new bytes32[](NUM_PRICES);
5760
priceIds[0] = bytes32(
@@ -101,7 +104,6 @@ contract GasBenchmark is Test, WormholeTestUtils, PythTestUtils {
101104
freshPricesWhMerkleUpdateData.push(updateData);
102105
freshPricesWhMerkleUpdateFee.push(updateFee);
103106
}
104-
105107
// Populate the contract with the initial prices
106108
(
107109
cachedPricesWhBatchUpdateData,
@@ -417,4 +419,8 @@ contract GasBenchmark is Test, WormholeTestUtils, PythTestUtils {
417419
function testBenchmarkGetUpdateFeeWhMerkle5() public view {
418420
pyth.getUpdateFee(freshPricesWhMerkleUpdateData[4]);
419421
}
422+
423+
function testBenchmarkWormholeParseAndVerifyVMBatchAttestation() public {
424+
wormhole.parseAndVerifyVM(freshPricesWhBatchUpdateData[0]);
425+
}
420426
}

target_chains/ethereum/contracts/forge-test/Pyth.WormholeMerkleAccumulator.t.sol

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ contract PythWormholeMerkleAccumulatorTest is
2626
uint64 constant MAX_UINT64 = uint64(int64(-1));
2727

2828
function setUp() public {
29-
pyth = IPyth(setUpPyth(setUpWormhole(1)));
29+
pyth = IPyth(setUpPyth(setUpWormholeReceiver(1)));
3030
}
3131

3232
function assertPriceFeedMessageStored(
@@ -476,13 +476,6 @@ contract PythWormholeMerkleAccumulatorTest is
476476
);
477477
}
478478

479-
function isNotMatch(
480-
bytes memory a,
481-
bytes memory b
482-
) public pure returns (bool) {
483-
return keccak256(a) != keccak256(b);
484-
}
485-
486479
/// @notice This method creates a forged invalid wormhole update data.
487480
/// The caller should pass the forgeItem as string and if it matches the
488481
/// expected value, that item will be forged to be invalid.

target_chains/ethereum/contracts/forge-test/Pyth.t.sol

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ contract PythTest is Test, WormholeTestUtils, PythTestUtils, RandTestUtils {
1919
uint64 constant MAX_UINT64 = uint64(int64(-1));
2020

2121
function setUp() public {
22-
pyth = IPyth(setUpPyth(setUpWormhole(1)));
22+
pyth = IPyth(setUpPyth(setUpWormholeReceiver(1)));
2323
}
2424

2525
function generateRandomPriceAttestations(

0 commit comments

Comments
 (0)