@@ -3,7 +3,7 @@ use std::{
33 io:: { self , BufRead , BufReader , Write } ,
44 net:: { SocketAddr , TcpListener , TcpStream } ,
55 sync:: {
6- mpsc:: { self , SendError } ,
6+ mpsc:: { self } ,
77 Arc , Mutex ,
88 } ,
99 thread:: JoinHandle ,
@@ -205,22 +205,34 @@ impl Default for TimeoutParams {
205205 }
206206}
207207
208+ #[ derive( Debug ) ]
209+ enum WriteRequest {
210+ Shutdown ,
211+ SendMessage ( NetworkMessage ) ,
212+ }
213+
208214/// Send messages to an open connection.
209215#[ derive( Debug ) ]
210216pub struct ConnectionWriter {
211- sender : mpsc:: Sender < NetworkMessage > ,
217+ sender : mpsc:: Sender < WriteRequest > ,
212218 task_handle : JoinHandle < Result < ( ) , io:: Error > > ,
213219}
214220
221+ #[ allow( clippy:: result_large_err) ]
215222impl ConnectionWriter {
216223 /// Send a network message to this peer. Errors indicate that the connection is terminated and
217224 /// no further messages will succeed.
218- #[ allow( clippy:: result_large_err) ]
219- pub fn send_message (
220- & self ,
221- network_message : NetworkMessage ,
222- ) -> Result < ( ) , SendError < NetworkMessage > > {
223- self . sender . send ( network_message)
225+ pub fn send_message ( & self , network_message : NetworkMessage ) -> Result < ( ) , Error > {
226+ self . sender
227+ . send ( WriteRequest :: SendMessage ( network_message) )
228+ . map_err ( |_| Error :: ChannelClosed )
229+ }
230+
231+ /// Kill both sides of the connection, erroring if the stream is already closed.
232+ pub fn shutdown ( & self ) -> Result < ( ) , Error > {
233+ self . sender
234+ . send ( WriteRequest :: Shutdown )
235+ . map_err ( |_| Error :: ChannelClosed )
224236 }
225237
226238 /// In the event of a failed message, investigate IO related failures if the connection was not
@@ -234,7 +246,7 @@ impl ConnectionWriter {
234246struct OpenWriter {
235247 tcp_stream : TcpStream ,
236248 transport : WriteTransport ,
237- receiver : mpsc:: Receiver < NetworkMessage > ,
249+ receiver : mpsc:: Receiver < WriteRequest > ,
238250 outbound_ping_state : Arc < Mutex < OutboundPing > > ,
239251 ping_interval : Duration ,
240252}
@@ -244,10 +256,14 @@ impl OpenWriter {
244256 loop {
245257 let message = self . receiver . recv_timeout ( Duration :: from_secs ( 1 ) ) ;
246258 match message {
247- Ok ( network_message) => {
248- self . transport
249- . write_message ( network_message, & mut self . tcp_stream ) ?;
250- }
259+ Ok ( request) => match request {
260+ WriteRequest :: SendMessage ( message) => self
261+ . transport
262+ . write_message ( message, & mut self . tcp_stream ) ?,
263+ WriteRequest :: Shutdown => {
264+ self . tcp_stream . shutdown ( std:: net:: Shutdown :: Both ) ?;
265+ }
266+ } ,
251267 Err ( e) => match e {
252268 mpsc:: RecvTimeoutError :: Timeout => ( ) ,
253269 _ => return Ok ( ( ) ) ,
@@ -436,6 +452,8 @@ pub enum Error {
436452 UnexpectedMagic ( Magic ) ,
437453 /// The peer did not send a version message.
438454 MissingVersion ,
455+ /// The channel to the message writing thread was closed.
456+ ChannelClosed ,
439457}
440458
441459impl Display for Error {
@@ -446,6 +464,7 @@ impl Display for Error {
446464 Error :: Handshake ( e) => e. fmt ( f) ,
447465 Error :: UnexpectedMagic ( magic) => write ! ( f, "unexpected network magic: {magic}" ) ,
448466 Error :: MissingVersion => write ! ( f, "missing version message." ) ,
467+ Error :: ChannelClosed => write ! ( f, "channel closed" ) ,
449468 }
450469 }
451470}
0 commit comments