@@ -105,17 +105,17 @@ impl<S: Read + Write> Client<S> {
105105 /// panic!("unexpected result")
106106 /// }
107107 ///
108- /// tpkt = tpkt::Client::new(link::Link::new(link::Stream::Raw(Cursor::new(vec![0, 6 , 0, 0, 0, 0]))));
108+ /// tpkt = tpkt::Client::new(link::Link::new(link::Stream::Raw(Cursor::new(vec![0, 7, 0 , 0, 0, 0, 0]))));
109109 /// if let tpkt::Payload::FastPath(_, c) = tpkt.read().unwrap() {
110- /// assert_eq!(c.into_inner(), vec![0, 0, 0, 0])
110+ /// assert_eq!(c.into_inner(), vec![0, 0, 0, 0, 0 ])
111111 /// }
112112 /// else {
113113 /// panic!("unexpected result")
114114 /// }
115115 ///
116- /// tpkt = tpkt::Client::new(link::Link::new(link::Stream::Raw(Cursor::new(vec![0, 0x80, 7 , 0, 0, 0, 0]))));
116+ /// tpkt = tpkt::Client::new(link::Link::new(link::Stream::Raw(Cursor::new(vec![0, 0x80, 8, 0 , 0, 0, 0, 0]))));
117117 /// if let tpkt::Payload::FastPath(_, c) = tpkt.read().unwrap() {
118- /// assert_eq!(c.into_inner(), vec![0, 0, 0, 0])
118+ /// assert_eq!(c.into_inner(), vec![0, 0, 0, 0, 0 ])
119119 /// }
120120 /// else {
121121 /// panic!("unexpected result")
@@ -136,9 +136,15 @@ impl<S: Read + Write> Client<S> {
136136 let mut size = U16 :: BE ( 0 ) ;
137137 size. read ( & mut buffer) ?;
138138
139- // now wait for body
140- Ok ( Payload :: Raw ( Cursor :: new ( self . transport . read ( size. inner ( ) as usize - 4 ) ?) ) )
141-
139+ // Minimal size must be 7
140+ // https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-rdpbcgr/18a27ef9-6f9a-4501-b000-94b1fe3c2c10
141+ if size. inner ( ) < 7 {
142+ Err ( Error :: RdpError ( RdpError :: new ( RdpErrorKind :: InvalidSize , "Invalid minimal size for TPKT" ) ) )
143+ }
144+ else {
145+ // now wait for body
146+ Ok ( Payload :: Raw ( Cursor :: new ( self . transport . read ( size. inner ( ) as usize - 4 ) ?) ) )
147+ }
142148 } else {
143149 // fast path
144150 let sec_flag = ( action >> 6 ) & 0x3 ;
@@ -147,12 +153,20 @@ impl<S: Read + Write> Client<S> {
147153 if short_length & 0x80 != 0 {
148154 let mut hi_length: u8 = 0 ;
149155 hi_length. read ( & mut Cursor :: new ( self . transport . read ( 1 ) ?) ) ?;
150- let length : u16 = ( ( short_length & !0x80 ) as u16 ) << 8 ;
156+ let length: u16 = ( ( short_length & !0x80 ) as u16 ) << 8 ;
151157 let length = length | hi_length as u16 ;
152- Ok ( Payload :: FastPath ( sec_flag, Cursor :: new ( self . transport . read ( length as usize - 3 ) ?) ) )
158+ if length < 7 {
159+ Err ( Error :: RdpError ( RdpError :: new ( RdpErrorKind :: InvalidSize , "Invalid minimal size for TPKT" ) ) )
160+ } else {
161+ Ok ( Payload :: FastPath ( sec_flag, Cursor :: new ( self . transport . read ( length as usize - 3 ) ?) ) )
162+ }
153163 }
154164 else {
155- Ok ( Payload :: FastPath ( sec_flag, Cursor :: new ( self . transport . read ( short_length as usize - 2 ) ?) ) )
165+ if short_length < 7 {
166+ Err ( Error :: RdpError ( RdpError :: new ( RdpErrorKind :: InvalidSize , "Invalid minimal size for TPKT" ) ) )
167+ } else {
168+ Ok ( Payload :: FastPath ( sec_flag, Cursor :: new ( self . transport . read ( short_length as usize - 2 ) ?) ) )
169+ }
156170 }
157171 }
158172 }
@@ -210,6 +224,7 @@ mod test {
210224 use super :: * ;
211225 use std:: io:: Cursor ;
212226 use model:: data:: { U32 , DataType } ;
227+ use model:: link:: Stream ;
213228
214229 /// Test the tpkt header type in write context
215230 #[ test]
@@ -233,4 +248,29 @@ mod test {
233248 assert_eq ! ( cast!( DataType :: U16 , message[ "size" ] ) . unwrap( ) , 8 ) ;
234249 assert_eq ! ( cast!( DataType :: U8 , message[ "action" ] ) . unwrap( ) , Action :: FastPathActionX224 as u8 ) ;
235250 }
251+
252+ fn process ( data : & [ u8 ] ) {
253+ let cur = Cursor :: new ( data. to_vec ( ) ) ;
254+ let link = Link :: new ( Stream :: Raw ( cur) ) ;
255+ let mut client = Client :: new ( link) ;
256+ let _ = client. read ( ) ;
257+ }
258+
259+ #[ test]
260+ fn test_tpkt_size_overflow_case_1 ( ) {
261+ let buf = b"\x00 \x00 \x03 \x00 \x00 \x00 " ;
262+ process ( buf) ;
263+ }
264+
265+ #[ test]
266+ fn test_tpkt_size_overflow_case_2 ( ) {
267+ let buf = b"\x00 \x80 \x00 \x00 \x00 \x00 " ;
268+ process ( buf) ;
269+ }
270+
271+ #[ test]
272+ fn test_tpkt_size_overflow_case_3 ( ) {
273+ let buf = b"\x03 \xe8 \x00 \x00 \x80 \x00 " ;
274+ process ( buf) ;
275+ }
236276}
0 commit comments