Skip to content

Commit 39213ba

Browse files
committed
fix(cheatnet): default to any calldata entry when specific CheatSpan::TargetCalls is 0
1 parent 2da194c commit 39213ba

File tree

2 files changed

+201
-6
lines changed
  • crates

2 files changed

+201
-6
lines changed

crates/cheatnet/src/runtime_extensions/call_to_blockifier_runtime_extension/execution/entry_point.rs

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::cell::RefCell;
22
use super::cairo1_execution::execute_entry_point_call_cairo1;
33
use crate::runtime_extensions::call_to_blockifier_runtime_extension::execution::deprecated::cairo0_execution::execute_entry_point_call_cairo0;
44
use crate::runtime_extensions::call_to_blockifier_runtime_extension::CheatnetState;
5-
use crate::state::{CallTrace, CallTraceNode, CheatStatus, EncounteredError};
5+
use crate::state::{CallTrace, CallTraceNode, CheatSpan, CheatStatus, EncounteredError};
66
use blockifier::execution::call_info::{CallExecution, Retdata};
77
use blockifier::{
88
execution::{
@@ -278,11 +278,14 @@ fn get_mocked_function_cheat_status<'a>(
278278
Some(contract_functions) => {
279279
let call_data_hash = poseidon_hash_many(call.calldata.0.iter());
280280
let key = (call.entry_point_selector, call_data_hash);
281-
if contract_functions.contains_key(&key) {
282-
contract_functions.get_mut(&key)
283-
} else {
284-
let key_zero = (call.entry_point_selector, Felt::zero());
285-
contract_functions.get_mut(&key_zero)
281+
let key_zero = (call.entry_point_selector, Felt::zero());
282+
283+
match contract_functions.get(&key) {
284+
Some(CheatStatus::Cheated(_, CheatSpan::TargetCalls(0))) => {
285+
contract_functions.get_mut(&key_zero)
286+
}
287+
Some(CheatStatus::Cheated(_, _)) => contract_functions.get_mut(&key),
288+
_ => contract_functions.get_mut(&key_zero),
286289
}
287290
}
288291
}

crates/forge/tests/integration/mock_call.rs

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,3 +388,195 @@ fn mock_call_when_complex_types() {
388388
let result = run_test_case(&test);
389389
assert_passed(&result);
390390
}
391+
392+
#[test]
393+
fn mock_calls_when() {
394+
let test = test_case!(
395+
indoc!(
396+
r#"
397+
use result::ResultTrait;
398+
use snforge_std::{ declare, ContractClassTrait, DeclareResultTrait, mock_call_when, MockCallData};
399+
400+
#[starknet::interface]
401+
trait IMockChecker<TContractState> {
402+
fn get_thing(ref self: TContractState) -> felt252;
403+
}
404+
405+
#[test]
406+
fn mock_call_when_one_specific() {
407+
let calldata = array![420];
408+
409+
let contract = declare("MockChecker").unwrap().contract_class();
410+
let (contract_address, _) = contract.deploy(@calldata).unwrap();
411+
412+
let dispatcher = IMockCheckerDispatcher { contract_address };
413+
414+
let mock_ret_data = 421;
415+
let expected_calldata = MockCallData::Values([].span());
416+
mock_call_when(contract_address, selector!("get_thing"), expected_calldata, mock_ret_data, 1);
417+
418+
let thing = dispatcher.get_thing();
419+
assert_eq!(thing, 421);
420+
421+
let thing = dispatcher.get_thing();
422+
assert_eq!(thing, 420);
423+
}
424+
425+
#[test]
426+
fn mock_call_when_twice_specific() {
427+
let calldata = array![420];
428+
429+
let contract = declare("MockChecker").unwrap().contract_class();
430+
let (contract_address, _) = contract.deploy(@calldata).unwrap();
431+
432+
let dispatcher = IMockCheckerDispatcher { contract_address };
433+
434+
let mock_ret_data = 421;
435+
let expected_calldata = MockCallData::Values([].span());
436+
mock_call_when(contract_address, selector!("get_thing"), expected_calldata, mock_ret_data, 2);
437+
438+
let thing = dispatcher.get_thing();
439+
assert_eq!(thing, 421);
440+
441+
let thing = dispatcher.get_thing();
442+
assert_eq!(thing, 421);
443+
444+
let thing = dispatcher.get_thing();
445+
assert_eq!(thing, 420);
446+
}
447+
448+
#[test]
449+
fn mock_call_when_one_any() {
450+
let calldata = array![420];
451+
452+
let contract = declare("MockChecker").unwrap().contract_class();
453+
let (contract_address, _) = contract.deploy(@calldata).unwrap();
454+
455+
let dispatcher = IMockCheckerDispatcher { contract_address };
456+
457+
let mock_ret_data = 421;
458+
mock_call_when(contract_address, selector!("get_thing"), MockCallData::Any, mock_ret_data, 1);
459+
460+
let thing = dispatcher.get_thing();
461+
assert_eq!(thing, 421);
462+
463+
let thing = dispatcher.get_thing();
464+
assert_eq!(thing, 420);
465+
}
466+
467+
#[test]
468+
fn mock_call_when_twice_any() {
469+
let calldata = array![420];
470+
471+
let contract = declare("MockChecker").unwrap().contract_class();
472+
let (contract_address, _) = contract.deploy(@calldata).unwrap();
473+
474+
let dispatcher = IMockCheckerDispatcher { contract_address };
475+
476+
let mock_ret_data = 421;
477+
mock_call_when(contract_address, selector!("get_thing"), MockCallData::Any, mock_ret_data, 2);
478+
479+
let thing = dispatcher.get_thing();
480+
assert_eq!(thing, 421);
481+
482+
let thing = dispatcher.get_thing();
483+
assert_eq!(thing, 421);
484+
485+
let thing = dispatcher.get_thing();
486+
assert_eq!(thing, 420);
487+
}
488+
489+
"#
490+
),
491+
Contract::from_code_path(
492+
"MockChecker".to_string(),
493+
Path::new("tests/data/contracts/mock_checker.cairo"),
494+
)
495+
.unwrap()
496+
);
497+
498+
let result = run_test_case(&test);
499+
assert_passed(&result);
500+
}
501+
502+
#[test]
503+
fn mock_calls_when_mixed() {
504+
let test = test_case!(
505+
indoc!(
506+
r#"
507+
use result::ResultTrait;
508+
use snforge_std::{ declare, ContractClassTrait, DeclareResultTrait, mock_call_when, MockCallData};
509+
510+
#[starknet::interface]
511+
trait IMockChecker<TContractState> {
512+
fn get_thing(ref self: TContractState) -> felt252;
513+
}
514+
515+
#[test]
516+
fn mock_call_when_one() {
517+
let calldata = array![420];
518+
519+
let contract = declare("MockChecker").unwrap().contract_class();
520+
let (contract_address, _) = contract.deploy(@calldata).unwrap();
521+
522+
let dispatcher = IMockCheckerDispatcher { contract_address };
523+
524+
let mock_ret_data = 421;
525+
let expected_calldata = MockCallData::Values([].span());
526+
mock_call_when(contract_address, selector!("get_thing"), expected_calldata, mock_ret_data, 1);
527+
mock_call_when(contract_address, selector!("get_thing"), MockCallData::Any, 422, 1);
528+
529+
let thing = dispatcher.get_thing();
530+
assert_eq!(thing, 421, "Specific calldata");
531+
532+
let thing = dispatcher.get_thing();
533+
assert_eq!(thing, 422, "Any calldata");
534+
535+
let thing = dispatcher.get_thing();
536+
assert_eq!(thing, 420);
537+
}
538+
539+
#[test]
540+
fn mock_call_when_multi() {
541+
let calldata = array![420];
542+
543+
let contract = declare("MockChecker").unwrap().contract_class();
544+
let (contract_address, _) = contract.deploy(@calldata).unwrap();
545+
546+
let dispatcher = IMockCheckerDispatcher { contract_address };
547+
548+
let mock_ret_data = 421;
549+
let expected_calldata = MockCallData::Values([].span());
550+
mock_call_when(contract_address, selector!("get_thing"), expected_calldata, mock_ret_data, 3);
551+
mock_call_when(contract_address, selector!("get_thing"), MockCallData::Any, 422, 2);
552+
553+
let thing = dispatcher.get_thing();
554+
assert_eq!(thing, 421, "1st Specific calldata");
555+
556+
let thing = dispatcher.get_thing();
557+
assert_eq!(thing, 421, "2nd Specific calldata");
558+
559+
let thing = dispatcher.get_thing();
560+
assert_eq!(thing, 421, "3rd Specific calldata");
561+
562+
let thing = dispatcher.get_thing();
563+
assert_eq!(thing, 422, "1st Any calldata");
564+
565+
let thing = dispatcher.get_thing();
566+
assert_eq!(thing, 422, "2nd Any calldata");
567+
568+
let thing = dispatcher.get_thing();
569+
assert_eq!(thing, 420);
570+
}
571+
"#
572+
),
573+
Contract::from_code_path(
574+
"MockChecker".to_string(),
575+
Path::new("tests/data/contracts/mock_checker.cairo"),
576+
)
577+
.unwrap()
578+
);
579+
580+
let result = run_test_case(&test);
581+
assert_passed(&result);
582+
}

0 commit comments

Comments
 (0)