@@ -177,7 +177,7 @@ where
177177 ) -> Self {
178178 let comm_clients =
179179 ctx. create_all_communication_clients :: < PartitionFinished < Builder :: Part > > ( ) ;
180- Self {
180+ let mut this = Self {
181181 partitions : IndexMap :: new ( ) ,
182182 part_builder,
183183 all_partitions,
@@ -186,7 +186,14 @@ where
186186 // times should not be an issue
187187 max_t : Some ( T :: MAX ) ,
188188 _phantom : PhantomData ,
189+ } ;
190+
191+ if let Some ( state) = ctx. load_state :: < IndexMap < Builder :: Part , Builder :: PartitionState > > ( ) {
192+ for ( k, v) in state. into_iter ( ) {
193+ this. add_partition ( k, Some ( v) ) ;
194+ }
189195 }
196+ this
190197 }
191198
192199 fn add_partition ( & mut self , part : Builder :: Part , part_state : Option < Builder :: PartitionState > ) {
@@ -264,7 +271,7 @@ where
264271 _output : & mut Output < Builder :: Part , VO , TO > ,
265272 ctx : & mut OperatorContext ,
266273 ) {
267- let state: Vec < _ > = self
274+ let state: IndexMap < Builder :: Part , Builder :: PartitionState > = self
268275 . partitions
269276 . iter ( )
270277 . map ( |( k, v) | ( k. clone ( ) , v. snapshot ( ) ) )
@@ -333,3 +340,129 @@ where
333340 }
334341 }
335342}
343+
344+ #[ cfg( test) ]
345+ mod tests {
346+ use std:: { sync:: Mutex , time:: Duration } ;
347+
348+ use crate :: {
349+ operators:: * ,
350+ runtime:: SingleThreadRuntime ,
351+ sinks:: { StatelessSink , VecSink } ,
352+ sources:: { StatefulSource , StatefulSourceImpl , StatefulSourcePartition } ,
353+ testing:: CapturingPersistenceBackend ,
354+ worker:: StreamProvider ,
355+ } ;
356+
357+ struct MockSource ( i32 ) ;
358+ struct MockSourcePartition {
359+ max : i32 ,
360+ next : i32 ,
361+ was_snapshotted : Mutex < bool > ,
362+ }
363+
364+ impl StatefulSourceImpl < i32 , i32 > for MockSource {
365+ type Part = ( ) ;
366+
367+ type PartitionState = i32 ;
368+
369+ type SourcePartition = MockSourcePartition ;
370+
371+ fn list_parts ( & self ) -> Vec < Self :: Part > {
372+ vec ! [ ( ) ]
373+ }
374+
375+ fn build_part (
376+ & mut self ,
377+ _part : & Self :: Part ,
378+ part_state : Option < Self :: PartitionState > ,
379+ ) -> Self :: SourcePartition {
380+ MockSourcePartition {
381+ max : self . 0 ,
382+ next : part_state. unwrap_or_default ( ) ,
383+ was_snapshotted : Mutex :: new ( false ) ,
384+ }
385+ }
386+ }
387+
388+ impl StatefulSourcePartition < i32 , i32 > for MockSourcePartition {
389+ type PartitionState = i32 ;
390+
391+ fn poll ( & mut self ) -> Option < ( i32 , i32 ) > {
392+ if self . next > self . max {
393+ None
394+ } else {
395+ let out = ( self . next , self . next ) ;
396+ self . next += 1 ;
397+ Some ( out)
398+ }
399+ }
400+
401+ fn is_finished ( & mut self ) -> bool {
402+ // only terminate after we have made a snapshot
403+ self . next > self . max && * self . was_snapshotted . lock ( ) . unwrap ( )
404+ }
405+
406+ fn snapshot ( & self ) -> Self :: PartitionState {
407+ * self . was_snapshotted . lock ( ) . unwrap ( ) = true ;
408+ self . next
409+ }
410+
411+ fn collect ( self ) -> Self :: PartitionState {
412+ self . next
413+ }
414+ }
415+
416+ /// Check that state gets loaded from persistence backend
417+ /// on initial start
418+ #[ test]
419+ fn test_state_is_loaded_from_persistence ( ) {
420+ let persistence = CapturingPersistenceBackend :: default ( ) ;
421+
422+ let first_sink = VecSink :: new ( ) ;
423+ let first_collected = first_sink. clone ( ) ;
424+
425+ // execute once, this will finish as soon as a snapshot was taken
426+ let rt = SingleThreadRuntime :: builder ( )
427+ . snapshots ( Duration :: from_millis ( 50 ) )
428+ . persistence ( persistence. clone ( ) )
429+ . build ( move |provider : & mut dyn StreamProvider | {
430+ provider
431+ . new_stream ( )
432+ . source ( "mock-source" , StatefulSource :: new ( MockSource ( 10 ) ) )
433+ . sink ( "vec-sink" , StatelessSink :: new ( first_sink) ) ;
434+ } ) ;
435+ rt. execute ( ) . unwrap ( ) ;
436+ let result: Vec < _ > = first_collected
437+ . drain_vec ( ..)
438+ . iter ( )
439+ . map ( |x| x. value )
440+ . collect ( ) ;
441+ let expected: Vec < _ > = ( 0 ..=10 ) . collect ( ) ;
442+ assert_eq ! ( result, expected) ;
443+
444+ // execute again, only numbers 11-15 should have been counted since we started from the
445+ // state which had already counted to 10
446+ let second_sink = VecSink :: new ( ) ;
447+ let second_collected = second_sink. clone ( ) ;
448+
449+ // execute again
450+ let rt = SingleThreadRuntime :: builder ( )
451+ . snapshots ( Duration :: from_millis ( 50 ) )
452+ . persistence ( persistence)
453+ . build ( move |provider : & mut dyn StreamProvider | {
454+ provider
455+ . new_stream ( )
456+ . source ( "mock-source" , StatefulSource :: new ( MockSource ( 15 ) ) )
457+ . sink ( "vec-sink" , StatelessSink :: new ( second_sink) ) ;
458+ } ) ;
459+ rt. execute ( ) . unwrap ( ) ;
460+ let result: Vec < _ > = second_collected
461+ . drain_vec ( ..)
462+ . iter ( )
463+ . map ( |x| x. value )
464+ . collect ( ) ;
465+ let expected: Vec < _ > = ( 11 ..=15 ) . collect ( ) ;
466+ assert_eq ! ( result, expected) ;
467+ }
468+ }
0 commit comments