@@ -25,6 +25,7 @@ use futures::channel::{
25
25
oneshot:: { channel, Sender } ,
26
26
} ;
27
27
use futures_util:: sink:: SinkExt ;
28
+ use std:: sync:: RwLock ;
28
29
use std:: {
29
30
collections:: HashMap ,
30
31
net:: { AddrParseError , SocketAddr } ,
@@ -41,7 +42,7 @@ pub struct JunoModule {
41
42
functions : ArcFunctionList ,
42
43
hook_listeners : ArcHookListenerList ,
43
44
message_buffer : Buffer ,
44
- registered : bool ,
45
+ registered : Arc < RwLock < bool > > ,
45
46
}
46
47
47
48
impl JunoModule {
@@ -62,27 +63,17 @@ impl JunoModule {
62
63
63
64
#[ cfg( target_family = "unix" ) ]
64
65
pub fn from_unix_socket ( socket_path : & str ) -> Self {
65
- JunoModule {
66
- protocol : BaseProtocol :: default ( ) ,
67
- connection : Box :: new ( UnixSocketConnection :: new ( socket_path. to_string ( ) ) ) ,
68
- requests : Arc :: new ( Mutex :: new ( HashMap :: new ( ) ) ) ,
69
- functions : Arc :: new ( Mutex :: new ( HashMap :: new ( ) ) ) ,
70
- hook_listeners : Arc :: new ( Mutex :: new ( HashMap :: new ( ) ) ) ,
71
- message_buffer : vec ! [ ] ,
72
- registered : false ,
73
- }
66
+ JunoModule :: new (
67
+ BaseProtocol :: default ( ) ,
68
+ Box :: new ( UnixSocketConnection :: new ( socket_path. to_string ( ) ) ) ,
69
+ )
74
70
}
75
71
76
72
pub fn from_inet_socket ( host : & str , port : u16 ) -> Self {
77
- JunoModule {
78
- protocol : BaseProtocol :: default ( ) ,
79
- connection : Box :: new ( InetSocketConnection :: new ( format ! ( "{}:{}" , host, port) ) ) ,
80
- requests : Arc :: new ( Mutex :: new ( HashMap :: new ( ) ) ) ,
81
- functions : Arc :: new ( Mutex :: new ( HashMap :: new ( ) ) ) ,
82
- hook_listeners : Arc :: new ( Mutex :: new ( HashMap :: new ( ) ) ) ,
83
- message_buffer : vec ! [ ] ,
84
- registered : false ,
85
- }
73
+ JunoModule :: new (
74
+ BaseProtocol :: default ( ) ,
75
+ Box :: new ( InetSocketConnection :: new ( format ! ( "{}:{}" , host, port) ) ) ,
76
+ )
86
77
}
87
78
88
79
pub fn new ( protocol : BaseProtocol , connection : Box < dyn BaseConnection + Send + Sync > ) -> Self {
@@ -93,7 +84,7 @@ impl JunoModule {
93
84
functions : Arc :: new ( Mutex :: new ( HashMap :: new ( ) ) ) ,
94
85
hook_listeners : Arc :: new ( Mutex :: new ( HashMap :: new ( ) ) ) ,
95
86
message_buffer : vec ! [ ] ,
96
- registered : false ,
87
+ registered : Arc :: new ( RwLock :: new ( false ) ) ,
97
88
}
98
89
}
99
90
@@ -109,8 +100,6 @@ impl JunoModule {
109
100
self . protocol
110
101
. initialize ( String :: from ( module_id) , String :: from ( version) , dependencies) ;
111
102
self . send_request ( request) . await ?;
112
-
113
- self . registered = true ;
114
103
Ok ( ( ) )
115
104
}
116
105
@@ -169,7 +158,7 @@ impl JunoModule {
169
158
}
170
159
171
160
fn ensure_registered ( & self ) -> Result < ( ) > {
172
- if !self . registered {
161
+ if !* self . registered . read ( ) . unwrap ( ) {
173
162
return Err ( Error :: Internal ( String :: from (
174
163
"Module not registered. Did you .await the call to initialize?" ,
175
164
) ) ) ;
@@ -187,6 +176,7 @@ impl JunoModule {
187
176
let requests = self . requests . clone ( ) ;
188
177
let functions = self . functions . clone ( ) ;
189
178
let hook_listeners = self . hook_listeners . clone ( ) ;
179
+ let registered_store = self . registered . clone ( ) ;
190
180
191
181
// Run the read-write loop
192
182
task:: spawn ( async {
@@ -196,6 +186,7 @@ impl JunoModule {
196
186
requests,
197
187
functions,
198
188
hook_listeners,
189
+ registered_store,
199
190
write_sender,
200
191
)
201
192
. await ;
@@ -206,23 +197,15 @@ impl JunoModule {
206
197
207
198
async fn send_request ( & mut self , request : BaseMessage ) -> Result < Value > {
208
199
if let BaseMessage :: RegisterModuleRequest { .. } = request {
209
- if self . registered {
210
- let ( sender, receiver) = channel :: < Result < Value > > ( ) ;
211
- sender. send ( Ok ( Value :: Null ) ) . unwrap ( ) ;
212
-
213
- return match receiver. await {
214
- Ok ( value) => value,
215
- Err ( _) => Err ( Error :: Internal ( String :: from (
216
- "Request sender was dropped before data could be retrieved" ,
217
- ) ) ) ,
218
- } ;
200
+ if * self . registered . read ( ) . unwrap ( ) {
201
+ return Err ( Error :: Internal ( String :: from ( "Module already registered" ) ) ) ;
219
202
}
220
203
}
221
204
222
205
let request_type = request. get_type ( ) ;
223
206
let request_id = request. get_request_id ( ) . clone ( ) ;
224
207
let mut encoded = self . protocol . encode ( request) ;
225
- if self . registered || request_type == 1 {
208
+ if * self . registered . read ( ) . unwrap ( ) || request_type == 1 {
226
209
self . connection . send ( encoded) . await ;
227
210
} else {
228
211
self . message_buffer . append ( & mut encoded) ;
@@ -247,6 +230,7 @@ async fn on_data_listener(
247
230
requests : ArcRequestList ,
248
231
functions : ArcFunctionList ,
249
232
hook_listeners : ArcHookListenerList ,
233
+ registered_store : Arc < RwLock < bool > > ,
250
234
mut write_sender : UnboundedSender < Buffer > ,
251
235
) {
252
236
while let Some ( data) = receiver. next ( ) . await {
@@ -277,7 +261,7 @@ async fn on_data_listener(
277
261
Ok ( Value :: Null )
278
262
}
279
263
BaseMessage :: TriggerHookRequest { .. } => {
280
- execute_hook_triggered ( message, & hook_listeners) . await
264
+ execute_hook_triggered ( message, & registered_store , & hook_listeners) . await
281
265
}
282
266
BaseMessage :: Error { error, .. } => Err ( Error :: FromJuno ( error) ) ,
283
267
_ => Ok ( Value :: Null ) ,
@@ -313,15 +297,22 @@ async fn execute_function_call(message: BaseMessage, functions: &ArcFunctionList
313
297
314
298
async fn execute_hook_triggered (
315
299
message : BaseMessage ,
300
+ registered_store : & Arc < RwLock < bool > > ,
316
301
hook_listeners : & ArcHookListenerList ,
317
302
) -> Result < Value > {
318
303
if let BaseMessage :: TriggerHookRequest { hook, .. } = message {
319
- let hook_listeners = hook_listeners. lock ( ) . await ;
320
- if !hook_listeners. contains_key ( & hook) {
321
- todo ! ( "Wtf do I do now? Need to propogate errors. How do I do that?" ) ;
322
- }
323
- for listener in & hook_listeners[ & hook] {
324
- listener ( Value :: Null ) ;
304
+ if & hook == "juno.activated" {
305
+ * registered_store. write ( ) . unwrap ( ) = true ;
306
+ } else if & hook == "juno.deactivated" {
307
+ * registered_store. write ( ) . unwrap ( ) = false ;
308
+ } else {
309
+ let hook_listeners = hook_listeners. lock ( ) . await ;
310
+ if !hook_listeners. contains_key ( & hook) {
311
+ todo ! ( "Wtf do I do now? Need to propogate errors. How do I do that?" ) ;
312
+ }
313
+ for listener in & hook_listeners[ & hook] {
314
+ listener ( Value :: Null ) ;
315
+ }
325
316
}
326
317
} else {
327
318
panic ! ( "Cannot execute hook from a request that wasn't a TriggerHookRequest!" ) ;
0 commit comments