Skip to content

Commit 90ed665

Browse files
committed
Add samod-cli
1 parent 242e557 commit 90ed665

File tree

8 files changed

+937
-47
lines changed

8 files changed

+937
-47
lines changed

Cargo.lock

Lines changed: 405 additions & 45 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
[workspace]
22
edition = "2024"
33
resolver = "3"
4-
members = [ "samod","samod-core", "samod-test-harness"]
4+
members = [ "samod", "samod-cli","samod-core", "samod-test-harness"]
55

66
[workspace.dependencies]
77
automerge = "0.7.1"
8+
rand = "0.9.2"

samod-cli/Cargo.toml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
[package]
2+
name = "samod-cli"
3+
version = "0.1.0"
4+
edition = "2024"
5+
6+
[dependencies]
7+
clap = { version = "4.5.60", features = ["derive"] }
8+
tokio = { version = "1.49.0", features = ["full"] }
9+
samod = { path = "../samod", version="0.7", features=["tokio", "axum", "tungstenite", "threadpool"] }
10+
url = "2.5.8"
11+
axum = { version = "0.8.8", features = ["ws"] }
12+
tracing-subscriber = { version = "0.3.22", features = ["env-filter", "json"] }
13+
tracing = "0.1.44"
14+
uuid = { version = "1.21.0", features = ["v4"] }
15+
rand = { workspace = true }
16+
opentelemetry = { version = "0.31", features = ["metrics"] }
17+
opentelemetry_sdk = { version = "0.31", features = ["metrics", "rt-tokio"] }
18+
opentelemetry-otlp = { version = "0.31", features = ["metrics", "http-proto"] }
19+
rayon = "1.11.0"

samod-cli/src/listener_arg.rs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
use std::net::{IpAddr, SocketAddr};
2+
3+
#[derive(Debug, Clone)]
4+
pub(crate) enum ListenerArg {
5+
Tcp(SocketAddr),
6+
WebSocket(SocketAddr),
7+
}
8+
9+
impl std::str::FromStr for ListenerArg {
10+
type Err = String;
11+
12+
fn from_str(s: &str) -> Result<Self, Self::Err> {
13+
let url: url::Url = s.parse().map_err(|e: url::ParseError| e.to_string())?;
14+
15+
let host = url.host().ok_or("URL must contain a host")?;
16+
let ip: IpAddr = match host {
17+
url::Host::Ipv4(addr) => addr.into(),
18+
url::Host::Ipv6(addr) => addr.into(),
19+
url::Host::Domain("localhost") => IpAddr::from([127, 0, 0, 1]),
20+
url::Host::Domain(other) => {
21+
return Err(format!(
22+
"expected an IP address or 'localhost', not '{other}'"
23+
));
24+
}
25+
};
26+
27+
let port = url
28+
.port_or_known_default()
29+
.ok_or("URL must contain a port")?;
30+
let addr = SocketAddr::new(ip, port);
31+
32+
match url.scheme() {
33+
"tcp" => Ok(ListenerArg::Tcp(addr)),
34+
"ws" | "wss" => Ok(ListenerArg::WebSocket(addr)),
35+
other => Err(format!(
36+
"unsupported scheme: '{other}', expected 'tcp', 'ws', or 'wss'"
37+
)),
38+
}
39+
}
40+
}

samod-cli/src/main.rs

