@@ -11,8 +11,10 @@ use core::{
11
11
sync:: atomic:: { compiler_fence, Ordering } ,
12
12
} ;
13
13
use std:: {
14
+ env,
14
15
io:: { ErrorKind , Read , Write } ,
15
16
net:: { SocketAddr , TcpListener , TcpStream , ToSocketAddrs } ,
17
+ sync:: Arc ,
16
18
} ;
17
19
18
20
#[ cfg( feature = "std" ) ]
@@ -30,7 +32,7 @@ use serde::{de::DeserializeOwned, Deserialize, Serialize};
30
32
use tokio:: {
31
33
io:: { AsyncReadExt , AsyncWriteExt } ,
32
34
sync:: { broadcast, mpsc} ,
33
- task:: spawn,
35
+ task:: { spawn, JoinHandle } ,
34
36
} ;
35
37
#[ cfg( feature = "std" ) ]
36
38
use typed_builder:: TypedBuilder ;
75
77
phantom : PhantomData < I > ,
76
78
}
77
79
80
+ const UNDEFINED_CLIENT_ID : ClientId = ClientId ( 0xffffffff ) ;
81
+
78
82
impl < I , MT > TcpEventBroker < I , MT >
79
83
where
80
84
I : Input ,
@@ -105,9 +109,10 @@ where
105
109
106
110
/// Run in the broker until all clients exit
107
111
#[ tokio:: main( flavor = "current_thread" ) ]
112
+ #[ allow( clippy:: too_many_lines) ]
108
113
pub async fn broker_loop ( & mut self ) -> Result < ( ) , Error > {
109
- let ( tx_bc, rx) = broadcast:: channel ( 128 ) ;
110
- let ( tx, mut rx_mpsc) = mpsc:: channel ( 128 ) ;
114
+ let ( tx_bc, rx) = broadcast:: channel ( 1024 ) ;
115
+ let ( tx, mut rx_mpsc) = mpsc:: channel ( 1024 ) ;
111
116
112
117
let exit_cleanly_after = self . exit_cleanly_after ;
113
118
@@ -118,36 +123,61 @@ where
118
123
let listener = tokio:: net:: TcpListener :: from_std ( listener) ?;
119
124
120
125
let tokio_broker = spawn ( async move {
121
- let mut recv_handles = vec ! [ ] ;
126
+ let mut recv_handles: Vec < JoinHandle < _ > > = vec ! [ ] ;
127
+ let mut receivers: Vec < Arc < tokio:: sync:: Mutex < broadcast:: Receiver < _ > > > > = vec ! [ ] ;
122
128
123
129
loop {
130
+ let mut reached_max = false ;
124
131
if let Some ( max_clients) = exit_cleanly_after {
125
132
if max_clients. get ( ) <= recv_handles. len ( ) {
126
- // we waited fro all the clients we wanted to see attached. Now wait for them to close their tcp connections.
127
- break ;
133
+ // we waited for all the clients we wanted to see attached. Now wait for them to close their tcp connections.
134
+ reached_max = true ;
128
135
}
129
136
}
130
137
131
- //println!("loop");
132
138
// Asynchronously wait for an inbound socket.
133
- let ( socket, _) = listener. accept ( ) . await . expect ( "test " ) ;
139
+ let ( socket, _) = listener. accept ( ) . await . expect ( "Accept failed " ) ;
134
140
let ( mut read, mut write) = tokio:: io:: split ( socket) ;
135
- // ClientIds for this broker start at 0.
136
- let this_client_id = ClientId ( recv_handles. len ( ) . try_into ( ) . unwrap ( ) ) ;
141
+
142
+ // Protocol: the new client communicate its old ClientId or -1 if new
143
+ let mut this_client_id = [ 0 ; 4 ] ;
144
+ read. read_exact ( & mut this_client_id)
145
+ . await
146
+ . expect ( "Socket closed?" ) ;
147
+ let this_client_id = ClientId ( u32:: from_le_bytes ( this_client_id) ) ;
148
+
149
+ let ( this_client_id, is_old) = if this_client_id == UNDEFINED_CLIENT_ID {
150
+ if reached_max {
151
+ ( UNDEFINED_CLIENT_ID , false ) // Dumb id
152
+ } else {
153
+ // ClientIds for this broker start at 0.
154
+ ( ClientId ( recv_handles. len ( ) . try_into ( ) . unwrap ( ) ) , false )
155
+ }
156
+ } else {
157
+ ( this_client_id, true )
158
+ } ;
159
+
137
160
let this_client_id_bytes = this_client_id. 0 . to_le_bytes ( ) ;
138
161
139
- // Send the client id for this node;
162
+ // Protocol: Send the client id for this node;
140
163
write. write_all ( & this_client_id_bytes) . await . unwrap ( ) ;
141
164
165
+ if !is_old && reached_max {
166
+ continue ;
167
+ }
168
+
142
169
let tx_inner = tx. clone ( ) ;
143
- let mut rx_inner = rx. resubscribe ( ) ;
144
- // Keep all handles around.
145
- recv_handles. push ( spawn ( async move {
170
+
171
+ let handle = async move {
146
172
// In a loop, read data from the socket and write the data back.
147
173
loop {
148
174
let mut len_buf = [ 0 ; 4 ] ;
149
175
150
- read. read_exact ( & mut len_buf) . await . expect ( "Socket closed?" ) ;
176
+ if read. read_exact ( & mut len_buf) . await . is_err ( ) {
177
+ // The socket is closed, the client is restarting
178
+ log:: info!( "Socket closed, client restarting" ) ;
179
+ return ;
180
+ }
151
181
152
182
let mut len = u32:: from_le_bytes ( len_buf) ;
153
183
// we forward the sender id as well, so we add 4 bytes to the message length
@@ -158,26 +188,55 @@ where
158
188
159
189
let mut buf = vec ! [ 0 ; len as usize ] ;
160
190
161
- read. read_exact ( & mut buf)
191
+ if read
192
+ . read_exact ( & mut buf)
162
193
. await
163
- . expect ( "failed to read data from socket" ) ;
194
+ // .expect("Failed to read data from socket"); // TODO verify if we have to handle this error
195
+ . is_err ( )
196
+ {
197
+ // The socket is closed, the client is restarting
198
+ log:: info!( "Socket closed, client restarting" ) ;
199
+ return ;
200
+ }
164
201
165
202
#[ cfg( feature = "tcp_debug" ) ]
166
203
println ! ( "len: {len:?} - {buf:?}" ) ;
167
204
tx_inner. send ( buf) . await . expect ( "Could not send" ) ;
168
205
}
169
- } ) ) ;
206
+ } ;
207
+
208
+ let client_idx = this_client_id. 0 as usize ;
209
+
210
+ // Keep all handles around.
211
+ if is_old {
212
+ recv_handles[ client_idx] . abort ( ) ;
213
+ recv_handles[ client_idx] = spawn ( handle) ;
214
+ } else {
215
+ recv_handles. push ( spawn ( handle) ) ;
216
+ // Get old messages only if new
217
+ let rx_inner = Arc :: new ( tokio:: sync:: Mutex :: new ( rx. resubscribe ( ) ) ) ;
218
+ receivers. push ( rx_inner. clone ( ) ) ;
219
+ }
220
+
221
+ let rx_inner = receivers[ client_idx] . clone ( ) ;
222
+
170
223
// The forwarding end. No need to keep a handle to this (TODO: unless they don't quit/get stuck?)
171
224
spawn ( async move {
172
225
// In a loop, read data from the socket and write the data back.
173
226
loop {
174
- let buf: Vec < u8 > = rx_inner. recv ( ) . await . unwrap_or ( vec ! [ ] ) ;
227
+ let buf: Vec < u8 > = rx_inner
228
+ . lock ( )
229
+ . await
230
+ . recv ( )
231
+ . await
232
+ . expect ( "Could not receive" ) ;
233
+ // TODO handle full capacity, Lagged https://docs.rs/tokio/latest/tokio/sync/broadcast/error/enum.RecvError.html
175
234
176
235
#[ cfg( feature = "tcp_debug" ) ]
177
236
println ! ( "{buf:?}" ) ;
178
237
179
238
if buf. len ( ) <= 4 {
180
- eprintln ! ( "We got no contents (or only the length) in a broadcast" ) ;
239
+ log :: warn !( "We got no contents (or only the length) in a broadcast" ) ;
181
240
continue ;
182
241
}
183
242
@@ -194,17 +253,26 @@ where
194
253
let len_buf: [ u8 ; 4 ] = len. to_le_bytes ( ) ;
195
254
196
255
// Write message length
197
- write. write_all ( & len_buf) . await . expect ( "Writing failed" ) ;
256
+ if write. write_all ( & len_buf) . await . is_err ( ) {
257
+ // The socket is closed, the client is restarting
258
+ log:: info!( "Socket closed, client restarting" ) ;
259
+ return ;
260
+ }
198
261
// Write the rest
199
- write. write_all ( & buf) . await . expect ( "Socket closed?" ) ;
262
+ if write. write_all ( & buf) . await . is_err ( ) {
263
+ // The socket is closed, the client is restarting
264
+ log:: info!( "Socket closed, client restarting" ) ;
265
+ return ;
266
+ }
200
267
}
201
268
} ) ;
202
269
}
203
- println ! ( "joining handles.." ) ;
270
+
271
+ /*log::info!("Joining handles..");
204
272
// wait for all clients to exit/error out
205
273
for recv_handle in recv_handles {
206
274
drop(recv_handle.await);
207
- }
275
+ }*/
208
276
} ) ;
209
277
210
278
loop {
@@ -386,12 +454,20 @@ impl<S> TcpEventManager<S>
386
454
where
387
455
S : UsesInput + HasExecutions + HasClientPerfMonitor ,
388
456
{
389
- /// Create a manager from a raw TCP client
390
- pub fn new < A : ToSocketAddrs > ( addr : & A , configuration : EventConfig ) -> Result < Self , Error > {
457
+ /// Create a manager from a raw TCP client specifying the client id
458
+ pub fn existing < A : ToSocketAddrs > (
459
+ addr : & A ,
460
+ client_id : ClientId ,
461
+ configuration : EventConfig ,
462
+ ) -> Result < Self , Error > {
391
463
let mut tcp = TcpStream :: connect ( addr) ?;
392
464
393
- let mut our_client_id_buf = [ 0_u8 ; 4 ] ;
394
- tcp. read_exact ( & mut our_client_id_buf) . unwrap ( ) ;
465
+ let mut our_client_id_buf = client_id. 0 . to_le_bytes ( ) ;
466
+ tcp. write_all ( & our_client_id_buf)
467
+ . expect ( "Cannot write to the broker" ) ;
468
+
469
+ tcp. read_exact ( & mut our_client_id_buf)
470
+ . expect ( "Cannot read from the broker" ) ;
395
471
let client_id = ClientId ( u32:: from_le_bytes ( our_client_id_buf) ) ;
396
472
397
473
println ! ( "Our client id: {client_id:?}" ) ;
@@ -407,15 +483,49 @@ where
407
483
} )
408
484
}
409
485
486
+ /// Create a manager from a raw TCP client
487
+ pub fn new < A : ToSocketAddrs > ( addr : & A , configuration : EventConfig ) -> Result < Self , Error > {
488
+ Self :: existing ( addr, UNDEFINED_CLIENT_ID , configuration)
489
+ }
490
+
491
+ /// Create an TCP event manager on a port specifying the client id
492
+ ///
493
+ /// If the port is not yet bound, it will act as a broker; otherwise, it
494
+ /// will act as a client.
495
+ pub fn existing_on_port (
496
+ port : u16 ,
497
+ client_id : ClientId ,
498
+ configuration : EventConfig ,
499
+ ) -> Result < Self , Error > {
500
+ Self :: existing ( & ( "127.0.0.1" , port) , client_id, configuration)
501
+ }
502
+
410
503
/// Create an TCP event manager on a port
411
504
///
412
505
/// If the port is not yet bound, it will act as a broker; otherwise, it
413
506
/// will act as a client.
414
- #[ cfg( feature = "std" ) ]
415
507
pub fn on_port ( port : u16 , configuration : EventConfig ) -> Result < Self , Error > {
416
508
Self :: new ( & ( "127.0.0.1" , port) , configuration)
417
509
}
418
510
511
+ /// Create an TCP event manager on a port specifying the client id from env
512
+ ///
513
+ /// If the port is not yet bound, it will act as a broker; otherwise, it
514
+ /// will act as a client.
515
+ pub fn existing_from_env < A : ToSocketAddrs > (
516
+ addr : & A ,
517
+ env_name : & str ,
518
+ configuration : EventConfig ,
519
+ ) -> Result < Self , Error > {
520
+ let this_id = ClientId ( str:: parse :: < u32 > ( & env:: var ( env_name) ?) ?) ;
521
+ Self :: existing ( addr, this_id, configuration)
522
+ }
523
+
524
+ /// Write the client id for a client [`EventManager`] to env vars
525
+ pub fn to_env ( & self , env_name : & str ) {
526
+ env:: set_var ( env_name, format ! ( "{}" , self . client_id. 0 ) ) ;
527
+ }
528
+
419
529
// Handle arriving events in the client
420
530
#[ allow( clippy:: unused_self) ]
421
531
fn handle_in_client < E , Z > (
@@ -731,8 +841,11 @@ where
731
841
fn on_restart ( & mut self , state : & mut S ) -> Result < ( ) , Error > {
732
842
// First, reset the page to 0 so the next iteration can read read from the beginning of this page
733
843
self . staterestorer . reset ( ) ;
734
- self . staterestorer
735
- . save ( & if self . save_state { Some ( state) } else { None } ) ?;
844
+ self . staterestorer . save ( & if self . save_state {
845
+ Some ( ( state, self . tcp_mgr . client_id ) )
846
+ } else {
847
+ None
848
+ } ) ?;
736
849
self . await_restart_safe ( ) ;
737
850
Ok ( ( ) )
738
851
}
@@ -938,7 +1051,7 @@ where
938
1051
} ;
939
1052
940
1053
// We get here if we are on Unix, or we are a broker on Windows (or without forks).
941
- let ( _mgr , core_id) = match self . kind {
1054
+ let ( mgr , core_id) = match self . kind {
942
1055
ManagerKind :: Any => {
943
1056
let connection = create_nonblocking_listener ( ( "127.0.0.1" , self . broker_port ) ) ;
944
1057
match connection {
@@ -994,7 +1107,7 @@ where
994
1107
}
995
1108
996
1109
// We are the fuzzer respawner in a tcp client
997
- // mgr.to_env(_ENV_FUZZER_BROKER_CLIENT_INITIAL);
1110
+ mgr. to_env ( _ENV_FUZZER_BROKER_CLIENT_INITIAL) ;
998
1111
999
1112
// First, create a channel from the current fuzzer to the next to store state between restarts.
1000
1113
#[ cfg( unix) ]
@@ -1030,6 +1143,7 @@ where
1030
1143
// Client->parent loop
1031
1144
loop {
1032
1145
log:: info!( "Spawning next client (id {ctr})" ) ;
1146
+ println ! ( "Spawning next client (id {ctr}) {core_id:?}" ) ;
1033
1147
1034
1148
// On Unix, we fork (when fork feature is enabled)
1035
1149
#[ cfg( all( unix, feature = "fork" ) ) ]
@@ -1091,19 +1205,27 @@ where
1091
1205
}
1092
1206
1093
1207
// If we're restarting, deserialize the old state.
1094
- let ( state, mut mgr) = if let Some ( state_opt) = staterestorer. restore ( ) ? {
1208
+ let ( state, mut mgr) = if let Some ( ( state_opt, this_id ) ) = staterestorer. restore ( ) ? {
1095
1209
(
1096
1210
state_opt,
1097
1211
TcpRestartingEventManager :: with_save_state (
1098
- TcpEventManager :: on_port ( self . broker_port , self . configuration ) ?,
1212
+ TcpEventManager :: existing_on_port (
1213
+ self . broker_port ,
1214
+ this_id,
1215
+ self . configuration ,
1216
+ ) ?,
1099
1217
staterestorer,
1100
1218
self . serialize_state ,
1101
1219
) ,
1102
1220
)
1103
1221
} else {
1104
1222
log:: info!( "First run. Let's set it all up" ) ;
1105
1223
// Mgr to send and receive msgs from/to all other fuzzer instances
1106
- let mgr = TcpEventManager :: < S > :: on_port ( self . broker_port , self . configuration ) ?;
1224
+ let mgr = TcpEventManager :: < S > :: existing_from_env (
1225
+ & ( "127.0.0.1" , self . broker_port ) ,
1226
+ _ENV_FUZZER_BROKER_CLIENT_INITIAL,
1227
+ self . configuration ,
1228
+ ) ?;
1107
1229
1108
1230
(
1109
1231
None ,
0 commit comments