@@ -17,15 +17,14 @@ use crate::connection::UnixSocketConnection;
17
17
18
18
use async_std:: {
19
19
prelude:: * ,
20
- sync:: { Arc , Mutex } ,
20
+ sync:: { Arc , Mutex , RwLock } ,
21
21
task,
22
22
} ;
23
23
use futures:: channel:: {
24
24
mpsc:: { UnboundedReceiver , UnboundedSender } ,
25
25
oneshot:: { channel, Sender } ,
26
26
} ;
27
27
use futures_util:: sink:: SinkExt ;
28
- use std:: sync:: RwLock ;
29
28
use std:: {
30
29
collections:: HashMap ,
31
30
net:: { AddrParseError , SocketAddr } ,
@@ -41,7 +40,7 @@ pub struct JunoModule {
41
40
requests : ArcRequestList ,
42
41
functions : ArcFunctionList ,
43
42
hook_listeners : ArcHookListenerList ,
44
- message_buffer : Buffer ,
43
+ message_buffer : Arc < RwLock < Buffer > > ,
45
44
registered : Arc < RwLock < bool > > ,
46
45
}
47
46
@@ -83,7 +82,7 @@ impl JunoModule {
83
82
requests : Arc :: new ( Mutex :: new ( HashMap :: new ( ) ) ) ,
84
83
functions : Arc :: new ( Mutex :: new ( HashMap :: new ( ) ) ) ,
85
84
hook_listeners : Arc :: new ( Mutex :: new ( HashMap :: new ( ) ) ) ,
86
- message_buffer : vec ! [ ] ,
85
+ message_buffer : Arc :: new ( RwLock :: new ( vec ! [ ] ) ) ,
87
86
registered : Arc :: new ( RwLock :: new ( false ) ) ,
88
87
}
89
88
}
@@ -165,6 +164,7 @@ impl JunoModule {
165
164
let requests = self . requests . clone ( ) ;
166
165
let functions = self . functions . clone ( ) ;
167
166
let hook_listeners = self . hook_listeners . clone ( ) ;
167
+ let message_buffer = self . message_buffer . clone ( ) ;
168
168
let registered_store = self . registered . clone ( ) ;
169
169
170
170
// Run the read-write loop
@@ -175,6 +175,7 @@ impl JunoModule {
175
175
requests,
176
176
functions,
177
177
hook_listeners,
178
+ message_buffer,
178
179
registered_store,
179
180
write_sender,
180
181
)
@@ -186,22 +187,21 @@ impl JunoModule {
186
187
187
188
async fn send_request ( & mut self , request : BaseMessage ) -> Result < Value > {
188
189
if let BaseMessage :: RegisterModuleRequest { .. } = request {
189
- if * self . registered . read ( ) . unwrap ( ) {
190
+ if * self . registered . read ( ) . await {
190
191
return Err ( Error :: Internal ( String :: from ( "Module already registered" ) ) ) ;
191
192
}
192
193
}
193
194
194
195
let request_type = request. get_type ( ) ;
195
196
let request_id = request. get_request_id ( ) . clone ( ) ;
196
197
let mut encoded = self . protocol . encode ( request) ;
197
- if * self . registered . read ( ) . unwrap ( ) || request_type == 1 {
198
- if self . message_buffer . len ( ) != 0 {
199
- self . connection . send ( self . message_buffer . clone ( ) ) . await ;
200
- self . message_buffer . clear ( ) ;
201
- }
198
+ if * self . registered . read ( ) . await || request_type == 1 {
202
199
self . connection . send ( encoded) . await ;
203
200
} else {
204
- self . message_buffer . append ( & mut encoded) ;
201
+ self . message_buffer
202
+ . write ( )
203
+ . await
204
+ . append ( & mut encoded) ;
205
205
}
206
206
207
207
let ( sender, receiver) = channel :: < Result < Value > > ( ) ;
@@ -223,6 +223,7 @@ async fn on_data_listener(
223
223
requests : ArcRequestList ,
224
224
functions : ArcFunctionList ,
225
225
hook_listeners : ArcHookListenerList ,
226
+ message_buffer : Arc < RwLock < Buffer > > ,
226
227
registered_store : Arc < RwLock < bool > > ,
227
228
mut write_sender : UnboundedSender < Buffer > ,
228
229
) {
@@ -254,7 +255,7 @@ async fn on_data_listener(
254
255
Ok ( Value :: Null )
255
256
}
256
257
BaseMessage :: TriggerHookResponse { .. } => {
257
- execute_hook_triggered ( message, & registered_store, & hook_listeners) . await
258
+ execute_hook_triggered ( message, & message_buffer , & write_sender , & registered_store, & hook_listeners) . await
258
259
}
259
260
BaseMessage :: Error { error, .. } => Err ( Error :: FromJuno ( error) ) ,
260
261
_ => Ok ( Value :: Null ) ,
@@ -290,16 +291,23 @@ async fn execute_function_call(message: BaseMessage, functions: &ArcFunctionList
290
291
291
292
async fn execute_hook_triggered (
292
293
message : BaseMessage ,
294
+ message_buffer : & Arc < RwLock < Buffer > > ,
295
+ mut write_sender : & UnboundedSender < Buffer > ,
293
296
registered_store : & Arc < RwLock < bool > > ,
294
297
hook_listeners : & ArcHookListenerList ,
295
298
) -> Result < Value > {
296
299
if let BaseMessage :: TriggerHookResponse { hook, data, .. } = message {
297
300
if hook. is_some ( ) {
298
301
let hook = hook. unwrap ( ) ;
299
302
if hook == "juno.activated" {
300
- * registered_store. write ( ) . unwrap ( ) = true ;
303
+ * registered_store. write ( ) . await = true ;
304
+ let mut buffer = message_buffer. write ( ) . await ;
305
+ if let Err ( err) = write_sender. send ( buffer. clone ( ) ) . await {
306
+ println ! ( "Error writing remaining buffer: {}" , err) ;
307
+ }
308
+ buffer. clear ( ) ;
301
309
} else if & hook == "juno.deactivated" {
302
- * registered_store. write ( ) . unwrap ( ) = false ;
310
+ * registered_store. write ( ) . await = false ;
303
311
} else {
304
312
let hook_listeners = hook_listeners. lock ( ) . await ;
305
313
if !hook_listeners. contains_key ( & hook) {
0 commit comments