Skip to content

Commit e5b462a

Browse files
authored
util: add BodyExt::with_trailers (#99)
* Add `BodyExt::with_trailers` * fix tests * merge with trailers from inner body * clean up setting state * fix `is_end_stream`
1 parent c58b641 commit e5b462a

File tree

4 files changed

+260
-1
lines changed

4 files changed

+260
-1
lines changed

http-body-util/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,4 @@ http-body = { version = "1", path = "../http-body" }
3333
pin-project-lite = "0.2"
3434

3535
[dev-dependencies]
36-
tokio = { version = "1", features = ["macros", "rt"] }
36+
tokio = { version = "1", features = ["macros", "rt", "sync", "rt-multi-thread"] }

http-body-util/src/combinators/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@ mod collect;
55
mod frame;
66
mod map_err;
77
mod map_frame;
8+
mod with_trailers;
89

910
pub use self::{
1011
box_body::{BoxBody, UnsyncBoxBody},
1112
collect::Collect,
1213
frame::Frame,
1314
map_err::MapErr,
1415
map_frame::MapFrame,
16+
with_trailers::WithTrailers,
1517
};
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
use std::{
2+
future::Future,
3+
pin::Pin,
4+
task::{Context, Poll},
5+
};
6+
7+
use futures_util::ready;
8+
use http::HeaderMap;
9+
use http_body::{Body, Frame};
10+
use pin_project_lite::pin_project;
11+
12+
pin_project! {
13+
/// Adds trailers to a body.
14+
///
15+
/// See [`BodyExt::with_trailers`] for more details.
16+
pub struct WithTrailers<T, F> {
17+
#[pin]
18+
state: State<T, F>,
19+
}
20+
}
21+
22+
impl<T, F> WithTrailers<T, F> {
23+
pub(crate) fn new(body: T, trailers: F) -> Self {
24+
Self {
25+
state: State::PollBody {
26+
body,
27+
trailers: Some(trailers),
28+
},
29+
}
30+
}
31+
}
32+
33+
pin_project! {
34+
#[project = StateProj]
35+
enum State<T, F> {
36+
PollBody {
37+
#[pin]
38+
body: T,
39+
trailers: Option<F>,
40+
},
41+
PollTrailers {
42+
#[pin]
43+
trailers: F,
44+
prev_trailers: Option<HeaderMap>,
45+
},
46+
Done,
47+
}
48+
}
49+
50+
impl<T, F> Body for WithTrailers<T, F>
51+
where
52+
T: Body,
53+
F: Future<Output = Option<Result<HeaderMap, T::Error>>>,
54+
{
55+
type Data = T::Data;
56+
type Error = T::Error;
57+
58+
fn poll_frame(
59+
mut self: Pin<&mut Self>,
60+
cx: &mut Context<'_>,
61+
) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
62+
loop {
63+
let mut this = self.as_mut().project();
64+
65+
match this.state.as_mut().project() {
66+
StateProj::PollBody { body, trailers } => match ready!(body.poll_frame(cx)?) {
67+
Some(frame) => match frame.into_trailers() {
68+
Ok(prev_trailers) => {
69+
let trailers = trailers.take().unwrap();
70+
this.state.set(State::PollTrailers {
71+
trailers,
72+
prev_trailers: Some(prev_trailers),
73+
});
74+
}
75+
Err(frame) => {
76+
return Poll::Ready(Some(Ok(frame)));
77+
}
78+
},
79+
None => {
80+
let trailers = trailers.take().unwrap();
81+
this.state.set(State::PollTrailers {
82+
trailers,
83+
prev_trailers: None,
84+
});
85+
}
86+
},
87+
StateProj::PollTrailers {
88+
trailers,
89+
prev_trailers,
90+
} => {
91+
let trailers = ready!(trailers.poll(cx)?);
92+
match (trailers, prev_trailers.take()) {
93+
(None, None) => return Poll::Ready(None),
94+
(None, Some(trailers)) | (Some(trailers), None) => {
95+
this.state.set(State::Done);
96+
return Poll::Ready(Some(Ok(Frame::trailers(trailers))));
97+
}
98+
(Some(new_trailers), Some(mut prev_trailers)) => {
99+
prev_trailers.extend(new_trailers);
100+
this.state.set(State::Done);
101+
return Poll::Ready(Some(Ok(Frame::trailers(prev_trailers))));
102+
}
103+
}
104+
}
105+
StateProj::Done => {
106+
return Poll::Ready(None);
107+
}
108+
}
109+
}
110+
}
111+
112+
#[inline]
113+
fn size_hint(&self) -> http_body::SizeHint {
114+
match &self.state {
115+
State::PollBody { body, .. } => body.size_hint(),
116+
State::PollTrailers { .. } | State::Done => Default::default(),
117+
}
118+
}
119+
}
120+
121+
#[cfg(test)]
122+
mod tests {
123+
use std::convert::Infallible;
124+
125+
use bytes::Bytes;
126+
use http::{HeaderMap, HeaderName, HeaderValue};
127+
128+
use crate::{BodyExt, Empty, Full};
129+
130+
#[allow(unused_imports)]
131+
use super::*;
132+
133+
#[tokio::test]
134+
async fn works() {
135+
let mut trailers = HeaderMap::new();
136+
trailers.insert(
137+
HeaderName::from_static("foo"),
138+
HeaderValue::from_static("bar"),
139+
);
140+
141+
let body =
142+
Full::<Bytes>::from("hello").with_trailers(std::future::ready(Some(
143+
Ok::<_, Infallible>(trailers.clone()),
144+
)));
145+
146+
futures_util::pin_mut!(body);
147+
let waker = futures_util::task::noop_waker();
148+
let mut cx = Context::from_waker(&waker);
149+
150+
let data = unwrap_ready(body.as_mut().poll_frame(&mut cx))
151+
.unwrap()
152+
.unwrap()
153+
.into_data()
154+
.unwrap();
155+
assert_eq!(data, "hello");
156+
157+
let body_trailers = unwrap_ready(body.as_mut().poll_frame(&mut cx))
158+
.unwrap()
159+
.unwrap()
160+
.into_trailers()
161+
.unwrap();
162+
assert_eq!(body_trailers, trailers);
163+
164+
assert!(unwrap_ready(body.as_mut().poll_frame(&mut cx)).is_none());
165+
}
166+
167+
#[tokio::test]
168+
async fn merges_trailers() {
169+
let mut trailers_1 = HeaderMap::new();
170+
trailers_1.insert(
171+
HeaderName::from_static("foo"),
172+
HeaderValue::from_static("bar"),
173+
);
174+
175+
let mut trailers_2 = HeaderMap::new();
176+
trailers_2.insert(
177+
HeaderName::from_static("baz"),
178+
HeaderValue::from_static("qux"),
179+
);
180+
181+
let body = Empty::<Bytes>::new()
182+
.with_trailers(std::future::ready(Some(Ok::<_, Infallible>(
183+
trailers_1.clone(),
184+
))))
185+
.with_trailers(std::future::ready(Some(Ok::<_, Infallible>(
186+
trailers_2.clone(),
187+
))));
188+
189+
futures_util::pin_mut!(body);
190+
let waker = futures_util::task::noop_waker();
191+
let mut cx = Context::from_waker(&waker);
192+
193+
let body_trailers = unwrap_ready(body.as_mut().poll_frame(&mut cx))
194+
.unwrap()
195+
.unwrap()
196+
.into_trailers()
197+
.unwrap();
198+
199+
let mut all_trailers = HeaderMap::new();
200+
all_trailers.extend(trailers_1);
201+
all_trailers.extend(trailers_2);
202+
assert_eq!(body_trailers, all_trailers);
203+
204+
assert!(unwrap_ready(body.as_mut().poll_frame(&mut cx)).is_none());
205+
}
206+
207+
fn unwrap_ready<T>(poll: Poll<T>) -> T {
208+
match poll {
209+
Poll::Ready(t) => t,
210+
Poll::Pending => panic!("pending"),
211+
}
212+
}
213+
}

