Skip to content

Commit 6e48a7f

Browse files
authored
Merge pull request #2 from nikola-jokic/nikola-jokic/client-stream
Implement client stream
2 parents 7b0a2da + 349972d commit 6e48a7f

File tree

5 files changed

+467
-15
lines changed

5 files changed

+467
-15
lines changed

connectrpc/src/client/mod.rs

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
11
#[cfg(feature = "reqwest")]
22
pub mod reqwest;
3+
#[cfg(feature = "async")]
4+
use futures_util::Stream;
35
#[cfg(feature = "reqwest")]
46
pub use reqwest::ReqwestClient;
57

68
use crate::Result;
79
use crate::codec::Codec;
810
use crate::connect::{DecodeMessage, EncodeMessage};
911
use crate::error::Error;
10-
#[cfg(feature = "async")]
11-
use crate::request::StreamingRequest;
1212
use crate::request::{self, UnaryRequest};
1313
#[cfg(feature = "async")]
14-
use crate::response::ServerStreamingResponse;
14+
use crate::request::{ClientStreamingRequest, ServerStreamingRequest};
1515
use crate::response::UnaryResponse;
16+
#[cfg(feature = "async")]
17+
use crate::response::{ClientStreamingResponse, ServerStreamingResponse};
1618
use bytes::Bytes;
1719
use http::Uri;
1820

@@ -54,8 +56,17 @@ where
5456
fn call_server_streaming(
5557
&self,
5658
path: &str,
57-
req: StreamingRequest<I>,
59+
req: ServerStreamingRequest<I>,
5860
) -> impl Future<Output = Result<ServerStreamingResponse<O>>>;
61+
62+
fn call_client_streaming<S>(
63+
&self,
64+
path: &str,
65+
req: ClientStreamingRequest<I, S>,
66+
) -> impl Future<Output = Result<ClientStreamingResponse<O>>>
67+
where
68+
S: Stream<Item = I> + Send + Sync + 'static,
69+
I: 'static;
5970
}
6071

6172
#[cfg(feature = "sync")]
@@ -165,7 +176,7 @@ impl CommonClient {
165176
pub fn streaming_request<Req>(
166177
&self,
167178
path: &str,
168-
req: StreamingRequest<Req>,
179+
req: ServerStreamingRequest<Req>,
169180
) -> Result<http::Request<Vec<u8>>>
170181
where
171182
Req: EncodeMessage,
@@ -177,7 +188,7 @@ impl CommonClient {
177188
.rpc_path(path)?
178189
.message_codec(self.message_codec)
179190
.append_metadata(metadata)
180-
.streaming(body)
191+
.server_streaming(body)
181192
}
182193

183194
/// Parses a unary response from the given HTTP response.

connectrpc/src/client/reqwest.rs

Lines changed: 87 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@ use crate::Result;
33
use crate::client::AsyncStreamingClient;
44
use crate::codec::Codec;
55
use crate::connect::{DecodeMessage, EncodeMessage};
6-
use crate::request::{self, StreamingRequest, UnaryRequest};
7-
use crate::response::{ServerStreamingResponse, UnaryResponse};
8-
use crate::stream::ConnectFrame;
6+
use crate::request::{self, ClientStreamingRequest, ServerStreamingRequest, UnaryRequest};
7+
use crate::response::{ClientStreamingResponse, ServerStreamingResponse, UnaryResponse};
8+
use crate::stream::{ConnectFrame, UnpinStream};
99
use bytes::Bytes;
10+
use futures_util::Stream;
11+
use futures_util::stream::StreamExt;
1012
use http::Uri;
1113

1214
/// A client implementation using the `reqwest` HTTP client library.
@@ -67,7 +69,7 @@ where
6769
async fn call_server_streaming(
6870
&self,
6971
path: &str,
70-
req: StreamingRequest<I>,
72+
req: ServerStreamingRequest<I>,
7173
) -> Result<ServerStreamingResponse<O>> {
7274
let req = self.common.streaming_request(path, req)?;
7375
let timeout = request::get_timeout(&req);
@@ -85,7 +87,87 @@ where
8587
_marker: std::marker::PhantomData,
8688
})
8789
} else {
88-
todo!()
90+
todo!("Handle error response status: {}", response.status())
91+
}
92+
}
93+
94+
async fn call_client_streaming<S>(
95+
&self,
96+
path: &str,
97+
req: ClientStreamingRequest<I, S>,
98+
) -> Result<ClientStreamingResponse<O>>
99+
where
100+
S: Stream<Item = I> + Send + Sync + 'static,
101+
I: 'static,
102+
{
103+
use reqwest::Body;
104+
105+
let crate::request::Parts {
106+
metadata,
107+
body: message_stream,
108+
} = req.into_parts();
109+
110+
// Build the base request using the builder
111+
let builder = self
112+
.common
113+
.builder
114+
.clone()
115+
.rpc_path(path)?
116+
.message_codec(self.common.message_codec)
117+
.append_metadata(metadata);
118+
119+
// Encode each message
120+
let codec = self.common.message_codec;
121+
let encoded_stream = message_stream.map(move |msg| codec.encode(&msg));
122+
123+
// Wrap in UnpinStream for Unpin compatibility
124+
let unpin_stream = UnpinStream(Box::pin(encoded_stream));
125+
126+
// Create the HTTP request with frame encoding
127+
let http_req = builder.client_streaming(unpin_stream)?;
128+
let timeout = request::get_timeout(&http_req);
129+
130+
// Split the request to get parts and body separately
131+
let (parts, frame_encoder) = http_req.into_parts();
132+
133+
// Construct reqwest request manually
134+
let mut req_builder = self.client.request(parts.method, parts.uri.to_string());
135+
136+
// Add headers
137+
for (name, value) in &parts.headers {
138+
req_builder = req_builder.header(name.clone(), value.clone());
139+
}
140+
141+
// Wrap frame encoder stream directly in reqwest Body (true streaming!)
142+
let body = Body::wrap_stream(frame_encoder);
143+
req_builder = req_builder.body(body);
144+
145+
// Set timeout
146+
if let Some(timeout) = timeout {
147+
req_builder = req_builder.timeout(timeout);
148+
}
149+
150+
// Execute request
151+
let response = req_builder.send().await?;
152+
153+
// Check response status
154+
if response.status().is_success() {
155+
let status = response.status();
156+
let headers = response.headers().clone();
157+
158+
// Read response body
159+
let body_bytes = response.bytes().await?;
160+
161+
// Decode response
162+
let message: O = self.common.message_codec.decode(&body_bytes)?;
163+
164+
Ok(ClientStreamingResponse {
165+
status,
166+
metadata: headers,
167+
message,
168+
})
169+
} else {
170+
todo!("Handle error response status: {}", response.status())
89171
}
90172
}
91173
}

