Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions crates/transport-nats/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1012,6 +1012,21 @@ impl AsyncWrite for ParamWriter {
}
}

#[derive(Debug)]
pub struct NatsContext {
pub headers: Option<HeaderMap>,
pub subject: Subject,
}

impl Default for NatsContext {
fn default() -> Self {
Self {
headers: None,
subject: Subject::from(""),
}
}
}

impl wrpc_transport::Invoke for Client {
type Context = Option<HeaderMap>;
type Outgoing = ParamWriter;
Expand Down Expand Up @@ -1140,11 +1155,12 @@ async fn handle_message(
reply: tx,
payload,
headers,
subject,
..
}: async_nats::Message,
paths: &[Box<[Option<usize>]>],
tasks: Arc<JoinSet<()>>,
) -> anyhow::Result<(Option<HeaderMap>, SubjectWriter, Reader)> {
) -> anyhow::Result<(NatsContext, SubjectWriter, Reader)> {
let tx = tx.context("peer did not specify a reply subject")?;

let mut cmds = Vec::with_capacity(paths.len().saturating_add(1));
Expand Down Expand Up @@ -1180,7 +1196,7 @@ async fn handle_message(
.await
.context("failed to publish handshake accept")?;
Ok((
headers,
NatsContext { headers, subject },
SubjectWriter::new(
nats.clone(),
Subject::from(result_subject(&tx)),
Expand All @@ -1201,7 +1217,7 @@ async fn handle_message(
}

impl wrpc_transport::Serve for Client {
type Context = Option<HeaderMap>;
type Context = NatsContext;
type Outgoing = SubjectWriter;
type Incoming = Reader;

Expand Down
4 changes: 2 additions & 2 deletions examples/rust/hello-nats-server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ struct Args {
#[derive(Clone, Copy)]
struct Server;

impl bindings::exports::wrpc_examples::hello::handler::Handler<Option<async_nats::HeaderMap>>
impl bindings::exports::wrpc_examples::hello::handler::Handler<wrpc_transport_nats::NatsContext>
for Server
{
async fn hello(&self, _: Option<async_nats::HeaderMap>) -> anyhow::Result<String> {
async fn hello(&self, _: wrpc_transport_nats::NatsContext) -> anyhow::Result<String> {
Ok("hello from Rust".to_string())
}
}
Expand Down
4 changes: 2 additions & 2 deletions examples/rust/streams-nats-server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ struct Args {
#[derive(Clone, Copy)]
struct Server;

impl bindings::exports::wrpc_examples::streams::handler::Handler<Option<async_nats::HeaderMap>>
impl bindings::exports::wrpc_examples::streams::handler::Handler<wrpc_transport_nats::NatsContext>
for Server
{
async fn echo(
&self,
_cx: Option<async_nats::HeaderMap>,
_cx: wrpc_transport_nats::NatsContext,
Req { numbers, bytes }: Req,
) -> anyhow::Result<(
Pin<Box<dyn Stream<Item = Vec<u64>> + Send>>,
Expand Down
72 changes: 38 additions & 34 deletions tests/rust.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@ use wrpc_transport::frame::{AcceptExt as _, Oneshot};
use wrpc_transport::{Accept, InvokeExt as _, ResourceBorrow, ResourceOwn, ServeExt as _};

#[instrument(skip_all, ret)]
async fn assert_bindgen_async<C, I, S>(clt: Arc<I>, srv: Arc<S>) -> anyhow::Result<()>
async fn assert_bindgen_async<IC, SC, I, S>(clt: Arc<I>, srv: Arc<S>) -> anyhow::Result<()>
where
C: Send + Sync + Default,
I: wrpc::Invoke<Context = C> + 'static,
S: wrpc::Serve<Context = C> + Send + 'static,
IC: Send + Sync + Default,
SC: Send + Sync + Default,
I: wrpc::Invoke<Context = IC> + 'static,
S: wrpc::Serve<Context = SC> + Send + 'static,
{
let span = Span::current();
let (shutdown_tx, shutdown_rx) = oneshot::channel();
Expand Down Expand Up @@ -190,11 +191,12 @@ where
}

#[instrument(skip_all, ret)]
async fn assert_bindgen_sync<C, I, S>(clt: Arc<I>, srv: Arc<S>) -> anyhow::Result<()>
async fn assert_bindgen_sync<IC, SC, I, S>(clt: Arc<I>, srv: Arc<S>) -> anyhow::Result<()>
where
C: Send + Sync + Default,
I: wrpc::Invoke<Context = C> + 'static,
S: wrpc::Serve<Context = C> + Send + 'static,
IC: Send + Sync + Default,
SC: Send + Sync + Default,
I: wrpc::Invoke<Context = IC> + 'static,
S: wrpc::Serve<Context = SC> + Send + 'static,
{
let span = Span::current();
let (shutdown_tx, shutdown_rx) = oneshot::channel();
Expand Down Expand Up @@ -494,33 +496,34 @@ where
// TODO: Remove the need for this
sleep(Duration::from_secs(1)).await;

impl<C, T> exports::bar::Handler<C> for Component<T>
impl<IC, SC, T> exports::bar::Handler<SC> for Component<T>
where
C: Send + Sync + Default,
T: wrpc::Invoke<Context = C>,
IC: Send + Sync + Default,
SC: Send + Sync + Default,
T: wrpc::Invoke<Context = IC>,
{
async fn bar(&self, _cx: C) -> anyhow::Result<String> {
async fn bar(&self, _cx: SC) -> anyhow::Result<String> {
use shared::Abc;

info!("calling `wrpc-test:integration/test.foo.f`");
foo::foo(self.0.as_ref(), C::default(), "foo")
foo::foo(self.0.as_ref(), IC::default(), "foo")
.await
.context("failed to call `wrpc-test:integration/test.foo.foo`")?;

info!("calling `wrpc-test:integration/test.f`");
let v = f(self.0.as_ref(), C::default(), "foo")
let v = f(self.0.as_ref(), IC::default(), "foo")
.await
.context("failed to call `wrpc-test:integration/test.f`")?;
assert_eq!(v, 42);

info!("calling `wrpc-test:integration/shared.fallible`");
let v = shared::fallible(self.0.as_ref(), C::default())
let v = shared::fallible(self.0.as_ref(), IC::default())
.await
.context("failed to call `wrpc-test:integration/shared.fallible`")?;
assert_eq!(v, Ok(true));

info!("calling `wrpc-test:integration/shared.numbers`");
let v = shared::numbers(self.0.as_ref(), C::default())
let v = shared::numbers(self.0.as_ref(), IC::default())
.await
.context("failed to call `wrpc-test:integration/shared.numbers`")?;
assert_eq!(
Expand All @@ -541,14 +544,14 @@ where

info!("calling `wrpc-test:integration/shared.with-flags`");
let v =
shared::with_flags(self.0.as_ref(), C::default())
shared::with_flags(self.0.as_ref(), IC::default())
.await
.context("failed to call `wrpc-test:integration/shared.with-flags`")?;
assert_eq!(v, Abc::A | Abc::C);

let counter = Counter::new(
self.0.as_ref(),
C::default(),
IC::default(),
0,
)
.await
Expand All @@ -557,13 +560,13 @@ where
)?;
let counter_borrow = counter.as_borrow();

Counter::increment_by(self.0.as_ref(), C::default(), &counter_borrow, 1)
Counter::increment_by(self.0.as_ref(), IC::default(), &counter_borrow, 1)
.await
.context("failed to call `wrpc-test:integration/shared.[method]counter-increment-by`")?;

let count = Counter::get_count(
self.0.as_ref(),
C::default(),
IC::default(),
&counter_borrow,
)
.await
Expand All @@ -572,13 +575,13 @@ where
)?;
assert_eq!(count, 1);

Counter::increment_by(self.0.as_ref(), C::default(), &counter_borrow, 2)
Counter::increment_by(self.0.as_ref(), IC::default(), &counter_borrow, 2)
.await
.context("failed to call `wrpc-test:integration/shared.[method]counter-increment-by`")?;

let count = Counter::get_count(
self.0.as_ref(),
C::default(),
IC::default(),
&counter_borrow,
)
.await
Expand All @@ -587,14 +590,14 @@ where
)?;
assert_eq!(count, 3);

let second_counter = Counter::clone_counter(self.0.as_ref(), C::default(), &counter_borrow)
let second_counter = Counter::clone_counter(self.0.as_ref(), IC::default(), &counter_borrow)
.await
.context("failed to call `wrpc-test:integration/shared.[method]counter-clone-counter`")?;

let second_counter_borrow = second_counter.as_borrow();
let sum = Counter::sum(
self.0.as_ref(),
C::default(),
IC::default(),
&counter_borrow,
&second_counter_borrow,
)
Expand Down Expand Up @@ -651,7 +654,7 @@ where
// TODO: Remove the need for this
sleep(Duration::from_secs(2)).await;

let v = bar::bar(clt.as_ref(), C::default())
let v = bar::bar(clt.as_ref(), IC::default())
.await
.context("failed to call `wrpc-test:integration/test.bar.bar`")?;
assert_eq!(v, "bar");
Expand All @@ -663,11 +666,12 @@ where
}

#[instrument(skip_all, ret)]
async fn assert_dynamic<C, I, S>(clt: Arc<I>, srv: Arc<S>) -> anyhow::Result<()>
async fn assert_dynamic<IC, SC, I, S>(clt: Arc<I>, srv: Arc<S>) -> anyhow::Result<()>
where
C: Send + Sync + Default + 'static,
I: wrpc::Invoke<Context = C>,
S: wrpc::Serve<Context = C>,
IC: Send + Sync + Default + 'static,
SC: Send + Sync + Default + 'static,
I: wrpc::Invoke<Context = IC>,
S: wrpc::Serve<Context = SC>,
{
use core::pin::pin;

Expand Down Expand Up @@ -725,7 +729,7 @@ where
async {
info!("invoking `test.reset`");
clt.invoke_values_blocking::<_, _, (String,)>(
C::default(),
IC::default(),
"test",
"reset",
("arg",),
Expand All @@ -735,7 +739,7 @@ where
.expect_err("`test.reset` should have failed");
info!("invoking `test.reset`");
clt.invoke_values_blocking::<_, _, (String,)>(
C::default(),
IC::default(),
"test",
"reset",
("arg",),
Expand All @@ -745,7 +749,7 @@ where
.expect_err("`test.reset` should have failed");
info!("invoking `test.reset`");
clt.invoke_values_blocking::<_, _, (String,)>(
C::default(),
IC::default(),
"test",
"reset",
("arg",),
Expand Down Expand Up @@ -824,7 +828,7 @@ where
info!("invoking `test.sync`");
let returns = clt
.invoke_values_blocking(
C::default(),
IC::default(),
"test",
"sync",
(
Expand Down Expand Up @@ -966,7 +970,7 @@ where
info!("invoking `test.async`");
let (returns, io) = clt
.invoke_values(
C::default(),
IC::default(),
"test",
"async",
(a, b, c, d, e),
Expand Down
Loading