Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions connectrpc/src/client/mod.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
#[cfg(feature = "reqwest")]
pub mod reqwest;
#[cfg(feature = "async")]
use futures_util::Stream;
#[cfg(feature = "reqwest")]
pub use reqwest::ReqwestClient;

use crate::Result;
use crate::codec::Codec;
use crate::connect::{DecodeMessage, EncodeMessage};
use crate::error::Error;
#[cfg(feature = "async")]
use crate::request::StreamingRequest;
use crate::request::{self, UnaryRequest};
#[cfg(feature = "async")]
use crate::response::ServerStreamingResponse;
use crate::request::{ClientStreamingRequest, ServerStreamingRequest};
use crate::response::UnaryResponse;
#[cfg(feature = "async")]
use crate::response::{ClientStreamingResponse, ServerStreamingResponse};
use bytes::Bytes;
use http::Uri;

Expand Down Expand Up @@ -54,8 +56,17 @@ where
fn call_server_streaming(
&self,
path: &str,
req: StreamingRequest<I>,
req: ServerStreamingRequest<I>,
) -> impl Future<Output = Result<ServerStreamingResponse<O>>>;

fn call_client_streaming<S>(
&self,
path: &str,
req: ClientStreamingRequest<I, S>,
) -> impl Future<Output = Result<ClientStreamingResponse<O>>>
where
S: Stream<Item = I> + Send + Sync + 'static,
I: 'static;
}

#[cfg(feature = "sync")]
Expand Down Expand Up @@ -165,7 +176,7 @@ impl CommonClient {
pub fn streaming_request<Req>(
&self,
path: &str,
req: StreamingRequest<Req>,
req: ServerStreamingRequest<Req>,
) -> Result<http::Request<Vec<u8>>>
where
Req: EncodeMessage,
Expand All @@ -177,7 +188,7 @@ impl CommonClient {
.rpc_path(path)?
.message_codec(self.message_codec)
.append_metadata(metadata)
.streaming(body)
.server_streaming(body)
}

/// Parses a unary response from the given HTTP response.
Expand Down
92 changes: 87 additions & 5 deletions connectrpc/src/client/reqwest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ use crate::Result;
use crate::client::AsyncStreamingClient;
use crate::codec::Codec;
use crate::connect::{DecodeMessage, EncodeMessage};
use crate::request::{self, StreamingRequest, UnaryRequest};
use crate::response::{ServerStreamingResponse, UnaryResponse};
use crate::stream::ConnectFrame;
use crate::request::{self, ClientStreamingRequest, ServerStreamingRequest, UnaryRequest};
use crate::response::{ClientStreamingResponse, ServerStreamingResponse, UnaryResponse};
use crate::stream::{ConnectFrame, UnpinStream};
use bytes::Bytes;
use futures_util::Stream;
use futures_util::stream::StreamExt;
use http::Uri;

/// A client implementation using the `reqwest` HTTP client library.
Expand Down Expand Up @@ -67,7 +69,7 @@ where
async fn call_server_streaming(
&self,
path: &str,
req: StreamingRequest<I>,
req: ServerStreamingRequest<I>,
) -> Result<ServerStreamingResponse<O>> {
let req = self.common.streaming_request(path, req)?;
let timeout = request::get_timeout(&req);
Expand All @@ -85,7 +87,87 @@ where
_marker: std::marker::PhantomData,
})
} else {
todo!()
todo!("Handle error response status: {}", response.status())
}
}

