Skip to content

Commit 9e3b176

Browse files
author
Damion Werner
committed
fix snapshotting not working for sources, add example of stateful source
1 parent 4cd68c6 commit 9e3b176

File tree

4 files changed

+249
-6
lines changed

4 files changed

+249
-6
lines changed

malstrom-core/Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,7 @@ slatedb = ["dep:slatedb", "dep:object_store", "dep:tokio-stream"]
3737
[[example]]
3838
name = "slatedb_backend"
3939
required-features = ["slatedb"]
40+
41+
[[example]]
42+
name = "slatedb_backend_failing"
43+
required-features = ["slatedb"]
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
//! Using SlateDB as a persistence backend
2+
use malstrom::keyed::partitioners::rendezvous_select;
3+
use malstrom::operators::*;
4+
use malstrom::sinks::{StatelessSink, StdOutSink};
5+
use malstrom::sources::{StatefulSource, StatefulSourceImpl, StatefulSourcePartition};
6+
use malstrom::{
7+
runtime::SingleThreadRuntime,
8+
snapshot::SlateDbBackend,
9+
worker::StreamProvider,
10+
};
11+
use object_store::{local::LocalFileSystem, path::Path};
12+
use std::sync::Arc;
13+
use std::thread::sleep;
14+
use std::time::{Duration, Instant};
15+
16+
fn main() {
17+
let filesystem = LocalFileSystem::new();
18+
let persistence = SlateDbBackend::new(Arc::new(filesystem), Path::from("/tmp")).unwrap();
19+
20+
loop {
21+
let job = SingleThreadRuntime::builder()
22+
.persistence(persistence.clone())
23+
.snapshots(Duration::from_secs(1))
24+
.build(build_dataflow);
25+
let thread = std::thread::spawn(move || {
26+
job.execute().unwrap()
27+
});
28+
match thread.join() {
29+
Ok(_) => return,
30+
Err(_) => {
31+
println!("Restarting worker");
32+
continue;
33+
},
34+
}
35+
}
36+
}
37+
38+
fn build_dataflow(provider: &mut dyn StreamProvider) {
39+
let start_time = Instant::now();
40+
let fail_interval = Duration::from_secs(10);
41+
provider
42+
.new_stream()
43+
.source(
44+
"iter-source",
45+
StatefulSource::new(StatefulNumberSource(0)),
46+
)
47+
.key_distribute("key-by-value", |x| x.value & 1 == 1, rendezvous_select)
48+
.stateful_map("sum", |_key, value, state: i32| {
49+
let state = state + value;
50+
(state, Some(state))
51+
})
52+
.inspect("expensive-operation", |_msg, _ctx| {
53+
// we need this to not overflow the sum before "crashing"
54+
sleep(Duration::from_millis(100))
55+
})
56+
.inspect("fail-random", move |_msg, _ctx| {
57+
if Instant::now().duration_since(start_time) > fail_interval {
58+
panic!("Oh no!")
59+
}
60+
})
61+
.sink("stdout", StatelessSink::new(StdOutSink));
62+
}
63+
64+
struct StatefulNumberSource(i32);
65+
66+
impl StatefulSourceImpl<i32, i32> for StatefulNumberSource {
67+
type Part = ();
68+
type PartitionState = i32;
69+
type SourcePartition = Self;
70+
71+
fn list_parts(&self) -> Vec<Self::Part> {
72+
vec![()]
73+
}
74+
75+
fn build_part(
76+
&mut self,
77+
_part: &Self::Part,
78+
part_state: Option<Self::PartitionState>,
79+
) -> Self::SourcePartition {
80+
println!("Build with {part_state:?}");
81+
Self(part_state.unwrap_or_default())
82+
}
83+
}
84+
85+
impl StatefulSourcePartition<i32, i32> for StatefulNumberSource {
86+
type PartitionState = i32;
87+
88+
fn poll(&mut self) -> Option<(i32, i32)> {
89+
let out = Some((self.0, self.0));
90+
self.0 += 1;
91+
out
92+
}
93+
94+
fn is_finished(&mut self) -> bool {
95+
false
96+
}
97+
98+
fn snapshot(&self) -> Self::PartitionState {
99+
println!("SNAPSHOTTING SOURCE");
100+
self.0
101+
}
102+
103+
fn collect(self) -> Self::PartitionState {
104+
self.0
105+
}
106+
}

malstrom-core/src/sources/stateful.rs

