1- mod subscribers ;
1+ mod progress_bar ;
22mod values;
33
44use std:: {
@@ -17,6 +17,7 @@ use daft_context::Subscriber;
1717use daft_dsl:: common_treenode:: { TreeNode , TreeNodeRecursion } ;
1818use futures:: future;
1919use itertools:: Itertools ;
20+ use progress_bar:: { ProgressBar , make_progress_bar_manager} ;
2021use tokio:: {
2122 runtime:: Handle ,
2223 sync:: { mpsc, oneshot} ,
@@ -25,12 +26,7 @@ use tokio::{
2526use tracing:: { Instrument , instrument:: Instrumented } ;
2627pub use values:: { Counter , DefaultRuntimeStats , Gauge , RuntimeStats } ;
2728
28- use crate :: {
29- pipeline:: PipelineNode ,
30- runtime_stats:: subscribers:: {
31- RuntimeStatsSubscriber , progress_bar:: make_progress_bar_manager, query:: SubscriberWrapper ,
32- } ,
33- } ;
29+ use crate :: pipeline:: PipelineNode ;
3430
3531#[ derive( Debug , Clone , Copy , PartialEq , Eq ) ]
3632pub enum QueryEndState {
@@ -94,7 +90,7 @@ impl RuntimeStatsManager {
9490 pub fn try_new (
9591 handle : & Handle ,
9692 pipeline : & Box < dyn PipelineNode > ,
97- query_subscribers : Vec < Arc < dyn Subscriber > > ,
93+ subscribers : Vec < Arc < dyn Subscriber > > ,
9894 query_id : QueryID ,
9995 ) -> DaftResult < Self > {
10096 // Construct mapping between node id and their node info and runtime stats
@@ -108,25 +104,25 @@ impl RuntimeStatsManager {
108104 Ok ( TreeNodeRecursion :: Continue )
109105 } ) ;
110106
111- let mut subscribers: Vec < Box < dyn RuntimeStatsSubscriber > > = Vec :: new ( ) ;
112- for subscriber in query_subscribers {
113- subscribers. push ( Box :: new ( SubscriberWrapper :: try_new (
114- subscriber,
115- query_id. clone ( ) ,
116- serde_json:: to_string ( & pipeline. repr_json ( ) )
117- . expect ( "Failed to serialize physical plan" )
118- . into ( ) ,
119- ) ?) ) ;
107+ let serialized_plan: Arc < str > = serde_json:: to_string ( & pipeline. repr_json ( ) )
108+ . expect ( "Failed to serialize physical plan" )
109+ . into ( ) ;
110+ for subscriber in & subscribers {
111+ subscriber. on_exec_start ( query_id. clone ( ) , serialized_plan. clone ( ) ) ?;
120112 }
121113
122- if should_enable_progress_bar ( ) {
123- subscribers. push ( make_progress_bar_manager ( & node_info_map) ) ;
124- }
114+ let progress_bar = if should_enable_progress_bar ( ) {
115+ Some ( make_progress_bar_manager ( & node_info_map) )
116+ } else {
117+ None
118+ } ;
125119
126120 let throttle_interval = Duration :: from_millis ( 200 ) ;
127121 Ok ( Self :: new_impl (
128122 handle,
123+ query_id,
129124 subscribers,
125+ progress_bar,
130126 node_stats_map,
131127 throttle_interval,
132128 ) )
@@ -135,7 +131,9 @@ impl RuntimeStatsManager {
135131 // Mostly used for testing purposes so we can inject our own subscribers and throttling interval
136132 fn new_impl (
137133 handle : & Handle ,
138- subscribers : Vec < Box < dyn RuntimeStatsSubscriber > > ,
134+ query_id : QueryID ,
135+ subscribers : Vec < Arc < dyn Subscriber > > ,
136+ progress_bar : Option < Box < dyn ProgressBar > > ,
139137 node_stats_map : HashMap < NodeID , Arc < dyn RuntimeStats > > ,
140138 throttle_interval : Duration ,
141139 ) -> Self {
@@ -154,7 +152,11 @@ impl RuntimeStatsManager {
154152 biased;
155153 Some ( ( node_id, is_initialize) ) = node_rx. recv( ) => {
156154 if is_initialize && active_nodes. insert( node_id) {
157- for res in future:: join_all( subscribers. iter( ) . map( |subscriber| subscriber. initialize_node( node_id) ) ) . await {
155+ if let Some ( progress_bar) = & progress_bar {
156+ progress_bar. initialize_node( node_id) ;
157+ }
158+
159+ for res in future:: join_all( subscribers. iter( ) . map( |subscriber| subscriber. on_exec_operator_start( query_id. clone( ) , node_id) ) ) . await {
158160 if let Err ( e) = res {
159161 log:: error!( "Failed to initialize node: {}" , e) ;
160162 }
@@ -164,9 +166,14 @@ impl RuntimeStatsManager {
164166 let event = runtime_stats. flush( ) ;
165167 let event = [ ( node_id, event) ] ;
166168
169+ if let Some ( progress_bar) = & progress_bar {
170+ progress_bar. handle_event( & event) ;
171+ progress_bar. finalize_node( node_id) ;
172+ }
173+
167174 for res in future:: join_all( subscribers. iter( ) . map( |subscriber| async {
168- subscriber. handle_event ( & event) . await ?;
169- subscriber. finalize_node ( node_id) . await
175+ subscriber. on_exec_emit_stats ( query_id . clone ( ) , & event) . await ?;
176+ subscriber. on_exec_operator_end ( query_id . clone ( ) , node_id) . await
170177 } ) ) . await {
171178 if let Err ( e) = res {
172179 log:: error!( "Failed to finalize node: {}" , e) ;
@@ -196,8 +203,12 @@ impl RuntimeStatsManager {
196203 snapshot_container. push( ( * node_id, event) ) ;
197204 }
198205
206+ if let Some ( progress_bar) = & progress_bar {
207+ progress_bar. handle_event( snapshot_container. as_slice( ) ) ;
208+ }
209+
199210 for res in future:: join_all( subscribers. iter( ) . map( |subscriber| {
200- subscriber. handle_event ( snapshot_container. as_slice( ) )
211+ subscriber. on_exec_emit_stats ( query_id . clone ( ) , snapshot_container. as_slice( ) )
201212 } ) ) . await {
202213 if let Err ( e) = res {
203214 log:: error!( "Failed to handle event: {}" , e) ;
@@ -208,8 +219,14 @@ impl RuntimeStatsManager {
208219 }
209220 }
210221
222+ if let Some ( progress_bar) = progress_bar
223+ && let Err ( e) = progress_bar. finish ( )
224+ {
225+ log:: warn!( "Failed to finish progress bar: {}" , e) ;
226+ }
227+
211228 for subscriber in subscribers {
212- if let Err ( e) = subscriber. finish ( ) . await {
229+ if let Err ( e) = subscriber. on_exec_end ( query_id . clone ( ) ) . await {
213230 log:: error!( "Failed to flush subscriber: {}" , e) ;
214231 }
215232 }
@@ -287,7 +304,11 @@ mod tests {
287304
288305 use async_trait:: async_trait;
289306 use common_error:: DaftResult ;
290- use common_metrics:: { CPU_US_KEY , NodeID , ROWS_IN_KEY , ROWS_OUT_KEY , Stat , StatSnapshot } ;
307+ use common_metrics:: {
308+ CPU_US_KEY , NodeID , QueryPlan , ROWS_IN_KEY , ROWS_OUT_KEY , Stat , StatSnapshot ,
309+ } ;
310+ use daft_context:: { QueryMetadata , QueryResult , Subscriber } ;
311+ use daft_micropartition:: MicroPartitionRef ;
291312 use tokio:: time:: { Duration , sleep} ;
292313
293314 use super :: * ;
@@ -325,44 +346,63 @@ mod tests {
325346 }
326347
327348 #[ async_trait]
328- impl RuntimeStatsSubscriber for MockSubscriber {
329- fn as_any ( & self ) -> & dyn std :: any :: Any {
330- self
349+ impl Subscriber for MockSubscriber {
350+ fn on_query_start ( & self , _ : QueryID , __ : Arc < QueryMetadata > ) -> DaftResult < ( ) > {
351+ Ok ( ( ) )
331352 }
332-
333- async fn initialize_node ( & self , _node_id : NodeID ) -> DaftResult < ( ) > {
353+ fn on_query_end ( & self , _: QueryID , __ : QueryResult ) -> DaftResult < ( ) > {
354+ Ok ( ( ) )
355+ }
356+ fn on_result_out ( & self , _: QueryID , __ : MicroPartitionRef ) -> DaftResult < ( ) > {
357+ Ok ( ( ) )
358+ }
359+ fn on_optimization_start ( & self , _: QueryID ) -> DaftResult < ( ) > {
360+ Ok ( ( ) )
361+ }
362+ fn on_optimization_end ( & self , _: QueryID , __ : QueryPlan ) -> DaftResult < ( ) > {
363+ Ok ( ( ) )
364+ }
365+ fn on_exec_start ( & self , _: QueryID , __ : QueryPlan ) -> DaftResult < ( ) > {
334366 Ok ( ( ) )
335367 }
336368
337- async fn finalize_node ( & self , _node_id : NodeID ) -> DaftResult < ( ) > {
369+ async fn on_exec_end ( & self , _: QueryID ) -> DaftResult < ( ) > {
370+ Ok ( ( ) )
371+ }
372+ async fn on_exec_operator_start ( & self , _: QueryID , _: NodeID ) -> DaftResult < ( ) > {
373+ Ok ( ( ) )
374+ }
375+ async fn on_exec_operator_end ( & self , _: QueryID , __ : NodeID ) -> DaftResult < ( ) > {
338376 Ok ( ( ) )
339377 }
340378
341- async fn handle_event ( & self , events : & [ ( NodeID , StatSnapshot ) ] ) -> DaftResult < ( ) > {
379+ async fn on_exec_emit_stats (
380+ & self ,
381+ _query_id : QueryID ,
382+ stats : & [ ( NodeID , StatSnapshot ) ] ,
383+ ) -> DaftResult < ( ) > {
342384 self . state
343385 . total_calls
344386 . fetch_add ( 1 , std:: sync:: atomic:: Ordering :: SeqCst ) ;
345- for ( _, snapshot) in events {
387+ for ( _, snapshot) in stats {
346388 * self . state . event . lock ( ) . unwrap ( ) = Some ( snapshot. clone ( ) ) ;
347389 }
348390 Ok ( ( ) )
349391 }
350-
351- async fn finish ( self : Box < Self > ) -> DaftResult < ( ) > {
352- Ok ( ( ) )
353- }
354392 }
355393
356394 #[ tokio:: test( start_paused = true ) ]
357395 async fn test_interval_respected ( ) {
358- let mock_subscriber = Box :: new ( MockSubscriber :: new ( ) ) ;
396+ let mock_subscriber = Arc :: new ( MockSubscriber :: new ( ) ) ;
359397 let mock_state = mock_subscriber. state . clone ( ) ;
360398
361399 let node_stat = Arc :: new ( DefaultRuntimeStats :: new ( 0 ) ) as Arc < dyn RuntimeStats > ;
362400 let throttle_interval = Duration :: from_millis ( 50 ) ;
363401 let stats_manager = RuntimeStatsManager :: new_impl (
364402 & tokio:: runtime:: Handle :: current ( ) ,
403+ "test_query_id" . into ( ) ,
365404 vec ! [ mock_subscriber] ,
405+ None ,
366406 HashMap :: from ( [ ( 0 , node_stat. clone ( ) ) ] ) ,
367407 throttle_interval,
368408 ) ;
@@ -412,16 +452,18 @@ mod tests {
412452
413453 #[ tokio:: test( start_paused = true ) ]
414454 async fn test_multiple_subscribers_all_receive_events ( ) {
415- let subscriber1 = Box :: new ( MockSubscriber :: new ( ) ) ;
416- let subscriber2 = Box :: new ( MockSubscriber :: new ( ) ) ;
455+ let subscriber1 = Arc :: new ( MockSubscriber :: new ( ) ) ;
456+ let subscriber2 = Arc :: new ( MockSubscriber :: new ( ) ) ;
417457 let state1 = subscriber1. state . clone ( ) ;
418458 let state2 = subscriber2. state . clone ( ) ;
419459
420460 let node_stat = Arc :: new ( DefaultRuntimeStats :: new ( 0 ) ) as Arc < dyn RuntimeStats > ;
421461 let throttle_interval = Duration :: from_millis ( 50 ) ;
422462 let stats_manager = RuntimeStatsManager :: new_impl (
423463 & tokio:: runtime:: Handle :: current ( ) ,
464+ "test_query_id" . into ( ) ,
424465 vec ! [ subscriber1, subscriber2] ,
466+ None ,
425467 HashMap :: from ( [ ( 0 , node_stat. clone ( ) ) ] ) ,
426468 throttle_interval,
427469 ) ;
@@ -443,35 +485,58 @@ mod tests {
443485 struct FailingSubscriber ;
444486
445487 #[ async_trait]
446- impl RuntimeStatsSubscriber for FailingSubscriber {
447- fn as_any ( & self ) -> & dyn std:: any:: Any {
448- self
488+ impl Subscriber for FailingSubscriber {
489+ fn on_query_start ( & self , _: QueryID , __ : Arc < QueryMetadata > ) -> DaftResult < ( ) > {
490+ Ok ( ( ) )
491+ }
492+ fn on_query_end ( & self , _: QueryID , __ : QueryResult ) -> DaftResult < ( ) > {
493+ Ok ( ( ) )
449494 }
450- async fn initialize_node ( & self , _: NodeID ) -> DaftResult < ( ) > {
495+ fn on_result_out ( & self , _: QueryID , __ : MicroPartitionRef ) -> DaftResult < ( ) > {
451496 Ok ( ( ) )
452497 }
453- async fn finalize_node ( & self , _: NodeID ) -> DaftResult < ( ) > {
498+ fn on_optimization_start ( & self , _: QueryID ) -> DaftResult < ( ) > {
454499 Ok ( ( ) )
455500 }
456- async fn handle_event ( & self , _: & [ ( NodeID , StatSnapshot ) ] ) -> DaftResult < ( ) > {
501+ fn on_optimization_end ( & self , _: QueryID , __ : QueryPlan ) -> DaftResult < ( ) > {
502+ Ok ( ( ) )
503+ }
504+ fn on_exec_start ( & self , _: QueryID , __ : QueryPlan ) -> DaftResult < ( ) > {
505+ Ok ( ( ) )
506+ }
507+
508+ async fn on_exec_end ( & self , _: QueryID ) -> DaftResult < ( ) > {
509+ Ok ( ( ) )
510+ }
511+ async fn on_exec_operator_start ( & self , _: QueryID , _: NodeID ) -> DaftResult < ( ) > {
512+ Ok ( ( ) )
513+ }
514+ async fn on_exec_operator_end ( & self , _: QueryID , __ : NodeID ) -> DaftResult < ( ) > {
515+ Ok ( ( ) )
516+ }
517+
518+ async fn on_exec_emit_stats (
519+ & self ,
520+ _: QueryID ,
521+ __ : & [ ( NodeID , StatSnapshot ) ] ,
522+ ) -> DaftResult < ( ) > {
457523 Err ( common_error:: DaftError :: InternalError (
458524 "Test error" . to_string ( ) ,
459525 ) )
460526 }
461- async fn finish ( self : Box < Self > ) -> DaftResult < ( ) > {
462- Ok ( ( ) )
463- }
464527 }
465528
466- let failing_subscriber = Box :: new ( FailingSubscriber ) ;
467- let mock_subscriber = Box :: new ( MockSubscriber :: new ( ) ) ;
529+ let failing_subscriber = Arc :: new ( FailingSubscriber ) ;
530+ let mock_subscriber = Arc :: new ( MockSubscriber :: new ( ) ) ;
468531 let state = mock_subscriber. state . clone ( ) ;
469532
470533 let node_stat = Arc :: new ( DefaultRuntimeStats :: new ( 0 ) ) as Arc < dyn RuntimeStats > ;
471534 let throttle_interval = Duration :: from_millis ( 50 ) ;
472535 let stats_manager = RuntimeStatsManager :: new_impl (
473536 & tokio:: runtime:: Handle :: current ( ) ,
537+ "test_query_id" . into ( ) ,
474538 vec ! [ failing_subscriber, mock_subscriber] ,
539+ None ,
475540 HashMap :: from ( [ ( 0 , node_stat. clone ( ) ) ] ) ,
476541 throttle_interval,
477542 ) ;
@@ -507,14 +572,16 @@ mod tests {
507572
508573 #[ tokio:: test( start_paused = true ) ]
509574 async fn test_events_without_init ( ) {
510- let mock_subscriber = Box :: new ( MockSubscriber :: new ( ) ) ;
575+ let mock_subscriber = Arc :: new ( MockSubscriber :: new ( ) ) ;
511576 let state = mock_subscriber. state . clone ( ) ;
512577
513578 let node_stat = Arc :: new ( DefaultRuntimeStats :: new ( 0 ) ) as Arc < dyn RuntimeStats > ;
514579 let throttle_interval = Duration :: from_millis ( 50 ) ;
515580 let stats_manager = RuntimeStatsManager :: new_impl (
516581 & tokio:: runtime:: Handle :: current ( ) ,
582+ "test_query_id" . into ( ) ,
517583 vec ! [ mock_subscriber] ,
584+ None ,
518585 HashMap :: from ( [ ( 0 , node_stat. clone ( ) ) ] ) ,
519586 throttle_interval,
520587 ) ;
@@ -537,15 +604,17 @@ mod tests {
537604
538605 #[ tokio:: test( start_paused = true ) ]
539606 async fn test_final_event_before_interval ( ) {
540- let mock_subscriber = Box :: new ( MockSubscriber :: new ( ) ) ;
607+ let mock_subscriber = Arc :: new ( MockSubscriber :: new ( ) ) ;
541608 let state = mock_subscriber. state . clone ( ) ;
542609
543610 // Use 500ms for the throttle interval.
544611 let throttle_interval = Duration :: from_millis ( 500 ) ;
545612 let node_stat = Arc :: new ( DefaultRuntimeStats :: new ( 0 ) ) as Arc < dyn RuntimeStats > ;
546613 let stats_manager = RuntimeStatsManager :: new_impl (
547614 & tokio:: runtime:: Handle :: current ( ) ,
615+ "test_query_id" . into ( ) ,
548616 vec ! [ mock_subscriber] ,
617+ None ,
549618 HashMap :: from ( [ ( 0 , node_stat. clone ( ) ) ] ) ,
550619 throttle_interval,
551620 ) ;
0 commit comments