async fn call_client_streaming<S>(
&self,
path: &str,
req: ClientStreamingRequest<I, S>,
) -> Result<ClientStreamingResponse<O>>
where
S: Stream<Item = I> + Send + Sync + 'static,
I: 'static,
{
use reqwest::Body;

let crate::request::Parts {
metadata,
body: message_stream,
} = req.into_parts();

// Build the base request using the builder
let builder = self
.common
.builder
.clone()
.rpc_path(path)?
.message_codec(self.common.message_codec)
.append_metadata(metadata);

// Encode each message
let codec = self.common.message_codec;
let encoded_stream = message_stream.map(move |msg| codec.encode(&msg));

// Wrap in UnpinStream for Unpin compatibility
let unpin_stream = UnpinStream(Box::pin(encoded_stream));

// Create the HTTP request with frame encoding
let http_req = builder.client_streaming(unpin_stream)?;
let timeout = request::get_timeout(&http_req);

// Split the request to get parts and body separately
let (parts, frame_encoder) = http_req.into_parts();

// Construct reqwest request manually
let mut req_builder = self.client.request(parts.method, parts.uri.to_string());

// Add headers
for (name, value) in &parts.headers {
req_builder = req_builder.header(name.clone(), value.clone());
}

// Wrap frame encoder stream directly in reqwest Body (true streaming!)
let body = Body::wrap_stream(frame_encoder);
req_builder = req_builder.body(body);

// Set timeout
if let Some(timeout) = timeout {
req_builder = req_builder.timeout(timeout);
}

// Execute request
let response = req_builder.send().await?;

// Check response status
if response.status().is_success() {
let status = response.status();
let headers = response.headers().clone();

// Read response body
let body_bytes = response.bytes().await?;

// Decode response
let message: O = self.common.message_codec.decode(&body_bytes)?;

Ok(ClientStreamingResponse {
status,
metadata: headers,
message,
})
} else {
todo!("Handle error response status: {}", response.status())
}
}
}
Expand Down
109 changes: 106 additions & 3 deletions connectrpc/src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ use crate::header::{
CONTENT_TYPE,
};
use crate::metadata::Metadata;
use crate::stream::FrameEncoder;
use futures_util::Stream;
use http::uri::{Authority, Scheme};
use http::{HeaderMap, HeaderName, HeaderValue, Method, Uri};
use std::time::Duration;
Expand Down Expand Up @@ -269,7 +271,7 @@ impl Builder {
/// POST request will be used.
///
/// https://connectrpc.com/docs/protocol#streaming-request
pub fn streaming(mut self, message: Vec<u8>) -> Result<http::Request<Vec<u8>>> {
pub fn server_streaming(mut self, message: Vec<u8>) -> Result<http::Request<Vec<u8>>> {
self.validate()?;
let mut req = self.request_base(Method::POST, message)?;
let headers = req.headers_mut();
Expand All @@ -289,6 +291,40 @@ impl Builder {
Ok(req)
}

/// Build a client streaming request with the given message stream as the body.
/// POST request will be used.
///
/// The message_stream should yield encoded messages (Vec<u8>).
/// This method wraps the stream in Connect protocol frames.
///
/// https://connectrpc.com/docs/protocol#streaming-request
pub fn client_streaming<S>(
mut self,
message_stream: S,
) -> Result<http::Request<FrameEncoder<S>>>
where
S: Stream<Item = Vec<u8>> + Send + Sync + Unpin,
{
self.validate()?;
let encoder = FrameEncoder::new(message_stream);
let mut req = self.request_base(Method::POST, encoder)?;
let headers = req.headers_mut();
headers.insert(
CONTENT_TYPE,
HeaderValue::from_str(&format!(
"application/connect+{}",
self.message_codec.as_ref().unwrap().name()
))?,
);
headers.insert(CONNECT_CONTENT_ENCODING, self.request_content_encoding());
// Streaming-Accept-Encoding → "connect-accept-encoding" Content-Coding [...]
for value in std::mem::take(&mut self.accept_encodings) {
req.headers_mut().append(CONNECT_ACCEPT_ENCODING, value);
}

Ok(req)
}

/// Validate that all required fields are set.
///
/// This method will be called automatically by the build methods.
Expand Down Expand Up @@ -408,15 +444,15 @@ where
}
}

pub struct StreamingRequest<T>
pub struct ServerStreamingRequest<T>
where
T: Send + Sync,
{
metadata: HeaderMap,
message: T,
}

impl<T> StreamingRequest<T>
impl<T> ServerStreamingRequest<T>
where
T: Send + Sync,
{
Expand Down Expand Up @@ -465,6 +501,73 @@ where
}
}

pub struct ClientStreamingRequest<T, S>
where
T: Send + Sync,
S: Stream<Item = T> + Send + Sync,
{
metadata: HeaderMap,
message_stream: S,
_phantom: std::marker::PhantomData<T>,
}

impl<T, S> ClientStreamingRequest<T, S>
where
T: Send + Sync,
S: Stream<Item = T> + Send + Sync,
{
/// Create a new client streaming request with the given message stream and empty metadata.
pub fn new(message_stream: S) -> Self {
Self {
metadata: HeaderMap::new(),
message_stream,
_phantom: std::marker::PhantomData,
}
}

pub fn with_metadata(mut self, metadata: HeaderMap) -> Self {
self.metadata = metadata;
self
}

/// Returns a reference to the metadata.
pub fn metadata(&self) -> &HeaderMap {
&self.metadata
}

/// Returns a mutable reference to the metadata.
pub fn metadata_mut(&mut self) -> &mut HeaderMap {
&mut self.metadata
}

/// Decomposes the request into its parts.
pub fn into_parts(self) -> Parts<S> {
Parts {
metadata: self.metadata,
body: self.message_stream,
}
}

/// Creates a request from its parts.
pub fn from_parts(parts: Parts<S>) -> Self {
Self {
metadata: parts.metadata,
message_stream: parts.body,
_phantom: std::marker::PhantomData,
}
}

/// Consumes the request, returning the message stream.
pub fn into_message_stream(self) -> S {
self.message_stream
}

/// Returns a reference to the message stream.
pub fn message_stream(&self) -> &S {
&self.message_stream
}
}

#[cfg(feature = "client")]
pub(crate) fn get_timeout<T>(req: &http::Request<T>) -> Option<Duration> {
req.headers()
Expand Down
44 changes: 43 additions & 1 deletion connectrpc/src/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ use crate::connect::DecodeMessage;
use crate::header::HeaderMap;
use crate::stream::ConnectFrame;
use crate::{Codec, Result};
use core::fmt;
use futures_util::Stream;
use futures_util::StreamExt;
use core::fmt;
use std::pin::Pin;

/// The parts of a unary response.
Expand Down Expand Up @@ -172,6 +172,48 @@ where
}
}

#[derive(Debug)]
pub struct ClientStreamingResponse<T>
where
T: Send + Sync,
{
pub status: http::StatusCode,
pub metadata: HeaderMap,
pub message: T,
}

impl<T> ClientStreamingResponse<T>
where
T: Send + Sync,
{
/// Returns the http status code of the response.
pub fn status(&self) -> http::StatusCode {
self.status
}

/// Returns the metadata of the response.
pub fn metadata(&self) -> &HeaderMap {
&self.metadata
}

pub fn into_message(self) -> T {
self.message
}

pub fn into_parts(self) -> Parts<T> {
Parts {
status: self.status,
metadata: self.metadata,
message: self.message,
}
}

/// Returns a reference to the message of the response.
pub fn message(&self) -> &T {
&self.message
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
Loading