diff --git a/connectrpc/Cargo.toml b/connectrpc/Cargo.toml index c1f17fb..c66cd80 100644 --- a/connectrpc/Cargo.toml +++ b/connectrpc/Cargo.toml @@ -16,7 +16,7 @@ client = [] server = [] async = [] sync = [] -reqwest = ["dep:reqwest", "client", "async"] +reqwest = ["dep:reqwest", "dep:tokio", "client", "async"] axum = ["dep:axum", "server", "async", "dep:hyper"] [dependencies] @@ -35,6 +35,7 @@ form_urlencoded = "1.2.1" axum = { version = "0.8", optional = true } hyper = { version = "1.7", default-features = false, optional = true } futures-util = "0.3" +tokio = { version = "1", features = ["sync"], optional = true } [dev-dependencies] tokio = { version = "1", features = ["full"] } diff --git a/connectrpc/src/client/mod.rs b/connectrpc/src/client/mod.rs index d9f163b..f633bb7 100644 --- a/connectrpc/src/client/mod.rs +++ b/connectrpc/src/client/mod.rs @@ -11,10 +11,10 @@ use crate::connect::{DecodeMessage, EncodeMessage}; use crate::error::Error; use crate::request::{self, UnaryRequest}; #[cfg(feature = "async")] -use crate::request::{ClientStreamingRequest, ServerStreamingRequest}; +use crate::request::{BidiStreamingRequest, ClientStreamingRequest, ServerStreamingRequest}; use crate::response::UnaryResponse; #[cfg(feature = "async")] -use crate::response::{ClientStreamingResponse, ServerStreamingResponse}; +use crate::response::{BidiStreamingResponse, ClientStreamingResponse, ServerStreamingResponse}; use bytes::Bytes; use http::Uri; @@ -65,8 +65,15 @@ where req: ClientStreamingRequest, ) -> impl Future>> where - S: Stream + Send + Sync + 'static, - I: 'static; + S: Stream + Send + Sync + 'static; + + fn call_bidi_streaming( + &self, + path: &str, + req: BidiStreamingRequest, + ) -> impl Future>> + where + SReq: Stream + Send + Sync + 'static; } #[cfg(feature = "sync")] diff --git a/connectrpc/src/client/reqwest.rs b/connectrpc/src/client/reqwest.rs index 665ac2b..bab6964 100644 --- a/connectrpc/src/client/reqwest.rs +++ b/connectrpc/src/client/reqwest.rs @@ -3,8 +3,13 @@ use crate::Result; use crate::client::AsyncStreamingClient; use crate::codec::Codec; use crate::connect::{DecodeMessage, EncodeMessage}; -use crate::request::{self, ClientStreamingRequest, ServerStreamingRequest, UnaryRequest}; -use crate::response::{ClientStreamingResponse, ServerStreamingResponse, UnaryResponse}; +use crate::error::Error; +use crate::request::{ + self, BidiStreamingRequest, ClientStreamingRequest, ServerStreamingRequest, UnaryRequest, +}; +use crate::response::{ + BidiStreamingResponse, ClientStreamingResponse, ServerStreamingResponse, UnaryResponse, +}; use crate::stream::{ConnectFrame, UnpinStream}; use bytes::Bytes; use futures_util::Stream; @@ -83,7 +88,7 @@ where status: http::StatusCode::OK, codec: self.common.message_codec, metadata: http::HeaderMap::new(), - message_stream: std::boxed::Box::pin(frames), + message_stream: Box::pin(frames), _marker: std::marker::PhantomData, }) } else { @@ -98,7 +103,6 @@ where ) -> Result> where S: Stream + Send + Sync + 'static, - I: 'static, { use reqwest::Body; @@ -170,6 +174,89 @@ where todo!("Handle error response status: {}", response.status()) } } + + async fn call_bidi_streaming( + &self, + path: &str, + req: BidiStreamingRequest, + ) -> Result> + where + SReq: Stream + Send + Sync + 'static, + { + use reqwest::Body; + + let crate::request::Parts { + metadata, + body: message_stream, + } = req.into_parts(); + + // Build the base request using the builder + let builder = self + .common + .builder + .clone() + .rpc_path(path)? + .message_codec(self.common.message_codec) + .append_metadata(metadata); + + // Encode each message + let codec = self.common.message_codec; + let encoded_stream = message_stream.map(move |msg| codec.encode(&msg)); + + // Wrap in UnpinStream for Unpin compatibility + let unpin_stream = UnpinStream(Box::pin(encoded_stream)); + + // Create the HTTP request with frame encoding + let http_req = builder.bidi_streaming(unpin_stream)?; + let timeout = request::get_timeout(&http_req); + + // Split the request to get parts and body separately + let (parts, frame_encoder) = http_req.into_parts(); + + // Construct reqwest request manually + let mut req_builder = self.client.request(parts.method, parts.uri.to_string()); + + // Add headers + for (name, value) in &parts.headers { + req_builder = req_builder.header(name.clone(), value.clone()); + } + + // Wrap frame encoder stream directly in reqwest Body (true streaming!) + let body = Body::wrap_stream(frame_encoder); + req_builder = req_builder.body(body); + + // Set timeout + if let Some(timeout) = timeout { + req_builder = req_builder.timeout(timeout); + } + + // Spawn the request in a background task to allow concurrent request/response streaming + let response_task = tokio::spawn(async move { req_builder.send().await }); + + // Wait for the response to start coming in + let response = response_task + .await + .map_err(|e| Error::internal(format!("request task failed: {}", e)))? + .map_err(|e| Error::internal(format!("request failed: {}", e)))?; + + // Check response status + if response.status().is_success() { + let status = response.status(); + let headers = response.headers().clone(); + let stream = response.bytes_stream(); + let frames = ConnectFrame::bytes_stream(stream); + + Ok(BidiStreamingResponse { + status, + codec: self.common.message_codec, + metadata: headers, + message_stream: Box::pin(frames), + _marker: std::marker::PhantomData, + }) + } else { + todo!("Handle error response status: {}", response.status()) + } + } } impl ReqwestClient { diff --git a/connectrpc/src/request.rs b/connectrpc/src/request.rs index 734dfe9..4252646 100644 --- a/connectrpc/src/request.rs +++ b/connectrpc/src/request.rs @@ -325,6 +325,38 @@ impl Builder { Ok(req) } + /// Build a bidirectional streaming request with the given message stream as the body. + /// POST request will be used. + /// + /// The message_stream should yield encoded messages (Vec). + /// This method wraps the stream in Connect protocol frames. + /// The response will also be a stream of frames that can be decoded. + /// + /// https://connectrpc.com/docs/protocol#streaming-request + pub fn bidi_streaming(mut self, message_stream: S) -> Result>> + where + S: Stream> + Send + Sync + Unpin, + { + self.validate()?; + let encoder = FrameEncoder::new(message_stream); + let mut req = self.request_base(Method::POST, encoder)?; + let headers = req.headers_mut(); + headers.insert( + CONTENT_TYPE, + HeaderValue::from_str(&format!( + "application/connect+{}", + self.message_codec.as_ref().unwrap().name() + ))?, + ); + headers.insert(CONNECT_CONTENT_ENCODING, self.request_content_encoding()); + // Streaming-Accept-Encoding → "connect-accept-encoding" Content-Coding [...] + for value in std::mem::take(&mut self.accept_encodings) { + req.headers_mut().append(CONNECT_ACCEPT_ENCODING, value); + } + + Ok(req) + } + /// Validate that all required fields are set. /// /// This method will be called automatically by the build methods. @@ -568,6 +600,73 @@ where } } +pub struct BidiStreamingRequest +where + T: Send + Sync, + S: Stream + Send + Sync, +{ + metadata: HeaderMap, + message_stream: S, + _phantom: std::marker::PhantomData, +} + +impl BidiStreamingRequest +where + T: Send + Sync, + S: Stream + Send + Sync, +{ + /// Create a new client streaming request with the given message stream and empty metadata. + pub fn new(message_stream: S) -> Self { + Self { + metadata: HeaderMap::new(), + message_stream, + _phantom: std::marker::PhantomData, + } + } + + pub fn with_metadata(mut self, metadata: HeaderMap) -> Self { + self.metadata = metadata; + self + } + + /// Returns a reference to the metadata. + pub fn metadata(&self) -> &HeaderMap { + &self.metadata + } + + /// Returns a mutable reference to the metadata. + pub fn metadata_mut(&mut self) -> &mut HeaderMap { + &mut self.metadata + } + + /// Decomposes the request into its parts. + pub fn into_parts(self) -> Parts { + Parts { + metadata: self.metadata, + body: self.message_stream, + } + } + + /// Creates a request from its parts. + pub fn from_parts(parts: Parts) -> Self { + Self { + metadata: parts.metadata, + message_stream: parts.body, + _phantom: std::marker::PhantomData, + } + } + + /// Consumes the request, returning the message stream. + pub fn into_message_stream(self) -> S { + self.message_stream + } + + /// Returns a reference to the message stream. + pub fn message_stream(&self) -> &S { + &self.message_stream + } +} + #[cfg(feature = "client")] pub(crate) fn get_timeout(req: &http::Request) -> Option { req.headers() diff --git a/connectrpc/src/response.rs b/connectrpc/src/response.rs index 3ebd2f5..2b51e5e 100644 --- a/connectrpc/src/response.rs +++ b/connectrpc/src/response.rs @@ -172,6 +172,67 @@ where } } +pub struct BidiStreamingResponse +where + T: Send + Sync, +{ + pub status: http::StatusCode, + pub metadata: HeaderMap, + pub codec: Codec, + pub message_stream: Pin> + Send + Sync>>, + pub _marker: std::marker::PhantomData, +} + +impl BidiStreamingResponse +where + T: Send + Sync, +{ + /// Returns the http status code of the response. + pub fn status(&self) -> http::StatusCode { + self.status + } + + /// Returns the metadata of the response. + pub fn metadata(&self) -> &HeaderMap { + &self.metadata + } +} + +impl BidiStreamingResponse +where + T: DecodeMessage + Send + 'static, +{ + /// Consumes the response and returns a stream of decoded messages. + /// Each item in the stream is a `Result` where `T` is the decoded message type. + /// The stream automatically deserializes ConnectFrames using the configured codec. + pub fn into_message_stream(self) -> impl Stream> + Send { + futures_util::stream::unfold( + (self.message_stream, self.codec), + |(mut frame_stream, codec)| async move { + loop { + match frame_stream.next().await { + Some(Ok(frame)) => { + // Skip empty frames + if !frame.data.is_empty() { + // Decode the frame data into a message + let result = codec.decode::(&frame.data); + return Some((result, (frame_stream, codec))); + } + // Continue to next frame if this one is empty + } + Some(Err(e)) => { + return Some((Err(e), (frame_stream, codec))); + } + None => { + return None; + } + } + } + }, + ) + } +} + #[derive(Debug)] pub struct ClientStreamingResponse where