Skip to content

Commit a12f337

Browse files
authored
Merge pull request #23 from rustaceanrob/10-23-kill
Add a `shutdown` on the writer
2 parents 3df56ec + 38d005e commit a12f337

File tree

1 file changed

+32
-13
lines changed

1 file changed

+32
-13
lines changed

src/net.rs

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use std::{
33
io::{self, BufRead, BufReader, Write},
44
net::{SocketAddr, TcpListener, TcpStream},
55
sync::{
6-
mpsc::{self, SendError},
6+
mpsc::{self},
77
Arc, Mutex,
88
},
99
thread::JoinHandle,
@@ -205,22 +205,34 @@ impl Default for TimeoutParams {
205205
}
206206
}
207207

208+
#[derive(Debug)]
209+
enum WriteRequest {
210+
Shutdown,
211+
SendMessage(NetworkMessage),
212+
}
213+
208214
/// Send messages to an open connection.
209215
#[derive(Debug)]
210216
pub struct ConnectionWriter {
211-
sender: mpsc::Sender<NetworkMessage>,
217+
sender: mpsc::Sender<WriteRequest>,
212218
task_handle: JoinHandle<Result<(), io::Error>>,
213219
}
214220

221+
#[allow(clippy::result_large_err)]
215222
impl ConnectionWriter {
216223
/// Send a network message to this peer. Errors indicate that the connection is terminated and
217224
/// no further messages will succeed.
218-
#[allow(clippy::result_large_err)]
219-
pub fn send_message(
220-
&self,
221-
network_message: NetworkMessage,
222-
) -> Result<(), SendError<NetworkMessage>> {
223-
self.sender.send(network_message)
225+
pub fn send_message(&self, network_message: NetworkMessage) -> Result<(), Error> {
226+
self.sender
227+
.send(WriteRequest::SendMessage(network_message))
228+
.map_err(|_| Error::ChannelClosed)
229+
}
230+
231+
/// Kill both sides of the connection, erroring if the stream is already closed.
232+
pub fn shutdown(&self) -> Result<(), Error> {
233+
self.sender
234+
.send(WriteRequest::Shutdown)
235+
.map_err(|_| Error::ChannelClosed)
224236
}
225237

226238
/// In the event of a failed message, investigate IO related failures if the connection was not
@@ -234,7 +246,7 @@ impl ConnectionWriter {
234246
struct OpenWriter {
235247
tcp_stream: TcpStream,
236248
transport: WriteTransport,
237-
receiver: mpsc::Receiver<NetworkMessage>,
249+
receiver: mpsc::Receiver<WriteRequest>,
238250
outbound_ping_state: Arc<Mutex<OutboundPing>>,
239251
ping_interval: Duration,
240252
}
@@ -244,10 +256,14 @@ impl OpenWriter {
244256
loop {
245257
let message = self.receiver.recv_timeout(Duration::from_secs(1));
246258
match message {
247-
Ok(network_message) => {
248-
self.transport
249-
.write_message(network_message, &mut self.tcp_stream)?;
250-
}
259+
Ok(request) => match request {
260+
WriteRequest::SendMessage(message) => self
261+
.transport
262+
.write_message(message, &mut self.tcp_stream)?,
263+
WriteRequest::Shutdown => {
264+
self.tcp_stream.shutdown(std::net::Shutdown::Both)?;
265+
}
266+
},
251267
Err(e) => match e {
252268
mpsc::RecvTimeoutError::Timeout => (),
253269
_ => return Ok(()),
@@ -436,6 +452,8 @@ pub enum Error {
436452
UnexpectedMagic(Magic),
437453
/// The peer did not send a version message.
438454
MissingVersion,
455+
/// The channel to the message writing thread was closed.
456+
ChannelClosed,
439457
}
440458

441459
impl Display for Error {
@@ -446,6 +464,7 @@ impl Display for Error {
446464
Error::Handshake(e) => e.fmt(f),
447465
Error::UnexpectedMagic(magic) => write!(f, "unexpected network magic: {magic}"),
448466
Error::MissingVersion => write!(f, "missing version message."),
467+
Error::ChannelClosed => write!(f, "channel closed"),
449468
}
450469
}
451470
}

0 commit comments

Comments
 (0)