Lines changed: 135 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ where
177177
) -> Self {
178178
let comm_clients =
179179
ctx.create_all_communication_clients::<PartitionFinished<Builder::Part>>();
180-
Self {
180+
let mut this = Self {
181181
partitions: IndexMap::new(),
182182
part_builder,
183183
all_partitions,
@@ -186,7 +186,14 @@ where
186186
// times should not be an issue
187187
max_t: Some(T::MAX),
188188
_phantom: PhantomData,
189+
};
190+
191+
if let Some(state) = ctx.load_state::<IndexMap<Builder::Part, Builder::PartitionState>>() {
192+
for (k, v) in state.into_iter() {
193+
this.add_partition(k, Some(v));
194+
}
189195
}
196+
this
190197
}
191198

192199
fn add_partition(&mut self, part: Builder::Part, part_state: Option<Builder::PartitionState>) {
@@ -264,7 +271,7 @@ where
264271
_output: &mut Output<Builder::Part, VO, TO>,
265272
ctx: &mut OperatorContext,
266273
) {
267-
let state: Vec<_> = self
274+
let state: IndexMap<Builder::Part, Builder::PartitionState> = self
268275
.partitions
269276
.iter()
270277
.map(|(k, v)| (k.clone(), v.snapshot()))
@@ -333,3 +340,129 @@ where
333340
}
334341
}
335342
}
343+
344+
#[cfg(test)]
345+
mod tests {
346+
use std::{sync::Mutex, time::Duration};
347+
348+
use crate::{
349+
operators::*,
350+
runtime::SingleThreadRuntime,
351+
sinks::{StatelessSink, VecSink},
352+
sources::{StatefulSource, StatefulSourceImpl, StatefulSourcePartition},
353+
testing::CapturingPersistenceBackend,
354+
worker::StreamProvider,
355+
};
356+
357+
struct MockSource(i32);
358+
struct MockSourcePartition {
359+
max: i32,
360+
next: i32,
361+
was_snapshotted: Mutex<bool>,
362+
}
363+
364+
impl StatefulSourceImpl<i32, i32> for MockSource {
365+
type Part = ();
366+
367+
type PartitionState = i32;
368+
369+
type SourcePartition = MockSourcePartition;
370+
371+
fn list_parts(&self) -> Vec<Self::Part> {
372+
vec![()]
373+
}
374+
375+
fn build_part(
376+
&mut self,
377+
_part: &Self::Part,
378+
part_state: Option<Self::PartitionState>,
379+
) -> Self::SourcePartition {
380+
MockSourcePartition {
381+
max: self.0,
382+
next: part_state.unwrap_or_default(),
383+
was_snapshotted: Mutex::new(false),
384+
}
385+
}
386+
}
387+
388+
impl StatefulSourcePartition<i32, i32> for MockSourcePartition {
389+
type PartitionState = i32;
390+
391+
fn poll(&mut self) -> Option<(i32, i32)> {
392+
if self.next > self.max {
393+
None
394+
} else {
395+
let out = (self.next, self.next);
396+
self.next += 1;
397+
Some(out)
398+
}
399+
}
400+
401+
fn is_finished(&mut self) -> bool {
402+
// only terminate after we have made a snapshot
403+
self.next > self.max && *self.was_snapshotted.lock().unwrap()
404+
}
405+
406+
fn snapshot(&self) -> Self::PartitionState {
407+
*self.was_snapshotted.lock().unwrap() = true;
408+
self.next
409+
}
410+
411+
fn collect(self) -> Self::PartitionState {
412+
self.next
413+
}
414+
}
415+
416+
/// Check that state gets loaded from persistence backend
417+
/// on initial start
418+
#[test]
419+
fn test_state_is_loaded_from_persistence() {
420+
let persistence = CapturingPersistenceBackend::default();
421+
422+
let first_sink = VecSink::new();
423+
let first_collected = first_sink.clone();
424+
425+
// execute once, this will finish as soon as a snapshot was taken
426+
let rt = SingleThreadRuntime::builder()
427+
.snapshots(Duration::from_millis(50))
428+
.persistence(persistence.clone())
429+
.build(move |provider: &mut dyn StreamProvider| {
430+
provider
431+
.new_stream()
432+
.source("mock-source", StatefulSource::new(MockSource(10)))
433+
.sink("vec-sink", StatelessSink::new(first_sink));
434+
});
435+
rt.execute().unwrap();
436+
let result: Vec<_> = first_collected
437+
.drain_vec(..)
438+
.iter()
439+
.map(|x| x.value)
440+
.collect();
441+
let expected: Vec<_> = (0..=10).collect();
442+
assert_eq!(result, expected);
443+
444+
// execute again, only numbers 11-15 should have been counted since we started from the
445+
// state which had already counted to 10
446+
let second_sink = VecSink::new();
447+
let second_collected = second_sink.clone();
448+
449+
// execute again
450+
let rt = SingleThreadRuntime::builder()
451+
.snapshots(Duration::from_millis(50))
452+
.persistence(persistence)
453+
.build(move |provider: &mut dyn StreamProvider| {
454+
provider
455+
.new_stream()
456+
.source("mock-source", StatefulSource::new(MockSource(15)))
457+
.sink("vec-sink", StatelessSink::new(second_sink));
458+
});
459+
rt.execute().unwrap();
460+
let result: Vec<_> = second_collected
461+
.drain_vec(..)
462+
.iter()
463+
.map(|x| x.value)
464+
.collect();
465+
let expected: Vec<_> = (11..=15).collect();
466+
assert_eq!(result, expected);
467+
}
468+
}

malstrom-core/src/testing/mod.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use std::sync::Arc;
12
use std::{collections::HashMap, rc::Rc, sync::Mutex};
23

34
use crate::keyed::distributed::{Acquire, Collect, Interrogate};
@@ -37,13 +38,13 @@ where
3738
/// If you have a clone of this backend you can retrieve the state using
3839
/// the corresponding operator_id
3940
pub struct CapturingPersistenceBackend {
40-
capture: Rc<Mutex<HashMap<OperatorId, Vec<u8>>>>,
41+
capture: Arc<Mutex<HashMap<OperatorId, Vec<u8>>>>,
4142
}
4243
impl PersistenceBackend for CapturingPersistenceBackend {
4344
type Client = CapturingPersistenceBackend;
4445

4546
fn last_commited(&self) -> Option<SnapshotVersion> {
46-
None
47+
Some(SnapshotVersion::default())
4748
}
4849

4950
fn for_version(
@@ -55,8 +56,7 @@ impl PersistenceBackend for CapturingPersistenceBackend {
5556
}
5657

5758
fn commit_version(&self, _snapshot_version: &crate::snapshot::SnapshotVersion) {
58-
// TODO
59-
todo!()
59+
// nothing happening here
6060
}
6161
}
6262

0 commit comments

Comments
 (0)