Skip to content

Commit a561cca

Browse files
authored
chore(observability): Refactor progress bar to remove RuntimeStatsSubscriber (#6030)
## Changes Made Again, a change while working on the One True Progress Bar (™️ pending) PR. Originally was going to make the pbar a QuerySubscriber, but that will be really messy, so instead, just refactored it a bit to remove RuntimeStatsSubscriber
1 parent 65d8f08 commit a561cca

File tree

5 files changed

+154
-187
lines changed

5 files changed

+154
-187
lines changed

src/daft-context/src/lib.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@ use daft_micropartition::MicroPartitionRef;
1515
#[cfg(feature = "python")]
1616
use pyo3::prelude::*;
1717

18-
pub use crate::subscribers::Subscriber;
19-
use crate::subscribers::{QueryMetadata, QueryResult};
18+
pub use crate::subscribers::{QueryMetadata, QueryResult, Subscriber};
2019

2120
#[derive(Default)]
2221
#[cfg_attr(debug_assertions, derive(Debug))]

src/daft-local-execution/src/runtime_stats/mod.rs

Lines changed: 124 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
mod subscribers;
1+
mod progress_bar;
22
mod values;
33

44
use std::{
@@ -17,6 +17,7 @@ use daft_context::Subscriber;
1717
use daft_dsl::common_treenode::{TreeNode, TreeNodeRecursion};
1818
use futures::future;
1919
use itertools::Itertools;
20+
use progress_bar::{ProgressBar, make_progress_bar_manager};
2021
use tokio::{
2122
runtime::Handle,
2223
sync::{mpsc, oneshot},
@@ -25,12 +26,7 @@ use tokio::{
2526
use tracing::{Instrument, instrument::Instrumented};
2627
pub use values::{Counter, DefaultRuntimeStats, Gauge, RuntimeStats};
2728

28-
use crate::{
29-
pipeline::PipelineNode,
30-
runtime_stats::subscribers::{
31-
RuntimeStatsSubscriber, progress_bar::make_progress_bar_manager, query::SubscriberWrapper,
32-
},
33-
};
29+
use crate::pipeline::PipelineNode;
3430

3531
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
3632
pub enum QueryEndState {
@@ -94,7 +90,7 @@ impl RuntimeStatsManager {
9490
pub fn try_new(
9591
handle: &Handle,
9692
pipeline: &Box<dyn PipelineNode>,
97-
query_subscribers: Vec<Arc<dyn Subscriber>>,
93+
subscribers: Vec<Arc<dyn Subscriber>>,
9894
query_id: QueryID,
9995
) -> DaftResult<Self> {
10096
// Construct mapping between node id and their node info and runtime stats
@@ -108,25 +104,25 @@ impl RuntimeStatsManager {
108104
Ok(TreeNodeRecursion::Continue)
109105
});
110106

111-
let mut subscribers: Vec<Box<dyn RuntimeStatsSubscriber>> = Vec::new();
112-
for subscriber in query_subscribers {
113-
subscribers.push(Box::new(SubscriberWrapper::try_new(
114-
subscriber,
115-
query_id.clone(),
116-
serde_json::to_string(&pipeline.repr_json())
117-
.expect("Failed to serialize physical plan")
118-
.into(),
119-
)?));
107+
let serialized_plan: Arc<str> = serde_json::to_string(&pipeline.repr_json())
108+
.expect("Failed to serialize physical plan")
109+
.into();
110+
for subscriber in &subscribers {
111+
subscriber.on_exec_start(query_id.clone(), serialized_plan.clone())?;
120112
}
121113

122-
if should_enable_progress_bar() {
123-
subscribers.push(make_progress_bar_manager(&node_info_map));
124-
}
114+
let progress_bar = if should_enable_progress_bar() {
115+
Some(make_progress_bar_manager(&node_info_map))
116+
} else {
117+
None
118+
};
125119

126120
let throttle_interval = Duration::from_millis(200);
127121
Ok(Self::new_impl(
128122
handle,
123+
query_id,
129124
subscribers,
125+
progress_bar,
130126
node_stats_map,
131127
throttle_interval,
132128
))
@@ -135,7 +131,9 @@ impl RuntimeStatsManager {
135131
// Mostly used for testing purposes so we can inject our own subscribers and throttling interval
136132
fn new_impl(
137133
handle: &Handle,
138-
subscribers: Vec<Box<dyn RuntimeStatsSubscriber>>,
134+
query_id: QueryID,
135+
subscribers: Vec<Arc<dyn Subscriber>>,
136+
progress_bar: Option<Box<dyn ProgressBar>>,
139137
node_stats_map: HashMap<NodeID, Arc<dyn RuntimeStats>>,
140138
throttle_interval: Duration,
141139
) -> Self {
@@ -154,7 +152,11 @@ impl RuntimeStatsManager {
154152
biased;
155153
Some((node_id, is_initialize)) = node_rx.recv() => {
156154
if is_initialize && active_nodes.insert(node_id) {
157-
for res in future::join_all(subscribers.iter().map(|subscriber| subscriber.initialize_node(node_id))).await {
155+
if let Some(progress_bar) = &progress_bar {
156+
progress_bar.initialize_node(node_id);
157+
}
158+
159+
for res in future::join_all(subscribers.iter().map(|subscriber| subscriber.on_exec_operator_start(query_id.clone(), node_id))).await {
158160
if let Err(e) = res {
159161
log::error!("Failed to initialize node: {}", e);
160162
}
@@ -164,9 +166,14 @@ impl RuntimeStatsManager {
164166
let event = runtime_stats.flush();
165167
let event = [(node_id, event)];
166168

169+
if let Some(progress_bar) = &progress_bar {
170+
progress_bar.handle_event(&event);
171+
progress_bar.finalize_node(node_id);
172+
}
173+
167174
for res in future::join_all(subscribers.iter().map(|subscriber| async {
168-
subscriber.handle_event(&event).await?;
169-
subscriber.finalize_node(node_id).await
175+
subscriber.on_exec_emit_stats(query_id.clone(), &event).await?;
176+
subscriber.on_exec_operator_end(query_id.clone(), node_id).await
170177
})).await {
171178
if let Err(e) = res {
172179
log::error!("Failed to finalize node: {}", e);
@@ -196,8 +203,12 @@ impl RuntimeStatsManager {
196203
snapshot_container.push((*node_id, event));
197204
}
198205

206+
if let Some(progress_bar) = &progress_bar {
207+
progress_bar.handle_event(snapshot_container.as_slice());
208+
}
209+
199210
for res in future::join_all(subscribers.iter().map(|subscriber| {
200-
subscriber.handle_event(snapshot_container.as_slice())
211+
subscriber.on_exec_emit_stats(query_id.clone(), snapshot_container.as_slice())
201212
})).await {
202213
if let Err(e) = res {
203214
log::error!("Failed to handle event: {}", e);
@@ -208,8 +219,14 @@ impl RuntimeStatsManager {
208219
}
209220
}
210221

222+
if let Some(progress_bar) = progress_bar
223+
&& let Err(e) = progress_bar.finish()
224+
{
225+
log::warn!("Failed to finish progress bar: {}", e);
226+
}
227+
211228
for subscriber in subscribers {
212-
if let Err(e) = subscriber.finish().await {
229+
if let Err(e) = subscriber.on_exec_end(query_id.clone()).await {
213230
log::error!("Failed to flush subscriber: {}", e);
214231
}
215232
}
@@ -287,7 +304,11 @@ mod tests {
287304

288305
use async_trait::async_trait;
289306
use common_error::DaftResult;
290-
use common_metrics::{CPU_US_KEY, NodeID, ROWS_IN_KEY, ROWS_OUT_KEY, Stat, StatSnapshot};
307+
use common_metrics::{
308+
CPU_US_KEY, NodeID, QueryPlan, ROWS_IN_KEY, ROWS_OUT_KEY, Stat, StatSnapshot,
309+
};
310+
use daft_context::{QueryMetadata, QueryResult, Subscriber};
311+
use daft_micropartition::MicroPartitionRef;
291312
use tokio::time::{Duration, sleep};
292313

293314
use super::*;
@@ -325,44 +346,63 @@ mod tests {
325346
}
326347

327348
#[async_trait]
328-
impl RuntimeStatsSubscriber for MockSubscriber {
329-
fn as_any(&self) -> &dyn std::any::Any {
330-
self
349+
impl Subscriber for MockSubscriber {
350+
fn on_query_start(&self, _: QueryID, __: Arc<QueryMetadata>) -> DaftResult<()> {
351+
Ok(())
331352
}
332-
333-
async fn initialize_node(&self, _node_id: NodeID) -> DaftResult<()> {
353+
fn on_query_end(&self, _: QueryID, __: QueryResult) -> DaftResult<()> {
354+
Ok(())
355+
}
356+
fn on_result_out(&self, _: QueryID, __: MicroPartitionRef) -> DaftResult<()> {
357+
Ok(())
358+
}
359+
fn on_optimization_start(&self, _: QueryID) -> DaftResult<()> {
360+
Ok(())
361+
}
362+
fn on_optimization_end(&self, _: QueryID, __: QueryPlan) -> DaftResult<()> {
363+
Ok(())
364+
}
365+
fn on_exec_start(&self, _: QueryID, __: QueryPlan) -> DaftResult<()> {
334366
Ok(())
335367
}
336368

337-
async fn finalize_node(&self, _node_id: NodeID) -> DaftResult<()> {
369+
async fn on_exec_end(&self, _: QueryID) -> DaftResult<()> {
370+
Ok(())
371+
}
372+
async fn on_exec_operator_start(&self, _: QueryID, _: NodeID) -> DaftResult<()> {
373+
Ok(())
374+
}
375+
async fn on_exec_operator_end(&self, _: QueryID, __: NodeID) -> DaftResult<()> {
338376
Ok(())
339377
}
340378

341-
async fn handle_event(&self, events: &[(NodeID, StatSnapshot)]) -> DaftResult<()> {
379+
async fn on_exec_emit_stats(
380+
&self,
381+
_query_id: QueryID,
382+
stats: &[(NodeID, StatSnapshot)],
383+
) -> DaftResult<()> {
342384
self.state
343385
.total_calls
344386
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
345-
for (_, snapshot) in events {
387+
for (_, snapshot) in stats {
346388
*self.state.event.lock().unwrap() = Some(snapshot.clone());
347389
}
348390
Ok(())
349391
}
350-
351-
async fn finish(self: Box<Self>) -> DaftResult<()> {
352-
Ok(())
353-
}
354392
}
355393

356394
#[tokio::test(start_paused = true)]
357395
async fn test_interval_respected() {
358-
let mock_subscriber = Box::new(MockSubscriber::new());
396+
let mock_subscriber = Arc::new(MockSubscriber::new());
359397
let mock_state = mock_subscriber.state.clone();
360398

361399
let node_stat = Arc::new(DefaultRuntimeStats::new(0)) as Arc<dyn RuntimeStats>;
362400
let throttle_interval = Duration::from_millis(50);
363401
let stats_manager = RuntimeStatsManager::new_impl(
364402
&tokio::runtime::Handle::current(),
403+
"test_query_id".into(),
365404
vec![mock_subscriber],
405+
None,
366406
HashMap::from([(0, node_stat.clone())]),
367407
throttle_interval,
368408
);
@@ -412,16 +452,18 @@ mod tests {
412452

413453
#[tokio::test(start_paused = true)]
414454
async fn test_multiple_subscribers_all_receive_events() {
415-
let subscriber1 = Box::new(MockSubscriber::new());
416-
let subscriber2 = Box::new(MockSubscriber::new());
455+
let subscriber1 = Arc::new(MockSubscriber::new());
456+
let subscriber2 = Arc::new(MockSubscriber::new());
417457
let state1 = subscriber1.state.clone();
418458
let state2 = subscriber2.state.clone();
419459

420460
let node_stat = Arc::new(DefaultRuntimeStats::new(0)) as Arc<dyn RuntimeStats>;
421461
let throttle_interval = Duration::from_millis(50);
422462
let stats_manager = RuntimeStatsManager::new_impl(
423463
&tokio::runtime::Handle::current(),
464+
"test_query_id".into(),
424465
vec![subscriber1, subscriber2],
466+
None,
425467
HashMap::from([(0, node_stat.clone())]),
426468
throttle_interval,
427469
);
@@ -443,35 +485,58 @@ mod tests {
443485
struct FailingSubscriber;
444486

445487
#[async_trait]
446-
impl RuntimeStatsSubscriber for FailingSubscriber {
447-
fn as_any(&self) -> &dyn std::any::Any {
448-
self
488+
impl Subscriber for FailingSubscriber {
489+
fn on_query_start(&self, _: QueryID, __: Arc<QueryMetadata>) -> DaftResult<()> {
490+
Ok(())
491+
}
492+
fn on_query_end(&self, _: QueryID, __: QueryResult) -> DaftResult<()> {
493+
Ok(())
449494
}
450-
async fn initialize_node(&self, _: NodeID) -> DaftResult<()> {
495+
fn on_result_out(&self, _: QueryID, __: MicroPartitionRef) -> DaftResult<()> {
451496
Ok(())
452497
}
453-
async fn finalize_node(&self, _: NodeID) -> DaftResult<()> {
498+
fn on_optimization_start(&self, _: QueryID) -> DaftResult<()> {
454499
Ok(())
455500
}
456-
async fn handle_event(&self, _: &[(NodeID, StatSnapshot)]) -> DaftResult<()> {
501+
fn on_optimization_end(&self, _: QueryID, __: QueryPlan) -> DaftResult<()> {
502+
Ok(())
503+
}
504+
fn on_exec_start(&self, _: QueryID, __: QueryPlan) -> DaftResult<()> {
505+
Ok(())
506+
}
507+
508+
async fn on_exec_end(&self, _: QueryID) -> DaftResult<()> {
509+
Ok(())
510+
}
511+
async fn on_exec_operator_start(&self, _: QueryID, _: NodeID) -> DaftResult<()> {
512+
Ok(())
513+
}
514+
async fn on_exec_operator_end(&self, _: QueryID, __: NodeID) -> DaftResult<()> {
515+
Ok(())
516+
}
517+
518+
async fn on_exec_emit_stats(
519+
&self,
520+
_: QueryID,
521+
__: &[(NodeID, StatSnapshot)],
522+
) -> DaftResult<()> {
457523
Err(common_error::DaftError::InternalError(
458524
"Test error".to_string(),
459525
))
460526
}
461-
async fn finish(self: Box<Self>) -> DaftResult<()> {
462-
Ok(())
463-
}
464527
}
465528

466-
let failing_subscriber = Box::new(FailingSubscriber);
467-
let mock_subscriber = Box::new(MockSubscriber::new());
529+
let failing_subscriber = Arc::new(FailingSubscriber);
530+
let mock_subscriber = Arc::new(MockSubscriber::new());
468531
let state = mock_subscriber.state.clone();
469532

470533
let node_stat = Arc::new(DefaultRuntimeStats::new(0)) as Arc<dyn RuntimeStats>;
471534
let throttle_interval = Duration::from_millis(50);
472535
let stats_manager = RuntimeStatsManager::new_impl(
473536
&tokio::runtime::Handle::current(),
537+
"test_query_id".into(),
474538
vec![failing_subscriber, mock_subscriber],
539+
None,
475540
HashMap::from([(0, node_stat.clone())]),
476541
throttle_interval,
477542
);
@@ -507,14 +572,16 @@ mod tests {
507572

508573
#[tokio::test(start_paused = true)]
509574
async fn test_events_without_init() {
510-
let mock_subscriber = Box::new(MockSubscriber::new());
575+
let mock_subscriber = Arc::new(MockSubscriber::new());
511576
let state = mock_subscriber.state.clone();
512577

513578
let node_stat = Arc::new(DefaultRuntimeStats::new(0)) as Arc<dyn RuntimeStats>;
514579
let throttle_interval = Duration::from_millis(50);
515580
let stats_manager = RuntimeStatsManager::new_impl(
516581
&tokio::runtime::Handle::current(),
582+
"test_query_id".into(),
517583
vec![mock_subscriber],
584+
None,
518585
HashMap::from([(0, node_stat.clone())]),
519586
throttle_interval,
520587
);
@@ -537,15 +604,17 @@ mod tests {
537604

538605
#[tokio::test(start_paused = true)]
539606
async fn test_final_event_before_interval() {
540-
let mock_subscriber = Box::new(MockSubscriber::new());
607+
let mock_subscriber = Arc::new(MockSubscriber::new());
541608
let state = mock_subscriber.state.clone();
542609

543610
// Use 500ms for the throttle interval.
544611
let throttle_interval = Duration::from_millis(500);
545612
let node_stat = Arc::new(DefaultRuntimeStats::new(0)) as Arc<dyn RuntimeStats>;
546613
let stats_manager = RuntimeStatsManager::new_impl(
547614
&tokio::runtime::Handle::current(),
615+
"test_query_id".into(),
548616
vec![mock_subscriber],
617+
None,
549618
HashMap::from([(0, node_stat.clone())]),
550619
throttle_interval,
551620
);

0 commit comments

Comments
 (0)