1
+
1
2
#[ cfg( test) ]
2
3
mod test {
3
4
use crate :: PythReceiver ;
5
+ use crate :: error:: PythReceiverError ;
4
6
use alloy_primitives:: { address, U256 } ;
5
7
use stylus_sdk:: testing:: * ;
8
+ use pythnet_sdk:: wire:: v1:: PYTHNET_ACCUMULATOR_UPDATE_MAGIC ;
9
+
10
+ fn initialize_test_contract ( vm : & TestVM ) -> PythReceiver {
11
+ let mut contract = PythReceiver :: from ( vm) ;
12
+ let wormhole_address = address ! ( "0x3F38404A2e3Cb949bcDfA19a5C3bDf3fE375fEb0" ) ;
13
+ let single_update_fee = U256 :: from ( 100u64 ) ;
14
+ let valid_time_period = U256 :: from ( 3600u64 ) ;
15
+
16
+ let data_source_chain_ids = vec ! [ 1u16 , 2u16 ] ;
17
+ let data_source_emitter_addresses = vec ! [
18
+ [ 1u8 ; 32 ] ,
19
+ [ 2u8 ; 32 ] ,
20
+ ] ;
21
+
22
+ let governance_chain_id = 1u16 ;
23
+ let governance_emitter_address = [ 3u8 ; 32 ] ;
24
+ let governance_initial_sequence = 0u64 ;
25
+ let data = vec ! [ ] ;
26
+
27
+ contract. initialize (
28
+ wormhole_address,
29
+ single_update_fee,
30
+ valid_time_period,
31
+ data_source_chain_ids,
32
+ data_source_emitter_addresses,
33
+ governance_chain_id,
34
+ governance_emitter_address,
35
+ governance_initial_sequence,
36
+ data,
37
+ ) ;
38
+ contract
39
+ }
40
+
41
+ fn create_valid_update_data ( ) -> Vec < u8 > {
42
+ let mut data = Vec :: new ( ) ;
43
+ data. extend_from_slice ( PYTHNET_ACCUMULATOR_UPDATE_MAGIC ) ;
44
+ data. extend_from_slice ( & [ 0u8 ; 100 ] ) ;
45
+ data
46
+ }
47
+
48
+ fn create_invalid_magic_data ( ) -> Vec < u8 > {
49
+ let mut data = Vec :: new ( ) ;
50
+ data. extend_from_slice ( & [ 0xFF , 0xFF , 0xFF , 0xFF ] ) ; // Invalid magic
51
+ data. extend_from_slice ( & [ 0u8 ; 100 ] ) ;
52
+ data
53
+ }
54
+
55
+ fn create_short_data ( ) -> Vec < u8 > {
56
+ vec ! [ 0u8 ; 2 ] // Too short for magic header
57
+ }
58
+
59
+ fn create_invalid_vaa_data ( ) -> Vec < u8 > {
60
+ let mut data = Vec :: new ( ) ;
61
+ data. extend_from_slice ( PYTHNET_ACCUMULATOR_UPDATE_MAGIC ) ;
62
+ data. extend_from_slice ( & [ 0u8 ; 50 ] ) ;
63
+ data
64
+ }
65
+
66
+ fn create_invalid_merkle_data ( ) -> Vec < u8 > {
67
+ let mut data = Vec :: new ( ) ;
68
+ data. extend_from_slice ( PYTHNET_ACCUMULATOR_UPDATE_MAGIC ) ;
69
+ data. extend_from_slice ( & [ 1u8 ; 80 ] ) ;
70
+ data
71
+ }
6
72
7
73
#[ test]
8
74
fn test_initialize ( ) {
9
- // Set up test environment
10
75
let vm = TestVM :: default ( ) ;
11
- // Initialize your contract
12
76
let mut contract = PythReceiver :: from ( & vm) ;
13
77
14
78
let wormhole_address = address ! ( "0x3F38404A2e3Cb949bcDfA19a5C3bDf3fE375fEb0" ) ;
15
79
let single_update_fee = U256 :: from ( 100u64 ) ;
16
- let valid_time_period = U256 :: from ( 3600u64 ) ; // 1 hour
80
+ let valid_time_period = U256 :: from ( 3600u64 ) ;
17
81
18
- let data_source_chain_ids = vec ! [ 1u16 , 2u16 ] ; // Ethereum and other chain
82
+ let data_source_chain_ids = vec ! [ 1u16 , 2u16 ] ;
19
83
let data_source_emitter_addresses = vec ! [
20
- [ 1u8 ; 32 ] , // First emitter address
21
- [ 2u8 ; 32 ] , // Second emitter address
84
+ [ 1u8 ; 32 ] ,
85
+ [ 2u8 ; 32 ] ,
22
86
] ;
23
87
24
88
let governance_chain_id = 1u16 ;
25
89
let governance_emitter_address = [ 3u8 ; 32 ] ;
26
90
let governance_initial_sequence = 0u64 ;
27
- let data = vec ! [ ] ; // Empty data for this test
91
+ let data = vec ! [ ] ;
28
92
29
93
contract. initialize (
30
94
wormhole_address,
@@ -39,13 +103,170 @@ mod test {
39
103
) ;
40
104
41
105
let fee = contract. get_update_fee ( vec ! [ ] ) ;
42
- assert_eq ! ( fee, U256 :: from( 0u8 ) ) ; // Should return 0 as per implementation
106
+ assert_eq ! ( fee, U256 :: from( 0u8 ) ) ; // Fee calculation not implemented yet
43
107
44
108
let twap_fee = contract. get_twap_update_fee ( vec ! [ ] ) ;
45
- assert_eq ! ( twap_fee, U256 :: from( 0u8 ) ) ; // Should return 0 as per implementation
109
+ assert_eq ! ( twap_fee, U256 :: from( 0u8 ) ) ; // Fee calculation not implemented yet
46
110
47
111
let test_price_id = [ 0u8 ; 32 ] ;
48
112
let price_result = contract. get_price_unsafe ( test_price_id) ;
49
- assert ! ( price_result. is_err( ) ) ; // Should return error for non-existent price
113
+ assert ! ( price_result. is_err( ) ) ;
114
+ assert ! ( matches!( price_result. unwrap_err( ) , PythReceiverError :: PriceUnavailable ) ) ;
115
+ }
116
+
117
+ #[ test]
118
+ fn test_update_new_price_feed ( ) {
119
+ let vm = TestVM :: default ( ) ;
120
+ let mut contract = initialize_test_contract ( & vm) ;
121
+
122
+ let test_price_id = [ 1u8 ; 32 ] ;
123
+
124
+ let update_data = create_valid_update_data ( ) ;
125
+ let result = contract. update_price_feeds (
126
+ update_data,
127
+ ) ;
128
+
129
+
130
+ let price_result = contract. get_price_unsafe ( test_price_id) ;
131
+ assert ! ( price_result. is_err( ) ) ;
132
+ assert ! ( matches!( price_result. unwrap_err( ) , PythReceiverError :: PriceUnavailable ) ) ;
133
+ }
134
+
135
+ #[ test]
136
+ fn test_update_existing_price_feed ( ) {
137
+ let vm = TestVM :: default ( ) ;
138
+ let mut contract = initialize_test_contract ( & vm) ;
139
+
140
+ let test_price_id = [ 1u8 ; 32 ] ;
141
+
142
+ let update_data1 = create_valid_update_data ( ) ;
143
+ let result1 = contract. update_price_feeds_internal (
144
+ update_data1,
145
+ vec ! [ ] ,
146
+ 0 ,
147
+ u64:: MAX ,
148
+ false
149
+ ) ;
150
+
151
+ let update_data2 = create_valid_update_data ( ) ;
152
+ let result2 = contract. update_price_feeds_internal (
153
+ update_data2,
154
+ vec ! [ ] ,
155
+ 0 ,
156
+ u64:: MAX ,
157
+ false
158
+ ) ;
159
+
160
+ }
161
+
162
+ #[ test]
163
+ fn test_invalid_magic_header ( ) {
164
+ let vm = TestVM :: default ( ) ;
165
+ let mut contract = initialize_test_contract ( & vm) ;
166
+
167
+ let invalid_data = create_invalid_magic_data ( ) ;
168
+ let result = contract. update_price_feeds_internal (
169
+ invalid_data,
170
+ vec ! [ ] ,
171
+ 0 ,
172
+ u64:: MAX ,
173
+ false
174
+ ) ;
175
+
176
+ assert ! ( result. is_err( ) ) ;
177
+ assert ! ( matches!( result. unwrap_err( ) , PythReceiverError :: InvalidAccumulatorMessage ) ) ;
178
+ }
179
+
180
+ #[ test]
181
+ fn test_invalid_wire_format ( ) {
182
+ let vm = TestVM :: default ( ) ;
183
+ let mut contract = initialize_test_contract ( & vm) ;
184
+
185
+ let short_data = create_short_data ( ) ;
186
+ let result = contract. update_price_feeds_internal (
187
+ short_data,
188
+ vec ! [ ] ,
189
+ 0 ,
190
+ u64:: MAX ,
191
+ false
192
+ ) ;
193
+
194
+ assert ! ( result. is_err( ) ) ;
195
+ assert ! ( matches!( result. unwrap_err( ) , PythReceiverError :: InvalidUpdateData ) ) ;
196
+ }
197
+
198
+ #[ test]
199
+ fn test_invalid_wormhole_vaa ( ) {
200
+ let vm = TestVM :: default ( ) ;
201
+ let mut contract = initialize_test_contract ( & vm) ;
202
+
203
+ let invalid_vaa_data = create_invalid_vaa_data ( ) ;
204
+ let result = contract. update_price_feeds_internal (
205
+ invalid_vaa_data,
206
+ vec ! [ ] ,
207
+ 0 ,
208
+ u64:: MAX ,
209
+ false
210
+ ) ;
211
+
212
+ assert ! ( result. is_err( ) ) ;
213
+ }
214
+
215
+ #[ test]
216
+ fn test_invalid_merkle_proof ( ) {
217
+ let vm = TestVM :: default ( ) ;
218
+ let mut contract = initialize_test_contract ( & vm) ;
219
+
220
+ let invalid_merkle_data = create_invalid_merkle_data ( ) ;
221
+ let result = contract. update_price_feeds_internal (
222
+ invalid_merkle_data,
223
+ vec ! [ ] ,
224
+ 0 ,
225
+ u64:: MAX ,
226
+ false
227
+ ) ;
228
+
229
+ assert ! ( result. is_err( ) ) ;
230
+ }
231
+
232
+ #[ test]
233
+ fn test_stale_price_rejection ( ) {
234
+ let vm = TestVM :: default ( ) ;
235
+ let mut contract = initialize_test_contract ( & vm) ;
236
+
237
+ let test_price_id = [ 1u8 ; 32 ] ;
238
+ let price_result = contract. get_price_unsafe ( test_price_id) ;
239
+ assert ! ( price_result. is_err( ) ) ;
240
+ assert ! ( matches!( price_result. unwrap_err( ) , PythReceiverError :: PriceUnavailable ) ) ;
241
+
242
+ }
243
+
244
+ #[ test]
245
+ fn test_get_price_no_older_than_error ( ) {
246
+ let vm = TestVM :: default ( ) ;
247
+ let mut contract = initialize_test_contract ( & vm) ;
248
+
249
+ let test_price_id = [ 1u8 ; 32 ] ;
250
+ let result = contract. get_price_no_older_than ( test_price_id, 1 ) ;
251
+
252
+ assert ! ( result. is_err( ) ) ;
253
+ assert ! ( matches!( result. unwrap_err( ) , PythReceiverError :: PriceUnavailable ) ) ;
254
+
255
+ }
256
+
257
+ #[ test]
258
+ fn test_contract_state_after_init ( ) {
259
+ let vm = TestVM :: default ( ) ;
260
+ let contract = initialize_test_contract ( & vm) ;
261
+
262
+ let fee = contract. get_update_fee ( vec ! [ ] ) ;
263
+ assert_eq ! ( fee, U256 :: from( 0u8 ) ) ;
264
+
265
+ let random_price_id = [ 42u8 ; 32 ] ;
266
+ let price_result = contract. get_price_unsafe ( random_price_id) ;
267
+ assert ! ( price_result. is_err( ) ) ;
268
+
269
+ let price_no_older_result = contract. get_price_no_older_than ( random_price_id, 3600 ) ;
270
+ assert ! ( price_no_older_result. is_err( ) ) ;
50
271
}
51
272
}
0 commit comments