Skip to content

Commit 7a5ac7d

Browse files
authored
feat(target_chains/starknet): add get_update_fee method (#1613)
1 parent 8887a09 commit 7a5ac7d

File tree

7 files changed

+54
-47
lines changed

7 files changed

+54
-47
lines changed

target_chains/starknet/contracts/src/pyth.cairo

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ pub use interface::{IPyth, IPythDispatcher, IPythDispatcherTrait, DataSource, Pr
1515
#[starknet::contract]
1616
mod pyth {
1717
use super::price_update::{
18-
PriceInfo, PriceFeedMessage, read_and_verify_message, read_header_and_wormhole_proof,
18+
PriceInfo, PriceFeedMessage, read_and_verify_message, read_and_verify_header,
1919
parse_wormhole_proof
2020
};
2121
use pyth::reader::{Reader, ReaderImpl};
@@ -177,7 +177,10 @@ mod pyth {
177177

178178
fn update_price_feeds(ref self: ContractState, data: ByteArray) {
179179
let mut reader = ReaderImpl::new(data);
180-
let wormhole_proof = read_header_and_wormhole_proof(ref reader);
180+
read_and_verify_header(ref reader);
181+
let wormhole_proof_size = reader.read_u16();
182+
let wormhole_proof = reader.read_byte_array(wormhole_proof_size.into());
183+
181184
let wormhole = IWormholeDispatcher { contract_address: self.wormhole_address.read() };
182185
let vm = wormhole.parse_and_verify_vm(wormhole_proof);
183186

@@ -217,6 +220,15 @@ mod pyth {
217220
}
218221
}
219222

223+
fn get_update_fee(self: @ContractState, data: ByteArray) -> u256 {
224+
let mut reader = ReaderImpl::new(data);
225+
read_and_verify_header(ref reader);
226+
let wormhole_proof_size = reader.read_u16();
227+
reader.skip(wormhole_proof_size.into());
228+
let num_updates = reader.read_u8();
229+
self.get_total_fee(num_updates)
230+
}
231+
220232
fn execute_governance_instruction(ref self: ContractState, data: ByteArray) {
221233
let wormhole = IWormholeDispatcher { contract_address: self.wormhole_address.read() };
222234
let vm = wormhole.parse_and_verify_vm(data.clone());
@@ -323,7 +335,7 @@ mod pyth {
323335
}
324336
}
325337

326-
fn get_total_fee(ref self: ContractState, num_updates: u8) -> u256 {
338+
fn get_total_fee(self: @ContractState, num_updates: u8) -> u256 {
327339
self.single_update_fee.read() * num_updates.into()
328340
}
329341

target_chains/starknet/contracts/src/pyth/fake_upgrades.cairo

Lines changed: 12 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,16 @@
1+
use pyth::pyth::{GetPriceUnsafeError, DataSource, Price};
2+
13
// Only used for tests.
24

5+
#[starknet::interface]
6+
pub trait IFakePyth<T> {
7+
fn get_price_unsafe(self: @T, price_id: u256) -> Result<Price, GetPriceUnsafeError>;
8+
fn pyth_upgradable_magic(self: @T) -> u32;
9+
}
10+
311
#[starknet::contract]
412
mod pyth_fake_upgrade1 {
5-
use pyth::pyth::{IPyth, GetPriceUnsafeError, DataSource, Price};
13+
use pyth::pyth::{GetPriceUnsafeError, DataSource, Price};
614
use pyth::byte_array::ByteArray;
715

816
#[storage]
@@ -12,24 +20,13 @@ mod pyth_fake_upgrade1 {
1220
fn constructor(ref self: ContractState) {}
1321

1422
#[abi(embed_v0)]
15-
impl PythImpl of IPyth<ContractState> {
23+
impl PythImpl of super::IFakePyth<ContractState> {
1624
fn get_price_unsafe(
1725
self: @ContractState, price_id: u256
1826
) -> Result<Price, GetPriceUnsafeError> {
1927
let price = Price { price: 42, conf: 2, expo: -5, publish_time: 101, };
2028
Result::Ok(price)
2129
}
22-
fn get_ema_price_unsafe(
23-
self: @ContractState, price_id: u256
24-
) -> Result<Price, GetPriceUnsafeError> {
25-
panic!("unsupported")
26-
}
27-
fn update_price_feeds(ref self: ContractState, data: ByteArray) {
28-
panic!("unsupported")
29-
}
30-
fn execute_governance_instruction(ref self: ContractState, data: ByteArray) {
31-
panic!("unsupported")
32-
}
3330
fn pyth_upgradable_magic(self: @ContractState) -> u32 {
3431
0x97a6f304
3532
}
@@ -38,7 +35,7 @@ mod pyth_fake_upgrade1 {
3835

3936
#[starknet::contract]
4037
mod pyth_fake_upgrade_wrong_magic {
41-
use pyth::pyth::{IPyth, GetPriceUnsafeError, DataSource, Price};
38+
use pyth::pyth::{GetPriceUnsafeError, DataSource, Price};
4239
use pyth::byte_array::ByteArray;
4340

4441
#[storage]
@@ -48,23 +45,12 @@ mod pyth_fake_upgrade_wrong_magic {
4845
fn constructor(ref self: ContractState) {}
4946

5047
#[abi(embed_v0)]
51-
impl PythImpl of IPyth<ContractState> {
48+
impl PythImpl of super::IFakePyth<ContractState> {
5249
fn get_price_unsafe(
5350
self: @ContractState, price_id: u256
5451
) -> Result<Price, GetPriceUnsafeError> {
5552
panic!("unsupported")
5653
}
57-
fn get_ema_price_unsafe(
58-
self: @ContractState, price_id: u256
59-
) -> Result<Price, GetPriceUnsafeError> {
60-
panic!("unsupported")
61-
}
62-
fn update_price_feeds(ref self: ContractState, data: ByteArray) {
63-
panic!("unsupported")
64-
}
65-
fn execute_governance_instruction(ref self: ContractState, data: ByteArray) {
66-
panic!("unsupported")
67-
}
6854
fn pyth_upgradable_magic(self: @ContractState) -> u32 {
6955
606
7056
}

target_chains/starknet/contracts/src/pyth/interface.cairo

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ pub trait IPyth<T> {
66
fn get_price_unsafe(self: @T, price_id: u256) -> Result<Price, GetPriceUnsafeError>;
77
fn get_ema_price_unsafe(self: @T, price_id: u256) -> Result<Price, GetPriceUnsafeError>;
88
fn update_price_feeds(ref self: T, data: ByteArray);
9+
fn get_update_fee(self: @T, data: ByteArray) -> u256;
910
fn execute_governance_instruction(ref self: T, data: ByteArray);
1011
fn pyth_upgradable_magic(self: @T) -> u32;
1112
}

target_chains/starknet/contracts/src/pyth/price_update.cairo

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ pub struct PriceFeedMessage {
6464
pub ema_conf: u64,
6565
}
6666

67-
pub fn read_header_and_wormhole_proof(ref reader: Reader) -> ByteArray {
67+
pub fn read_and_verify_header(ref reader: Reader) {
6868
if reader.read_u32() != ACCUMULATOR_MAGIC {
6969
panic_with_felt252(UpdatePriceFeedsError::InvalidUpdateData.into());
7070
}
@@ -76,7 +76,7 @@ pub fn read_header_and_wormhole_proof(ref reader: Reader) -> ByteArray {
7676
}
7777

7878
let trailing_header_size = reader.read_u8();
79-
reader.skip(trailing_header_size);
79+
reader.skip(trailing_header_size.into());
8080

8181
let update_type: UpdateType = reader
8282
.read_u8()
@@ -86,9 +86,6 @@ pub fn read_header_and_wormhole_proof(ref reader: Reader) -> ByteArray {
8686
match update_type {
8787
UpdateType::WormholeMerkle => {}
8888
}
89-
90-
let wormhole_proof_size = reader.read_u16();
91-
reader.read_byte_array(wormhole_proof_size.into())
9289
}
9390

9491
pub fn parse_wormhole_proof(payload: ByteArray) -> u256 {

target_chains/starknet/contracts/src/reader.cairo

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,13 +86,15 @@ pub impl ReaderImpl of ReaderTrait {
8686
}
8787

8888
// TODO: skip without calculating values
89-
fn skip(ref self: Reader, mut num_bytes: u8) {
89+
fn skip(ref self: Reader, mut num_bytes: usize) {
9090
while num_bytes > 0 {
9191
if num_bytes > 16 {
9292
self.read_num_bytes(16);
9393
num_bytes -= 16;
9494
} else {
95-
self.read_num_bytes(num_bytes);
95+
// num_bytes <= 16 so it shouldn't overflow.
96+
self.read_num_bytes(num_bytes.try_into().expect(UNEXPECTED_OVERFLOW));
97+
break;
9698
}
9799
}
98100
}

target_chains/starknet/contracts/tests/data.cairo

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -404,12 +404,12 @@ pub fn pyth_set_fee_alt_emitter() -> ByteArray {
404404
// A Pyth governance instruction to upgrade the contract signed by the test guardian #1.
405405
pub fn pyth_upgrade_fake1() -> ByteArray {
406406
let bytes = array![
407-
1766847064779997312831656888004304663648863693096357069129843988620764542,
408-
372087717591229137403366610731035855939366039700111817924253553748324215495,
409-
182456855699527413949626507253841174034899423702925950617111971827806109696,
407+
1766847064779994791169214817472264547450542145364282319310439743685771618,
408+
175385590228001769706203572954671062839210335359545531991708252078677402742,
409+
338282801975945534678621806212670914146735662234331326855531973960850735104,
410410
49565958604199796163020368,
411-
148907253453589022235416579439991212386300560409198472807590534281503440988,
412-
7311947531350894019,
411+
148907253453589022320407306335457538262203456299261498528172020674942501293,
412+
9624434269354675143,
413413
];
414414
ByteArrayImpl::new(array_try_into(bytes), 8)
415415
}
@@ -430,12 +430,12 @@ pub fn pyth_upgrade_not_pyth() -> ByteArray {
430430
// A Pyth governance instruction to upgrade the contract signed by the test guardian #1.
431431
pub fn pyth_upgrade_wrong_magic() -> ByteArray {
432432
let bytes = array![
433-
1766847064779991597204876565434227784957683976813807912095437759207426783,
434-
116073945271795196915694593374349818132109707660825728503969190995475470190,
435-
402267564262237040559156170656516235895865250461329916178126910059500797952,
433+
1766847064779993581380818181711092803131812037068363180730038764700119064,
434+
43179698701133869693008541869474965453366967663087320291846878688486859828,
435+
257191826617037171240065659464096594985467828231875472974396182656981139456,
436436
49565958604199796163020368,
437-
148907253453589022397792005599092877068906138702361966208625267621388965397,
438-
10856656060318424790,
437+
148907253453589022340563264373887392414227070562033595690783947835630084766,
438+
5698494087895763928,
439439
];
440440
ByteArrayImpl::new(array_try_into(bytes), 8)
441441
}

target_chains/starknet/contracts/tests/pyth.cairo

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,11 @@ fn update_price_feeds_works() {
8989
let fee_contract = deploy_fee_contract(user);
9090
let pyth = deploy_default(wormhole.contract_address, fee_contract.contract_address);
9191

92+
let fee = pyth.get_update_fee(data::good_update1());
93+
assert!(fee == 1000);
94+
9295
start_prank(CheatTarget::One(fee_contract.contract_address), user.try_into().unwrap());
93-
fee_contract.approve(pyth.contract_address, 10000);
96+
fee_contract.approve(pyth.contract_address, fee);
9497
stop_prank(CheatTarget::One(fee_contract.contract_address));
9598

9699
let mut spy = spy_events(SpyOn::One(pyth.contract_address));
@@ -136,6 +139,9 @@ fn test_governance_set_fee_works() {
136139
let fee_contract = deploy_fee_contract(user);
137140
let pyth = deploy_default(wormhole.contract_address, fee_contract.contract_address);
138141

142+
let fee1 = pyth.get_update_fee(data::test_price_update1());
143+
assert!(fee1 == 1000);
144+
139145
start_prank(CheatTarget::One(fee_contract.contract_address), user);
140146
fee_contract.approve(pyth.contract_address, 10000);
141147
stop_prank(CheatTarget::One(fee_contract.contract_address));
@@ -164,6 +170,9 @@ fn test_governance_set_fee_works() {
164170
let expected = FeeSet { old_fee: 1000, new_fee: 4200, };
165171
assert!(event == PythEvent::FeeSet(expected));
166172

173+
let fee2 = pyth.get_update_fee(data::test_price_update2());
174+
assert!(fee2 == 4200);
175+
167176
start_prank(CheatTarget::One(pyth.contract_address), user);
168177
pyth.update_price_feeds(data::test_price_update2());
169178
stop_prank(CheatTarget::One(pyth.contract_address));

0 commit comments

Comments
 (0)