Skip to content

Commit 6d2bbc1

Browse files
authored
feat: add --addr flag to daft-dashboard cli (#5444)
1 parent fc5932c commit 6d2bbc1

File tree

3 files changed

+43
-19
lines changed

3 files changed

+43
-19
lines changed

src/daft-cli/src/python.rs

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
1-
use std::str::FromStr;
1+
use std::{
2+
net::{IpAddr, SocketAddr},
3+
str::FromStr,
4+
};
25

36
use clap::{Args, Parser, Subcommand, arg};
47
use pyo3::prelude::*;
58
use tracing_subscriber::{self, filter::Directive, layer::SubscriberExt, util::SubscriberInitExt};
69

710
#[derive(Debug, Args)]
811
struct DashboardArgs {
12+
/// The address to launch the dashboard on
13+
#[arg(short, long, default_value = "0.0.0.0")]
14+
addr: IpAddr,
915
#[arg(short, long, default_value_t = 80)]
1016
/// The port to launch the dashboard on
1117
port: u16,
@@ -35,6 +41,15 @@ fn run_dashboard(py: Python, args: DashboardArgs) {
3541
let filter = Directive::from_str(if args.verbose { "INFO" } else { "ERROR" })
3642
.expect("Failed to parse tracing filter");
3743

44+
if args.addr.is_unspecified() {
45+
println!("{}", console::style(format!(
46+
"⚠️ Listening on all network interfaces ({})! This is not recommended in production.",
47+
args.addr
48+
)).yellow().bold());
49+
}
50+
51+
let socket_addr = SocketAddr::from((args.addr, args.port));
52+
3853
// Set the subscriber for the detached run
3954
tracing_subscriber::registry()
4055
.with(
@@ -56,26 +71,25 @@ fn run_dashboard(py: Python, args: DashboardArgs) {
5671
"{} To get started, run your Daft script with env `{}`",
5772
console::style("█").magenta(),
5873
console::style(format!(
59-
"DAFT_DASHBOARD_URL=\"http://{}:{}\" python ...",
60-
daft_dashboard::DEFAULT_SERVER_ADDR,
61-
args.port
74+
"DAFT_DASHBOARD_URL=\"http://{}\" python ...",
75+
socket_addr
6276
))
6377
.bold(),
6478
);
6579
println!(
6680
"✨ View the dashboard at {}. Press Ctrl+C to shutdown",
67-
console::style(format!(
68-
"http://{}:{}",
69-
daft_dashboard::DEFAULT_SERVER_ADDR,
70-
args.port
71-
))
72-
.bold()
73-
.magenta()
74-
.underlined(),
81+
console::style(format!("http://{}", socket_addr))
82+
.bold()
83+
.magenta()
84+
.underlined(),
7585
);
76-
daft_dashboard::launch_server(args.port, async move { shutdown_rx.await.unwrap() })
77-
.await
78-
.expect("Failed to launch dashboard server");
86+
daft_dashboard::launch_server(
87+
args.addr,
88+
args.port,
89+
async move { shutdown_rx.await.unwrap() },
90+
)
91+
.await
92+
.expect("Failed to launch dashboard server");
7993
});
8094

8195
loop {

src/daft-dashboard/src/lib.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@ pub mod engine;
55
pub mod python;
66
pub(crate) mod state;
77

8-
use std::{net::Ipv4Addr, sync::Arc};
8+
use std::{
9+
net::{IpAddr, Ipv4Addr},
10+
sync::Arc,
11+
};
912

1013
use axum::{
1114
Json, Router,
@@ -28,6 +31,7 @@ use tracing::Level;
2831

2932
use crate::state::{DashboardState, GLOBAL_DASHBOARD_STATE};
3033

34+
// NOTE(void001): default listen to all ipv4 address, which pose a security risk in production environment
3135
pub const DEFAULT_SERVER_ADDR: Ipv4Addr = Ipv4Addr::UNSPECIFIED;
3236
pub const DEFAULT_SERVER_PORT: u16 = 3238;
3337

@@ -274,6 +278,7 @@ async fn ping() -> StatusCode {
274278
}
275279

276280
pub async fn launch_server(
281+
addr: IpAddr,
277282
port: u16,
278283
shutdown_fn: impl Future<Output = ()> + Send + 'static,
279284
) -> std::io::Result<()> {
@@ -303,7 +308,7 @@ pub async fn launch_server(
303308
.with_state(GLOBAL_DASHBOARD_STATE.clone());
304309

305310
// Start the server
306-
let listener = TcpListener::bind((DEFAULT_SERVER_ADDR, port)).await?;
311+
let listener = TcpListener::bind((addr, port)).await?;
307312
axum::serve(listener, app)
308313
.with_graceful_shutdown(shutdown_fn)
309314
.await

src/daft-dashboard/src/python.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ pub fn launch(noop_if_initialized: bool) -> PyResult<ConnectionHandle> {
8383
}
8484
}
8585

86-
let port = super::DEFAULT_SERVER_PORT; // TODO: Make configurable
86+
let port = super::DEFAULT_SERVER_PORT;
8787
let (send, recv) = oneshot::channel::<()>();
8888

8989
let handle = ConnectionHandle {
@@ -93,7 +93,12 @@ pub fn launch(noop_if_initialized: bool) -> PyResult<ConnectionHandle> {
9393
let _ = std::thread::spawn(move || {
9494
DASHBOARD_ENABLED.store(true, Ordering::SeqCst);
9595
let res = tokio_runtime().block_on(async {
96-
super::launch_server(port, async move { recv.await.unwrap() }).await
96+
super::launch_server(
97+
std::net::IpAddr::V4(super::DEFAULT_SERVER_ADDR),
98+
port,
99+
async move { recv.await.unwrap() },
100+
)
101+
.await
97102
});
98103
DASHBOARD_ENABLED.store(false, Ordering::SeqCst);
99104
res

0 commit comments

Comments
 (0)