Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions example-service/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
#![deny(warnings, unused, dead_code)]

use clap::Parser;
use service::{WorldClient, init_tracing};
Expand Down Expand Up @@ -34,10 +35,13 @@ async fn main() -> anyhow::Result<()> {
let client = WorldClient::new(client::Config::default(), transport.await?).spawn();

let hello = async move {
let mut context = context::current();
let mut context2 = context::current();

// Send the request twice, just to be safe! ;)
tokio::select! {
hello1 = client.hello(context::current(), format!("{}1", flags.name)) => { hello1 }
hello2 = client.hello(context::current(), format!("{}2", flags.name)) => { hello2 }
hello1 = client.hello(&mut context, format!("{}1", flags.name)) => { hello1 }
hello2 = client.hello(&mut context2, format!("{}2", flags.name)) => { hello2 }
}
}
.instrument(tracing::info_span!("Two Hellos"))
Expand Down
2 changes: 2 additions & 0 deletions example-service/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.

#![deny(warnings, unused, dead_code)]

use opentelemetry::trace::TracerProvider as _;
use tracing_subscriber::{fmt::format::FmtSpan, prelude::*};

Expand Down
4 changes: 3 additions & 1 deletion example-service/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
#![deny(warnings, unused, dead_code)]

use clap::Parser;
use futures::{future, prelude::*};
Expand Down Expand Up @@ -35,7 +36,8 @@ struct Flags {
struct HelloServer(SocketAddr);

impl World for HelloServer {
async fn hello(self, _: context::Context, name: String) -> String {
type Context = context::Context;
async fn hello(self, _: &mut Self::Context, name: String) -> String {
let sleep_time =
Duration::from_millis(Uniform::new_inclusive(1, 10).sample(&mut thread_rng()));
time::sleep(sleep_time).await;
Expand Down
88 changes: 48 additions & 40 deletions plugins/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,9 @@
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.

#![deny(warnings, unused, dead_code)]
#![recursion_limit = "512"]

extern crate proc_macro;
extern crate proc_macro2;
extern crate quote;
extern crate syn;

use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::{ToTokens, format_ident, quote};
Expand Down Expand Up @@ -375,7 +371,7 @@ fn collect_cfg_attrs(rpcs: &[RpcMethod]) -> Vec<Vec<&Attribute>> {
/// # Example
///
/// ```no_run
/// use tarpc::{client, transport, service, server::{self, Channel}, context::Context};
/// use tarpc::{client, context, transport, service, server::{self, Channel}, context::Context};
///
/// #[service]
/// pub trait Calculator {
Expand All @@ -401,7 +397,8 @@ fn collect_cfg_attrs(rpcs: &[RpcMethod]) -> Vec<Vec<&Attribute>> {
/// #[derive(Clone)]
/// struct CalculatorServer;
/// impl Calculator for CalculatorServer {
/// async fn add(self, context: Context, a: i32, b: i32) -> i32 {
/// type Context = context::Context;
/// async fn add(self, context: &mut Self::Context, a: i32, b: i32) -> i32 {
/// a + b
/// }
/// }
Expand Down Expand Up @@ -547,26 +544,28 @@ impl ServiceGenerator<'_> {
} = self;

let rpc_fns = rpcs
.iter()
.zip(return_types.iter())
.map(
|(
RpcMethod {
attrs, ident, args, ..
},
output,
)| {
quote! {
#( #attrs )*
async fn #ident(self, context: ::tarpc::context::Context, #( #args ),*) -> #output;
}
},
);
.iter()
.zip(return_types.iter())
.map(
|(
RpcMethod {
attrs, ident, args, ..
},
output,
)| {
quote! {
#( #attrs )*
async fn #ident(self, context: &mut Self::Context, #( #args ),*) -> #output;
}
},
);

let stub_doc = format!("The stub trait for service [`{service_ident}`].");
quote! {
#( #attrs )*
#vis trait #service_ident: ::core::marker::Sized {
type Context: ::tarpc::context::ExtractContext<::tarpc::context::Context>;

#( #rpc_fns )*

/// Returns a serving function to use with
Expand All @@ -577,11 +576,11 @@ impl ServiceGenerator<'_> {
}

#[doc = #stub_doc]
#vis trait #client_stub_ident: ::tarpc::client::stub::Stub<Req = #request_ident, Resp = #response_ident> {
#vis trait #client_stub_ident<ClientCtx>: ::tarpc::client::stub::Stub<ClientCtx = ClientCtx, Req = #request_ident, Resp = #response_ident> {
}

impl<S> #client_stub_ident for S
where S: ::tarpc::client::stub::Stub<Req = #request_ident, Resp = #response_ident>
impl<S, ClientCtx> #client_stub_ident<ClientCtx> for S
where S: ::tarpc::client::stub::Stub<ClientCtx = ClientCtx, Req = #request_ident, Resp = #response_ident>
{
}
}
Expand Down Expand Up @@ -620,9 +619,9 @@ impl ServiceGenerator<'_> {
{
type Req = #request_ident;
type Resp = #response_ident;
type ServerCtx = S::Context;


async fn serve(self, ctx: ::tarpc::context::Context, req: #request_ident)
async fn serve(self, ctx: &mut Self::ServerCtx, req: #request_ident)
-> ::core::result::Result<#response_ident, ::tarpc::ServerError> {
match req {
#(
Expand Down Expand Up @@ -711,12 +710,19 @@ impl ServiceGenerator<'_> {

quote! {
#[allow(unused)]
#[derive(Clone, Debug)]
#[derive(Debug)]
/// The client stub that makes RPC calls to the server. All request methods return
/// [Futures](::core::future::Future).
#vis struct #client_ident<
Stub = ::tarpc::client::Channel<#request_ident, #response_ident>
>(Stub);
ClientCtx,
Stub = ::tarpc::client::Channel<#request_ident, #response_ident, ClientCtx>
>(Stub, ::std::marker::PhantomData<ClientCtx>);

impl<ClientCtx, Stub: ::std::clone::Clone> ::std::clone::Clone for #client_ident<ClientCtx,Stub> {
fn clone(&self) -> Self {
Self(self.0.clone(), ::std::marker::PhantomData)
}
}
}
}

Expand All @@ -730,32 +736,33 @@ impl ServiceGenerator<'_> {
} = self;

quote! {
impl #client_ident {
impl<ClientCtx> #client_ident<ClientCtx> {
/// Returns a new client stub that sends requests over the given transport.
#vis fn new<T>(config: ::tarpc::client::Config, transport: T)
-> ::tarpc::client::NewClient<
Self,
::tarpc::client::RequestDispatch<#request_ident, #response_ident, T>
::tarpc::client::RequestDispatch<#request_ident, #response_ident, ClientCtx, T>
>
where
T: ::tarpc::Transport<::tarpc::ClientMessage<#request_ident>, ::tarpc::Response<#response_ident>>
T: ::tarpc::Transport<::tarpc::ClientMessage<ClientCtx, #request_ident>, ::tarpc::Response<ClientCtx, #response_ident>>
{
let new_client = ::tarpc::client::new(config, transport);
::tarpc::client::NewClient {
client: #client_ident(new_client.client),
client: #client_ident(new_client.client, ::std::marker::PhantomData),
dispatch: new_client.dispatch,
}
}
}

impl<Stub> ::core::convert::From<Stub> for #client_ident<Stub>
impl<ClientCtx, Stub> ::core::convert::From<Stub> for #client_ident<ClientCtx, Stub>
where Stub: ::tarpc::client::stub::Stub<
Req = #request_ident,
Resp = #response_ident>
Resp = #response_ident,
ClientCtx = ClientCtx>
{
/// Returns a new client stub that sends requests over the given transport.
fn from(stub: Stub) -> Self {
#client_ident(stub)
#client_ident::<ClientCtx, Stub>(stub, ::std::marker::PhantomData)
}

}
Expand All @@ -778,15 +785,16 @@ impl ServiceGenerator<'_> {
} = self;

quote! {
impl<Stub> #client_ident<Stub>
impl<ClientCtx, Stub> #client_ident<ClientCtx, Stub>
where Stub: ::tarpc::client::stub::Stub<
Req = #request_ident,
Resp = #response_ident>
Resp = #response_ident,
ClientCtx = ClientCtx>
{
#(
#[allow(unused)]
#( #method_attrs )*
#vis fn #method_idents(&self, ctx: ::tarpc::context::Context, #( #args ),*)
#vis fn #method_idents<'a>(&'a self, ctx: &'a mut Stub::ClientCtx, #( #args ),*)
-> impl ::core::future::Future<Output = ::core::result::Result<#return_types, ::tarpc::client::RpcError>> + '_ {
let request = #request_ident::#camel_case_idents { #( #arg_pats ),* };
let resp = self.0.call(ctx, request);
Expand Down
17 changes: 10 additions & 7 deletions plugins/tests/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,16 @@ fn att_service_trait() {
}

impl Foo for () {
async fn two_part(self, _: context::Context, s: String, i: i32) -> (String, i32) {
type Context = context::Context;
async fn two_part(self, _: &mut Self::Context, s: String, i: i32) -> (String, i32) {
(s, i)
}

async fn bar(self, _: context::Context, s: String) -> String {
async fn bar(self, _: &mut Self::Context, s: String) -> String {
s
}

async fn baz(self, _: context::Context) {}
async fn baz(self, _: &mut Self::Context) {}
}
}

Expand All @@ -37,20 +38,21 @@ fn raw_idents() {
}

impl r#trait for () {
type Context = context::Context;
async fn r#await(
self,
_: context::Context,
_: &mut Self::Context,
r#struct: r#yield,
r#enum: i32,
) -> (r#yield, i32) {
(r#struct, r#enum)
}

async fn r#fn(self, _: context::Context, r#impl: r#yield) -> r#yield {
async fn r#fn(self, _: &mut Self::Context, r#impl: r#yield) -> r#yield {
r#impl
}

async fn r#async(self, _: context::Context) {}
async fn r#async(self, _: &mut Self::Context) {}
}
}

Expand All @@ -64,7 +66,8 @@ fn service_with_cfg_rpc() {
}

impl Foo for () {
async fn foo(self, _: context::Context) {}
type Context = context::Context;
async fn foo(self, _: &mut Self::Context) {}
}
}

Expand Down
6 changes: 4 additions & 2 deletions tarpc/examples/compression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
#![deny(warnings, unused, dead_code)]

use flate2::{Compression, read::DeflateDecoder, write::DeflateEncoder};
use futures::{Sink, SinkExt, Stream, StreamExt, TryStreamExt, prelude::*};
Expand Down Expand Up @@ -108,7 +109,8 @@ pub trait World {
struct HelloServer;

impl World for HelloServer {
async fn hello(self, _: context::Context, name: String) -> String {
type Context = context::Context;
async fn hello(self, _: &mut Self::Context, name: String) -> String {
format!("Hey, {name}!")
}
}
Expand All @@ -134,7 +136,7 @@ async fn main() -> anyhow::Result<()> {

println!(
"{}",
client.hello(context::current(), "friend".into()).await?
client.hello(&mut context::current(), "friend".into()).await?
);
Ok(())
}
8 changes: 5 additions & 3 deletions tarpc/examples/custom_transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.
#![deny(warnings, unused, dead_code)]

use futures::prelude::*;
use tarpc::context::Context;
use tarpc::{context};
use tarpc::serde_transport as transport;
use tarpc::server::{BaseChannel, Channel};
use tarpc::tokio_serde::formats::Bincode;
Expand All @@ -21,7 +22,8 @@ pub trait PingService {
struct Service;

impl PingService for Service {
async fn ping(self, _: Context) {}
type Context = context::Context;
async fn ping(self, _: &mut Self::Context) {}
}

#[tokio::main]
Expand Down Expand Up @@ -52,7 +54,7 @@ async fn main() -> anyhow::Result<()> {
let transport = transport::new(codec_builder.new_framed(conn), Bincode::default());
PingServiceClient::new(Default::default(), transport)
.spawn()
.ping(tarpc::context::current())
.ping(&mut context::current())
.await?;

Ok(())
Expand Down
Loading