3131
3232use bitcoin:: secp256k1:: PublicKey ;
3333
34- use tokio:: net:: TcpStream ;
34+ use tokio:: net:: { tcp , TcpStream } ;
3535use tokio:: { io, time} ;
3636use tokio:: sync:: mpsc;
37- use tokio:: io:: { AsyncReadExt , AsyncWrite , AsyncWriteExt } ;
37+ use tokio:: io:: AsyncWrite ;
3838
3939use lightning:: ln:: peer_handler;
4040use lightning:: ln:: peer_handler:: SocketDescriptor as LnSocketTrait ;
@@ -59,7 +59,7 @@ static ID_COUNTER: AtomicU64 = AtomicU64::new(0);
5959// define a trivial two- and three- select macro with the specific types we need and just use that.
6060
6161pub ( crate ) enum SelectorOutput {
62- A ( Option < ( ) > ) , B ( Option < ( ) > ) , C ( tokio:: io:: Result < usize > ) ,
62+ A ( Option < ( ) > ) , B ( Option < ( ) > ) , C ( tokio:: io:: Result < ( ) > ) ,
6363}
6464
6565pub ( crate ) struct TwoSelector <
@@ -87,15 +87,15 @@ impl<
8787}
8888
8989pub ( crate ) struct ThreeSelector <
90- A : Future < Output =Option < ( ) > > + Unpin , B : Future < Output =Option < ( ) > > + Unpin , C : Future < Output =tokio:: io:: Result < usize > > + Unpin
90+ A : Future < Output =Option < ( ) > > + Unpin , B : Future < Output =Option < ( ) > > + Unpin , C : Future < Output =tokio:: io:: Result < ( ) > > + Unpin
9191> {
9292 pub a : A ,
9393 pub b : B ,
9494 pub c : C ,
9595}
9696
9797impl <
98- A : Future < Output =Option < ( ) > > + Unpin , B : Future < Output =Option < ( ) > > + Unpin , C : Future < Output =tokio:: io:: Result < usize > > + Unpin
98+ A : Future < Output =Option < ( ) > > + Unpin , B : Future < Output =Option < ( ) > > + Unpin , C : Future < Output =tokio:: io:: Result < ( ) > > + Unpin
9999> Future for ThreeSelector < A , B , C > {
100100 type Output = SelectorOutput ;
101101 fn poll ( mut self : Pin < & mut Self > , ctx : & mut task:: Context < ' _ > ) -> Poll < SelectorOutput > {
@@ -119,7 +119,7 @@ impl<
119119/// Connection object (in an Arc<Mutex<>>) in each SocketDescriptor we create as well as in the
120120/// read future (which is returned by schedule_read).
121121struct Connection {
122- writer : Option < io :: WriteHalf < TcpStream > > ,
122+ writer : Option < Arc < TcpStream > > ,
123123 // Because our PeerManager is templated by user-provided types, and we can't (as far as I can
124124 // tell) have a const RawWakerVTable built out of templated functions, we need some indirection
125125 // between being woken up with write-ready and calling PeerManager::write_buffer_space_avail.
@@ -156,7 +156,7 @@ impl Connection {
156156 async fn schedule_read < PM : Deref + ' static + Send + Sync + Clone > (
157157 peer_manager : PM ,
158158 us : Arc < Mutex < Self > > ,
159- mut reader : io :: ReadHalf < TcpStream > ,
159+ reader : Arc < TcpStream > ,
160160 mut read_wake_receiver : mpsc:: Receiver < ( ) > ,
161161 mut write_avail_receiver : mpsc:: Receiver < ( ) > ,
162162 ) where PM :: Target : APeerManager < Descriptor = SocketDescriptor > {
@@ -200,7 +200,7 @@ impl Connection {
200200 ThreeSelector {
201201 a : Box :: pin ( write_avail_receiver. recv ( ) ) ,
202202 b : Box :: pin ( read_wake_receiver. recv ( ) ) ,
203- c : Box :: pin ( reader. read ( & mut buf ) ) ,
203+ c : Box :: pin ( reader. readable ( ) ) ,
204204 } . await
205205 } ;
206206 match select_result {
@@ -211,8 +211,9 @@ impl Connection {
211211 }
212212 } ,
213213 SelectorOutput :: B ( _) => { } ,
214- SelectorOutput :: C ( read) => {
215- match read {
214+ SelectorOutput :: C ( res) => {
215+ if res. is_err ( ) { break Disconnect :: PeerDisconnected ; }
216+ match reader. try_read ( & mut buf) {
216217 Ok ( 0 ) => break Disconnect :: PeerDisconnected ,
217218 Ok ( len) => {
218219 let read_res = peer_manager. as_ref ( ) . read_event ( & mut our_descriptor, & buf[ 0 ..len] ) ;
@@ -226,7 +227,11 @@ impl Connection {
226227 Err ( _) => break Disconnect :: CloseConnection ,
227228 }
228229 } ,
229- Err ( _) => break Disconnect :: PeerDisconnected ,
230+ Err ( e) if e. kind ( ) == std:: io:: ErrorKind :: WouldBlock => {
231+ // readable() is allowed to spuriously wake, so we have to handle
232+ // WouldBlock here.
233+ } ,
234+ Err ( e) => break Disconnect :: PeerDisconnected ,
230235 }
231236 } ,
232237 }
@@ -239,18 +244,14 @@ impl Connection {
239244 // here.
240245 let _ = tokio:: task:: yield_now ( ) . await ;
241246 } ;
242- let writer_option = us. lock ( ) . unwrap ( ) . writer . take ( ) ;
243- if let Some ( mut writer) = writer_option {
244- // If the socket is already closed, shutdown() will fail, so just ignore it.
245- let _ = writer. shutdown ( ) . await ;
246- }
247+ us. lock ( ) . unwrap ( ) . writer . take ( ) ;
247248 if let Disconnect :: PeerDisconnected = disconnect_type {
248249 peer_manager. as_ref ( ) . socket_disconnected ( & our_descriptor) ;
249250 peer_manager. as_ref ( ) . process_events ( ) ;
250251 }
251252 }
252253
253- fn new ( stream : StdTcpStream ) -> ( io :: ReadHalf < TcpStream > , mpsc:: Receiver < ( ) > , mpsc:: Receiver < ( ) > , Arc < Mutex < Self > > ) {
254+ fn new ( stream : StdTcpStream ) -> ( Arc < TcpStream > , mpsc:: Receiver < ( ) > , mpsc:: Receiver < ( ) > , Arc < Mutex < Self > > ) {
254255 // We only ever need a channel of depth 1 here: if we returned a non-full write to the
255256 // PeerManager, we will eventually get notified that there is room in the socket to write
256257 // new bytes, which will generate an event. That event will be popped off the queue before
@@ -262,11 +263,11 @@ impl Connection {
262263 // false.
263264 let ( read_waker, read_receiver) = mpsc:: channel ( 1 ) ;
264265 stream. set_nonblocking ( true ) . unwrap ( ) ;
265- let ( reader , writer ) = io :: split ( TcpStream :: from_std ( stream) . unwrap ( ) ) ;
266+ let tokio_stream = Arc :: new ( TcpStream :: from_std ( stream) . unwrap ( ) ) ;
266267
267- ( reader , write_receiver, read_receiver,
268+ ( Arc :: clone ( & tokio_stream ) , write_receiver, read_receiver,
268269 Arc :: new ( Mutex :: new ( Self {
269- writer : Some ( writer ) , write_avail, read_waker, read_paused : false ,
270+ writer : Some ( tokio_stream ) , write_avail, read_waker, read_paused : false ,
270271 rl_requested_disconnect : false ,
271272 id : ID_COUNTER . fetch_add ( 1 , Ordering :: AcqRel )
272273 } ) ) )
@@ -462,9 +463,9 @@ impl SocketDescriptor {
462463}
463464impl peer_handler:: SocketDescriptor for SocketDescriptor {
464465 fn send_data ( & mut self , data : & [ u8 ] , resume_read : bool ) -> usize {
465- // To send data, we take a lock on our Connection to access the WriteHalf of the TcpStream,
466- // writing to it if there's room in the kernel buffer, or otherwise create a new Waker with
467- // a SocketDescriptor in it which can wake up the write_avail Sender, waking up the
466+ // To send data, we take a lock on our Connection to access the TcpStream, writing to it if
467+ // there's room in the kernel buffer, or otherwise create a new Waker with a
468+ // SocketDescriptor in it which can wake up the write_avail Sender, waking up the
468469 // processing future which will call write_buffer_space_avail and we'll end up back here.
469470 let mut us = self . conn . lock ( ) . unwrap ( ) ;
470471 if us. writer . is_none ( ) {
@@ -484,24 +485,18 @@ impl peer_handler::SocketDescriptor for SocketDescriptor {
484485 let mut ctx = task:: Context :: from_waker ( & waker) ;
485486 let mut written_len = 0 ;
486487 loop {
487- match std:: pin:: Pin :: new ( us. writer . as_mut ( ) . unwrap ( ) ) . poll_write ( & mut ctx, & data[ written_len..] ) {
488- task:: Poll :: Ready ( Ok ( res) ) => {
489- // The tokio docs *seem* to indicate this can't happen, and I certainly don't
490- // know how to handle it if it does (cause it should be a Poll::Pending
491- // instead):
492- assert_ne ! ( res, 0 ) ;
493- written_len += res;
494- if written_len == data. len ( ) { return written_len; }
495- } ,
496- task:: Poll :: Ready ( Err ( e) ) => {
497- // The tokio docs *seem* to indicate this can't happen, and I certainly don't
498- // know how to handle it if it does (cause it should be a Poll::Pending
499- // instead):
500- assert_ne ! ( e. kind( ) , io:: ErrorKind :: WouldBlock ) ;
501- // Probably we've already been closed, just return what we have and let the
502- // read thread handle closing logic.
503- return written_len;
488+ match us. writer . as_ref ( ) . unwrap ( ) . poll_write_ready ( & mut ctx) {
489+ task:: Poll :: Ready ( Ok ( ( ) ) ) => {
490+ match us. writer . as_ref ( ) . unwrap ( ) . try_write ( & data[ written_len..] ) {
491+ Ok ( res) => {
492+ debug_assert_ne ! ( res, 0 ) ;
493+ written_len += res;
494+ if written_len == data. len ( ) { return written_len; }
495+ } ,
496+ Err ( e) => return written_len,
497+ }
504498 } ,
499+ task:: Poll :: Ready ( Err ( e) ) => return written_len,
505500 task:: Poll :: Pending => {
506501 // We're queued up for a write event now, but we need to make sure we also
507502 // pause read given we're now waiting on the remote end to ACK (and in
0 commit comments