Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
284 changes: 216 additions & 68 deletions crates/starknet-devnet-core/src/starknet/defaulter.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
use std::collections::HashMap;
use std::ops::Deref;
use std::sync::{Arc, LazyLock, RwLock};

use blockifier::execution::contract_class::RunnableCompiledClass;
use blockifier::state::errors::StateError;
use blockifier::state::state_api::StateResult;
Expand All @@ -6,7 +10,8 @@ use starknet_api::state::StorageKey;
use starknet_rs_core::types::Felt;
use starknet_types::contract_class::convert_codegen_to_blockifier_compiled_class;
use tokio::sync::oneshot;
use tracing::debug;
use tracing::{debug, trace};
use url::Url;

use super::starknet_config::ForkConfig;

Expand Down Expand Up @@ -37,17 +42,31 @@ impl OriginError {
}
}

/// ORIGIN READER
pub trait OriginReader: std::fmt::Debug + Send + Sync {
fn get_storage_at(
&self,
contract_address: ContractAddress,
key: StorageKey,
) -> StateResult<Felt>;
fn get_nonce_at(&self, contract_address: ContractAddress) -> StateResult<Nonce>;
fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult<RunnableCompiledClass>;
fn get_class_hash_at(&self, contract_address: ContractAddress) -> StateResult<ClassHash>;
}

/// NODE ORIGIN READER

