Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 56 additions & 5 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,10 @@ tower-http = "0.4.4"
tracing = { version = "0.1.38", features = [ "log" ], default-features = false }
tracing-log = "0.1.3"
tracing-subscriber = { version = "0.3.16", features = [ "env-filter", "json" ] }
# indicatif
tracing-indicatif = "0.3.9"
indicatif = "0.17.9"

url = { version = "2.4.0", features = [ "serde" ] }
walkdir = "2.5.0"
# TODO: see if we still need the git version
Expand Down
1 change: 1 addition & 0 deletions bin/torii/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ anyhow.workspace = true
clap.workspace = true
tracing.workspace = true
tracing-subscriber.workspace = true
tracing-indicatif.workspace = true

[build-dependencies]
vergen = { version = "9.0.6", features = ["build", "emit_and_set"] }
Expand Down
25 changes: 17 additions & 8 deletions bin/torii/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,32 @@
use clap::Parser;
use cli::Cli;
use torii_runner::Runner;
use tracing_subscriber::{fmt, EnvFilter};
use tracing_indicatif::IndicatifLayer;
use tracing_subscriber::fmt::format::FmtSpan;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::EnvFilter;
use tracing_subscriber::Registry;

mod cli;

#[tokio::main]
async fn main() -> anyhow::Result<()> {
// Set the global tracing subscriber
let filter_layer =
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("torii=info"));
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info,torii=info")); // Adjust default filter if needed

let subscriber = fmt::Subscriber::builder()
.with_env_filter(filter_layer)
.finish();
let indicatif_layer = IndicatifLayer::new();

// Set the global subscriber
tracing::subscriber::set_global_default(subscriber)
.expect("Failed to set the global tracing subscriber");
Registry::default()
.with(filter_layer)
.with(
tracing_subscriber::fmt::layer()
.with_span_events(FmtSpan::NEW | FmtSpan::CLOSE)
.with_writer(indicatif_layer.get_stderr_writer()),
)
.with(indicatif_layer)
.init();

let args = Cli::parse().args.with_config_file()?;
let runner = Runner::new(args, env!("TORII_VERSION_SPEC").to_string());
Expand Down
4 changes: 4 additions & 0 deletions crates/cli/src/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ pub struct ToriiArgs {
#[command(flatten)]
pub sql: SqlOptions,

#[command(flatten)]
pub snapshot: SnapshotOptions,

#[cfg(feature = "server")]
#[command(flatten)]
pub metrics: MetricsOptions,
Expand All @@ -82,6 +85,7 @@ impl Default for ToriiArgs {
events: EventsOptions::default(),
erc: ErcOptions::default(),
sql: SqlOptions::default(),
snapshot: SnapshotOptions::default(),
runner: RunnerOptions::default(),
#[cfg(feature = "server")]
metrics: MetricsOptions::default(),
Expand Down
16 changes: 16 additions & 0 deletions crates/cli/src/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,22 @@ impl Default for SqlOptions {
}
}

#[derive(Default, Debug, clap::Args, Clone, Serialize, Deserialize, PartialEq, MergeOptions)]
#[serde(default)]
#[command(next_help_heading = "Snapshot options")]
pub struct SnapshotOptions {
/// Snapshot URL to download
#[arg(long = "snapshot.url", help = "The snapshot URL to download.")]
pub url: Option<String>,

/// Optional version of the remote snapshot torii version
#[arg(
long = "snapshot.version",
help = "Optional version of the torii the snapshot has been made from. This is only used to give a warning if there is a version mismatch between the snapshot and this torii."
)]
pub version: Option<String>,
}

#[derive(Default, Debug, clap::Args, Clone, Serialize, Deserialize, PartialEq, MergeOptions)]
#[serde(default)]
#[command(next_help_heading = "Runner options")]
Expand Down
4 changes: 4 additions & 0 deletions crates/runner/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ tracing-subscriber.workspace = true
tracing.workspace = true
url.workspace = true
webbrowser = "0.8"
reqwest.workspace = true
indicatif.workspace = true
tracing-indicatif.workspace = true
rand.workspace = true

[dev-dependencies]
assert_matches.workspace = true
Expand Down
97 changes: 94 additions & 3 deletions crates/runner/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

