@@ -25,10 +25,27 @@ use std::{error, fmt, io};
2525
2626// TODO: these methods could be on an Ext trait to AsyncWrite
2727
28+ /// Writes a message to the given socket with a length prefix appended to it. Also flushes the socket.
29+ ///
30+ /// > **Note**: Prepends a variable-length prefix indicate the length of the message. This is
31+ /// > compatible with what [`read_length_prefixed`] expects.
32+ pub async fn write_length_prefixed ( socket : & mut ( impl AsyncWrite + Unpin ) , data : impl AsRef < [ u8 ] > )
33+ -> Result < ( ) , io:: Error >
34+ {
35+ write_varint ( socket, data. as_ref ( ) . len ( ) ) . await ?;
36+ socket. write_all ( data. as_ref ( ) ) . await ?;
37+ socket. flush ( ) . await ?;
38+
39+ Ok ( ( ) )
40+ }
41+
2842/// Send a message to the given socket, then shuts down the writing side.
2943///
3044/// > **Note**: Prepends a variable-length prefix indicate the length of the message. This is
3145/// > compatible with what `read_one` expects.
46+ ///
47+ #[ deprecated( since = "0.29.0" , note = "Use `write_length_prefixed` instead. You will need to manually close the stream using `socket.close().await`." ) ]
48+ #[ allow( dead_code) ]
3249pub async fn write_one ( socket : & mut ( impl AsyncWrite + Unpin ) , data : impl AsRef < [ u8 ] > )
3350 -> Result < ( ) , io:: Error >
3451{
@@ -42,6 +59,8 @@ pub async fn write_one(socket: &mut (impl AsyncWrite + Unpin), data: impl AsRef<
4259///
4360/// > **Note**: Prepends a variable-length prefix indicate the length of the message. This is
4461/// > compatible with what `read_one` expects.
62+ #[ deprecated( since = "0.29.0" , note = "Use `write_length_prefixed` instead." ) ]
63+ #[ allow( dead_code) ]
4564pub async fn write_with_len_prefix ( socket : & mut ( impl AsyncWrite + Unpin ) , data : impl AsRef < [ u8 ] > )
4665 -> Result < ( ) , io:: Error >
4766{
@@ -60,6 +79,7 @@ pub async fn write_varint(socket: &mut (impl AsyncWrite + Unpin), len: usize)
6079 let mut len_data = unsigned_varint:: encode:: usize_buffer ( ) ;
6180 let encoded_len = unsigned_varint:: encode:: usize ( len, & mut len_data) . len ( ) ;
6281 socket. write_all ( & len_data[ ..encoded_len] ) . await ?;
82+
6383 Ok ( ( ) )
6484}
6585
@@ -106,6 +126,27 @@ pub async fn read_varint(socket: &mut (impl AsyncRead + Unpin)) -> Result<usize,
106126 }
107127}
108128
129+ /// Reads a length-prefixed message from the given socket.
130+ ///
131+ /// The `max_size` parameter is the maximum size in bytes of the message that we accept. This is
132+ /// necessary in order to avoid DoS attacks where the remote sends us a message of several
133+ /// gigabytes.
134+ ///
135+ /// > **Note**: Assumes that a variable-length prefix indicates the length of the message. This is
136+ /// > compatible with what [`write_length_prefixed`] does.
137+ pub async fn read_length_prefixed ( socket : & mut ( impl AsyncRead + Unpin ) , max_size : usize ) -> io:: Result < Vec < u8 > >
138+ {
139+ let len = read_varint ( socket) . await ?;
140+ if len > max_size {
141+ return Err ( io:: Error :: new ( io:: ErrorKind :: InvalidData , format ! ( "Received data size ({} bytes) exceeds maximum ({} bytes)" , len, max_size) ) )
142+ }
143+
144+ let mut buf = vec ! [ 0 ; len] ;
145+ socket. read_exact ( & mut buf) . await ?;
146+
147+ Ok ( buf)
148+ }
149+
109150/// Reads a length-prefixed message from the given socket.
110151///
111152/// The `max_size` parameter is the maximum size in bytes of the message that we accept. This is
@@ -114,6 +155,8 @@ pub async fn read_varint(socket: &mut (impl AsyncRead + Unpin)) -> Result<usize,
114155///
115156/// > **Note**: Assumes that a variable-length prefix indicates the length of the message. This is
116157/// > compatible with what `write_one` does.
158+ #[ deprecated( since = "0.29.0" , note = "Use `read_length_prefixed` instead." ) ]
159+ #[ allow( dead_code, deprecated) ]
117160pub async fn read_one ( socket : & mut ( impl AsyncRead + Unpin ) , max_size : usize )
118161 -> Result < Vec < u8 > , ReadOneError >
119162{
@@ -132,6 +175,7 @@ pub async fn read_one(socket: &mut (impl AsyncRead + Unpin), max_size: usize)
132175
133176/// Error while reading one message.
134177#[ derive( Debug ) ]
178+ #[ deprecated( since = "0.29.0" , note = "Use `read_length_prefixed` instead of `read_one` to avoid depending on this type." ) ]
135179pub enum ReadOneError {
136180 /// Error on the socket.
137181 Io ( std:: io:: Error ) ,
@@ -144,12 +188,14 @@ pub enum ReadOneError {
144188 } ,
145189}
146190
191+ #[ allow( deprecated) ]
147192impl From < std:: io:: Error > for ReadOneError {
148193 fn from ( err : std:: io:: Error ) -> ReadOneError {
149194 ReadOneError :: Io ( err)
150195 }
151196}
152197
198+ #[ allow( deprecated) ]
153199impl fmt:: Display for ReadOneError {
154200 fn fmt ( & self , f : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
155201 match * self {
@@ -159,6 +205,7 @@ impl fmt::Display for ReadOneError {
159205 }
160206}
161207
208+ #[ allow( deprecated) ]
162209impl error:: Error for ReadOneError {
163210 fn source ( & self ) -> Option < & ( dyn error:: Error + ' static ) > {
164211 match * self {
@@ -173,15 +220,18 @@ mod tests {
173220 use super :: * ;
174221
175222 #[ test]
176- fn write_one_works ( ) {
223+ fn write_length_prefixed_works ( ) {
177224 let data = ( 0 ..rand:: random :: < usize > ( ) % 10_000 )
178225 . map ( |_| rand:: random :: < u8 > ( ) )
179226 . collect :: < Vec < _ > > ( ) ;
180-
181227 let mut out = vec ! [ 0 ; 10_000 ] ;
182- futures:: executor:: block_on (
183- write_one ( & mut futures:: io:: Cursor :: new ( & mut out[ ..] ) , data. clone ( ) )
184- ) . unwrap ( ) ;
228+
229+ futures:: executor:: block_on ( async {
230+ let mut socket = futures:: io:: Cursor :: new ( & mut out[ ..] ) ;
231+
232+ write_length_prefixed ( & mut socket, & data) . await . unwrap ( ) ;
233+ socket. close ( ) . await . unwrap ( ) ;
234+ } ) ;
185235
186236 let ( out_len, out_data) = unsigned_varint:: decode:: usize ( & out) . unwrap ( ) ;
187237 assert_eq ! ( out_len, data. len( ) ) ;
0 commit comments