Skip to content

Commit 78bfb00

Browse files
committed
Refine body receiver polling
1 parent 833b8bc commit 78bfb00

File tree

1 file changed

+79
-22
lines changed

1 file changed

+79
-22
lines changed

crates/http/src/protocol/body/body_channel.rs

Lines changed: 79 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use crate::protocol::{Message, ParseError, PayloadItem, PayloadSize, RequestHeader};
22
use bytes::Bytes;
3-
use futures::{SinkExt, Stream, StreamExt, channel::mpsc};
3+
use futures::{Sink, SinkExt, Stream, StreamExt, channel::mpsc};
44
use http_body::{Body, Frame, SizeHint};
55
use std::pin::Pin;
66
use std::task::{Context, Poll};
@@ -120,6 +120,7 @@ pub(crate) struct BodyReceiver {
120120
signal_sender: mpsc::Sender<BodyRequestSignal>,
121121
data_receiver: mpsc::Receiver<Result<PayloadItem, ParseError>>,
122122
payload_size: PayloadSize,
123+
in_flight: bool,
123124
}
124125

125126
impl BodyReceiver {
@@ -128,21 +129,7 @@ impl BodyReceiver {
128129
data_receiver: mpsc::Receiver<Result<PayloadItem, ParseError>>,
129130
payload_size: PayloadSize,
130131
) -> Self {
131-
Self { signal_sender, data_receiver, payload_size }
132-
}
133-
}
134-
135-
impl BodyReceiver {
136-
pub async fn receive_data(&mut self) -> Result<PayloadItem, ParseError> {
137-
if let Err(e) = self.signal_sender.send(BodyRequestSignal::RequestData).await {
138-
error!("failed to send request_more through channel, {}", e);
139-
return Err(ParseError::invalid_body("failed to send signal when receive body data"));
140-
}
141-
142-
self.data_receiver
143-
.next()
144-
.await
145-
.unwrap_or_else(|| Err(ParseError::invalid_body("body stream should not receive None when receive data")))
132+
Self { signal_sender, data_receiver, payload_size, in_flight: false }
146133
}
147134
}
148135

@@ -153,14 +140,40 @@ impl Body for BodyReceiver {
153140
fn poll_frame(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
154141
let this = self.get_mut();
155142

156-
tokio::pin! {
157-
let future = this.receive_data();
143+
if !this.in_flight {
144+
match Pin::new(&mut this.signal_sender).poll_ready(cx) {
145+
Poll::Ready(Ok(())) => {
146+
if let Err(e) = Pin::new(&mut this.signal_sender).start_send(BodyRequestSignal::RequestData) {
147+
error!("failed to send request_more through channel, {}", e);
148+
return Poll::Ready(Some(Err(ParseError::invalid_body("failed to send signal when receive body data"))));
149+
}
150+
this.in_flight = true;
151+
}
152+
Poll::Ready(Err(e)) => {
153+
error!("failed to prepare request_more through channel, {}", e);
154+
return Poll::Ready(Some(Err(ParseError::invalid_body("failed to send signal when receive body data"))));
155+
}
156+
Poll::Pending => return Poll::Pending,
157+
}
158158
}
159159

160-
match future.poll(cx) {
161-
Poll::Ready(Ok(PayloadItem::Chunk(bytes))) => Poll::Ready(Some(Ok(Frame::data(bytes)))),
162-
Poll::Ready(Ok(PayloadItem::Eof)) => Poll::Ready(None),
163-
Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
160+
match this.data_receiver.poll_next_unpin(cx) {
161+
Poll::Ready(Some(Ok(PayloadItem::Chunk(bytes)))) => {
162+
this.in_flight = false;
163+
Poll::Ready(Some(Ok(Frame::data(bytes))))
164+
}
165+
Poll::Ready(Some(Ok(PayloadItem::Eof))) => {
166+
this.in_flight = false;
167+
Poll::Ready(None)
168+
}
169+
Poll::Ready(Some(Err(e))) => {
170+
this.in_flight = false;
171+
Poll::Ready(Some(Err(e)))
172+
}
173+
Poll::Ready(None) => {
174+
this.in_flight = false;
175+
Poll::Ready(Some(Err(ParseError::invalid_body("body stream should not receive None when receive data"))))
176+
}
164177
Poll::Pending => Poll::Pending,
165178
}
166179
}
@@ -189,3 +202,47 @@ impl From<PayloadSize> for SizeHint {
189202
}
190203
}
191204
}
205+
206+
#[cfg(test)]
207+
mod tests {
208+
use super::*;
209+
use bytes::Bytes;
210+
use futures::channel::mpsc;
211+
use futures::task::noop_waker_ref;
212+
use futures::{FutureExt, StreamExt};
213+
use std::pin::Pin;
214+
use std::task::{Context, Poll};
215+
216+
#[tokio::test]
217+
async fn body_receiver_only_requests_once_until_response() {
218+
let (signal_sender, mut signal_receiver) = mpsc::channel(8);
219+
let (mut data_sender, data_receiver) = mpsc::channel(8);
220+
let mut body_receiver = BodyReceiver::new(signal_sender, data_receiver, PayloadSize::new_chunked());
221+
222+
let waker = noop_waker_ref();
223+
let mut cx = Context::from_waker(waker);
224+
225+
assert!(matches!(Pin::new(&mut body_receiver).poll_frame(&mut cx), Poll::Pending));
226+
assert!(matches!(signal_receiver.next().await, Some(BodyRequestSignal::RequestData)));
227+
228+
assert!(matches!(Pin::new(&mut body_receiver).poll_frame(&mut cx), Poll::Pending));
229+
assert!(signal_receiver.next().now_or_never().is_none());
230+
231+
data_sender.try_send(Ok(PayloadItem::Chunk(Bytes::from_static(b"hello")))).expect("send chunk");
232+
233+
match Pin::new(&mut body_receiver).poll_frame(&mut cx) {
234+
Poll::Ready(Some(Ok(frame))) => {
235+
let data = frame.into_data().expect("expected data frame");
236+
assert_eq!(data, Bytes::from_static(b"hello"));
237+
}
238+
other => panic!("unexpected poll result: {:?}", other),
239+
}
240+
241+
assert!(matches!(Pin::new(&mut body_receiver).poll_frame(&mut cx), Poll::Pending));
242+
assert!(matches!(signal_receiver.next().await, Some(BodyRequestSignal::RequestData)));
243+
244+
data_sender.try_send(Ok(PayloadItem::Eof)).expect("send eof");
245+
246+
assert!(matches!(Pin::new(&mut body_receiver).poll_frame(&mut cx), Poll::Ready(None)));
247+
}
248+
}

0 commit comments

Comments
 (0)