http-body-util/src/lib.rs

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,50 @@ pub trait BodyExt: http_body::Body {
8989
collected: Some(crate::Collected::default()),
9090
}
9191
}
92+
93+
/// Add trailers to the body.
94+
///
95+
/// The trailers will be sent when all previous frames have been sent and the `trailers` future
96+
/// resolves.
97+
///
98+
/// # Example
99+
///
100+
/// ```
101+
/// use http::HeaderMap;
102+
/// use http_body_util::{Full, BodyExt};
103+
/// use bytes::Bytes;
104+
///
105+
/// # #[tokio::main]
106+
/// async fn main() {
107+
/// let (tx, rx) = tokio::sync::oneshot::channel::<HeaderMap>();
108+
///
109+
/// let body = Full::<Bytes>::from("Hello, World!")
110+
/// // add trailers via a future
111+
/// .with_trailers(async move {
112+
/// match rx.await {
113+
/// Ok(trailers) => Some(Ok(trailers)),
114+
/// Err(_err) => None,
115+
/// }
116+
/// });
117+
///
118+
/// // compute the trailers in the background
119+
/// tokio::spawn(async move {
120+
/// let _ = tx.send(compute_trailers().await);
121+
/// });
122+
///
123+
/// async fn compute_trailers() -> HeaderMap {
124+
/// // ...
125+
/// # unimplemented!()
126+
/// }
127+
/// # }
128+
/// ```
129+
fn with_trailers<F>(self, trailers: F) -> combinators::WithTrailers<Self, F>
130+
where
131+
Self: Sized,
132+
F: std::future::Future<Output = Option<Result<http::HeaderMap, Self::Error>>>,
133+
{
134+
combinators::WithTrailers::new(self, trailers)
135+
}
92136
}
93137

94138
impl<T: ?Sized> BodyExt for T where T: http_body::Body {}

0 commit comments

Comments
 (0)