connectrpc/src/request.rs

Lines changed: 106 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ use crate::header::{
88
CONTENT_TYPE,
99
};
1010
use crate::metadata::Metadata;
11+
use crate::stream::FrameEncoder;
12+
use futures_util::Stream;
1113
use http::uri::{Authority, Scheme};
1214
use http::{HeaderMap, HeaderName, HeaderValue, Method, Uri};
1315
use std::time::Duration;
@@ -269,7 +271,7 @@ impl Builder {
269271
/// POST request will be used.
270272
///
271273
/// https://connectrpc.com/docs/protocol#streaming-request
272-
pub fn streaming(mut self, message: Vec<u8>) -> Result<http::Request<Vec<u8>>> {
274+
pub fn server_streaming(mut self, message: Vec<u8>) -> Result<http::Request<Vec<u8>>> {
273275
self.validate()?;
274276
let mut req = self.request_base(Method::POST, message)?;
275277
let headers = req.headers_mut();
@@ -289,6 +291,40 @@ impl Builder {
289291
Ok(req)
290292
}
291293

294+
/// Build a client streaming request with the given message stream as the body.
295+
/// POST request will be used.
296+
///
297+
/// The message_stream should yield encoded messages (Vec<u8>).
298+
/// This method wraps the stream in Connect protocol frames.
299+
///
300+
/// https://connectrpc.com/docs/protocol#streaming-request
301+
pub fn client_streaming<S>(
302+
mut self,
303+
message_stream: S,
304+
) -> Result<http::Request<FrameEncoder<S>>>
305+
where
306+
S: Stream<Item = Vec<u8>> + Send + Sync + Unpin,
307+
{
308+
self.validate()?;
309+
let encoder = FrameEncoder::new(message_stream);
310+
let mut req = self.request_base(Method::POST, encoder)?;
311+
let headers = req.headers_mut();
312+
headers.insert(
313+
CONTENT_TYPE,
314+
HeaderValue::from_str(&format!(
315+
"application/connect+{}",
316+
self.message_codec.as_ref().unwrap().name()
317+
))?,
318+
);
319+
headers.insert(CONNECT_CONTENT_ENCODING, self.request_content_encoding());
320+
// Streaming-Accept-Encoding → "connect-accept-encoding" Content-Coding [...]
321+
for value in std::mem::take(&mut self.accept_encodings) {
322+
req.headers_mut().append(CONNECT_ACCEPT_ENCODING, value);
323+
}
324+
325+
Ok(req)
326+
}
327+
292328
/// Validate that all required fields are set.
293329
///
294330
/// This method will be called automatically by the build methods.
@@ -408,15 +444,15 @@ where
408444
}
409445
}
410446

411-
pub struct StreamingRequest<T>
447+
pub struct ServerStreamingRequest<T>
412448
where
413449
T: Send + Sync,
414450
{
415451
metadata: HeaderMap,
416452
message: T,
417453
}
418454

419-
impl<T> StreamingRequest<T>
455+
impl<T> ServerStreamingRequest<T>
420456
where
421457
T: Send + Sync,
422458
{
@@ -465,6 +501,73 @@ where
465501
}
466502
}
467503

