Skip to content

Commit 28b402a

Browse files
committed
added fee handling and a test for it
1 parent ae38e3a commit 28b402a

File tree

3 files changed

+71
-6
lines changed

3 files changed

+71
-6
lines changed

target_chains/stylus/contracts/pyth-receiver/src/error.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ pub enum PythReceiverError {
1414
InvalidUnknownSource,
1515
NewPriceUnavailable,
1616
InvalidAccumulatorMessageType,
17+
InsufficientFee,
1718
}
1819

1920
impl core::fmt::Debug for PythReceiverError {
@@ -37,6 +38,7 @@ impl From<PythReceiverError> for Vec<u8> {
3738
PythReceiverError::InvalidUnknownSource => 10,
3839
PythReceiverError::NewPriceUnavailable => 11,
3940
PythReceiverError::InvalidAccumulatorMessageType => 12,
41+
PythReceiverError::InsufficientFee => 13,
4042
}]
4143
}
4244
}

target_chains/stylus/contracts/pyth-receiver/src/integration_tests.rs

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,11 @@ mod test {
111111
data,
112112
);
113113

114+
alice.fund(U256::from(200));
115+
114116
let update_data = test_data::good_update1();
115117

116-
let result = pyth_contract.sender(alice).update_price_feeds(update_data);
118+
let result = pyth_contract.sender_and_value(alice, U256::from(100)).update_price_feeds(update_data);
117119
assert!(result.is_ok());
118120

119121
let price_result = pyth_contract.sender(alice).get_price_unsafe(TEST_PRICE_ID);
@@ -130,6 +132,46 @@ mod test {
130132

131133
}
132134

135+
#[motsu::test]
136+
fn test_update_price_feed_insufficient_fee(pyth_contract: Contract<PythReceiver>, wormhole_contract: Contract<WormholeContract>, alice: Address) {
137+
let guardians = current_guardians();
138+
let governance_contract = Address::from_slice(&GOVERNANCE_CONTRACT.to_be_bytes::<32>()[12..32]);
139+
wormhole_contract.sender(alice).initialize(guardians, CHAIN_ID, GOVERNANCE_CHAIN_ID, governance_contract).unwrap();
140+
// let result = wormhole_contract.sender(alice).store_gs(4, current_guardians(), 0);
141+
142+
let single_update_fee = U256::from(100u64);
143+
let valid_time_period = U256::from(3600u64);
144+
145+
let data_source_chain_ids = vec![PYTHNET_CHAIN_ID];
146+
let data_source_emitter_addresses = vec![PYTHNET_EMITTER_ADDRESS];
147+
148+
let governance_chain_id = 1u16;
149+
let governance_emitter_address = [3u8; 32];
150+
let governance_initial_sequence = 0u64;
151+
let data = vec![];
152+
153+
pyth_contract.sender(alice).initialize(
154+
wormhole_contract.address(),
155+
single_update_fee,
156+
valid_time_period,
157+
data_source_chain_ids,
158+
data_source_emitter_addresses,
159+
governance_chain_id,
160+
governance_emitter_address,
161+
governance_initial_sequence,
162+
data,
163+
);
164+
165+
alice.fund(U256::from(50));
166+
167+
let update_data = test_data::good_update1();
168+
169+
let result = pyth_contract.sender_and_value(alice, U256::from(50)).update_price_feeds(update_data);
170+
assert!(result.is_err());
171+
assert_eq!(result.unwrap_err(), PythReceiverError::InsufficientFee);
172+
173+
}
174+
133175
#[motsu::test]
134176
fn test_get_price_after_multiple_updates(pyth_contract: Contract<PythReceiver>, wormhole_contract: Contract<WormholeContract>, alice: Address) {
135177
let guardians = current_guardians();
@@ -159,12 +201,14 @@ mod test {
159201
data,
160202
);
161203

204+
alice.fund(U256::from(200));
205+
162206
let update_data1 = test_data::good_update1();
163-
let result1 = pyth_contract.sender(alice).update_price_feeds(update_data1);
207+
let result1 = pyth_contract.sender_and_value(alice, U256::from(100)).update_price_feeds(update_data1);
164208
assert!(result1.is_ok());
165209

166210
let update_data2 = test_data::good_update2();
167-
let result2 = pyth_contract.sender(alice).update_price_feeds(update_data2);
211+
let result2 = pyth_contract.sender_and_value(alice, U256::from(100)).update_price_feeds(update_data2);
168212
assert!(result2.is_ok());
169213

170214
let price_result = pyth_contract.sender(alice).get_price_unsafe(TEST_PRICE_ID);
@@ -281,8 +325,10 @@ mod test {
281325
data,
282326
);
283327

328+
alice.fund(U256::from(200));
329+
284330
let update_data = test_data::good_update2();
285-
let result = pyth_contract.sender(alice).update_price_feeds(update_data);
331+
let result = pyth_contract.sender_and_value(alice, U256::from(100)).update_price_feeds(update_data);
286332
assert!(result.is_ok());
287333

288334
let price_result = pyth_contract.sender(alice).get_price_no_older_than(TEST_PRICE_ID, u64::MAX);
@@ -325,9 +371,10 @@ mod test {
325371
governance_initial_sequence,
326372
data,
327373
);
374+
alice.fund(U256::from(200));
328375

329376
let update_data = test_data::good_update2();
330-
let result = pyth_contract.sender(alice).update_price_feeds(update_data);
377+
let result = pyth_contract.sender_and_value(alice, U256::from(100)).update_price_feeds(update_data);
331378
assert!(result.is_ok());
332379

333380
let price_result = pyth_contract.sender(alice).get_price_no_older_than(TEST_PRICE_ID, 1);
@@ -363,9 +410,11 @@ mod test {
363410
governance_initial_sequence,
364411
data,
365412
);
413+
414+
alice.fund(U256::from(200));
366415

367416
let update_data = test_data::multiple_updates();
368-
let result = pyth_contract.sender(alice).update_price_feeds(update_data);
417+
let result = pyth_contract.sender_and_value(alice, U256::from(200)).update_price_feeds(update_data);
369418
assert!(result.is_ok());
370419

371420
let first_id: [u8; 32] = [

target_chains/stylus/contracts/pyth-receiver/src/lib.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ impl PythReceiver {
149149
Ok(price_info)
150150
}
151151

152+
#[payable]
152153
pub fn update_price_feeds(&mut self, update_data: Vec<u8>) -> Result<(), PythReceiverError> {
153154
self.update_price_feeds_internal(update_data)?;
154155
Ok(())
@@ -202,6 +203,15 @@ impl PythReceiver {
202203

203204
let root_digest: MerkleRoot<Keccak160> = parse_wormhole_proof(vaa).unwrap();
204205

206+
let num_updates = u8::try_from(updates.len()).expect("value doesn't fit in u8");
207+
let total_fee = self.get_total_fee(num_updates);
208+
209+
let value = self.vm().msg_value();
210+
211+
if value < total_fee {
212+
return Err(PythReceiverError::InsufficientFee);
213+
}
214+
205215
for update in updates {
206216

207217
let message_vec = Vec::from(update.message);
@@ -242,6 +252,10 @@ impl PythReceiver {
242252
Ok(())
243253
}
244254

255+
fn get_total_fee(&self, num_updates: u8) -> U256 {
256+
U256::from(num_updates).saturating_mul(self.single_update_fee_in_wei.get())
257+
}
258+
245259
// pub fn get_update_fee(&self, _update_data: Vec<Vec<u8>>) -> U256 {
246260
// U256::from(0u8)
247261
// }

0 commit comments

Comments
 (0)