/// Used for interacting with the origin in forking mode. The calls are blocking. Only handles the
/// basic state reading necessary for contract interaction. For other RPC methods, see
/// `OriginForwarder`
#[derive(Debug, Clone)]
struct BlockingOriginReader {
pub struct NodeApiOriginReader {
url: url::Url,
block_number: u64,
client: reqwest::Client,
}

impl BlockingOriginReader {
impl NodeApiOriginReader {
fn new(url: url::Url, block_number: u64) -> Self {
Self { url, block_number, client: reqwest::Client::new() }
}
Expand Down Expand Up @@ -142,71 +161,7 @@ impl BlockingOriginReader {
}
}

/// Used for forking - reads from the origin if Some(origin_reader); otherwise returns the default
/// or Err, depending on the method
#[derive(Clone, Debug, Default)]
pub struct StarknetDefaulter {
origin_reader: Option<BlockingOriginReader>,
}

impl StarknetDefaulter {
pub fn new(fork_config: ForkConfig) -> Self {
let origin_reader =
if let (Some(fork_url), Some(block)) = (fork_config.url, fork_config.block_number) {
Some(BlockingOriginReader::new(fork_url, block))
} else {
None
};
Self { origin_reader }
}

pub fn get_storage_at(
&self,
contract_address: ContractAddress,
key: StorageKey,
) -> StateResult<Felt> {
if let Some(origin) = &self.origin_reader {
origin.get_storage_at(contract_address, key)
} else {
Ok(Default::default())
}
}

pub fn get_nonce_at(&self, contract_address: ContractAddress) -> StateResult<Nonce> {
if let Some(origin) = &self.origin_reader {
origin.get_nonce_at(contract_address)
} else {
Ok(Default::default())
}
}

pub fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult<RunnableCompiledClass> {
if let Some(origin) = &self.origin_reader {
origin.get_compiled_class(class_hash)
} else {
Err(StateError::UndeclaredClassHash(class_hash))
}
}

pub fn get_class_hash_at(&self, contract_address: ContractAddress) -> StateResult<ClassHash> {
if let Some(origin) = &self.origin_reader {
origin.get_class_hash_at(contract_address)
} else {
Ok(Default::default())
}
}
}

fn convert_json_value_to_felt(json_value: serde_json::Value) -> StateResult<Felt> {
serde_json::from_value(json_value).map_err(|e| StateError::StateReadError(e.to_string()))
}

fn convert_patricia_key_to_hex(key: PatriciaKey) -> String {
key.key().to_hex_string()
}

// Same as StateReader, but with &self instead of &mut self
impl BlockingOriginReader {
impl OriginReader for NodeApiOriginReader {
fn get_storage_at(
&self,
contract_address: ContractAddress,
Expand Down Expand Up @@ -274,3 +229,196 @@ impl BlockingOriginReader {
}
}
}

/// EMPTY ORIGIN READER

#[derive(Debug, Clone)]
pub struct EmptyOriginReader;

impl OriginReader for EmptyOriginReader {
fn get_storage_at(
&self,
_contract_address: ContractAddress,
_key: StorageKey,
) -> StateResult<Felt> {
Ok(Default::default())
}

fn get_nonce_at(&self, _contract_address: ContractAddress) -> StateResult<Nonce> {
Ok(Default::default())
}

fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult<RunnableCompiledClass> {
Err(StateError::UndeclaredClassHash(class_hash))
}

fn get_class_hash_at(&self, _contract_address: ContractAddress) -> StateResult<ClassHash> {
Ok(Default::default())
}
}

#[derive(Debug, Clone)]
pub struct StarknetDefaulter {
reader: Arc<dyn OriginReader>,
}

type StarknetDefaulterFactory = fn(Url, u64) -> StarknetDefaulter;

impl StarknetDefaulter {
pub fn create_node_api_defaulter(url: Url, block_number: u64) -> Self {
Self { reader: Arc::new(NodeApiOriginReader::new(url, block_number)) }
}

pub fn create_empty_defaulter() -> Self {
Self { reader: Arc::new(EmptyOriginReader {}) }
}

pub fn new_with_reader(reader: Arc<dyn OriginReader>) -> Self {
Self { reader }
}
}

static STARKNET_DEFAULTERS: LazyLock<RwLock<HashMap<&str, StarknetDefaulterFactory>>> =
LazyLock::new(|| {
let mut m = HashMap::new();
m.insert("http", StarknetDefaulter::create_node_api_defaulter as StarknetDefaulterFactory);
m.insert("https", StarknetDefaulter::create_node_api_defaulter as StarknetDefaulterFactory);
RwLock::new(m)
});

impl StarknetDefaulter {
pub fn register_defaulter(
scheme: &'static str,
factory: StarknetDefaulterFactory,
) -> Result<(), String> {
{
let defaulters = STARKNET_DEFAULTERS.read().map_err(|_| "Lock error")?;
if defaulters.contains_key(scheme) {
return Err(format!("Defaulter for scheme '{scheme}' already exists"));
}
}

let mut defaulters = STARKNET_DEFAULTERS.write().map_err(|_| "Lock error")?;
defaulters.insert(scheme, factory);
Ok(())
}

pub fn new(fork_config: ForkConfig) -> Self {
if let (Some(url), Some(block_number)) = (fork_config.url, fork_config.block_number) {
let defaulters = STARKNET_DEFAULTERS.read().unwrap(); // Lock the mutex to access the map

if let Some(factory) = defaulters.get(url.scheme()) {
factory(url, block_number)
} else {
Self::create_empty_defaulter()
}
} else {
Self::create_empty_defaulter()
}
}
}

impl Default for StarknetDefaulter {
fn default() -> Self {
Self::create_empty_defaulter()
}
}

impl OriginReader for StarknetDefaulter {
fn get_storage_at(
&self,
contract_address: ContractAddress,
key: StorageKey,
) -> StateResult<Felt> {
self.reader
.get_storage_at(contract_address, key)
.map_err(|err| {
debug!(
contract_address = contract_address.to_hex_string(),
key = key.to_hex_string(),
error = format!("{}", err),
"OriginReader::get_storage_at failed",
);
err
})
.map(|res| {
debug!(
contract_address = contract_address.to_hex_string(),
key = key.to_hex_string(),
res = res.to_hex_string(),
"OriginReader::get_storage_at success",
);
res
})
}

fn get_nonce_at(&self, contract_address: ContractAddress) -> StateResult<Nonce> {
self.reader
.get_nonce_at(contract_address)
.map_err(|err| {
debug!(
contract_address = contract_address.to_hex_string(),
error = format!("{}", err),
"OriginReader::get_nonce_at failed",
);
err
})
.map(|res| {
debug!(
contract_address = contract_address.to_hex_string(),
res = res.to_hex_string(),
"OriginReader::get_nonce_at succeeded",
);
res
})
}

fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult<RunnableCompiledClass> {
self.reader
.get_compiled_class(class_hash)
.map_err(|err| {
debug!(
class_hash = class_hash.to_hex_string(),
error = format!("{}", err),
"OriginReader::get_compiled_class failed",
);
err
})
.map(|res| {
debug!(
class_hash = class_hash.to_hex_string(),
"OriginReader::get_compiled_class succeeded",
);
res
})
}

fn get_class_hash_at(&self, contract_address: ContractAddress) -> StateResult<ClassHash> {
self.reader
.get_class_hash_at(contract_address)
.map_err(|err| {
debug!(
contract_address = contract_address.to_hex_string(),
error = format!("{}", err),
"OriginReader::get_class_hash_at failed",
);
err
})
.map(|res| {
debug!(
contract_address = contract_address.to_hex_string(),
res = res.to_hex_string(),
"OriginReader::get_class_hash_at succeeded",
);
res
})
}
}

fn convert_json_value_to_felt(json_value: serde_json::Value) -> StateResult<Felt> {
serde_json::from_value(json_value).map_err(|e| StateError::StateReadError(e.to_string()))
}

fn convert_patricia_key_to_hex(key: PatriciaKey) -> String {
key.key().to_hex_string()
}
3 changes: 2 additions & 1 deletion crates/starknet-devnet-core/src/starknet/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ mod add_deploy_account_transaction;
mod add_invoke_transaction;
mod add_l1_handler_transaction;
mod cheats;
pub(crate) mod defaulter;
pub mod defaulter;
mod estimations;
pub mod events;
mod get_class_impls;
Expand Down Expand Up @@ -156,6 +156,7 @@ impl Default for Starknet {
impl Starknet {
pub fn new(config: &StarknetConfig) -> DevnetResult<Self> {
let defaulter = StarknetDefaulter::new(config.fork_config.clone());

let rpc_contract_classes = Arc::new(RwLock::new(CommittedClassStorage::default()));
let mut state = StarknetState::new(defaulter, rpc_contract_classes.clone());

Expand Down
2 changes: 1 addition & 1 deletion crates/starknet-devnet-core/src/state/state_readers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use starknet_api::core::{ClassHash, CompiledClassHash, ContractAddress, Nonce};
use starknet_api::state::StorageKey;
use starknet_rs_core::types::Felt;

use crate::starknet::defaulter::StarknetDefaulter;
use crate::starknet::defaulter::{OriginReader, StarknetDefaulter};

/// A simple implementation of `StateReader` using `HashMap`s as storage.
/// Copied from blockifier test_utils, added `impl State`
Expand Down