Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion connectrpc/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ client = []
server = []
async = []
sync = []
reqwest = ["dep:reqwest", "client", "async"]
reqwest = ["dep:reqwest", "dep:tokio", "client", "async"]
axum = ["dep:axum", "server", "async", "dep:hyper"]

[dependencies]
Expand All @@ -35,6 +35,7 @@ form_urlencoded = "1.2.1"
axum = { version = "0.8", optional = true }
hyper = { version = "1.7", default-features = false, optional = true }
futures-util = "0.3"
tokio = { version = "1", features = ["sync"], optional = true }

[dev-dependencies]
tokio = { version = "1", features = ["full"] }
15 changes: 11 additions & 4 deletions connectrpc/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ use crate::connect::{DecodeMessage, EncodeMessage};
use crate::error::Error;
use crate::request::{self, UnaryRequest};
#[cfg(feature = "async")]
use crate::request::{ClientStreamingRequest, ServerStreamingRequest};
use crate::request::{BidiStreamingRequest, ClientStreamingRequest, ServerStreamingRequest};
use crate::response::UnaryResponse;
#[cfg(feature = "async")]
use crate::response::{ClientStreamingResponse, ServerStreamingResponse};
use crate::response::{BidiStreamingResponse, ClientStreamingResponse, ServerStreamingResponse};
use bytes::Bytes;
use http::Uri;

Expand Down Expand Up @@ -65,8 +65,15 @@ where
req: ClientStreamingRequest<I, S>,
) -> impl Future<Output = Result<ClientStreamingResponse<O>>>
where
S: Stream<Item = I> + Send + Sync + 'static,
I: 'static;
S: Stream<Item = I> + Send + Sync + 'static;

fn call_bidi_streaming<SReq>(
&self,
path: &str,
req: BidiStreamingRequest<I, SReq>,
) -> impl Future<Output = Result<BidiStreamingResponse<O>>>
where
SReq: Stream<Item = I> + Send + Sync + 'static;
}

#[cfg(feature = "sync")]
Expand Down
95 changes: 91 additions & 4 deletions connectrpc/src/client/reqwest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,13 @@ use crate::Result;
use crate::client::AsyncStreamingClient;
use crate::codec::Codec;
use crate::connect::{DecodeMessage, EncodeMessage};
use crate::request::{self, ClientStreamingRequest, ServerStreamingRequest, UnaryRequest};
use crate::response::{ClientStreamingResponse, ServerStreamingResponse, UnaryResponse};
use crate::error::Error;
use crate::request::{
self, BidiStreamingRequest, ClientStreamingRequest, ServerStreamingRequest, UnaryRequest,
};
use crate::response::{
BidiStreamingResponse, ClientStreamingResponse, ServerStreamingResponse, UnaryResponse,
};
use crate::stream::{ConnectFrame, UnpinStream};
use bytes::Bytes;
use futures_util::Stream;
Expand Down Expand Up @@ -83,7 +88,7 @@ where
status: http::StatusCode::OK,
codec: self.common.message_codec,
metadata: http::HeaderMap::new(),
message_stream: std::boxed::Box::pin(frames),
message_stream: Box::pin(frames),
_marker: std::marker::PhantomData,
})
} else {
Expand All @@ -98,7 +103,6 @@ where
) -> Result<ClientStreamingResponse<O>>
where
S: Stream<Item = I> + Send + Sync + 'static,
I: 'static,
{
use reqwest::Body;

Expand Down Expand Up @@ -170,6 +174,89 @@ where
todo!("Handle error response status: {}", response.status())
}
}

async fn call_bidi_streaming<SReq>(
&self,
path: &str,
req: BidiStreamingRequest<I, SReq>,
) -> Result<BidiStreamingResponse<O>>
where
SReq: Stream<Item = I> + Send + Sync + 'static,
{
use reqwest::Body;

let crate::request::Parts {
metadata,
body: message_stream,
} = req.into_parts();

// Build the base request using the builder
let builder = self
.common
.builder
.clone()
.rpc_path(path)?
.message_codec(self.common.message_codec)
.append_metadata(metadata);

// Encode each message
let codec = self.common.message_codec;
let encoded_stream = message_stream.map(move |msg| codec.encode(&msg));

// Wrap in UnpinStream for Unpin compatibility
let unpin_stream = UnpinStream(Box::pin(encoded_stream));

// Create the HTTP request with frame encoding
let http_req = builder.bidi_streaming(unpin_stream)?;
let timeout = request::get_timeout(&http_req);

// Split the request to get parts and body separately
let (parts, frame_encoder) = http_req.into_parts();

// Construct reqwest request manually
let mut req_builder = self.client.request(parts.method, parts.uri.to_string());

// Add headers
for (name, value) in &parts.headers {
req_builder = req_builder.header(name.clone(), value.clone());
}

// Wrap frame encoder stream directly in reqwest Body (true streaming!)
let body = Body::wrap_stream(frame_encoder);
req_builder = req_builder.body(body);

// Set timeout
if let Some(timeout) = timeout {
req_builder = req_builder.timeout(timeout);
}

// Spawn the request in a background task to allow concurrent request/response streaming
let response_task = tokio::spawn(async move { req_builder.send().await });

// Wait for the response to start coming in
let response = response_task
.await
.map_err(|e| Error::internal(format!("request task failed: {}", e)))?
.map_err(|e| Error::internal(format!("request failed: {}", e)))?;

// Check response status
if response.status().is_success() {
let status = response.status();
let headers = response.headers().clone();
let stream = response.bytes_stream();
let frames = ConnectFrame::bytes_stream(stream);

Ok(BidiStreamingResponse {
status,
codec: self.common.message_codec,
metadata: headers,
message_stream: Box::pin(frames),
_marker: std::marker::PhantomData,
})
} else {
todo!("Handle error response status: {}", response.status())
}
}
}