504+
pub struct ClientStreamingRequest<T, S>
505+
where
506+
T: Send + Sync,
507+
S: Stream<Item = T> + Send + Sync,
508+
{
509+
metadata: HeaderMap,
510+
message_stream: S,
511+
_phantom: std::marker::PhantomData<T>,
512+
}
513+
514+
impl<T, S> ClientStreamingRequest<T, S>
515+
where
516+
T: Send + Sync,
517+
S: Stream<Item = T> + Send + Sync,
518+
{
519+
/// Create a new client streaming request with the given message stream and empty metadata.
520+
pub fn new(message_stream: S) -> Self {
521+
Self {
522+
metadata: HeaderMap::new(),
523+
message_stream,
524+
_phantom: std::marker::PhantomData,
525+
}
526+
}
527+
528+
pub fn with_metadata(mut self, metadata: HeaderMap) -> Self {
529+
self.metadata = metadata;
530+
self
531+
}
532+
533+
/// Returns a reference to the metadata.
534+
pub fn metadata(&self) -> &HeaderMap {
535+
&self.metadata
536+
}
537+
538+
/// Returns a mutable reference to the metadata.
539+
pub fn metadata_mut(&mut self) -> &mut HeaderMap {
540+
&mut self.metadata
541+
}
542+
543+
/// Decomposes the request into its parts.
544+
pub fn into_parts(self) -> Parts<S> {
545+
Parts {
546+
metadata: self.metadata,
547+
body: self.message_stream,
548+
}
549+
}
550+
551+
/// Creates a request from its parts.
552+
pub fn from_parts(parts: Parts<S>) -> Self {
553+
Self {
554+
metadata: parts.metadata,
555+
message_stream: parts.body,
556+
_phantom: std::marker::PhantomData,
557+
}
558+
}
559+
560+
/// Consumes the request, returning the message stream.
561+
pub fn into_message_stream(self) -> S {
562+
self.message_stream
563+
}
564+
565+
/// Returns a reference to the message stream.
566+
pub fn message_stream(&self) -> &S {
567+
&self.message_stream
568+
}
569+
}
570+
468571
#[cfg(feature = "client")]
469572
pub(crate) fn get_timeout<T>(req: &http::Request<T>) -> Option<Duration> {
470573
req.headers()

connectrpc/src/response.rs

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@ use crate::connect::DecodeMessage;
22
use crate::header::HeaderMap;
33
use crate::stream::ConnectFrame;
44
use crate::{Codec, Result};
5+
use core::fmt;
56
use futures_util::Stream;
67
use futures_util::StreamExt;
7-
use core::fmt;
88
use std::pin::Pin;
99

1010
/// The parts of a unary response.
@@ -172,6 +172,48 @@ where
172172
}
173173
}
174174

175+
#[derive(Debug)]
176+
pub struct ClientStreamingResponse<T>
177+
where
178+
T: Send + Sync,
179+
{
180+
pub status: http::StatusCode,
181+
pub metadata: HeaderMap,
182+
pub message: T,
183+
}
184+
185+
impl<T> ClientStreamingResponse<T>
186+
where
187+
T: Send + Sync,
188+
{
189+
/// Returns the http status code of the response.
190+
pub fn status(&self) -> http::StatusCode {
191+
self.status
192+
}
193+
194+
/// Returns the metadata of the response.
195+
pub fn metadata(&self) -> &HeaderMap {
196+
&self.metadata
197+
}
198+
199+
pub fn into_message(self) -> T {
200+
self.message
201+
}
202+
203+
pub fn into_parts(self) -> Parts<T> {
204+
Parts {
205+
status: self.status,
206+
metadata: self.metadata,
207+
message: self.message,
208+
}
209+
}
210+
211+
/// Returns a reference to the message of the response.
212+
pub fn message(&self) -> &T {
213+
&self.message
214+
}
215+
}
216+
175217
#[cfg(test)]
176218
mod tests {
177219
use super::*;

0 commit comments

Comments
 (0)