diff --git a/crates/transport-nats/src/lib.rs b/crates/transport-nats/src/lib.rs index 69ff45dba..042fb8868 100644 --- a/crates/transport-nats/src/lib.rs +++ b/crates/transport-nats/src/lib.rs @@ -1012,6 +1012,21 @@ impl AsyncWrite for ParamWriter { } } +#[derive(Debug)] +pub struct NatsContext { + pub headers: Option, + pub subject: Subject, +} + +impl Default for NatsContext { + fn default() -> Self { + Self { + headers: None, + subject: Subject::from(""), + } + } +} + impl wrpc_transport::Invoke for Client { type Context = Option; type Outgoing = ParamWriter; @@ -1140,11 +1155,12 @@ async fn handle_message( reply: tx, payload, headers, + subject, .. }: async_nats::Message, paths: &[Box<[Option]>], tasks: Arc>, -) -> anyhow::Result<(Option, SubjectWriter, Reader)> { +) -> anyhow::Result<(NatsContext, SubjectWriter, Reader)> { let tx = tx.context("peer did not specify a reply subject")?; let mut cmds = Vec::with_capacity(paths.len().saturating_add(1)); @@ -1180,7 +1196,7 @@ async fn handle_message( .await .context("failed to publish handshake accept")?; Ok(( - headers, + NatsContext { headers, subject }, SubjectWriter::new( nats.clone(), Subject::from(result_subject(&tx)), @@ -1201,7 +1217,7 @@ async fn handle_message( } impl wrpc_transport::Serve for Client { - type Context = Option; + type Context = NatsContext; type Outgoing = SubjectWriter; type Incoming = Reader; diff --git a/examples/rust/hello-nats-server/src/main.rs b/examples/rust/hello-nats-server/src/main.rs index fa8742985..875069be8 100644 --- a/examples/rust/hello-nats-server/src/main.rs +++ b/examples/rust/hello-nats-server/src/main.rs @@ -32,10 +32,10 @@ struct Args { #[derive(Clone, Copy)] struct Server; -impl bindings::exports::wrpc_examples::hello::handler::Handler> +impl bindings::exports::wrpc_examples::hello::handler::Handler for Server { - async fn hello(&self, _: Option) -> anyhow::Result { + async fn hello(&self, _: wrpc_transport_nats::NatsContext) -> anyhow::Result { Ok("hello from Rust".to_string()) } } diff --git a/examples/rust/streams-nats-server/src/main.rs b/examples/rust/streams-nats-server/src/main.rs index c5f15e889..5a9339bd3 100644 --- a/examples/rust/streams-nats-server/src/main.rs +++ b/examples/rust/streams-nats-server/src/main.rs @@ -35,12 +35,12 @@ struct Args { #[derive(Clone, Copy)] struct Server; -impl bindings::exports::wrpc_examples::streams::handler::Handler> +impl bindings::exports::wrpc_examples::streams::handler::Handler for Server { async fn echo( &self, - _cx: Option, + _cx: wrpc_transport_nats::NatsContext, Req { numbers, bytes }: Req, ) -> anyhow::Result<( Pin> + Send>>, diff --git a/tests/rust.rs b/tests/rust.rs index 699f287e5..7a628c25c 100644 --- a/tests/rust.rs +++ b/tests/rust.rs @@ -22,11 +22,12 @@ use wrpc_transport::frame::{AcceptExt as _, Oneshot}; use wrpc_transport::{Accept, InvokeExt as _, ResourceBorrow, ResourceOwn, ServeExt as _}; #[instrument(skip_all, ret)] -async fn assert_bindgen_async(clt: Arc, srv: Arc) -> anyhow::Result<()> +async fn assert_bindgen_async(clt: Arc, srv: Arc) -> anyhow::Result<()> where - C: Send + Sync + Default, - I: wrpc::Invoke + 'static, - S: wrpc::Serve + Send + 'static, + IC: Send + Sync + Default, + SC: Send + Sync + Default, + I: wrpc::Invoke + 'static, + S: wrpc::Serve + Send + 'static, { let span = Span::current(); let (shutdown_tx, shutdown_rx) = oneshot::channel(); @@ -190,11 +191,12 @@ where } #[instrument(skip_all, ret)] -async fn assert_bindgen_sync(clt: Arc, srv: Arc) -> anyhow::Result<()> +async fn assert_bindgen_sync(clt: Arc, srv: Arc) -> anyhow::Result<()> where - C: Send + Sync + Default, - I: wrpc::Invoke + 'static, - S: wrpc::Serve + Send + 'static, + IC: Send + Sync + Default, + SC: Send + Sync + Default, + I: wrpc::Invoke + 'static, + S: wrpc::Serve + Send + 'static, { let span = Span::current(); let (shutdown_tx, shutdown_rx) = oneshot::channel(); @@ -494,33 +496,34 @@ where // TODO: Remove the need for this sleep(Duration::from_secs(1)).await; - impl exports::bar::Handler for Component + impl exports::bar::Handler for Component where - C: Send + Sync + Default, - T: wrpc::Invoke, + IC: Send + Sync + Default, + SC: Send + Sync + Default, + T: wrpc::Invoke, { - async fn bar(&self, _cx: C) -> anyhow::Result { + async fn bar(&self, _cx: SC) -> anyhow::Result { use shared::Abc; info!("calling `wrpc-test:integration/test.foo.f`"); - foo::foo(self.0.as_ref(), C::default(), "foo") + foo::foo(self.0.as_ref(), IC::default(), "foo") .await .context("failed to call `wrpc-test:integration/test.foo.foo`")?; info!("calling `wrpc-test:integration/test.f`"); - let v = f(self.0.as_ref(), C::default(), "foo") + let v = f(self.0.as_ref(), IC::default(), "foo") .await .context("failed to call `wrpc-test:integration/test.f`")?; assert_eq!(v, 42); info!("calling `wrpc-test:integration/shared.fallible`"); - let v = shared::fallible(self.0.as_ref(), C::default()) + let v = shared::fallible(self.0.as_ref(), IC::default()) .await .context("failed to call `wrpc-test:integration/shared.fallible`")?; assert_eq!(v, Ok(true)); info!("calling `wrpc-test:integration/shared.numbers`"); - let v = shared::numbers(self.0.as_ref(), C::default()) + let v = shared::numbers(self.0.as_ref(), IC::default()) .await .context("failed to call `wrpc-test:integration/shared.numbers`")?; assert_eq!( @@ -541,14 +544,14 @@ where info!("calling `wrpc-test:integration/shared.with-flags`"); let v = - shared::with_flags(self.0.as_ref(), C::default()) + shared::with_flags(self.0.as_ref(), IC::default()) .await .context("failed to call `wrpc-test:integration/shared.with-flags`")?; assert_eq!(v, Abc::A | Abc::C); let counter = Counter::new( self.0.as_ref(), - C::default(), + IC::default(), 0, ) .await @@ -557,13 +560,13 @@ where )?; let counter_borrow = counter.as_borrow(); - Counter::increment_by(self.0.as_ref(), C::default(), &counter_borrow, 1) + Counter::increment_by(self.0.as_ref(), IC::default(), &counter_borrow, 1) .await .context("failed to call `wrpc-test:integration/shared.[method]counter-increment-by`")?; let count = Counter::get_count( self.0.as_ref(), - C::default(), + IC::default(), &counter_borrow, ) .await @@ -572,13 +575,13 @@ where )?; assert_eq!(count, 1); - Counter::increment_by(self.0.as_ref(), C::default(), &counter_borrow, 2) + Counter::increment_by(self.0.as_ref(), IC::default(), &counter_borrow, 2) .await .context("failed to call `wrpc-test:integration/shared.[method]counter-increment-by`")?; let count = Counter::get_count( self.0.as_ref(), - C::default(), + IC::default(), &counter_borrow, ) .await @@ -587,14 +590,14 @@ where )?; assert_eq!(count, 3); - let second_counter = Counter::clone_counter(self.0.as_ref(), C::default(), &counter_borrow) + let second_counter = Counter::clone_counter(self.0.as_ref(), IC::default(), &counter_borrow) .await .context("failed to call `wrpc-test:integration/shared.[method]counter-clone-counter`")?; let second_counter_borrow = second_counter.as_borrow(); let sum = Counter::sum( self.0.as_ref(), - C::default(), + IC::default(), &counter_borrow, &second_counter_borrow, ) @@ -651,7 +654,7 @@ where // TODO: Remove the need for this sleep(Duration::from_secs(2)).await; - let v = bar::bar(clt.as_ref(), C::default()) + let v = bar::bar(clt.as_ref(), IC::default()) .await .context("failed to call `wrpc-test:integration/test.bar.bar`")?; assert_eq!(v, "bar"); @@ -663,11 +666,12 @@ where } #[instrument(skip_all, ret)] -async fn assert_dynamic(clt: Arc, srv: Arc) -> anyhow::Result<()> +async fn assert_dynamic(clt: Arc, srv: Arc) -> anyhow::Result<()> where - C: Send + Sync + Default + 'static, - I: wrpc::Invoke, - S: wrpc::Serve, + IC: Send + Sync + Default + 'static, + SC: Send + Sync + Default + 'static, + I: wrpc::Invoke, + S: wrpc::Serve, { use core::pin::pin; @@ -725,7 +729,7 @@ where async { info!("invoking `test.reset`"); clt.invoke_values_blocking::<_, _, (String,)>( - C::default(), + IC::default(), "test", "reset", ("arg",), @@ -735,7 +739,7 @@ where .expect_err("`test.reset` should have failed"); info!("invoking `test.reset`"); clt.invoke_values_blocking::<_, _, (String,)>( - C::default(), + IC::default(), "test", "reset", ("arg",), @@ -745,7 +749,7 @@ where .expect_err("`test.reset` should have failed"); info!("invoking `test.reset`"); clt.invoke_values_blocking::<_, _, (String,)>( - C::default(), + IC::default(), "test", "reset", ("arg",), @@ -824,7 +828,7 @@ where info!("invoking `test.sync`"); let returns = clt .invoke_values_blocking( - C::default(), + IC::default(), "test", "sync", ( @@ -966,7 +970,7 @@ where info!("invoking `test.async`"); let (returns, io) = clt .invoke_values( - C::default(), + IC::default(), "test", "async", (a, b, c, d, e),