@@ -6,7 +6,7 @@ use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufStream, R
6
6
use tokio:: net:: TcpStream ;
7
7
use tracing:: { debug, enabled, trace, Level } ;
8
8
9
- use std:: io:: Error ;
9
+ use std:: io:: { Error , ErrorKind } ;
10
10
use std:: net:: SocketAddr ;
11
11
use std:: ops:: Deref ;
12
12
use std:: pin:: Pin ;
@@ -108,8 +108,8 @@ impl Stream {
108
108
pub async fn check ( & mut self ) -> Result < ( ) , crate :: net:: Error > {
109
109
let mut buf = [ 0u8 ; 1 ] ;
110
110
match self {
111
- Self :: Plain ( plain) => plain. get_mut ( ) . peek ( & mut buf) . await ?,
112
- Self :: Tls ( tls) => tls. get_mut ( ) . get_mut ( ) . 0 . peek ( & mut buf) . await ?,
111
+ Self :: Plain ( plain) => eof ( plain. get_mut ( ) . peek ( & mut buf) . await ) ?,
112
+ Self :: Tls ( tls) => eof ( tls. get_mut ( ) . get_mut ( ) . 0 . peek ( & mut buf) . await ) ?,
113
113
Self :: DevNull => 0 ,
114
114
} ;
115
115
@@ -126,8 +126,8 @@ impl Stream {
126
126
let bytes = message. to_bytes ( ) ?;
127
127
128
128
match self {
129
- Stream :: Plain ( ref mut stream) => stream. write_all ( & bytes) . await ?,
130
- Stream :: Tls ( ref mut stream) => stream. write_all ( & bytes) . await ?,
129
+ Stream :: Plain ( ref mut stream) => eof ( stream. write_all ( & bytes) . await ) ?,
130
+ Stream :: Tls ( ref mut stream) => eof ( stream. write_all ( & bytes) . await ) ?,
131
131
Self :: DevNull => ( ) ,
132
132
}
133
133
@@ -165,7 +165,7 @@ impl Stream {
165
165
message : & impl Protocol ,
166
166
) -> Result < usize , crate :: net:: Error > {
167
167
let sent = self . send ( message) . await ?;
168
- self . flush ( ) . await ?;
168
+ eof ( self . flush ( ) . await ) ?;
169
169
trace ! ( "😳" ) ;
170
170
171
171
Ok ( sent)
@@ -180,7 +180,7 @@ impl Stream {
180
180
for message in messages {
181
181
sent += self . send ( message) . await ?;
182
182
}
183
- self . flush ( ) . await ?;
183
+ eof ( self . flush ( ) . await ) ?;
184
184
trace ! ( "😳" ) ;
185
185
Ok ( sent)
186
186
}
@@ -199,15 +199,15 @@ impl Stream {
199
199
200
200
/// Read data into a buffer, avoiding unnecessary allocations.
201
201
pub async fn read_buf ( & mut self , bytes : & mut BytesMut ) -> Result < Message , crate :: net:: Error > {
202
- let code = self . read_u8 ( ) . await ?;
203
- let len = self . read_i32 ( ) . await ?;
202
+ let code = eof ( self . read_u8 ( ) . await ) ?;
203
+ let len = eof ( self . read_i32 ( ) . await ) ?;
204
204
205
205
bytes. put_u8 ( code) ;
206
206
bytes. put_i32 ( len) ;
207
207
208
208
// Length must be at least 4 bytes.
209
209
if len < 4 {
210
- return Err ( crate :: net:: Error :: Eof ) ;
210
+ return Err ( crate :: net:: Error :: UnexpectedEof ) ;
211
211
}
212
212
213
213
let capacity = len as usize + 1 ;
@@ -218,7 +218,7 @@ impl Stream {
218
218
bytes. set_len ( capacity) ;
219
219
}
220
220
221
- self . read_exact ( & mut bytes[ 5 ..capacity] ) . await ?;
221
+ eof ( self . read_exact ( & mut bytes[ 5 ..capacity] ) . await ) ?;
222
222
223
223
let message = Message :: new ( bytes. split ( ) . freeze ( ) ) ;
224
224
@@ -261,6 +261,19 @@ impl Stream {
261
261
}
262
262
}
263
263
264
+ fn eof < T > ( result : std:: io:: Result < T > ) -> Result < T , crate :: net:: Error > {
265
+ match result {
266
+ Ok ( val) => Ok ( val) ,
267
+ Err ( err) => {
268
+ if err. kind ( ) == ErrorKind :: UnexpectedEof {
269
+ Err ( crate :: net:: Error :: UnexpectedEof )
270
+ } else {
271
+ Err ( crate :: net:: Error :: Io ( err) )
272
+ }
273
+ }
274
+ }
275
+ }
276
+
264
277
/// Wrapper around SocketAddr
265
278
/// to make it easier to debug.
266
279
pub struct PeerAddr {
0 commit comments