Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion connectrpc/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use crate::request::{BidiStreamingRequest, ClientStreamingRequest, ServerStreami
use crate::response::UnaryResponse;
#[cfg(feature = "async")]
use crate::response::{BidiStreamingResponse, ClientStreamingResponse, ServerStreamingResponse};
use crate::stream::ServerStreamingEncoder;
use bytes::Bytes;
use http::Uri;

Expand Down Expand Up @@ -184,7 +185,7 @@ impl CommonClient {
&self,
path: &str,
req: ServerStreamingRequest<Req>,
) -> Result<http::Request<Vec<u8>>>
) -> Result<http::Request<ServerStreamingEncoder>>
where
Req: EncodeMessage,
{
Expand Down
29 changes: 25 additions & 4 deletions connectrpc/src/client/reqwest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,31 @@ where
path: &str,
req: ServerStreamingRequest<I>,
) -> Result<ServerStreamingResponse<O>> {
let req = self.common.streaming_request(path, req)?;
let timeout = request::get_timeout(&req);
let mut req: reqwest::Request = req.try_into()?;
*req.timeout_mut() = timeout;
use reqwest::Body;

let http_req = self.common.streaming_request(path, req)?;
let timeout = request::get_timeout(&http_req);

// Split the request to get parts and body separately
let (parts, frame_encoder) = http_req.into_parts();

// Construct reqwest request manually
let mut req_builder = self.client.request(parts.method, parts.uri.to_string());

// Add headers
for (name, value) in &parts.headers {
req_builder = req_builder.header(name.clone(), value.clone());
}

// Wrap frame encoder stream in reqwest Body
let body = Body::wrap_stream(frame_encoder);
req_builder = req_builder.body(body);

// Set timeout and execute
let mut req = req_builder.build()?;
if let Some(timeout) = timeout {
*req.timeout_mut() = Some(timeout);
}
let response = self.client.execute(req).await?;
if response.status().is_success() {
let stream = response.bytes_stream();
Expand Down
107 changes: 100 additions & 7 deletions connectrpc/src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@ use crate::header::{
CONTENT_TYPE,
};
use crate::metadata::Metadata;
use crate::stream::FrameEncoder;
use crate::stream::ServerStreamingEncoder;
use crate::stream::StreamingFrameEncoder;
use crate::stream::UnpinStream;
use futures_util::Stream;
use futures_util::stream;
use http::uri::{Authority, Scheme};
use http::{HeaderMap, HeaderName, HeaderValue, Method, Uri};
use std::time::Duration;
Expand Down Expand Up @@ -271,9 +274,14 @@ impl Builder {
/// POST request will be used.
///
/// https://connectrpc.com/docs/protocol#streaming-request
pub fn server_streaming(mut self, message: Vec<u8>) -> Result<http::Request<Vec<u8>>> {
pub fn server_streaming(
mut self,
message: Vec<u8>,
) -> Result<http::Request<ServerStreamingEncoder>> {
self.validate()?;
let mut req = self.request_base(Method::POST, message)?;
let stream = UnpinStream(Box::pin(stream::iter(std::iter::once(message))));
let encoder = StreamingFrameEncoder::new(stream);
let mut req = self.request_base(Method::POST, encoder)?;
let headers = req.headers_mut();
headers.insert(
CONTENT_TYPE,
Expand Down Expand Up @@ -301,12 +309,12 @@ impl Builder {
pub fn client_streaming<S>(
mut self,
message_stream: S,
) -> Result<http::Request<FrameEncoder<S>>>
) -> Result<http::Request<StreamingFrameEncoder<S>>>
where
S: Stream<Item = Vec<u8>> + Send + Sync + Unpin,
{
self.validate()?;
let encoder = FrameEncoder::new(message_stream);
let encoder = StreamingFrameEncoder::new(message_stream);
let mut req = self.request_base(Method::POST, encoder)?;
let headers = req.headers_mut();
headers.insert(
Expand All @@ -333,12 +341,15 @@ impl Builder {
/// The response will also be a stream of frames that can be decoded.
///
/// https://connectrpc.com/docs/protocol#streaming-request
pub fn bidi_streaming<S>(mut self, message_stream: S) -> Result<http::Request<FrameEncoder<S>>>
pub fn bidi_streaming<S>(
mut self,
message_stream: S,
) -> Result<http::Request<StreamingFrameEncoder<S>>>
where
S: Stream<Item = Vec<u8>> + Send + Sync + Unpin,
{
self.validate()?;
let encoder = FrameEncoder::new(message_stream);
let encoder = StreamingFrameEncoder::new(message_stream);
let mut req = self.request_base(Method::POST, encoder)?;
let headers = req.headers_mut();
headers.insert(
Expand Down Expand Up @@ -750,4 +761,86 @@ mod tests {
assert_eq!(query_map.get("encoding").unwrap(), codec.name());
}
}

#[tokio::test]
async fn test_builder_server_streaming_frames() {
use crate::stream::{ConnectFrame, StreamDecoder};
use futures_util::StreamExt;

// Test that server streaming creates exactly 2 frames:
// 1. Message frame containing the request
// 2. End-of-stream frame
for codec in [Codec::Proto, Codec::Json] {
let request = HelloRequest {
name: "world".to_string(),
};
let body = codec.encode(&request);

let req = Builder::new()
.scheme(Scheme::HTTPS)
.authority("example.com")
.unwrap()
.rpc_path("/helloworld.Greeter/SayHello")
.unwrap()
.message_codec(codec)
.server_streaming(body.clone())
.unwrap();

// Basic request validation
assert_eq!(req.method(), Method::POST);
assert_eq!(
req.uri(),
&"https://example.com/helloworld.Greeter/SayHello"
);
assert_eq!(
req.headers().get(CONTENT_TYPE).unwrap(),
&format!("application/connect+{}", codec.name())
);

// Extract the streaming body and decode frames using StreamDecoder
let stream_encoder = req.into_body();
let frame_stream = StreamDecoder::decode_frames(stream_encoder);

// Collect all decoded frames from the stream
let frames: Vec<ConnectFrame> = frame_stream
.map(|result| result.expect("Frame should be valid"))
.collect()
.await;

// Should have exactly 2 frames: message frame + end frame
assert_eq!(
frames.len(),
2,
"Expected exactly 2 frames for codec {}",
codec.name()
);

// Validate first frame (message frame)
let message_frame = &frames[0];
assert!(
!message_frame.compressed,
"Message frame should not be compressed"
);
assert!(
!message_frame.end,
"Message frame should not be end-of-stream"
);
assert_eq!(
message_frame.data.len(),
body.len(),
"Message data length should match encoded body length"
);
assert_eq!(
&message_frame.data[..],
&body[..],
"Message data should match original"
);

// Validate second frame (end frame)
let end_frame = &frames[1];
assert!(!end_frame.compressed, "End frame should not be compressed");
assert!(end_frame.end, "End frame should be end-of-stream");
assert_eq!(end_frame.data.len(), 0, "End frame should have no data");
}
}
}
92 changes: 62 additions & 30 deletions connectrpc/src/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,20 @@ use http_body::Body;
use http_body_util::BodyExt;
use std::pin::Pin;

/// Type alias for server streaming frame encoder.
/// This encapsulates the internal implementation and allows for future changes.
pub type ServerStreamingEncoder =
StreamingFrameEncoder<UnpinStream<futures_util::stream::Iter<std::iter::Once<Vec<u8>>>>>;

#[derive(Debug, Clone)]
pub struct ConnectFrame {
pub compressed: bool,
pub end: bool,
pub data: Bytes,
}

pub const FLAGS_COMPRESSED: u8 = 0b1;
pub const FLAGS_END: u8 = 0b1;
pub const FLAGS_COMPRESSED: u8 = 0b1; // bit 0
pub const FLAGS_END: u8 = 0b10; // bit 1

impl ConnectFrame {
pub fn body_stream<B>(body: B) -> impl Stream<Item = Result<Self>>
Expand Down Expand Up @@ -109,30 +114,13 @@ impl FrameParseState {
/// - Byte 0: Flags (bit 0 = compressed, bit 1 = end-of-stream)
/// - Bytes 1-4: Message length (u32 big-endian)
/// - Bytes 5+: Message data
pub struct FrameEncoder<S> {
message_stream: S,
finished: bool,
}

impl<S> FrameEncoder<S>
where
S: Stream<Item = Vec<u8>> + Send + Sync,
{
/// Create a new frame encoder from a message stream.
///
/// The stream should yield encoded messages (Vec<u8>).
/// Each message will be wrapped in a Connect frame.
pub fn new(message_stream: S) -> Self {
Self {
message_stream,
finished: false,
}
}
pub struct FrameEncoder;

impl FrameEncoder {
/// Encode a single message into a ConnectFrame.
///
/// Returns the frame bytes that can be sent over HTTP.
fn encode_message(message_data: Vec<u8>) -> Result<Bytes> {
pub fn encode_message(message_data: Vec<u8>) -> Result<Bytes> {
let message_len = message_data.len() as u32;

// Frame format: [flags(1) | length(4) | data]
Expand All @@ -151,7 +139,7 @@ where
}

/// Encode the final frame (end-of-stream marker).
fn encode_end_frame() -> Bytes {
pub fn encode_end_frame() -> Bytes {
let mut frame = BytesMut::with_capacity(5);

// Flags: not compressed, end-of-stream
Expand All @@ -164,7 +152,28 @@ where
}
}

impl<S> Stream for FrameEncoder<S>
pub struct StreamingFrameEncoder<S> {
message_stream: S,
finished: bool,
}

impl<S> StreamingFrameEncoder<S>
where
S: Stream<Item = Vec<u8>> + Send + Sync,
{
/// Create a new frame encoder from a message stream.
///
/// The stream should yield encoded messages (Vec<u8>).
/// Each message will be wrapped in a Connect frame.
pub fn new(message_stream: S) -> Self {
Self {
message_stream,
finished: false,
}
}
}

impl<S> Stream for StreamingFrameEncoder<S>
where
S: Stream<Item = Vec<u8>> + Send + Sync + Unpin,
{
Expand All @@ -185,23 +194,23 @@ where
match std::pin::Pin::new(&mut self.message_stream).poll_next(cx) {
Poll::Ready(Some(message_data)) => {
// Encode the message into a frame
match Self::encode_message(message_data) {
match FrameEncoder::encode_message(message_data) {
Ok(frame) => Poll::Ready(Some(Ok(frame))),
Err(err) => Poll::Ready(Some(Err(err))),
}
}
Poll::Ready(None) => {
// Stream ended, send the end-of-stream frame
self.finished = true;
Poll::Ready(Some(Ok(Self::encode_end_frame())))
Poll::Ready(Some(Ok(FrameEncoder::encode_end_frame())))
}
Poll::Pending => Poll::Pending,
}
}
}

// Wrapper to make non-Unpin streams compatible with FrameEncoder
pub(crate) struct UnpinStream<S: Stream>(pub(crate) Pin<Box<S>>);
pub struct UnpinStream<S: Stream>(pub(crate) Pin<Box<S>>);

impl<S: Stream> Unpin for UnpinStream<S> {}

Expand All @@ -226,7 +235,7 @@ mod tests {

let messages = vec![vec![1, 2, 3, 4, 5]];
let message_stream = stream::iter(messages);
let mut encoder = FrameEncoder::new(message_stream);
let mut encoder = StreamingFrameEncoder::new(message_stream);

// First frame: the message
let frame1 = encoder.next().await.unwrap().unwrap();
Expand Down Expand Up @@ -255,7 +264,7 @@ mod tests {

let messages = vec![vec![1, 2, 3], vec![4, 5], vec![6, 7, 8, 9]];
let message_stream = stream::iter(messages);
let mut encoder = FrameEncoder::new(message_stream);
let mut encoder = StreamingFrameEncoder::new(message_stream);

// Frame 1: first message (3 bytes)
let frame1 = encoder.next().await.unwrap().unwrap();
Expand Down Expand Up @@ -298,7 +307,7 @@ mod tests {

let messages: Vec<Vec<u8>> = vec![];
let message_stream = stream::iter(messages);
let mut encoder = FrameEncoder::new(message_stream);
let mut encoder = StreamingFrameEncoder::new(message_stream);

// Only frame: end-of-stream
let frame = encoder.next().await.unwrap().unwrap();
Expand All @@ -312,3 +321,26 @@ mod tests {
assert!(encoder.next().await.is_none());
}
}

/// A utility for decoding Connect frames from a stream of bytes.
///
/// This provides a convenient way to parse frames for testing and validation.
pub struct StreamDecoder;

impl StreamDecoder {
/// Decode frames from a stream that yields `Result<Bytes>`.
///
/// This is particularly useful for testing frame encoders, as it can
/// directly consume the output of `StreamingFrameEncoder` and parse
/// it back into `ConnectFrame` objects.
pub fn decode_frames<S>(stream: S) -> impl Stream<Item = Result<ConnectFrame>>
where
S: Stream<Item = Result<Bytes>>,
{
// Convert the Result<Bytes> stream to the format expected by ConnectFrame::bytes_stream
let byte_stream =
stream.map(|result: Result<Bytes>| result.map_err(|e| Box::new(e) as BoxError));

ConnectFrame::bytes_stream(byte_stream)
}
}