@@ -26,7 +26,7 @@ use std::{
26
26
27
27
use axum:: {
28
28
body:: HttpBody ,
29
- extract:: { FromRef , Path , State } ,
29
+ extract:: { DefaultBodyLimit , FromRef , Path , State } ,
30
30
http:: {
31
31
header:: { CONTENT_TYPE , ETAG , IF_MATCH , IF_NONE_MATCH , LOCATION } ,
32
32
StatusCode ,
@@ -253,6 +253,7 @@ where
253
253
254
254
Router :: new ( )
255
255
. nest ( prefix, router)
256
+ . layer ( DefaultBodyLimit :: disable ( ) )
256
257
. layer ( RequestBodyLimitLayer :: new ( max_bytes) )
257
258
. layer ( SetResponseHeaderLayer :: if_not_present (
258
259
HeaderName :: from_static ( "x-max-bytes" ) ,
@@ -270,14 +271,71 @@ where
270
271
271
272
#[ cfg( test) ]
272
273
mod tests {
274
+ use std:: convert:: Infallible ;
275
+
273
276
use super :: * ;
274
277
275
278
use axum:: http:: {
276
279
header:: { CONTENT_LENGTH , CONTENT_TYPE } ,
277
280
Request ,
278
281
} ;
282
+ use bytes:: Buf ;
279
283
use tower:: util:: ServiceExt ;
280
284
285
+ /// A slow body, which sends the bytes in small chunks (1 byte per chunk by default)
286
+ #[ derive( Clone ) ]
287
+ struct SlowBody {
288
+ body : Bytes ,
289
+ chunk_size : usize ,
290
+ }
291
+
292
+ impl SlowBody {
293
+ const fn from_static ( bytes : & ' static [ u8 ] ) -> Self {
294
+ Self {
295
+ body : Bytes :: from_static ( bytes) ,
296
+ chunk_size : 1 ,
297
+ }
298
+ }
299
+
300
+ const fn from_bytes ( body : Bytes ) -> Self {
301
+ Self {
302
+ body,
303
+ chunk_size : 1 ,
304
+ }
305
+ }
306
+
307
+ const fn with_chunk_size ( mut self , chunk_size : usize ) -> Self {
308
+ self . chunk_size = chunk_size;
309
+ self
310
+ }
311
+ }
312
+
313
+ impl HttpBody for SlowBody {
314
+ type Data = Bytes ;
315
+ type Error = Infallible ;
316
+
317
+ fn poll_data (
318
+ self : std:: pin:: Pin < & mut Self > ,
319
+ _cx : & mut std:: task:: Context < ' _ > ,
320
+ ) -> std:: task:: Poll < Option < Result < Self :: Data , Self :: Error > > > {
321
+ if self . body . is_empty ( ) {
322
+ std:: task:: Poll :: Ready ( None )
323
+ } else {
324
+ let size = self . chunk_size . min ( self . body . len ( ) ) ;
325
+ let ret = self . body . slice ( 0 ..size) ;
326
+ self . get_mut ( ) . body . advance ( size) ;
327
+ std:: task:: Poll :: Ready ( Some ( Ok ( ret) ) )
328
+ }
329
+ }
330
+
331
+ fn poll_trailers (
332
+ self : std:: pin:: Pin < & mut Self > ,
333
+ _cx : & mut std:: task:: Context < ' _ > ,
334
+ ) -> std:: task:: Poll < Result < Option < headers:: HeaderMap > , Self :: Error > > {
335
+ std:: task:: Poll :: Ready ( Ok ( None ) )
336
+ }
337
+ }
338
+
281
339
async fn advance_time ( duration : Duration ) {
282
340
tokio:: task:: yield_now ( ) . await ;
283
341
tokio:: time:: pause ( ) ;
@@ -325,6 +383,63 @@ mod tests {
325
383
assert_eq ! ( response. status( ) , StatusCode :: NOT_FOUND ) ;
326
384
}
327
385
386
+ #[ tokio:: test]
387
+ async fn test_post_max_bytes ( ) {
388
+ let ttl = Duration :: from_secs ( 60 ) ;
389
+
390
+ let body = br#"{"hello": "world"}"# ;
391
+
392
+ // It doesn't work with a way too small size
393
+ let slow_body = SlowBody :: from_static ( body) ;
394
+ let request = Request :: post ( "/" )
395
+ . header ( CONTENT_TYPE , "application/json" )
396
+ . body ( slow_body)
397
+ . unwrap ( ) ;
398
+ let response = router ( "/" , ttl, 8 ) . oneshot ( request) . await . unwrap ( ) ;
399
+ assert_eq ! ( response. status( ) , StatusCode :: PAYLOAD_TOO_LARGE ) ;
400
+
401
+ // It works with exactly the right size
402
+ let slow_body = SlowBody :: from_static ( body) ;
403
+ let request = Request :: post ( "/" )
404
+ . header ( CONTENT_TYPE , "application/json" )
405
+ . body ( slow_body)
406
+ . unwrap ( ) ;
407
+ let response = router ( "/" , ttl, body. len ( ) ) . oneshot ( request) . await . unwrap ( ) ;
408
+ assert_eq ! ( response. status( ) , StatusCode :: CREATED ) ;
409
+
410
+ // It doesn't work even if the size is one too short
411
+ let slow_body = SlowBody :: from_static ( body) ;
412
+ let request = Request :: post ( "/" )
413
+ . header ( CONTENT_TYPE , "application/json" )
414
+ . body ( slow_body)
415
+ . unwrap ( ) ;
416
+ let response = router ( "/" , ttl, body. len ( ) - 1 )
417
+ . oneshot ( request)
418
+ . await
419
+ . unwrap ( ) ;
420
+ assert_eq ! ( response. status( ) , StatusCode :: PAYLOAD_TOO_LARGE ) ;
421
+
422
+ // Try with a big body (4MB), sent in small 128 bytes chunks
423
+ let body = vec ! [ 42 ; 4 * 1024 * 1024 ] . into_boxed_slice ( ) ;
424
+ let slow_body = SlowBody :: from_bytes ( Bytes :: from ( body) ) . with_chunk_size ( 128 ) ;
425
+ let request = Request :: post ( "/" ) . body ( slow_body) . unwrap ( ) ;
426
+ let response = router ( "/" , ttl, 4 * 1024 * 1024 )
427
+ . oneshot ( request)
428
+ . await
429
+ . unwrap ( ) ;
430
+ assert_eq ! ( response. status( ) , StatusCode :: CREATED ) ;
431
+
432
+ // Try with a big body (4MB + 1B), sent in small 128 bytes chunks
433
+ let body = vec ! [ 42 ; 4 * 1024 * 1024 + 1 ] . into_boxed_slice ( ) ;
434
+ let slow_body = SlowBody :: from_bytes ( Bytes :: from ( body) ) . with_chunk_size ( 128 ) ;
435
+ let request = Request :: post ( "/" ) . body ( slow_body) . unwrap ( ) ;
436
+ let response = router ( "/" , ttl, 4 * 1024 * 1024 )
437
+ . oneshot ( request)
438
+ . await
439
+ . unwrap ( ) ;
440
+ assert_eq ! ( response. status( ) , StatusCode :: PAYLOAD_TOO_LARGE ) ;
441
+ }
442
+
328
443
#[ tokio:: test]
329
444
async fn test_post_and_get_if_none_match ( ) {
330
445
let ttl = Duration :: from_secs ( 60 ) ;
0 commit comments