Skip to content

Commit 1caff31

Browse files
committed
framework for tests
1 parent 47479f8 commit 1caff31

File tree

2 files changed

+236
-13
lines changed

2 files changed

+236
-13
lines changed
Lines changed: 231 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,94 @@
1+
12
#[cfg(test)]
23
mod test {
34
use crate::PythReceiver;
5+
use crate::error::PythReceiverError;
46
use alloy_primitives::{address, U256};
57
use stylus_sdk::testing::*;
8+
use pythnet_sdk::wire::v1::PYTHNET_ACCUMULATOR_UPDATE_MAGIC;
9+
10+
fn initialize_test_contract(vm: &TestVM) -> PythReceiver {
11+
let mut contract = PythReceiver::from(vm);
12+
let wormhole_address = address!("0x3F38404A2e3Cb949bcDfA19a5C3bDf3fE375fEb0");
13+
let single_update_fee = U256::from(100u64);
14+
let valid_time_period = U256::from(3600u64);
15+
16+
let data_source_chain_ids = vec![1u16, 2u16];
17+
let data_source_emitter_addresses = vec![
18+
[1u8; 32],
19+
[2u8; 32],
20+
];
21+
22+
let governance_chain_id = 1u16;
23+
let governance_emitter_address = [3u8; 32];
24+
let governance_initial_sequence = 0u64;
25+
let data = vec![];
26+
27+
contract.initialize(
28+
wormhole_address,
29+
single_update_fee,
30+
valid_time_period,
31+
data_source_chain_ids,
32+
data_source_emitter_addresses,
33+
governance_chain_id,
34+
governance_emitter_address,
35+
governance_initial_sequence,
36+
data,
37+
);
38+
contract
39+
}
40+
41+
fn create_valid_update_data() -> Vec<u8> {
42+
let mut data = Vec::new();
43+
data.extend_from_slice(PYTHNET_ACCUMULATOR_UPDATE_MAGIC);
44+
data.extend_from_slice(&[0u8; 100]);
45+
data
46+
}
47+
48+
fn create_invalid_magic_data() -> Vec<u8> {
49+
let mut data = Vec::new();
50+
data.extend_from_slice(&[0xFF, 0xFF, 0xFF, 0xFF]); // Invalid magic
51+
data.extend_from_slice(&[0u8; 100]);
52+
data
53+
}
54+
55+
fn create_short_data() -> Vec<u8> {
56+
vec![0u8; 2] // Too short for magic header
57+
}
58+
59+
fn create_invalid_vaa_data() -> Vec<u8> {
60+
let mut data = Vec::new();
61+
data.extend_from_slice(PYTHNET_ACCUMULATOR_UPDATE_MAGIC);
62+
data.extend_from_slice(&[0u8; 50]);
63+
data
64+
}
65+
66+
fn create_invalid_merkle_data() -> Vec<u8> {
67+
let mut data = Vec::new();
68+
data.extend_from_slice(PYTHNET_ACCUMULATOR_UPDATE_MAGIC);
69+
data.extend_from_slice(&[1u8; 80]);
70+
data
71+
}
672

773
#[test]
874
fn test_initialize() {
9-
// Set up test environment
1075
let vm = TestVM::default();
11-
// Initialize your contract
1276
let mut contract = PythReceiver::from(&vm);
1377

1478
let wormhole_address = address!("0x3F38404A2e3Cb949bcDfA19a5C3bDf3fE375fEb0");
1579
let single_update_fee = U256::from(100u64);
16-
let valid_time_period = U256::from(3600u64); // 1 hour
80+
let valid_time_period = U256::from(3600u64);
1781

18-
let data_source_chain_ids = vec![1u16, 2u16]; // Ethereum and other chain
82+
let data_source_chain_ids = vec![1u16, 2u16];
1983
let data_source_emitter_addresses = vec![
20-
[1u8; 32], // First emitter address
21-
[2u8; 32], // Second emitter address
84+
[1u8; 32],
85+
[2u8; 32],
2286
];
2387

2488
let governance_chain_id = 1u16;
2589
let governance_emitter_address = [3u8; 32];
2690
let governance_initial_sequence = 0u64;
27-
let data = vec![]; // Empty data for this test
91+
let data = vec![];
2892

2993
contract.initialize(
3094
wormhole_address,
@@ -39,13 +103,170 @@ mod test {
39103
);
40104

41105
let fee = contract.get_update_fee(vec![]);
42-
assert_eq!(fee, U256::from(0u8)); // Should return 0 as per implementation
106+
assert_eq!(fee, U256::from(0u8)); // Fee calculation not implemented yet
43107

44108
let twap_fee = contract.get_twap_update_fee(vec![]);
45-
assert_eq!(twap_fee, U256::from(0u8)); // Should return 0 as per implementation
109+
assert_eq!(twap_fee, U256::from(0u8)); // Fee calculation not implemented yet
46110

47111
let test_price_id = [0u8; 32];
48112
let price_result = contract.get_price_unsafe(test_price_id);
49-
assert!(price_result.is_err()); // Should return error for non-existent price
113+
assert!(price_result.is_err());
114+
assert!(matches!(price_result.unwrap_err(), PythReceiverError::PriceUnavailable));
115+
}
116+
117+
#[test]
118+
fn test_update_new_price_feed() {
119+
let vm = TestVM::default();
120+
let mut contract = initialize_test_contract(&vm);
121+
122+
let test_price_id = [1u8; 32];
123+
124+
let update_data = create_valid_update_data();
125+
let result = contract.update_price_feeds(
126+
update_data,
127+
);
128+
129+
130+
let price_result = contract.get_price_unsafe(test_price_id);
131+
assert!(price_result.is_err());
132+
assert!(matches!(price_result.unwrap_err(), PythReceiverError::PriceUnavailable));
133+
}
134+
135+
#[test]
136+
fn test_update_existing_price_feed() {
137+
let vm = TestVM::default();
138+
let mut contract = initialize_test_contract(&vm);
139+
140+
let test_price_id = [1u8; 32];
141+
142+
let update_data1 = create_valid_update_data();
143+
let result1 = contract.update_price_feeds_internal(
144+
update_data1,
145+
vec![],
146+
0,
147+
u64::MAX,
148+
false
149+
);
150+
151+
let update_data2 = create_valid_update_data();
152+
let result2 = contract.update_price_feeds_internal(
153+
update_data2,
154+
vec![],
155+
0,
156+
u64::MAX,
157+
false
158+
);
159+
160+
}
161+
162+
#[test]
163+
fn test_invalid_magic_header() {
164+
let vm = TestVM::default();
165+
let mut contract = initialize_test_contract(&vm);
166+
167+
let invalid_data = create_invalid_magic_data();
168+
let result = contract.update_price_feeds_internal(
169+
invalid_data,
170+
vec![],
171+
0,
172+
u64::MAX,
173+
false
174+
);
175+
176+
assert!(result.is_err());
177+
assert!(matches!(result.unwrap_err(), PythReceiverError::InvalidAccumulatorMessage));
178+
}
179+
180+
#[test]
181+
fn test_invalid_wire_format() {
182+
let vm = TestVM::default();
183+
let mut contract = initialize_test_contract(&vm);
184+
185+
let short_data = create_short_data();
186+
let result = contract.update_price_feeds_internal(
187+
short_data,
188+
vec![],
189+
0,
190+
u64::MAX,
191+
false
192+
);
193+
194+
assert!(result.is_err());
195+
assert!(matches!(result.unwrap_err(), PythReceiverError::InvalidUpdateData));
196+
}
197+
198+
#[test]
199+
fn test_invalid_wormhole_vaa() {
200+
let vm = TestVM::default();
201+
let mut contract = initialize_test_contract(&vm);
202+
203+
let invalid_vaa_data = create_invalid_vaa_data();
204+
let result = contract.update_price_feeds_internal(
205+
invalid_vaa_data,
206+
vec![],
207+
0,
208+
u64::MAX,
209+
false
210+
);
211+
212+
assert!(result.is_err());
213+
}
214+
215+
#[test]
216+
fn test_invalid_merkle_proof() {
217+
let vm = TestVM::default();
218+
let mut contract = initialize_test_contract(&vm);
219+
220+
let invalid_merkle_data = create_invalid_merkle_data();
221+
let result = contract.update_price_feeds_internal(
222+
invalid_merkle_data,
223+
vec![],
224+
0,
225+
u64::MAX,
226+
false
227+
);
228+
229+
assert!(result.is_err());
230+
}
231+
232+
#[test]
233+
fn test_stale_price_rejection() {
234+
let vm = TestVM::default();
235+
let mut contract = initialize_test_contract(&vm);
236+
237+
let test_price_id = [1u8; 32];
238+
let price_result = contract.get_price_unsafe(test_price_id);
239+
assert!(price_result.is_err());
240+
assert!(matches!(price_result.unwrap_err(), PythReceiverError::PriceUnavailable));
241+
242+
}
243+
244+
#[test]
245+
fn test_get_price_no_older_than_error() {
246+
let vm = TestVM::default();
247+
let mut contract = initialize_test_contract(&vm);
248+
249+
let test_price_id = [1u8; 32];
250+
let result = contract.get_price_no_older_than(test_price_id, 1);
251+
252+
assert!(result.is_err());
253+
assert!(matches!(result.unwrap_err(), PythReceiverError::PriceUnavailable));
254+
255+
}
256+
257+
#[test]
258+
fn test_contract_state_after_init() {
259+
let vm = TestVM::default();
260+
let contract = initialize_test_contract(&vm);
261+
262+
let fee = contract.get_update_fee(vec![]);
263+
assert_eq!(fee, U256::from(0u8));
264+
265+
let random_price_id = [42u8; 32];
266+
let price_result = contract.get_price_unsafe(random_price_id);
267+
assert!(price_result.is_err());
268+
269+
let price_no_older_result = contract.get_price_no_older_than(random_price_id, 3600);
270+
assert!(price_no_older_result.is_err());
50271
}
51272
}

target_chains/stylus/contracts/pyth-receiver/src/lib.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,9 @@ impl PythReceiver {
128128
(U64::ZERO, I32::ZERO, I64::ZERO, U64::ZERO, I64::ZERO, U64::ZERO)
129129
}
130130

131-
pub fn update_price_feeds(&mut self, update_data: Vec<u8>) {
132-
131+
pub fn update_price_feeds(&mut self, update_data: Vec<u8>) -> Result<(), PythReceiverError> {
132+
self.update_price_feeds_internal(update_data)?;
133+
Ok(())
133134
}
134135

135136
pub fn update_price_feeds_if_necessary(
@@ -141,7 +142,8 @@ impl PythReceiver {
141142
// dummy implementation
142143
}
143144

144-
fn update_price_feeds_internal(&mut self, update_data: Vec<u8>, price_ids: Vec<Address>, min_publish_time: u64, max_publish_time: u64, unique: bool) -> Result<(), PythReceiverError> {
145+
// fn update_price_feeds_internal(&mut self, update_data: Vec<u8>, price_ids: Vec<Address>, min_publish_time: u64, max_publish_time: u64, unique: bool) -> Result<(), PythReceiverError> {
146+
fn update_price_feeds_internal(&mut self, update_data: Vec<u8>) -> Result<(), PythReceiverError> {
145147
let update_data_array: &[u8] = &update_data;
146148
// Check the first 4 bytes of the update_data_array for the magic header
147149
if update_data_array.len() < 4 {

0 commit comments

Comments
 (0)