Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 52 additions & 83 deletions concurrency/src/tasks/gen_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,45 +25,35 @@ impl<G: GenServer> Clone for GenServerHandle<G> {
}

impl<G: GenServer> GenServerHandle<G> {
pub(crate) fn new(initial_state: G::State) -> Self {
pub(crate) fn new(gen_server: G) -> Self {
let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
let cancellation_token = CancellationToken::new();
let handle = GenServerHandle {
tx,
cancellation_token,
};
let mut gen_server: G = GenServer::new();
let handle_clone = handle.clone();
// Ignore the JoinHandle for now. Maybe we'll use it in the future
let _join_handle = rt::spawn(async move {
if gen_server
.run(&handle, &mut rx, initial_state)
.await
.is_err()
{
if gen_server.run(&handle, &mut rx).await.is_err() {
tracing::trace!("GenServer crashed")
};
});
handle_clone
}

pub(crate) fn new_blocking(initial_state: G::State) -> Self {
pub(crate) fn new_blocking(gen_server: G) -> Self {
let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
let cancellation_token = CancellationToken::new();
let handle = GenServerHandle {
tx,
cancellation_token,
};
let mut gen_server: G = GenServer::new();
let handle_clone = handle.clone();
// Ignore the JoinHandle for now. Maybe we'll use it in the future
let _join_handle = rt::spawn_blocking(|| {
rt::block_on(async move {
if gen_server
.run(&handle, &mut rx, initial_state)
.await
.is_err()
{
if gen_server.run(&handle, &mut rx).await.is_err() {
tracing::trace!("GenServer crashed")
};
})
Expand Down Expand Up @@ -119,63 +109,59 @@ pub enum GenServerInMsg<G: GenServer> {
}

pub enum CallResponse<G: GenServer> {
Reply(G::State, G::OutMsg),
Reply(G, G::OutMsg),
Unused,
Stop(G::OutMsg),
}

pub enum CastResponse<G: GenServer> {
NoReply(G::State),
NoReply(G),
Unused,
Stop,
}

pub trait GenServer
where
Self: Default + Send + Sized,
{
pub trait GenServer: Default + Send + Sized + Clone {
type CallMsg: Clone + Send + Sized + Sync;
type CastMsg: Clone + Send + Sized + Sync;
type OutMsg: Send + Sized;
type State: Clone + Send;
type Error: Debug + Send;

fn new() -> Self {
Self::default()
}

fn start(initial_state: Self::State) -> GenServerHandle<Self> {
GenServerHandle::new(initial_state)
fn start(self) -> GenServerHandle<Self> {
GenServerHandle::new(self)
}

/// Tokio tasks depend on a coolaborative multitasking model. "work stealing" can't
/// happen if the task is blocking the thread. As such, for sync compute task
/// or other blocking tasks need to be in their own separate thread, and the OS
/// will manage them through hardware interrupts.
/// Start blocking provides such thread.
fn start_blocking(initial_state: Self::State) -> GenServerHandle<Self> {
GenServerHandle::new_blocking(initial_state)
fn start_blocking(self) -> GenServerHandle<Self> {
GenServerHandle::new_blocking(self)
}

fn run(
&mut self,
self,
handle: &GenServerHandle<Self>,
rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
state: Self::State,
) -> impl Future<Output = Result<(), GenServerError>> + Send {
async {
let init_result = self
.init(handle, state.clone())
.clone()
.init(handle)
.await
.inspect_err(|err| tracing::error!("Initialization failed: {err:?}"));

let res = match init_result {
Ok(new_state) => self.main_loop(handle, rx, new_state).await,
Ok(new_state) => new_state.main_loop(handle, rx).await,
Err(_) => Err(GenServerError::Initialization),
};

handle.cancellation_token().cancel();
if let Err(err) = self.teardown(handle, state).await {
if let Err(err) = self.teardown(handle).await {
tracing::error!("Error during teardown: {err:?}");
}
res
Expand All @@ -186,23 +172,21 @@ where
/// can be overrided on implementations in case initial steps are
/// required.
fn init(
&mut self,
self,
_handle: &GenServerHandle<Self>,
state: Self::State,
) -> impl Future<Output = Result<Self::State, Self::Error>> + Send {
async { Ok(state) }
) -> impl Future<Output = Result<Self, Self::Error>> + Send {
async { Ok(self) }
}

fn main_loop(
&mut self,
mut self,
handle: &GenServerHandle<Self>,
rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
mut state: Self::State,
) -> impl Future<Output = Result<(), GenServerError>> + Send {
async {
loop {
let (new_state, cont) = self.receive(handle, rx, state).await?;
state = new_state;
let (new_state, cont) = self.receive(handle, rx).await?;
self = new_state;
if !cont {
break;
}
Expand All @@ -213,21 +197,20 @@ where
}

fn receive(
&mut self,
self,
handle: &GenServerHandle<Self>,
rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
state: Self::State,
) -> impl Future<Output = Result<(Self::State, bool), GenServerError>> + Send {
) -> impl Future<Output = Result<(Self, bool), GenServerError>> + Send {
async move {
let message = rx.recv().await;

// Save current state in case of a rollback
let state_clone = state.clone();
let state_clone = self.clone();

let (keep_running, new_state) = match message {
Some(GenServerInMsg::Call { sender, message }) => {
let (keep_running, new_state, response) =
match AssertUnwindSafe(self.handle_call(message, handle, state))
match AssertUnwindSafe(self.handle_call(message, handle))
.catch_unwind()
.await
{
Expand Down Expand Up @@ -257,7 +240,7 @@ where
(keep_running, new_state)
}
Some(GenServerInMsg::Cast { message }) => {
match AssertUnwindSafe(self.handle_cast(message, handle, state))
match AssertUnwindSafe(self.handle_cast(message, handle))
.catch_unwind()
.await
{
Expand All @@ -279,27 +262,25 @@ where
}
None => {
// Channel has been closed; won't receive further messages. Stop the server.
(false, state)
(false, self)
}
};
Ok((new_state, keep_running))
}
}

fn handle_call(
&mut self,
self,
_message: Self::CallMsg,
_handle: &GenServerHandle<Self>,
_state: Self::State,
) -> impl Future<Output = CallResponse<Self>> + Send {
async { CallResponse::Unused }
}

fn handle_cast(
&mut self,
self,
_message: Self::CastMsg,
_handle: &GenServerHandle<Self>,
_state: Self::State,
) -> impl Future<Output = CastResponse<Self>> + Send {
async { CastResponse::Unused }
}
Expand All @@ -308,9 +289,8 @@ where
/// It can be overrided on implementations in case final steps are required,
/// like closing streams, stopping timers, etc.
fn teardown(
&mut self,
self,
_handle: &GenServerHandle<Self>,
_state: Self::State,
) -> impl Future<Output = Result<(), Self::Error>> + Send {
async { Ok(()) }
}
Expand All @@ -323,7 +303,7 @@ mod tests {
use crate::tasks::send_after;
use std::{thread, time::Duration};

#[derive(Default)]
#[derive(Default, Clone)]
struct BadlyBehavedTask;

#[derive(Clone)]
Expand All @@ -340,80 +320,71 @@ mod tests {
type CallMsg = InMessage;
type CastMsg = ();
type OutMsg = ();
type State = ();
type Error = ();

async fn handle_call(
&mut self,
self,
_: Self::CallMsg,
_: &GenServerHandle<Self>,
_: Self::State,
) -> CallResponse<Self> {
CallResponse::Stop(())
}

async fn handle_cast(
&mut self,
self,
_: Self::CastMsg,
_: &GenServerHandle<Self>,
_: Self::State,
) -> CastResponse<Self> {
rt::sleep(Duration::from_millis(20)).await;
thread::sleep(Duration::from_secs(2));
CastResponse::Stop
}
}

#[derive(Default)]
struct WellBehavedTask;

#[derive(Clone)]
struct CountState {
#[derive(Default, Clone)]
struct WellBehavedTask {
pub count: u64,
}

impl GenServer for WellBehavedTask {
type CallMsg = InMessage;
type CastMsg = ();
type OutMsg = OutMsg;
type State = CountState;
type Error = ();

async fn handle_call(
&mut self,
self,
message: Self::CallMsg,
_: &GenServerHandle<Self>,
state: Self::State,
) -> CallResponse<Self> {
match message {
InMessage::GetCount => {
let count = state.count;
CallResponse::Reply(state, OutMsg::Count(count))
let count = self.count;
CallResponse::Reply(self, OutMsg::Count(count))
}
InMessage::Stop => CallResponse::Stop(OutMsg::Count(state.count)),
InMessage::Stop => CallResponse::Stop(OutMsg::Count(self.count)),
}
}

async fn handle_cast(
&mut self,
mut self,
_: Self::CastMsg,
handle: &GenServerHandle<Self>,
mut state: Self::State,
) -> CastResponse<Self> {
state.count += 1;
self.count += 1;
println!("{:?}: good still alive", thread::current().id());
send_after(Duration::from_millis(100), handle.to_owned(), ());
CastResponse::NoReply(state)
CastResponse::NoReply(self)
}
}

#[test]
pub fn badly_behaved_thread_non_blocking() {
let runtime = rt::Runtime::new().unwrap();
runtime.block_on(async move {
let mut badboy = BadlyBehavedTask::start(());
let mut badboy = BadlyBehavedTask.start();
let _ = badboy.cast(()).await;
let mut goodboy = WellBehavedTask::start(CountState { count: 0 });
let mut goodboy = WellBehavedTask { count: 0 }.start();
let _ = goodboy.cast(()).await;
rt::sleep(Duration::from_secs(1)).await;
let count = goodboy.call(InMessage::GetCount).await.unwrap();
Expand All @@ -431,9 +402,9 @@ mod tests {
pub fn badly_behaved_thread() {
let runtime = rt::Runtime::new().unwrap();
runtime.block_on(async move {
let mut badboy = BadlyBehavedTask::start_blocking(());
let mut badboy = BadlyBehavedTask.start();
let _ = badboy.cast(()).await;
let mut goodboy = WellBehavedTask::start(CountState { count: 0 });
let mut goodboy = WellBehavedTask { count: 0 }.start();
let _ = goodboy.cast(()).await;
rt::sleep(Duration::from_secs(1)).await;
let count = goodboy.call(InMessage::GetCount).await.unwrap();
Expand All @@ -449,7 +420,7 @@ mod tests {

const TIMEOUT_DURATION: Duration = Duration::from_millis(100);

#[derive(Default)]
#[derive(Debug, Default, Clone)]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Default is no longer needed, right?

struct SomeTask;

#[derive(Clone)]
Expand All @@ -462,25 +433,23 @@ mod tests {
type CallMsg = SomeTaskCallMsg;
type CastMsg = ();
Copy link
Collaborator

@ElFantasma ElFantasma Jul 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not part of this PR, but maybe we can move all these () to Unused 👉 👈?

type OutMsg = ();
type State = ();
type Error = ();

async fn handle_call(
&mut self,
self,
message: Self::CallMsg,
_handle: &GenServerHandle<Self>,
_state: Self::State,
) -> CallResponse<Self> {
match message {
SomeTaskCallMsg::SlowOperation => {
// Simulate a slow operation that will not resolve in time
rt::sleep(TIMEOUT_DURATION * 2).await;
CallResponse::Reply((), ())
CallResponse::Reply(self, ())
}
SomeTaskCallMsg::FastOperation => {
// Simulate a fast operation that resolves in time
rt::sleep(TIMEOUT_DURATION / 2).await;
CallResponse::Reply((), ())
CallResponse::Reply(self, ())
}
}
}
Expand All @@ -490,7 +459,7 @@ mod tests {
pub fn unresolving_task_times_out() {
let runtime = rt::Runtime::new().unwrap();
runtime.block_on(async move {
let mut unresolving_task = SomeTask::start(());
let mut unresolving_task = SomeTask.start();

let result = unresolving_task
.call_with_timeout(SomeTaskCallMsg::FastOperation, TIMEOUT_DURATION)
Expand Down
Loading
Loading