@@ -10,7 +10,7 @@ fn mock_call_simple() {
10
10
indoc!(
11
11
r#"
12
12
use result::ResultTrait;
13
- use snforge_std::{ declare, ContractClassTrait, DeclareResultTrait, start_mock_call, stop_mock_call, MockCallData };
13
+ use snforge_std::{ declare, ContractClassTrait, DeclareResultTrait, start_mock_call, stop_mock_call };
14
14
15
15
#[starknet::interface]
16
16
trait IMockChecker<TContractState> {
@@ -26,19 +26,13 @@ fn mock_call_simple() {
26
26
27
27
let dispatcher = IMockCheckerDispatcher { contract_address };
28
28
29
- let specific_mock_ret_data = 421;
30
- let default_mock_ret_data = 404;
31
- let expected_calldata = MockCallData::Values([].span());
32
- start_mock_call(contract_address, selector!("get_thing"), expected_calldata, specific_mock_ret_data);
33
- start_mock_call(contract_address, selector!("get_thing"), MockCallData::Any, default_mock_ret_data);
34
- let thing = dispatcher.get_thing();
35
- assert(thing == specific_mock_ret_data, 'Incorrect thing');
29
+ let mock_ret_data = 421;
36
30
37
- stop_mock_call (contract_address, selector!("get_thing"), expected_calldata );
31
+ start_mock_call (contract_address, selector!("get_thing"), mock_ret_data );
38
32
let thing = dispatcher.get_thing();
39
- assert(thing == default_mock_ret_data , 'Incorrect thing');
33
+ assert(thing == 421 , 'Incorrect thing');
40
34
41
- stop_mock_call(contract_address, selector!("get_thing"), MockCallData::Any );
35
+ stop_mock_call(contract_address, selector!("get_thing"));
42
36
let thing = dispatcher.get_thing();
43
37
assert(thing == 420, 'Incorrect thing');
44
38
}
@@ -51,12 +45,12 @@ fn mock_call_simple() {
51
45
let (contract_address, _) = contract.deploy(@calldata).unwrap();
52
46
53
47
let mock_ret_data = 421;
54
- start_mock_call(contract_address, selector!("get_thing"), MockCallData::Any, mock_ret_data);
48
+ start_mock_call(contract_address, selector!("get_thing"), mock_ret_data);
55
49
56
50
let dispatcher = IMockCheckerDispatcher { contract_address };
57
51
let thing = dispatcher.get_thing();
58
52
59
- assert(thing == 421, 'Incorrect thing all catch ');
53
+ assert(thing == 421, 'Incorrect thing');
60
54
}
61
55
"#
62
56
) ,
@@ -79,7 +73,7 @@ fn mock_call_complex_types() {
79
73
use result::ResultTrait;
80
74
use array::ArrayTrait;
81
75
use serde::Serde;
82
- use snforge_std::{ declare, ContractClassTrait, DeclareResultTrait, start_mock_call, MockCallData };
76
+ use snforge_std::{ declare, ContractClassTrait, DeclareResultTrait, start_mock_call };
83
77
84
78
#[starknet::interface]
85
79
trait IMockChecker<TContractState> {
@@ -103,7 +97,7 @@ fn mock_call_complex_types() {
103
97
let dispatcher = IMockCheckerDispatcher { contract_address };
104
98
105
99
let mock_ret_data = StructThing {item_one: 412, item_two: 421};
106
- start_mock_call(contract_address, selector!("get_struct_thing"), MockCallData::Any, mock_ret_data);
100
+ start_mock_call(contract_address, selector!("get_struct_thing"), mock_ret_data);
107
101
108
102
let thing: StructThing = dispatcher.get_struct_thing();
109
103
@@ -121,7 +115,7 @@ fn mock_call_complex_types() {
121
115
let dispatcher = IMockCheckerDispatcher { contract_address };
122
116
123
117
let mock_ret_data = array![ StructThing {item_one: 112, item_two: 121}, StructThing {item_one: 412, item_two: 421} ];
124
- start_mock_call(contract_address, selector!("get_arr_thing"), MockCallData::Any, mock_ret_data);
118
+ start_mock_call(contract_address, selector!("get_arr_thing"), mock_ret_data);
125
119
126
120
let things: Array<StructThing> = dispatcher.get_arr_thing();
127
121
@@ -152,7 +146,7 @@ fn mock_calls() {
152
146
indoc!(
153
147
r#"
154
148
use result::ResultTrait;
155
- use snforge_std::{ declare, ContractClassTrait, DeclareResultTrait, mock_call, MockCallData };
149
+ use snforge_std::{ declare, ContractClassTrait, DeclareResultTrait, mock_call, start_mock_call, stop_mock_call };
156
150
157
151
#[starknet::interface]
158
152
trait IMockChecker<TContractState> {
@@ -170,7 +164,7 @@ fn mock_calls() {
170
164
171
165
let mock_ret_data = 421;
172
166
173
- mock_call(contract_address, selector!("get_thing"), MockCallData::Any, mock_ret_data, 1);
167
+ mock_call(contract_address, selector!("get_thing"), mock_ret_data, 1);
174
168
175
169
let thing = dispatcher.get_thing();
176
170
assert_eq!(thing, 421);
@@ -190,7 +184,7 @@ fn mock_calls() {
190
184
191
185
let mock_ret_data = 421;
192
186
193
- mock_call(contract_address, selector!("get_thing"), MockCallData::Any, mock_ret_data, 2);
187
+ mock_call(contract_address, selector!("get_thing"), mock_ret_data, 2);
194
188
195
189
let thing = dispatcher.get_thing();
196
190
assert_eq!(thing, 421);
@@ -213,3 +207,184 @@ fn mock_calls() {
213
207
let result = run_test_case ( & test) ;
214
208
assert_passed ( & result) ;
215
209
}
210
+
211
+ #[ test]
212
+ fn mock_call_when_simple ( ) {
213
+ let test = test_case ! (
214
+ indoc!(
215
+ r#"
216
+ use result::ResultTrait;
217
+ use snforge_std::{ declare, ContractClassTrait, DeclareResultTrait, start_mock_call_when, stop_mock_call_when, MockCallData };
218
+
219
+ #[starknet::interface]
220
+ trait IMockChecker<TContractState> {
221
+ fn get_thing(ref self: TContractState) -> felt252;
222
+ }
223
+
224
+ #[test]
225
+ fn mock_call_when_simple() {
226
+ let calldata = array![420];
227
+
228
+ let contract = declare("MockChecker").unwrap().contract_class();
229
+ let (contract_address, _) = contract.deploy(@calldata).unwrap();
230
+
231
+ let dispatcher = IMockCheckerDispatcher { contract_address };
232
+
233
+ let specific_mock_ret_data = 421;
234
+ let default_mock_ret_data = 404;
235
+ let expected_calldata = MockCallData::Values([].span());
236
+
237
+ start_mock_call_when(contract_address, selector!("get_thing"), expected_calldata, specific_mock_ret_data);
238
+ start_mock_call_when(contract_address, selector!("get_thing"), MockCallData::Any, default_mock_ret_data);
239
+ let thing = dispatcher.get_thing();
240
+ assert(thing == specific_mock_ret_data, 'Incorrect thing');
241
+
242
+ stop_mock_call_when(contract_address, selector!("get_thing"), expected_calldata);
243
+ let thing = dispatcher.get_thing();
244
+ assert(thing == default_mock_ret_data, 'Incorrect thing');
245
+
246
+ stop_mock_call_when(contract_address, selector!("get_thing"), MockCallData::Any);
247
+ let thing = dispatcher.get_thing();
248
+ assert(thing == 420, 'Incorrect thing');
249
+ }
250
+
251
+ #[test]
252
+ fn mock_call_when_simple_before_dispatcher_created() {
253
+ let calldata = array![420];
254
+
255
+ let contract = declare("MockChecker").unwrap().contract_class();
256
+ let (contract_address, _) = contract.deploy(@calldata).unwrap();
257
+
258
+ let specific_mock_ret_data = 421;
259
+ let default_mock_ret_data = 404;
260
+ let expected_calldata = MockCallData::Values([].span());
261
+
262
+ start_mock_call_when(contract_address, selector!("get_thing"), expected_calldata, specific_mock_ret_data);
263
+ start_mock_call_when(contract_address, selector!("get_thing"), MockCallData::Any, default_mock_ret_data);
264
+ let dispatcher = IMockCheckerDispatcher { contract_address };
265
+ let thing = dispatcher.get_thing();
266
+
267
+ assert(thing == specific_mock_ret_data, 'Incorrect thing');
268
+
269
+ stop_mock_call_when(contract_address, selector!("get_thing"), expected_calldata);
270
+ let thing = dispatcher.get_thing();
271
+ assert(thing == default_mock_ret_data, 'Incorrect thing');
272
+
273
+ stop_mock_call_when(contract_address, selector!("get_thing"), MockCallData::Any);
274
+ let thing = dispatcher.get_thing();
275
+ assert(thing == 420, 'Incorrect thing');
276
+ }
277
+ "#
278
+ ) ,
279
+ Contract :: from_code_path(
280
+ "MockChecker" . to_string( ) ,
281
+ Path :: new( "tests/data/contracts/mock_checker.cairo" ) ,
282
+ )
283
+ . unwrap( )
284
+ ) ;
285
+
286
+ let result = run_test_case ( & test) ;
287
+ assert_passed ( & result) ;
288
+ }
289
+
290
+ #[ test]
291
+ fn mock_call_when_complex_types ( ) {
292
+ let test = test_case ! (
293
+ indoc!(
294
+ r#"
295
+ use result::ResultTrait;
296
+ use array::ArrayTrait;
297
+ use serde::Serde;
298
+ use snforge_std::{ declare, ContractClassTrait, DeclareResultTrait, start_mock_call_when, stop_mock_call_when, MockCallData };
299
+
300
+ #[starknet::interface]
301
+ trait IMockChecker<TContractState> {
302
+ fn get_struct_thing(ref self: TContractState) -> StructThing;
303
+ fn get_arr_thing(ref self: TContractState) -> Array<StructThing>;
304
+ }
305
+
306
+ #[derive(Serde, Drop)]
307
+ struct StructThing {
308
+ item_one: felt252,
309
+ item_two: felt252,
310
+ }
311
+
312
+ #[test]
313
+ fn start_mock_call_when_return_struct() {
314
+ let calldata = array![420];
315
+
316
+ let contract = declare("MockChecker").unwrap().contract_class();
317
+ let (contract_address, _) = contract.deploy(@calldata).unwrap();
318
+
319
+ let dispatcher = IMockCheckerDispatcher { contract_address };
320
+
321
+ let default_mock_ret_data = StructThing {item_one: 412, item_two: 421};
322
+ let specific_mock_ret_data = StructThing {item_one: 404, item_two: 401};
323
+ let expected_calldata = MockCallData::Values([].span());
324
+
325
+ start_mock_call_when(contract_address, selector!("get_struct_thing"), MockCallData::Any, default_mock_ret_data);
326
+ start_mock_call_when(contract_address, selector!("get_struct_thing"), expected_calldata, specific_mock_ret_data);
327
+
328
+ let thing: StructThing = dispatcher.get_struct_thing();
329
+
330
+ assert(thing.item_one == 404, 'thing.item_one');
331
+ assert(thing.item_two == 401, 'thing.item_two');
332
+
333
+ stop_mock_call_when(contract_address, selector!("get_struct_thing"), expected_calldata);
334
+ let thing: StructThing = dispatcher.get_struct_thing();
335
+
336
+ assert(thing.item_one == 412, 'thing.item_one');
337
+ assert(thing.item_two == 421, 'thing.item_two');
338
+ }
339
+
340
+ #[test]
341
+ fn start_mock_call_when_return_arr() {
342
+ let calldata = array![420];
343
+
344
+ let contract = declare("MockChecker").unwrap().contract_class();
345
+ let (contract_address, _) = contract.deploy(@calldata).unwrap();
346
+
347
+ let dispatcher = IMockCheckerDispatcher { contract_address };
348
+
349
+ let default_mock_ret_data = array![ StructThing {item_one: 112, item_two: 121}, StructThing {item_one: 412, item_two: 421} ];
350
+ let specific_mock_ret_data = array![ StructThing {item_one: 212, item_two: 221}, StructThing {item_one: 512, item_two: 521} ];
351
+
352
+ let expected_calldata = MockCallData::Values([].span());
353
+
354
+ start_mock_call_when(contract_address, selector!("get_arr_thing"), MockCallData::Any, default_mock_ret_data);
355
+ start_mock_call_when(contract_address, selector!("get_arr_thing"), expected_calldata, specific_mock_ret_data);
356
+
357
+ let things: Array<StructThing> = dispatcher.get_arr_thing();
358
+
359
+ let thing = things.at(0);
360
+ assert(*thing.item_one == 212, 'thing1.item_one 1');
361
+ assert(*thing.item_two == 221, 'thing1.item_two');
362
+
363
+ let thing = things.at(1);
364
+ assert(*thing.item_one == 512, 'thing2.item_one 2');
365
+ assert(*thing.item_two == 521, 'thing2.item_two');
366
+
367
+ stop_mock_call_when(contract_address, selector!("get_arr_thing"), expected_calldata);
368
+
369
+ let things: Array<StructThing> = dispatcher.get_arr_thing();
370
+
371
+ let thing = things.at(0);
372
+ assert(*thing.item_one == 112, 'thing1.item_one 3');
373
+ assert(*thing.item_two == 121, 'thing1.item_two');
374
+
375
+ let thing = things.at(1);
376
+ assert(*thing.item_one == 412, 'thing2.item_one 4');
377
+ assert(*thing.item_two == 421, 'thing2.item_two');
378
+ }
379
+ "#
380
+ ) ,
381
+ Contract :: from_code_path(
382
+ "MockChecker" . to_string( ) ,
383
+ Path :: new( "tests/data/contracts/mock_checker.cairo" ) ,
384
+ )
385
+ . unwrap( )
386
+ ) ;
387
+
388
+ let result = run_test_case ( & test) ;
389
+ assert_passed ( & result) ;
390
+ }
0 commit comments