Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 19 additions & 35 deletions concurrency/src/tasks/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,49 +8,33 @@ use spawned_rt::tasks::JoinHandle;
///
/// This function returns a handle to the spawned task and a cancellation token
/// to stop it.
pub fn spawn_listener<T, F, S, I, E>(
mut handle: GenServerHandle<T>,
message_builder: F,
mut stream: S,
) -> JoinHandle<()>
pub fn spawn_listener<T, S>(mut handle: GenServerHandle<T>, stream: S) -> JoinHandle<()>
where
T: GenServer + 'static,
F: Fn(I) -> T::CastMsg + Send + 'static + std::marker::Sync,
I: Send,
E: std::fmt::Debug + Send,
S: Unpin + Send + Stream<Item = Result<I, E>> + 'static,
T: GenServer,
S: Send + Stream<Item = T::CastMsg> + 'static,
{
let cancelation_token = handle.cancellation_token();
let join_handle = spawned_rt::tasks::spawn(async move {
let result = select(
Box::pin(cancelation_token.cancelled()),
Box::pin(async {
loop {
match stream.next().await {
// Stream has a new valid Item
Some(Ok(i)) => match handle.cast(message_builder(i)).await {
Ok(_) => tracing::trace!("Message sent successfully"),
Err(e) => {
tracing::error!("Failed to send message: {e:?}");
break;
}
},
// Stream has new data, but failed to extract the Item,
// probably due to decoding problems.
Some(Err(e)) => {
// log the error but keep listener alive for more valid Items
tracing::error!("Error processing stream: {e:?}");
}
None => {
tracing::trace!("Stream finished");
let mut pinned_stream = core::pin::pin!(stream);
let is_cancelled = core::pin::pin!(cancelation_token.cancelled());
let listener_loop = core::pin::pin!(async {
loop {
match pinned_stream.next().await {
Some(msg) => match handle.cast(msg).await {
Ok(_) => tracing::trace!("Message sent successfully"),
Err(e) => {
tracing::error!("Failed to send message: {e:?}");
break;
}
},
None => {
tracing::trace!("Stream finished");
break;
}
}
}),
)
.await;
match result {
}
});
match select(is_cancelled, listener_loop).await {
futures::future::Either::Left(_) => tracing::trace!("GenServer stopped"),
futures::future::Either::Right(_) => (), // Stream finished or errored out
}
Expand Down
83 changes: 53 additions & 30 deletions concurrency/src/tasks/stream_tests.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use crate::tasks::{
send_after, stream::spawn_listener, CallResponse, CastResponse, GenServer, GenServerHandle,
};
use futures::{stream, StreamExt};
use spawned_rt::tasks::{self as rt, BroadcastStream, ReceiverStream};
use std::{io::Error, time::Duration};
use std::time::Duration;

type SummatoryHandle = GenServerHandle<Summatory>;

Expand All @@ -21,6 +22,7 @@ type SummatoryOutMessage = u16;
#[derive(Clone)]
enum SummatoryCastMessage {
Add(u16),
StreamError,
Stop,
}

Expand All @@ -46,6 +48,7 @@ impl GenServer for Summatory {
self.count += val;
CastResponse::NoReply
}
SummatoryCastMessage::StreamError => CastResponse::Stop,
SummatoryCastMessage::Stop => CastResponse::Stop,
}
}
Expand All @@ -60,20 +63,17 @@ impl GenServer for Summatory {
}
}

// In this example, the stream sends u8 values, which are converted to the type
// supported by the GenServer (SummatoryCastMessage / u16).
fn message_builder(value: u8) -> SummatoryCastMessage {
SummatoryCastMessage::Add(value.into())
}

#[test]
pub fn test_sum_numbers_from_stream() {
let runtime = rt::Runtime::new().unwrap();
runtime.block_on(async move {
let mut summatory_handle = Summatory::new(0).start();
let stream = tokio_stream::iter(vec![1u8, 2, 3, 4, 5].into_iter().map(Ok::<u8, ()>));
let stream = stream::iter(vec![1u16, 2, 3, 4, 5].into_iter().map(Ok::<u16, ()>));

spawn_listener(summatory_handle.clone(), message_builder, stream);
spawn_listener(
summatory_handle.clone(),
stream.filter_map(|result| async move { result.ok().map(SummatoryCastMessage::Add) }),
);

// Wait for 1 second so the whole stream is processed
rt::sleep(Duration::from_secs(1)).await;
Expand All @@ -88,7 +88,7 @@ pub fn test_sum_numbers_from_channel() {
let runtime = rt::Runtime::new().unwrap();
runtime.block_on(async move {
let mut summatory_handle = Summatory::new(0).start();
let (tx, rx) = spawned_rt::tasks::mpsc::channel::<Result<u8, ()>>();
let (tx, rx) = spawned_rt::tasks::mpsc::channel::<Result<u16, ()>>();

// Spawn a task to send numbers to the channel
spawned_rt::tasks::spawn(async move {
Expand All @@ -99,8 +99,8 @@ pub fn test_sum_numbers_from_channel() {

spawn_listener(
summatory_handle.clone(),
message_builder,
ReceiverStream::new(rx),
ReceiverStream::new(rx)
.filter_map(|result| async move { result.ok().map(SummatoryCastMessage::Add) }),
);

// Wait for 1 second so the whole stream is processed
Expand All @@ -116,19 +116,19 @@ pub fn test_sum_numbers_from_broadcast_channel() {
let runtime = rt::Runtime::new().unwrap();
runtime.block_on(async move {
let mut summatory_handle = Summatory::new(0).start();
let (tx, rx) = tokio::sync::broadcast::channel::<u8>(5);
let (tx, rx) = tokio::sync::broadcast::channel::<u16>(5);

// Spawn a task to send numbers to the channel
spawned_rt::tasks::spawn(async move {
for i in 1u8..=5 {
for i in 1u16..=5 {
tx.send(i).unwrap();
}
});

spawn_listener(
summatory_handle.clone(),
message_builder,
BroadcastStream::new(rx),
BroadcastStream::new(rx)
.filter_map(|result| async move { result.ok().map(SummatoryCastMessage::Add) }),
);

// Wait for 1 second so the whole stream is processed
Expand All @@ -146,7 +146,7 @@ pub fn test_stream_cancellation() {
let runtime = rt::Runtime::new().unwrap();
runtime.block_on(async move {
let mut summatory_handle = Summatory::new(0).start();
let (tx, rx) = spawned_rt::tasks::mpsc::channel::<Result<u8, ()>>();
let (tx, rx) = spawned_rt::tasks::mpsc::channel::<Result<u16, ()>>();

// Spawn a task to send numbers to the channel
spawned_rt::tasks::spawn(async move {
Expand All @@ -158,8 +158,8 @@ pub fn test_stream_cancellation() {

let listener_handle = spawn_listener(
summatory_handle.clone(),
message_builder,
ReceiverStream::new(rx),
ReceiverStream::new(rx)
.filter_map(|result| async move { result.ok().map(SummatoryCastMessage::Add) }),
);

// Start a timer to stop the stream after a certain time
Expand Down Expand Up @@ -189,20 +189,43 @@ pub fn test_stream_cancellation() {
}

#[test]
pub fn test_stream_skipping_decoding_error() {
pub fn test_halting_on_stream_error() {
let runtime = rt::Runtime::new().unwrap();
runtime.block_on(async move {
let mut summatory_handle = Summatory::new(0).start();
let stream = tokio_stream::iter(vec![Ok(1u16), Ok(2), Ok(3), Err(()), Ok(4), Ok(5)]);
let msg_stream = stream.filter_map(|value| async move {
match value {
Ok(number) => Some(SummatoryCastMessage::Add(number)),
Err(_) => Some(SummatoryCastMessage::StreamError),
}
});

spawn_listener(summatory_handle.clone(), msg_stream);

// Wait for 1 second so the whole stream is processed
rt::sleep(Duration::from_secs(1)).await;

let result = Summatory::get_value(&mut summatory_handle).await;
// GenServer should have been terminated, hence the result should be an error
assert!(result.is_err());
})
}

#[test]
pub fn test_skipping_on_stream_error() {
let runtime = rt::Runtime::new().unwrap();
runtime.block_on(async move {
let mut summatory_handle = Summatory::new(0).start();
let stream = tokio_stream::iter(vec![
Ok(1),
Ok(2),
Ok(3),
Err(Error::other("oh no!")),
Ok(4),
Ok(5),
]);

spawn_listener(summatory_handle.clone(), message_builder, stream);
let stream = tokio_stream::iter(vec![Ok(1u16), Ok(2), Ok(3), Err(()), Ok(4), Ok(5)]);
let msg_stream = stream.filter_map(|value| async move {
match value {
Ok(number) => Some(SummatoryCastMessage::Add(number)),
Err(_) => None,
}
});

spawn_listener(summatory_handle.clone(), msg_stream);

// Wait for 1 second so the whole stream is processed
rt::sleep(Duration::from_secs(1)).await;
Expand Down