Skip to content

Commit 299bf1b

Browse files
committed
Updated registration only after activation hook
1 parent 39a28f6 commit 299bf1b

File tree

2 files changed

+33
-42
lines changed

2 files changed

+33
-42
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
authors = ["Rakshith Ravi <[email protected]>"]
33
edition = "2018"
44
name = "juno"
5-
version = "0.1.3-1"
5+
version = "0.1.3-2"
66
license = "MIT"
77
description = "A helper rust library for the juno microservices framework"
88
homepage = "https://github.com/bytesonus/juno-rust"

src/juno_module.rs

Lines changed: 32 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ use futures::channel::{
2525
oneshot::{channel, Sender},
2626
};
2727
use futures_util::sink::SinkExt;
28+
use std::sync::RwLock;
2829
use std::{
2930
collections::HashMap,
3031
net::{AddrParseError, SocketAddr},
@@ -41,7 +42,7 @@ pub struct JunoModule {
4142
functions: ArcFunctionList,
4243
hook_listeners: ArcHookListenerList,
4344
message_buffer: Buffer,
44-
registered: bool,
45+
registered: Arc<RwLock<bool>>,
4546
}
4647

4748
impl JunoModule {
@@ -62,27 +63,17 @@ impl JunoModule {
6263

6364
#[cfg(target_family = "unix")]
6465
pub fn from_unix_socket(socket_path: &str) -> Self {
65-
JunoModule {
66-
protocol: BaseProtocol::default(),
67-
connection: Box::new(UnixSocketConnection::new(socket_path.to_string())),
68-
requests: Arc::new(Mutex::new(HashMap::new())),
69-
functions: Arc::new(Mutex::new(HashMap::new())),
70-
hook_listeners: Arc::new(Mutex::new(HashMap::new())),
71-
message_buffer: vec![],
72-
registered: false,
73-
}
66+
JunoModule::new(
67+
BaseProtocol::default(),
68+
Box::new(UnixSocketConnection::new(socket_path.to_string())),
69+
)
7470
}
7571

7672
pub fn from_inet_socket(host: &str, port: u16) -> Self {
77-
JunoModule {
78-
protocol: BaseProtocol::default(),
79-
connection: Box::new(InetSocketConnection::new(format!("{}:{}", host, port))),
80-
requests: Arc::new(Mutex::new(HashMap::new())),
81-
functions: Arc::new(Mutex::new(HashMap::new())),
82-
hook_listeners: Arc::new(Mutex::new(HashMap::new())),
83-
message_buffer: vec![],
84-
registered: false,
85-
}
73+
JunoModule::new(
74+
BaseProtocol::default(),
75+
Box::new(InetSocketConnection::new(format!("{}:{}", host, port))),
76+
)
8677
}
8778

8879
pub fn new(protocol: BaseProtocol, connection: Box<dyn BaseConnection + Send + Sync>) -> Self {
@@ -93,7 +84,7 @@ impl JunoModule {
9384
functions: Arc::new(Mutex::new(HashMap::new())),
9485
hook_listeners: Arc::new(Mutex::new(HashMap::new())),
9586
message_buffer: vec![],
96-
registered: false,
87+
registered: Arc::new(RwLock::new(false)),
9788
}
9889
}
9990

@@ -109,8 +100,6 @@ impl JunoModule {
109100
self.protocol
110101
.initialize(String::from(module_id), String::from(version), dependencies);
111102
self.send_request(request).await?;
112-
113-
self.registered = true;
114103
Ok(())
115104
}
116105

@@ -169,7 +158,7 @@ impl JunoModule {
169158
}
170159

171160
fn ensure_registered(&self) -> Result<()> {
172-
if !self.registered {
161+
if !*self.registered.read().unwrap() {
173162
return Err(Error::Internal(String::from(
174163
"Module not registered. Did you .await the call to initialize?",
175164
)));
@@ -187,6 +176,7 @@ impl JunoModule {
187176
let requests = self.requests.clone();
188177
let functions = self.functions.clone();
189178
let hook_listeners = self.hook_listeners.clone();
179+
let registered_store = self.registered.clone();
190180

191181
// Run the read-write loop
192182
task::spawn(async {
@@ -196,6 +186,7 @@ impl JunoModule {
196186
requests,
197187
functions,
198188
hook_listeners,
189+
registered_store,
199190
write_sender,
200191
)
201192
.await;
@@ -206,23 +197,15 @@ impl JunoModule {
206197

207198
async fn send_request(&mut self, request: BaseMessage) -> Result<Value> {
208199
if let BaseMessage::RegisterModuleRequest { .. } = request {
209-
if self.registered {
210-
let (sender, receiver) = channel::<Result<Value>>();
211-
sender.send(Ok(Value::Null)).unwrap();
212-
213-
return match receiver.await {
214-
Ok(value) => value,
215-
Err(_) => Err(Error::Internal(String::from(
216-
"Request sender was dropped before data could be retrieved",
217-
))),
218-
};
200+
if *self.registered.read().unwrap() {
201+
return Err(Error::Internal(String::from("Module already registered")));
219202
}
220203
}
221204

222205
let request_type = request.get_type();
223206
let request_id = request.get_request_id().clone();
224207
let mut encoded = self.protocol.encode(request);
225-
if self.registered || request_type == 1 {
208+
if *self.registered.read().unwrap() || request_type == 1 {
226209
self.connection.send(encoded).await;
227210
} else {
228211
self.message_buffer.append(&mut encoded);
@@ -247,6 +230,7 @@ async fn on_data_listener(
247230
requests: ArcRequestList,
248231
functions: ArcFunctionList,
249232
hook_listeners: ArcHookListenerList,
233+
registered_store: Arc<RwLock<bool>>,
250234
mut write_sender: UnboundedSender<Buffer>,
251235
) {
252236
while let Some(data) = receiver.next().await {
@@ -277,7 +261,7 @@ async fn on_data_listener(
277261
Ok(Value::Null)
278262
}
279263
BaseMessage::TriggerHookRequest { .. } => {
280-
execute_hook_triggered(message, &hook_listeners).await
264+
execute_hook_triggered(message, &registered_store, &hook_listeners).await
281265
}
282266
BaseMessage::Error { error, .. } => Err(Error::FromJuno(error)),
283267
_ => Ok(Value::Null),
@@ -313,15 +297,22 @@ async fn execute_function_call(message: BaseMessage, functions: &ArcFunctionList
313297

314298
async fn execute_hook_triggered(
315299
message: BaseMessage,
300+
registered_store: &Arc<RwLock<bool>>,
316301
hook_listeners: &ArcHookListenerList,
317302
) -> Result<Value> {
318303
if let BaseMessage::TriggerHookRequest { hook, .. } = message {
319-
let hook_listeners = hook_listeners.lock().await;
320-
if !hook_listeners.contains_key(&hook) {
321-
todo!("Wtf do I do now? Need to propogate errors. How do I do that?");
322-
}
323-
for listener in &hook_listeners[&hook] {
324-
listener(Value::Null);
304+
if &hook == "juno.activated" {
305+
*registered_store.write().unwrap() = true;
306+
} else if &hook == "juno.deactivated" {
307+
*registered_store.write().unwrap() = false;
308+
} else {
309+
let hook_listeners = hook_listeners.lock().await;
310+
if !hook_listeners.contains_key(&hook) {
311+
todo!("Wtf do I do now? Need to propogate errors. How do I do that?");
312+
}
313+
for listener in &hook_listeners[&hook] {
314+
listener(Value::Null);
315+
}
325316
}
326317
} else {
327318
panic!("Cannot execute hook from a request that wasn't a TriggerHookRequest!");

0 commit comments

Comments
 (0)