@@ -3,7 +3,7 @@ use crate::{execution::stats::UpdateStats, prelude::*};
33use super :: stats;
44use futures:: future:: try_join_all;
55use sqlx:: PgPool ;
6- use tokio:: { sync:: watch, time:: MissedTickBehavior } ;
6+ use tokio:: { sync:: watch, task :: JoinSet , time:: MissedTickBehavior } ;
77
88pub struct FlowLiveUpdaterUpdates {
99 pub active_sources : Vec < String > ,
@@ -22,7 +22,8 @@ struct UpdateReceiveState {
2222
2323pub struct FlowLiveUpdater {
2424 flow_ctx : Arc < FlowContext > ,
25- tasks : Vec < ( tokio:: task:: JoinHandle < Result < ( ) > > , Arc < stats:: UpdateStats > ) > ,
25+ join_set : Mutex < Option < JoinSet < Result < ( ) > > > > ,
26+ stats_per_task : Vec < Arc < stats:: UpdateStats > > ,
2627 recv_state : tokio:: sync:: Mutex < UpdateReceiveState > ,
2728 num_remaining_tasks_rx : watch:: Receiver < usize > ,
2829
@@ -267,7 +268,11 @@ impl SourceUpdateTask {
267268 . boxed ( )
268269 } ) ;
269270
270- try_join_all ( futs) . await ?;
271+ let join_result = try_join_all ( futs) . await ;
272+ if let Err ( err) = join_result {
273+ error ! ( "Error in source `{}`: {:?}" , import_op. name, err) ;
274+ return Err ( err) ;
275+ }
271276 Ok ( ( ) )
272277 }
273278}
@@ -288,27 +293,30 @@ impl FlowLiveUpdater {
288293
289294 let ( num_remaining_tasks_tx, num_remaining_tasks_rx) =
290295 watch:: channel ( plan. import_ops . len ( ) ) ;
291- let tasks = ( 0 ..plan. import_ops . len ( ) )
292- . map ( |source_idx| {
293- let source_update_stats = Arc :: new ( stats:: UpdateStats :: default ( ) ) ;
294- let source_update_task = SourceUpdateTask {
295- source_idx,
296- flow : flow_ctx. flow . clone ( ) ,
297- plan : plan. clone ( ) ,
298- execution_ctx : execution_ctx. clone ( ) ,
299- source_update_stats : source_update_stats. clone ( ) ,
300- pool : pool. clone ( ) ,
301- options : options. clone ( ) ,
302- status_tx : status_tx. clone ( ) ,
303- num_remaining_tasks_tx : num_remaining_tasks_tx. clone ( ) ,
304- } ;
305- let task = tokio:: spawn ( source_update_task. run ( ) ) ;
306- ( task, source_update_stats)
307- } )
308- . collect ( ) ;
296+
297+ let mut join_set = JoinSet :: new ( ) ;
298+ let mut stats_per_task = Vec :: new ( ) ;
299+
300+ for source_idx in 0 ..plan. import_ops . len ( ) {
301+ let source_update_stats = Arc :: new ( stats:: UpdateStats :: default ( ) ) ;
302+ let source_update_task = SourceUpdateTask {
303+ source_idx,
304+ flow : flow_ctx. flow . clone ( ) ,
305+ plan : plan. clone ( ) ,
306+ execution_ctx : execution_ctx. clone ( ) ,
307+ source_update_stats : source_update_stats. clone ( ) ,
308+ pool : pool. clone ( ) ,
309+ options : options. clone ( ) ,
310+ status_tx : status_tx. clone ( ) ,
311+ num_remaining_tasks_tx : num_remaining_tasks_tx. clone ( ) ,
312+ } ;
313+ join_set. spawn ( source_update_task. run ( ) ) ;
314+ stats_per_task. push ( source_update_stats) ;
315+ }
309316 Ok ( Self {
310317 flow_ctx,
311- tasks,
318+ join_set : Mutex :: new ( Some ( join_set) ) ,
319+ stats_per_task,
312320 recv_state : tokio:: sync:: Mutex :: new ( UpdateReceiveState {
313321 status_rx,
314322 last_num_source_updates : vec ! [ 0 ; plan. import_ops. len( ) ] ,
@@ -322,27 +330,43 @@ impl FlowLiveUpdater {
322330 }
323331
324332 pub async fn wait ( & self ) -> Result < ( ) > {
325- let mut rx = self . num_remaining_tasks_rx . clone ( ) ;
326- if * rx. borrow ( ) == 0 {
333+ {
334+ let mut rx = self . num_remaining_tasks_rx . clone ( ) ;
335+ rx. wait_for ( |v| * v == 0 ) . await ?;
336+ }
337+
338+ let Some ( mut join_set) = self . join_set . lock ( ) . unwrap ( ) . take ( ) else {
327339 return Ok ( ( ) ) ;
340+ } ;
341+ while let Some ( task_result) = join_set. join_next ( ) . await {
342+ match task_result {
343+ Ok ( Ok ( _) ) => { }
344+ Ok ( Err ( err) ) => {
345+ return Err ( err) ;
346+ }
347+ Err ( err) if err. is_cancelled ( ) => { }
348+ Err ( err) => {
349+ return Err ( err. into ( ) ) ;
350+ }
351+ }
328352 }
329- rx. wait_for ( |v| * v == 0 ) . await ?;
330353 Ok ( ( ) )
331354 }
332355
333356 pub fn abort ( & self ) {
334- for ( task, _) in & self . tasks {
335- task. abort ( ) ;
357+ let mut join_set = self . join_set . lock ( ) . unwrap ( ) ;
358+ if let Some ( join_set) = & mut * join_set {
359+ join_set. abort_all ( ) ;
336360 }
337361 }
338362
339363 pub fn index_update_info ( & self ) -> stats:: IndexUpdateInfo {
340364 stats:: IndexUpdateInfo {
341365 sources : std:: iter:: zip (
342366 self . flow_ctx . flow . flow_instance . import_ops . iter ( ) ,
343- self . tasks . iter ( ) ,
367+ self . stats_per_task . iter ( ) ,
344368 )
345- . map ( |( import_op, ( _ , stats) ) | stats:: SourceUpdateInfo {
369+ . map ( |( import_op, stats) | stats:: SourceUpdateInfo {
346370 source_name : import_op. name . clone ( ) ,
347371 stats : stats. as_ref ( ) . clone ( ) ,
348372 } )
0 commit comments