Skip to content

Commit a0f50cd

Browse files
committed
update threads GenServer to not use clone, similar to lambdaclass#42
1 parent 863877b commit a0f50cd

File tree

6 files changed

+86
-99
lines changed

6 files changed

+86
-99
lines changed

Cargo.lock

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

concurrency/src/threads/gen_server.rs

Lines changed: 47 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
//! GenServer trait and structs to create an abstraction similar to Erlang gen_server.
22
//! See examples/name_server for a usage example.
3-
use spawned_rt::threads::{self as rt, mpsc, oneshot};
3+
use spawned_rt::threads::{self as rt, mpsc, oneshot, CancellationToken};
44
use std::{
55
fmt::Debug,
66
panic::{catch_unwind, AssertUnwindSafe},
@@ -11,20 +11,26 @@ use crate::error::GenServerError;
1111
#[derive(Debug)]
1212
pub struct GenServerHandle<G: GenServer + 'static> {
1313
pub tx: mpsc::Sender<GenServerInMsg<G>>,
14+
cancellation_token: CancellationToken,
1415
}
1516

1617
impl<G: GenServer> Clone for GenServerHandle<G> {
1718
fn clone(&self) -> Self {
1819
Self {
1920
tx: self.tx.clone(),
21+
cancellation_token: self.cancellation_token.clone(),
2022
}
2123
}
2224
}
2325

