11use crate :: protocol:: { Message , ParseError , PayloadItem , PayloadSize , RequestHeader } ;
22use bytes:: Bytes ;
3- use futures:: { SinkExt , Stream , StreamExt , channel:: mpsc} ;
3+ use futures:: { Sink , SinkExt , Stream , StreamExt , channel:: mpsc} ;
44use http_body:: { Body , Frame , SizeHint } ;
55use std:: pin:: Pin ;
66use std:: task:: { Context , Poll } ;
@@ -120,6 +120,7 @@ pub(crate) struct BodyReceiver {
120120 signal_sender : mpsc:: Sender < BodyRequestSignal > ,
121121 data_receiver : mpsc:: Receiver < Result < PayloadItem , ParseError > > ,
122122 payload_size : PayloadSize ,
123+ in_flight : bool ,
123124}
124125
125126impl BodyReceiver {
@@ -128,21 +129,7 @@ impl BodyReceiver {
128129 data_receiver : mpsc:: Receiver < Result < PayloadItem , ParseError > > ,
129130 payload_size : PayloadSize ,
130131 ) -> Self {
131- Self { signal_sender, data_receiver, payload_size }
132- }
133- }
134-
135- impl BodyReceiver {
136- pub async fn receive_data ( & mut self ) -> Result < PayloadItem , ParseError > {
137- if let Err ( e) = self . signal_sender . send ( BodyRequestSignal :: RequestData ) . await {
138- error ! ( "failed to send request_more through channel, {}" , e) ;
139- return Err ( ParseError :: invalid_body ( "failed to send signal when receive body data" ) ) ;
140- }
141-
142- self . data_receiver
143- . next ( )
144- . await
145- . unwrap_or_else ( || Err ( ParseError :: invalid_body ( "body stream should not receive None when receive data" ) ) )
132+ Self { signal_sender, data_receiver, payload_size, in_flight : false }
146133 }
147134}
148135
@@ -153,14 +140,40 @@ impl Body for BodyReceiver {
153140 fn poll_frame ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Option < Result < Frame < Self :: Data > , Self :: Error > > > {
154141 let this = self . get_mut ( ) ;
155142
156- tokio:: pin! {
157- let future = this. receive_data( ) ;
143+ if !this. in_flight {
144+ match Pin :: new ( & mut this. signal_sender ) . poll_ready ( cx) {
145+ Poll :: Ready ( Ok ( ( ) ) ) => {
146+ if let Err ( e) = Pin :: new ( & mut this. signal_sender ) . start_send ( BodyRequestSignal :: RequestData ) {
147+ error ! ( "failed to send request_more through channel, {}" , e) ;
148+ return Poll :: Ready ( Some ( Err ( ParseError :: invalid_body ( "failed to send signal when receive body data" ) ) ) ) ;
149+ }
150+ this. in_flight = true ;
151+ }
152+ Poll :: Ready ( Err ( e) ) => {
153+ error ! ( "failed to prepare request_more through channel, {}" , e) ;
154+ return Poll :: Ready ( Some ( Err ( ParseError :: invalid_body ( "failed to send signal when receive body data" ) ) ) ) ;
155+ }
156+ Poll :: Pending => return Poll :: Pending ,
157+ }
158158 }
159159
160- match future. poll ( cx) {
161- Poll :: Ready ( Ok ( PayloadItem :: Chunk ( bytes) ) ) => Poll :: Ready ( Some ( Ok ( Frame :: data ( bytes) ) ) ) ,
162- Poll :: Ready ( Ok ( PayloadItem :: Eof ) ) => Poll :: Ready ( None ) ,
163- Poll :: Ready ( Err ( e) ) => Poll :: Ready ( Some ( Err ( e) ) ) ,
160+ match this. data_receiver . poll_next_unpin ( cx) {
161+ Poll :: Ready ( Some ( Ok ( PayloadItem :: Chunk ( bytes) ) ) ) => {
162+ this. in_flight = false ;
163+ Poll :: Ready ( Some ( Ok ( Frame :: data ( bytes) ) ) )
164+ }
165+ Poll :: Ready ( Some ( Ok ( PayloadItem :: Eof ) ) ) => {
166+ this. in_flight = false ;
167+ Poll :: Ready ( None )
168+ }
169+ Poll :: Ready ( Some ( Err ( e) ) ) => {
170+ this. in_flight = false ;
171+ Poll :: Ready ( Some ( Err ( e) ) )
172+ }
173+ Poll :: Ready ( None ) => {
174+ this. in_flight = false ;
175+ Poll :: Ready ( Some ( Err ( ParseError :: invalid_body ( "body stream should not receive None when receive data" ) ) ) )
176+ }
164177 Poll :: Pending => Poll :: Pending ,
165178 }
166179 }
@@ -189,3 +202,47 @@ impl From<PayloadSize> for SizeHint {
189202 }
190203 }
191204}
205+
206+ #[ cfg( test) ]
207+ mod tests {
208+ use super :: * ;
209+ use bytes:: Bytes ;
210+ use futures:: channel:: mpsc;
211+ use futures:: task:: noop_waker_ref;
212+ use futures:: { FutureExt , StreamExt } ;
213+ use std:: pin:: Pin ;
214+ use std:: task:: { Context , Poll } ;
215+
216+ #[ tokio:: test]
217+ async fn body_receiver_only_requests_once_until_response ( ) {
218+ let ( signal_sender, mut signal_receiver) = mpsc:: channel ( 8 ) ;
219+ let ( mut data_sender, data_receiver) = mpsc:: channel ( 8 ) ;
220+ let mut body_receiver = BodyReceiver :: new ( signal_sender, data_receiver, PayloadSize :: new_chunked ( ) ) ;
221+
222+ let waker = noop_waker_ref ( ) ;
223+ let mut cx = Context :: from_waker ( waker) ;
224+
225+ assert ! ( matches!( Pin :: new( & mut body_receiver) . poll_frame( & mut cx) , Poll :: Pending ) ) ;
226+ assert ! ( matches!( signal_receiver. next( ) . await , Some ( BodyRequestSignal :: RequestData ) ) ) ;
227+
228+ assert ! ( matches!( Pin :: new( & mut body_receiver) . poll_frame( & mut cx) , Poll :: Pending ) ) ;
229+ assert ! ( signal_receiver. next( ) . now_or_never( ) . is_none( ) ) ;
230+
231+ data_sender. try_send ( Ok ( PayloadItem :: Chunk ( Bytes :: from_static ( b"hello" ) ) ) ) . expect ( "send chunk" ) ;
232+
233+ match Pin :: new ( & mut body_receiver) . poll_frame ( & mut cx) {
234+ Poll :: Ready ( Some ( Ok ( frame) ) ) => {
235+ let data = frame. into_data ( ) . expect ( "expected data frame" ) ;
236+ assert_eq ! ( data, Bytes :: from_static( b"hello" ) ) ;
237+ }
238+ other => panic ! ( "unexpected poll result: {:?}" , other) ,
239+ }
240+
241+ assert ! ( matches!( Pin :: new( & mut body_receiver) . poll_frame( & mut cx) , Poll :: Pending ) ) ;
242+ assert ! ( matches!( signal_receiver. next( ) . await , Some ( BodyRequestSignal :: RequestData ) ) ) ;
243+
244+ data_sender. try_send ( Ok ( PayloadItem :: Eof ) ) . expect ( "send eof" ) ;
245+
246+ assert ! ( matches!( Pin :: new( & mut body_receiver) . poll_frame( & mut cx) , Poll :: Ready ( None ) ) ) ;
247+ }
248+ }
0 commit comments