diff --git a/concurrency/src/tasks/stream.rs b/concurrency/src/tasks/stream.rs index 26a8da3..492c4f9 100644 --- a/concurrency/src/tasks/stream.rs +++ b/concurrency/src/tasks/stream.rs @@ -8,49 +8,33 @@ use spawned_rt::tasks::JoinHandle; /// /// This function returns a handle to the spawned task and a cancellation token /// to stop it. -pub fn spawn_listener( - mut handle: GenServerHandle, - message_builder: F, - mut stream: S, -) -> JoinHandle<()> +pub fn spawn_listener(mut handle: GenServerHandle, stream: S) -> JoinHandle<()> where - T: GenServer + 'static, - F: Fn(I) -> T::CastMsg + Send + 'static + std::marker::Sync, - I: Send, - E: std::fmt::Debug + Send, - S: Unpin + Send + Stream> + 'static, + T: GenServer, + S: Send + Stream + 'static, { let cancelation_token = handle.cancellation_token(); let join_handle = spawned_rt::tasks::spawn(async move { - let result = select( - Box::pin(cancelation_token.cancelled()), - Box::pin(async { - loop { - match stream.next().await { - // Stream has a new valid Item - Some(Ok(i)) => match handle.cast(message_builder(i)).await { - Ok(_) => tracing::trace!("Message sent successfully"), - Err(e) => { - tracing::error!("Failed to send message: {e:?}"); - break; - } - }, - // Stream has new data, but failed to extract the Item, - // probably due to decoding problems. - Some(Err(e)) => { - // log the error but keep listener alive for more valid Items - tracing::error!("Error processing stream: {e:?}"); - } - None => { - tracing::trace!("Stream finished"); + let mut pinned_stream = core::pin::pin!(stream); + let is_cancelled = core::pin::pin!(cancelation_token.cancelled()); + let listener_loop = core::pin::pin!(async { + loop { + match pinned_stream.next().await { + Some(msg) => match handle.cast(msg).await { + Ok(_) => tracing::trace!("Message sent successfully"), + Err(e) => { + tracing::error!("Failed to send message: {e:?}"); break; } + }, + None => { + tracing::trace!("Stream finished"); + break; } } - }), - ) - .await; - match result { + } + }); + match select(is_cancelled, listener_loop).await { futures::future::Either::Left(_) => tracing::trace!("GenServer stopped"), futures::future::Either::Right(_) => (), // Stream finished or errored out } diff --git a/concurrency/src/tasks/stream_tests.rs b/concurrency/src/tasks/stream_tests.rs index 14d89d7..bebc023 100644 --- a/concurrency/src/tasks/stream_tests.rs +++ b/concurrency/src/tasks/stream_tests.rs @@ -1,8 +1,9 @@ use crate::tasks::{ send_after, stream::spawn_listener, CallResponse, CastResponse, GenServer, GenServerHandle, }; +use futures::{stream, StreamExt}; use spawned_rt::tasks::{self as rt, BroadcastStream, ReceiverStream}; -use std::{io::Error, time::Duration}; +use std::time::Duration; type SummatoryHandle = GenServerHandle; @@ -21,6 +22,7 @@ type SummatoryOutMessage = u16; #[derive(Clone)] enum SummatoryCastMessage { Add(u16), + StreamError, Stop, } @@ -46,6 +48,7 @@ impl GenServer for Summatory { self.count += val; CastResponse::NoReply } + SummatoryCastMessage::StreamError => CastResponse::Stop, SummatoryCastMessage::Stop => CastResponse::Stop, } } @@ -60,20 +63,17 @@ impl GenServer for Summatory { } } -// In this example, the stream sends u8 values, which are converted to the type -// supported by the GenServer (SummatoryCastMessage / u16). -fn message_builder(value: u8) -> SummatoryCastMessage { - SummatoryCastMessage::Add(value.into()) -} - #[test] pub fn test_sum_numbers_from_stream() { let runtime = rt::Runtime::new().unwrap(); runtime.block_on(async move { let mut summatory_handle = Summatory::new(0).start(); - let stream = tokio_stream::iter(vec![1u8, 2, 3, 4, 5].into_iter().map(Ok::)); + let stream = stream::iter(vec![1u16, 2, 3, 4, 5].into_iter().map(Ok::)); - spawn_listener(summatory_handle.clone(), message_builder, stream); + spawn_listener( + summatory_handle.clone(), + stream.filter_map(|result| async move { result.ok().map(SummatoryCastMessage::Add) }), + ); // Wait for 1 second so the whole stream is processed rt::sleep(Duration::from_secs(1)).await; @@ -88,7 +88,7 @@ pub fn test_sum_numbers_from_channel() { let runtime = rt::Runtime::new().unwrap(); runtime.block_on(async move { let mut summatory_handle = Summatory::new(0).start(); - let (tx, rx) = spawned_rt::tasks::mpsc::channel::>(); + let (tx, rx) = spawned_rt::tasks::mpsc::channel::>(); // Spawn a task to send numbers to the channel spawned_rt::tasks::spawn(async move { @@ -99,8 +99,8 @@ pub fn test_sum_numbers_from_channel() { spawn_listener( summatory_handle.clone(), - message_builder, - ReceiverStream::new(rx), + ReceiverStream::new(rx) + .filter_map(|result| async move { result.ok().map(SummatoryCastMessage::Add) }), ); // Wait for 1 second so the whole stream is processed @@ -116,19 +116,19 @@ pub fn test_sum_numbers_from_broadcast_channel() { let runtime = rt::Runtime::new().unwrap(); runtime.block_on(async move { let mut summatory_handle = Summatory::new(0).start(); - let (tx, rx) = tokio::sync::broadcast::channel::(5); + let (tx, rx) = tokio::sync::broadcast::channel::(5); // Spawn a task to send numbers to the channel spawned_rt::tasks::spawn(async move { - for i in 1u8..=5 { + for i in 1u16..=5 { tx.send(i).unwrap(); } }); spawn_listener( summatory_handle.clone(), - message_builder, - BroadcastStream::new(rx), + BroadcastStream::new(rx) + .filter_map(|result| async move { result.ok().map(SummatoryCastMessage::Add) }), ); // Wait for 1 second so the whole stream is processed @@ -146,7 +146,7 @@ pub fn test_stream_cancellation() { let runtime = rt::Runtime::new().unwrap(); runtime.block_on(async move { let mut summatory_handle = Summatory::new(0).start(); - let (tx, rx) = spawned_rt::tasks::mpsc::channel::>(); + let (tx, rx) = spawned_rt::tasks::mpsc::channel::>(); // Spawn a task to send numbers to the channel spawned_rt::tasks::spawn(async move { @@ -158,8 +158,8 @@ pub fn test_stream_cancellation() { let listener_handle = spawn_listener( summatory_handle.clone(), - message_builder, - ReceiverStream::new(rx), + ReceiverStream::new(rx) + .filter_map(|result| async move { result.ok().map(SummatoryCastMessage::Add) }), ); // Start a timer to stop the stream after a certain time @@ -189,20 +189,43 @@ pub fn test_stream_cancellation() { } #[test] -pub fn test_stream_skipping_decoding_error() { +pub fn test_halting_on_stream_error() { + let runtime = rt::Runtime::new().unwrap(); + runtime.block_on(async move { + let mut summatory_handle = Summatory::new(0).start(); + let stream = tokio_stream::iter(vec![Ok(1u16), Ok(2), Ok(3), Err(()), Ok(4), Ok(5)]); + let msg_stream = stream.filter_map(|value| async move { + match value { + Ok(number) => Some(SummatoryCastMessage::Add(number)), + Err(_) => Some(SummatoryCastMessage::StreamError), + } + }); + + spawn_listener(summatory_handle.clone(), msg_stream); + + // Wait for 1 second so the whole stream is processed + rt::sleep(Duration::from_secs(1)).await; + + let result = Summatory::get_value(&mut summatory_handle).await; + // GenServer should have been terminated, hence the result should be an error + assert!(result.is_err()); + }) +} + +#[test] +pub fn test_skipping_on_stream_error() { let runtime = rt::Runtime::new().unwrap(); runtime.block_on(async move { let mut summatory_handle = Summatory::new(0).start(); - let stream = tokio_stream::iter(vec![ - Ok(1), - Ok(2), - Ok(3), - Err(Error::other("oh no!")), - Ok(4), - Ok(5), - ]); - - spawn_listener(summatory_handle.clone(), message_builder, stream); + let stream = tokio_stream::iter(vec![Ok(1u16), Ok(2), Ok(3), Err(()), Ok(4), Ok(5)]); + let msg_stream = stream.filter_map(|value| async move { + match value { + Ok(number) => Some(SummatoryCastMessage::Add(number)), + Err(_) => None, + } + }); + + spawn_listener(summatory_handle.clone(), msg_stream); // Wait for 1 second so the whole stream is processed rt::sleep(Duration::from_secs(1)).await;