use std::cmp;
use std::net::SocketAddr;
use std::path::Path;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
Expand All @@ -29,6 +30,8 @@ use starknet::core::types::{BlockId, BlockTag};
use starknet::providers::jsonrpc::HttpTransport;
use starknet::providers::{JsonRpcClient, Provider};
use tempfile::{NamedTempFile, TempDir};
use tokio::fs::File;
use tokio::io::AsyncWriteExt;
use tokio::sync::broadcast;
use tokio::sync::broadcast::Sender;
use tokio_stream::StreamExt;
Expand All @@ -42,7 +45,8 @@ use torii_sqlite::executor::Executor;
use torii_sqlite::simple_broker::SimpleBroker;
use torii_sqlite::types::{Contract, ContractType, Model};
use torii_sqlite::{Sql, SqlConfig};
use tracing::{error, info, warn};
use tracing::{error, info, info_span, warn, Instrument, Span};
use tracing_indicatif::span_ext::IndicatifSpanExt;
use url::form_urlencoded;

mod constants;
Expand Down Expand Up @@ -114,15 +118,52 @@ impl Runner {
}

let tempfile = NamedTempFile::new()?;
let database_path = if let Some(db_dir) = self.args.db_dir {
let database_path = if let Some(db_dir) = &self.args.db_dir {
// Create the directory if it doesn't exist
std::fs::create_dir_all(&db_dir)?;
std::fs::create_dir_all(db_dir)?;
// Set the database file path inside the directory
db_dir.join("torii.db")
} else {
tempfile.path().to_path_buf()
};

// Download snapshot if URL is provided
if let Some(snapshot_url) = self.args.snapshot.url {
// We don't wanna download our snapshot into an existing database. So only proceed if we don't have an existing db dir
// or if we have a tempfile path.
if self.args.db_dir.is_none() || !database_path.exists() {
info!(target: LOG_TARGET, url = %snapshot_url, path = %database_path.display(), "Downloading snapshot...");

// Check for version mismatch
if let Some(snapshot_version) = self.args.snapshot.version {
if snapshot_version != self.version_spec {
warn!(
target: LOG_TARGET,
snapshot_version = %snapshot_version,
current_version = %self.version_spec,
"Snapshot version mismatch. This may cause issues."
);
}
}

let client = reqwest::Client::new();
if let Err(e) =
stream_snapshot_into_file(&snapshot_url, &database_path, &client).await
{
error!(target: LOG_TARGET, error = %e, "Failed to download snapshot.");
// Decide if we should exit or continue with a fresh DB
// For now, let's exit as the user explicitly requested a snapshot.
return Err(e);
}
info!(target: LOG_TARGET, "Snapshot downloaded successfully.");
} else {
error!(target: LOG_TARGET, "A database already exists at the given path. If you want to download a new snapshot, please delete the existing database file or provide a different path.");
return Err(anyhow::anyhow!(
"Database file already exists at the specified path."
));
}
}

let mut options = SqliteConnectOptions::from_str(&database_path.to_string_lossy())?
.create_if_missing(true)
.with_regexp();
Expand Down Expand Up @@ -408,3 +449,53 @@ async fn verify_contracts_deployed(

Ok(undeployed)
}

/// Streams a snapshot into a file, displaying progress and handling potential errors.
///
/// # Arguments
/// * `url` - The URL to download from.
/// * `destination_path` - The path to save the downloaded file.
/// * `client` - An instance of `reqwest::Client`.
///
/// # Returns
/// * `Ok(())` if the download is successful.
/// * `Err(anyhow::Error)` if any error occurs during download or file writing.
async fn stream_snapshot_into_file(
url: &str,
destination_path: &Path,
client: &reqwest::Client,
) -> anyhow::Result<()> {
let response = client.get(url).send().await?.error_for_status()?;
let total_size = response.content_length().unwrap_or(0);

let span = info_span!("download_snapshot", url);
span.pb_set_style(
&indicatif::ProgressStyle::default_bar()
.template(
"{msg} [{elapsed_precise}] \n[{bar:40.cyan/blue}] {bytes}/{total_bytes} ({eta}) {percent}%",
)?
.progress_chars("█▓░"),
);
span.pb_set_length(total_size);
span.pb_set_message(&format!("Downloading {}", url));

let instrumented_future = async {
let mut file = File::create(destination_path).await?;
let mut downloaded: u64 = 0;
let mut stream = response.bytes_stream();

while let Some(item) = stream.next().await {
let chunk = item?;
file.write_all(&chunk).await?;
let new = cmp::min(downloaded.saturating_add(chunk.len() as u64), total_size);
downloaded = new;
Span::current().pb_set_position(new);
}

Span::current().pb_set_message("Downloaded snapshot successfully");
Ok(())
}
.instrument(span);

instrumented_future.await
}
Loading