Skip to content

Commit 0b386f5

Browse files
committed
feat(dashboard): enhance dashboard UI and fix Ray runner state reporting
1 parent 1d7b41a commit 0b386f5

File tree

29 files changed

+1328
-292
lines changed

29 files changed

+1328
-292
lines changed

daft/context.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,21 @@ def _notify_optimization_start(self, query_id: str) -> None:
7777
def _notify_optimization_end(self, query_id: str, optimized_plan: str) -> None:
7878
self._ctx.notify_optimization_end(query_id, optimized_plan)
7979

80+
def _notify_exec_start(self, query_id: str, physical_plan: str) -> None:
81+
self._ctx.notify_exec_start(query_id, physical_plan)
82+
83+
def _notify_exec_end(self, query_id: str) -> None:
84+
self._ctx.notify_exec_end(query_id)
85+
86+
def _notify_exec_operator_start(self, query_id: str, node_id: int) -> None:
87+
self._ctx.notify_exec_operator_start(query_id, node_id)
88+
89+
def _notify_exec_operator_end(self, query_id: str, node_id: int) -> None:
90+
self._ctx.notify_exec_operator_end(query_id, node_id)
91+
92+
def _notify_exec_emit_stats(self, query_id: str, node_id: int, stats: dict[str, int]) -> None:
93+
self._ctx.notify_exec_emit_stats(query_id, node_id, stats)
94+
8095
def _notify_result_out(self, query_id: str, result: PartitionT) -> None:
8196
from daft.recordbatch.micropartition import MicroPartition
8297

daft/daft/__init__.pyi

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1997,6 +1997,7 @@ class DistributedPhysicalPlan:
19971997
def num_partitions(self) -> int: ...
19981998
def repr_ascii(self, simple: bool) -> str: ...
19991999
def repr_mermaid(self, options: MermaidOptions) -> str: ...
2000+
def repr_json(self) -> str: ...
20002001

20012002
class DistributedPhysicalPlanRunner:
20022003
def __init__(self) -> None: ...
@@ -2197,8 +2198,18 @@ class QueryEndState(Enum):
21972198
class PyQueryMetadata:
21982199
output_schema: PySchema
21992200
unoptimized_plan: str
2201+
runner: str
2202+
ray_dashboard_url: str | None
2203+
entrypoint: str | None
22002204

2201-
def __init__(self, output_schema: PySchema, unoptimized_plan: str) -> None: ...
2205+
def __init__(
2206+
self,
2207+
output_schema: PySchema,
2208+
unoptimized_plan: str,
2209+
runner: str,
2210+
ray_dashboard_url: str | None = None,
2211+
entrypoint: str | None = None,
2212+
) -> None: ...
22022213

22032214
class PyQueryResult:
22042215
end_state: QueryEndState
@@ -2225,6 +2236,11 @@ class PyDaftContext:
22252236
def notify_result_out(self, query_id: str, result: PartitionT) -> None: ...
22262237
def notify_optimization_start(self, query_id: str) -> None: ...
22272238
def notify_optimization_end(self, query_id: str, optimized_plan: str) -> None: ...
2239+
def notify_exec_start(self, query_id: str, physical_plan: str) -> None: ...
2240+
def notify_exec_end(self, query_id: str) -> None: ...
2241+
def notify_exec_operator_start(self, query_id: str, node_id: int) -> None: ...
2242+
def notify_exec_operator_end(self, query_id: str, node_id: int) -> None: ...
2243+
def notify_exec_emit_stats(self, query_id: str, node_id: int, stats: dict[str, int]) -> None: ...
22282244

22292245
def set_runner_ray(
22302246
address: str | None = None,

daft/runners/flotilla.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,9 @@ def try_autoscale(bundles: list[dict[str, int]]) -> None:
237237
num_cpus=0,
238238
)
239239
class RemoteFlotillaRunner:
240-
def __init__(self) -> None:
240+
def __init__(self, dashboard_url: str | None = None) -> None:
241+
if dashboard_url:
242+
os.environ["DAFT_DASHBOARD_URL"] = dashboard_url
241243
self.curr_plans: dict[str, DistributedPhysicalPlan] = {}
242244
self.curr_result_gens: dict[str, AsyncIterator[RayPartitionRef]] = {}
243245
self.plan_runner = DistributedPhysicalPlanRunner()
@@ -340,19 +342,21 @@ class FlotillaRunner:
340342

