Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ members = [
"examples/bank",
"examples/bank_threads",
"examples/name_server",
"examples/name_server_with_error",
"examples/ping_pong",
"examples/ping_pong_threads",
"examples/updater",
Expand Down
97 changes: 40 additions & 57 deletions concurrency/src/tasks/gen_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,13 @@ pub enum GenServerInMsg<G: GenServer> {
}

pub enum CallResponse<G: GenServer> {
Reply(G, G::OutMsg),
Reply(G::OutMsg),
Unused,
Stop(G::OutMsg),
}

pub enum CastResponse<G: GenServer> {
NoReply(G),
pub enum CastResponse {
NoReply,
Unused,
Stop,
}
Expand All @@ -128,7 +128,7 @@ pub enum InitResult<G: GenServer> {
NoSuccess(G),
}

pub trait GenServer: Send + Sized + Clone {
pub trait GenServer: Send + Sized {
type CallMsg: Clone + Send + Sized + Sync;
type CastMsg: Clone + Send + Sized + Sync;
type OutMsg: Send + Sized;
Expand All @@ -154,7 +154,7 @@ pub trait GenServer: Send + Sized + Clone {
) -> impl Future<Output = Result<(), GenServerError>> + Send {
async {
let res = match self.init(handle).await {
Ok(Success(new_state)) => new_state.main_loop(handle, rx).await,
Ok(Success(new_state)) => Ok(new_state.main_loop(handle, rx).await),
Ok(NoSuccess(intermediate_state)) => {
// new_state is NoSuccess, this means the initialization failed, but the error was handled
// in callback. No need to report the error.
Expand Down Expand Up @@ -191,53 +191,44 @@ pub trait GenServer: Send + Sized + Clone {
mut self,
handle: &GenServerHandle<Self>,
rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
) -> impl Future<Output = Result<Self, GenServerError>> + Send {
) -> impl Future<Output = Self> + Send {
async {
loop {
let (new_state, cont) = self.receive(handle, rx).await?;
self = new_state;
if !cont {
if !self.receive(handle, rx).await {
break;
}
}
tracing::trace!("Stopping GenServer");
Ok(self)
self
}
}

fn receive(
self,
&mut self,
handle: &GenServerHandle<Self>,
rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
) -> impl Future<Output = Result<(Self, bool), GenServerError>> + Send {
) -> impl Future<Output = bool> + Send {
async move {
let message = rx.recv().await;

// Save current state in case of a rollback
let state_clone = self.clone();

let (keep_running, new_state) = match message {
let keep_running = match message {
Some(GenServerInMsg::Call { sender, message }) => {
let (keep_running, new_state, response) =
let (keep_running, response) =
match AssertUnwindSafe(self.handle_call(message, handle))
.catch_unwind()
.await
{
Ok(response) => match response {
CallResponse::Reply(new_state, response) => {
(true, new_state, Ok(response))
}
CallResponse::Stop(response) => (false, state_clone, Ok(response)),
CallResponse::Reply(response) => (true, Ok(response)),
CallResponse::Stop(response) => (false, Ok(response)),
CallResponse::Unused => {
tracing::error!("GenServer received unexpected CallMessage");
(false, state_clone, Err(GenServerError::CallMsgUnused))
(false, Err(GenServerError::CallMsgUnused))
}
},
Err(error) => {
tracing::error!(
"Error in callback, reverting state - Error: '{error:?}'"
);
(true, state_clone, Err(GenServerError::Callback))
tracing::error!("Error in callback: '{error:?}'");
(false, Err(GenServerError::Callback))
}
};
// Send response back
Expand All @@ -246,51 +237,49 @@ pub trait GenServer: Send + Sized + Clone {
"GenServer failed to send response back, client must have died"
)
};
(keep_running, new_state)
keep_running
}
Some(GenServerInMsg::Cast { message }) => {
match AssertUnwindSafe(self.handle_cast(message, handle))
.catch_unwind()
.await
{
Ok(response) => match response {
CastResponse::NoReply(new_state) => (true, new_state),
CastResponse::Stop => (false, state_clone),
CastResponse::NoReply => true,
CastResponse::Stop => false,
CastResponse::Unused => {
tracing::error!("GenServer received unexpected CastMessage");
(false, state_clone)
false
}
},
Err(error) => {
tracing::trace!(
"Error in callback, reverting state - Error: '{error:?}'"
);
(true, state_clone)
tracing::trace!("Error in callback: '{error:?}'");
false
}
}
}
None => {
// Channel has been closed; won't receive further messages. Stop the server.
(false, self)
false
}
};
Ok((new_state, keep_running))
keep_running
}
}

fn handle_call(
self,
&mut self,
_message: Self::CallMsg,
_handle: &GenServerHandle<Self>,
) -> impl Future<Output = CallResponse<Self>> + Send {
async { CallResponse::Unused }
}

fn handle_cast(
self,
&mut self,
_message: Self::CastMsg,
_handle: &GenServerHandle<Self>,
) -> impl Future<Output = CastResponse<Self>> + Send {
) -> impl Future<Output = CastResponse> + Send {
async { CastResponse::Unused }
}

Expand All @@ -316,7 +305,6 @@ mod tests {
time::Duration,
};

#[derive(Clone)]
struct BadlyBehavedTask;

#[derive(Clone)]
Expand All @@ -336,25 +324,24 @@ mod tests {
type Error = Unused;

async fn handle_call(
self,
&mut self,
_: Self::CallMsg,
_: &GenServerHandle<Self>,
) -> CallResponse<Self> {
CallResponse::Stop(Unused)
}

async fn handle_cast(
self,
&mut self,
_: Self::CastMsg,
_: &GenServerHandle<Self>,
) -> CastResponse<Self> {
) -> CastResponse {
rt::sleep(Duration::from_millis(20)).await;
thread::sleep(Duration::from_secs(2));
CastResponse::Stop
}
}

#[derive(Clone)]
struct WellBehavedTask {
pub count: u64,
}
Expand All @@ -366,28 +353,25 @@ mod tests {
type Error = Unused;

async fn handle_call(
self,
&mut self,
message: Self::CallMsg,
_: &GenServerHandle<Self>,
) -> CallResponse<Self> {
match message {
InMessage::GetCount => {
let count = self.count;
CallResponse::Reply(self, OutMsg::Count(count))
}
InMessage::GetCount => CallResponse::Reply(OutMsg::Count(self.count)),
InMessage::Stop => CallResponse::Stop(OutMsg::Count(self.count)),
}
}

async fn handle_cast(
mut self,
&mut self,
_: Self::CastMsg,
handle: &GenServerHandle<Self>,
) -> CastResponse<Self> {
) -> CastResponse {
self.count += 1;
println!("{:?}: good still alive", thread::current().id());
send_after(Duration::from_millis(100), handle.to_owned(), Unused);
CastResponse::NoReply(self)
CastResponse::NoReply
}
}

Expand Down Expand Up @@ -433,7 +417,7 @@ mod tests {

const TIMEOUT_DURATION: Duration = Duration::from_millis(100);

#[derive(Debug, Default, Clone)]
#[derive(Debug, Default)]
struct SomeTask;

#[derive(Clone)]
Expand All @@ -449,20 +433,20 @@ mod tests {
type Error = Unused;

async fn handle_call(
self,
&mut self,
message: Self::CallMsg,
_handle: &GenServerHandle<Self>,
) -> CallResponse<Self> {
match message {
SomeTaskCallMsg::SlowOperation => {
// Simulate a slow operation that will not resolve in time
rt::sleep(TIMEOUT_DURATION * 2).await;
CallResponse::Reply(self, Unused)
CallResponse::Reply(Unused)
}
SomeTaskCallMsg::FastOperation => {
// Simulate a fast operation that resolves in time
rt::sleep(TIMEOUT_DURATION / 2).await;
CallResponse::Reply(self, Unused)
CallResponse::Reply(Unused)
}
}
}
Expand All @@ -486,7 +470,6 @@ mod tests {
});
}

#[derive(Clone)]
struct SomeTaskThatFailsOnInit {
sender_channel: Arc<Mutex<mpsc::Receiver<u8>>>,
}
Expand Down
11 changes: 5 additions & 6 deletions concurrency/src/tasks/stream_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ use crate::tasks::{

type SummatoryHandle = GenServerHandle<Summatory>;

#[derive(Clone)]
struct Summatory {
count: u16,
}
Expand Down Expand Up @@ -40,26 +39,26 @@ impl GenServer for Summatory {
type Error = ();

async fn handle_cast(
mut self,
&mut self,
message: Self::CastMsg,
_handle: &GenServerHandle<Self>,
) -> CastResponse<Self> {
) -> CastResponse {
match message {
SummatoryCastMessage::Add(val) => {
self.count += val;
CastResponse::NoReply(self)
CastResponse::NoReply
}
SummatoryCastMessage::Stop => CastResponse::Stop,
}
}

async fn handle_call(
self,
&mut self,
_message: Self::CallMsg,
_handle: &SummatoryHandle,
) -> CallResponse<Self> {
let current_value = self.count;
CallResponse::Reply(self, current_value)
CallResponse::Reply(current_value)
}
}

Expand Down
Loading
Loading