Skip to content

Commit 19aac2f

Browse files
Fix TCP manager and restarts (#1556)
* Fix TCP manager and restarts * clippy * clippy * clippy
1 parent 652c24c commit 19aac2f

File tree

1 file changed

+158
-36
lines changed

1 file changed

+158
-36
lines changed

libafl/src/events/tcp.rs

Lines changed: 158 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@ use core::{
1111
sync::atomic::{compiler_fence, Ordering},
1212
};
1313
use std::{
14+
env,
1415
io::{ErrorKind, Read, Write},
1516
net::{SocketAddr, TcpListener, TcpStream, ToSocketAddrs},
17+
sync::Arc,
1618
};
1719

1820
#[cfg(feature = "std")]
@@ -30,7 +32,7 @@ use serde::{de::DeserializeOwned, Deserialize, Serialize};
3032
use tokio::{
3133
io::{AsyncReadExt, AsyncWriteExt},
3234
sync::{broadcast, mpsc},
33-
task::spawn,
35+
task::{spawn, JoinHandle},
3436
};
3537
#[cfg(feature = "std")]
3638
use typed_builder::TypedBuilder;
@@ -75,6 +77,8 @@ where
7577
phantom: PhantomData<I>,
7678
}
7779

80+
const UNDEFINED_CLIENT_ID: ClientId = ClientId(0xffffffff);
81+
7882
impl<I, MT> TcpEventBroker<I, MT>
7983
where
8084
I: Input,
@@ -105,9 +109,10 @@ where
105109