2426
impl<G: GenServer> GenServerHandle<G> {
2527
pub(crate) fn new(gen_server: G) -> Self {
2628
let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
27-
let handle = GenServerHandle { tx };
29+
let cancellation_token = CancellationToken::new();
30+
let handle = GenServerHandle {
31+
tx,
32+
cancellation_token,
33+
};
2834
let handle_clone = handle.clone();
2935
// Ignore the JoinHandle for now. Maybe we'll use it in the future
3036
let _join_handle = rt::spawn(move || {
@@ -69,18 +75,18 @@ pub enum GenServerInMsg<G: GenServer> {
6975
}
7076

7177
pub enum CallResponse<G: GenServer> {
72-
Reply(G, G::OutMsg),
78+
Reply(G::OutMsg),
7379
Unused,
7480
Stop(G::OutMsg),
7581
}
7682

77-
pub enum CastResponse<G: GenServer> {
78-
NoReply(G),
83+
pub enum CastResponse {
84+
NoReply,
7985
Unused,
8086
Stop,
8187
}
8288

83-
pub trait GenServer: Send + Sized + Clone {
89+
pub trait GenServer: Send + Sized {
8490
type CallMsg: Clone + Send + Sized;
8591
type CastMsg: Clone + Send + Sized;
8692
type OutMsg: Send + Sized;
@@ -101,16 +107,16 @@ pub trait GenServer: Send + Sized + Clone {
101107
handle: &GenServerHandle<Self>,
102108
rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
103109
) -> Result<(), GenServerError> {
104-
match self.init(handle) {
105-
Ok(new_state) => {
106-
new_state.main_loop(handle, rx)?;
107-
Ok(())
108-
}
110+
let mut cancellation_token = handle.cancellation_token.clone();
111+
let res = match self.init(handle) {
112+
Ok(new_state) => Ok(new_state.main_loop(handle, rx)?),
109113
Err(err) => {
110114
tracing::error!("Initialization failed: {err:?}");
111115
Err(GenServerError::Initialization)
112116
}
113-
}
117+
};
118+
cancellation_token.cancel();
119+
res
114120
}
115121

116122
/// Initialization function. It's called before main loop. It
@@ -126,90 +132,82 @@ pub trait GenServer: Send + Sized + Clone {
126132
rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
127133
) -> Result<(), GenServerError> {
128134
loop {
129-
let (new_state, cont) = self.receive(handle, rx)?;
130-
if !cont {
135+
if !self.receive(handle, rx)? {
131136
break;
132137
}
133-
self = new_state;
134138
}
135139
tracing::trace!("Stopping GenServer");
136140
Ok(())
137141
}
138142

139143
fn receive(
140-
self,
144+
&mut self,
141145
handle: &GenServerHandle<Self>,
142146
rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
143-
) -> Result<(Self, bool), GenServerError> {
147+
) -> Result<bool, GenServerError> {
144148
let message = rx.recv().ok();
145149

146-
// Save current state in case of a rollback
147-
let state_clone = self.clone();
148-
149-
let (keep_running, new_state) = match message {
150+
let keep_running = match message {
150151
Some(GenServerInMsg::Call { sender, message }) => {
151-
let (keep_running, new_state, response) =
152-
match catch_unwind(AssertUnwindSafe(|| self.handle_call(message, handle))) {
153-
Ok(response) => match response {
154-
CallResponse::Reply(new_state, response) => {
155-
(true, new_state, Ok(response))
156-
}
157-
CallResponse::Stop(response) => (false, state_clone, Ok(response)),
158-
CallResponse::Unused => {
159-
tracing::error!("GenServer received unexpected CallMessage");
160-
(false, state_clone, Err(GenServerError::CallMsgUnused))
161-
}
162-
},
163-
Err(error) => {
164-
tracing::trace!(
165-
"Error in callback, reverting state - Error: '{error:?}'"
166-
);
167-
(true, state_clone, Err(GenServerError::Callback))
152+
let (keep_running, response) = match catch_unwind(AssertUnwindSafe(|| {
153+
self.handle_call(message, handle)
154+
})) {
155+
Ok(response) => match response {
156+
CallResponse::Reply(response) => (true, Ok(response)),
157+
CallResponse::Stop(response) => (false, Ok(response)),
158+
CallResponse::Unused => {
159+
tracing::error!("GenServer received unexpected CallMessage");
160+
(false, Err(GenServerError::CallMsgUnused))
168161
}
169-
};
162+
},
163+
Err(error) => {
164+
tracing::trace!("Error in callback, reverting state - Error: '{error:?}'");
165+
(true, Err(GenServerError::Callback))
166+
}
167+
};
170168
// Send response back
171169
if sender.send(response).is_err() {
172170
tracing::trace!("GenServer failed to send response back, client must have died")
173171
};
174-
(keep_running, new_state)
172+
keep_running
175173
}
176174
Some(GenServerInMsg::Cast { message }) => {
177175
match catch_unwind(AssertUnwindSafe(|| self.handle_cast(message, handle))) {
178176
Ok(response) => match response {
179-
CastResponse::NoReply(new_state) => (true, new_state),
180-
CastResponse::Stop => (false, state_clone),
177+
CastResponse::NoReply => true,
178+
CastResponse::Stop => false,
181179
CastResponse::Unused => {
182180
tracing::error!("GenServer received unexpected CastMessage");
183-
(false, state_clone)
181+
false
184182
}
185183
},
186184
Err(error) => {
187185
tracing::trace!("Error in callback, reverting state - Error: '{error:?}'");
188-
(true, state_clone)
186+
true
189187
}
190188
}
191189
}
192190
None => {
193191
// Channel has been closed; won't receive further messages. Stop the server.
194-
(false, self)
192+
false
195193
}
196194
};
197-
Ok((new_state, keep_running))
195+
Ok(keep_running)
198196
}
199197

200198
fn handle_call(
201-
self,
199+
&mut self,
202200
_message: Self::CallMsg,
203201
_handle: &GenServerHandle<Self>,
204202
) -> CallResponse<Self> {
205203
CallResponse::Unused
206204
}
207205

208206
fn handle_cast(
209-
self,
207+
&mut self,
210208
_message: Self::CastMsg,
211209
_handle: &GenServerHandle<Self>,
212-
) -> CastResponse<Self> {
210+
) -> CastResponse {
213211
CastResponse::Unused
214212
}
215213
}

concurrency/src/threads/timer_tests.rs

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -63,16 +63,20 @@ impl GenServer for Repeater {
6363
Ok(self)
6464
}
6565

66-
fn handle_call(self, _message: Self::CallMsg, _handle: &RepeaterHandle) -> CallResponse<Self> {
66+
fn handle_call(
67+
&mut self,
68+
_message: Self::CallMsg,
69+
_handle: &RepeaterHandle,
70+
) -> CallResponse<Self> {
6771
let count = self.count;
68-
CallResponse::Reply(self, RepeaterOutMessage::Count(count))
72+
CallResponse::Reply(RepeaterOutMessage::Count(count))
6973
}
7074

7175
fn handle_cast(
72-
mut self,
76+
&mut self,
7377
message: Self::CastMsg,
7478
_handle: &GenServerHandle<Self>,
75-
) -> CastResponse<Self> {
79+
) -> CastResponse {
7680
match message {
7781
RepeaterCastMessage::Inc => {
7882
self.count += 1;
@@ -83,7 +87,7 @@ impl GenServer for Repeater {
8387
};
8488
}
8589
};
86-
CastResponse::NoReply(self)
90+
CastResponse::NoReply
8791
}
8892
}
8993

@@ -156,22 +160,22 @@ impl GenServer for Delayed {
156160
type OutMsg = DelayedOutMessage;
157161
type Error = ();
158162

159-
fn handle_call(self, _message: Self::CallMsg, _handle: &DelayedHandle) -> CallResponse<Self> {
163+
fn handle_call(
164+
&mut self,
165+
_message: Self::CallMsg,
166+
_handle: &DelayedHandle,
167+
) -> CallResponse<Self> {
160168
let count = self.count;
161-
CallResponse::Reply(self, DelayedOutMessage::Count(count))
169+
CallResponse::Reply(DelayedOutMessage::Count(count))
162170
}
163171

164-
fn handle_cast(
165-
mut self,
166-
message: Self::CastMsg,
167-
_handle: &DelayedHandle,
168-
) -> CastResponse<Self> {
172+
fn handle_cast(&mut self, message: Self::CastMsg, _handle: &DelayedHandle) -> CastResponse {
169173
match message {
170174
DelayedCastMessage::Inc => {
171175
self.count += 1;
172176
}
173177
};
174-
CastResponse::NoReply(self)
178+
CastResponse::NoReply
175179
}
176180
}
177181

examples/bank_threads/src/server.rs

Lines changed: 17 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -61,53 +61,42 @@ impl GenServer for Bank {
6161
Ok(self)
6262
}
6363

64-
fn handle_call(mut self, message: Self::CallMsg, _handle: &BankHandle) -> CallResponse<Self> {
64+
fn handle_call(&mut self, message: Self::CallMsg, _handle: &BankHandle) -> CallResponse<Self> {
6565
match message.clone() {
6666
Self::CallMsg::New { who } => match self.accounts.get(&who) {
67-
Some(_amount) => {
68-
CallResponse::Reply(self, Err(BankError::AlreadyACustomer { who }))
69-
}
67+
Some(_amount) => CallResponse::Reply(Err(BankError::AlreadyACustomer { who })),
7068
None => {
7169
self.accounts.insert(who.clone(), 0);
72-
CallResponse::Reply(self, Ok(OutMessage::Welcome { who }))
70+
CallResponse::Reply(Ok(OutMessage::Welcome { who }))
7371
}
7472
},
7573
Self::CallMsg::Add { who, amount } => match self.accounts.get(&who) {
7674
Some(current) => {
7775
let new_amount = current + amount;
7876
self.accounts.insert(who.clone(), new_amount);
79-
CallResponse::Reply(
80-
self,
81-
Ok(OutMessage::Balance {
82-
who,
83-
amount: new_amount,
84-
}),
85-
)
77+
CallResponse::Reply(Ok(OutMessage::Balance {
78+
who,
79+
amount: new_amount,
80+
}))
8681
}
87-
None => CallResponse::Reply(self, Err(BankError::NotACustomer { who })),
82+
None => CallResponse::Reply(Err(BankError::NotACustomer { who })),
8883
},
8984
Self::CallMsg::Remove { who, amount } => match self.accounts.get(&who) {
9085
Some(&current) => match current < amount {
91-
true => CallResponse::Reply(
92-
self,
93-
Err(BankError::InsufficientBalance {
94-
who,
95-
amount: current,
96-
}),
97-
),
86+
true => CallResponse::Reply(Err(BankError::InsufficientBalance {
87+
who,
88+
amount: current,
89+
})),
9890
false => {
9991
let new_amount = current - amount;
10092
self.accounts.insert(who.clone(), new_amount);
101-
CallResponse::Reply(
102-
self,
103-
Ok(OutMessage::WidrawOk {
104-
who,
105-
amount: new_amount,
106-
}),
107-
)
93+
CallResponse::Reply(Ok(OutMessage::WidrawOk {
94+
who,
95+
amount: new_amount,
96+
}))
10897
}
10998
},
110-
None => CallResponse::Reply(self, Err(BankError::NotACustomer { who })),
99+
None => CallResponse::Reply(Err(BankError::NotACustomer { who })),
111100
},
112101
Self::CallMsg::Stop => CallResponse::Stop(Ok(OutMessage::Stopped)),
113102
}

examples/updater_threads/src/server.rs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,7 @@ impl GenServer for UpdaterServer {
2828
Ok(self)
2929
}
3030

31-
fn handle_cast(
32-
self,
33-
message: Self::CastMsg,
34-
handle: &UpdateServerHandle,
35-
) -> CastResponse<Self> {
31+
fn handle_cast(&mut self, message: Self::CastMsg, handle: &UpdateServerHandle) -> CastResponse {
3632
match message {
3733
Self::CastMsg::Check => {
3834
send_after(self.periodicity, handle.clone(), InMessage::Check);
@@ -42,7 +38,7 @@ impl GenServer for UpdaterServer {
4238

4339
tracing::info!("Response: {resp:?}");
4440

45-
CastResponse::NoReply(self)
41+
CastResponse::NoReply
4642
}
4743
}
4844
}

rt/src/threads/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ where
3434
spawn(f)
3535
}
3636

37-
#[derive(Clone, Default)]
37+
#[derive(Clone, Debug, Default)]
3838
pub struct CancellationToken {
3939
is_cancelled: Arc<AtomicBool>,
4040
}

0 commit comments

Comments
 (0)