Lines changed: 285 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,285 @@
1+
use std::{net::SocketAddr, path::PathBuf, sync::Arc};
2+
3+
use clap::Parser;
4+
5+
mod listener_arg;
6+
use listener_arg::ListenerArg;
7+
mod peer_arg;
8+
use peer_arg::PeerArg;
9+
mod otel_observer;
10+
use otel_observer::OtelObserver;
11+
use samod::{
12+
AcceptorHandle, ConcurrencyConfig, DocumentId, PeerId, Repo,
13+
storage::{InMemoryStorage, TokioFilesystemStorage},
14+
websocket::TungsteniteDialer,
15+
};
16+
use tokio::net::TcpListener;
17+
18+
#[derive(clap::ValueEnum, Clone, Default)]
19+
enum LogFormat {
20+
#[default]
21+
Text,
22+
Json,
23+
}
24+
25+
#[derive(clap::Parser)]
26+
pub(crate) struct Args {
27+
#[arg(long, default_value = "text", help = "Log output format")]
28+
log_format: LogFormat,
29+
#[command(subcommand)]
30+
command: Command,
31+
}
32+
33+
#[derive(clap::Subcommand)]
34+
pub(crate) enum Command {
35+
Serve(ServeCommand),
36+
}
37+
38+
#[derive(clap::Parser)]
39+
pub(crate) struct ServeCommand {
40+
#[arg(short, long, help = "URLS to listen on")]
41+
listeners: Vec<ListenerArg>,
42+
#[arg(short, long, help = "Peer URLs to connect to")]
43+
peers: Vec<PeerArg>,
44+
#[arg(
45+
short,
46+
long,
47+
help = "Path to the directory where samod should store its data"
48+
)]
49+
storage_dir: Option<PathBuf>,
50+
#[arg(
51+
long,
52+
help = "Peer ID prefixes to announce documents to (e.g. storage-server for the public sync server"
53+
)]
54+
relay_peer_id_prefixes: Vec<String>,
55+
#[arg(
56+
long,
57+
help = "Peer ID prefix to use for this server (e.g. storage-server for the public sync server)"
58+
)]
59+
peer_id_prefix: Option<String>,
60+
#[arg(
61+
long,
62+
help = "OTLP HTTP endpoint to export metrics to (e.g. http://localhost:4318 for otel-tui, or http://localhost:9090/api/v1/otlp for Prometheus)"
63+
)]
64+
otel_endpoint: Option<String>,
65+
}
66+
67+
#[tokio::main]
68+
async fn main() {
69+
let Args {
70+
log_format,
71+
command,
72+
} = Args::parse();
73+
let env_filter = tracing_subscriber::EnvFilter::try_from_default_env()
74+
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info"));
75+
match log_format {
76+
LogFormat::Text => tracing_subscriber::fmt().with_env_filter(env_filter).init(),
77+
LogFormat::Json => tracing_subscriber::fmt()
78+
.json()
79+
.with_env_filter(env_filter)
80+
.init(),
81+
}
82+
83+
match command {
84+
Command::Serve(serve_command) => serve(serve_command).await,
85+
}
86+
}
87+
88+
fn init_meter_provider(endpoint: &str) -> opentelemetry_sdk::metrics::SdkMeterProvider {
89+
use opentelemetry_otlp::WithExportConfig;
90+
use opentelemetry_sdk::Resource;
91+
use opentelemetry_sdk::metrics::SdkMeterProvider;
92+
93+
let endpoint = endpoint.strip_suffix('/').unwrap_or(endpoint);
94+
let exporter = opentelemetry_otlp::MetricExporter::builder()
95+
.with_http()
96+
.with_endpoint(format!("{endpoint}/v1/metrics"))
97+
.build()
98+
.expect("failed to build OTLP metric exporter");
99+
100+
let provider = SdkMeterProvider::builder()
101+
.with_periodic_exporter(exporter)
102+
.with_resource(Resource::builder().with_service_name("samod-cli").build())
103+
.build();
104+
105+
opentelemetry::global::set_meter_provider(provider.clone());
106+
provider
107+
}
108+
109+
async fn serve(
110+
ServeCommand {
111+
listeners,
112+
peers,
113+
storage_dir,
114+
relay_peer_id_prefixes,
115+
peer_id_prefix,
116+
otel_endpoint,
117+
}: ServeCommand,
118+
) {
119+
// Set up OpenTelemetry metrics if an endpoint is configured
120+
let _meter_provider = otel_endpoint.as_deref().map(|endpoint| {
121+
let provider = init_meter_provider(endpoint);
122+
tracing::info!("OpenTelemetry metrics exporting to {}", endpoint);
123+
provider
124+
});
125+
126+
let observer = _meter_provider.as_ref().map(|_| {
127+
let meter = opentelemetry::global::meter("samod-cli");
128+
OtelObserver::new(&meter)
129+
});
130+
131+
let announce_policy = move |_doc_id: DocumentId, peer_id: PeerId| {
132+
relay_peer_id_prefixes
133+
.iter()
134+
.any(|prefix| peer_id.to_string().starts_with(prefix))
135+
};
136+
let peer_id = if let Some(prefix) = peer_id_prefix {
137+
PeerId::from_string(format!("{}-{}", prefix, uuid::Uuid::new_v4()))
138+
} else {
139+
PeerId::new_with_rng(&mut rand::rng())
140+
};
141+
let threadpool = rayon::ThreadPoolBuilder::new()
142+
.build()
143+
.expect("failed to build threadpool");
144+
let repo = match storage_dir {
145+
Some(dir) => {
146+
tracing::info!("using file system storage at {}", dir.display());
147+
let storage = TokioFilesystemStorage::new(dir);
148+
let mut builder = samod::Repo::build_tokio()
149+
.with_storage(storage)
150+
.with_announce_policy(announce_policy)
151+
.with_peer_id(peer_id)
152+
.with_concurrency(ConcurrencyConfig::Threadpool(threadpool));
153+
if let Some(obs) = observer {
154+
builder = builder.with_observer(obs);
155+
}
156+
builder.load().await
157+
}
158+
None => {
159+
tracing::info!("using ephemeral in-memory storage");
160+
let storage = InMemoryStorage::new();
161+
let mut builder = samod::Repo::build_tokio()
162+
.with_storage(storage)
163+
.with_announce_policy(announce_policy)
164+
.with_peer_id(peer_id)
165+
.with_concurrency(ConcurrencyConfig::Threadpool(threadpool));
166+
if let Some(obs) = observer {
167+
builder = builder.with_observer(obs);
168+
}
169+
builder.load().await
170+
}
171+
};
172+
173+
for listener in listeners {
174+
match listener {
175+
ListenerArg::WebSocket(addr) => {
176+
tracing::info!("starting websocket listener on {}", addr);
177+
listen_websocket(repo.clone(), addr).await;
178+
}
179+
ListenerArg::Tcp(addr) => {
180+
tracing::info!("starting tcp listener on {addr}");
181+
listen_tcp(repo.clone(), addr).await;
182+
}
183+
}
184+
}
185+
186+
for peer in peers {
187+
match peer {
188+
PeerArg::Tcp { host, port } => {
189+
tracing::info!("creating outbound connection to {}:{}", host, port);
190+
repo.dial(
191+
samod::BackoffConfig::default(),
192+
Arc::new(samod::tokio_io::TcpDialer::new_host_port(
193+
host.clone(),
194+
port,
195+
)),
196+
)
197+
.inspect_err(|e| {
198+
tracing::warn!("error dialing tcp peer {}:{}: {e}", host, port);
199+
})
200+
.ok();
201+
}
202+
PeerArg::WebSocket(url) => {
203+
tracing::info!("creating outbound connection to {}", url);
204+
repo.dial(
205+
samod::BackoffConfig::default(),
206+
Arc::new(TungsteniteDialer::new(url.clone())),
207+
)
208+
.inspect_err(|e| {
209+
tracing::warn!("error dialing websocket peer {}: {e}", url);
210+
})
211+
.ok();
212+
}
213+
}
214+
}
215+
216+
// Now wait for termination
217+
tokio::signal::ctrl_c()
218+
.await
219+
.expect("failed to listen for ctrl-c signal");
220+
221+
// Flush metrics on shutdown
222+
if let Some(provider) = _meter_provider
223+
&& let Err(e) = provider.shutdown()
224+
{
225+
tracing::warn!("error shutting down meter provider: {e}");
226+
}
227+
}
228+
229+
async fn listen_tcp(repo: Repo, addr: SocketAddr) {
230+
let url = url::Url::parse(&format!("tcp://{}:{}", addr.ip(), addr.port())).unwrap();
231+
let Ok(listener) = TcpListener::bind(addr).await.inspect_err(|e| {
232+
tracing::error!("unable to listen on {url}: {e}");
233+
}) else {
234+
return;
235+
};
236+
let Ok(acceptor) = repo.make_acceptor(url.clone()).inspect_err(|e| {
237+
tracing::warn!("error creating acceptor for {url}: {e}");
238+
}) else {
239+
return;
240+
};
241+
tokio::spawn(async move {
242+
loop {
243+
let Ok((io, addr)) = listener.accept().await.inspect_err(|e| {
244+
tracing::warn!("error accepting tcp connection on {url}: {e}");
245+
}) else {
246+
continue;
247+
};
248+
tracing::info!("accepted tcp connection from {}", addr);
249+
if let Err(e) = acceptor.accept_tokio_io(io) {
250+
tracing::error!(?e, "failed to accept tcp connection from {}", addr);
251+
}
252+
}
253+
});
254+
}
255+
256+
async fn listen_websocket(repo: Repo, addr: SocketAddr) {
257+
let listener = TcpListener::bind(addr)
258+
.await
259+
.expect("unable to bind socket");
260+
let Ok(acceptor) = repo
261+
.make_acceptor(url::Url::parse(&format!("ws://{}", addr)).unwrap())
262+
.inspect_err(|e| {
263+
tracing::warn!("error creating acceptor for {}: {e}", addr);
264+
})
265+
else {
266+
return;
267+
};
268+
269+
let app = axum::Router::new()
270+
.route("/", axum::routing::get(websocket_handler))
271+
.with_state(acceptor.clone());
272+
let server = axum::serve(listener, app).into_future();
273+
tokio::spawn(server);
274+
}
275+
276+
async fn websocket_handler(
277+
ws: axum::extract::ws::WebSocketUpgrade,
278+
axum::extract::State(acceptor): axum::extract::State<AcceptorHandle>,
279+
) -> axum::response::Response {
280+
ws.on_upgrade(|socket| async move {
281+
if let Err(e) = acceptor.accept_axum(socket) {
282+
tracing::error!(?e, "failed to accept axum websocket");
283+
}
284+
})
285+
}

0 commit comments

Comments
 (0)