106110
/// Run in the broker until all clients exit
107111
#[tokio::main(flavor = "current_thread")]
112+
#[allow(clippy::too_many_lines)]
108113
pub async fn broker_loop(&mut self) -> Result<(), Error> {
109-
let (tx_bc, rx) = broadcast::channel(128);
110-
let (tx, mut rx_mpsc) = mpsc::channel(128);
114+
let (tx_bc, rx) = broadcast::channel(1024);
115+
let (tx, mut rx_mpsc) = mpsc::channel(1024);
111116

112117
let exit_cleanly_after = self.exit_cleanly_after;
113118

@@ -118,36 +123,61 @@ where
118123
let listener = tokio::net::TcpListener::from_std(listener)?;
119124

120125
let tokio_broker = spawn(async move {
121-
let mut recv_handles = vec![];
126+
let mut recv_handles: Vec<JoinHandle<_>> = vec![];
127+
let mut receivers: Vec<Arc<tokio::sync::Mutex<broadcast::Receiver<_>>>> = vec![];
122128

123129
loop {
130+
let mut reached_max = false;
124131
if let Some(max_clients) = exit_cleanly_after {
125132
if max_clients.get() <= recv_handles.len() {
126-
// we waited fro all the clients we wanted to see attached. Now wait for them to close their tcp connections.
127-
break;
133+
// we waited for all the clients we wanted to see attached. Now wait for them to close their tcp connections.
134+
reached_max = true;
128135
}
129136
}
130137

131-
//println!("loop");
132138
// Asynchronously wait for an inbound socket.
133-
let (socket, _) = listener.accept().await.expect("test");
139+
let (socket, _) = listener.accept().await.expect("Accept failed");
134140
let (mut read, mut write) = tokio::io::split(socket);
135-
// ClientIds for this broker start at 0.
136-
let this_client_id = ClientId(recv_handles.len().try_into().unwrap());
141+
142+
// Protocol: the new client communicate its old ClientId or -1 if new
143+
let mut this_client_id = [0; 4];
144+
read.read_exact(&mut this_client_id)
145+
.await
146+
.expect("Socket closed?");
147+
let this_client_id = ClientId(u32::from_le_bytes(this_client_id));
148+
149+
let (this_client_id, is_old) = if this_client_id == UNDEFINED_CLIENT_ID {
150+
if reached_max {
151+
(UNDEFINED_CLIENT_ID, false) // Dumb id
152+
} else {
153+
// ClientIds for this broker start at 0.
154+
(ClientId(recv_handles.len().try_into().unwrap()), false)
155+
}
156+
} else {
157+
(this_client_id, true)
158+
};
159+
137160
let this_client_id_bytes = this_client_id.0.to_le_bytes();
138161

139-
// Send the client id for this node;
162+
// Protocol: Send the client id for this node;
140163
write.write_all(&this_client_id_bytes).await.unwrap();
141164

165+
if !is_old && reached_max {
166+
continue;
167+
}
168+
142169
let tx_inner = tx.clone();
143-
let mut rx_inner = rx.resubscribe();
144-
// Keep all handles around.
145-
recv_handles.push(spawn(async move {
170+
171+
let handle = async move {
146172
// In a loop, read data from the socket and write the data back.
147173
loop {
148174
let mut len_buf = [0; 4];
149175

150-
read.read_exact(&mut len_buf).await.expect("Socket closed?");
176+
if read.read_exact(&mut len_buf).await.is_err() {
177+
// The socket is closed, the client is restarting
178+
log::info!("Socket closed, client restarting");
179+
return;
180+
}
151181

152182
let mut len = u32::from_le_bytes(len_buf);
153183
// we forward the sender id as well, so we add 4 bytes to the message length
@@ -158,26 +188,55 @@ where
158188

159189
let mut buf = vec![0; len as usize];
160190

161-
read.read_exact(&mut buf)
191+
if read
192+
.read_exact(&mut buf)
162193
.await
163-
.expect("failed to read data from socket");
194+
// .expect("Failed to read data from socket"); // TODO verify if we have to handle this error
195+
.is_err()
196+
{
197+
// The socket is closed, the client is restarting
198+
log::info!("Socket closed, client restarting");
199+
return;
200+
}
164201

165202
#[cfg(feature = "tcp_debug")]
166203
println!("len: {len:?} - {buf:?}");
167204
tx_inner.send(buf).await.expect("Could not send");
168205
}
169-
}));
206+
};
207+
208+
let client_idx = this_client_id.0 as usize;
209+
210+
// Keep all handles around.
211+
if is_old {
212+
recv_handles[client_idx].abort();
213+
recv_handles[client_idx] = spawn(handle);
214+
} else {
215+
recv_handles.push(spawn(handle));
216+
// Get old messages only if new
217+
let rx_inner = Arc::new(tokio::sync::Mutex::new(rx.resubscribe()));
218+
receivers.push(rx_inner.clone());
219+
}
220+
221+
let rx_inner = receivers[client_idx].clone();
222+
170223
// The forwarding end. No need to keep a handle to this (TODO: unless they don't quit/get stuck?)
171224
spawn(async move {
172225
// In a loop, read data from the socket and write the data back.
173226
loop {
174-
let buf: Vec<u8> = rx_inner.recv().await.unwrap_or(vec![]);
227+
let buf: Vec<u8> = rx_inner
228+
.lock()
229+
.await
230+
.recv()
231+
.await
232+
.expect("Could not receive");
233+
// TODO handle full capacity, Lagged https://docs.rs/tokio/latest/tokio/sync/broadcast/error/enum.RecvError.html
175234

176235
#[cfg(feature = "tcp_debug")]
177236
println!("{buf:?}");
178237

179238
if buf.len() <= 4 {
180-
eprintln!("We got no contents (or only the length) in a broadcast");
239+
log::warn!("We got no contents (or only the length) in a broadcast");
181240
continue;
182241
}
183242

@@ -194,17 +253,26 @@ where
194253
let len_buf: [u8; 4] = len.to_le_bytes();
195254

196255
// Write message length
197-
write.write_all(&len_buf).await.expect("Writing failed");
256+
if write.write_all(&len_buf).await.is_err() {
257+
// The socket is closed, the client is restarting
258+
log::info!("Socket closed, client restarting");
259+
return;
260+
}
198261
// Write the rest
199-
write.write_all(&buf).await.expect("Socket closed?");
262+
if write.write_all(&buf).await.is_err() {
263+
// The socket is closed, the client is restarting
264+
log::info!("Socket closed, client restarting");
265+
return;
266+
}
200267
}
201268
});
202269
}
203-
println!("joining handles..");
270+
271+
/*log::info!("Joining handles..");
204272
// wait for all clients to exit/error out
205273
for recv_handle in recv_handles {
206274
drop(recv_handle.await);
207-
}
275+
}*/
208276
});
209277

210278
loop {
@@ -386,12 +454,20 @@ impl<S> TcpEventManager<S>
386454
where
387455
S: UsesInput + HasExecutions + HasClientPerfMonitor,
388456
{
389-
/// Create a manager from a raw TCP client
390-
pub fn new<A: ToSocketAddrs>(addr: &A, configuration: EventConfig) -> Result<Self, Error> {
457+
/// Create a manager from a raw TCP client specifying the client id
458+
pub fn existing<A: ToSocketAddrs>(
459+
addr: &A,
460+
client_id: ClientId,
461+
configuration: EventConfig,
462+
) -> Result<Self, Error> {
391463
let mut tcp = TcpStream::connect(addr)?;
392464

393-
let mut our_client_id_buf = [0_u8; 4];
394-
tcp.read_exact(&mut our_client_id_buf).unwrap();
465+
let mut our_client_id_buf = client_id.0.to_le_bytes();
466+
tcp.write_all(&our_client_id_buf)
467+
.expect("Cannot write to the broker");
468+
469+
tcp.read_exact(&mut our_client_id_buf)
470+
.expect("Cannot read from the broker");
395471
let client_id = ClientId(u32::from_le_bytes(our_client_id_buf));
396472

397473
println!("Our client id: {client_id:?}");
@@ -407,15 +483,49 @@ where
407483
})
408484
}
409485

486+
/// Create a manager from a raw TCP client
487+
pub fn new<A: ToSocketAddrs>(addr: &A, configuration: EventConfig) -> Result<Self, Error> {
488+
Self::existing(addr, UNDEFINED_CLIENT_ID, configuration)
489+
}
490+
491+
/// Create an TCP event manager on a port specifying the client id
492+
///
493+
/// If the port is not yet bound, it will act as a broker; otherwise, it
494+
/// will act as a client.
495+
pub fn existing_on_port(
496+
port: u16,
497+
client_id: ClientId,
498+
configuration: EventConfig,
499+
) -> Result<Self, Error> {
500+
Self::existing(&("127.0.0.1", port), client_id, configuration)
501+
}
502+
410503
/// Create an TCP event manager on a port
411504
///
412505
/// If the port is not yet bound, it will act as a broker; otherwise, it
413506
/// will act as a client.
414-
#[cfg(feature = "std")]
415507
pub fn on_port(port: u16, configuration: EventConfig) -> Result<Self, Error> {
416508
Self::new(&("127.0.0.1", port), configuration)
417509
}
418510

511+
/// Create an TCP event manager on a port specifying the client id from env
512+
///
513+
/// If the port is not yet bound, it will act as a broker; otherwise, it
514+
/// will act as a client.
515+
pub fn existing_from_env<A: ToSocketAddrs>(
516+
addr: &A,
517+
env_name: &str,
518+
configuration: EventConfig,
519+
) -> Result<Self, Error> {
520+
let this_id = ClientId(str::parse::<u32>(&env::var(env_name)?)?);
521+
Self::existing(addr, this_id, configuration)
522+
}
523+
524+
/// Write the client id for a client [`EventManager`] to env vars
525+
pub fn to_env(&self, env_name: &str) {
526+
env::set_var(env_name, format!("{}", self.client_id.0));
527+
}
528+
419529
// Handle arriving events in the client
420530
#[allow(clippy::unused_self)]
421531
fn handle_in_client<E, Z>(
@@ -731,8 +841,11 @@ where
731841
fn on_restart(&mut self, state: &mut S) -> Result<(), Error> {
732842
// First, reset the page to 0 so the next iteration can read read from the beginning of this page
733843
self.staterestorer.reset();
734-
self.staterestorer
735-
.save(&if self.save_state { Some(state) } else { None })?;
844+
self.staterestorer.save(&if self.save_state {
845+
Some((state, self.tcp_mgr.client_id))
846+
} else {
847+
None
848+
})?;
736849
self.await_restart_safe();
737850
Ok(())
738851
}
@@ -938,7 +1051,7 @@ where
9381051
};
9391052

9401053
// We get here if we are on Unix, or we are a broker on Windows (or without forks).
941-
let (_mgr, core_id) = match self.kind {
1054+
let (mgr, core_id) = match self.kind {
9421055
ManagerKind::Any => {
9431056
let connection = create_nonblocking_listener(("127.0.0.1", self.broker_port));
9441057
match connection {
@@ -994,7 +1107,7 @@ where
9941107
}
9951108

9961109
// We are the fuzzer respawner in a tcp client
997-
//mgr.to_env(_ENV_FUZZER_BROKER_CLIENT_INITIAL);
1110+
mgr.to_env(_ENV_FUZZER_BROKER_CLIENT_INITIAL);
9981111

9991112
// First, create a channel from the current fuzzer to the next to store state between restarts.
10001113
#[cfg(unix)]
@@ -1030,6 +1143,7 @@ where
10301143
// Client->parent loop
10311144
loop {
10321145
log::info!("Spawning next client (id {ctr})");
1146+
println!("Spawning next client (id {ctr}) {core_id:?}");
10331147

10341148
// On Unix, we fork (when fork feature is enabled)
10351149
#[cfg(all(unix, feature = "fork"))]
@@ -1091,19 +1205,27 @@ where
10911205
}
10921206

10931207
// If we're restarting, deserialize the old state.
1094-
let (state, mut mgr) = if let Some(state_opt) = staterestorer.restore()? {
1208+
let (state, mut mgr) = if let Some((state_opt, this_id)) = staterestorer.restore()? {
10951209
(
10961210
state_opt,
10971211
TcpRestartingEventManager::with_save_state(
1098-
TcpEventManager::on_port(self.broker_port, self.configuration)?,
1212+
TcpEventManager::existing_on_port(
1213+
self.broker_port,
1214+
this_id,
1215+
self.configuration,
1216+
)?,
10991217
staterestorer,
11001218
self.serialize_state,
11011219
),
11021220
)
11031221
} else {
11041222
log::info!("First run. Let's set it all up");
11051223
// Mgr to send and receive msgs from/to all other fuzzer instances
1106-
let mgr = TcpEventManager::<S>::on_port(self.broker_port, self.configuration)?;
1224+
let mgr = TcpEventManager::<S>::existing_from_env(
1225+
&("127.0.0.1", self.broker_port),
1226+
_ENV_FUZZER_BROKER_CLIENT_INITIAL,
1227+
self.configuration,
1228+
)?;
11071229

11081230
(
11091231
None,

0 commit comments

Comments
 (0)