1
- use futures_lite:: { io, prelude:: * } ;
1
+ use futures_lite:: { io, prelude:: * , ready } ;
2
2
use serde:: { de:: DeserializeOwned , Serialize } ;
3
3
4
4
use std:: fmt:: { self , Debug } ;
@@ -56,6 +56,7 @@ pin_project_lite::pin_project! {
56
56
reader: Box <dyn AsyncBufRead + Unpin + Send + Sync + ' static >,
57
57
mime: Mime ,
58
58
length: Option <usize >,
59
+ bytes_read: usize
59
60
}
60
61
}
61
62
@@ -78,6 +79,7 @@ impl Body {
78
79
reader : Box :: new ( io:: empty ( ) ) ,
79
80
mime : mime:: BYTE_STREAM ,
80
81
length : Some ( 0 ) ,
82
+ bytes_read : 0 ,
81
83
}
82
84
}
83
85
@@ -108,6 +110,7 @@ impl Body {
108
110
reader : Box :: new ( reader) ,
109
111
mime : mime:: BYTE_STREAM ,
110
112
length : len,
113
+ bytes_read : 0 ,
111
114
}
112
115
}
113
116
@@ -151,6 +154,7 @@ impl Body {
151
154
mime : mime:: BYTE_STREAM ,
152
155
length : Some ( bytes. len ( ) ) ,
153
156
reader : Box :: new ( io:: Cursor :: new ( bytes) ) ,
157
+ bytes_read : 0 ,
154
158
}
155
159
}
156
160
@@ -200,6 +204,7 @@ impl Body {
200
204
mime : mime:: PLAIN ,
201
205
length : Some ( s. len ( ) ) ,
202
206
reader : Box :: new ( io:: Cursor :: new ( s. into_bytes ( ) ) ) ,
207
+ bytes_read : 0 ,
203
208
}
204
209
}
205
210
@@ -245,6 +250,7 @@ impl Body {
245
250
length : Some ( bytes. len ( ) ) ,
246
251
reader : Box :: new ( io:: Cursor :: new ( bytes) ) ,
247
252
mime : mime:: JSON ,
253
+ bytes_read : 0 ,
248
254
} ;
249
255
Ok ( body)
250
256
}
@@ -309,6 +315,7 @@ impl Body {
309
315
length : Some ( bytes. len ( ) ) ,
310
316
reader : Box :: new ( io:: Cursor :: new ( bytes) ) ,
311
317
mime : mime:: FORM ,
318
+ bytes_read : 0 ,
312
319
} ;
313
320
Ok ( body)
314
321
}
@@ -377,6 +384,7 @@ impl Body {
377
384
mime,
378
385
length : Some ( len as usize ) ,
379
386
reader : Box :: new ( io:: BufReader :: new ( file) ) ,
387
+ bytes_read : 0 ,
380
388
} )
381
389
}
382
390
@@ -418,6 +426,7 @@ impl Debug for Body {
418
426
f. debug_struct ( "Body" )
419
427
. field ( "reader" , & "<hidden>" )
420
428
. field ( "length" , & self . length )
429
+ . field ( "bytes_read" , & self . bytes_read )
421
430
. finish ( )
422
431
}
423
432
}
@@ -459,15 +468,25 @@ impl AsyncRead for Body {
459
468
cx : & mut Context < ' _ > ,
460
469
buf : & mut [ u8 ] ,
461
470
) -> Poll < io:: Result < usize > > {
462
- Pin :: new ( & mut self . reader ) . poll_read ( cx, buf)
471
+ let mut buf = match self . length {
472
+ None => buf,
473
+ Some ( length) if length == self . bytes_read => return Poll :: Ready ( Ok ( 0 ) ) ,
474
+ Some ( length) => {
475
+ let max_len = ( length - self . bytes_read ) . min ( buf. len ( ) ) ;
476
+ & mut buf[ 0 ..max_len]
477
+ }
478
+ } ;
479
+
480
+ let bytes = ready ! ( Pin :: new( & mut self . reader) . poll_read( cx, & mut buf) ) ?;
481
+ self . bytes_read += bytes;
482
+ Poll :: Ready ( Ok ( bytes) )
463
483
}
464
484
}
465
485
466
486
impl AsyncBufRead for Body {
467
487
#[ allow( missing_doc_code_examples) ]
468
488
fn poll_fill_buf ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < io:: Result < & ' _ [ u8 ] > > {
469
- let this = self . project ( ) ;
470
- this. reader . poll_fill_buf ( cx)
489
+ self . project ( ) . reader . poll_fill_buf ( cx)
471
490
}
472
491
473
492
fn consume ( mut self : Pin < & mut Self > , amt : usize ) {
@@ -500,6 +519,7 @@ fn guess_ext(path: &std::path::Path) -> Option<Mime> {
500
519
#[ cfg( test) ]
501
520
mod test {
502
521
use super :: * ;
522
+ use async_std:: io:: Cursor ;
503
523
use serde:: Deserialize ;
504
524
505
525
#[ async_std:: test]
@@ -523,4 +543,74 @@ mod test {
523
543
let res = body. into_form :: < Foo > ( ) . await ;
524
544
assert_eq ! ( res. unwrap_err( ) . status( ) , 422 ) ;
525
545
}
546
+
547
+ async fn read_with_buffers_of_size < R > ( reader : & mut R , size : usize ) -> crate :: Result < String >
548
+ where
549
+ R : AsyncRead + Unpin ,
550
+ {
551
+ let mut return_buffer = vec ! [ ] ;
552
+ loop {
553
+ let mut buf = vec ! [ 0 ; size] ;
554
+ match reader. read ( & mut buf) . await ? {
555
+ 0 => break Ok ( String :: from_utf8 ( return_buffer) ?) ,
556
+ bytes_read => return_buffer. extend_from_slice ( & buf[ ..bytes_read] ) ,
557
+ }
558
+ }
559
+ }
560
+
561
+ #[ async_std:: test]
562
+ async fn attempting_to_read_past_length ( ) -> crate :: Result < ( ) > {
563
+ for buf_len in 1 ..13 {
564
+ let mut body = Body :: from_reader ( Cursor :: new ( "hello world" ) , Some ( 5 ) ) ;
565
+ assert_eq ! (
566
+ read_with_buffers_of_size( & mut body, buf_len) . await ?,
567
+ "hello"
568
+ ) ;
569
+ assert_eq ! ( body. bytes_read, 5 ) ;
570
+ }
571
+
572
+ Ok ( ( ) )
573
+ }
574
+
575
+ #[ async_std:: test]
576
+ async fn attempting_to_read_when_length_is_greater_than_content ( ) -> crate :: Result < ( ) > {
577
+ for buf_len in 1 ..13 {
578
+ let mut body = Body :: from_reader ( Cursor :: new ( "hello world" ) , Some ( 15 ) ) ;
579
+ assert_eq ! (
580
+ read_with_buffers_of_size( & mut body, buf_len) . await ?,
581
+ "hello world"
582
+ ) ;
583
+ assert_eq ! ( body. bytes_read, 11 ) ;
584
+ }
585
+
586
+ Ok ( ( ) )
587
+ }
588
+
589
+ #[ async_std:: test]
590
+ async fn attempting_to_read_when_length_is_exactly_right ( ) -> crate :: Result < ( ) > {
591
+ for buf_len in 1 ..13 {
592
+ let mut body = Body :: from_reader ( Cursor :: new ( "hello world" ) , Some ( 11 ) ) ;
593
+ assert_eq ! (
594
+ read_with_buffers_of_size( & mut body, buf_len) . await ?,
595
+ "hello world"
596
+ ) ;
597
+ assert_eq ! ( body. bytes_read, 11 ) ;
598
+ }
599
+
600
+ Ok ( ( ) )
601
+ }
602
+
603
+ #[ async_std:: test]
604
+ async fn reading_in_various_buffer_lengths_when_there_is_no_length ( ) -> crate :: Result < ( ) > {
605
+ for buf_len in 1 ..13 {
606
+ let mut body = Body :: from_reader ( Cursor :: new ( "hello world" ) , None ) ;
607
+ assert_eq ! (
608
+ read_with_buffers_of_size( & mut body, buf_len) . await ?,
609
+ "hello world"
610
+ ) ;
611
+ assert_eq ! ( body. bytes_read, 11 ) ;
612
+ }
613
+
614
+ Ok ( ( ) )
615
+ }
526
616
}
0 commit comments