@@ -36,6 +36,11 @@ use log::{debug, info, warn};
3636use env_logger;
3737use crate :: response:: BodyWithTrailers ;
3838use std:: sync:: Once ;
39+ use tokio:: time:: timeout;
40+ use std:: future:: Future ;
41+ use std:: pin:: Pin ;
42+ use std:: task:: { Context , Poll } ;
43+ use tokio:: time:: { Sleep , sleep} ;
3944
4045static LOGGER_INIT : Once = Once :: new ( ) ;
4146
@@ -47,6 +52,7 @@ struct ServerConfig {
4752 bind_address : String ,
4853 tokio_threads : Option < usize > ,
4954 debug : bool ,
55+ recv_timeout : u64 ,
5056}
5157
5258impl ServerConfig {
@@ -55,6 +61,7 @@ impl ServerConfig {
5561 bind_address : String :: from ( "127.0.0.1:3000" ) ,
5662 tokio_threads : None ,
5763 debug : false ,
64+ recv_timeout : 30000 , // Default 30 second timeout
5865 }
5966 }
6067}
@@ -101,6 +108,10 @@ impl Server {
101108 server_config. debug = bool:: try_convert ( debug) ?;
102109 }
103110
111+ if let Some ( recv_timeout) = config. get ( magnus:: Symbol :: new ( "recv_timeout" ) ) {
112+ server_config. recv_timeout = u64:: try_convert ( recv_timeout) ?;
113+ }
114+
104115 // Initialize logging if debug is enabled, but only do it once
105116 if server_config. debug {
106117 LOGGER_INIT . call_once ( || {
@@ -215,16 +226,19 @@ impl Server {
215226 let work_tx = work_tx. clone ( ) ;
216227
217228 let server_task = tokio:: spawn ( async move {
229+ let timer = hyper_util:: rt:: TokioTimer :: new ( ) ;
230+
218231 if config. bind_address . starts_with ( "unix:" ) {
219232 let path = config. bind_address . trim_start_matches ( "unix:" ) ;
220233 let listener = UnixListener :: bind ( path) . unwrap ( ) ;
221234
222235 loop {
223236 let ( stream, _) = listener. accept ( ) . await . unwrap ( ) ;
224237 let work_tx = work_tx. clone ( ) ;
238+ let timer = timer. clone ( ) ;
225239
226240 tokio:: task:: spawn ( async move {
227- handle_connection ( stream, work_tx) . await ;
241+ handle_connection ( stream, work_tx, config . recv_timeout , timer ) . await ;
228242 } ) ;
229243 }
230244 } else {
@@ -235,9 +249,10 @@ impl Server {
235249 loop {
236250 let ( stream, _) = listener. accept ( ) . await . unwrap ( ) ;
237251 let work_tx = work_tx. clone ( ) ;
252+ let timer = timer. clone ( ) ;
238253
239254 tokio:: task:: spawn ( async move {
240- handle_connection ( stream, work_tx) . await ;
255+ handle_connection ( stream, work_tx, config . recv_timeout , timer ) . await ;
241256 } ) ;
242257 }
243258 }
@@ -282,19 +297,27 @@ impl Server {
282297async fn handle_request (
283298 req : HyperRequest < Incoming > ,
284299 work_tx : Arc < crossbeam_channel:: Sender < RequestWithCompletion > > ,
300+ recv_timeout : u64 ,
285301) -> Result < HyperResponse < BodyWithTrailers > , Error > {
286302 debug ! ( "Received request: {:?}" , req) ;
287303 debug ! ( "HTTP version: {:?}" , req. version( ) ) ;
288304 debug ! ( "Headers: {:?}" , req. headers( ) ) ;
289305
290306 let ( parts, body) = req. into_parts ( ) ;
291307
292- // Collect the body
293- let body_bytes = match body. collect ( ) . await {
294- Ok ( collected) => collected. to_bytes ( ) ,
295- Err ( e) => {
308+ // Collect the body with timeout
309+ let body_bytes = match timeout (
310+ std:: time:: Duration :: from_millis ( recv_timeout) ,
311+ body. collect ( )
312+ ) . await {
313+ Ok ( Ok ( collected) ) => collected. to_bytes ( ) ,
314+ Ok ( Err ( e) ) => {
296315 debug ! ( "Error collecting body: {:?}" , e) ;
297316 return Err ( e) ;
317+ } ,
318+ Err ( _) => {
319+ debug ! ( "Timeout collecting body" ) ;
320+ return Ok ( create_timeout_response ( ) ) ;
298321 }
299322 } ;
300323
@@ -336,23 +359,42 @@ async fn handle_request(
336359 }
337360}
338361
362+ fn create_timeout_response ( ) -> HyperResponse < BodyWithTrailers > {
363+ let builder = HyperResponse :: builder ( )
364+ . status ( StatusCode :: REQUEST_TIMEOUT )
365+ . header ( "content-type" , "text/plain" ) ;
366+
367+ builder. body ( BodyWithTrailers :: new ( Bytes :: from ( "Request timed out while receiving body" ) , None ) )
368+ . unwrap ( )
369+ }
370+
339371async fn handle_connection (
340372 stream : impl tokio:: io:: AsyncRead + tokio:: io:: AsyncWrite + Unpin + Send + ' static ,
341373 work_tx : Arc < crossbeam_channel:: Sender < RequestWithCompletion > > ,
374+ recv_timeout : u64 ,
375+ timer : hyper_util:: rt:: TokioTimer ,
342376) {
343377 info ! ( "New connection established" ) ;
344378
345379 let service = service_fn ( move |req : HyperRequest < Incoming > | {
346380 debug ! ( "Service handling request" ) ;
347381 let work_tx = work_tx. clone ( ) ;
348- handle_request ( req, work_tx)
382+ handle_request ( req, work_tx, recv_timeout )
349383 } ) ;
350384
351385 let io = TokioIo :: new ( stream) ;
352386
353- debug ! ( "Setting up HTTP/2 connection" ) ;
354- let builder = auto:: Builder :: new ( hyper_util:: rt:: TokioExecutor :: new ( ) ) ;
355-
387+ debug ! ( "Setting up connection" ) ;
388+ let mut builder = auto:: Builder :: new ( hyper_util:: rt:: TokioExecutor :: new ( ) ) ;
389+
390+ builder. http1 ( )
391+ . header_read_timeout ( std:: time:: Duration :: from_millis ( recv_timeout) )
392+ . timer ( timer. clone ( ) ) ;
393+
394+ builder. http2 ( )
395+ . keep_alive_interval ( std:: time:: Duration :: from_secs ( 10 ) )
396+ . timer ( timer) ;
397+
356398 if let Err ( err) = builder
357399 . serve_connection ( io, service)
358400 . await
0 commit comments