diff --git a/linkerd/http/retry/src/peek_trailers.rs b/linkerd/http/retry/src/peek_trailers.rs index f38f42f368..05e237172e 100644 --- a/linkerd/http/retry/src/peek_trailers.rs +++ b/linkerd/http/retry/src/peek_trailers.rs @@ -182,7 +182,7 @@ impl PeekTrailersBody { }, // The body yielded an unknown kind of frame. Some(Ok(None)) => Inner::Buffered { - first: None, + first: Some(Ok(first)), second: None, inner: body, }, @@ -192,7 +192,7 @@ impl PeekTrailersBody { // that a second DATA frame is on the way, and we are no longer willing to // await additional frames. There are no trailers to peek. Inner::Buffered { - first: None, + first: Some(Ok(first)), second: None, inner: body, } @@ -338,6 +338,7 @@ mod tests { use bytes::Bytes; use http::{HeaderMap, HeaderValue}; use http_body::Body; + use http_body_util::BodyExt; use linkerd_error::Error; use linkerd_mock_http_body::MockBody; use std::{ops::Not, task::Poll}; @@ -354,6 +355,17 @@ mod tests { Some(Ok(trls)) } + async fn collect(body: B) -> (Bytes, Option) + where + B: Body, + B::Error: std::fmt::Debug, + { + let coll = body.collect().await.expect("can collect"); + let trls = coll.trailers().cloned(); + let data = coll.to_bytes(); + (data, trls) + } + #[tokio::test] async fn cannot_peek_empty() { let (_guard, _handle) = linkerd_tracing::test::trace_init(); @@ -361,6 +373,10 @@ mod tests { let peek = PeekTrailersBody::read_body(empty).await; assert!(peek.peek_trailers().is_none()); assert!(peek.is_end_stream()); + + let (data, trailers) = collect(peek).await; + assert_eq!(data, ""); + assert!(trailers.is_none()); } #[tokio::test] @@ -370,6 +386,10 @@ mod tests { let peek = PeekTrailersBody::read_body(only_trailers).await; assert!(peek.peek_trailers().is_some()); assert!(peek.is_end_stream().not()); + + let (data, trailers) = collect(peek).await; + assert_eq!(data, ""); + assert_eq!(trailers.unwrap().get("trailer").unwrap(), "shiny"); } #[tokio::test] @@ -381,6 +401,10 @@ mod tests { let peek = PeekTrailersBody::read_body(body).await; assert!(peek.peek_trailers().is_some()); assert!(peek.is_end_stream().not()); + + let (data, trailers) = collect(peek).await; + assert_eq!(data, "hello"); + assert_eq!(trailers.unwrap().get("trailer").unwrap(), "shiny"); } #[tokio::test] @@ -393,6 +417,10 @@ mod tests { let peek = PeekTrailersBody::read_body(body).await; assert!(peek.peek_trailers().is_none()); assert!(peek.is_end_stream().not()); + + let (data, trailers) = collect(peek).await; + assert_eq!(data, "hello"); + assert_eq!(trailers.unwrap().get("trailer").unwrap(), "shiny"); } #[tokio::test] @@ -405,6 +433,10 @@ mod tests { let peek = PeekTrailersBody::read_body(body).await; assert!(peek.peek_trailers().is_some()); assert!(peek.is_end_stream().not()); + + let (data, trailers) = collect(peek).await; + assert_eq!(data, "hello"); + assert_eq!(trailers.unwrap().get("trailer").unwrap(), "shiny"); } #[tokio::test] @@ -417,5 +449,27 @@ mod tests { let peek = PeekTrailersBody::read_body(body).await; assert!(peek.peek_trailers().is_none()); assert!(peek.is_end_stream().not()); + + let (data, trailers) = collect(peek).await; + assert_eq!(data, "hellohello"); + assert_eq!(trailers.unwrap().get("trailer").unwrap(), "shiny"); + } + + #[tokio::test] + async fn cannot_peek_body_with_various_pending_polls() { + let (_guard, _handle) = linkerd_tracing::test::trace_init(); + let body = MockBody::default() + .then_yield_data(Poll::Ready(data())) + .then_yield_data(Poll::Pending) + .then_yield_data(Poll::Ready(data())) + .then_yield_data(Poll::Pending) + .then_yield_trailer(Poll::Ready(trailers())); + let peek = PeekTrailersBody::read_body(body).await; + assert!(peek.peek_trailers().is_none()); + assert!(peek.is_end_stream().not()); + + let (data, trailers) = collect(peek).await; + assert_eq!(data, "hellohello"); + assert_eq!(trailers.unwrap().get("trailer").unwrap(), "shiny"); } }