@@ -88,13 +88,20 @@ pub async fn send_bytes(payload: &[u8], stream: &mut GenericStream) -> Result<()
8888 Ok ( ( ) )
8989}
9090
91+ pub async fn receive_bytes ( stream : & mut GenericStream ) -> Result < Vec < u8 > , Error > {
92+ receive_bytes_with_max_size ( stream, None ) . await
93+ }
94+
9195/// Receive a byte stream. \
9296/// This is part of the basic protocol beneath all communication. \
9397///
9498/// 1. First of, the client sends a u64 as a 4byte vector in BigEndian mode, which specifies the
9599/// length of the payload we're going to receive.
96100/// 2. Receive chunks of [PACKET_SIZE] bytes until we finished all expected bytes.
97- pub async fn receive_bytes ( stream : & mut GenericStream ) -> Result < Vec < u8 > , Error > {
101+ pub async fn receive_bytes_with_max_size (
102+ stream : & mut GenericStream ,
103+ max_size : Option < usize > ,
104+ ) -> Result < Vec < u8 > , Error > {
98105 // Receive the header with the overall message size
99106 let mut header = vec ! [ 0 ; 8 ] ;
100107 stream
@@ -104,6 +111,21 @@ pub async fn receive_bytes(stream: &mut GenericStream) -> Result<Vec<u8>, Error>
104111 let mut header = Cursor :: new ( header) ;
105112 let message_size = ReadBytesExt :: read_u64 :: < BigEndian > ( & mut header) ? as usize ;
106113
114+ if let Some ( max_size) = max_size {
115+ if message_size > max_size {
116+ error ! (
117+ "Client requested message size of {message_size}, but only {max_size} is allowed."
118+ ) ;
119+ return Err ( Error :: MessageTooBig ( message_size, max_size) ) ;
120+ }
121+ }
122+
123+ // Show a warning if we see unusually large payloads. In this case payloads that're bigger than
124+ // 20MB, which is pretty large considering pueue is usually only sending a bit of text.
125+ if message_size > ( 20 * ( 2usize . pow ( 20 ) ) ) {
126+ warn ! ( "Client is sending a large payload: {message_size} bytes." ) ;
127+ }
128+
107129 // Buffer for the whole payload
108130 let mut payload_bytes = Vec :: with_capacity ( message_size) ;
109131
@@ -281,4 +303,104 @@ mod test {
281303
282304 Ok ( ( ) )
283305 }
306+
307+ use tracing:: level_filters:: LevelFilter ;
308+ use tracing_subscriber:: {
309+ EnvFilter , Layer , Registry , field:: MakeExt , filter:: FromEnvError , fmt:: time:: ChronoLocal ,
310+ layer:: SubscriberExt , util:: SubscriberInitExt ,
311+ } ;
312+
313+ pub fn install_tracing ( verbosity : u8 ) -> Result < ( ) , FromEnvError > {
314+ let mut pretty = false ;
315+ let level = match verbosity {
316+ 0 => LevelFilter :: WARN ,
317+ 1 => LevelFilter :: INFO ,
318+ 2 => LevelFilter :: DEBUG ,
319+ 3 => LevelFilter :: TRACE ,
320+ _ => {
321+ pretty = true ;
322+ LevelFilter :: TRACE
323+ }
324+ } ;
325+
326+ // tries to find local offset internally
327+ let timer = ChronoLocal :: new ( "%H:%M:%S" . into ( ) ) ;
328+
329+ type GenericLayer < S > = Box < dyn tracing_subscriber:: Layer < S > + Send + Sync > ;
330+ let fmt_layer: GenericLayer < _ > = match pretty {
331+ false => Box :: new (
332+ tracing_subscriber:: fmt:: layer ( )
333+ . map_fmt_fields ( |f| f. debug_alt ( ) )
334+ . with_timer ( timer)
335+ . with_writer ( std:: io:: stderr) ,
336+ ) ,
337+ true => Box :: new (
338+ tracing_subscriber:: fmt:: layer ( )
339+ . pretty ( )
340+ . with_timer ( timer)
341+ . with_target ( true )
342+ . with_thread_ids ( false )
343+ . with_thread_names ( true )
344+ . with_level ( true )
345+ . with_ansi ( true )
346+ . with_span_events ( tracing_subscriber:: fmt:: format:: FmtSpan :: ACTIVE )
347+ . with_writer ( std:: io:: stderr) ,
348+ ) ,
349+ } ;
350+ let filter_layer = EnvFilter :: builder ( )
351+ . with_default_directive ( level. into ( ) )
352+ . from_env ( ) ?;
353+
354+ Registry :: default ( )
355+ . with ( fmt_layer. with_filter ( filter_layer) )
356+ . with ( tracing_error:: ErrorLayer :: default ( ) )
357+ . init ( ) ;
358+
359+ Ok ( ( ) )
360+ }
361+
362+ /// Ensure there's no OOM if a huge payload during the handshake phase is being requested.
363+ ///
364+ /// We limit the receiving buffer to ~4MB for the incoming secret to prevent (potentially
365+ /// unintended) DoS attacks when something connect to Pueue and sends a malformed secret
366+ /// payload.
367+ #[ tokio:: test]
368+ async fn test_restricted_payload_size ( ) -> Result < ( ) , Error > {
369+ install_tracing ( 3 )
370+ . expect ( "Couldn't init tracing for test, have you initialised tracing twice?" ) ;
371+ let listener = TcpListener :: bind ( "127.0.0.1:0" ) . await ?;
372+ let addr = listener. local_addr ( ) ?;
373+
374+ let listener: GenericListener = Box :: new ( listener) ;
375+
376+ // Spawn a sub thread that:
377+ // 1. Accepts a new connection.
378+ // 2. Sends a malformed payload.
379+ task:: spawn ( async move {
380+ let mut stream = listener. accept ( ) . await . unwrap ( ) ;
381+
382+ // Send a payload of 9 bytes to the daemon receiver.
383+ // The first 8 bytes determine the payload size in BigEndian.
384+ // This payload requests 2^64 bytes of memory for the secret.
385+ stream
386+ . write_all ( & [ 128 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ] )
387+ . await
388+ . unwrap ( ) ;
389+ } ) ;
390+
391+ // Create a receiver stream
392+ let mut client: GenericStream = Box :: new ( TcpStream :: connect ( & addr) . await ?) ;
393+ // Wait for a short time to allow the sender to send the message
394+ tokio:: time:: sleep ( Duration :: from_millis ( 500 ) ) . await ;
395+
396+ // Get the message while restricting the payload size to 4MB
397+ let result = receive_bytes_with_max_size ( & mut client, Some ( 4 * 2usize . pow ( 20 ) ) ) . await ;
398+
399+ assert ! (
400+ result. is_err( ) ,
401+ "The payload should be rejected due to large size"
402+ ) ;
403+
404+ Ok ( ( ) )
405+ }
284406}
0 commit comments