341343
def __init__(self) -> None:
342344
head_node_id = get_head_node_id()
345+
dashboard_url = os.environ.get("DAFT_DASHBOARD_URL")
343346
self.runner = RemoteFlotillaRunner.options( # type: ignore
344347
name=get_flotilla_runner_actor_name(),
345348
namespace=FLOTILLA_RUNNER_NAMESPACE,
346349
get_if_exists=True,
350+
runtime_env={"env_vars": {"DAFT_DASHBOARD_URL": dashboard_url}} if dashboard_url else None,
347351
scheduling_strategy=(
348352
ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
349353
node_id=head_node_id,
350354
soft=False,
351355
)
352-
if head_node_id is not None
353-
else "DEFAULT"
356+
if head_node_id
357+
else None
354358
),
355-
).remote()
359+
).remote(dashboard_url=dashboard_url)
356360

357361
def stream_plan(
358362
self,

daft/runners/native_runner.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,19 @@ def run_iter(
9696
output_schema = builder.schema()
9797

9898
# Optimize the logical plan.
99-
ctx._notify_query_start(query_id, PyQueryMetadata(output_schema._schema, builder.repr_json()))
99+
import sys
100+
101+
entrypoint = "python " + " ".join(sys.argv)
102+
ctx._notify_query_start(
103+
query_id,
104+
PyQueryMetadata(output_schema._schema, builder.repr_json(), "Native (Swordfish)", None, entrypoint),
105+
)
100106
ctx._notify_optimization_start(query_id)
101107
builder = builder.optimize(ctx.daft_execution_config)
102108
ctx._notify_optimization_end(query_id, builder.repr_json())
103109

104110
plan = LocalPhysicalPlan.from_logical_plan_builder(builder._builder)
111+
105112
executor = NativeExecutor()
106113
results_gen = executor.run(
107114
plan,
@@ -112,8 +119,10 @@ def run_iter(
112119
)
113120

114121
try:
122+
total_rows = 0
115123
for result in results_gen:
116124
ctx._notify_result_out(query_id, result.partition())
125+
total_rows += len(result.partition())
117126
yield result
118127
except KeyboardInterrupt as e:
119128
query_result = PyQueryResult(QueryEndState.Canceled, "Query canceled by the user.")

daft/runners/ray_runner.py

Lines changed: 81 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
from __future__ import annotations
22

3+
import json
34
import logging
5+
import os
6+
import sys
47
import time
58
import uuid
69
from collections.abc import Generator, Iterable, Iterator
@@ -45,6 +48,9 @@
4548
FileFormatConfig,
4649
FileInfos,
4750
IOConfig,
51+
PyQueryMetadata,
52+
PyQueryResult,
53+
QueryEndState,
4854
)
4955
from daft.datatype import DataType
5056
from daft.filesystem import glob_path_with_stats
@@ -548,19 +554,84 @@ def run_iter(
548554
ctx = get_context()
549555
query_id = str(uuid.uuid4())
550556
daft_execution_config = ctx.daft_execution_config
557+
output_schema = builder.schema()
551558

552-
# Optimize the logical plan.
553-
builder = builder.optimize(daft_execution_config)
554-
555-
distributed_plan = DistributedPhysicalPlan.from_logical_plan_builder(
556-
builder._builder, query_id, daft_execution_config
559+
# Notify query start
560+
ray_dashboard_url = None
561+
try:
562+
ray_dashboard_url = ray.worker.get_dashboard_url()
563+
if ray_dashboard_url:
564+
if not ray_dashboard_url.startswith("http"):
565+
ray_dashboard_url = f"http://{ray_dashboard_url}"
566+
567+
# Try to append Job ID
568+
try:
569+
job_id = ray.get_runtime_context().get_job_id()
570+
if job_id:
571+
ray_dashboard_url = f"{ray_dashboard_url}/#/jobs/{job_id}"
572+
except Exception:
573+
pass
574+
except Exception:
575+
pass
576+
577+
entrypoint = "python " + " ".join(sys.argv)
578+
579+
ctx._notify_query_start(
580+
query_id,
581+
PyQueryMetadata(
582+
output_schema._schema, builder.repr_json(), "Ray (Flotilla)", ray_dashboard_url, entrypoint
583+
),
557584
)
558-
if self.flotilla_plan_runner is None:
559-
self.flotilla_plan_runner = FlotillaRunner()
585+
ctx._notify_optimization_start(query_id)
560586

561-
yield from self.flotilla_plan_runner.stream_plan(
562-
distributed_plan, self._part_set_cache.get_all_partition_sets()
563-
)
587+
# Log Dashboard URL if configured
588+
dashboard_url = os.environ.get("DAFT_DASHBOARD_URL")
589+
if dashboard_url:
590+
print(f"Daft Dashboard: {dashboard_url}/query/{query_id}")
591+
592+
try:
593+
# Optimize the logical plan.
594+
builder = builder.optimize(daft_execution_config)
595+
ctx._notify_optimization_end(query_id, builder.repr_json())
596+
597+
distributed_plan = DistributedPhysicalPlan.from_logical_plan_builder(
598+
builder._builder, query_id, daft_execution_config
599+
)
600+
physical_plan_json = distributed_plan.repr_json()
601+
ctx._notify_exec_start(query_id, physical_plan_json)
602+
603+
if self.flotilla_plan_runner is None:
604+
self.flotilla_plan_runner = FlotillaRunner()
605+
606+
total_rows = 0
607+
for result in self.flotilla_plan_runner.stream_plan(
608+
distributed_plan, self._part_set_cache.get_all_partition_sets()
609+
):
610+
if result.metadata() is not None:
611+
total_rows += result.metadata().num_rows
612+
yield result
613+
614+
# Mark all operators as finished to clean up the Dashboard UI before notify_exec_end
615+
try:
616+
plan_dict = json.loads(physical_plan_json)
617+
618+
def notify_end(node: dict[str, Any]) -> None:
619+
if "id" in node:
620+
ctx._notify_exec_operator_end(query_id, node["id"])
621+
if "children" in node:
622+
for child in node["children"]:
623+
notify_end(child)
624+
625+
notify_end(plan_dict)
626+
except Exception:
627+
pass
628+
629+
ctx._notify_exec_end(query_id)
630+
ctx._notify_query_end(query_id, PyQueryResult(QueryEndState.Finished, ""))
631+
632+
except Exception as e:
633+
ctx._notify_query_end(query_id, PyQueryResult(QueryEndState.Failed, str(e)))
634+
raise
564635

565636
def run_iter_tables(
566637
self, builder: LogicalPlanBuilder, results_buffer_size: int | None = None

daft/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,10 +168,10 @@ def np_datetime64_to_timestamp(dt: np.datetime64) -> tuple[int, PyTimeUnit | Non
168168
val: np.int64 = dt.astype(np.int64) * np.int64(count)
169169

170170
if unit in ("Y", "M", "W", "D"):
171-
val = np.datetime64(dt, "D").astype(np.int64) # type: ignore
171+
val = np.datetime64(dt, "D").astype(np.int64)
172172
return val.item(), None
173173
elif unit in ("h", "m"):
174-
val = np.datetime64(dt, "s").astype(np.int64) # type: ignore
174+
val = np.datetime64(dt, "s").astype(np.int64)
175175
return val.item(), PyTimeUnit.seconds()
176176
elif unit == "s":
177177
return val.item(), PyTimeUnit.seconds()
@@ -183,7 +183,7 @@ def np_datetime64_to_timestamp(dt: np.datetime64) -> tuple[int, PyTimeUnit | Non
183183
return val.item(), PyTimeUnit.nanoseconds()
184184
else:
185185
# unit is too small, just convert to nanoseconds
186-
val = np.datetime64(dt, "ns").astype(np.int64) # type: ignore
186+
val = np.datetime64(dt, "ns").astype(np.int64)
187187
return val.item(), PyTimeUnit.nanoseconds()
188188

189189

src/daft-context/Cargo.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,14 @@ daft-micropartition = {path = "../daft-micropartition", default-features = false
1111
daft-dashboard = {path = "../daft-dashboard", default-features = false}
1212
daft-runners = {workspace = true}
1313
pyo3 = {workspace = true, optional = true}
14+
tokio = {workspace = true}
15+
oneshot = "0.1.8"
1416
log = {workspace = true}
1517
# For debug subscriber
1618
dashmap = {workspace = true}
1719
# Client for submitting to dashboard server
18-
reqwest = {workspace = true, default-features = false}
20+
reqwest = {workspace = true, default_features = false}
21+
uuid = {workspace = true}
1922

2023
[features]
2124
python = [

src/daft-context/src/lib.rs

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,100 @@ impl DaftContext {
184184
Ok::<(), DaftError>(())
185185
})
186186
}
187+
188+
pub fn notify_exec_start(&self, query_id: QueryID, physical_plan: String) -> DaftResult<()> {
189+
self.with_state(|state| {
190+
for subscriber in state.subscribers.values() {
191+
subscriber.on_exec_start(query_id.clone(), physical_plan.clone().into())?;
192+
}
193+
Ok::<(), DaftError>(())
194+
})
195+
}
196+
197+
pub fn notify_exec_end(&self, query_id: QueryID) -> DaftResult<()> {
198+
let subscribers = self.with_state(|state| {
199+
state
200+
.subscribers
201+
.values()
202+
.cloned()
203+
.collect::<Vec<Arc<dyn Subscriber>>>()
204+
});
205+
let rt = common_runtime::get_io_runtime(false);
206+
for subscriber in subscribers {
207+
let query_id = query_id.clone();
208+
let _ = rt.block_within_async_context(async move {
209+
if let Err(e) = subscriber.on_exec_end(query_id).await {
210+
log::error!("Failed to notify exec end: {}", e);
211+
}
212+
});
213+
}
214+
Ok(())
215+
}
216+
217+
pub fn notify_exec_operator_start(&self, query_id: QueryID, node_id: usize) -> DaftResult<()> {
218+
let subscribers = self.with_state(|state| {
219+
state
220+
.subscribers
221+
.values()
222+
.cloned()
223+
.collect::<Vec<Arc<dyn Subscriber>>>()
224+
});
225+
let rt = common_runtime::get_io_runtime(false);
226+
for subscriber in subscribers {
227+
let query_id = query_id.clone();
228+
rt.spawn(async move {
229+
if let Err(e) = subscriber.on_exec_operator_start(query_id, node_id).await {
230+
log::error!("Failed to notify exec operator start: {}", e);
231+
}
232+
});
233+
}
234+
Ok(())
235+
}
236+
237+
pub fn notify_exec_operator_end(&self, query_id: QueryID, node_id: usize) -> DaftResult<()> {
238+
let subscribers = self.with_state(|state| {
239+
state
240+
.subscribers
241+
.values()
242+
.cloned()
243+
.collect::<Vec<Arc<dyn Subscriber>>>()
244+
});
245+
let rt = common_runtime::get_io_runtime(false);
246+
for subscriber in subscribers {
247+
let query_id = query_id.clone();
248+
rt.spawn(async move {
249+
if let Err(e) = subscriber.on_exec_operator_end(query_id, node_id).await {
250+
log::error!("Failed to notify exec operator end: {}", e);
251+
}
252+
});
253+
}
254+
Ok(())
255+
}
256+
257+
pub fn notify_exec_emit_stats(
258+
&self,
259+
query_id: QueryID,
260+
stats: Vec<(usize, common_metrics::StatSnapshot)>,
261+
) -> DaftResult<()> {
262+
let subscribers = self.with_state(|state| {
263+
state
264+
.subscribers
265+
.values()
266+
.cloned()
267+
.collect::<Vec<Arc<dyn Subscriber>>>()
268+
});
269+
let rt = common_runtime::get_io_runtime(false);
270+
for subscriber in subscribers {
271+
let stats = stats.clone();
272+
let query_id = query_id.clone();
273+
rt.spawn(async move {
274+
if let Err(e) = subscriber.on_exec_emit_stats(query_id, &stats).await {
275+
log::error!("Failed to notify exec emit stats: {}", e);
276+
}
277+
});
278+
}
279+
Ok(())
280+
}
187281
}
188282

189283
static DAFT_CONTEXT: OnceLock<DaftContext> = OnceLock::new();

0 commit comments

Comments
 (0)