Skip to content

Commit 2da194c

Browse files
committed
snforge_std: revert breaking changes on mock_call
Add `mock_call_when`, `start_mock_call_when` and `stop_mock_call_when`
1 parent d75e50b commit 2da194c

File tree

5 files changed

+247
-28
lines changed

5 files changed

+247
-28
lines changed

crates/forge/tests/integration/cheat_fork.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ fn mock_call_cairo0_contract() {
155155
let test = test_case!(formatdoc!(
156156
r#"
157157
use starknet::{{contract_address_const}};
158-
use snforge_std::{{start_mock_call, stop_mock_call, MockCallData}};
158+
use snforge_std::{{start_mock_call, stop_mock_call}};
159159
160160
#[starknet::interface]
161161
trait IERC20<TContractState> {{
@@ -173,11 +173,11 @@ fn mock_call_cairo0_contract() {
173173
174174
assert(eth_dispatcher.name() == 'Ether', 'invalid name');
175175
176-
start_mock_call(eth_dispatcher.contract_address, selector!("name"), MockCallData::Any, 'NotEther');
176+
start_mock_call(eth_dispatcher.contract_address, selector!("name"), 'NotEther');
177177
178178
assert(eth_dispatcher.name() == 'NotEther', 'invalid mocked name');
179179
180-
stop_mock_call(eth_dispatcher.contract_address, selector!("name"), MockCallData::Any);
180+
stop_mock_call(eth_dispatcher.contract_address, selector!("name"));
181181
182182
assert(eth_dispatcher.name() == 'Ether', 'invalid name after mock');
183183
}}

crates/forge/tests/integration/mock_call.rs

Lines changed: 194 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ fn mock_call_simple() {
1010
indoc!(
1111
r#"
1212
use result::ResultTrait;
13-
use snforge_std::{ declare, ContractClassTrait, DeclareResultTrait, start_mock_call, stop_mock_call, MockCallData };
13+
use snforge_std::{ declare, ContractClassTrait, DeclareResultTrait, start_mock_call, stop_mock_call };
1414
1515
#[starknet::interface]
1616
trait IMockChecker<TContractState> {
@@ -26,19 +26,13 @@ fn mock_call_simple() {
2626
2727
let dispatcher = IMockCheckerDispatcher { contract_address };
2828
29-
let specific_mock_ret_data = 421;
30-
let default_mock_ret_data = 404;
31-
let expected_calldata = MockCallData::Values([].span());
32-
start_mock_call(contract_address, selector!("get_thing"), expected_calldata, specific_mock_ret_data);
33-
start_mock_call(contract_address, selector!("get_thing"), MockCallData::Any, default_mock_ret_data);
34-
let thing = dispatcher.get_thing();
35-
assert(thing == specific_mock_ret_data, 'Incorrect thing');
29+
let mock_ret_data = 421;
3630
37-
stop_mock_call(contract_address, selector!("get_thing"), expected_calldata);
31+
start_mock_call(contract_address, selector!("get_thing"), mock_ret_data);
3832
let thing = dispatcher.get_thing();
39-
assert(thing == default_mock_ret_data, 'Incorrect thing');
33+
assert(thing == 421, 'Incorrect thing');
4034
41-
stop_mock_call(contract_address, selector!("get_thing"), MockCallData::Any);
35+
stop_mock_call(contract_address, selector!("get_thing"));
4236
let thing = dispatcher.get_thing();
4337
assert(thing == 420, 'Incorrect thing');
4438
}
@@ -51,12 +45,12 @@ fn mock_call_simple() {
5145
let (contract_address, _) = contract.deploy(@calldata).unwrap();
5246
5347
let mock_ret_data = 421;
54-
start_mock_call(contract_address, selector!("get_thing"), MockCallData::Any, mock_ret_data);
48+
start_mock_call(contract_address, selector!("get_thing"), mock_ret_data);
5549
5650
let dispatcher = IMockCheckerDispatcher { contract_address };
5751
let thing = dispatcher.get_thing();
5852
59-
assert(thing == 421, 'Incorrect thing all catch');
53+
assert(thing == 421, 'Incorrect thing');
6054
}
6155
"#
6256
),
@@ -79,7 +73,7 @@ fn mock_call_complex_types() {
7973
use result::ResultTrait;
8074
use array::ArrayTrait;
8175
use serde::Serde;
82-
use snforge_std::{ declare, ContractClassTrait, DeclareResultTrait, start_mock_call, MockCallData };
76+
use snforge_std::{ declare, ContractClassTrait, DeclareResultTrait, start_mock_call };
8377
8478
#[starknet::interface]
8579
trait IMockChecker<TContractState> {
@@ -103,7 +97,7 @@ fn mock_call_complex_types() {
10397
let dispatcher = IMockCheckerDispatcher { contract_address };
10498
10599
let mock_ret_data = StructThing {item_one: 412, item_two: 421};
106-
start_mock_call(contract_address, selector!("get_struct_thing"), MockCallData::Any, mock_ret_data);
100+
start_mock_call(contract_address, selector!("get_struct_thing"), mock_ret_data);
107101
108102
let thing: StructThing = dispatcher.get_struct_thing();
109103
@@ -121,7 +115,7 @@ fn mock_call_complex_types() {
121115
let dispatcher = IMockCheckerDispatcher { contract_address };
122116
123117
let mock_ret_data = array![ StructThing {item_one: 112, item_two: 121}, StructThing {item_one: 412, item_two: 421} ];
124-
start_mock_call(contract_address, selector!("get_arr_thing"), MockCallData::Any, mock_ret_data);
118+
start_mock_call(contract_address, selector!("get_arr_thing"), mock_ret_data);
125119
126120
let things: Array<StructThing> = dispatcher.get_arr_thing();
127121
@@ -152,7 +146,7 @@ fn mock_calls() {
152146
indoc!(
153147
r#"
154148
use result::ResultTrait;
155-
use snforge_std::{ declare, ContractClassTrait, DeclareResultTrait, mock_call, MockCallData };
149+
use snforge_std::{ declare, ContractClassTrait, DeclareResultTrait, mock_call, start_mock_call, stop_mock_call };
156150
157151
#[starknet::interface]
158152
trait IMockChecker<TContractState> {
@@ -170,7 +164,7 @@ fn mock_calls() {
170164
171165
let mock_ret_data = 421;
172166
173-
mock_call(contract_address, selector!("get_thing"), MockCallData::Any, mock_ret_data, 1);
167+
mock_call(contract_address, selector!("get_thing"), mock_ret_data, 1);
174168
175169
let thing = dispatcher.get_thing();
176170
assert_eq!(thing, 421);
@@ -190,7 +184,7 @@ fn mock_calls() {
190184
191185
let mock_ret_data = 421;
192186
193-
mock_call(contract_address, selector!("get_thing"), MockCallData::Any, mock_ret_data, 2);
187+
mock_call(contract_address, selector!("get_thing"), mock_ret_data, 2);
194188
195189
let thing = dispatcher.get_thing();
196190
assert_eq!(thing, 421);
@@ -213,3 +207,184 @@ fn mock_calls() {
213207
let result = run_test_case(&test);
214208
assert_passed(&result);
215209
}
210+
211+
#[test]
212+
fn mock_call_when_simple() {
213+
let test = test_case!(
214+
indoc!(
215+
r#"
216+
use result::ResultTrait;
217+
use snforge_std::{ declare, ContractClassTrait, DeclareResultTrait, start_mock_call_when, stop_mock_call_when, MockCallData };
218+
219+
#[starknet::interface]
220+
trait IMockChecker<TContractState> {
221+
fn get_thing(ref self: TContractState) -> felt252;
222+
}
223+
224+
#[test]
225+
fn mock_call_when_simple() {
226+
let calldata = array![420];
227+
228+
let contract = declare("MockChecker").unwrap().contract_class();
229+
let (contract_address, _) = contract.deploy(@calldata).unwrap();
230+
231+
let dispatcher = IMockCheckerDispatcher { contract_address };
232+
233+
let specific_mock_ret_data = 421;
234+
let default_mock_ret_data = 404;
235+
let expected_calldata = MockCallData::Values([].span());
236+
237+
start_mock_call_when(contract_address, selector!("get_thing"), expected_calldata, specific_mock_ret_data);
238+
start_mock_call_when(contract_address, selector!("get_thing"), MockCallData::Any, default_mock_ret_data);
239+
let thing = dispatcher.get_thing();
240+
assert(thing == specific_mock_ret_data, 'Incorrect thing');
241+
242+
stop_mock_call_when(contract_address, selector!("get_thing"), expected_calldata);
243+
let thing = dispatcher.get_thing();
244+
assert(thing == default_mock_ret_data, 'Incorrect thing');
245+
246+
stop_mock_call_when(contract_address, selector!("get_thing"), MockCallData::Any);
247+
let thing = dispatcher.get_thing();
248+
assert(thing == 420, 'Incorrect thing');
249+
}
250+
251+
#[test]
252+
fn mock_call_when_simple_before_dispatcher_created() {
253+
let calldata = array![420];
254+
255+
let contract = declare("MockChecker").unwrap().contract_class();
256+
let (contract_address, _) = contract.deploy(@calldata).unwrap();
257+
258+
let specific_mock_ret_data = 421;
259+
let default_mock_ret_data = 404;
260+
let expected_calldata = MockCallData::Values([].span());
261+
262+
start_mock_call_when(contract_address, selector!("get_thing"), expected_calldata, specific_mock_ret_data);
263+
start_mock_call_when(contract_address, selector!("get_thing"), MockCallData::Any, default_mock_ret_data);
264+
let dispatcher = IMockCheckerDispatcher { contract_address };
265+
let thing = dispatcher.get_thing();
266+
267+
assert(thing == specific_mock_ret_data, 'Incorrect thing');
268+
269+
stop_mock_call_when(contract_address, selector!("get_thing"), expected_calldata);
270+
let thing = dispatcher.get_thing();
271+
assert(thing == default_mock_ret_data, 'Incorrect thing');
272+
273+
stop_mock_call_when(contract_address, selector!("get_thing"), MockCallData::Any);
274+
let thing = dispatcher.get_thing();
275+
assert(thing == 420, 'Incorrect thing');
276+
}
277+
"#
278+
),
279+
Contract::from_code_path(
280+
"MockChecker".to_string(),
281+
Path::new("tests/data/contracts/mock_checker.cairo"),
282+
)
283+
.unwrap()
284+
);
285+
286+
let result = run_test_case(&test);
287+
assert_passed(&result);
288+
}
289+
290+
#[test]
291+
fn mock_call_when_complex_types() {
292+
let test = test_case!(
293+
indoc!(
294+
r#"
295+
use result::ResultTrait;
296+
use array::ArrayTrait;
297+
use serde::Serde;
298+
use snforge_std::{ declare, ContractClassTrait, DeclareResultTrait, start_mock_call_when, stop_mock_call_when, MockCallData };
299+
300+
#[starknet::interface]
301+
trait IMockChecker<TContractState> {
302+
fn get_struct_thing(ref self: TContractState) -> StructThing;
303+
fn get_arr_thing(ref self: TContractState) -> Array<StructThing>;
304+
}
305+
306+
#[derive(Serde, Drop)]
307+
struct StructThing {
308+
item_one: felt252,
309+
item_two: felt252,
310+
}
311+
312+
#[test]
313+
fn start_mock_call_when_return_struct() {
314+
let calldata = array![420];
315+
316+
let contract = declare("MockChecker").unwrap().contract_class();
317+
let (contract_address, _) = contract.deploy(@calldata).unwrap();
318+
319+
let dispatcher = IMockCheckerDispatcher { contract_address };
320+
321+
let default_mock_ret_data = StructThing {item_one: 412, item_two: 421};
322+
let specific_mock_ret_data = StructThing {item_one: 404, item_two: 401};
323+
let expected_calldata = MockCallData::Values([].span());
324+
325+
start_mock_call_when(contract_address, selector!("get_struct_thing"), MockCallData::Any, default_mock_ret_data);
326+
start_mock_call_when(contract_address, selector!("get_struct_thing"), expected_calldata, specific_mock_ret_data);
327+
328+
let thing: StructThing = dispatcher.get_struct_thing();
329+
330+
assert(thing.item_one == 404, 'thing.item_one');
331+
assert(thing.item_two == 401, 'thing.item_two');
332+
333+
stop_mock_call_when(contract_address, selector!("get_struct_thing"), expected_calldata);
334+
let thing: StructThing = dispatcher.get_struct_thing();
335+
336+
assert(thing.item_one == 412, 'thing.item_one');
337+
assert(thing.item_two == 421, 'thing.item_two');
338+
}
339+
340+
#[test]
341+
fn start_mock_call_when_return_arr() {
342+
let calldata = array![420];
343+
344+
let contract = declare("MockChecker").unwrap().contract_class();
345+
let (contract_address, _) = contract.deploy(@calldata).unwrap();
346+
347+
let dispatcher = IMockCheckerDispatcher { contract_address };
348+
349+
let default_mock_ret_data = array![ StructThing {item_one: 112, item_two: 121}, StructThing {item_one: 412, item_two: 421} ];
350+
let specific_mock_ret_data = array![ StructThing {item_one: 212, item_two: 221}, StructThing {item_one: 512, item_two: 521} ];
351+
352+
let expected_calldata = MockCallData::Values([].span());
353+
354+
start_mock_call_when(contract_address, selector!("get_arr_thing"), MockCallData::Any, default_mock_ret_data);
355+
start_mock_call_when(contract_address, selector!("get_arr_thing"), expected_calldata, specific_mock_ret_data);
356+
357+
let things: Array<StructThing> = dispatcher.get_arr_thing();
358+
359+
let thing = things.at(0);
360+
assert(*thing.item_one == 212, 'thing1.item_one 1');
361+
assert(*thing.item_two == 221, 'thing1.item_two');
362+
363+
let thing = things.at(1);
364+
assert(*thing.item_one == 512, 'thing2.item_one 2');
365+
assert(*thing.item_two == 521, 'thing2.item_two');
366+
367+
stop_mock_call_when(contract_address, selector!("get_arr_thing"), expected_calldata);
368+
369+
let things: Array<StructThing> = dispatcher.get_arr_thing();
370+
371+
let thing = things.at(0);
372+
assert(*thing.item_one == 112, 'thing1.item_one 3');
373+
assert(*thing.item_two == 121, 'thing1.item_two');
374+
375+
let thing = things.at(1);
376+
assert(*thing.item_one == 412, 'thing2.item_one 4');
377+
assert(*thing.item_two == 421, 'thing2.item_two');
378+
}
379+
"#
380+
),
381+
Contract::from_code_path(
382+
"MockChecker".to_string(),
383+
Path::new("tests/data/contracts/mock_checker.cairo"),
384+
)
385+
.unwrap()
386+
);
387+
388+
let result = run_test_case(&test);
389+
assert_passed(&result);
390+
}

crates/forge/tests/integration/test_state.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,7 @@ fn inconsistent_syscall_pointers() {
678678
r#"
679679
use starknet::ContractAddress;
680680
use starknet::info::get_block_number;
681-
use snforge_std::{start_mock_call, MockCallData};
681+
use snforge_std::start_mock_call;
682682
683683
#[starknet::interface]
684684
trait IContract<TContractState> {
@@ -689,7 +689,7 @@ fn inconsistent_syscall_pointers() {
689689
fn inconsistent_syscall_pointers() {
690690
// verifies if SyscallHandler.syscal_ptr is incremented correctly when calling a contract
691691
let address = 'address'.try_into().unwrap();
692-
start_mock_call(address, selector!("get_value"), MockCallData::Any, 55);
692+
start_mock_call(address, selector!("get_value"), 55);
693693
let contract = IContractDispatcher { contract_address: address };
694694
contract.get_value(address);
695695
get_block_number();

0 commit comments

Comments
 (0)