diff --git a/CHANGELOG.md b/CHANGELOG.md index a30db20a22..1d30e77a12 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Forge +#### Added + +- `get_current_vm_step` function to get the current step from Cairo VM during test execution. For more see [docs](https://foundry-rs.github.io/starknet-foundry/snforge-library/testing/get_current_vm_step.html) + #### Changed - Gas values in fuzzing test output are now displayed as whole numbers without fractional parts diff --git a/crates/cheatnet/src/runtime_extensions/forge_config_extension.rs b/crates/cheatnet/src/runtime_extensions/forge_config_extension.rs index 82b139887e..2a7b0890b6 100644 --- a/crates/cheatnet/src/runtime_extensions/forge_config_extension.rs +++ b/crates/cheatnet/src/runtime_extensions/forge_config_extension.rs @@ -1,3 +1,4 @@ +use cairo_vm::vm::vm_core::VirtualMachine; use config::RawForgeConfig; use conversions::serde::deserialize::BufferReader; use runtime::{CheatcodeHandlingResult, EnhancedHintError, ExtensionLogic, StarknetRuntime}; @@ -17,6 +18,7 @@ impl<'a> ExtensionLogic for ForgeConfigExtension<'a> { selector: &str, mut input_reader: BufferReader<'_>, _extended_runtime: &mut Self::Runtime, + _vm: &VirtualMachine, ) -> Result { macro_rules! config_cheatcode { ( $prop:ident) => {{ diff --git a/crates/cheatnet/src/runtime_extensions/forge_runtime_extension/mod.rs b/crates/cheatnet/src/runtime_extensions/forge_runtime_extension/mod.rs index 8d854afdb2..2d51e9a532 100644 --- a/crates/cheatnet/src/runtime_extensions/forge_runtime_extension/mod.rs +++ b/crates/cheatnet/src/runtime_extensions/forge_runtime_extension/mod.rs @@ -19,6 +19,7 @@ use crate::runtime_extensions::{ }; use crate::trace_data::{CallTrace, CallTraceNode, GasReportData}; use anyhow::{Context, Result, anyhow}; +use blockifier::blockifier_versioned_constants::VersionedConstants; use blockifier::bouncer::vm_resources_to_sierra_gas; use blockifier::context::TransactionContext; use blockifier::execution::call_info::{ @@ -81,6 +82,7 @@ impl<'a> ExtensionLogic for ForgeExtension<'a> { selector: &str, mut input_reader: BufferReader<'_>, extended_runtime: &mut Self::Runtime, + vm: &VirtualMachine, ) -> Result { if let Some(oracle_selector) = self .oracle_hint_service @@ -552,6 +554,33 @@ impl<'a> ExtensionLogic for ForgeExtension<'a> { .cheat_block_hash(block_number, operation); Ok(CheatcodeHandlingResult::from_serializable(())) } + "get_current_vm_step" => { + // Each contract call is executed in separate VM, hence all VM steps + // are calculated as sum of steps from calls + current VM steps. + // Since syscalls are added to VM resources after the execution, we need + // to include them manually here. + let top_call = extended_runtime + .extended_runtime + .extension + .cheatnet_state + .trace_data + .current_call_stack + .top(); + let vm_steps_from_inner_calls = calculate_vm_steps_from_calls(&top_call); + let top_call_syscalls = &extended_runtime + .extended_runtime + .extended_runtime + .hint_handler + .base + .syscalls_usage; + let vm_steps_from_syscalls = &VersionedConstants::latest_constants() + .get_additional_os_syscall_resources(top_call_syscalls) + .n_steps; + let total_vm_steps = + vm_steps_from_inner_calls + vm_steps_from_syscalls + vm.get_current_step(); + + Ok(CheatcodeHandlingResult::from_serializable(total_vm_steps)) + } _ => Ok(CheatcodeHandlingResult::Forwarded), } } @@ -858,3 +887,19 @@ pub fn get_all_used_resources( l1_handler_payload_lengths, } } + +fn calculate_vm_steps_from_calls(top_call: &Rc>) -> usize { + // Resources from inner calls already include syscall resources used in them + let used_resources = + &top_call + .borrow() + .nested_calls + .iter() + .fold(ExecutionResources::default(), |acc, node| match node { + CallTraceNode::EntryPointCall(call_trace) => { + &acc + &call_trace.borrow().used_execution_resources + } + CallTraceNode::DeployWithoutConstructor => acc, + }); + used_resources.n_steps +} diff --git a/crates/forge/tests/data/contracts/hello_starknet.cairo b/crates/forge/tests/data/contracts/hello_starknet.cairo index 2ab543f589..a8e4c76254 100644 --- a/crates/forge/tests/data/contracts/hello_starknet.cairo +++ b/crates/forge/tests/data/contracts/hello_starknet.cairo @@ -2,6 +2,12 @@ trait IHelloStarknet { fn increase_balance(ref self: TContractState, amount: felt252); fn get_balance(self: @TContractState) -> felt252; + fn call_other_contract( + self: @TContractState, + other_contract_address: felt252, + selector: felt252, + calldata: Option>, + ) -> Span; fn do_a_panic(self: @TContractState); fn do_a_panic_with(self: @TContractState, panic_data: Array); fn do_a_panic_with_bytearray(self: @TContractState); @@ -10,6 +16,7 @@ trait IHelloStarknet { #[starknet::contract] mod HelloStarknet { use array::ArrayTrait; + use starknet::{SyscallResultTrait, syscalls}; #[storage] struct Storage { @@ -28,6 +35,23 @@ mod HelloStarknet { self.balance.read() } + fn call_other_contract( + self: @ContractState, + other_contract_address: felt252, + selector: felt252, + calldata: Option>, + ) -> Span { + syscalls::call_contract_syscall( + other_contract_address.try_into().unwrap(), + selector, + match calldata { + Some(data) => data.span(), + None => array![].span(), + }, + ) + .unwrap_syscall() + } + // Panics fn do_a_panic(self: @ContractState) { let mut arr = ArrayTrait::new(); @@ -43,7 +67,10 @@ mod HelloStarknet { // Panics with a bytearray fn do_a_panic_with_bytearray(self: @ContractState) { - assert!(false, "This is a very long\n and multiline message that is certain to fill the buffer"); + assert!( + false, + "This is a very long\n and multiline message that is certain to fill the buffer", + ); } } } diff --git a/crates/forge/tests/integration/get_current_vm_step.rs b/crates/forge/tests/integration/get_current_vm_step.rs new file mode 100644 index 0000000000..33eca5bcc0 --- /dev/null +++ b/crates/forge/tests/integration/get_current_vm_step.rs @@ -0,0 +1,94 @@ +use crate::utils::runner::{Contract, assert_passed}; +use crate::utils::running_tests::run_test_case; +use crate::utils::test_case; +use forge_runner::forge_config::ForgeTrackedResource; +use indoc::indoc; +use std::path::Path; + +#[test] +fn test_get_current_vm_step() { + let test = test_case!( + indoc!( + r#" + use snforge_std::testing::get_current_vm_step; + use snforge_std::{ContractClassTrait, DeclareResultTrait, declare}; + + + const STEPS_MARGIN: u32 = 100; + + // 1173 = cost of 1 deploy syscall without calldata + const DEPLOY_SYSCALL_STEPS: u32 = 1173; + + // 903 = steps of 1 call contract syscall + const CALL_CONTRACT_SYSCALL_STEPS: u32 = 903; + + // 90 = steps of 1 call contract syscall + const STORAGE_READ_SYSCALL_STEPS: u32 = 90; + + #[test] + fn check_current_vm_step() { + let contract = declare("HelloStarknet").unwrap().contract_class(); + let step_a = get_current_vm_step(); + + let (contract_address_a, _) = contract.deploy(@ArrayTrait::new()).unwrap(); + let (contract_address_b, _) = contract.deploy(@ArrayTrait::new()).unwrap(); + // Sycalls between step_a and step_b: + // top call: 2 x deploy syscall + // inner call: -/- + let step_b = get_current_vm_step(); + + let expected_steps_taken = 2 * DEPLOY_SYSCALL_STEPS + 130; // 130 are steps from VM + let expected_lower = expected_steps_taken + step_a - STEPS_MARGIN; + let expected_upper = expected_steps_taken + step_a + STEPS_MARGIN; + assert!( + expected_lower <= step_b && step_b <= expected_upper, + "step_b ({step_b}) not in [{expected_lower}, {expected_upper}]", + ); + + let dispatcher_a = IHelloStarknetDispatcher { contract_address: contract_address_a }; + + // contract A calls `get_balance` from contract B + let _balance = dispatcher_a + .call_other_contract( + contract_address_b.try_into().unwrap(), selector!("get_balance"), None, + ); + + // Sycalls between step_b and step_c: + // top call: 1 x call contract syscall + // inner calls: 1 x storage read syscall, 1 x call contract syscall + let step_c = get_current_vm_step(); + + let expected_steps_taken = 2 * CALL_CONTRACT_SYSCALL_STEPS + + 1 * STORAGE_READ_SYSCALL_STEPS + + 277; // 277 are steps from VM + let expected_lower = expected_steps_taken + step_b - STEPS_MARGIN; + let expected_upper = expected_steps_taken + step_b + STEPS_MARGIN; + assert!( + expected_lower <= step_c && step_c <= expected_upper, + "step_c ({step_c}) not in [{expected_lower}, {expected_upper}]", + ); + } + + #[starknet::interface] + pub trait IHelloStarknet { + fn get_balance(self: @TContractState) -> felt252; + fn call_other_contract( + self: @TContractState, + other_contract_address: felt252, + selector: felt252, + calldata: Option>, + ) -> Span; + } + "# + ), + Contract::from_code_path( + "HelloStarknet".to_string(), + Path::new("tests/data/contracts/hello_starknet.cairo"), + ) + .unwrap() + ); + + let result = run_test_case(&test, ForgeTrackedResource::CairoSteps); + + assert_passed(&result); +} diff --git a/crates/forge/tests/integration/mod.rs b/crates/forge/tests/integration/mod.rs index a89e5fa30b..dd627c3d76 100644 --- a/crates/forge/tests/integration/mod.rs +++ b/crates/forge/tests/integration/mod.rs @@ -19,6 +19,7 @@ mod gas; mod generate_random_felt; mod get_available_gas; mod get_class_hash; +mod get_current_vm_step; mod interact_with_state; mod l1_handler_executor; mod message_to_l1; diff --git a/crates/runtime/src/lib.rs b/crates/runtime/src/lib.rs index 4cd1882906..fb89e2343e 100644 --- a/crates/runtime/src/lib.rs +++ b/crates/runtime/src/lib.rs @@ -279,6 +279,7 @@ impl ExtendedRuntime { &selector, BufferReader::new(&inputs), &mut self.extended_runtime, + vm, ); let res = match result { @@ -423,6 +424,7 @@ pub trait ExtensionLogic { _selector: &str, _input_reader: BufferReader, _extended_runtime: &mut Self::Runtime, + _vm: &VirtualMachine, ) -> Result { Ok(CheatcodeHandlingResult::Forwarded) } diff --git a/crates/sncast/src/starknet_commands/script/run.rs b/crates/sncast/src/starknet_commands/script/run.rs index ea0b398ee0..b9baeda866 100644 --- a/crates/sncast/src/starknet_commands/script/run.rs +++ b/crates/sncast/src/starknet_commands/script/run.rs @@ -104,6 +104,7 @@ impl<'a> ExtensionLogic for CastScriptExtension<'a> { selector: &str, mut input_reader: BufferReader, _extended_runtime: &mut Self::Runtime, + _vm: &VirtualMachine, ) -> Result { match selector { "call" => { diff --git a/docs/listings/testing_reference/Scarb.toml b/docs/listings/testing_reference/Scarb.toml new file mode 100644 index 0000000000..5e79b4acf9 --- /dev/null +++ b/docs/listings/testing_reference/Scarb.toml @@ -0,0 +1,12 @@ +[package] +name = "testing_reference" +version = "0.1.0" +edition = "2024_07" + +[dependencies] +starknet = "2.12.0" + +[dev-dependencies] +snforge_std = { path = "../../../snforge_std" } + +[[target.starknet-contract]] diff --git a/docs/listings/testing_reference/src/lib.cairo b/docs/listings/testing_reference/src/lib.cairo new file mode 100644 index 0000000000..7dee15a673 --- /dev/null +++ b/docs/listings/testing_reference/src/lib.cairo @@ -0,0 +1,27 @@ +#[starknet::interface] +pub trait ICounter { + fn increment(ref self: TContractState); +} + +#[starknet::contract] +pub mod Counter { + use starknet::storage::{StoragePointerReadAccess, StoragePointerWriteAccess}; + + #[storage] + struct Storage { + i: felt252, + } + + #[constructor] + fn constructor(ref self: ContractState) { + self.i.write(0); + } + + #[abi(embed_v0)] + impl CounterImpl of super::ICounter { + fn increment(ref self: ContractState) { + let current_value = self.i.read(); + self.i.write(current_value + 1); + } + } +} diff --git a/docs/listings/testing_reference/tests/tests.cairo b/docs/listings/testing_reference/tests/tests.cairo new file mode 100644 index 0000000000..4ef4f11a3c --- /dev/null +++ b/docs/listings/testing_reference/tests/tests.cairo @@ -0,0 +1,30 @@ +use snforge_std::testing::get_current_vm_step; +use snforge_std::{ContractClassTrait, DeclareResultTrait, declare}; +use testing_reference::{ICounterSafeDispatcher, ICounterSafeDispatcherTrait}; + +#[feature("safe_dispatcher")] +fn setup() { + // Deploy contract + let (contract_address, _) = declare("Counter") + .unwrap() + .contract_class() + .deploy(@array![]) + .unwrap(); + + let dispatcher = ICounterSafeDispatcher { contract_address }; + + // Increment counter a few times + dispatcher.increment(); + dispatcher.increment(); + dispatcher.increment(); +} + +#[test] +fn test_setup_steps() { + let steps_start = get_current_vm_step(); + setup(); + let steps_end = get_current_vm_step(); + + // Assert that setup used no more than 100 steps + assert!(steps_end - steps_start <= 100); +} diff --git a/docs/src/SUMMARY.md b/docs/src/SUMMARY.md index 111a8991b7..66a1847713 100644 --- a/docs/src/SUMMARY.md +++ b/docs/src/SUMMARY.md @@ -129,6 +129,8 @@ * [env](appendix/snforge-library/env.md) * [signature](appendix/snforge-library/signature.md) * [fuzzable](appendix/snforge-library/fuzzable.md) + * [testing](appendix/snforge-library/testing.md) + * [get_current_vm_step](appendix/snforge-library/testing/get_current_vm_step.md) * [`sncast` Commands](appendix/sncast.md) * [common flags](appendix/sncast/common.md) * [account](appendix/sncast/account/account.md) diff --git a/docs/src/appendix/snforge-library/testing.md b/docs/src/appendix/snforge-library/testing.md new file mode 100644 index 0000000000..a59fc4358e --- /dev/null +++ b/docs/src/appendix/snforge-library/testing.md @@ -0,0 +1,7 @@ +# `testing` Module + +Module containing functions useful for testing. + +## Functions + +* [`get_current_vm_step`](./testing/get_current_vm_step.md) diff --git a/docs/src/appendix/snforge-library/testing/get_current_vm_step.md b/docs/src/appendix/snforge-library/testing/get_current_vm_step.md new file mode 100644 index 0000000000..e071e86ebb --- /dev/null +++ b/docs/src/appendix/snforge-library/testing/get_current_vm_step.md @@ -0,0 +1,49 @@ +# `get_current_vm_step` + +Gets the current step from Cairo VM during test execution. + +```rust +fn get_current_vm_step() -> u32; +``` + +## Example + +Contract code: + +```rust +{{#include ../../../../listings/testing_reference/src/lib.cairo}} +``` + +Test code: + +```rust +{{#include ../../../../listings/testing_reference/tests/tests.cairo}} +``` + + +Let's run the test: +```shell +$ snforge test test_setup_steps +``` + +
+Output: + +```shell +Collected 1 test(s) from testing_reference package +Running 0 test(s) from src/ +Running 1 test(s) from tests/ +[FAIL] testing_reference_integrationtest::tests::test_setup_steps + +Failure data: + "assertion failed: `steps_end - steps_start <= 100`." + +Tests: 0 passed, 1 failed, 0 ignored, 0 filtered out + +Failures: + testing_reference_integrationtest::tests::test_setup_steps +``` +
+
+ +The test fails because the `setup` function exceeded the allowed number of steps (100 in this case). diff --git a/snforge_std/src/lib.cairo b/snforge_std/src/lib.cairo index f819af4ebb..7c1ff760ab 100644 --- a/snforge_std/src/lib.cairo +++ b/snforge_std/src/lib.cairo @@ -126,6 +126,8 @@ pub mod fuzzable; pub mod signature; +pub mod testing; + pub mod trace; #[doc(hidden)] diff --git a/snforge_std/src/testing.cairo b/snforge_std/src/testing.cairo new file mode 100644 index 0000000000..2b8a3801d5 --- /dev/null +++ b/snforge_std/src/testing.cairo @@ -0,0 +1,6 @@ +use crate::cheatcode::execute_cheatcode_and_deserialize; + +/// Gets the current step from Cairo VM during test execution +pub fn get_current_vm_step() -> u32 { + execute_cheatcode_and_deserialize::<'get_current_vm_step', u32>(array![].span()) +}