Skip to content

Commit 03924c4

Browse files
committed
class downloader
1 parent 23719c6 commit 03924c4

File tree

1 file changed

+87
-43
lines changed

1 file changed

+87
-43
lines changed

crates/katana/pipeline/src/stage/classes.rs

Lines changed: 87 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
1-
use std::time::Duration;
1+
use std::sync::Arc;
22

33
use anyhow::Result;
4+
use backon::{ExponentialBuilder, Retryable};
45
use katana_primitives::block::BlockNumber;
5-
use katana_primitives::class::{
6-
CasmContractClass, ClassHash, CompiledClass, ContractClass, SierraContractClass,
7-
};
8-
use katana_primitives::conversion::rpc::{legacy_rpc_to_class, StarknetRsLegacyContractClass};
6+
use katana_primitives::class::{ClassHash, ContractClass, SierraContractClass};
7+
use katana_primitives::conversion::rpc::StarknetRsLegacyContractClass;
98
use katana_provider::traits::contract::{ContractClassWriter, ContractClassWriterExt};
109
use katana_provider::traits::state_update::StateUpdateProvider;
1110
use katana_rpc_types::class::RpcSierraContractClass;
1211
use starknet::providers::sequencer::models::{BlockId, DeployedClass};
1312
use starknet::providers::{ProviderError, SequencerGatewayProvider};
14-
use tracing::info;
13+
use tracing::warn;
1514

1615
use super::{Stage, StageExecutionInput, StageResult};
1716

@@ -24,22 +23,97 @@ pub enum Error {
2423
#[derive(Debug)]
2524
pub struct Classes<P> {
2625
provider: P,
27-
feeder_gateway: SequencerGatewayProvider,
26+
downloader: Downloader,
2827
}
2928

3029
impl<P> Classes<P> {
3130
pub fn new(provider: P, feeder_gateway: SequencerGatewayProvider) -> Self {
32-
Self { provider, feeder_gateway }
31+
Self { provider, downloader: Downloader::new(feeder_gateway) }
3332
}
3433
}
3534

36-
impl<P> Classes<P>
35+
#[async_trait::async_trait]
36+
impl<P> Stage for Classes<P>
3737
where
38-
P: StateUpdateProvider + ContractClassWriter,
38+
P: StateUpdateProvider + ContractClassWriter + ContractClassWriterExt,
3939
{
40-
#[allow(deprecated)]
41-
async fn get_class(&self, hash: ClassHash, block: BlockNumber) -> Result<ContractClass, Error> {
42-
let class = self.feeder_gateway.get_class_by_hash(hash, BlockId::Number(block)).await?;
40+
fn id(&self) -> &'static str {
41+
"Classes"
42+
}
43+
44+
async fn execute(&mut self, input: &StageExecutionInput) -> StageResult {
45+
for i in input.from..=input.to {
46+
let class_hashes = self.provider.declared_classes(i.into())?.unwrap();
47+
let class_hashes = class_hashes.keys().map(|hash| *hash).collect::<Vec<_>>();
48+
49+
let classes = self.downloader.fetch_classes(&class_hashes, i).await?;
50+
for (hash, class) in classes {
51+
self.provider.set_class(hash, class)?;
52+
}
53+
}
54+
55+
Ok(())
56+
}
57+
}
58+
59+
#[derive(Debug, Clone)]
60+
struct Downloader {
61+
client: Arc<SequencerGatewayProvider>,
62+
}
63+
64+
impl Downloader {
65+
fn new(client: SequencerGatewayProvider) -> Self {
66+
Self { client: Arc::new(client) }
67+
}
68+
69+
async fn fetch_classes(
70+
&self,
71+
classes: &[ClassHash],
72+
block: BlockNumber,
73+
) -> Result<Vec<(ClassHash, ContractClass)>, Error> {
74+
let mut all_results = Vec::with_capacity(classes.len());
75+
76+
for hash in classes {
77+
let mut futures = Vec::new();
78+
79+
futures.push(self.fetch_class_with_retry(*hash, block));
80+
let batch_results = futures::future::join_all(futures).await;
81+
82+
all_results.extend(batch_results);
83+
}
84+
85+
all_results.into_iter().collect()
86+
}
87+
88+
async fn fetch_class_with_retry(
89+
&self,
90+
hash: ClassHash,
91+
block: BlockNumber,
92+
) -> Result<(ClassHash, ContractClass), Error> {
93+
let request = || async move {
94+
#[allow(deprecated)]
95+
self.clone().fetch_class(hash, block).await
96+
};
97+
98+
// Retry only when being rate limited
99+
let result = request
100+
.retry(ExponentialBuilder::default())
101+
.when(|e| matches!(e, Error::Gateway(ProviderError::RateLimited)))
102+
.notify(|error, _| {
103+
warn!(target: "pipeline", hash = format!("{hash:#x}"), %block, %error, "Retrying class download.");
104+
})
105+
.await?;
106+
107+
Ok((hash, result))
108+
}
109+
110+
async fn fetch_class(
111+
&self,
112+
hash: ClassHash,
113+
block: BlockNumber,
114+
) -> Result<ContractClass, Error> {
115+
#[allow(deprecated)]
116+
let class = self.client.get_class_by_hash(hash, BlockId::Number(block)).await?;
43117

44118
let class = match class {
45119
DeployedClass::LegacyClass(legacy) => {
@@ -59,36 +133,6 @@ where
59133
}
60134
}
61135

62-
#[async_trait::async_trait]
63-
impl<P> Stage for Classes<P>
64-
where
65-
P: StateUpdateProvider + ContractClassWriter + ContractClassWriterExt,
66-
{
67-
fn id(&self) -> &'static str {
68-
"Classes"
69-
}
70-
71-
async fn execute(&mut self, input: &StageExecutionInput) -> StageResult {
72-
for i in input.from..=input.to {
73-
// loop thru all the class hashes in the current block
74-
let class_hashes = self.provider.declared_classes(i.into())?.unwrap();
75-
76-
// TODO: do this in parallel
77-
for hash in class_hashes.keys() {
78-
info!(target: "pipeline", class_hash = format!("{hash:#x}"), "Fetching class artifacts.");
79-
80-
// 1. fetch sierra and casm class from fgw
81-
let class = self.get_class(*hash, i).await?;
82-
self.provider.set_class(*hash, class)?;
83-
84-
tokio::time::sleep(Duration::from_secs(1)).await;
85-
}
86-
}
87-
88-
Ok(())
89-
}
90-
}
91-
92136
fn to_inner_legacy_class(class: StarknetRsLegacyContractClass) -> Result<ContractClass> {
93137
let value = serde_json::to_value(class)?;
94138
let class = serde_json::from_value::<katana_primitives::class::LegacyContractClass>(value)?;

0 commit comments

Comments
 (0)