diff --git a/crates/core/src/backend/storage.rs b/crates/core/src/backend/storage.rs index 63cb36c0e..6d526058c 100644 --- a/crates/core/src/backend/storage.rs +++ b/crates/core/src/backend/storage.rs @@ -141,7 +141,7 @@ impl Blockchain { // TODO: convert this to block number instead of BlockHashOrNumber so that it is easier to // check if the requested block is within the supported range or not. - let database = ForkedProvider::new(db, block_id, provider.clone()); + let database = ForkedProvider::new(db, block_num, provider.clone()); // initialize parent fork block // diff --git a/crates/primitives/src/block.rs b/crates/primitives/src/block.rs index 15866a6ea..708abfb7e 100644 --- a/crates/primitives/src/block.rs +++ b/crates/primitives/src/block.rs @@ -52,7 +52,7 @@ pub enum ConfirmedBlockIdOrTag { L1Accepted, } -#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub enum BlockHashOrNumber { Hash(BlockHash), diff --git a/crates/storage/fork/src/lib.rs b/crates/storage/fork/src/lib.rs index 97a53c45e..78ee06856 100644 --- a/crates/storage/fork/src/lib.rs +++ b/crates/storage/fork/src/lib.rs @@ -14,7 +14,7 @@ use futures::channel::mpsc::{channel as async_channel, Receiver, SendError, Send use futures::future::BoxFuture; use futures::stream::Stream; use futures::{Future, FutureExt}; -use katana_primitives::block::{BlockHashOrNumber, BlockIdOrTag}; +use katana_primitives::block::{BlockIdOrTag, BlockNumber}; use katana_primitives::class::{ ClassHash, CompiledClassHash, ComputeClassHashError, ContractClass, ContractClassCompilationError, @@ -68,10 +68,10 @@ struct Request

{ /// Each request consists of a payload and the sender half of a oneshot channel that will be used /// to send the result back to the backend handle. enum BackendRequest { - Nonce(Request), - Class(Request), - ClassHash(Request), - Storage(Request<(ContractAddress, StorageKey)>), + Nonce(Request<(ContractAddress, BlockNumber)>), + Class(Request<(ClassHash, BlockNumber)>), + ClassHash(Request<(ContractAddress, BlockNumber)>), + Storage(Request<((ContractAddress, StorageKey), BlockNumber)>), // Test-only request kind for requesting the backend stats #[cfg(test)] Stats(OneshotSender), @@ -79,30 +79,40 @@ enum BackendRequest { impl BackendRequest { /// Create a new request for fetching the nonce of a contract. - fn nonce(address: ContractAddress) -> (BackendRequest, OneshotReceiver) { + fn nonce( + address: ContractAddress, + block_id: BlockNumber, + ) -> (BackendRequest, OneshotReceiver) { let (sender, receiver) = oneshot(); - (BackendRequest::Nonce(Request { payload: address, sender }), receiver) + (BackendRequest::Nonce(Request { payload: (address, block_id), sender }), receiver) } /// Create a new request for fetching the class definitions of a contract. - fn class(hash: ClassHash) -> (BackendRequest, OneshotReceiver) { + fn class( + hash: ClassHash, + block_id: BlockNumber, + ) -> (BackendRequest, OneshotReceiver) { let (sender, receiver) = oneshot(); - (BackendRequest::Class(Request { payload: hash, sender }), receiver) + (BackendRequest::Class(Request { payload: (hash, block_id), sender }), receiver) } /// Create a new request for fetching the class hash of a contract. - fn class_hash(address: ContractAddress) -> (BackendRequest, OneshotReceiver) { + fn class_hash( + address: ContractAddress, + block_id: BlockNumber, + ) -> (BackendRequest, OneshotReceiver) { let (sender, receiver) = oneshot(); - (BackendRequest::ClassHash(Request { payload: address, sender }), receiver) + (BackendRequest::ClassHash(Request { payload: (address, block_id), sender }), receiver) } /// Create a new request for fetching the storage value of a contract. fn storage( address: ContractAddress, key: StorageKey, + block_id: BlockNumber, ) -> (BackendRequest, OneshotReceiver) { let (sender, receiver) = oneshot(); - (BackendRequest::Storage(Request { payload: (address, key), sender }), receiver) + (BackendRequest::Storage(Request { payload: ((address, key), block_id), sender }), receiver) } #[cfg(test)] @@ -118,10 +128,10 @@ type BackendRequestFuture = BoxFuture<'static, BackendResponse>; // This is used for request deduplication. #[derive(Eq, Hash, PartialEq, Clone, Copy, Debug)] enum BackendRequestIdentifier { - Nonce(ContractAddress), - Class(ClassHash), - ClassHash(ContractAddress), - Storage((ContractAddress, StorageKey)), + Nonce(ContractAddress, BlockNumber), + Class(ClassHash, BlockNumber), + ClassHash(ContractAddress, BlockNumber), + Storage((ContractAddress, StorageKey), BlockNumber), } /// The backend for the forked provider. @@ -139,8 +149,6 @@ pub struct Backend { queued_requests: VecDeque, /// A channel for receiving requests from the [BackendHandle]s. incoming: Receiver, - /// Pinned block id for all requests. - block_id: BlockHashOrNumber, } ///////////////////////////////////////////////////////////////// @@ -150,14 +158,11 @@ pub struct Backend { impl Backend { // TODO(kariy): create a `.start()` method start running the backend logic and let the users // choose which thread to running it on instead of spawning the thread ourselves. - /// Create a new [Backend] with the given provider and block id, and returns a handle to it. The - /// backend will start processing requests immediately upon creation. + /// Create a new [Backend] with the given provider and returns a handle to it. The backend + /// will start processing requests immediately upon creation. #[allow(clippy::new_ret_no_self)] - pub fn new( - provider: StarknetClient, - block_id: BlockHashOrNumber, - ) -> Result { - let (handle, backend) = Self::new_inner(provider, block_id); + pub fn new(provider: StarknetClient) -> Result { + let (handle, backend) = Self::new_inner(provider); thread::Builder::new() .name("forking-backend".into()) @@ -175,14 +180,10 @@ impl Backend { Ok(handle) } - fn new_inner( - provider: StarknetClient, - block_id: BlockHashOrNumber, - ) -> (BackendClient, Backend) { + fn new_inner(provider: StarknetClient) -> (BackendClient, Backend) { // Create async channel to receive requests from the handle. let (tx, rx) = async_channel(100); let backend = Backend { - block_id, incoming: rx, provider: Arc::new(provider), request_dedup_map: HashMap::new(), @@ -196,20 +197,20 @@ impl Backend { /// This method is responsible for transforming the incoming request /// sent from a [BackendHandle] into a RPC request to the remote network. fn handle_requests(&mut self, request: BackendRequest) { - let block_id = BlockIdOrTag::from(self.block_id); let provider = self.provider.clone(); // Check if there are similar requests in the queue before sending the request match request { - BackendRequest::Nonce(Request { payload, sender }) => { - let req_key = BackendRequestIdentifier::Nonce(payload); + BackendRequest::Nonce(Request { payload: (address, block_id), sender }) => { + let req_key = BackendRequestIdentifier::Nonce(address, block_id); + let block_id = BlockIdOrTag::from(block_id); self.dedup_request( req_key, sender, Box::pin(async move { let res = provider - .get_nonce(block_id, payload) + .get_nonce(block_id, address) .await .map_err(|e| BackendError::StarknetProvider(Arc::new(e))); @@ -218,8 +219,9 @@ impl Backend { ); } - BackendRequest::Storage(Request { payload: (addr, key), sender }) => { - let req_key = BackendRequestIdentifier::Storage((addr, key)); + BackendRequest::Storage(Request { payload: ((addr, key), block_id), sender }) => { + let req_key = BackendRequestIdentifier::Storage((addr, key), block_id); + let block_id = BlockIdOrTag::from(block_id); self.dedup_request( req_key, @@ -235,15 +237,16 @@ impl Backend { ); } - BackendRequest::ClassHash(Request { payload, sender }) => { - let req_key = BackendRequestIdentifier::ClassHash(payload); + BackendRequest::ClassHash(Request { payload: (address, block_id), sender }) => { + let req_key = BackendRequestIdentifier::ClassHash(address, block_id); + let block_id = BlockIdOrTag::from(block_id); self.dedup_request( req_key, sender, Box::pin(async move { let res = provider - .get_class_hash_at(block_id, payload) + .get_class_hash_at(block_id, address) .await .map_err(|e| BackendError::StarknetProvider(Arc::new(e))); @@ -252,15 +255,16 @@ impl Backend { ); } - BackendRequest::Class(Request { payload, sender }) => { - let req_key = BackendRequestIdentifier::Class(payload); + BackendRequest::Class(Request { payload: (hash, block_id), sender }) => { + let req_key = BackendRequestIdentifier::Class(hash, block_id); + let block_id = BlockIdOrTag::from(block_id); self.dedup_request( req_key, sender, Box::pin(async move { let res = provider - .get_class(block_id, payload) + .get_class(block_id, hash) .await .map_err(|e| BackendError::StarknetProvider(Arc::new(e))); @@ -370,7 +374,6 @@ impl Debug for Backend { .field("pending_requests", &self.pending_requests.len()) .field("queued_requests", &self.queued_requests.len()) .field("incoming", &self.incoming) - .field("block", &self.block_id) .finish() } } @@ -417,9 +420,13 @@ impl Clone for BackendClient { ///////////////////////////////////////////////////////////////// impl BackendClient { - pub fn get_nonce(&self, address: ContractAddress) -> Result, BackendClientError> { + pub fn get_nonce( + &self, + address: ContractAddress, + block_id: BlockNumber, + ) -> Result, BackendClientError> { trace!(target: LOG_TARGET, %address, "Requesting contract nonce."); - let (req, rx) = BackendRequest::nonce(address); + let (req, rx) = BackendRequest::nonce(address, block_id); self.request(req)?; match rx.recv()? { BackendResponse::Nonce(res) => handle_not_found_err(res), @@ -431,9 +438,10 @@ impl BackendClient { &self, address: ContractAddress, key: StorageKey, + block_id: BlockNumber, ) -> Result, BackendClientError> { trace!(target: LOG_TARGET, %address, key = %format!("{key:#x}"), "Requesting contract storage."); - let (req, rx) = BackendRequest::storage(address, key); + let (req, rx) = BackendRequest::storage(address, key, block_id); self.request(req)?; match rx.recv()? { BackendResponse::Storage(res) => handle_not_found_err(res), @@ -444,9 +452,10 @@ impl BackendClient { pub fn get_class_hash_at( &self, address: ContractAddress, + block_id: BlockNumber, ) -> Result, BackendClientError> { trace!(target: LOG_TARGET, %address, "Requesting contract class hash."); - let (req, rx) = BackendRequest::class_hash(address); + let (req, rx) = BackendRequest::class_hash(address, block_id); self.request(req)?; match rx.recv()? { BackendResponse::ClassHashAt(res) => handle_not_found_err(res), @@ -457,9 +466,10 @@ impl BackendClient { pub fn get_class_at( &self, class_hash: ClassHash, + block_id: BlockNumber, ) -> Result, BackendClientError> { trace!(target: LOG_TARGET, class_hash = %format!("{class_hash:#x}"), "Requesting class."); - let (req, rx) = BackendRequest::class(class_hash); + let (req, rx) = BackendRequest::class(class_hash, block_id); self.request(req)?; match rx.recv()? { BackendResponse::ClassAt(res) => { @@ -476,9 +486,10 @@ impl BackendClient { pub fn get_compiled_class_hash( &self, class_hash: ClassHash, + block_id: BlockNumber, ) -> Result, BackendClientError> { trace!(target: LOG_TARGET, class_hash = %format!("{class_hash:#x}"), "Requesting compiled class hash."); - if let Some(class) = self.get_class_at(class_hash)? { + if let Some(class) = self.get_class_at(class_hash, block_id)? { let class = class.compile()?; Ok(Some(class.class_hash()?)) } else { @@ -526,7 +537,6 @@ pub(crate) mod test_utils { use std::sync::mpsc::{sync_channel, SyncSender}; - use katana_primitives::block::BlockNumber; use katana_rpc_client::HttpClientBuilder; use serde_json::{json, Value}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; @@ -535,10 +545,10 @@ pub(crate) mod test_utils { use super::*; - pub fn create_forked_backend(rpc_url: &str, block_num: BlockNumber) -> BackendClient { + pub fn create_forked_backend(rpc_url: &str) -> BackendClient { let url = Url::parse(rpc_url).expect("valid url"); let provider = StarknetClient::new(HttpClientBuilder::new().build(url).unwrap()); - Backend::new(provider, block_num.into()).unwrap() + Backend::new(provider).unwrap() } // Starts a TCP server that never close the connection. @@ -570,11 +580,15 @@ pub(crate) mod test_utils { use tokio::runtime::Builder; let (tx, rx) = sync_channel::<()>(1); + let (ready_tx, ready_rx) = sync_channel::<()>(1); thread::spawn(move || { Builder::new_current_thread().enable_all().build().unwrap().block_on(async move { let listener = TcpListener::bind(addr).await.unwrap(); + // Signal that the server is ready + ready_tx.send(()).unwrap(); + loop { let (mut socket, _) = listener.accept().await.unwrap(); @@ -621,6 +635,9 @@ pub(crate) mod test_utils { }); }); + // Wait for the server to be ready + ready_rx.recv().unwrap(); + // Returning the sender to allow controlling the response timing. tx } @@ -650,7 +667,8 @@ mod tests { // start a mock remote network start_tcp_server("127.0.0.1:8080".to_string()); - let handle = create_forked_backend("http://127.0.0.1:8080", 1); + let handle = create_forked_backend("http://127.0.0.1:8080"); + let block_id = 1; // check no pending requests let stats = handle.stats().expect(ERROR_STATS); @@ -659,23 +677,23 @@ mod tests { // send requests to the backend let h1 = handle.clone(); thread::spawn(move || { - h1.get_nonce(felt!("0x1").into()).expect(ERROR_SEND_REQUEST); + h1.get_nonce(felt!("0x1").into(), block_id).expect(ERROR_SEND_REQUEST); }); let h2 = handle.clone(); thread::spawn(move || { - h2.get_class_at(felt!("0x1")).expect(ERROR_SEND_REQUEST); + h2.get_class_at(felt!("0x1"), block_id).expect(ERROR_SEND_REQUEST); }); let h3 = handle.clone(); thread::spawn(move || { - h3.get_compiled_class_hash(felt!("0x2")).expect(ERROR_SEND_REQUEST); + h3.get_compiled_class_hash(felt!("0x2"), block_id).expect(ERROR_SEND_REQUEST); }); let h4 = handle.clone(); thread::spawn(move || { - h4.get_class_hash_at(felt!("0x1").into()).expect(ERROR_SEND_REQUEST); + h4.get_class_hash_at(felt!("0x1").into(), block_id).expect(ERROR_SEND_REQUEST); }); let h5 = handle.clone(); thread::spawn(move || { - h5.get_storage(felt!("0x1").into(), felt!("0x1")).expect(ERROR_SEND_REQUEST); + h5.get_storage(felt!("0x1").into(), felt!("0x1"), block_id).expect(ERROR_SEND_REQUEST); }); // wait for the requests to be handled @@ -691,7 +709,8 @@ mod tests { // start a mock remote network start_tcp_server("127.0.0.1:8081".to_string()); - let handle = create_forked_backend("http://127.0.0.1:8081", 1); + let handle = create_forked_backend("http://127.0.0.1:8081"); + let block_id = 1; // check no pending requests let stats = handle.stats().expect(ERROR_STATS); @@ -700,11 +719,11 @@ mod tests { // send requests to the backend let h1 = handle.clone(); thread::spawn(move || { - h1.get_nonce(felt!("0x1").into()).expect(ERROR_SEND_REQUEST); + h1.get_nonce(felt!("0x1").into(), block_id).expect(ERROR_SEND_REQUEST); }); let h2 = handle.clone(); thread::spawn(move || { - h2.get_nonce(felt!("0x1").into()).expect(ERROR_SEND_REQUEST); + h2.get_nonce(felt!("0x1").into(), block_id).expect(ERROR_SEND_REQUEST); }); // wait for the requests to be handled @@ -717,7 +736,7 @@ mod tests { // Different request, should be counted let h3 = handle.clone(); thread::spawn(move || { - h3.get_nonce(felt!("0x2").into()).expect(ERROR_SEND_REQUEST); + h3.get_nonce(felt!("0x2").into(), block_id).expect(ERROR_SEND_REQUEST); }); // wait for the requests to be handled @@ -733,7 +752,8 @@ mod tests { // start a mock remote network start_tcp_server("127.0.0.1:8082".to_string()); - let handle = create_forked_backend("http://127.0.0.1:8082", 1); + let handle = create_forked_backend("http://127.0.0.1:8082"); + let block_id = 1; // check no pending requests let stats = handle.stats().expect(ERROR_STATS); @@ -742,11 +762,11 @@ mod tests { // send requests to the backend let h1 = handle.clone(); thread::spawn(move || { - h1.get_class_at(felt!("0x1")).expect(ERROR_SEND_REQUEST); + h1.get_class_at(felt!("0x1"), block_id).expect(ERROR_SEND_REQUEST); }); let h2 = handle.clone(); thread::spawn(move || { - h2.get_class_at(felt!("0x1")).expect(ERROR_SEND_REQUEST); + h2.get_class_at(felt!("0x1"), block_id).expect(ERROR_SEND_REQUEST); }); // wait for the requests to be handled @@ -759,7 +779,7 @@ mod tests { // Different request, should be counted let h3 = handle.clone(); thread::spawn(move || { - h3.get_class_at(felt!("0x2")).expect(ERROR_SEND_REQUEST); + h3.get_class_at(felt!("0x2"), block_id).expect(ERROR_SEND_REQUEST); }); // wait for the requests to be handled @@ -775,7 +795,8 @@ mod tests { // start a mock remote network start_tcp_server("127.0.0.1:8083".to_string()); - let handle = create_forked_backend("http://127.0.0.1:8083", 1); + let handle = create_forked_backend("http://127.0.0.1:8083"); + let block_id = 1; // check no pending requests let stats = handle.stats().expect(ERROR_STATS); @@ -784,11 +805,11 @@ mod tests { // send requests to the backend let h1 = handle.clone(); thread::spawn(move || { - h1.get_compiled_class_hash(felt!("0x1")).expect(ERROR_SEND_REQUEST); + h1.get_compiled_class_hash(felt!("0x1"), block_id).expect(ERROR_SEND_REQUEST); }); let h2 = handle.clone(); thread::spawn(move || { - h2.get_compiled_class_hash(felt!("0x1")).expect(ERROR_SEND_REQUEST); + h2.get_compiled_class_hash(felt!("0x1"), block_id).expect(ERROR_SEND_REQUEST); }); // wait for the requests to be handled @@ -801,7 +822,7 @@ mod tests { // Different request, should be counted let h3 = handle.clone(); thread::spawn(move || { - h3.get_compiled_class_hash(felt!("0x2")).expect(ERROR_SEND_REQUEST); + h3.get_compiled_class_hash(felt!("0x2"), block_id).expect(ERROR_SEND_REQUEST); }); // wait for the requests to be handled @@ -817,7 +838,8 @@ mod tests { // start a mock remote network start_tcp_server("127.0.0.1:8084".to_string()); - let handle = create_forked_backend("http://127.0.0.1:8084", 1); + let handle = create_forked_backend("http://127.0.0.1:8084"); + let block_id = 1; // check no pending requests let stats = handle.stats().expect(ERROR_STATS); @@ -826,12 +848,12 @@ mod tests { // send requests to the backend let h1 = handle.clone(); thread::spawn(move || { - h1.get_class_at(felt!("0x1")).expect(ERROR_SEND_REQUEST); + h1.get_class_at(felt!("0x1"), block_id).expect(ERROR_SEND_REQUEST); }); // Since this also calls to the same request as the previous one, it should be deduped let h2 = handle.clone(); thread::spawn(move || { - h2.get_compiled_class_hash(felt!("0x1")).expect(ERROR_SEND_REQUEST); + h2.get_compiled_class_hash(felt!("0x1"), block_id).expect(ERROR_SEND_REQUEST); }); // wait for the requests to be handled @@ -844,7 +866,7 @@ mod tests { // Different request, should be counted let h3 = handle.clone(); thread::spawn(move || { - h3.get_class_at(felt!("0x2")).expect(ERROR_SEND_REQUEST); + h3.get_class_at(felt!("0x2"), block_id).expect(ERROR_SEND_REQUEST); }); // wait for the requests to be handled @@ -860,7 +882,8 @@ mod tests { // start a mock remote network start_tcp_server("127.0.0.1:8085".to_string()); - let handle = create_forked_backend("http://127.0.0.1:8085", 1); + let handle = create_forked_backend("http://127.0.0.1:8085"); + let block_id = 1; // check no pending requests let stats = handle.stats().expect(ERROR_STATS); @@ -869,11 +892,11 @@ mod tests { // send requests to the backend let h1 = handle.clone(); thread::spawn(move || { - h1.get_class_hash_at(felt!("0x1").into()).expect(ERROR_SEND_REQUEST); + h1.get_class_hash_at(felt!("0x1").into(), block_id).expect(ERROR_SEND_REQUEST); }); let h2 = handle.clone(); thread::spawn(move || { - h2.get_class_hash_at(felt!("0x1").into()).expect(ERROR_SEND_REQUEST); + h2.get_class_hash_at(felt!("0x1").into(), block_id).expect(ERROR_SEND_REQUEST); }); // wait for the requests to be handled @@ -886,7 +909,7 @@ mod tests { // Different request, should be counted let h3 = handle.clone(); thread::spawn(move || { - h3.get_class_hash_at(felt!("0x2").into()).expect(ERROR_SEND_REQUEST); + h3.get_class_hash_at(felt!("0x2").into(), block_id).expect(ERROR_SEND_REQUEST); }); // wait for the requests to be handled @@ -902,7 +925,8 @@ mod tests { // start a mock remote network start_tcp_server("127.0.0.1:8086".to_string()); - let handle = create_forked_backend("http://127.0.0.1:8086", 1); + let handle = create_forked_backend("http://127.0.0.1:8086"); + let block_id = 1; // check no pending requests let stats = handle.stats().expect(ERROR_STATS); @@ -911,11 +935,11 @@ mod tests { // send requests to the backend let h1 = handle.clone(); thread::spawn(move || { - h1.get_storage(felt!("0x1").into(), felt!("0x1")).expect(ERROR_SEND_REQUEST); + h1.get_storage(felt!("0x1").into(), felt!("0x1"), block_id).expect(ERROR_SEND_REQUEST); }); let h2 = handle.clone(); thread::spawn(move || { - h2.get_storage(felt!("0x1").into(), felt!("0x1")).expect(ERROR_SEND_REQUEST); + h2.get_storage(felt!("0x1").into(), felt!("0x1"), block_id).expect(ERROR_SEND_REQUEST); }); // wait for the requests to be handled @@ -928,7 +952,7 @@ mod tests { // Different request, should be counted let h3 = handle.clone(); thread::spawn(move || { - h3.get_storage(felt!("0x2").into(), felt!("0x3")).expect(ERROR_SEND_REQUEST); + h3.get_storage(felt!("0x2").into(), felt!("0x3"), block_id).expect(ERROR_SEND_REQUEST); }); // wait for the requests to be handled @@ -944,7 +968,8 @@ mod tests { // start a mock remote network start_tcp_server("127.0.0.1:8087".to_string()); - let handle = create_forked_backend("http://127.0.0.1:8087", 1); + let handle = create_forked_backend("http://127.0.0.1:8087"); + let block_id = 1; // check no pending requests let stats = handle.stats().expect(ERROR_STATS); @@ -953,11 +978,11 @@ mod tests { // send requests to the backend let h1 = handle.clone(); thread::spawn(move || { - h1.get_storage(felt!("0x1").into(), felt!("0x1")).expect(ERROR_SEND_REQUEST); + h1.get_storage(felt!("0x1").into(), felt!("0x1"), block_id).expect(ERROR_SEND_REQUEST); }); let h2 = handle.clone(); thread::spawn(move || { - h2.get_storage(felt!("0x1").into(), felt!("0x1")).expect(ERROR_SEND_REQUEST); + h2.get_storage(felt!("0x1").into(), felt!("0x1"), block_id).expect(ERROR_SEND_REQUEST); }); // wait for the requests to be handled @@ -970,12 +995,12 @@ mod tests { // Different request, should be counted let h3 = handle.clone(); thread::spawn(move || { - h3.get_storage(felt!("0x1").into(), felt!("0x3")).expect(ERROR_SEND_REQUEST); + h3.get_storage(felt!("0x1").into(), felt!("0x3"), block_id).expect(ERROR_SEND_REQUEST); }); // Different request, should be counted let h4 = handle.clone(); thread::spawn(move || { - h4.get_storage(felt!("0x1").into(), felt!("0x6")).expect(ERROR_SEND_REQUEST); + h4.get_storage(felt!("0x1").into(), felt!("0x6"), block_id).expect(ERROR_SEND_REQUEST); }); // wait for the requests to be handled @@ -988,7 +1013,7 @@ mod tests { // Same request as the last one, shouldn't be counted let h5 = handle.clone(); thread::spawn(move || { - h5.get_storage(felt!("0x1").into(), felt!("0x6")).expect(ERROR_SEND_REQUEST); + h5.get_storage(felt!("0x1").into(), felt!("0x6"), block_id).expect(ERROR_SEND_REQUEST); }); // wait for the requests to be handled @@ -1005,7 +1030,8 @@ mod tests { let result = "0x123"; let sender = start_mock_rpc_server("127.0.0.1:8090".to_string(), result.to_string()); - let handle = create_forked_backend("http://127.0.0.1:8090", 1); + let handle = create_forked_backend("http://127.0.0.1:8090"); + let block_id = 1; let addr = ContractAddress(felt!("0x1")); // Collect results from multiple identical nonce requests @@ -1016,7 +1042,7 @@ mod tests { let h = handle.clone(); let results = results.clone(); thread::spawn(move || { - let res = h.get_nonce(addr); + let res = h.get_nonce(addr, block_id); results.lock().unwrap().push(res); }) }) diff --git a/crates/storage/provider/provider/src/providers/fork/mod.rs b/crates/storage/provider/provider/src/providers/fork/mod.rs index 8a88ffe78..6403a36ea 100644 --- a/crates/storage/provider/provider/src/providers/fork/mod.rs +++ b/crates/storage/provider/provider/src/providers/fork/mod.rs @@ -39,6 +39,7 @@ mod trie; pub struct ForkedProvider { backend: BackendClient, provider: Arc>, + block_id: BlockNumber, } impl ForkedProvider { @@ -47,23 +48,27 @@ impl ForkedProvider { /// - `db`: The database to use for the provider. /// - `block_id`: The block number or hash to use as the fork point. /// - `provider`: The Starknet JSON-RPC client to use for the provider. - pub fn new(db: Db, block_id: BlockHashOrNumber, provider: StarknetClient) -> Self { - let backend = Backend::new(provider, block_id).expect("failed to create backend"); + pub fn new(db: Db, block_id: BlockNumber, provider: StarknetClient) -> Self { + let backend = Backend::new(provider).expect("failed to create backend"); let provider = Arc::new(DbProvider::new(db)); - Self { provider, backend } + Self { provider, backend, block_id } } pub fn backend(&self) -> &BackendClient { &self.backend } + + pub fn block_id(&self) -> BlockNumber { + self.block_id + } } impl ForkedProvider { /// Creates a new [`ForkedProvider`] using an ephemeral database. - pub fn new_ephemeral(block_id: BlockHashOrNumber, provider: StarknetClient) -> Self { - let backend = Backend::new(provider, block_id).expect("failed to create backend"); + pub fn new_ephemeral(block_id: BlockNumber, provider: StarknetClient) -> Self { + let backend = Backend::new(provider).expect("failed to create backend"); let provider = Arc::new(DbProvider::new_in_memory()); - Self { provider, backend } + Self { provider, backend, block_id } } } diff --git a/crates/storage/provider/provider/src/providers/fork/state.rs b/crates/storage/provider/provider/src/providers/fork/state.rs index 3829eda57..28cdd4a61 100644 --- a/crates/storage/provider/provider/src/providers/fork/state.rs +++ b/crates/storage/provider/provider/src/providers/fork/state.rs @@ -30,7 +30,12 @@ where let tx = self.provider.db().tx()?; let db = self.provider.clone(); let provider = db::state::LatestStateProvider::new(tx); - Ok(Box::new(LatestStateProvider { db, backend: self.backend.clone(), provider })) + Ok(Box::new(LatestStateProvider { + db, + backend: self.backend.clone(), + provider, + forked_block_id: self.block_id, + })) } fn historical( @@ -57,7 +62,14 @@ where let tx = db.db().tx()?; let client = self.backend.clone(); - Ok(Some(Box::new(HistoricalStateProvider::new(db, tx, block, client)))) + Ok(Some(Box::new(HistoricalStateProvider::new( + db, + tx, + block, + client, + block, + self.block_id, + )))) } } @@ -66,6 +78,7 @@ struct LatestStateProvider { db: Arc>, backend: BackendClient, provider: db::state::LatestStateProvider, + forked_block_id: BlockNumber, } impl ContractClassProvider for LatestStateProvider @@ -75,7 +88,7 @@ where fn class(&self, hash: ClassHash) -> ProviderResult> { if let Some(class) = self.provider.class(hash)? { Ok(Some(class)) - } else if let Some(class) = self.backend.get_class_at(hash)? { + } else if let Some(class) = self.backend.get_class_at(hash, self.forked_block_id)? { self.db.db().update(|tx| tx.put::(hash, class.clone().into()))??; Ok(Some(class)) } else { @@ -89,7 +102,9 @@ where ) -> ProviderResult> { if let res @ Some(..) = self.provider.compiled_class_hash_of_class_hash(hash)? { Ok(res) - } else if let Some(compiled_hash) = self.backend.get_compiled_class_hash(hash)? { + } else if let Some(compiled_hash) = + self.backend.get_compiled_class_hash(hash, self.forked_block_id)? + { self.db .db() .update(|tx| tx.put::(hash, compiled_hash))??; @@ -107,10 +122,10 @@ where fn nonce(&self, address: ContractAddress) -> ProviderResult> { if let res @ Some(..) = self.provider.nonce(address)? { Ok(res) - } else if let Some(nonce) = self.backend.get_nonce(address)? { + } else if let Some(nonce) = self.backend.get_nonce(address, self.forked_block_id)? { let class_hash = self .backend - .get_class_hash_at(address)? + .get_class_hash_at(address, self.forked_block_id)? .ok_or(ProviderError::MissingContractClassHash { address })?; let entry = GenericContractInfo { nonce, class_hash }; @@ -128,10 +143,12 @@ where ) -> ProviderResult> { if let res @ Some(..) = self.provider.class_hash_of_contract(address)? { Ok(res) - } else if let Some(class_hash) = self.backend.get_class_hash_at(address)? { + } else if let Some(class_hash) = + self.backend.get_class_hash_at(address, self.forked_block_id)? + { let nonce = self .backend - .get_nonce(address)? + .get_nonce(address, self.forked_block_id)? .ok_or(ProviderError::MissingContractNonce { address })?; let entry = GenericContractInfo { class_hash, nonce }; @@ -150,7 +167,7 @@ where ) -> ProviderResult> { if let res @ Some(..) = self.provider.storage(address, key)? { Ok(res) - } else if let Some(value) = self.backend.get_storage(address, key)? { + } else if let Some(value) = self.backend.get_storage(address, key, self.forked_block_id)? { let entry = StorageEntry { key, value }; self.db.db().update(|tx| tx.put::(address, entry))??; Ok(Some(value)) @@ -206,27 +223,36 @@ struct HistoricalStateProvider { db: Arc>, backend: BackendClient, provider: db::state::HistoricalStateProvider, + block_number: BlockNumber, + forked_block_number: BlockNumber, } impl HistoricalStateProvider { - pub fn new( + fn new( db: Arc>, tx: Db::Tx, block: BlockNumber, backend: BackendClient, + block_number: BlockNumber, + forked_block_number: BlockNumber, ) -> Self { let provider = db::state::HistoricalStateProvider::new(tx, block); - Self { db, backend, provider } + Self { db, backend, provider, block_number, forked_block_number } } } impl ContractClassProvider for HistoricalStateProvider { fn class(&self, hash: ClassHash) -> ProviderResult> { if let res @ Some(..) = self.provider.class(hash)? { - Ok(res) - } else if let Some(class) = self.backend.get_class_at(hash)? { - self.db.db().tx_mut()?.put::(hash, class.clone().into())?; - Ok(Some(class)) + return Ok(res); + } + + if self.block_number > self.forked_block_number { + return Ok(None); + } + + if let class @ Some(..) = self.backend.get_class_at(hash, self.block_number)? { + Ok(class) } else { Ok(None) } @@ -237,8 +263,16 @@ impl ContractClassProvider for HistoricalStateProvider { hash: ClassHash, ) -> ProviderResult> { if let res @ Some(..) = self.provider.compiled_class_hash_of_class_hash(hash)? { - Ok(res) - } else if let Some(compiled_hash) = self.backend.get_compiled_class_hash(hash)? { + return Ok(res); + } + + if self.block_number > self.forked_block_number { + return Ok(None); + } + + if let Some(compiled_hash) = + self.backend.get_compiled_class_hash(hash, self.block_number)? + { self.db.db().tx_mut()?.put::(hash, compiled_hash)?; Ok(Some(compiled_hash)) } else { @@ -250,8 +284,14 @@ impl ContractClassProvider for HistoricalStateProvider { impl StateProvider for HistoricalStateProvider { fn nonce(&self, address: ContractAddress) -> ProviderResult> { if let res @ Some(..) = self.provider.nonce(address)? { - Ok(res) - } else if let res @ Some(nonce) = self.backend.get_nonce(address)? { + return Ok(res); + } + + if self.block_number > self.forked_block_number { + return Ok(None); + } + + if let res @ Some(nonce) = self.backend.get_nonce(address, self.block_number)? { let block = self.provider.block(); let entry = ContractNonceChange { contract_address: address, nonce }; @@ -267,12 +307,18 @@ impl StateProvider for HistoricalStateProvider { address: ContractAddress, ) -> ProviderResult> { if let res @ Some(..) = self.provider.class_hash_of_contract(address)? { - Ok(res) - } else if let res @ Some(class_hash) = self.backend.get_class_hash_at(address)? { + return Ok(res); + } + + if self.block_number > self.forked_block_number { + return Ok(None); + } + + if let res @ Some(hash) = self.backend.get_class_hash_at(address, self.block_number)? { let block = self.provider.block(); - // TODO: this is technically wrong, we probably should insert the `ClassChangeHistory` - // entry on the state update level instead. - let entry = ContractClassChange::deployed(address, class_hash); + // TODO: this is technically wrong, we probably should insert the + // `ClassChangeHistory` entry on the state update level instead. + let entry = ContractClassChange::deployed(address, hash); self.db.db().tx_mut()?.put::(block, entry)?; Ok(res) @@ -287,8 +333,14 @@ impl StateProvider for HistoricalStateProvider { key: StorageKey, ) -> ProviderResult> { if let res @ Some(..) = self.provider.storage(address, key)? { - Ok(res) - } else if let res @ Some(value) = self.backend.get_storage(address, key)? { + return Ok(res); + } + + if self.block_number > self.forked_block_number { + return Ok(None); + } + + if let res @ Some(value) = self.backend.get_storage(address, key, self.block_number)? { let key = ContractStorageKey { contract_address: address, key }; let block = self.provider.block();