diff --git a/CHANGELOG.md b/CHANGELOG.md index 8a77127cc4..a53cac0d77 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -324,6 +324,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Rust is no longer required to use `snforge` if using Scarb >= 2.10.0 on supported platforms - precompiled `snforge_scarb_plugin` plugin binaries are now published to [package registry](https://scarbs.xyz) for new versions. - Added a suggestion for using the `--max-n-steps` flag when the Cairo VM returns the error: `Could not reach the end of the program. RunResources has no remaining steps`. +- `mock_call_when`, `start_mock_call_when`, `stop_mock_call_when` cheatcodes. #### Fixed diff --git a/Cargo.lock b/Cargo.lock index 82149d7198..4f9a8f20a1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1563,6 +1563,7 @@ dependencies = [ "serde_json", "shared", "starknet", + "starknet-crypto 0.7.1", "starknet-types-core", "starknet_api", "tempfile", diff --git a/crates/cheatnet/Cargo.toml b/crates/cheatnet/Cargo.toml index 1386b37eda..3649e99467 100644 --- a/crates/cheatnet/Cargo.toml +++ b/crates/cheatnet/Cargo.toml @@ -13,6 +13,7 @@ bimap.workspace = true camino.workspace = true starknet_api.workspace = true starknet-types-core.workspace = true +starknet-crypto.workspace = true cairo-lang-casm.workspace = true cairo-lang-utils.workspace = true cairo-lang-starknet-classes.workspace = true diff --git a/crates/cheatnet/src/runtime_extensions/call_to_blockifier_runtime_extension/execution/entry_point.rs b/crates/cheatnet/src/runtime_extensions/call_to_blockifier_runtime_extension/execution/entry_point.rs index 39f3dc9e5e..2b211d7a18 100644 --- a/crates/cheatnet/src/runtime_extensions/call_to_blockifier_runtime_extension/execution/entry_point.rs +++ b/crates/cheatnet/src/runtime_extensions/call_to_blockifier_runtime_extension/execution/entry_point.rs @@ -31,12 +31,14 @@ use cairo_vm::vm::runners::cairo_runner::{CairoRunner, ExecutionResources}; use cairo_vm::vm::trace::trace_entry::RelocatedTraceEntry; use conversions::FromConv; use conversions::string::TryFromHexStr; +use num_traits::Zero; use shared::vm::VirtualMachineExt; use starknet_api::{ contract_class::EntryPointType, core::ClassHash, transaction::{TransactionVersion, fields::Calldata}, }; +use starknet_crypto::poseidon_hash_many; use starknet_types_core::felt::Felt; use std::collections::{HashMap, HashSet}; use thiserror::Error; @@ -415,11 +417,22 @@ fn get_mocked_function_cheat_status<'a>( if call.call_type == CallType::Delegate { return None; } - - cheatnet_state + match cheatnet_state .mocked_functions .get_mut(&call.storage_address) - .and_then(|contract_functions| contract_functions.get_mut(&call.entry_point_selector)) + { + None => None, + Some(contract_functions) => { + let calldata_hash = poseidon_hash_many(call.calldata.0.iter()); + let key = (call.entry_point_selector, calldata_hash); + let key_zero = (call.entry_point_selector, Felt::zero()); + + match contract_functions.get(&key) { + Some(CheatStatus::Cheated(_, _)) => contract_functions.get_mut(&key), + _ => contract_functions.get_mut(&key_zero), + } + } + } } fn mocked_call_info( diff --git a/crates/cheatnet/src/runtime_extensions/forge_runtime_extension/cheatcodes/mock_call.rs b/crates/cheatnet/src/runtime_extensions/forge_runtime_extension/cheatcodes/mock_call.rs index 11588add18..3dabb1868f 100644 --- a/crates/cheatnet/src/runtime_extensions/forge_runtime_extension/cheatcodes/mock_call.rs +++ b/crates/cheatnet/src/runtime_extensions/forge_runtime_extension/cheatcodes/mock_call.rs @@ -1,6 +1,8 @@ use crate::CheatnetState; -use crate::state::{CheatSpan, CheatStatus}; +use crate::state::{CheatSpan, CheatStatus, MockCalldata}; +use num_traits::Zero; use starknet_api::core::{ContractAddress, EntryPointSelector}; +use starknet_crypto::poseidon_hash_many; use starknet_types_core::felt::Felt; use std::collections::hash_map::Entry; @@ -9,26 +11,30 @@ impl CheatnetState { &mut self, contract_address: ContractAddress, function_selector: EntryPointSelector, + calldata: MockCalldata, ret_data: &[Felt], span: CheatSpan, ) { let contract_mocked_functions = self.mocked_functions.entry(contract_address).or_default(); - - contract_mocked_functions.insert( - function_selector, - CheatStatus::Cheated(ret_data.to_vec(), span), - ); + let calldata_hash = match calldata { + MockCalldata::Values(data) => poseidon_hash_many(data.iter()), + MockCalldata::Any => Felt::zero(), + }; + let key = (function_selector, calldata_hash); + contract_mocked_functions.insert(key, CheatStatus::Cheated(ret_data.to_vec(), span)); } pub fn start_mock_call( &mut self, contract_address: ContractAddress, function_selector: EntryPointSelector, + calldata: MockCalldata, ret_data: &[Felt], ) { self.mock_call( contract_address, function_selector, + calldata, ret_data, CheatSpan::Indefinite, ); @@ -38,10 +44,15 @@ impl CheatnetState { &mut self, contract_address: ContractAddress, function_selector: EntryPointSelector, + calldata: MockCalldata, ) { if let Entry::Occupied(mut e) = self.mocked_functions.entry(contract_address) { let contract_mocked_functions = e.get_mut(); - contract_mocked_functions.remove(&function_selector); + let calldata_hash = match calldata { + MockCalldata::Values(data) => poseidon_hash_many(data.iter()), + MockCalldata::Any => Felt::zero(), + }; + contract_mocked_functions.remove(&(function_selector, calldata_hash)); } } } diff --git a/crates/cheatnet/src/runtime_extensions/forge_runtime_extension/mod.rs b/crates/cheatnet/src/runtime_extensions/forge_runtime_extension/mod.rs index 61119ee603..dd67ef02b9 100644 --- a/crates/cheatnet/src/runtime_extensions/forge_runtime_extension/mod.rs +++ b/crates/cheatnet/src/runtime_extensions/forge_runtime_extension/mod.rs @@ -95,26 +95,31 @@ impl<'a> ExtensionLogic for ForgeExtension<'a> { "mock_call" => { let contract_address = input_reader.read()?; let function_selector = input_reader.read()?; + let calldata = input_reader.read()?; let span = input_reader.read()?; - let ret_data: Vec<_> = input_reader.read()?; - extended_runtime .extended_runtime .extension .cheatnet_state - .mock_call(contract_address, function_selector, &ret_data, span); + .mock_call( + contract_address, + function_selector, + calldata, + &ret_data, + span, + ); Ok(CheatcodeHandlingResult::from_serializable(())) } "stop_mock_call" => { let contract_address = input_reader.read()?; let function_selector = input_reader.read()?; - + let calldata = input_reader.read()?; extended_runtime .extended_runtime .extension .cheatnet_state - .stop_mock_call(contract_address, function_selector); + .stop_mock_call(contract_address, function_selector, calldata); Ok(CheatcodeHandlingResult::from_serializable(())) } "replace_bytecode" => { diff --git a/crates/cheatnet/src/state.rs b/crates/cheatnet/src/state.rs index 0537d263d4..c4bf604835 100644 --- a/crates/cheatnet/src/state.rs +++ b/crates/cheatnet/src/state.rs @@ -47,6 +47,12 @@ pub enum CheatSpan { TargetCalls(NonZeroUsize), } +#[derive(CairoDeserialize, Clone, Debug, PartialEq, Eq)] +pub enum MockCalldata { + Any, + Values(Vec), +} + #[derive(Debug)] pub struct ExtendedStateReader { pub dict_state_reader: DictStateReader, @@ -369,12 +375,13 @@ pub struct TraceData { pub is_vm_trace_needed: bool, } +type MockedFunctionKey = (EntryPointSelector, Felt); pub struct CheatnetState { pub cheated_execution_info_contracts: HashMap, pub global_cheated_execution_info: ExecutionInfoMock, pub mocked_functions: - HashMap>>>, + HashMap>>>, pub replaced_bytecode_contracts: HashMap, pub detected_events: Vec, pub detected_messages_to_l1: Vec, diff --git a/crates/cheatnet/tests/cheatcodes/mock_call.rs b/crates/cheatnet/tests/cheatcodes/mock_call.rs index 88b5faa0fd..02a276b611 100644 --- a/crates/cheatnet/tests/cheatcodes/mock_call.rs +++ b/crates/cheatnet/tests/cheatcodes/mock_call.rs @@ -9,7 +9,7 @@ use crate::{ }; use cheatnet::runtime_extensions::forge_runtime_extension::cheatcodes::declare::declare; use cheatnet::runtime_extensions::forge_runtime_extension::cheatcodes::storage::selector_from_name; -use cheatnet::state::{CheatSpan, CheatnetState}; +use cheatnet::state::{CheatSpan, CheatnetState, MockCalldata}; use conversions::IntoConv; use starknet::core::utils::get_selector_from_name; use starknet_api::core::ContractAddress; @@ -40,6 +40,7 @@ impl MockCallTrait for TestEnvironment { self.cheatnet_state.mock_call( *contract_address, function_selector.into_(), + MockCalldata::Any, &ret_data, span, ); @@ -47,8 +48,11 @@ impl MockCallTrait for TestEnvironment { fn stop_mock_call(&mut self, contract_address: &ContractAddress, function_name: &str) { let function_selector = get_selector_from_name(function_name).unwrap(); - self.cheatnet_state - .stop_mock_call(*contract_address, function_selector.into_()); + self.cheatnet_state.stop_mock_call( + *contract_address, + function_selector.into_(), + MockCalldata::Any, + ); } } @@ -67,7 +71,12 @@ fn mock_call_simple() { let selector = selector_from_name("get_thing"); let ret_data = [Felt::from(123)]; - cheatnet_state.start_mock_call(contract_address, selector_from_name("get_thing"), &ret_data); + cheatnet_state.start_mock_call( + contract_address, + selector_from_name("get_thing"), + MockCalldata::Any, + &ret_data, + ); let output = call_contract( &mut cached_state, @@ -95,7 +104,12 @@ fn mock_call_stop() { let selector = selector_from_name("get_thing"); let ret_data = [Felt::from(123)]; - cheatnet_state.start_mock_call(contract_address, selector_from_name("get_thing"), &ret_data); + cheatnet_state.start_mock_call( + contract_address, + selector_from_name("get_thing"), + MockCalldata::Any, + &ret_data, + ); let output = call_contract( &mut cached_state, @@ -107,7 +121,11 @@ fn mock_call_stop() { assert_success(output, &ret_data); - cheatnet_state.stop_mock_call(contract_address, selector_from_name("get_thing")); + cheatnet_state.stop_mock_call( + contract_address, + selector_from_name("get_thing"), + MockCalldata::Any, + ); let output = call_contract( &mut cached_state, @@ -134,7 +152,11 @@ fn mock_call_stop_no_start() { let selector = selector_from_name("get_thing"); - cheatnet_state.stop_mock_call(contract_address, selector_from_name("get_thing")); + cheatnet_state.stop_mock_call( + contract_address, + selector_from_name("get_thing"), + MockCalldata::Any, + ); let output = call_contract( &mut cached_state, @@ -162,10 +184,10 @@ fn mock_call_double() { let selector = selector_from_name("get_thing"); let ret_data = [Felt::from(123)]; - cheatnet_state.start_mock_call(contract_address, selector, &ret_data); + cheatnet_state.start_mock_call(contract_address, selector, MockCalldata::Any, &ret_data); let ret_data = [Felt::from(999)]; - cheatnet_state.start_mock_call(contract_address, selector, &ret_data); + cheatnet_state.start_mock_call(contract_address, selector, MockCalldata::Any, &ret_data); let output = call_contract( &mut cached_state, @@ -177,7 +199,7 @@ fn mock_call_double() { assert_success(output, &ret_data); - cheatnet_state.stop_mock_call(contract_address, selector); + cheatnet_state.stop_mock_call(contract_address, selector, MockCalldata::Any); let output = call_contract( &mut cached_state, @@ -205,7 +227,12 @@ fn mock_call_double_call() { let selector = selector_from_name("get_thing"); let ret_data = [Felt::from(123)]; - cheatnet_state.start_mock_call(contract_address, selector_from_name("get_thing"), &ret_data); + cheatnet_state.start_mock_call( + contract_address, + selector_from_name("get_thing"), + MockCalldata::Any, + &ret_data, + ); let output = call_contract( &mut cached_state, @@ -242,7 +269,12 @@ fn mock_call_proxy() { let selector = selector_from_name("get_thing"); let ret_data = [Felt::from(123)]; - cheatnet_state.start_mock_call(contract_address, selector_from_name("get_thing"), &ret_data); + cheatnet_state.start_mock_call( + contract_address, + selector_from_name("get_thing"), + MockCalldata::Any, + &ret_data, + ); let output = call_contract( &mut cached_state, @@ -286,7 +318,12 @@ fn mock_call_proxy_with_other_syscall() { let selector = selector_from_name("get_thing"); let ret_data = [Felt::from(123)]; - cheatnet_state.start_mock_call(contract_address, selector_from_name("get_thing"), &ret_data); + cheatnet_state.start_mock_call( + contract_address, + selector_from_name("get_thing"), + MockCalldata::Any, + &ret_data, + ); let output = call_contract( &mut cached_state, @@ -331,7 +368,12 @@ fn mock_call_inner_call_no_effect() { let selector = selector_from_name("get_thing"); let ret_data = [Felt::from(123)]; - cheatnet_state.start_mock_call(contract_address, selector_from_name("get_thing"), &ret_data); + cheatnet_state.start_mock_call( + contract_address, + selector_from_name("get_thing"), + MockCalldata::Any, + &ret_data, + ); let output = call_contract( &mut cached_state, @@ -385,6 +427,7 @@ fn mock_call_library_call_no_effect() { cheatnet_state.start_mock_call( contract_address, selector_from_name("get_constant_thing"), + MockCalldata::Any, &ret_data, ); @@ -418,6 +461,7 @@ fn mock_call_before_deployment() { cheatnet_state.start_mock_call( precalculated_address, selector_from_name("get_thing"), + MockCalldata::Any, &ret_data, ); @@ -460,6 +504,7 @@ fn mock_call_not_implemented() { cheatnet_state.start_mock_call( contract_address, selector_from_name("get_thing_not_implemented"), + MockCalldata::Any, &ret_data, ); @@ -490,6 +535,7 @@ fn mock_call_in_constructor() { cheatnet_state.start_mock_call( balance_contract_address, selector_from_name("get_balance"), + MockCalldata::Any, &ret_data, ); @@ -534,11 +580,17 @@ fn mock_call_two_methods() { let selector2 = selector_from_name("get_constant_thing"); let ret_data = [Felt::from(123)]; - cheatnet_state.start_mock_call(contract_address, selector_from_name("get_thing"), &ret_data); + cheatnet_state.start_mock_call( + contract_address, + selector_from_name("get_thing"), + MockCalldata::Any, + &ret_data, + ); cheatnet_state.start_mock_call( contract_address, selector_from_name("get_constant_thing"), + MockCalldata::Any, &ret_data, ); @@ -573,7 +625,12 @@ fn mock_call_nonexisting_contract() { let contract_address = ContractAddress::from(218_u8); - cheatnet_state.start_mock_call(contract_address, selector_from_name("get_thing"), &ret_data); + cheatnet_state.start_mock_call( + contract_address, + selector_from_name("get_thing"), + MockCalldata::Any, + &ret_data, + ); let output = call_contract( &mut cached_state, diff --git a/crates/forge/tests/integration/mock_call.rs b/crates/forge/tests/integration/mock_call.rs index 0a018fca9b..34a00a616e 100644 --- a/crates/forge/tests/integration/mock_call.rs +++ b/crates/forge/tests/integration/mock_call.rs @@ -208,3 +208,544 @@ fn mock_calls() { let result = run_test_case(&test, ForgeTrackedResource::CairoSteps); assert_passed(&result); } + +#[test] +fn mock_call_when_simple() { + let test = test_case!( + indoc!( + r#" + use result::ResultTrait; + use snforge_std::{ declare, ContractClassTrait, DeclareResultTrait, start_mock_call_when, stop_mock_call_when, MockCalldata }; + + #[starknet::interface] + trait IMockChecker { + fn get_thing(ref self: TContractState) -> felt252; + } + + #[test] + fn mock_call_when_simple() { + let calldata = array![420]; + + let contract = declare("MockChecker").unwrap().contract_class(); + let (contract_address, _) = contract.deploy(@calldata).unwrap(); + + let dispatcher = IMockCheckerDispatcher { contract_address }; + + let specific_mock_ret_data = 421; + let default_mock_ret_data = 404; + let expected_calldata = MockCalldata::Values([].span()); + + start_mock_call_when(contract_address, selector!("get_thing"), expected_calldata, specific_mock_ret_data); + start_mock_call_when(contract_address, selector!("get_thing"), MockCalldata::Any, default_mock_ret_data); + let thing = dispatcher.get_thing(); + assert(thing == specific_mock_ret_data, 'Incorrect thing'); + + stop_mock_call_when(contract_address, selector!("get_thing"), expected_calldata); + let thing = dispatcher.get_thing(); + assert(thing == default_mock_ret_data, 'Incorrect thing'); + + stop_mock_call_when(contract_address, selector!("get_thing"), MockCalldata::Any); + let thing = dispatcher.get_thing(); + assert(thing == 420, 'Incorrect thing'); + } + + #[test] + fn mock_call_when_simple_before_dispatcher_created() { + let calldata = array![420]; + + let contract = declare("MockChecker").unwrap().contract_class(); + let (contract_address, _) = contract.deploy(@calldata).unwrap(); + + let specific_mock_ret_data = 421; + let default_mock_ret_data = 404; + let expected_calldata = MockCalldata::Values([].span()); + + start_mock_call_when(contract_address, selector!("get_thing"), expected_calldata, specific_mock_ret_data); + start_mock_call_when(contract_address, selector!("get_thing"), MockCalldata::Any, default_mock_ret_data); + let dispatcher = IMockCheckerDispatcher { contract_address }; + let thing = dispatcher.get_thing(); + + assert(thing == specific_mock_ret_data, 'Incorrect thing'); + + stop_mock_call_when(contract_address, selector!("get_thing"), expected_calldata); + let thing = dispatcher.get_thing(); + assert(thing == default_mock_ret_data, 'Incorrect thing'); + + stop_mock_call_when(contract_address, selector!("get_thing"), MockCalldata::Any); + let thing = dispatcher.get_thing(); + assert(thing == 420, 'Incorrect thing'); + } + "# + ), + Contract::from_code_path( + "MockChecker".to_string(), + Path::new("tests/data/contracts/mock_checker.cairo"), + ) + .unwrap() + ); + + let result = run_test_case(&test, ForgeTrackedResource::CairoSteps); + assert_passed(&result); +} + +#[test] +fn mock_call_when_complex_types() { + let test = test_case!( + indoc!( + r#" + use result::ResultTrait; + use array::ArrayTrait; + use serde::Serde; + use snforge_std::{ declare, ContractClassTrait, DeclareResultTrait, start_mock_call_when, stop_mock_call_when, MockCalldata }; + + #[starknet::interface] + trait IMockChecker { + fn get_struct_thing(ref self: TContractState) -> StructThing; + fn get_arr_thing(ref self: TContractState) -> Array; + } + + #[derive(Serde, Drop)] + struct StructThing { + item_one: felt252, + item_two: felt252, + } + + #[test] + fn start_mock_call_when_return_struct() { + let calldata = array![420]; + + let contract = declare("MockChecker").unwrap().contract_class(); + let (contract_address, _) = contract.deploy(@calldata).unwrap(); + + let dispatcher = IMockCheckerDispatcher { contract_address }; + + let default_mock_ret_data = StructThing {item_one: 412, item_two: 421}; + let specific_mock_ret_data = StructThing {item_one: 404, item_two: 401}; + let expected_calldata = MockCalldata::Values([].span()); + + start_mock_call_when(contract_address, selector!("get_struct_thing"), MockCalldata::Any, default_mock_ret_data); + start_mock_call_when(contract_address, selector!("get_struct_thing"), expected_calldata, specific_mock_ret_data); + + let thing: StructThing = dispatcher.get_struct_thing(); + + assert(thing.item_one == 404, 'thing.item_one'); + assert(thing.item_two == 401, 'thing.item_two'); + + stop_mock_call_when(contract_address, selector!("get_struct_thing"), expected_calldata); + let thing: StructThing = dispatcher.get_struct_thing(); + + assert(thing.item_one == 412, 'thing.item_one'); + assert(thing.item_two == 421, 'thing.item_two'); + } + + #[test] + fn start_mock_call_when_return_arr() { + let calldata = array![420]; + + let contract = declare("MockChecker").unwrap().contract_class(); + let (contract_address, _) = contract.deploy(@calldata).unwrap(); + + let dispatcher = IMockCheckerDispatcher { contract_address }; + + let default_mock_ret_data = array![ StructThing {item_one: 112, item_two: 121}, StructThing {item_one: 412, item_two: 421} ]; + let specific_mock_ret_data = array![ StructThing {item_one: 212, item_two: 221}, StructThing {item_one: 512, item_two: 521} ]; + + let expected_calldata = MockCalldata::Values([].span()); + + start_mock_call_when(contract_address, selector!("get_arr_thing"), MockCalldata::Any, default_mock_ret_data); + start_mock_call_when(contract_address, selector!("get_arr_thing"), expected_calldata, specific_mock_ret_data); + + let things: Array = dispatcher.get_arr_thing(); + + let thing = things.at(0); + assert(*thing.item_one == 212, 'thing1.item_one 1'); + assert(*thing.item_two == 221, 'thing1.item_two'); + + let thing = things.at(1); + assert(*thing.item_one == 512, 'thing2.item_one 2'); + assert(*thing.item_two == 521, 'thing2.item_two'); + + stop_mock_call_when(contract_address, selector!("get_arr_thing"), expected_calldata); + + let things: Array = dispatcher.get_arr_thing(); + + let thing = things.at(0); + assert(*thing.item_one == 112, 'thing1.item_one 3'); + assert(*thing.item_two == 121, 'thing1.item_two'); + + let thing = things.at(1); + assert(*thing.item_one == 412, 'thing2.item_one 4'); + assert(*thing.item_two == 421, 'thing2.item_two'); + } + "# + ), + Contract::from_code_path( + "MockChecker".to_string(), + Path::new("tests/data/contracts/mock_checker.cairo"), + ) + .unwrap() + ); + + let result = run_test_case(&test, ForgeTrackedResource::CairoSteps); + assert_passed(&result); +} + +#[test] +fn mock_calls_when() { + let test = test_case!( + indoc!( + r#" + use result::ResultTrait; + use snforge_std::{ declare, ContractClassTrait, DeclareResultTrait, mock_call_when, MockCalldata}; + + #[starknet::interface] + trait IMockChecker { + fn get_thing(ref self: TContractState) -> felt252; + } + + #[test] + fn mock_call_when_one_specific() { + let calldata = array![420]; + + let contract = declare("MockChecker").unwrap().contract_class(); + let (contract_address, _) = contract.deploy(@calldata).unwrap(); + + let dispatcher = IMockCheckerDispatcher { contract_address }; + + let mock_ret_data = 421; + let expected_calldata = MockCalldata::Values([].span()); + mock_call_when(contract_address, selector!("get_thing"), expected_calldata, mock_ret_data, 1); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, 421); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, 420); + } + + #[test] + fn mock_call_when_twice_specific() { + let calldata = array![420]; + + let contract = declare("MockChecker").unwrap().contract_class(); + let (contract_address, _) = contract.deploy(@calldata).unwrap(); + + let dispatcher = IMockCheckerDispatcher { contract_address }; + + let mock_ret_data = 421; + let expected_calldata = MockCalldata::Values([].span()); + mock_call_when(contract_address, selector!("get_thing"), expected_calldata, mock_ret_data, 2); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, 421); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, 421); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, 420); + } + + #[test] + fn mock_call_when_one_any() { + let calldata = array![420]; + + let contract = declare("MockChecker").unwrap().contract_class(); + let (contract_address, _) = contract.deploy(@calldata).unwrap(); + + let dispatcher = IMockCheckerDispatcher { contract_address }; + + let mock_ret_data = 421; + mock_call_when(contract_address, selector!("get_thing"), MockCalldata::Any, mock_ret_data, 1); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, 421); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, 420); + } + + #[test] + fn mock_call_when_twice_any() { + let calldata = array![420]; + + let contract = declare("MockChecker").unwrap().contract_class(); + let (contract_address, _) = contract.deploy(@calldata).unwrap(); + + let dispatcher = IMockCheckerDispatcher { contract_address }; + + let mock_ret_data = 421; + mock_call_when(contract_address, selector!("get_thing"), MockCalldata::Any, mock_ret_data, 2); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, 421); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, 421); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, 420); + } + + "# + ), + Contract::from_code_path( + "MockChecker".to_string(), + Path::new("tests/data/contracts/mock_checker.cairo"), + ) + .unwrap() + ); + + let result = run_test_case(&test, ForgeTrackedResource::CairoSteps); + assert_passed(&result); +} + +#[test] +fn mock_calls_when_mixed() { + let test = test_case!( + indoc!( + r#" + use result::ResultTrait; + use snforge_std::{ declare, ContractClassTrait, DeclareResultTrait, mock_call_when, MockCalldata}; + + #[starknet::interface] + trait IMockChecker { + fn get_thing(ref self: TContractState) -> felt252; + } + + #[test] + fn mock_call_when_one() { + let calldata = array![420]; + + let contract = declare("MockChecker").unwrap().contract_class(); + let (contract_address, _) = contract.deploy(@calldata).unwrap(); + + let dispatcher = IMockCheckerDispatcher { contract_address }; + + let mock_ret_data = 421; + let expected_calldata = MockCalldata::Values([].span()); + mock_call_when(contract_address, selector!("get_thing"), expected_calldata, mock_ret_data, 1); + mock_call_when(contract_address, selector!("get_thing"), MockCalldata::Any, 422, 1); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, 421, "Specific calldata"); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, 422, "Any calldata"); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, 420); + } + + #[test] + fn mock_call_when_multi() { + let calldata = array![420]; + + let contract = declare("MockChecker").unwrap().contract_class(); + let (contract_address, _) = contract.deploy(@calldata).unwrap(); + + let dispatcher = IMockCheckerDispatcher { contract_address }; + + let mock_ret_data = 421; + let expected_calldata = MockCalldata::Values([].span()); + mock_call_when(contract_address, selector!("get_thing"), expected_calldata, mock_ret_data, 3); + mock_call_when(contract_address, selector!("get_thing"), MockCalldata::Any, 422, 2); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, 421, "1st Specific calldata"); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, 421, "2nd Specific calldata"); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, 421, "3rd Specific calldata"); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, 422, "1st Any calldata"); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, 422, "2nd Any calldata"); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, 420); + } + "# + ), + Contract::from_code_path( + "MockChecker".to_string(), + Path::new("tests/data/contracts/mock_checker.cairo"), + ) + .unwrap() + ); + + let result = run_test_case(&test, ForgeTrackedResource::CairoSteps); + assert_passed(&result); +} + +#[test] +fn mock_calls_start_stop_when_mixed() { + let test = test_case!( + indoc!( + r#" + use result::ResultTrait; + use snforge_std::{ declare, ContractClassTrait, DeclareResultTrait, mock_call_when, MockCalldata, start_mock_call, start_mock_call_when, stop_mock_call, stop_mock_call_when}; + + #[starknet::interface] + trait IMockChecker { + fn get_thing(ref self: TContractState) -> felt252; + } + + #[test] + fn mock_calls_start_stop_when_mixed() { + let calldata = array![420]; + + let contract = declare("MockChecker").unwrap().contract_class(); + let (contract_address, _) = contract.deploy(@calldata).unwrap(); + + let dispatcher = IMockCheckerDispatcher { contract_address }; + + let mock_ret_data = 421; + let mock_when_ret_data = 422; + + let expected_calldata = MockCalldata::Values([].span()); + start_mock_call(contract_address, selector!("get_thing"), mock_ret_data); + start_mock_call_when(contract_address, selector!("get_thing"), expected_calldata, mock_when_ret_data); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, mock_when_ret_data, "1st Mock call when"); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, mock_when_ret_data, "2nd Mock call when"); + stop_mock_call_when(contract_address, selector!("get_thing"), expected_calldata); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, mock_ret_data, "Mock call"); + + stop_mock_call(contract_address, selector!("get_thing")); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, 420); + } + + #[test] + fn mock_calls_start_stop_when_count_mixed() { + let calldata = array![420]; + + let contract = declare("MockChecker").unwrap().contract_class(); + let (contract_address, _) = contract.deploy(@calldata).unwrap(); + + let dispatcher = IMockCheckerDispatcher { contract_address }; + + let mock_ret_data = 421; + let mock_when_ret_data = 422; + + let expected_calldata = MockCalldata::Values([].span()); + start_mock_call(contract_address, selector!("get_thing"), mock_ret_data); + mock_call_when(contract_address, selector!("get_thing"), expected_calldata, mock_when_ret_data, 2); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, mock_when_ret_data, "1st Mock call when"); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, mock_when_ret_data, "2nd Mock call when"); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, mock_ret_data, "Mock call"); + + stop_mock_call(contract_address, selector!("get_thing")); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, 420); + } + "# + ), + Contract::from_code_path( + "MockChecker".to_string(), + Path::new("tests/data/contracts/mock_checker.cairo"), + ) + .unwrap() + ); + + let result = run_test_case(&test, ForgeTrackedResource::CairoSteps); + assert_passed(&result); +} + +#[test] +fn mock_calls_start_stop_when_interleaved() { + let test = test_case!( + indoc!( + r#" + use result::ResultTrait; + use snforge_std::{ declare, ContractClassTrait, DeclareResultTrait, MockCalldata, start_mock_call, start_mock_call_when, stop_mock_call, stop_mock_call_when}; + + #[starknet::interface] + trait IMockChecker { + fn get_thing(ref self: TContractState) -> felt252; + } + + #[test] + fn mock_calls_start_when_and_stop() { + let calldata = array![420]; + + let contract = declare("MockChecker").unwrap().contract_class(); + let (contract_address, _) = contract.deploy(@calldata).unwrap(); + + let dispatcher = IMockCheckerDispatcher { contract_address }; + + let mock_when_ret_data = 422; + + let expected_calldata = MockCalldata::Values([].span()); + start_mock_call_when(contract_address, selector!("get_thing"), expected_calldata, mock_when_ret_data); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, mock_when_ret_data, "1st Mock call when"); + + stop_mock_call(contract_address, selector!("get_thing")); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, mock_when_ret_data, "2nd Mock call when"); + stop_mock_call_when(contract_address, selector!("get_thing"), expected_calldata); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, 420); + } + + #[test] + fn mock_calls_start_and_stop_when() { + let calldata = array![420]; + + let contract = declare("MockChecker").unwrap().contract_class(); + let (contract_address, _) = contract.deploy(@calldata).unwrap(); + + let dispatcher = IMockCheckerDispatcher { contract_address }; + + let mock_ret_data = 421; + + let expected_calldata = MockCalldata::Values([].span()); + + start_mock_call(contract_address, selector!("get_thing"), mock_ret_data); + let thing = dispatcher.get_thing(); + assert_eq!(thing, mock_ret_data, "Mock call"); + + stop_mock_call_when(contract_address, selector!("get_thing"), expected_calldata); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, mock_ret_data, "Mock call"); + + stop_mock_call(contract_address, selector!("get_thing")); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, 420); + } + "# + ), + Contract::from_code_path( + "MockChecker".to_string(), + Path::new("tests/data/contracts/mock_checker.cairo"), + ) + .unwrap() + ); + + let result = run_test_case(&test, ForgeTrackedResource::CairoSteps); + assert_passed(&result); +} diff --git a/docs/src/SUMMARY.md b/docs/src/SUMMARY.md index 6d86b478ab..51326c1281 100644 --- a/docs/src/SUMMARY.md +++ b/docs/src/SUMMARY.md @@ -99,6 +99,7 @@ * [fee_data_availability_mode](appendix/cheatcodes/fee_data_availability_mode.md) * [account_deployment_data](appendix/cheatcodes/account_deployment_data.md) * [mock_call](appendix/cheatcodes/mock_call.md) + * [mock_call_when](appendix/cheatcodes/mock_call_when.md) * [get_class_hash](appendix/cheatcodes/get_class_hash.md) * [replace_bytecode](appendix/cheatcodes/replace_bytecode.md) * [l1_handler](appendix/cheatcodes/l1_handler.md) diff --git a/docs/src/appendix/cheatcodes.md b/docs/src/appendix/cheatcodes.md index 388f5bde83..fc555f6c5b 100644 --- a/docs/src/appendix/cheatcodes.md +++ b/docs/src/appendix/cheatcodes.md @@ -4,6 +4,9 @@ - [`mock_call`](cheatcodes/mock_call.md#mock_call) - mocks a number of contract calls to an entry point - [`start_mock_call`](cheatcodes/mock_call.md#start_mock_call) - mocks contract call to an entry point - [`stop_mock_call`](cheatcodes/mock_call.md#stop_mock_call) - cancels the `mock_call` / `start_mock_call` for an entry point +- [`mock_call_when`](cheatcodes/mock_call_when.md#mock_call_when) - mocks a number of contract calls to an entry point for a given calldata +- [`start_mock_call_when`](cheatcodes/mock_call_when.md#start_mock_call_when) - mocks contract call to an entry point for a given calldata +- [`stop_mock_call_when`](cheatcodes/mock_call_when.md#stop_mock_call_when) - cancels the `mock_call_when` / `start_mock_call_when` for an entry point - [`get_class_hash`](cheatcodes/get_class_hash.md) - retrieves a class hash of a contract - [`replace_bytecode`](cheatcodes/replace_bytecode.md) - replace the class hash of a contract - [`l1_handler`](cheatcodes/l1_handler.md) - executes a `#[l1_handler]` function to mock a message arriving from Ethereum diff --git a/docs/src/appendix/cheatcodes/mock_call_when.md b/docs/src/appendix/cheatcodes/mock_call_when.md new file mode 100644 index 0000000000..2bd1a5ad54 --- /dev/null +++ b/docs/src/appendix/cheatcodes/mock_call_when.md @@ -0,0 +1,43 @@ +# `mock_call_when` + +Cheatcodes mocking contract entry point calls: + +## `MockCalldata` + +```rust +pub enum MockCalldata { + Any, + Values: Span, +} +``` + +`MockCalldata` is an enum used to specify for which calldata the contract entry point will be mocked. +- `Any` mock the contract entry point for any calldata. +- `Values` mock the contract entry point only for this calldata. + +## `mock_call_when` +> `fn mock_call_when, impl TDestruct: Destruct>( +> contract_address: ContractAddress, function_selector: felt252, calldata: MockCalldata, ret_data: T, n_times: u32 +> )` + +Mocks contract call to a `function_selector` of a contract at the given address, with the given calldata, for `n_times` first calls that are made +to the contract. +A call to function `function_selector` will return data provided in `ret_data` argument. +An address with no contract can be mocked as well. +An entrypoint that is not present on the deployed contract is also possible to mock. +Note that the function is not meant for mocking internal calls - it works only for contract entry points. + +## `start_mock_call_when` +> `fn start_mock_call, impl TDestruct: Destruct>( +> contract_address: ContractAddress, function_selector: felt252, calldata: MockCalldata, ret_data: T +> )` + +Mocks contract call to a `function_selector` of a contract at the given address, with the given calldata, indefinitely. +See `mock_call_when` for comprehensive definition of how it can be used. + + +### `stop_mock_call_when` + +> `fn stop_mock_call_when(contract_address: ContractAddress, function_selector: felt252, calldata: MockCalldata)` + +Cancels the `mock_call_when` / `start_mock_call_when` for the function `function_selector` of a contract at the given addressn with the given calldata diff --git a/snforge_std/src/cheatcodes.cairo b/snforge_std/src/cheatcodes.cairo index b6823aa7d2..8eead125fa 100644 --- a/snforge_std/src/cheatcodes.cairo +++ b/snforge_std/src/cheatcodes.cairo @@ -23,6 +23,16 @@ pub enum CheatSpan { TargetCalls: NonZero, } +/// Enum used to specify the calldata that should be matched when mocking a contract call. +#[derive(Copy, Drop, PartialEq, Clone, Debug, Serde)] +pub enum MockCalldata { + /// Matches any calldata. + Any, + /// Matches the specified serialized calldata. + Values: Span, +} + + pub fn test_selector() -> felt252 { // Result of selector!("TEST_CONTRACT_SELECTOR") since `selector!` macro requires dependency on // `starknet`. @@ -33,6 +43,7 @@ pub fn test_address() -> ContractAddress { contract_address_const::<469394814521890341860918960550914>() } + /// Mocks contract call to a `function_selector` of a contract at the given address, for `n_times` /// first calls that are made to the contract. /// A call to function `function_selector` will return data provided in `ret_data` argument. @@ -47,12 +58,56 @@ pub fn test_address() -> ContractAddress { /// - `n_times` - number of calls to mock the function for pub fn mock_call, impl TDestruct: Destruct>( contract_address: ContractAddress, function_selector: felt252, ret_data: T, n_times: u32, +) { + mock_call_when(contract_address, function_selector, MockCalldata::Any, ret_data, n_times) +} + +/// Mocks contract call to a function of a contract at the given address, indefinitely. +/// See `mock_call` for comprehensive definition of how it can be used. +/// - `contract_address` - targeted contracts' address +/// - `function_selector` - hashed name of the target function (can be obtained with `selector!` +/// macro) +/// - `ret_data` - data to be returned by the function +pub fn start_mock_call, impl TDestruct: Destruct>( + contract_address: ContractAddress, function_selector: felt252, ret_data: T, +) { + start_mock_call_when(contract_address, function_selector, MockCalldata::Any, ret_data) +} + +/// Cancels the `mock_call` / `start_mock_call` for the function with given name and contract +/// address. +/// - `contract_address` - targeted contracts' address +/// - `function_selector` - hashed name of the target function (can be obtained with `selector!` +/// macro) +pub fn stop_mock_call(contract_address: ContractAddress, function_selector: felt252) { + stop_mock_call_when(contract_address, function_selector, MockCalldata::Any) +} + +/// Mocks contract call to a `function_selector` of a contract at the given address, for `n_times` +/// first calls that are made to the contract. +/// A call to function `function_selector` will return data provided in `ret_data` argument. +/// An address with no contract can be mocked as well. +/// An entrypoint that is not present on the deployed contract is also possible to mock. +/// Note that the function is not meant for mocking internal calls - it works only for contract +/// entry points. +/// - `contract_address` - target contract address +/// - `function_selector` - hashed name of the target function (can be obtained with `selector!` +/// macro) +/// - `calldata` - matching calldata +/// - `ret_data` - data to return by the function `function_selector` +/// - `n_times` - number of calls to mock the function for +pub fn mock_call_when, impl TDestruct: Destruct>( + contract_address: ContractAddress, + function_selector: felt252, + calldata: MockCalldata, + ret_data: T, + n_times: u32, ) { assert!(n_times > 0, "cannot `mock_call` 0 times, `n_times` argument must be greater than 0"); let contract_address_felt: felt252 = contract_address.into(); let mut inputs = array![contract_address_felt, function_selector]; - + calldata.serialize(ref inputs); CheatSpan::TargetCalls(n_times.try_into().expect('`n_times` must be > 0')) .serialize(ref inputs); @@ -70,13 +125,17 @@ pub fn mock_call, impl TDestruct: Destruct /// - `contract_address` - targeted contracts' address /// - `function_selector` - hashed name of the target function (can be obtained with `selector!` /// macro) +/// - `calldata` - matching calldata /// - `ret_data` - data to be returned by the function -pub fn start_mock_call, impl TDestruct: Destruct>( - contract_address: ContractAddress, function_selector: felt252, ret_data: T, +pub fn start_mock_call_when, impl TDestruct: Destruct>( + contract_address: ContractAddress, + function_selector: felt252, + calldata: MockCalldata, + ret_data: T, ) { let contract_address_felt: felt252 = contract_address.into(); let mut inputs = array![contract_address_felt, function_selector]; - + calldata.serialize(ref inputs); CheatSpan::Indefinite.serialize(ref inputs); let mut ret_data_arr = ArrayTrait::new(); @@ -87,16 +146,20 @@ pub fn start_mock_call, impl TDestruct: De execute_cheatcode_and_deserialize::<'mock_call', ()>(inputs.span()); } -/// Cancels the `mock_call` / `start_mock_call` for the function with given name and contract -/// address. +/// Cancels the `mock_call_when` / `start_mock_call_when` for the function with given name and +/// contract address. /// - `contract_address` - targeted contracts' address /// - `function_selector` - hashed name of the target function (can be obtained with `selector!` +/// - `calldata` - matching calldata /// macro) -pub fn stop_mock_call(contract_address: ContractAddress, function_selector: felt252) { +pub fn stop_mock_call_when( + contract_address: ContractAddress, function_selector: felt252, calldata: MockCalldata, +) { let contract_address_felt: felt252 = contract_address.into(); - execute_cheatcode_and_deserialize::< - 'stop_mock_call', (), - >(array![contract_address_felt, function_selector].span()); + let mut inputs = array![contract_address_felt, function_selector]; + calldata.serialize(ref inputs); + + execute_cheatcode_and_deserialize::<'stop_mock_call', ()>(inputs.span()); } #[derive(Drop, Serde, PartialEq, Debug)] diff --git a/snforge_std/src/lib.cairo b/snforge_std/src/lib.cairo index f819af4ebb..9c9690b829 100644 --- a/snforge_std/src/lib.cairo +++ b/snforge_std/src/lib.cairo @@ -108,8 +108,9 @@ pub use cheatcodes::message_to_l1::{ pub use cheatcodes::storage::store; pub use cheatcodes::storage::{interact_with_state, load, map_entry_address}; pub use cheatcodes::{ - ReplaceBytecodeError, mock_call, replace_bytecode, start_mock_call, stop_mock_call, - test_address, test_selector, + MockCalldata, ReplaceBytecodeError, mock_call, mock_call_when, replace_bytecode, + start_mock_call, start_mock_call_when, stop_mock_call, stop_mock_call_when, test_address, + test_selector, }; pub mod byte_array;