impl ReqwestClient {
Expand Down
99 changes: 99 additions & 0 deletions connectrpc/src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,38 @@ impl Builder {
Ok(req)
}

/// Build a bidirectional streaming request with the given message stream as the body.
/// POST request will be used.
///
/// The message_stream should yield encoded messages (Vec<u8>).
/// This method wraps the stream in Connect protocol frames.
/// The response will also be a stream of frames that can be decoded.
///
/// https://connectrpc.com/docs/protocol#streaming-request
pub fn bidi_streaming<S>(mut self, message_stream: S) -> Result<http::Request<FrameEncoder<S>>>
where
S: Stream<Item = Vec<u8>> + Send + Sync + Unpin,
{
self.validate()?;
let encoder = FrameEncoder::new(message_stream);
let mut req = self.request_base(Method::POST, encoder)?;
let headers = req.headers_mut();
headers.insert(
CONTENT_TYPE,
HeaderValue::from_str(&format!(
"application/connect+{}",
self.message_codec.as_ref().unwrap().name()
))?,
);
headers.insert(CONNECT_CONTENT_ENCODING, self.request_content_encoding());
// Streaming-Accept-Encoding → "connect-accept-encoding" Content-Coding [...]
for value in std::mem::take(&mut self.accept_encodings) {
req.headers_mut().append(CONNECT_ACCEPT_ENCODING, value);
}

Ok(req)
}

/// Validate that all required fields are set.
///
/// This method will be called automatically by the build methods.
Expand Down Expand Up @@ -568,6 +600,73 @@ where
}
}

pub struct BidiStreamingRequest<T, S>
where
T: Send + Sync,
S: Stream<Item = T> + Send + Sync,
{
metadata: HeaderMap,
message_stream: S,
_phantom: std::marker::PhantomData<T>,
}

impl<T, S> BidiStreamingRequest<T, S>
where
T: Send + Sync,
S: Stream<Item = T> + Send + Sync,
{
/// Create a new client streaming request with the given message stream and empty metadata.
pub fn new(message_stream: S) -> Self {
Self {
metadata: HeaderMap::new(),
message_stream,
_phantom: std::marker::PhantomData,
}
}

pub fn with_metadata(mut self, metadata: HeaderMap) -> Self {
self.metadata = metadata;
self
}

/// Returns a reference to the metadata.
pub fn metadata(&self) -> &HeaderMap {
&self.metadata
}

/// Returns a mutable reference to the metadata.
pub fn metadata_mut(&mut self) -> &mut HeaderMap {
&mut self.metadata
}

/// Decomposes the request into its parts.
pub fn into_parts(self) -> Parts<S> {
Parts {
metadata: self.metadata,
body: self.message_stream,
}
}

/// Creates a request from its parts.
pub fn from_parts(parts: Parts<S>) -> Self {
Self {
metadata: parts.metadata,
message_stream: parts.body,
_phantom: std::marker::PhantomData,
}
}

/// Consumes the request, returning the message stream.
pub fn into_message_stream(self) -> S {
self.message_stream
}

/// Returns a reference to the message stream.
pub fn message_stream(&self) -> &S {
&self.message_stream
}
}

#[cfg(feature = "client")]
pub(crate) fn get_timeout<T>(req: &http::Request<T>) -> Option<Duration> {
req.headers()
Expand Down
61 changes: 61 additions & 0 deletions connectrpc/src/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,67 @@ where
}
}

pub struct BidiStreamingResponse<T>
where
T: Send + Sync,
{
pub status: http::StatusCode,
pub metadata: HeaderMap,
pub codec: Codec,
pub message_stream: Pin<Box<dyn Stream<Item = Result<ConnectFrame>> + Send + Sync>>,
pub _marker: std::marker::PhantomData<T>,
}

impl<T> BidiStreamingResponse<T>
where
T: Send + Sync,
{
/// Returns the http status code of the response.
pub fn status(&self) -> http::StatusCode {
self.status
}

/// Returns the metadata of the response.
pub fn metadata(&self) -> &HeaderMap {
&self.metadata
}
}

impl<T> BidiStreamingResponse<T>
where
T: DecodeMessage + Send + 'static,
{
/// Consumes the response and returns a stream of decoded messages.
/// Each item in the stream is a `Result<T>` where `T` is the decoded message type.
/// The stream automatically deserializes ConnectFrames using the configured codec.
pub fn into_message_stream(self) -> impl Stream<Item = Result<T>> + Send {
futures_util::stream::unfold(
(self.message_stream, self.codec),
|(mut frame_stream, codec)| async move {
loop {
match frame_stream.next().await {
Some(Ok(frame)) => {
// Skip empty frames
if !frame.data.is_empty() {
// Decode the frame data into a message
let result = codec.decode::<T>(&frame.data);
return Some((result, (frame_stream, codec)));
}
// Continue to next frame if this one is empty
}
Some(Err(e)) => {
return Some((Err(e), (frame_stream, codec)));
}
None => {
return None;
}
}
}
},
)
}
}

#[derive(Debug)]
pub struct ClientStreamingResponse<T>
where
Expand Down