1- use crate :: { InputProvider , OutputStream , VarName , semantics:: AsyncConfig } ;
1+ use crate :: {
2+ InputProvider , OutputStream , VarName ,
3+ semantics:: AsyncConfig ,
4+ stream_utils:: { self , SenderWithAck } ,
5+ } ;
26use async_stream:: stream;
37use async_trait:: async_trait;
4- use std:: collections:: BTreeMap ;
8+ use futures:: stream:: FuturesUnordered ;
9+ use futures:: { FutureExt , StreamExt } ;
10+ use std:: collections:: { BTreeMap , BTreeSet } ;
511use tracing:: debug;
612use unsync:: spsc:: Sender as SpscSender ;
713
814const CHANNEL_SIZE : usize = 10 ;
915
1016struct Channel < AC : AsyncConfig > {
11- sender : Option < SpscSender < AC :: Val > > ,
12- control_sender : Option < SpscSender < ( ) > > ,
17+ sender : Option < SpscSender < AC :: Val > > , // Data sent from user to receiver
1318 receiver : Option < OutputStream < AC :: Val > > ,
1419}
1520
1621pub struct ManualInputProvider < AC : AsyncConfig > {
17- vars : BTreeMap < VarName , Channel < AC > > ,
22+ map : BTreeMap < VarName , Channel < AC > > ,
23+ senders : Option < BTreeMap < VarName , SenderWithAck < ( ) > > > , // Control tick'er used within
24+ // control_stream
1825}
1926
2027impl < AC : AsyncConfig > ManualInputProvider < AC > {
@@ -26,72 +33,86 @@ impl<AC: AsyncConfig> ManualInputProvider<AC> {
2633 // On top of the regular InputProvider interface, it needs to have a sender channel for pushing
2734 // values.
2835 pub fn new ( input_vars : Vec < VarName > ) -> Self {
29- let vars: BTreeMap < VarName , Channel < AC > > = input_vars
30- . into_iter ( )
31- . map ( |v| {
32- let ( tx, mut rx) = unsync:: spsc:: channel ( CHANNEL_SIZE ) ;
33- let ( ctrl_tx, mut ctrl_rx) = unsync:: spsc:: channel ( CHANNEL_SIZE ) ;
34-
35- let rx: OutputStream < AC :: Val > = Box :: pin ( stream ! {
36- while let Some ( _) = ctrl_rx. recv( ) . await {
37- if let Some ( val) = rx. recv( ) . await {
38- yield val;
39- }
40- else {
41- return ;
42- }
43- }
44- } ) ;
45- (
46- v,
47- Channel {
48- sender : Some ( tx) ,
49- control_sender : Some ( ctrl_tx) ,
50- receiver : Some ( rx) ,
51- } ,
52- )
53- } )
54- . collect ( ) ;
55-
56- Self { vars }
36+ let mut map = BTreeMap :: new ( ) ;
37+ let mut senders = BTreeMap :: new ( ) ;
38+ for name in input_vars. into_iter ( ) {
39+ let ( tx, mut rx) = unsync:: spsc:: channel ( CHANNEL_SIZE ) ;
40+ let ( ctrl_tx, mut ctrl_rx) = stream_utils:: channel_with_ack ( CHANNEL_SIZE ) ;
41+ let var_stream: OutputStream < AC :: Val > = Box :: pin ( stream ! {
42+ while let Some ( _) = ctrl_rx. next( ) . await {
43+ if let Some ( val) = rx. recv( ) . await {
44+ yield val;
45+ } else {
46+ return ;
47+ }
48+ } } ) ;
49+ map. insert (
50+ name. clone ( ) ,
51+ Channel {
52+ sender : Some ( tx) ,
53+ receiver : Some ( var_stream) ,
54+ } ,
55+ ) ;
56+ senders. insert ( name, ctrl_tx) ;
57+ }
58+ Self {
59+ map,
60+ senders : Some ( senders) ,
61+ }
5762 }
5863
5964 pub fn sender_channel ( & mut self , var : & VarName ) -> Option < SpscSender < AC :: Val > > {
60- self . vars . get_mut ( var) ?. sender . take ( )
65+ self . map . get_mut ( var) ?. sender . take ( )
6166 }
6267}
6368
6469#[ async_trait( ?Send ) ]
6570impl < AC : AsyncConfig > InputProvider for ManualInputProvider < AC > {
6671 type Val = AC :: Val ;
6772 fn var_stream ( & mut self , var : & VarName ) -> Option < OutputStream < Self :: Val > > {
68- self . vars
73+ self . map
6974 . get_mut ( var)
7075 . and_then ( |channel| channel. receiver . take ( ) )
7176 }
7277
7378 async fn control_stream ( & mut self ) -> OutputStream < anyhow:: Result < ( ) > > {
74- let mut ctrl_tx = BTreeMap :: new ( ) ;
75- for ( var, channel) in self . vars . iter_mut ( ) {
76- let sender = channel
77- . control_sender
78- . take ( )
79- . expect ( "control stream can only be taken once" ) ;
80- ctrl_tx. insert ( var. clone ( ) , sender) ;
81- }
79+ let mut ctrl_tx = self
80+ . senders
81+ . take ( )
82+ . expect ( "control stream can only be taken once" ) ;
83+ // Set of those streams that have been taken with var(),
84+ // otherwise we deadlock waiting for acks from those that are not taken.
85+ let taken = self
86+ . map
87+ . iter ( )
88+ . filter_map ( |( name, stream) | stream. receiver . is_none ( ) . then ( || name. clone ( ) ) )
89+ . collect :: < BTreeSet < _ > > ( ) ;
90+ ctrl_tx. retain ( |name, _| taken. contains ( name) ) ;
8291
8392 Box :: pin ( stream ! {
8493 loop {
85- let mut dead = Vec :: new ( ) ;
86- // Send to each sender:
87- for ( name , sender ) in & mut ctrl_tx {
88- debug! ( "Sending tick to var stream {}" , name ) ;
89- if let Err ( e ) = sender. send ( ( ) ) . await {
90- // Not an error, most likely because the channel is done
91- debug! ( "Failed to send tick to var stream {}: {}" , name, e ) ;
92- dead . push( name . clone ( ) ) ;
94+ // Must be in scope to ensure only one mutable borrow
95+ let dead = {
96+ let mut futs = FuturesUnordered :: new ( ) ;
97+
98+ // Create sender tasks
99+ for ( name , sender ) in & mut ctrl_tx {
100+ let task = sender . send( ( ) ) . map ( |res| ( name. clone ( ) , res ) ) ;
101+ futs . push( task ) ;
93102 }
94- }
103+
104+ let mut dead = Vec :: new( ) ;
105+ // Futs returns None when empty - does not indicate tasks result
106+ while let Some ( ( name, res) ) = futs. next( ) . await {
107+ if let Err ( e) = res {
108+ // Not an error, most likely because the channel is done
109+ debug!( "Failed to send tick to var stream {}: {}" , name, e) ;
110+ dead. push( name) ;
111+ }
112+ }
113+ dead
114+ } ;
115+
95116 for name in dead {
96117 ctrl_tx. remove( & name) ;
97118 }
@@ -100,11 +121,6 @@ impl<AC: AsyncConfig> InputProvider for ManualInputProvider<AC> {
100121 if ctrl_tx. is_empty( ) {
101122 return ;
102123 }
103- // Timer to avoid starvation - has been seen in tests.
104- // (smool::future::yield_now() does not do the trick)
105- // A better but more complex solution would be to add backpressure from the
106- // tick receiver.
107- smol:: Timer :: after( std:: time:: Duration :: from_millis( 1 ) ) . await ;
108124 }
109125 } )
110126 }
@@ -282,7 +298,7 @@ mod tests {
282298 async fn var_stream_progress_different_len ( _ex : Rc < LocalExecutor < ' static > > ) {
283299 let xs = vec ! [ Value :: Int ( 1 ) , Value :: Int ( 2 ) ] ;
284300 let mut x_iter = xs. into_iter ( ) ;
285- let _ys: Vec < Value > = vec ! [ ] ; // ys are empty..
301+ let _ys: Vec < Value > = vec ! [ ] ; // ys are empty..
286302 let mut provider = ManualInputProvider :: < TestConfig > :: new ( vec ! [ "x" . into( ) , "y" . into( ) ] ) ;
287303
288304 let mut x_stream = provider
@@ -361,71 +377,95 @@ mod tests {
361377 }
362378
363379 #[ apply( async_test) ]
364- async fn control_stream_without_consuming ( _ex : Rc < LocalExecutor < ' static > > ) {
365- // Tests that the InputProvider does not hang if a Runtime does not consume all the
366- // var_streams/sender_channels.
367- let mut provider = ManualInputProvider :: < TestConfig > :: new ( vec ! [ "x" . into( ) , "y" . into( ) ] ) ;
368- let mut control_stream = provider. control_stream ( ) . await ;
369- for _ in 1 ..15 {
370- let _ = with_timeout ( control_stream. next ( ) , 1 , "control_stream.next()" )
371- . await
372- . expect ( "Control stream should not hang" ) ;
373- }
374-
375- let mut provider = ManualInputProvider :: < TestConfig > :: new ( vec ! [ "x" . into( ) , "y" . into( ) ] ) ;
376- let _x_stream = provider
377- . var_stream ( & "x" . into ( ) )
378- . expect ( "x stream should be available" ) ;
379- let mut control_stream = provider. control_stream ( ) . await ;
380- for _ in 1 ..15 {
381- let _ = with_timeout ( control_stream. next ( ) , 1 , "control_stream.next()" )
382- . await
383- . expect ( "Control stream should not hang" ) ;
384- }
385- }
386-
387- #[ apply( async_test) ]
388- async fn var_stream_reverse_ticks ( _ex : Rc < LocalExecutor < ' static > > ) {
389- // Checks that if we first send the values through sender channels, and then tick the
390- // control stream, the var_streams still yield the values.
391- let xs = vec ! [ Value :: Int ( 1 ) ] ;
380+ async fn var_stream_large_regression ( _ex : Rc < LocalExecutor < ' static > > ) {
381+ // Test that checks that ManualInputProvider can handle a large number of ticks without
382+ // deadlocking or running out of memory.
383+ // Introduced after regression with runtime test
384+
385+ const SIZE : usize = 3000 ;
386+ let xs: Vec < Value > = ( 0 ..SIZE ) . map ( |x| Value :: Int ( x as i64 ) ) . collect ( ) ;
387+ let ys: Vec < Value > = ( 0 ..SIZE ) . map ( |x| Value :: Int ( x as i64 ) ) . collect ( ) ;
388+ let add = if SIZE % 2 == 0 { 0 } else { 1 } ;
389+ let es: Vec < Value > = std:: iter:: repeat ( Value :: Deferred )
390+ . take ( ( SIZE / 2 ) as usize )
391+ . chain ( ( 0 ..( SIZE / 2 ) + add) . map ( |_| Value :: Str ( "x + y" . into ( ) ) ) )
392+ . collect ( ) ;
392393 let mut x_iter = xs. into_iter ( ) ;
393- let mut provider = ManualInputProvider :: < TestConfig > :: new ( vec ! [ "x" . into( ) ] ) ;
394+ let mut y_iter = ys. into_iter ( ) ;
395+ let mut e_iter = es. into_iter ( ) ;
396+ let mut provider =
397+ ManualInputProvider :: < TestConfig > :: new ( vec ! [ "x" . into( ) , "y" . into( ) , "e" . into( ) ] ) ;
394398
395- let mut x_stream = provider
396- . var_stream ( & "x" . into ( ) )
397- . expect ( "x stream should be available" ) ;
398399 let mut x_sender = provider
399400 . sender_channel ( & "x" . into ( ) )
400401 . expect ( "x sender should exist" ) ;
402+ let mut y_sender = provider
403+ . sender_channel ( & "y" . into ( ) )
404+ . expect ( "y sender should exist" ) ;
405+ let mut e_sender = provider
406+ . sender_channel ( & "e" . into ( ) )
407+ . expect ( "e sender should exist" ) ;
408+
409+ let mut x_stream = provider
410+ . var_stream ( & "x" . into ( ) )
411+ . expect ( "x stream should be available" ) ;
412+ let mut y_stream = provider
413+ . var_stream ( & "y" . into ( ) )
414+ . expect ( "y stream should be available" ) ;
415+ let mut e_stream = provider
416+ . var_stream ( & "e" . into ( ) )
417+ . expect ( "e stream should be available" ) ;
401418
402419 let mut control_stream = provider. control_stream ( ) . await ;
403420
404- let _ = x_sender
405- . send ( x_iter. next ( ) . unwrap ( ) )
406- . await
407- . expect ( "x_sender should be able to send" ) ;
421+ for _ in 0 ..SIZE {
422+ let _ = with_timeout ( control_stream. next ( ) , 1 , "control_stream.next()" )
423+ . await
424+ . expect ( "control stream should yield a value" ) ;
425+ let _ = x_sender
426+ . send ( x_iter. next ( ) . unwrap ( ) )
427+ . await
428+ . expect ( "x_sender should be able to send" ) ;
429+ let _ = with_timeout ( x_stream. next ( ) , 1 , "x_stream.next()" )
430+ . await
431+ . expect ( "x stream should yield a value" ) ;
432+ let _ = y_sender
433+ . send ( y_iter. next ( ) . unwrap ( ) )
434+ . await
435+ . expect ( "y_sender should be able to send" ) ;
436+ let _ = with_timeout ( y_stream. next ( ) , 1 , "y_stream.next()" )
437+ . await
438+ . expect ( "y stream should yield a value" ) ;
439+ let _ = e_sender
440+ . send ( e_iter. next ( ) . unwrap ( ) )
441+ . await
442+ . expect ( "e_sender should be able to send" ) ;
443+ let _ = with_timeout ( e_stream. next ( ) , 1 , "e_stream.next()" )
444+ . await
445+ . expect ( "e stream should yield a value" ) ;
446+ }
447+
408448 let _ = control_stream
409449 . next ( )
410450 . await
411451 . expect ( "control stream should yield Ok" ) ;
412452
453+ std:: mem:: drop ( x_sender) ;
454+ std:: mem:: drop ( y_sender) ;
455+ std:: mem:: drop ( e_sender) ;
413456 let x_res = with_timeout ( x_stream. next ( ) , 1 , "x_stream.next()" )
414457 . await
415458 . expect ( "x stream should yield a value" ) ;
416-
417- assert_eq ! ( x_res, Some ( 1 . into( ) ) ) ;
418-
419- std:: mem:: drop ( x_sender) ; // Drop sender to indicate it is done
420- let _ = control_stream
421- . next ( )
459+ let y_res = with_timeout ( y_stream. next ( ) , 1 , "y_stream.next()" )
422460 . await
423- . expect ( "control stream should yield Ok" ) ;
424-
425- let x_res = with_timeout ( x_stream. next ( ) , 1 , "x_stream.next()" )
461+ . expect ( "y stream should yield a value" ) ;
462+ let e_res = with_timeout ( e_stream. next ( ) , 1 , "e_stream.next()" )
426463 . await
427- . expect ( "x stream should yield a value" ) ;
464+ . expect ( "e stream should yield a value" ) ;
428465
466+ // All are exhausted:
429467 assert_eq ! ( x_res, None ) ;
468+ assert_eq ! ( y_res, None ) ;
469+ assert_eq ! ( e_res, None ) ;
430470 }
431471}
0 commit comments