diff --git a/crates/starknet-devnet-core/src/starknet/defaulter.rs b/crates/starknet-devnet-core/src/starknet/defaulter.rs index aedd2e610..9db89cb6f 100644 --- a/crates/starknet-devnet-core/src/starknet/defaulter.rs +++ b/crates/starknet-devnet-core/src/starknet/defaulter.rs @@ -1,3 +1,7 @@ +use std::collections::HashMap; +use std::ops::Deref; +use std::sync::{Arc, LazyLock, RwLock}; + use blockifier::execution::contract_class::RunnableCompiledClass; use blockifier::state::errors::StateError; use blockifier::state::state_api::StateResult; @@ -6,7 +10,8 @@ use starknet_api::state::StorageKey; use starknet_rs_core::types::Felt; use starknet_types::contract_class::convert_codegen_to_blockifier_compiled_class; use tokio::sync::oneshot; -use tracing::debug; +use tracing::{debug, trace}; +use url::Url; use super::starknet_config::ForkConfig; @@ -37,17 +42,31 @@ impl OriginError { } } +/// ORIGIN READER +pub trait OriginReader: std::fmt::Debug + Send + Sync { + fn get_storage_at( + &self, + contract_address: ContractAddress, + key: StorageKey, + ) -> StateResult; + fn get_nonce_at(&self, contract_address: ContractAddress) -> StateResult; + fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult; + fn get_class_hash_at(&self, contract_address: ContractAddress) -> StateResult; +} + +/// NODE ORIGIN READER + /// Used for interacting with the origin in forking mode. The calls are blocking. Only handles the /// basic state reading necessary for contract interaction. For other RPC methods, see /// `OriginForwarder` #[derive(Debug, Clone)] -struct BlockingOriginReader { +pub struct NodeApiOriginReader { url: url::Url, block_number: u64, client: reqwest::Client, } -impl BlockingOriginReader { +impl NodeApiOriginReader { fn new(url: url::Url, block_number: u64) -> Self { Self { url, block_number, client: reqwest::Client::new() } } @@ -142,71 +161,7 @@ impl BlockingOriginReader { } } -/// Used for forking - reads from the origin if Some(origin_reader); otherwise returns the default -/// or Err, depending on the method -#[derive(Clone, Debug, Default)] -pub struct StarknetDefaulter { - origin_reader: Option, -} - -impl StarknetDefaulter { - pub fn new(fork_config: ForkConfig) -> Self { - let origin_reader = - if let (Some(fork_url), Some(block)) = (fork_config.url, fork_config.block_number) { - Some(BlockingOriginReader::new(fork_url, block)) - } else { - None - }; - Self { origin_reader } - } - - pub fn get_storage_at( - &self, - contract_address: ContractAddress, - key: StorageKey, - ) -> StateResult { - if let Some(origin) = &self.origin_reader { - origin.get_storage_at(contract_address, key) - } else { - Ok(Default::default()) - } - } - - pub fn get_nonce_at(&self, contract_address: ContractAddress) -> StateResult { - if let Some(origin) = &self.origin_reader { - origin.get_nonce_at(contract_address) - } else { - Ok(Default::default()) - } - } - - pub fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult { - if let Some(origin) = &self.origin_reader { - origin.get_compiled_class(class_hash) - } else { - Err(StateError::UndeclaredClassHash(class_hash)) - } - } - - pub fn get_class_hash_at(&self, contract_address: ContractAddress) -> StateResult { - if let Some(origin) = &self.origin_reader { - origin.get_class_hash_at(contract_address) - } else { - Ok(Default::default()) - } - } -} - -fn convert_json_value_to_felt(json_value: serde_json::Value) -> StateResult { - serde_json::from_value(json_value).map_err(|e| StateError::StateReadError(e.to_string())) -} - -fn convert_patricia_key_to_hex(key: PatriciaKey) -> String { - key.key().to_hex_string() -} - -// Same as StateReader, but with &self instead of &mut self -impl BlockingOriginReader { +impl OriginReader for NodeApiOriginReader { fn get_storage_at( &self, contract_address: ContractAddress, @@ -274,3 +229,196 @@ impl BlockingOriginReader { } } } + +/// EMPTY ORIGIN READER + +#[derive(Debug, Clone)] +pub struct EmptyOriginReader; + +impl OriginReader for EmptyOriginReader { + fn get_storage_at( + &self, + _contract_address: ContractAddress, + _key: StorageKey, + ) -> StateResult { + Ok(Default::default()) + } + + fn get_nonce_at(&self, _contract_address: ContractAddress) -> StateResult { + Ok(Default::default()) + } + + fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult { + Err(StateError::UndeclaredClassHash(class_hash)) + } + + fn get_class_hash_at(&self, _contract_address: ContractAddress) -> StateResult { + Ok(Default::default()) + } +} + +#[derive(Debug, Clone)] +pub struct StarknetDefaulter { + reader: Arc, +} + +type StarknetDefaulterFactory = fn(Url, u64) -> StarknetDefaulter; + +impl StarknetDefaulter { + pub fn create_node_api_defaulter(url: Url, block_number: u64) -> Self { + Self { reader: Arc::new(NodeApiOriginReader::new(url, block_number)) } + } + + pub fn create_empty_defaulter() -> Self { + Self { reader: Arc::new(EmptyOriginReader {}) } + } + + pub fn new_with_reader(reader: Arc) -> Self { + Self { reader } + } +} + +static STARKNET_DEFAULTERS: LazyLock>> = + LazyLock::new(|| { + let mut m = HashMap::new(); + m.insert("http", StarknetDefaulter::create_node_api_defaulter as StarknetDefaulterFactory); + m.insert("https", StarknetDefaulter::create_node_api_defaulter as StarknetDefaulterFactory); + RwLock::new(m) + }); + +impl StarknetDefaulter { + pub fn register_defaulter( + scheme: &'static str, + factory: StarknetDefaulterFactory, + ) -> Result<(), String> { + { + let defaulters = STARKNET_DEFAULTERS.read().map_err(|_| "Lock error")?; + if defaulters.contains_key(scheme) { + return Err(format!("Defaulter for scheme '{scheme}' already exists")); + } + } + + let mut defaulters = STARKNET_DEFAULTERS.write().map_err(|_| "Lock error")?; + defaulters.insert(scheme, factory); + Ok(()) + } + + pub fn new(fork_config: ForkConfig) -> Self { + if let (Some(url), Some(block_number)) = (fork_config.url, fork_config.block_number) { + let defaulters = STARKNET_DEFAULTERS.read().unwrap(); // Lock the mutex to access the map + + if let Some(factory) = defaulters.get(url.scheme()) { + factory(url, block_number) + } else { + Self::create_empty_defaulter() + } + } else { + Self::create_empty_defaulter() + } + } +} + +impl Default for StarknetDefaulter { + fn default() -> Self { + Self::create_empty_defaulter() + } +} + +impl OriginReader for StarknetDefaulter { + fn get_storage_at( + &self, + contract_address: ContractAddress, + key: StorageKey, + ) -> StateResult { + self.reader + .get_storage_at(contract_address, key) + .map_err(|err| { + debug!( + contract_address = contract_address.to_hex_string(), + key = key.to_hex_string(), + error = format!("{}", err), + "OriginReader::get_storage_at failed", + ); + err + }) + .map(|res| { + debug!( + contract_address = contract_address.to_hex_string(), + key = key.to_hex_string(), + res = res.to_hex_string(), + "OriginReader::get_storage_at success", + ); + res + }) + } + + fn get_nonce_at(&self, contract_address: ContractAddress) -> StateResult { + self.reader + .get_nonce_at(contract_address) + .map_err(|err| { + debug!( + contract_address = contract_address.to_hex_string(), + error = format!("{}", err), + "OriginReader::get_nonce_at failed", + ); + err + }) + .map(|res| { + debug!( + contract_address = contract_address.to_hex_string(), + res = res.to_hex_string(), + "OriginReader::get_nonce_at succeeded", + ); + res + }) + } + + fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult { + self.reader + .get_compiled_class(class_hash) + .map_err(|err| { + debug!( + class_hash = class_hash.to_hex_string(), + error = format!("{}", err), + "OriginReader::get_compiled_class failed", + ); + err + }) + .map(|res| { + debug!( + class_hash = class_hash.to_hex_string(), + "OriginReader::get_compiled_class succeeded", + ); + res + }) + } + + fn get_class_hash_at(&self, contract_address: ContractAddress) -> StateResult { + self.reader + .get_class_hash_at(contract_address) + .map_err(|err| { + debug!( + contract_address = contract_address.to_hex_string(), + error = format!("{}", err), + "OriginReader::get_class_hash_at failed", + ); + err + }) + .map(|res| { + debug!( + contract_address = contract_address.to_hex_string(), + res = res.to_hex_string(), + "OriginReader::get_class_hash_at succeeded", + ); + res + }) + } +} + +fn convert_json_value_to_felt(json_value: serde_json::Value) -> StateResult { + serde_json::from_value(json_value).map_err(|e| StateError::StateReadError(e.to_string())) +} + +fn convert_patricia_key_to_hex(key: PatriciaKey) -> String { + key.key().to_hex_string() +} diff --git a/crates/starknet-devnet-core/src/starknet/mod.rs b/crates/starknet-devnet-core/src/starknet/mod.rs index 50526b227..4f7698b78 100644 --- a/crates/starknet-devnet-core/src/starknet/mod.rs +++ b/crates/starknet-devnet-core/src/starknet/mod.rs @@ -85,7 +85,7 @@ mod add_deploy_account_transaction; mod add_invoke_transaction; mod add_l1_handler_transaction; mod cheats; -pub(crate) mod defaulter; +pub mod defaulter; mod estimations; pub mod events; mod get_class_impls; @@ -156,6 +156,7 @@ impl Default for Starknet { impl Starknet { pub fn new(config: &StarknetConfig) -> DevnetResult { let defaulter = StarknetDefaulter::new(config.fork_config.clone()); + let rpc_contract_classes = Arc::new(RwLock::new(CommittedClassStorage::default())); let mut state = StarknetState::new(defaulter, rpc_contract_classes.clone()); diff --git a/crates/starknet-devnet-core/src/state/state_readers.rs b/crates/starknet-devnet-core/src/state/state_readers.rs index 300ad4a5d..6e77b6ae0 100644 --- a/crates/starknet-devnet-core/src/state/state_readers.rs +++ b/crates/starknet-devnet-core/src/state/state_readers.rs @@ -8,7 +8,7 @@ use starknet_api::core::{ClassHash, CompiledClassHash, ContractAddress, Nonce}; use starknet_api::state::StorageKey; use starknet_rs_core::types::Felt; -use crate::starknet::defaulter::StarknetDefaulter; +use crate::starknet::defaulter::{OriginReader, StarknetDefaulter}; /// A simple implementation of `StateReader` using `HashMap`s as storage. /// Copied from blockifier test_utils, added `impl State`