diff --git a/Cargo.toml b/Cargo.toml index 81cae45861..583cb517f4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -98,3 +98,6 @@ rust_2024_compatibility = { level = "warn", priority = -1 } missing_docs = { level = "warn", priority = -1 } missing_debug_implementations = { level = "warn", priority = -1 } missing_copy_implementations = { level = "warn", priority = -1 } + +[workspace.lints.clippy] +manual_async_fn = { level = "allow", priority = -1 } diff --git a/benches/bench.rs b/benches/bench.rs index 06397cee5d..1f45168dcc 100644 --- a/benches/bench.rs +++ b/benches/bench.rs @@ -9,7 +9,7 @@ use helpers::fixed_client::{ ws_handshake, }; use helpers::{KIB, SUB_METHOD_NAME, UNSUB_METHOD_NAME}; -use jsonrpsee::types::{Id, RequestSer}; +use jsonrpsee::types::{Id, Request}; use pprof::criterion::{Output, PProfProfiler}; use tokio::runtime::Runtime as TokioRuntime; @@ -101,7 +101,7 @@ impl RequestType { } } -fn v2_serialize(req: RequestSer<'_>) -> String { +fn v2_serialize(req: Request<'_>) -> String { serde_json::to_string(&req).unwrap() } @@ -115,7 +115,7 @@ pub fn jsonrpsee_types_v2(crit: &mut Criterion) { b.iter(|| { let params = serde_json::value::RawValue::from_string("[1, 2]".to_string()).unwrap(); - let request = RequestSer::borrowed(&Id::Number(0), &"say_hello", Some(¶ms)); + let request = Request::borrowed("say_hello", Some(¶ms), Id::Number(0)); v2_serialize(request); }) }); @@ -128,7 +128,7 @@ pub fn jsonrpsee_types_v2(crit: &mut Criterion) { builder.insert(1u64).unwrap(); builder.insert(2u32).unwrap(); let params = builder.to_rpc_params().expect("Valid params"); - let request = RequestSer::borrowed(&Id::Number(0), &"say_hello", params.as_deref()); + let request = Request::borrowed("say_hello", params.as_deref(), Id::Number(0)); v2_serialize(request); }) }); @@ -138,7 +138,7 @@ pub fn jsonrpsee_types_v2(crit: &mut Criterion) { b.iter(|| { let params = serde_json::value::RawValue::from_string(r#"{"key": 1}"#.to_string()).unwrap(); - let request = RequestSer::borrowed(&Id::Number(0), &"say_hello", Some(¶ms)); + let request = Request::borrowed("say_hello", Some(¶ms), Id::Number(0)); v2_serialize(request); }) }); @@ -150,7 +150,7 @@ pub fn jsonrpsee_types_v2(crit: &mut Criterion) { let mut builder = ObjectParams::new(); builder.insert("key", 1u32).unwrap(); let params = builder.to_rpc_params().expect("Valid params"); - let request = RequestSer::borrowed(&Id::Number(0), &"say_hello", params.as_deref()); + let request = Request::borrowed("say_hello", params.as_deref(), Id::Number(0)); v2_serialize(request); }) }); diff --git a/client/http-client/Cargo.toml b/client/http-client/Cargo.toml index 27fc411ce2..58d198ed8e 100644 --- a/client/http-client/Cargo.toml +++ b/client/http-client/Cargo.toml @@ -31,7 +31,6 @@ serde = { workspace = true } serde_json = { workspace = true } thiserror = { workspace = true } tokio = { workspace = true, features = ["time"] } -tracing = { workspace = true } tower = { workspace = true, features = ["util"] } url = { workspace = true } diff --git a/client/http-client/src/client.rs b/client/http-client/src/client.rs index 1955f0452c..ddd291b95f 100644 --- a/client/http-client/src/client.rs +++ b/client/http-client/src/client.rs @@ -29,28 +29,32 @@ use std::fmt; use std::sync::Arc; use std::time::Duration; -use crate::transport::{self, Error as TransportError, HttpBackend, HttpTransportClient, HttpTransportClientBuilder}; -use crate::types::{NotificationSer, RequestSer, Response}; +use crate::rpc_service::RpcService; +use crate::transport::{self, Error as TransportError, HttpBackend, HttpTransportClientBuilder}; use crate::{HttpRequest, HttpResponse}; use async_trait::async_trait; use hyper::body::Bytes; -use hyper::http::HeaderMap; +use hyper::http::{Extensions, HeaderMap}; use jsonrpsee_core::client::{ - BatchResponse, ClientT, Error, IdKind, RequestIdManager, Subscription, SubscriptionClientT, generate_batch_id_range, + BatchResponse, ClientT, Error, IdKind, MethodResponse, RequestIdManager, Subscription, SubscriptionClientT, + generate_batch_id_range, }; +use jsonrpsee_core::middleware::layer::RpcLoggerLayer; +use jsonrpsee_core::middleware::{Batch, RpcServiceBuilder, RpcServiceT}; use jsonrpsee_core::params::BatchRequestBuilder; use jsonrpsee_core::traits::ToRpcParams; -use jsonrpsee_core::{BoxError, JsonRawValue, TEN_MB_SIZE_BYTES}; -use jsonrpsee_types::{ErrorObject, InvalidRequestId, ResponseSuccess, TwoPointZero}; +use jsonrpsee_core::{BoxError, TEN_MB_SIZE_BYTES}; +use jsonrpsee_types::{ErrorObject, InvalidRequestId, Notification, Request, ResponseSuccess, TwoPointZero}; use serde::de::DeserializeOwned; use tokio::sync::Semaphore; use tower::layer::util::Identity; use tower::{Layer, Service}; -use tracing::instrument; #[cfg(feature = "tls")] use crate::{CertificateStore, CustomCertStore}; +type Logger = tower::layer::util::Stack; + /// HTTP client builder. /// /// # Examples @@ -75,21 +79,21 @@ use crate::{CertificateStore, CustomCertStore}; /// } /// ``` #[derive(Clone, Debug)] -pub struct HttpClientBuilder { +pub struct HttpClientBuilder { max_request_size: u32, max_response_size: u32, request_timeout: Duration, #[cfg(feature = "tls")] certificate_store: CertificateStore, id_kind: IdKind, - max_log_length: u32, headers: HeaderMap, - service_builder: tower::ServiceBuilder, + service_builder: tower::ServiceBuilder, + rpc_middleware: RpcServiceBuilder, tcp_no_delay: bool, max_concurrent_requests: Option, } -impl HttpClientBuilder { +impl HttpClientBuilder { /// Set the maximum size of a request body in bytes. Default is 10 MiB. pub fn max_request_size(mut self, size: u32) -> Self { self.max_request_size = size; @@ -191,14 +195,6 @@ impl HttpClientBuilder { self } - /// Max length for logging for requests and responses in number characters. - /// - /// Logs bigger than this limit will be truncated. - pub fn set_max_logging_length(mut self, max: u32) -> Self { - self.max_log_length = max; - self - } - /// Set a custom header passed to the server with every request (default is none). /// /// The caller is responsible for checking that the headers do not conflict or are duplicated. @@ -215,17 +211,37 @@ impl HttpClientBuilder { self } + /// Set the RPC middleware. + pub fn set_rpc_middleware(self, rpc_builder: RpcServiceBuilder) -> HttpClientBuilder { + HttpClientBuilder { + #[cfg(feature = "tls")] + certificate_store: self.certificate_store, + id_kind: self.id_kind, + headers: self.headers, + max_request_size: self.max_request_size, + max_response_size: self.max_response_size, + service_builder: self.service_builder, + rpc_middleware: rpc_builder, + request_timeout: self.request_timeout, + tcp_no_delay: self.tcp_no_delay, + max_concurrent_requests: self.max_concurrent_requests, + } + } + /// Set custom tower middleware. - pub fn set_http_middleware(self, service_builder: tower::ServiceBuilder) -> HttpClientBuilder { + pub fn set_http_middleware( + self, + service_builder: tower::ServiceBuilder, + ) -> HttpClientBuilder { HttpClientBuilder { #[cfg(feature = "tls")] certificate_store: self.certificate_store, id_kind: self.id_kind, headers: self.headers, - max_log_length: self.max_log_length, max_request_size: self.max_request_size, max_response_size: self.max_response_size, service_builder, + rpc_middleware: self.rpc_middleware, request_timeout: self.request_timeout, tcp_no_delay: self.tcp_no_delay, max_concurrent_requests: self.max_concurrent_requests, @@ -233,16 +249,18 @@ impl HttpClientBuilder { } } -impl HttpClientBuilder +impl HttpClientBuilder where - L: Layer, + RpcMiddleware: Layer, Service = S2>, + for<'a> >>::Service: RpcServiceT, + HttpMiddleware: Layer, S: Service, Error = TransportError> + Clone, B: http_body::Body + Send + Unpin + 'static, B::Data: Send, B::Error: Into, { /// Build the HTTP client with target to connect to. - pub fn build(self, target: impl AsRef) -> Result, Error> { + pub fn build(self, target: impl AsRef) -> Result, Error> { let Self { max_request_size, max_response_size, @@ -251,17 +269,16 @@ where certificate_store, id_kind, headers, - max_log_length, service_builder, tcp_no_delay, + rpc_middleware, .. } = self; - let transport = HttpTransportClientBuilder { + let http = HttpTransportClientBuilder { max_request_size, max_response_size, headers, - max_log_length, tcp_no_delay, service_builder, #[cfg(feature = "tls")] @@ -275,15 +292,15 @@ where .map(|max_concurrent_requests| Arc::new(Semaphore::new(max_concurrent_requests))); Ok(HttpClient { - transport, + service: rpc_middleware.service(RpcService::new(http)), id_manager: Arc::new(RequestIdManager::new(id_kind)), - request_timeout, request_guard, + request_timeout, }) } } -impl Default for HttpClientBuilder { +impl Default for HttpClientBuilder { fn default() -> Self { Self { max_request_size: TEN_MB_SIZE_BYTES, @@ -292,38 +309,38 @@ impl Default for HttpClientBuilder { #[cfg(feature = "tls")] certificate_store: CertificateStore::Native, id_kind: IdKind::Number, - max_log_length: 4096, headers: HeaderMap::new(), service_builder: tower::ServiceBuilder::new(), + rpc_middleware: RpcServiceBuilder::default().rpc_logger(1024), tcp_no_delay: true, max_concurrent_requests: None, } } } -impl HttpClientBuilder { +impl HttpClientBuilder { /// Create a new builder. - pub fn new() -> HttpClientBuilder { + pub fn new() -> HttpClientBuilder { HttpClientBuilder::default() } } /// JSON-RPC HTTP Client that provides functionality to perform method calls and notifications. #[derive(Debug, Clone)] -pub struct HttpClient { - /// HTTP transport client. - transport: HttpTransportClient, - /// Request timeout. Defaults to 60sec. - request_timeout: Duration, +pub struct HttpClient { + /// HTTP service. + service: S, /// Request ID manager. id_manager: Arc, /// Concurrent requests limit guard. request_guard: Option>, + /// Request timeout. + request_timeout: Duration, } impl HttpClient { /// Create a builder for the HttpClient. - pub fn builder() -> HttpClientBuilder { + pub fn builder() -> HttpClientBuilder { HttpClientBuilder::new() } @@ -334,15 +351,10 @@ impl HttpClient { } #[async_trait] -impl ClientT for HttpClient +impl ClientT for HttpClient where - S: Service, Error = TransportError> + Send + Sync + Clone, - >::Future: Send, - B: http_body::Body + Send + Unpin + 'static, - B::Error: Into, - B::Data: Send, + S: RpcServiceT + Send + Sync, { - #[instrument(name = "notification", skip(self, params), level = "trace")] async fn notification(&self, method: &str, params: Params) -> Result<(), Error> where Params: ToRpcParams + Send, @@ -351,20 +363,17 @@ where Some(permit) => permit.acquire().await.ok(), None => None, }; - let params = params.to_rpc_params()?; - let notif = - serde_json::to_string(&NotificationSer::borrowed(&method, params.as_deref())).map_err(Error::ParseError)?; - - let fut = self.transport.send(notif); + let params = params.to_rpc_params()?.map(StdCow::Owned); - match tokio::time::timeout(self.request_timeout, fut).await { - Ok(Ok(ok)) => Ok(ok), - Err(_) => Err(Error::RequestTimeout), - Ok(Err(e)) => Err(Error::Transport(e.into())), - } + run_future_until_timeout( + self.service.notification(Notification::new(method.into(), params)), + self.request_timeout, + ) + .await + .map_err(|e| Error::Transport(e.into()))?; + Ok(()) } - #[instrument(name = "method_call", skip(self, params), level = "trace")] async fn request(&self, method: &str, params: Params) -> Result where R: DeserializeOwned, @@ -377,34 +386,20 @@ where let id = self.id_manager.next_request_id(); let params = params.to_rpc_params()?; - let request = RequestSer::borrowed(&id, &method, params.as_deref()); - let raw = serde_json::to_string(&request).map_err(Error::ParseError)?; - - let fut = self.transport.send_and_read_body(raw); - let body = match tokio::time::timeout(self.request_timeout, fut).await { - Ok(Ok(body)) => body, - Err(_e) => { - return Err(Error::RequestTimeout); - } - Ok(Err(e)) => { - return Err(Error::Transport(e.into())); - } - }; - - // NOTE: it's decoded first to `JsonRawValue` and then to `R` below to get - // a better error message if `R` couldn't be decoded. - let response = ResponseSuccess::try_from(serde_json::from_slice::>(&body)?)?; + let method_response = run_future_until_timeout( + self.service.call(Request::borrowed(method, params.as_deref(), id.clone())), + self.request_timeout, + ) + .await? + .into_method_call() + .expect("Method call must return a method call reponse; qed"); - let result = serde_json::from_str(response.result.get()).map_err(Error::ParseError)?; + let rp = ResponseSuccess::try_from(method_response.into_inner())?; - if response.id == id { - Ok(result) - } else { - Err(InvalidRequestId::NotPendingRequest(response.id.to_string()).into()) - } + let result = serde_json::from_str(rp.result.get()).map_err(Error::ParseError)?; + if rp.id == id { Ok(result) } else { Err(InvalidRequestId::NotPendingRequest(rp.id.to_string()).into()) } } - #[instrument(name = "batch", skip(self, batch), level = "trace")] async fn batch_request<'a, R>(&self, batch: BatchRequestBuilder<'a>) -> Result, Error> where R: DeserializeOwned + fmt::Debug + 'a, @@ -417,46 +412,42 @@ where let id = self.id_manager.next_request_id(); let id_range = generate_batch_id_range(id, batch.len() as u64)?; - let mut batch_request = Vec::with_capacity(batch.len()); + let mut batch_request = Batch::with_capacity(batch.len()); for ((method, params), id) in batch.into_iter().zip(id_range.clone()) { let id = self.id_manager.as_id_kind().into_id(id); - batch_request.push(RequestSer { + let req = Request { jsonrpc: TwoPointZero, - id, method: method.into(), params: params.map(StdCow::Owned), - }); + id, + extensions: Extensions::new(), + }; + batch_request.push(req); } - let fut = self.transport.send_and_read_body(serde_json::to_string(&batch_request).map_err(Error::ParseError)?); - - let body = match tokio::time::timeout(self.request_timeout, fut).await { - Ok(Ok(body)) => body, - Err(_e) => return Err(Error::RequestTimeout), - Ok(Err(e)) => return Err(Error::Transport(e.into())), - }; - - let json_rps: Vec> = serde_json::from_slice(&body).map_err(Error::ParseError)?; + let rp = run_future_until_timeout(self.service.batch(batch_request), self.request_timeout).await?; + let json_rps = rp.into_batch().expect("Batch must return a batch reponse; qed"); - let mut responses = Vec::with_capacity(json_rps.len()); - let mut successful_calls = 0; - let mut failed_calls = 0; + let mut batch_response = Vec::new(); + let mut success = 0; + let mut failed = 0; + // Fill the batch response with placeholder values. for _ in 0..json_rps.len() { - responses.push(Err(ErrorObject::borrowed(0, "", None))); + batch_response.push(Err(ErrorObject::borrowed(0, "", None))); } - for rp in json_rps { - let id = rp.id.try_parse_inner_as_number()?; + for rp in json_rps.into_iter() { + let id = rp.id().try_parse_inner_as_number()?; - let res = match ResponseSuccess::try_from(rp) { + let res = match ResponseSuccess::try_from(rp.into_inner()) { Ok(r) => { - let result = serde_json::from_str(r.result.get())?; - successful_calls += 1; - Ok(result) + let v = serde_json::from_str(r.result.get()).map_err(Error::ParseError)?; + success += 1; + Ok(v) } Err(err) => { - failed_calls += 1; + failed += 1; Err(err) } }; @@ -464,7 +455,7 @@ where let maybe_elem = id .checked_sub(id_range.start) .and_then(|p| p.try_into().ok()) - .and_then(|p: usize| responses.get_mut(p)); + .and_then(|p: usize| batch_response.get_mut(p)); if let Some(elem) = maybe_elem { *elem = res; @@ -473,22 +464,17 @@ where } } - Ok(BatchResponse::new(successful_calls, responses, failed_calls)) + Ok(BatchResponse::new(success, batch_response, failed)) } } #[async_trait] -impl SubscriptionClientT for HttpClient +impl SubscriptionClientT for HttpClient where - S: Service, Error = TransportError> + Send + Sync + Clone, - >::Future: Send, - B: http_body::Body + Send + Unpin + 'static, - B::Data: Send, - B::Error: Into, + S: RpcServiceT + Send + Sync, { /// Send a subscription request to the server. Not implemented for HTTP; will always return /// [`Error::HttpNotImplemented`]. - #[instrument(name = "subscription", fields(method = _subscribe_method), skip(self, _params, _subscribe_method, _unsubscribe_method), level = "trace")] async fn subscribe<'a, N, Params>( &self, _subscribe_method: &'a str, @@ -503,7 +489,6 @@ where } /// Subscribe to a specific method. Not implemented for HTTP; will always return [`Error::HttpNotImplemented`]. - #[instrument(name = "subscribe_method", fields(method = _method), skip(self, _method), level = "trace")] async fn subscribe_to_method<'a, N>(&self, _method: &'a str) -> Result, Error> where N: DeserializeOwned, @@ -511,3 +496,14 @@ where Err(Error::HttpNotImplemented) } } + +async fn run_future_until_timeout(fut: F, timeout: Duration) -> Result +where + F: std::future::Future>, +{ + match tokio::time::timeout(timeout, fut).await { + Ok(Ok(r)) => Ok(r), + Err(_) => Err(Error::RequestTimeout), + Ok(Err(e)) => Err(e), + } +} diff --git a/client/http-client/src/lib.rs b/client/http-client/src/lib.rs index 4d33a2c0c5..4bc7341ea7 100644 --- a/client/http-client/src/lib.rs +++ b/client/http-client/src/lib.rs @@ -36,6 +36,7 @@ #![cfg_attr(docsrs, feature(doc_cfg))] mod client; +mod rpc_service; /// HTTP transport. pub mod transport; diff --git a/client/http-client/src/rpc_service.rs b/client/http-client/src/rpc_service.rs new file mode 100644 index 0000000000..f0b5f3127c --- /dev/null +++ b/client/http-client/src/rpc_service.rs @@ -0,0 +1,78 @@ +use std::sync::Arc; + +use hyper::body::Bytes; +use jsonrpsee_core::{ + BoxError, JsonRawValue, + client::{Error, MethodResponse}, + middleware::{Batch, Notification, Request, RpcServiceT}, +}; +use jsonrpsee_types::Response; +use tower::Service; + +use crate::{ + HttpRequest, HttpResponse, + transport::{Error as TransportError, HttpTransportClient}, +}; + +#[derive(Clone, Debug)] +pub struct RpcService { + service: Arc>, +} + +impl RpcService { + pub fn new(service: HttpTransportClient) -> Self { + Self { service: Arc::new(service) } + } +} + +impl RpcServiceT for RpcService +where + HttpMiddleware: + Service, Error = TransportError> + Clone + Send + Sync + 'static, + HttpMiddleware::Future: Send, + B: http_body::Body + Send + 'static, + B::Data: Send, + B::Error: Into, +{ + type Error = Error; + type Response = MethodResponse; + + fn call<'a>(&self, request: Request<'a>) -> impl Future> + Send + 'a { + let service = self.service.clone(); + + async move { + let raw = serde_json::to_string(&request)?; + let bytes = service.send_and_read_body(raw).await.map_err(|e| Error::Transport(e.into()))?; + let json_rp: Response> = serde_json::from_slice(&bytes)?; + Ok(MethodResponse::method_call(json_rp.into_owned().into(), request.extensions)) + } + } + + fn batch<'a>(&self, batch: Batch<'a>) -> impl Future> + Send + 'a { + let service = self.service.clone(); + + async move { + let raw = serde_json::to_string(&batch)?; + let bytes = service.send_and_read_body(raw).await.map_err(|e| Error::Transport(e.into()))?; + let rp: Vec<_> = serde_json::from_slice::>>>(&bytes)? + .into_iter() + .map(|r| r.into_owned().into()) + .collect(); + + Ok(MethodResponse::batch(rp, batch.into_extensions())) + } + } + + fn notification<'a>( + &self, + notif: Notification<'a>, + ) -> impl Future> + Send + 'a { + let service = self.service.clone(); + + async move { + let raw = serde_json::to_string(¬if)?; + service.send(raw).await.map_err(|e| Error::Transport(e.into()))?; + Ok(MethodResponse::notification(notif.extensions)) + } + } +} diff --git a/client/http-client/src/tests.rs b/client/http-client/src/tests.rs index 5b620cf1b4..f6af48e623 100644 --- a/client/http-client/src/tests.rs +++ b/client/http-client/src/tests.rs @@ -173,16 +173,10 @@ async fn batch_request_with_failed_call_works() { assert_eq!(res.len(), 3); let successful_calls: Vec<_> = res.iter().filter_map(|r| r.as_ref().ok()).collect(); - let failed_calls: Vec<_> = res - .iter() - .filter_map(|r| match r { - Err(e) => Some(e), - _ => None, - }) - .collect(); + let failed_calls: Vec<_> = res.iter().filter_map(|r| r.clone().err()).collect(); assert_eq!(successful_calls, vec!["hello", "here's your swag"]); - assert_eq!(failed_calls, vec![&ErrorObject::from(ErrorCode::MethodNotFound)]); + assert_eq!(failed_calls, vec![ErrorObject::from(ErrorCode::MethodNotFound)]); } #[tokio::test] diff --git a/client/http-client/src/transport.rs b/client/http-client/src/transport.rs index 1d980a14f7..d9dce06982 100644 --- a/client/http-client/src/transport.rs +++ b/client/http-client/src/transport.rs @@ -13,7 +13,6 @@ use hyper_util::client::legacy::Client; use hyper_util::client::legacy::connect::HttpConnector; use hyper_util::rt::TokioExecutor; use jsonrpsee_core::BoxError; -use jsonrpsee_core::tracing::client::{rx_log_from_bytes, tx_log_from_str}; use jsonrpsee_core::{ TEN_MB_SIZE_BYTES, http_helpers::{self, HttpError}, @@ -93,10 +92,6 @@ pub struct HttpTransportClientBuilder { pub(crate) max_request_size: u32, /// Configurable max response body size pub(crate) max_response_size: u32, - /// Max length for logging for requests and responses - /// - /// Logs bigger than this limit will be truncated. - pub(crate) max_log_length: u32, /// Custom headers to pass with every request. pub(crate) headers: HeaderMap, /// Service builder @@ -119,7 +114,6 @@ impl HttpTransportClientBuilder { certificate_store: CertificateStore::Native, max_request_size: TEN_MB_SIZE_BYTES, max_response_size: TEN_MB_SIZE_BYTES, - max_log_length: 1024, headers: HeaderMap::new(), service_builder: tower::ServiceBuilder::new(), tcp_no_delay: true, @@ -163,21 +157,12 @@ impl HttpTransportClientBuilder { self } - /// Max length for logging for requests and responses in number characters. - /// - /// Logs bigger than this limit will be truncated. - pub fn set_max_logging_length(mut self, max: u32) -> Self { - self.max_log_length = max; - self - } - /// Configure a tower service. pub fn set_service(self, service: tower::ServiceBuilder) -> HttpTransportClientBuilder { HttpTransportClientBuilder { #[cfg(feature = "tls")] certificate_store: self.certificate_store, headers: self.headers, - max_log_length: self.max_log_length, max_request_size: self.max_request_size, max_response_size: self.max_response_size, service_builder: service, @@ -199,7 +184,6 @@ impl HttpTransportClientBuilder { certificate_store, max_request_size, max_response_size, - max_log_length, headers, service_builder, tcp_no_delay, @@ -286,7 +270,6 @@ impl HttpTransportClientBuilder { client: service_builder.service(client), max_request_size, max_response_size, - max_log_length, headers: cached_headers, }) } @@ -303,10 +286,6 @@ pub struct HttpTransportClient { max_request_size: u32, /// Configurable max response body size max_response_size: u32, - /// Max length for logging for requests and responses - /// - /// Logs bigger than this limit will be truncated. - max_log_length: u32, /// Custom headers to pass with every request. headers: HeaderMap, } @@ -340,22 +319,17 @@ where /// Send serialized message and wait until all bytes from the HTTP message body have been read. pub(crate) async fn send_and_read_body(&self, body: String) -> Result, Error> { - tx_log_from_str(&body, self.max_log_length); - let response = self.inner_send(body).await?; - let (parts, body) = response.into_parts(); + let (parts, body) = response.into_parts(); let (body, _is_single) = http_helpers::read_body(&parts.headers, body, self.max_response_size).await?; - rx_log_from_bytes(&body, self.max_log_length); - Ok(body) } /// Send serialized message without reading the HTTP message body. pub(crate) async fn send(&self, body: String) -> Result<(), Error> { - let _ = self.inner_send(body).await?; - + self.inner_send(body).await?; Ok(()) } } diff --git a/client/transport/Cargo.toml b/client/transport/Cargo.toml index 4c1e9d04a9..c641442d39 100644 --- a/client/transport/Cargo.toml +++ b/client/transport/Cargo.toml @@ -18,12 +18,12 @@ workspace = true [dependencies] jsonrpsee-core = { workspace = true, features = ["client"] } -tracing = { workspace = true } # optional thiserror = { workspace = true, optional = true } futures-util = { workspace = true, features = ["alloc"], optional = true } http = { workspace = true, optional = true } +tracing = { workspace = true, optional = true } tokio-util = { workspace = true, features = ["compat"], optional = true } tokio = { workspace = true, features = ["net", "time", "macros"], optional = true } pin-project = { workspace = true, optional = true } @@ -57,6 +57,7 @@ ws = [ "soketto", "pin-project", "thiserror", + "tracing", "url", ] web = [ diff --git a/client/wasm-client/Cargo.toml b/client/wasm-client/Cargo.toml index 574922b719..0d9b202ff8 100644 --- a/client/wasm-client/Cargo.toml +++ b/client/wasm-client/Cargo.toml @@ -20,6 +20,7 @@ workspace = true jsonrpsee-types = { workspace = true } jsonrpsee-client-transport = { workspace = true, features = ["web"] } jsonrpsee-core = { workspace = true, features = ["async-wasm-client"] } +tower = { workspace = true } [package.metadata.docs.rs] all-features = true diff --git a/client/wasm-client/src/lib.rs b/client/wasm-client/src/lib.rs index 5b9e8b97cc..ee959331c6 100644 --- a/client/wasm-client/src/lib.rs +++ b/client/wasm-client/src/lib.rs @@ -36,7 +36,11 @@ pub use jsonrpsee_types as types; use std::time::Duration; use jsonrpsee_client_transport::web; -use jsonrpsee_core::client::{ClientBuilder, Error, IdKind}; +use jsonrpsee_core::client::async_client::RpcService; +use jsonrpsee_core::client::{Error, IdKind}; +use jsonrpsee_core::middleware::{RpcServiceBuilder, layer::RpcLoggerLayer}; + +type Logger = tower::layer::util::Stack; /// Builder for [`Client`]. /// @@ -58,23 +62,23 @@ use jsonrpsee_core::client::{ClientBuilder, Error, IdKind}; /// } /// /// ``` -#[derive(Copy, Clone, Debug)] -pub struct WasmClientBuilder { +#[derive(Clone, Debug)] +pub struct WasmClientBuilder { id_kind: IdKind, max_concurrent_requests: usize, max_buffer_capacity_per_subscription: usize, - max_log_length: u32, request_timeout: Duration, + service_builder: RpcServiceBuilder, } impl Default for WasmClientBuilder { fn default() -> Self { Self { id_kind: IdKind::Number, - max_log_length: 4096, max_concurrent_requests: 256, max_buffer_capacity_per_subscription: 1024, request_timeout: Duration::from_secs(60), + service_builder: RpcServiceBuilder::default().rpc_logger(1024), } } } @@ -84,7 +88,9 @@ impl WasmClientBuilder { pub fn new() -> WasmClientBuilder { WasmClientBuilder::default() } +} +impl WasmClientBuilder { /// See documentation [`ClientBuilder::request_timeout`] (default is 60 seconds). pub fn request_timeout(mut self, timeout: Duration) -> Self { self.request_timeout = timeout; @@ -109,32 +115,39 @@ impl WasmClientBuilder { self } - /// Set maximum length for logging calls and responses. - /// - /// Logs bigger than this limit will be truncated. - pub fn set_max_logging_length(mut self, max: u32) -> Self { - self.max_log_length = max; - self + /// See documentation for [`ClientBuilder::set_rpc_middleware`]. + pub fn set_rpc_middleware(self, middleware: RpcServiceBuilder) -> WasmClientBuilder { + WasmClientBuilder { + id_kind: self.id_kind, + max_concurrent_requests: self.max_concurrent_requests, + max_buffer_capacity_per_subscription: self.max_buffer_capacity_per_subscription, + request_timeout: self.request_timeout, + service_builder: middleware, + } } /// Build the client with specified URL to connect to. - pub async fn build(self, url: impl AsRef) -> Result { + pub async fn build(self, url: impl AsRef) -> Result, Error> + where + L: tower::Layer + Clone + Send + Sync + 'static, + { let Self { - max_log_length, id_kind, request_timeout, max_concurrent_requests, max_buffer_capacity_per_subscription, + service_builder, } = self; let (sender, receiver) = web::connect(url).await.map_err(|e| Error::Transport(e.into()))?; - let builder = ClientBuilder::default() - .set_max_logging_length(max_log_length) + let client = Client::builder() .request_timeout(request_timeout) .id_format(id_kind) .max_buffer_capacity_per_subscription(max_buffer_capacity_per_subscription) - .max_concurrent_requests(max_concurrent_requests); + .max_concurrent_requests(max_concurrent_requests) + .set_rpc_middleware(service_builder) + .build_with_wasm(sender, receiver); - Ok(builder.build_with_wasm(sender, receiver)) + Ok(client) } } diff --git a/client/ws-client/Cargo.toml b/client/ws-client/Cargo.toml index a457359ac4..bfbb7a7cdf 100644 --- a/client/ws-client/Cargo.toml +++ b/client/ws-client/Cargo.toml @@ -22,6 +22,7 @@ jsonrpsee-types = { workspace = true } jsonrpsee-client-transport = { workspace = true, features = ["ws"] } jsonrpsee-core = { workspace = true, features = ["async-client"] } url = { workspace = true } +tower = { workspace = true } [dev-dependencies] tracing-subscriber = { workspace = true } diff --git a/client/ws-client/src/lib.rs b/client/ws-client/src/lib.rs index 255dfb1d3c..1935ffa23f 100644 --- a/client/ws-client/src/lib.rs +++ b/client/ws-client/src/lib.rs @@ -41,12 +41,15 @@ mod tests; pub use http::{HeaderMap, HeaderValue}; pub use jsonrpsee_core::client::Client as WsClient; pub use jsonrpsee_core::client::async_client::PingConfig; +pub use jsonrpsee_core::client::async_client::RpcService; +pub use jsonrpsee_core::middleware::RpcServiceBuilder; pub use jsonrpsee_types as types; use jsonrpsee_client_transport::ws::{AsyncRead, AsyncWrite, WsTransportClientBuilder}; use jsonrpsee_core::TEN_MB_SIZE_BYTES; use jsonrpsee_core::client::{ClientBuilder, Error, IdKind, MaybeSend, TransportReceiverT, TransportSenderT}; use std::time::Duration; +use tower::layer::util::Identity; use url::Url; #[cfg(feature = "tls")] @@ -81,7 +84,7 @@ use jsonrpsee_client_transport::ws::CertificateStore; /// /// ``` #[derive(Clone, Debug)] -pub struct WsClientBuilder { +pub struct WsClientBuilder { #[cfg(feature = "tls")] certificate_store: CertificateStore, max_request_size: u32, @@ -94,11 +97,11 @@ pub struct WsClientBuilder { max_buffer_capacity_per_subscription: usize, max_redirections: usize, id_kind: IdKind, - max_log_length: u32, tcp_no_delay: bool, + service_builder: RpcServiceBuilder, } -impl Default for WsClientBuilder { +impl Default for WsClientBuilder { fn default() -> Self { Self { #[cfg(feature = "tls")] @@ -113,18 +116,20 @@ impl Default for WsClientBuilder { max_buffer_capacity_per_subscription: 1024, max_redirections: 5, id_kind: IdKind::Number, - max_log_length: 4096, tcp_no_delay: true, + service_builder: RpcServiceBuilder::default(), } } } -impl WsClientBuilder { +impl WsClientBuilder { /// Create a new WebSocket client builder. - pub fn new() -> WsClientBuilder { + pub fn new() -> WsClientBuilder { WsClientBuilder::default() } +} +impl WsClientBuilder { /// Force to use a custom certificate store. /// /// # Optional @@ -259,29 +264,42 @@ impl WsClientBuilder { self } - /// Set maximum length for logging calls and responses. - /// - /// Logs bigger than this limit will be truncated. - pub fn set_max_logging_length(mut self, max: u32) -> Self { - self.max_log_length = max; - self - } - /// See documentation [`ClientBuilder::set_tcp_no_delay`] (default is true). pub fn set_tcp_no_delay(mut self, no_delay: bool) -> Self { self.tcp_no_delay = no_delay; self } + /// Set the RPC service builder. + pub fn set_rpc_middleware(self, service_builder: RpcServiceBuilder) -> WsClientBuilder { + WsClientBuilder { + #[cfg(feature = "tls")] + certificate_store: self.certificate_store, + max_request_size: self.max_request_size, + max_response_size: self.max_response_size, + request_timeout: self.request_timeout, + connection_timeout: self.connection_timeout, + ping_config: self.ping_config, + headers: self.headers, + max_concurrent_requests: self.max_concurrent_requests, + max_buffer_capacity_per_subscription: self.max_buffer_capacity_per_subscription, + max_redirections: self.max_redirections, + id_kind: self.id_kind, + tcp_no_delay: self.tcp_no_delay, + service_builder, + } + } + /// Build the [`WsClient`] with specified [`TransportSenderT`] [`TransportReceiverT`] parameters /// /// ## Panics /// /// Panics if being called outside of `tokio` runtime context. - pub fn build_with_transport(self, sender: S, receiver: R) -> WsClient + pub fn build_with_transport(self, sender: S, receiver: R) -> WsClient where S: TransportSenderT + Send, R: TransportReceiverT + Send, + RpcMiddleware: tower::Layer + Clone + Send + Sync + 'static, { let Self { max_concurrent_requests, @@ -289,8 +307,8 @@ impl WsClientBuilder { ping_config, max_buffer_capacity_per_subscription, id_kind, - max_log_length, tcp_no_delay, + service_builder, .. } = self; @@ -299,8 +317,8 @@ impl WsClientBuilder { .request_timeout(request_timeout) .max_concurrent_requests(max_concurrent_requests) .id_format(id_kind) - .set_max_logging_length(max_log_length) - .set_tcp_no_delay(tcp_no_delay); + .set_tcp_no_delay(tcp_no_delay) + .set_rpc_middleware(service_builder); if let Some(cfg) = ping_config { client = client.enable_ws_ping(cfg); @@ -314,9 +332,10 @@ impl WsClientBuilder { /// ## Panics /// /// Panics if being called outside of `tokio` runtime context. - pub async fn build_with_stream(self, url: impl AsRef, data_stream: T) -> Result + pub async fn build_with_stream(self, url: impl AsRef, data_stream: T) -> Result, Error> where T: AsyncRead + AsyncWrite + Unpin + MaybeSend + 'static, + RpcMiddleware: tower::Layer + Clone + Send + Sync + 'static, { let transport_builder = WsTransportClientBuilder { #[cfg(feature = "tls")] @@ -343,7 +362,10 @@ impl WsClientBuilder { /// ## Panics /// /// Panics if being called outside of `tokio` runtime context. - pub async fn build(self, url: impl AsRef) -> Result { + pub async fn build(self, url: impl AsRef) -> Result, Error> + where + RpcMiddleware: tower::Layer + Clone + Send + Sync + 'static, + { let transport_builder = WsTransportClientBuilder { #[cfg(feature = "tls")] certificate_store: self.certificate_store.clone(), diff --git a/client/ws-client/src/tests.rs b/client/ws-client/src/tests.rs index 51560ffd68..eb0fd25110 100644 --- a/client/ws-client/src/tests.rs +++ b/client/ws-client/src/tests.rs @@ -359,16 +359,10 @@ async fn batch_request_with_failed_call_works() { assert_eq!(res.len(), 3); let successful_calls: Vec<_> = res.iter().filter_map(|r| r.as_ref().ok()).collect(); - let failed_calls: Vec<_> = res - .iter() - .filter_map(|r| match r { - Err(e) => Some(e), - _ => None, - }) - .collect(); + let failed_calls: Vec<_> = res.iter().filter_map(|r| r.clone().err()).collect(); assert_eq!(successful_calls, vec!["hello", "here's your swag"]); - assert_eq!(failed_calls, vec![&ErrorObject::from(ErrorCode::MethodNotFound)]); + assert_eq!(failed_calls, vec![ErrorObject::from(ErrorCode::MethodNotFound)]); } #[tokio::test] diff --git a/core/Cargo.toml b/core/Cargo.toml index e6c37ab565..5d1849366f 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -21,11 +21,11 @@ async-trait = { workspace = true } jsonrpsee-types = { workspace = true } thiserror = { workspace = true } serde = { workspace = true } -serde_json = { workspace = true, features = ["std"] } +serde_json = { workspace = true, features = ["std", "raw_value"] } tracing = { workspace = true } # optional deps -futures-util = { workspace = true, optional = true } +futures-util = { workspace = true, optional = true, features = ["alloc"] } http = { workspace = true, optional = true } bytes = { workspace = true, optional = true } http-body = { workspace = true, optional = true } @@ -34,6 +34,7 @@ rustc-hash = { workspace = true, optional = true } rand = { workspace = true, optional = true } parking_lot = { workspace = true, optional = true } tokio = { workspace = true, optional = true } +tower = { workspace = true, optional = true } futures-timer = { workspace = true, optional = true } tokio-stream = { workspace = true, optional = true } pin-project = { workspace = true, optional = true } @@ -44,11 +45,11 @@ wasm-bindgen-futures = { workspace = true, optional = true } [features] default = [] http-helpers = ["bytes", "futures-util", "http-body", "http-body-util", "http"] -server = ["futures-util/alloc", "rustc-hash/std", "parking_lot", "rand", "tokio/rt", "tokio/sync", "tokio/macros", "tokio/time", "http"] -client = ["futures-util/sink", "tokio/sync"] +server = ["futures-util", "rustc-hash/std", "parking_lot", "rand", "tokio/rt", "tokio/sync", "tokio/macros", "tokio/time", "tower", "http", "pin-project"] +client = ["futures-util/sink", "tokio/sync", "tower", "pin-project", "http"] async-client = [ "client", - "futures-util/alloc", + "futures-util", "rustc-hash", "tokio/macros", "tokio/rt", @@ -59,7 +60,7 @@ async-client = [ ] async-wasm-client = [ "client", - "futures-util/alloc", + "futures-util", "wasm-bindgen-futures", "rustc-hash/std", "futures-timer/wasm-bindgen", diff --git a/core/src/client/async_client/helpers.rs b/core/src/client/async_client/helpers.rs index 21c54b85a9..34be025887 100644 --- a/core/src/client/async_client/helpers.rs +++ b/core/src/client/async_client/helpers.rs @@ -25,34 +25,33 @@ // DEALINGS IN THE SOFTWARE. use crate::client::async_client::manager::{RequestManager, RequestStatus}; -use crate::client::async_client::{Notification, LOG_TARGET}; -use crate::client::{subscription_channel, Error, RequestMessage, TransportSenderT, TrySubscriptionSendError}; +use crate::client::async_client::{LOG_TARGET, Notification}; +use crate::client::{ + Error, RawResponseOwned, RequestMessage, TransportSenderT, TrySubscriptionSendError, subscription_channel, +}; use crate::params::ArrayParams; use crate::traits::ToRpcParams; use futures_timer::Delay; use futures_util::future::{self, Either}; +use http::Extensions; +use serde_json::value::RawValue; use tokio::sync::oneshot; use jsonrpsee_types::response::SubscriptionError; use jsonrpsee_types::{ - ErrorObject, Id, InvalidRequestId, RequestSer, Response, ResponseSuccess, SubscriptionId, SubscriptionResponse, + ErrorObject, Id, InvalidRequestId, Request, Response, ResponseSuccess, SubscriptionId, SubscriptionResponse, + TwoPointZero, }; -use serde_json::Value as JsonValue; +use std::borrow::Cow; use std::ops::Range; -#[derive(Debug, Clone)] -pub(crate) struct InnerBatchResponse { - pub(crate) id: u64, - pub(crate) result: Result>, -} - /// Attempts to process a batch response. /// /// On success the result is sent to the frontend. pub(crate) fn process_batch_response( manager: &mut RequestManager, - rps: Vec, + rps: Vec, range: Range, ) -> Result<(), InvalidRequestId> { let mut responses = Vec::with_capacity(rps.len()); @@ -69,17 +68,18 @@ pub(crate) fn process_batch_response( for _ in range { let err_obj = ErrorObject::borrowed(0, "", None); - responses.push(Err(err_obj)); + responses.push(Response::new(jsonrpsee_types::ResponsePayload::error(err_obj), Id::Null).into()); } for rp in rps { + let id = rp.id().try_parse_inner_as_number()?; let maybe_elem = - rp.id.checked_sub(start_idx).and_then(|p| p.try_into().ok()).and_then(|p: usize| responses.get_mut(p)); + id.checked_sub(start_idx).and_then(|p| p.try_into().ok()).and_then(|p: usize| responses.get_mut(p)); if let Some(elem) = maybe_elem { - *elem = rp.result; + *elem = rp; } else { - return Err(InvalidRequestId::NotPendingRequest(rp.id.to_string())); + return Err(InvalidRequestId::NotPendingRequest(rp.id().to_string())); } } @@ -93,7 +93,7 @@ pub(crate) fn process_batch_response( /// `None` is returned. pub(crate) fn process_subscription_response( manager: &mut RequestManager, - response: SubscriptionResponse, + response: SubscriptionResponse>, ) -> Option> { let sub_id = response.params.subscription.into_owned(); let request_id = match manager.get_request_id_by_subscription_id(&sub_id) { @@ -128,7 +128,7 @@ pub(crate) fn process_subscription_response( /// It's possible that the user closed down the subscription before the actual close response is received pub(crate) fn process_subscription_close_response( manager: &mut RequestManager, - response: SubscriptionError, + response: SubscriptionError<&RawValue>, ) { let sub_id = response.params.subscription.into_owned(); match manager.get_request_id_by_subscription_id(&sub_id) { @@ -173,11 +173,10 @@ pub(crate) fn process_notification(manager: &mut RequestManager, notif: Notifica /// Returns `Err(_)` if the response couldn't be handled. pub(crate) fn process_single_response( manager: &mut RequestManager, - response: Response, + response: RawResponseOwned, max_capacity_per_subscription: usize, ) -> Result, InvalidRequestId> { - let response_id = response.id.clone().into_owned(); - let result = ResponseSuccess::try_from(response).map(|s| s.result).map_err(Error::Call); + let response_id = response.id().clone().into_owned(); match manager.request_status(&response_id) { RequestStatus::PendingMethodCall => { @@ -187,7 +186,7 @@ pub(crate) fn process_single_response( None => return Err(InvalidRequestId::NotPendingRequest(response_id.to_string())), }; - let _ = send_back_oneshot.send(result); + let _ = send_back_oneshot.send(Ok(response)); Ok(None) } RequestStatus::PendingSubscription => { @@ -195,16 +194,20 @@ pub(crate) fn process_single_response( .complete_pending_subscription(response_id.clone()) .ok_or(InvalidRequestId::NotPendingRequest(response_id.to_string()))?; - let sub_id = result.map(|r| SubscriptionId::try_from(r).ok()); + let result = ResponseSuccess::try_from(response.into_inner()); - let sub_id = match sub_id { - Ok(Some(sub_id)) => sub_id, - Ok(None) => { - let _ = send_back_oneshot.send(Err(Error::InvalidSubscriptionId)); + let json = match result { + Ok(s) => s.result, + Err(e) => { + let _ = send_back_oneshot.send(Err(Error::Call(e))); return Ok(None); } + }; + + let sub_id = match serde_json::from_str::(json.get()) { + Ok(s) => s.into_owned(), Err(e) => { - let _ = send_back_oneshot.send(Err(e)); + let _ = send_back_oneshot.send(Err(e.into())); return Ok(None); } }; @@ -254,7 +257,14 @@ pub(crate) fn build_unsubscribe_message( params.insert(sub_id).ok()?; let params = params.to_rpc_params().ok()?; - let raw = serde_json::to_string(&RequestSer::owned(unsub_req_id.clone(), unsub, params)).ok()?; + let raw = serde_json::to_string(&Request { + jsonrpc: TwoPointZero, + id: unsub_req_id.clone(), + method: unsub.into(), + params: params.map(Cow::Owned), + extensions: Extensions::new(), + }) + .ok()?; Some(RequestMessage { raw, id: unsub_req_id, send_back: None }) } diff --git a/core/src/client/async_client/manager.rs b/core/src/client/async_client/manager.rs index f638b80164..333d745804 100644 --- a/core/src/client/async_client/manager.rs +++ b/core/src/client/async_client/manager.rs @@ -33,17 +33,16 @@ //! - SubscriptionId: unique ID generated by server use std::{ - collections::{hash_map::Entry, HashMap}, + collections::{HashMap, hash_map::Entry}, ops::Range, }; use crate::{ - client::{BatchEntry, Error, SubscriptionReceiver, SubscriptionSender}, + client::{Error, RawResponseOwned, SubscriptionReceiver, SubscriptionSender}, error::RegisterMethodError, }; -use jsonrpsee_types::{Id, SubscriptionId}; +use jsonrpsee_types::{Id, InvalidRequestId, SubscriptionId}; use rustc_hash::FxHashMap; -use serde_json::value::Value as JsonValue; use tokio::sync::oneshot; #[derive(Debug)] @@ -66,8 +65,8 @@ pub(crate) enum RequestStatus { Invalid, } -type PendingCallOneshot = Option>>; -type PendingBatchOneshot = oneshot::Sender>, Error>>; +type PendingCallOneshot = Option>>; +type PendingBatchOneshot = oneshot::Sender, InvalidRequestId>>; type PendingSubscriptionOneshot = oneshot::Sender), Error>>; type SubscriptionSink = SubscriptionSender; type UnsubscribeMethod = String; @@ -315,11 +314,7 @@ impl RequestManager { /// /// Returns `Some` if the `request_id` was registered as a subscription otherwise `None`. pub(crate) fn as_subscription_mut(&mut self, request_id: &RequestId) -> Option<&mut SubscriptionSink> { - if let Some(Kind::Subscription((_, sink, _))) = self.requests.get_mut(request_id) { - Some(sink) - } else { - None - } + if let Some(Kind::Subscription((_, sink, _))) = self.requests.get_mut(request_id) { Some(sink) } else { None } } /// Get a mutable reference to underlying `Sink` in order to send incoming notifications to the subscription. @@ -341,14 +336,13 @@ impl RequestManager { mod tests { use crate::client::subscription_channel; - use super::{Error, RequestManager}; + use super::RequestManager; use jsonrpsee_types::{Id, SubscriptionId}; - use serde_json::Value as JsonValue; use tokio::sync::oneshot; #[test] fn insert_remove_pending_request_works() { - let (request_tx, _) = oneshot::channel::>(); + let (request_tx, _) = oneshot::channel(); let mut manager = RequestManager::new(); assert!(manager.insert_pending_call(Id::Number(0), Some(request_tx)).is_ok()); @@ -360,26 +354,30 @@ mod tests { let (pending_sub_tx, _) = oneshot::channel(); let (sub_tx, _) = subscription_channel(1); let mut manager = RequestManager::new(); - assert!(manager - .insert_pending_subscription(Id::Number(1), Id::Number(2), pending_sub_tx, "unsubscribe_method".into()) - .is_ok()); + assert!( + manager + .insert_pending_subscription(Id::Number(1), Id::Number(2), pending_sub_tx, "unsubscribe_method".into()) + .is_ok() + ); let (unsub_req_id, _send_back_oneshot, unsubscribe_method) = manager.complete_pending_subscription(Id::Number(1)).unwrap(); assert_eq!(unsub_req_id, Id::Number(2)); - assert!(manager - .insert_subscription( - Id::Number(1), - Id::Number(2), - SubscriptionId::Str("uniq_id_from_server".into()), - sub_tx, - unsubscribe_method - ) - .is_ok()); + assert!( + manager + .insert_subscription( + Id::Number(1), + Id::Number(2), + SubscriptionId::Str("uniq_id_from_server".into()), + sub_tx, + unsubscribe_method + ) + .is_ok() + ); assert!(manager.as_subscription_mut(&Id::Number(1)).is_some()); - assert!(manager - .remove_subscription(Id::Number(1), SubscriptionId::Str("uniq_id_from_server".into())) - .is_some()); + assert!( + manager.remove_subscription(Id::Number(1), SubscriptionId::Str("uniq_id_from_server".into())).is_some() + ); } #[test] @@ -389,12 +387,16 @@ mod tests { let (tx3, _) = oneshot::channel(); let (tx4, _) = oneshot::channel(); let mut manager = RequestManager::new(); - assert!(manager - .insert_pending_subscription(Id::Str("1".into()), Id::Str("1".into()), tx1, "unsubscribe_method".into()) - .is_err()); - assert!(manager - .insert_pending_subscription(Id::Str("0".into()), Id::Str("1".into()), tx2, "unsubscribe_method".into()) - .is_ok()); + assert!( + manager + .insert_pending_subscription(Id::Str("1".into()), Id::Str("1".into()), tx1, "unsubscribe_method".into()) + .is_err() + ); + assert!( + manager + .insert_pending_subscription(Id::Str("0".into()), Id::Str("1".into()), tx2, "unsubscribe_method".into()) + .is_ok() + ); assert!( manager .insert_pending_subscription( @@ -429,18 +431,22 @@ mod tests { let mut manager = RequestManager::new(); assert!(manager.insert_pending_call(Id::Number(0), Some(request_tx1)).is_ok()); assert!(manager.insert_pending_call(Id::Number(0), Some(request_tx2)).is_err()); - assert!(manager - .insert_pending_subscription(Id::Number(0), Id::Number(1), pending_sub_tx, "beef".to_string()) - .is_err()); - assert!(manager - .insert_subscription( - Id::Number(0), - Id::Number(99), - SubscriptionId::Num(137), - sub_tx, - "bibimbap".to_string() - ) - .is_err()); + assert!( + manager + .insert_pending_subscription(Id::Number(0), Id::Number(1), pending_sub_tx, "beef".to_string()) + .is_err() + ); + assert!( + manager + .insert_subscription( + Id::Number(0), + Id::Number(99), + SubscriptionId::Num(137), + sub_tx, + "bibimbap".to_string() + ) + .is_err() + ); assert!(manager.remove_subscription(Id::Number(0), SubscriptionId::Num(137)).is_none()); assert!(manager.complete_pending_subscription(Id::Number(0)).is_none()); @@ -455,23 +461,29 @@ mod tests { let (sub_tx, _) = subscription_channel(1); let mut manager = RequestManager::new(); - assert!(manager - .insert_pending_subscription(Id::Number(99), Id::Number(100), pending_sub_tx1, "beef".to_string()) - .is_ok()); + assert!( + manager + .insert_pending_subscription(Id::Number(99), Id::Number(100), pending_sub_tx1, "beef".to_string()) + .is_ok() + ); assert!(manager.insert_pending_call(Id::Number(99), Some(request_tx)).is_err()); - assert!(manager - .insert_pending_subscription(Id::Number(99), Id::Number(1337), pending_sub_tx2, "vegan".to_string()) - .is_err()); - - assert!(manager - .insert_subscription( - Id::Number(99), - Id::Number(100), - SubscriptionId::Num(0), - sub_tx, - "bibimbap".to_string() - ) - .is_err()); + assert!( + manager + .insert_pending_subscription(Id::Number(99), Id::Number(1337), pending_sub_tx2, "vegan".to_string()) + .is_err() + ); + + assert!( + manager + .insert_subscription( + Id::Number(99), + Id::Number(100), + SubscriptionId::Num(0), + sub_tx, + "bibimbap".to_string() + ) + .is_err() + ); assert!(manager.remove_subscription(Id::Number(99), SubscriptionId::Num(0)).is_none()); assert!(manager.complete_pending_call(Id::Number(99)).is_none()); @@ -487,15 +499,33 @@ mod tests { let mut manager = RequestManager::new(); - assert!(manager - .insert_subscription(Id::Number(3), Id::Number(4), SubscriptionId::Num(0), sub_tx1, "bibimbap".to_string()) - .is_ok()); - assert!(manager - .insert_subscription(Id::Number(3), Id::Number(4), SubscriptionId::Num(1), sub_tx2, "bibimbap".to_string()) - .is_err()); - assert!(manager - .insert_pending_subscription(Id::Number(3), Id::Number(4), pending_sub_tx, "beef".to_string()) - .is_err()); + assert!( + manager + .insert_subscription( + Id::Number(3), + Id::Number(4), + SubscriptionId::Num(0), + sub_tx1, + "bibimbap".to_string() + ) + .is_ok() + ); + assert!( + manager + .insert_subscription( + Id::Number(3), + Id::Number(4), + SubscriptionId::Num(1), + sub_tx2, + "bibimbap".to_string() + ) + .is_err() + ); + assert!( + manager + .insert_pending_subscription(Id::Number(3), Id::Number(4), pending_sub_tx, "beef".to_string()) + .is_err() + ); assert!(manager.insert_pending_call(Id::Number(3), Some(request_tx)).is_err()); assert!(manager.remove_subscription(Id::Number(3), SubscriptionId::Num(7)).is_none()); diff --git a/core/src/client/async_client/mod.rs b/core/src/client/async_client/mod.rs index 3ed6353dea..663510c6ab 100644 --- a/core/src/client/async_client/mod.rs +++ b/core/src/client/async_client/mod.rs @@ -28,45 +28,50 @@ mod helpers; mod manager; +mod rpc_service; mod utils; -use crate::client::async_client::helpers::{process_subscription_close_response, InnerBatchResponse}; +pub use rpc_service::{Error as RpcServiceError, RpcService}; + +use crate::JsonRawValue; +use crate::client::async_client::helpers::process_subscription_close_response; use crate::client::async_client::utils::MaybePendingFutures; use crate::client::{ - BatchMessage, BatchResponse, ClientT, Error, ReceivedMessage, RegisterNotificationMessage, RequestMessage, - Subscription, SubscriptionClientT, SubscriptionKind, SubscriptionMessage, TransportReceiverT, TransportSenderT, + BatchResponse, ClientT, Error, ReceivedMessage, RegisterNotificationMessage, Subscription, SubscriptionClientT, + SubscriptionKind, TransportReceiverT, TransportSenderT, }; use crate::error::RegisterMethodError; +use crate::middleware::layer::RpcLoggerLayer; +use crate::middleware::{Batch, IsBatch, IsSubscription, Request, RpcServiceBuilder, RpcServiceT}; use crate::params::{BatchRequestBuilder, EmptyBatchRequest}; -use crate::tracing::client::{rx_log_from_json, tx_log_from_str}; use crate::traits::ToRpcParams; -use crate::JsonRawValue; use std::borrow::Cow as StdCow; +use async_trait::async_trait; use core::time::Duration; +use futures_util::Stream; +use futures_util::future::{self, Either}; +use futures_util::stream::StreamExt; use helpers::{ build_unsubscribe_message, call_with_timeout, process_batch_response, process_notification, process_single_response, process_subscription_response, stop_subscription, }; +use http::Extensions; +use jsonrpsee_types::response::SubscriptionError; use jsonrpsee_types::{InvalidRequestId, ResponseSuccess, TwoPointZero}; +use jsonrpsee_types::{Response, SubscriptionResponse}; use manager::RequestManager; -use std::sync::Arc; - -use async_trait::async_trait; -use futures_timer::Delay; -use futures_util::future::{self, Either}; -use futures_util::stream::StreamExt; -use futures_util::Stream; -use jsonrpsee_types::response::{ResponsePayload, SubscriptionError}; -use jsonrpsee_types::{NotificationSer, RequestSer, Response, SubscriptionResponse}; use serde::de::DeserializeOwned; +use std::sync::Arc; use tokio::sync::{mpsc, oneshot}; -use tracing::instrument; +use tower::layer::util::Identity; use self::utils::{InactivityCheck, IntervalStream}; -use super::{generate_batch_id_range, subscription_channel, FrontToBack, IdKind, RequestIdManager}; +use super::{FrontToBack, IdKind, MethodResponse, RequestIdManager, generate_batch_id_range, subscription_channel}; + +pub(crate) type Notification<'a> = jsonrpsee_types::Notification<'a, Option>>; -pub(crate) type Notification<'a> = jsonrpsee_types::Notification<'a, Option>; +type Logger = tower::layer::util::Stack; const LOG_TARGET: &str = "jsonrpsee-client"; const NOT_POISONED: &str = "Not poisoned; qed"; @@ -112,7 +117,7 @@ impl PingConfig { } /// Configure how long to wait for the WebSocket pong. - /// When this limit is expired it's regarded as inresponsive. + /// When this limit is expired it's regarded as unresponsive. /// /// You may configure how many times the connection is allowed to /// be inactive by [`PingConfig::max_failures`]. @@ -177,15 +182,15 @@ impl ErrorFromBack { } /// Builder for [`Client`]. -#[derive(Debug, Copy, Clone)] -pub struct ClientBuilder { +#[derive(Debug, Clone)] +pub struct ClientBuilder { request_timeout: Duration, max_concurrent_requests: usize, max_buffer_capacity_per_subscription: usize, id_kind: IdKind, - max_log_length: u32, ping_config: Option, tcp_no_delay: bool, + service_builder: RpcServiceBuilder, } impl Default for ClientBuilder { @@ -195,19 +200,21 @@ impl Default for ClientBuilder { max_concurrent_requests: 256, max_buffer_capacity_per_subscription: 1024, id_kind: IdKind::Number, - max_log_length: 4096, ping_config: None, tcp_no_delay: true, + service_builder: RpcServiceBuilder::default().rpc_logger(1024), } } } -impl ClientBuilder { - /// Create a builder for the client. +impl ClientBuilder { + /// Create a new client builder. pub fn new() -> ClientBuilder { ClientBuilder::default() } +} +impl ClientBuilder { /// Set request timeout (default is 60 seconds). pub fn request_timeout(mut self, timeout: Duration) -> Self { self.request_timeout = timeout; @@ -243,14 +250,6 @@ impl ClientBuilder { self } - /// Set maximum length for logging calls and responses. - /// - /// Logs bigger than this limit will be truncated. - pub fn set_max_logging_length(mut self, max: u32) -> Self { - self.max_log_length = max; - self - } - /// Enable WebSocket ping/pong on the client. /// /// This only works if the transport supports WebSocket pings. @@ -279,6 +278,22 @@ impl ClientBuilder { self } + /// Configure the client to a specific RPC middleware which + /// runs for every JSON-RPC call. + /// + /// This is useful for adding a custom logger or something similar. + pub fn set_rpc_middleware(self, service_builder: RpcServiceBuilder) -> ClientBuilder { + ClientBuilder { + request_timeout: self.request_timeout, + max_concurrent_requests: self.max_concurrent_requests, + max_buffer_capacity_per_subscription: self.max_buffer_capacity_per_subscription, + id_kind: self.id_kind, + ping_config: self.ping_config, + tcp_no_delay: self.tcp_no_delay, + service_builder, + } + } + /// Build the client with given transport. /// /// ## Panics @@ -286,10 +301,11 @@ impl ClientBuilder { /// Panics if called outside of `tokio` runtime context. #[cfg(feature = "async-client")] #[cfg_attr(docsrs, doc(cfg(feature = "async-client")))] - pub fn build_with_tokio(self, sender: S, receiver: R) -> Client + pub fn build_with_tokio(self, sender: S, receiver: R) -> Client where S: TransportSenderT + Send, R: TransportReceiverT + Send, + L: tower::Layer + Clone + Send + Sync + 'static, { let (to_back, from_front) = mpsc::channel(self.max_concurrent_requests); let disconnect_reason = SharedDisconnectReason::default(); @@ -344,10 +360,10 @@ impl ClientBuilder { Client { to_back: to_back.clone(), + service: self.service_builder.service(RpcService::new(to_back.clone())), request_timeout: self.request_timeout, error: ErrorFromBack::new(to_back, disconnect_reason), id_manager: RequestIdManager::new(self.id_kind), - max_log_length: self.max_log_length, on_exit: Some(client_dropped_tx), } } @@ -355,10 +371,11 @@ impl ClientBuilder { /// Build the client with given transport. #[cfg(all(feature = "async-wasm-client", target_arch = "wasm32"))] #[cfg_attr(docsrs, doc(cfg(feature = "async-wasm-client")))] - pub fn build_with_wasm(self, sender: S, receiver: R) -> Client + pub fn build_with_wasm(self, sender: S, receiver: R) -> Client where S: TransportSenderT, R: TransportReceiverT, + L: tower::Layer + Clone + Send + Sync + 'static, { use futures_util::stream::Pending; @@ -402,10 +419,10 @@ impl ClientBuilder { Client { to_back: to_back.clone(), + service: self.service_builder.service(RpcService::new(to_back.clone())), request_timeout: self.request_timeout, error: ErrorFromBack::new(to_back, disconnect_reason), id_manager: RequestIdManager::new(self.id_kind), - max_log_length: self.max_log_length, on_exit: Some(client_dropped_tx), } } @@ -413,7 +430,7 @@ impl ClientBuilder { /// Generic asynchronous client. #[derive(Debug)] -pub struct Client { +pub struct Client { /// Channel to send requests to the background task. to_back: mpsc::Sender, error: ErrorFromBack, @@ -421,25 +438,38 @@ pub struct Client { request_timeout: Duration, /// Request ID manager. id_manager: RequestIdManager, - /// Max length for logging for requests and responses. - /// - /// Entries bigger than this limit will be truncated. - max_log_length: u32, /// When the client is dropped a message is sent to the background thread. on_exit: Option>, + service: L, } -impl Client { - /// Create a builder for the server. +impl Client { + /// Create a builder for the client. pub fn builder() -> ClientBuilder { ClientBuilder::new() } +} +impl Client { /// Checks if the client is connected to the target. pub fn is_connected(&self) -> bool { !self.to_back.is_closed() } + async fn run_future_until_timeout( + &self, + fut: impl Future>, + ) -> Result { + tokio::pin!(fut); + + match futures_util::future::select(fut, futures_timer::Delay::new(self.request_timeout)).await { + Either::Left((Ok(r), _)) => Ok(r), + Either::Left((Err(RpcServiceError::Client(e)), _)) => Err(e), + Either::Left((Err(RpcServiceError::FetchFromBackend), _)) => Err(self.on_disconnect().await), + Either::Right(_) => Err(Error::RequestTimeout), + } + } + /// Completes when the client is disconnected or the client's background task encountered an error. /// If the client is already disconnected, the future produced by this method will complete immediately. /// @@ -456,7 +486,7 @@ impl Client { } } -impl Drop for Client { +impl Drop for Client { fn drop(&mut self) { if let Some(e) = self.on_exit.take() { let _ = e.send(()); @@ -465,68 +495,36 @@ impl Drop for Client { } #[async_trait] -impl ClientT for Client { - #[instrument(name = "notification", skip(self, params), level = "trace")] +impl ClientT for Client +where + L: RpcServiceT + Send + Sync, +{ async fn notification(&self, method: &str, params: Params) -> Result<(), Error> where Params: ToRpcParams + Send, { // NOTE: we use this to guard against max number of concurrent requests. let _req_id = self.id_manager.next_request_id(); - let params = params.to_rpc_params()?; - let notif = NotificationSer::borrowed(&method, params.as_deref()); - - let raw = serde_json::to_string(¬if).map_err(Error::ParseError)?; - tx_log_from_str(&raw, self.max_log_length); - - let sender = self.to_back.clone(); - let fut = sender.send(FrontToBack::Notification(raw)); - - tokio::pin!(fut); - - match future::select(fut, Delay::new(self.request_timeout)).await { - Either::Left((Ok(()), _)) => Ok(()), - Either::Left((Err(_), _)) => Err(self.on_disconnect().await), - Either::Right((_, _)) => Err(Error::RequestTimeout), - } + let params = params.to_rpc_params()?.map(StdCow::Owned); + let fut = self.service.notification(jsonrpsee_types::Notification::new(method.into(), params)); + self.run_future_until_timeout(fut).await?; + Ok(()) } - #[instrument(name = "method_call", skip(self, params), level = "trace")] async fn request(&self, method: &str, params: Params) -> Result where R: DeserializeOwned, Params: ToRpcParams + Send, { - let (send_back_tx, send_back_rx) = oneshot::channel(); let id = self.id_manager.next_request_id(); - let params = params.to_rpc_params()?; - let raw = - serde_json::to_string(&RequestSer::borrowed(&id, &method, params.as_deref())).map_err(Error::ParseError)?; - tx_log_from_str(&raw, self.max_log_length); - - if self - .to_back - .clone() - .send(FrontToBack::Request(RequestMessage { raw, id: id.clone(), send_back: Some(send_back_tx) })) - .await - .is_err() - { - return Err(self.on_disconnect().await); - } - - let json_value = match call_with_timeout(self.request_timeout, send_back_rx).await { - Ok(Ok(v)) => v, - Ok(Err(err)) => return Err(err), - Err(_) => return Err(self.on_disconnect().await), - }; - - rx_log_from_json(&Response::new(ResponsePayload::success_borrowed(&json_value), id), self.max_log_length); + let fut = self.service.call(Request::borrowed(method, params.as_deref(), id.clone())); + let rp = self.run_future_until_timeout(fut).await?.into_method_call().expect("Method call response"); + let success = ResponseSuccess::try_from(rp.into_inner())?; - serde_json::from_value(json_value).map_err(Error::ParseError) + serde_json::from_str(success.result.get()).map_err(Into::into) } - #[instrument(name = "batch", skip(self, batch), level = "trace")] async fn batch_request<'a, R>(&self, batch: BatchRequestBuilder<'a>) -> Result, Error> where R: DeserializeOwned, @@ -535,50 +533,31 @@ impl ClientT for Client { let id = self.id_manager.next_request_id(); let id_range = generate_batch_id_range(id, batch.len() as u64)?; - let mut batches = Vec::with_capacity(batch.len()); + let mut b = Batch::with_capacity(batch.len()); + for ((method, params), id) in batch.into_iter().zip(id_range.clone()) { - let id = self.id_manager.as_id_kind().into_id(id); - batches.push(RequestSer { + b.push(Request { jsonrpc: TwoPointZero, - id, + id: self.id_manager.as_id_kind().into_id(id), method: method.into(), params: params.map(StdCow::Owned), + extensions: Extensions::new(), }); } - let (send_back_tx, send_back_rx) = oneshot::channel(); - - let raw = serde_json::to_string(&batches).map_err(Error::ParseError)?; + b.extensions_mut().insert(IsBatch { id_range }); - tx_log_from_str(&raw, self.max_log_length); - - if self - .to_back - .clone() - .send(FrontToBack::Batch(BatchMessage { raw, ids: id_range, send_back: send_back_tx })) - .await - .is_err() - { - return Err(self.on_disconnect().await); - } - - let res = call_with_timeout(self.request_timeout, send_back_rx).await; - let json_values = match res { - Ok(Ok(v)) => v, - Ok(Err(err)) => return Err(err), - Err(_) => return Err(self.on_disconnect().await), - }; - - rx_log_from_json(&json_values, self.max_log_length); + let fut = self.service.batch(b); + let json_values = self.run_future_until_timeout(fut).await?.into_batch().expect("Batch response"); let mut responses = Vec::with_capacity(json_values.len()); let mut successful_calls = 0; let mut failed_calls = 0; for json_val in json_values { - match json_val { + match ResponseSuccess::try_from(json_val.into_inner()) { Ok(val) => { - let result: R = serde_json::from_value(val).map_err(Error::ParseError)?; + let result: R = serde_json::from_str(val.result.get()).map_err(Error::ParseError)?; responses.push(Ok(result)); successful_calls += 1; } @@ -593,12 +572,14 @@ impl ClientT for Client { } #[async_trait] -impl SubscriptionClientT for Client { +impl SubscriptionClientT for Client +where + L: RpcServiceT + Send + Sync, +{ /// Send a subscription request to the server. /// /// The `subscribe_method` and `params` are used to ask for the subscription towards the /// server. The `unsubscribe_method` is used to close the subscription. - #[instrument(name = "subscription", fields(method = subscribe_method), skip(self, params, subscribe_method, unsubscribe_method), level = "trace")] async fn subscribe<'a, Notif, Params>( &self, subscribe_method: &'a str, @@ -613,45 +594,29 @@ impl SubscriptionClientT for Client { return Err(RegisterMethodError::SubscriptionNameConflict(unsubscribe_method.to_owned()).into()); } - let id_sub = self.id_manager.next_request_id(); - let id_unsub = self.id_manager.next_request_id(); + let req_id_sub = self.id_manager.next_request_id(); + let req_id_unsub = self.id_manager.next_request_id(); let params = params.to_rpc_params()?; - let raw = serde_json::to_string(&RequestSer::borrowed(&id_sub, &subscribe_method, params.as_deref())) - .map_err(Error::ParseError)?; - - tx_log_from_str(&raw, self.max_log_length); - - let (send_back_tx, send_back_rx) = tokio::sync::oneshot::channel(); - if self - .to_back - .clone() - .send(FrontToBack::Subscribe(SubscriptionMessage { - raw, - subscribe_id: id_sub, - unsubscribe_id: id_unsub.clone(), - unsubscribe_method: unsubscribe_method.to_owned(), - send_back: send_back_tx, - })) - .await - .is_err() - { - return Err(self.on_disconnect().await); - } + let mut ext = Extensions::new(); + ext.insert(IsSubscription::new(req_id_sub.clone(), req_id_unsub, unsubscribe_method.to_owned())); - let (notifs_rx, sub_id) = match call_with_timeout(self.request_timeout, send_back_rx).await { - Ok(Ok(val)) => val, - Ok(Err(err)) => return Err(err), - Err(_) => return Err(self.on_disconnect().await), + let req = Request { + jsonrpc: TwoPointZero, + id: req_id_sub, + method: subscribe_method.into(), + params: params.map(StdCow::Owned), + extensions: ext, }; - rx_log_from_json(&Response::new(ResponsePayload::success_borrowed(&sub_id), id_unsub), self.max_log_length); + let fut = self.service.call(req); + let (sub_id, notifs_rx) = + self.run_future_until_timeout(fut).await?.into_subscription().expect("Subscription response"); Ok(Subscription::new(self.to_back.clone(), notifs_rx, SubscriptionKind::Subscription(sub_id))) } /// Subscribe to a specific method. - #[instrument(name = "subscribe_method", skip(self), level = "trace")] async fn subscribe_to_method<'a, N>(&self, method: &'a str) -> Result, Error> where N: DeserializeOwned, @@ -699,12 +664,17 @@ fn handle_backend_messages( let first_non_whitespace = raw.iter().find(|byte| !byte.is_ascii_whitespace()); let mut messages = Vec::new(); + tracing::trace!(target: LOG_TARGET, "rx: {}", serde_json::from_slice::<&JsonRawValue>(raw).map_or("", |v| v.get())); + match first_non_whitespace { Some(b'{') => { // Single response to a request. if let Ok(single) = serde_json::from_slice::>(raw) { - let maybe_unsub = - process_single_response(&mut manager.lock(), single, max_buffer_capacity_per_subscription)?; + let maybe_unsub = process_single_response( + &mut manager.lock(), + single.into_owned().into(), + max_buffer_capacity_per_subscription, + )?; if let Some(unsub) = maybe_unsub { return Ok(vec![FrontToBack::Request(unsub)]); @@ -738,8 +708,7 @@ fn handle_backend_messages( for r in raw_responses { if let Ok(response) = serde_json::from_str::>(r.get()) { let id = response.id.try_parse_inner_as_number()?; - let result = ResponseSuccess::try_from(response).map(|s| s.result); - batch.push(InnerBatchResponse { id, result }); + batch.push(response.into_owned().into()); let r = range.get_or_insert(id..id); @@ -812,7 +781,7 @@ async fn handle_frontend_messages( FrontToBack::Batch(batch) => { if let Err(send_back) = manager.lock().insert_pending_batch(batch.ids.clone(), batch.send_back) { tracing::debug!(target: LOG_TARGET, "Batch request already pending: {:?}", batch.ids); - let _ = send_back.send(Err(InvalidRequestId::Occupied(format!("{:?}", batch.ids)).into())); + let _ = send_back.send(Err(InvalidRequestId::Occupied(format!("{:?}", batch.ids)))); return Ok(()); } @@ -828,7 +797,7 @@ async fn handle_frontend_messages( tracing::debug!(target: LOG_TARGET, "Denied duplicate method call"); if let Some(s) = send_back { - let _ = s.send(Err(InvalidRequestId::Occupied(request.id.to_string()).into())); + let _ = s.send(Err(InvalidRequestId::Occupied(request.id.to_string()))); } return Ok(()); } diff --git a/core/src/client/async_client/rpc_service.rs b/core/src/client/async_client/rpc_service.rs new file mode 100644 index 0000000000..efcbe23451 --- /dev/null +++ b/core/src/client/async_client/rpc_service.rs @@ -0,0 +1,139 @@ +use crate::{ + client::{ + BatchMessage, Error as ClientError, FrontToBack, MethodResponse, RequestMessage, SubscriptionMessage, + SubscriptionResponse, + }, + middleware::{Batch, IsBatch, IsSubscription, Notification, Request, RpcServiceT}, +}; + +use jsonrpsee_types::{Response, ResponsePayload}; +use tokio::sync::{mpsc, oneshot}; + +/// RpcService error. +#[derive(Debug, thiserror::Error)] +pub enum Error { + /// Client error. + #[error(transparent)] + Client(#[from] ClientError), + #[error("Fetch from backend")] + /// Internal error state when the underlying channel is closed + /// and the error reason needs to be fetched from the backend. + FetchFromBackend, +} + +impl From> for Error { + fn from(_: mpsc::error::SendError) -> Self { + Error::FetchFromBackend + } +} + +impl From for Error { + fn from(_: oneshot::error::RecvError) -> Self { + Error::FetchFromBackend + } +} + +/// RpcService implementation for the async client. +#[derive(Debug, Clone)] +pub struct RpcService(mpsc::Sender); + +impl RpcService { + // This is a private interface but we need to expose it for the async client + // to be able to create the service. + #[allow(private_interfaces)] + pub(crate) fn new(tx: mpsc::Sender) -> Self { + Self(tx) + } +} + +impl RpcServiceT for RpcService { + type Response = MethodResponse; + type Error = Error; + + fn call<'a>(&self, request: Request<'a>) -> impl Future> + Send + 'a { + let tx = self.0.clone(); + + async move { + let raw = serde_json::to_string(&request).map_err(client_err)?; + + match request.extensions.get::() { + Some(sub) => { + let (send_back_tx, send_back_rx) = tokio::sync::oneshot::channel(); + + tx.clone() + .send(FrontToBack::Subscribe(SubscriptionMessage { + raw, + subscribe_id: sub.sub_req_id(), + unsubscribe_id: sub.unsub_req_id(), + unsubscribe_method: sub.unsubscribe_method().to_owned(), + send_back: send_back_tx, + })) + .await?; + + let (subscribe_rx, sub_id) = send_back_rx.await??; + + let s = serde_json::value::to_raw_value(&sub_id).map_err(client_err)?; + + Ok(MethodResponse::subscription( + SubscriptionResponse { + rp: Response::new(ResponsePayload::success(s), request.id.clone().into_owned()).into(), + sub_id, + stream: subscribe_rx, + }, + request.extensions, + )) + } + None => { + let (send_back_tx, send_back_rx) = oneshot::channel(); + + tx.send(FrontToBack::Request(RequestMessage { + raw, + send_back: Some(send_back_tx), + id: request.id.clone().into_owned(), + })) + .await?; + let rp = send_back_rx.await?.map_err(client_err)?; + + Ok(MethodResponse::method_call(rp, request.extensions)) + } + } + } + } + + fn batch<'a>(&self, mut batch: Batch<'a>) -> impl Future> + Send + 'a { + let tx = self.0.clone(); + + async move { + let (send_back_tx, send_back_rx) = oneshot::channel(); + + let raw = serde_json::to_string(&batch).map_err(client_err)?; + let id_range = batch + .extensions() + .get::() + .map(|b| b.id_range.clone()) + .expect("Batch ID range must be set in extensions"); + + tx.send(FrontToBack::Batch(BatchMessage { raw, ids: id_range, send_back: send_back_tx })).await?; + let json = send_back_rx.await?.map_err(client_err)?; + + Ok(MethodResponse::batch(json, batch.into_extensions())) + } + } + + fn notification<'a>( + &self, + n: Notification<'a>, + ) -> impl Future> + Send + 'a { + let tx = self.0.clone(); + + async move { + let raw = serde_json::to_string(&n).map_err(client_err)?; + tx.send(FrontToBack::Notification(raw)).await?; + Ok(MethodResponse::notification(n.extensions)) + } + } +} + +fn client_err(err: impl Into) -> Error { + Error::Client(err.into()) +} diff --git a/core/src/client/mod.rs b/core/src/client/mod.rs index 976b822e2f..93c27cfaa4 100644 --- a/core/src/client/mod.rs +++ b/core/src/client/mod.rs @@ -32,6 +32,7 @@ cfg_async_client! { } pub mod error; + pub use error::Error; use std::fmt; @@ -42,20 +43,27 @@ use std::sync::{Arc, RwLock}; use std::task::{self, Poll}; use tokio::sync::mpsc::error::TrySendError; +use crate::middleware::ToJson; use crate::params::BatchRequestBuilder; use crate::traits::ToRpcParams; + use async_trait::async_trait; use core::marker::PhantomData; use futures_util::stream::{Stream, StreamExt}; -use jsonrpsee_types::{ErrorObject, Id, SubscriptionId}; +use http::Extensions; +use jsonrpsee_types::{ErrorObject, Id, InvalidRequestId, SubscriptionId}; +use serde::Serialize; use serde::de::DeserializeOwned; -use serde_json::Value as JsonValue; +use serde_json::value::RawValue; use tokio::sync::{mpsc, oneshot}; /// Shared state whether a subscription has lagged or not. #[derive(Debug, Clone)] pub(crate) struct SubscriptionLagged(Arc>); +/// Owned version of [`RawResponse`]. +pub type RawResponseOwned = RawResponse<'static>; + impl SubscriptionLagged { /// Create a new [`SubscriptionLagged`]. pub(crate) fn new() -> Self { @@ -269,7 +277,7 @@ pub struct Subscription { is_closed: bool, /// Channel to send requests to the background task. to_back: mpsc::Sender, - /// Channel from which we receive notifications from the server, as encoded `JsonValue`s. + /// Channel from which we receive notifications from the server, as encoded JSON. rx: SubscriptionReceiver, /// Callback kind. kind: Option, @@ -320,11 +328,7 @@ impl Subscription { return None; } - if lagged { - Some(SubscriptionCloseReason::Lagged) - } else { - Some(SubscriptionCloseReason::ConnectionClosed) - } + if lagged { Some(SubscriptionCloseReason::Lagged) } else { Some(SubscriptionCloseReason::ConnectionClosed) } } } @@ -336,7 +340,7 @@ struct BatchMessage { /// Request IDs. ids: Range, /// One-shot channel over which we send back the result of this request. - send_back: oneshot::Sender>, Error>>, + send_back: oneshot::Sender, InvalidRequestId>>, } /// Request message. @@ -347,7 +351,7 @@ struct RequestMessage { /// Request ID. id: Id<'static>, /// One-shot channel over which we send back the result of this request. - send_back: Option>>, + send_back: Option>>, } /// Subscription message. @@ -425,7 +429,7 @@ where type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll> { let res = match futures_util::ready!(self.rx.poll_next_unpin(cx)) { - Some(v) => Some(serde_json::from_value::(v)), + Some(v) => Some(serde_json::from_str::(v.get())), None => { self.is_closed = true; None @@ -607,17 +611,17 @@ enum TrySubscriptionSendError { #[error("The subscription is closed")] Closed, #[error("A subscription message was dropped")] - TooSlow(JsonValue), + TooSlow(Box), } #[derive(Debug)] pub(crate) struct SubscriptionSender { - inner: mpsc::Sender, + inner: mpsc::Sender>, lagged: SubscriptionLagged, } impl SubscriptionSender { - fn send(&self, msg: JsonValue) -> Result<(), TrySubscriptionSendError> { + fn send(&self, msg: Box) -> Result<(), TrySubscriptionSendError> { match self.inner.try_send(msg) { Ok(_) => Ok(()), Err(TrySendError::Closed(_)) => Err(TrySubscriptionSendError::Closed), @@ -631,12 +635,12 @@ impl SubscriptionSender { #[derive(Debug)] pub(crate) struct SubscriptionReceiver { - inner: mpsc::Receiver, + inner: mpsc::Receiver>, lagged: SubscriptionLagged, } impl Stream for SubscriptionReceiver { - type Item = JsonValue; + type Item = Box; fn poll_next(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll> { self.inner.poll_recv(cx) @@ -650,3 +654,213 @@ fn subscription_channel(max_buf_size: usize) -> (SubscriptionSender, Subscriptio (SubscriptionSender { inner: tx, lagged: lagged_tx }, SubscriptionReceiver { inner: rx, lagged: lagged_rx }) } + +/// Represents the kind of response that can be received from the server. +#[derive(Debug)] +pub enum MethodResponseKind { + /// Method call response. + MethodCall(RawResponseOwned), + /// Subscription response. + Subscription(SubscriptionResponse), + /// Notification response (no payload). + Notification, + /// Batch response. + Batch(Vec), +} + +/// Represents an active subscription returned by the server. +#[derive(Debug)] +pub struct SubscriptionResponse { + /// The ID of the subscription. + sub_id: SubscriptionId<'static>, + // The receiver is used to receive notifications from the server and shouldn't be exposed to the user + // from the middleware. + stream: SubscriptionReceiver, + /// The raw response from the server (mostly used for middleware). + rp: RawResponseOwned, +} + +impl SubscriptionResponse { + /// Get the subscription ID. + pub fn subscription_id(&self) -> &SubscriptionId<'static> { + &self.sub_id + } + + /// Get the raw response. + pub fn response(&self) -> &RawResponseOwned { + &self.rp + } +} + +/// Represents a response from the server which can be a method call, notification or batch. +#[derive(Debug)] +pub struct MethodResponse { + extensions: Extensions, + inner: MethodResponseKind, +} + +impl MethodResponse { + /// Create a new method response. + pub fn method_call(rp: RawResponseOwned, extensions: Extensions) -> Self { + Self { inner: MethodResponseKind::MethodCall(rp), extensions } + } + + /// Create a new subscription response. + pub fn subscription(sub: SubscriptionResponse, extensions: Extensions) -> Self { + Self { inner: MethodResponseKind::Subscription(sub), extensions } + } + + /// Create a new notification response. + pub fn notification(extensions: Extensions) -> Self { + Self { inner: MethodResponseKind::Notification, extensions } + } + + /// Create a new batch response. + pub fn batch(json: Vec, extensions: Extensions) -> Self { + Self { inner: MethodResponseKind::Batch(json), extensions } + } + + /// Get the method call if this response is a method call. + pub fn into_method_call(self) -> Option { + match self.inner { + MethodResponseKind::MethodCall(call) => Some(call), + _ => None, + } + } + + /// Get the batch if this response is a batch. + pub fn into_batch(self) -> Option> { + match self.inner { + MethodResponseKind::Batch(batch) => Some(batch), + _ => None, + } + } + + /// Get the subscription if this response is a subscription. + fn into_subscription(self) -> Option<(SubscriptionId<'static>, SubscriptionReceiver)> { + match self.inner { + MethodResponseKind::Subscription(s) => Some((s.sub_id, s.stream)), + _ => None, + } + } + + /// Returns whether this response is a method call. + pub fn is_method_call(&self) -> bool { + matches!(self.inner, MethodResponseKind::MethodCall(_)) + } + + /// Returns whether this response is a notification. + pub fn is_notification(&self) -> bool { + matches!(self.inner, MethodResponseKind::Notification) + } + + /// Returns whether this response is a batch. + pub fn is_batch(&self) -> bool { + matches!(self.inner, MethodResponseKind::Batch(_)) + } + + /// Returns whether this response is a subscription. + pub fn is_subscription(&self) -> bool { + matches!(self.inner, MethodResponseKind::Subscription { .. }) + } + + /// Returns a reference to the associated extensions. + pub fn extensions(&self) -> &Extensions { + &self.extensions + } + + /// Returns a mutable reference to the associated extensions. + pub fn extensions_mut(&mut self) -> &mut Extensions { + &mut self.extensions + } +} + +impl std::ops::Deref for MethodResponse { + type Target = MethodResponseKind; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl ToJson for MethodResponse { + fn to_json(&self) -> Result, serde_json::Error> { + match &self.inner { + MethodResponseKind::MethodCall(call) => call.to_json(), + MethodResponseKind::Notification => Ok(Box::::default()), + MethodResponseKind::Batch(json) => serde_json::value::to_raw_value(json), + MethodResponseKind::Subscription(s) => serde_json::value::to_raw_value(&s.rp), + } + } +} + +/// A raw JSON-RPC response object which can be either a JSON-RPC success or error response. +/// +/// This is a wrapper around the `jsonrpsee_types::Response` type for ease of use +/// for middleware client implementations. +#[derive(Debug)] +pub struct RawResponse<'a>(jsonrpsee_types::Response<'a, Box>); + +impl Serialize for RawResponse<'_> { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + self.0.serialize(serializer) + } +} + +impl<'a> From>> for RawResponse<'a> { + fn from(r: jsonrpsee_types::Response<'a, Box>) -> Self { + Self(r) + } +} + +impl<'a> RawResponse<'a> { + /// Whether this response is successful JSON-RPC response. + pub fn is_success(&self) -> bool { + match self.0.payload { + jsonrpsee_types::ResponsePayload::Success(_) => true, + jsonrpsee_types::ResponsePayload::Error(_) => false, + } + } + + /// Extract the error object from the response if it is an error. + pub fn as_error(&self) -> Option<&ErrorObject<'_>> { + match self.0.payload { + jsonrpsee_types::ResponsePayload::Error(ref err) => Some(err), + _ => None, + } + } + + // Extract the result field the response if it is a success. + /// + /// Omits JSON-RPC specific fields like `jsonrpc` and `id`. + pub fn as_success(&self) -> Option<&RawValue> { + match self.0.payload { + jsonrpsee_types::ResponsePayload::Success(ref res) => Some(res), + _ => None, + } + } + + /// Get the request ID. + pub fn id(&self) -> &Id<'a> { + &self.0.id + } + + /// Consume the response and extract the inner value. + pub fn into_inner(self) -> jsonrpsee_types::Response<'a, Box> { + self.0 + } + + /// Convert the response into an owned version. + pub fn into_owned(self) -> RawResponseOwned { + RawResponse(self.0.into_owned()) + } +} + +impl ToJson for RawResponse<'_> { + fn to_json(&self) -> Result, serde_json::Error> { + serde_json::value::to_raw_value(&self.0) + } +} diff --git a/core/src/lib.rs b/core/src/lib.rs index 4d9aea31d8..374aae2fda 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -56,8 +56,10 @@ cfg_client! { pub use client::Error as ClientError; } -/// Shared tracing helpers to trace RPC calls. -pub mod tracing; +cfg_client_or_server! { + pub mod middleware; +} + pub use async_trait::async_trait; pub use error::{RegisterMethodError, StringError}; diff --git a/server/src/middleware/rpc/layer/either.rs b/core/src/middleware/layer/either.rs similarity index 68% rename from server/src/middleware/rpc/layer/either.rs rename to core/src/middleware/layer/either.rs index 01209323c3..81fad8cf01 100644 --- a/server/src/middleware/rpc/layer/either.rs +++ b/core/src/middleware/layer/either.rs @@ -31,7 +31,7 @@ //! work to implement tower::Layer for //! external types such as future::Either. -use crate::middleware::rpc::RpcServiceT; +use crate::middleware::{Batch, Notification, RpcServiceT}; use jsonrpsee_types::Request; /// [`tower::util::Either`] but @@ -59,17 +59,35 @@ where } } -impl<'a, A, B> RpcServiceT<'a> for Either +impl RpcServiceT for Either where - A: RpcServiceT<'a> + Send + 'a, - B: RpcServiceT<'a> + Send + 'a, + A: RpcServiceT + Send, + B: RpcServiceT + Send, { - type Future = futures_util::future::Either; + type Error = A::Error; + type Response = A::Response; - fn call(&self, request: Request<'a>) -> Self::Future { + fn call<'a>(&self, request: Request<'a>) -> impl Future> + Send + 'a { match self { Either::Left(service) => futures_util::future::Either::Left(service.call(request)), Either::Right(service) => futures_util::future::Either::Right(service.call(request)), } } + + fn batch<'a>(&self, batch: Batch<'a>) -> impl Future> + Send + 'a { + match self { + Either::Left(service) => futures_util::future::Either::Left(service.batch(batch)), + Either::Right(service) => futures_util::future::Either::Right(service.batch(batch)), + } + } + + fn notification<'a>( + &self, + n: Notification<'a>, + ) -> impl Future> + Send + 'a { + match self { + Either::Left(service) => futures_util::future::Either::Left(service.notification(n)), + Either::Right(service) => futures_util::future::Either::Right(service.notification(n)), + } + } } diff --git a/core/src/middleware/layer/logger.rs b/core/src/middleware/layer/logger.rs new file mode 100644 index 0000000000..344e6ff648 --- /dev/null +++ b/core/src/middleware/layer/logger.rs @@ -0,0 +1,157 @@ +// Copyright 2019-2021 Parity Technologies (UK) Ltd. +// +// Permission is hereby granted, free of charge, to any +// person obtaining a copy of this software and associated +// documentation files (the "Software"), to deal in the +// Software without restriction, including without +// limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of +// the Software, and to permit persons to whom the Software +// is furnished to do so, subject to the following +// conditions: +// +// The above copyright notice and this permission notice +// shall be included in all copies or substantial portions +// of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +// SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +//! RPC Logger layer. + +use crate::middleware::{Batch, Notification, RpcServiceT, ToJson}; + +use futures_util::Future; +use jsonrpsee_types::Request; +use serde_json::value::RawValue; +use tracing::Instrument; + +/// RPC logger layer. +#[derive(Copy, Clone, Debug)] +pub struct RpcLoggerLayer(u32); + +impl RpcLoggerLayer { + /// Create a new logging layer. + pub fn new(max: u32) -> Self { + Self(max) + } +} + +impl tower::Layer for RpcLoggerLayer { + type Service = RpcLogger; + + fn layer(&self, service: S) -> Self::Service { + RpcLogger { service, max: self.0 } + } +} + +/// A middleware that logs each RPC call and response. +#[derive(Debug)] +pub struct RpcLogger { + max: u32, + service: S, +} + +impl RpcServiceT for RpcLogger +where + S: RpcServiceT + Send + Sync + Clone + 'static, + S::Error: std::fmt::Debug + Send, + S::Response: ToJson, +{ + type Error = S::Error; + type Response = S::Response; + + #[tracing::instrument(name = "method_call", skip_all, fields(method = request.method_name()), level = "trace")] + fn call<'a>(&self, request: Request<'a>) -> impl Future> + Send + 'a { + let json = serde_json::value::to_raw_value(&request); + let json_str = unwrap_json_str_or_invalid(&json); + tracing::trace!(target: "jsonrpsee", "request = {}", truncate_at_char_boundary(json_str, self.max as usize)); + + let service = self.service.clone(); + let max = self.max; + + async move { + let rp = service.call(request).await; + + if let Ok(ref rp) = rp { + let json = rp.to_json(); + let json_str = unwrap_json_str_or_invalid(&json); + tracing::trace!(target: "jsonrpsee", "response = {}", truncate_at_char_boundary(json_str, max as usize)); + } + rp + } + .in_current_span() + } + + #[tracing::instrument(name = "batch", skip_all, fields(method = "batch"), level = "trace")] + fn batch<'a>(&self, batch: Batch<'a>) -> impl Future> + Send + 'a { + let json = serde_json::value::to_raw_value(&batch); + let json_str = unwrap_json_str_or_invalid(&json); + tracing::trace!(target: "jsonrpsee", "batch request = {}", truncate_at_char_boundary(json_str, self.max as usize)); + let service = self.service.clone(); + let max = self.max; + + async move { + let rp = service.batch(batch).await; + + if let Ok(ref rp) = rp { + let json = rp.to_json(); + let json_str = unwrap_json_str_or_invalid(&json); + tracing::trace!(target: "jsonrpsee", "batch response = {}", truncate_at_char_boundary(json_str, max as usize)); + } + rp + } + .in_current_span() + } + + #[tracing::instrument(name = "notification", skip_all, fields(method = &*n.method), level = "trace")] + fn notification<'a>( + &self, + n: Notification<'a>, + ) -> impl Future> + Send + 'a { + let json = serde_json::value::to_raw_value(&n); + let json_str = unwrap_json_str_or_invalid(&json); + tracing::trace!(target: "jsonrpsee", "notification request = {}", truncate_at_char_boundary(json_str, self.max as usize)); + + self.service.notification(n).in_current_span() + } +} + +fn unwrap_json_str_or_invalid(json: &Result, serde_json::Error>) -> &str { + match json { + Ok(s) => s.get(), + Err(_) => "", + } +} + +/// Find the next char boundary to truncate at. +fn truncate_at_char_boundary(s: &str, max: usize) -> &str { + if s.len() < max { + return s; + } + + match s.char_indices().nth(max) { + None => s, + Some((idx, _)) => &s[..idx], + } +} + +#[cfg(test)] +mod tests { + use super::truncate_at_char_boundary; + + #[test] + fn truncate_at_char_boundary_works() { + assert_eq!(truncate_at_char_boundary("ボルテックス", 0), ""); + assert_eq!(truncate_at_char_boundary("ボルテックス", 4), "ボルテッ"); + assert_eq!(truncate_at_char_boundary("ボルテックス", 100), "ボルテックス"); + assert_eq!(truncate_at_char_boundary("hola-hola", 4), "hola"); + } +} diff --git a/server/src/middleware/rpc/layer/mod.rs b/core/src/middleware/layer/mod.rs similarity index 55% rename from server/src/middleware/rpc/layer/mod.rs rename to core/src/middleware/layer/mod.rs index 16a37c8852..f9b886bdd9 100644 --- a/server/src/middleware/rpc/layer/mod.rs +++ b/core/src/middleware/layer/mod.rs @@ -26,41 +26,8 @@ //! Specific middleware layer implementation provided by jsonrpsee. -pub mod either; -pub mod logger; -pub mod rpc_service; +mod either; +mod logger; +pub use either::*; pub use logger::*; -pub use rpc_service::*; - -use std::pin::Pin; -use std::task::{Context, Poll}; - -use futures_util::future::{Either, Future}; -use jsonrpsee_core::server::MethodResponse; -use pin_project::pin_project; - -/// Response which may be ready or a future. -#[derive(Debug)] -#[pin_project] -pub struct ResponseFuture(#[pin] futures_util::future::Either>); - -impl ResponseFuture { - /// Returns a future that resolves to a response. - pub fn future(f: F) -> ResponseFuture { - ResponseFuture(Either::Left(f)) - } - - /// Return a response which is already computed. - pub fn ready(response: MethodResponse) -> ResponseFuture { - ResponseFuture(Either::Right(std::future::ready(response))) - } -} - -impl> Future for ResponseFuture { - type Output = MethodResponse; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - self.project().0.poll(cx) - } -} diff --git a/core/src/middleware/mod.rs b/core/src/middleware/mod.rs new file mode 100644 index 0000000000..180d2bdbd3 --- /dev/null +++ b/core/src/middleware/mod.rs @@ -0,0 +1,461 @@ +//! Middleware for the RPC service. + +pub mod layer; + +use futures_util::future::{Either, Future}; +use jsonrpsee_types::{ErrorObject, Id}; +use pin_project::pin_project; +use serde::Serialize; +use serde::ser::SerializeSeq; +use serde_json::value::RawValue; +use tower::layer::LayerFn; +use tower::layer::util::{Identity, Stack}; + +use std::borrow::Cow; +use std::pin::Pin; +use std::task::{Context, Poll}; + +/// Re-export types from `jsonrpsee_types` crate for convenience. +pub type Notification<'a> = jsonrpsee_types::Notification<'a, Option>>; +/// Re-export types from `jsonrpsee_types` crate for convenience. +pub use jsonrpsee_types::{Extensions, Request}; + +/// Error response that can used to indicate an error in JSON-RPC request batch request. +/// This is used in the [`Batch`] type to indicate an error in the batch entry. +#[derive(Debug)] +pub struct BatchEntryErr<'a>(jsonrpsee_types::Response<'a, ()>); + +impl<'a> BatchEntryErr<'a> { + /// Create a new error response. + pub fn new(id: Id<'a>, err: ErrorObject<'a>) -> Self { + let payload = jsonrpsee_types::ResponsePayload::Error(err); + let response = jsonrpsee_types::Response::new(payload, id); + Self(response) + } + + /// Get the parts of the error response.q + pub fn into_parts(self) -> (ErrorObject<'a>, Id<'a>) { + let err = match self.0.payload { + jsonrpsee_types::ResponsePayload::Error(err) => err, + _ => unreachable!("BatchEntryErr can only be created from error payload; qed"), + }; + (err, self.0.id) + } +} + +/// A batch of JSON-RPC requests. +#[derive(Debug, Default)] +pub struct Batch<'a> { + inner: Vec, BatchEntryErr<'a>>>, + extensions: Option, +} + +impl std::fmt::Display for Batch<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let fmt = serde_json::to_string(self).map_err(|_| std::fmt::Error)?; + f.write_str(&fmt) + } +} + +impl<'a> IntoIterator for Batch<'a> { + type Item = Result, BatchEntryErr<'a>>; + type IntoIter = std::vec::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.inner.into_iter() + } +} + +impl<'a> Batch<'a> { + /// Create a new batch from a vector of batch entries. + pub fn from(entries: Vec, BatchEntryErr<'a>>>) -> Self { + Self { inner: entries, extensions: None } + } + + /// Create a new empty batch. + pub fn new() -> Self { + Self { inner: Vec::new(), extensions: None } + } + + /// Create a new empty batch with the at least capacity. + pub fn with_capacity(capacity: usize) -> Self { + Self { inner: Vec::with_capacity(capacity), extensions: None } + } + + /// Push a new batch entry to the batch. + pub fn push(&mut self, req: Request<'a>) { + match self.extensions { + Some(ref mut ext) => { + ext.extend(req.extensions().clone()); + } + None => { + self.extensions = Some(req.extensions().clone()); + } + }; + + self.inner.push(Ok(BatchEntry::Call(req))); + } + + /// Get the length of the batch. + pub fn len(&self) -> usize { + self.inner.len() + } + + /// Returns whether the batch is empty. + pub fn is_empty(&self) -> bool { + self.inner.is_empty() + } + + /// Get an iterator over the batch. + pub fn iter(&self) -> impl Iterator, BatchEntryErr<'a>>> { + self.inner.iter() + } + + /// Get a mutable iterator over the batch. + pub fn iter_mut(&mut self) -> impl Iterator, BatchEntryErr<'a>>> { + self.inner.iter_mut() + } + + /// Consume the batch and and return the parts. + pub fn into_extensions(self) -> Extensions { + match self.extensions { + Some(ext) => ext, + None => self.extensions_from_iter(), + } + } + + /// Get a reference to the extensions of the batch. + pub fn extensions(&mut self) -> &Extensions { + if self.extensions.is_none() { + self.extensions = Some(self.extensions_from_iter()); + } + + self.extensions.as_ref().expect("Extensions inserted above; qed") + } + + /// Get a mutable reference to the extensions of the batch. + pub fn extensions_mut(&mut self) -> &mut Extensions { + if self.extensions.is_none() { + self.extensions = Some(self.extensions_from_iter()); + } + + self.extensions.as_mut().expect("Extensions inserted above; qed") + } + + fn extensions_from_iter(&self) -> Extensions { + let mut ext = Extensions::new(); + for entry in self.inner.iter().flatten() { + ext.extend(entry.extensions().clone()); + } + ext + } +} + +impl Serialize for Batch<'_> { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let mut seq = serializer.serialize_seq(Some(self.inner.len()))?; + for entry in &self.inner { + match entry { + Ok(entry) => seq.serialize_element(entry)?, + Err(err) => seq.serialize_element(&err.0)?, + } + } + seq.end() + } +} + +#[derive(Debug, Clone)] +/// A marker type to indicate that the request is a subscription for the [`RpcServiceT::call`] method. +pub struct IsSubscription { + sub_id: Id<'static>, + unsub_id: Id<'static>, + unsub_method: String, +} + +impl IsSubscription { + /// Create a new [`IsSubscription`] instance. + pub fn new(sub_id: Id<'static>, unsub_id: Id<'static>, unsub_method: String) -> Self { + Self { sub_id, unsub_id, unsub_method } + } + + /// Get the request id of the subscription calls. + pub fn sub_req_id(&self) -> Id<'static> { + self.sub_id.clone() + } + + /// Get the request id of the unsubscription call. + pub fn unsub_req_id(&self) -> Id<'static> { + self.unsub_id.clone() + } + + /// Get the unsubscription method name. + pub fn unsubscribe_method(&self) -> &str { + &self.unsub_method + } +} + +/// An extension type for the [`RpcServiceT::batch`] for the expected id range of the batch entries. +#[derive(Debug, Clone)] +pub struct IsBatch { + /// The range of ids for the batch entries. + pub id_range: std::ops::Range, +} + +/// A batch entry specific for the [`RpcServiceT::batch`] method to support both +/// method calls and notifications. +#[derive(Debug, Clone, Serialize)] +#[serde(untagged)] +pub enum BatchEntry<'a> { + /// A regular JSON-RPC call. + Call(Request<'a>), + /// A JSON-RPC notification. + Notification(Notification<'a>), +} + +impl<'a> BatchEntry<'a> { + /// Get a reference to extensions of the batch entry. + pub fn extensions(&self) -> &Extensions { + match self { + BatchEntry::Call(req) => req.extensions(), + BatchEntry::Notification(n) => n.extensions(), + } + } + + /// Get a mut reference to extensions of the batch entry. + pub fn extensions_mut(&mut self) -> &mut Extensions { + match self { + BatchEntry::Call(req) => req.extensions_mut(), + BatchEntry::Notification(n) => n.extensions_mut(), + } + } + + /// Get the method name of the batch entry. + pub fn method_name(&self) -> &str { + match self { + BatchEntry::Call(req) => req.method_name(), + BatchEntry::Notification(n) => n.method_name(), + } + } + + /// Set the method name of the batch entry. + pub fn set_method_name(&mut self, name: impl Into>) { + match self { + BatchEntry::Call(req) => { + req.method = name.into(); + } + BatchEntry::Notification(n) => { + n.method = name.into(); + } + } + } + + /// Get the params of the batch entry (may be empty). + pub fn params(&self) -> Option<&Cow<'a, RawValue>> { + match self { + BatchEntry::Call(req) => req.params.as_ref(), + BatchEntry::Notification(n) => n.params.as_ref(), + } + } + + /// Set the params of the batch entry. + pub fn set_params(&mut self, params: Option>) { + match self { + BatchEntry::Call(req) => { + req.params = params.map(Cow::Owned); + } + BatchEntry::Notification(n) => { + n.params = params.map(Cow::Owned); + } + } + } + + /// Consume the batch entry and extract the extensions. + pub fn into_extensions(self) -> Extensions { + match self { + BatchEntry::Call(req) => req.extensions, + BatchEntry::Notification(n) => n.extensions, + } + } +} + +/// Represent a JSON-RPC service that can process JSON-RPC calls, notifications, and batch requests. +/// +/// This trait is similar to [`tower::Service`] but it's specialized for JSON-RPC operations. +/// +/// The response type is a future that resolves to a `Result` mainly because this trait is +/// intended to by used by both client and server implementations. +/// +/// In the server implementation, the error is infallible but in the client implementation, the error +/// can occur due to I/O errors or JSON-RPC protocol errors. +/// +/// Such that server implementations must use `std::convert::Infallible` as the error type because +/// the underlying service requires that and otherwise one will get a compiler error that tries to +/// explain that. +pub trait RpcServiceT { + /// The error type. + type Error: std::fmt::Debug; + + /// The response type + type Response: ToJson; + + /// Processes a single JSON-RPC call, which may be a subscription or regular call. + fn call<'a>(&self, request: Request<'a>) -> impl Future> + Send + 'a; + + /// Processes multiple JSON-RPC calls at once, similar to `RpcServiceT::call`. + /// + /// This method wraps `RpcServiceT::call` and `RpcServiceT::notification`, + /// but the root RPC service does not inherently recognize custom implementations + /// of these methods. + /// + /// As a result, if you have custom logic for individual calls or notifications, + /// you must duplicate that implementation in this method or no middleware will be applied + /// for calls inside the batch. + fn batch<'a>(&self, requests: Batch<'a>) -> impl Future> + Send + 'a; + + /// Similar to `RpcServiceT::call` but processes a JSON-RPC notification. + fn notification<'a>( + &self, + n: Notification<'a>, + ) -> impl Future> + Send + 'a; +} + +/// Interface for types that can be serialized into JSON. +pub trait ToJson { + /// Convert the type into a JSON value. + fn to_json(&self) -> Result, serde_json::Error>; +} + +/// Similar to [`tower::ServiceBuilder`] but doesn't +/// support any tower middleware implementations. +#[derive(Debug, Clone)] +pub struct RpcServiceBuilder(tower::ServiceBuilder); + +impl Default for RpcServiceBuilder { + fn default() -> Self { + RpcServiceBuilder(tower::ServiceBuilder::new()) + } +} + +impl RpcServiceBuilder { + /// Create a new [`RpcServiceBuilder`]. + pub fn new() -> Self { + Self(tower::ServiceBuilder::new()) + } +} + +impl RpcServiceBuilder { + /// Optionally add a new layer `T` to the [`RpcServiceBuilder`]. + /// + /// See the documentation for [`tower::ServiceBuilder::option_layer`] for more details. + pub fn option_layer(self, layer: Option) -> RpcServiceBuilder, L>> { + let layer = if let Some(layer) = layer { Either::Left(layer) } else { Either::Right(Identity::new()) }; + self.layer(layer) + } + + /// Add a new layer `T` to the [`RpcServiceBuilder`]. + /// + /// See the documentation for [`tower::ServiceBuilder::layer`] for more details. + pub fn layer(self, layer: T) -> RpcServiceBuilder> { + RpcServiceBuilder(self.0.layer(layer)) + } + + /// Add a [`tower::Layer`] built from a function that accepts a service and returns another service. + /// + /// See the documentation for [`tower::ServiceBuilder::layer_fn`] for more details. + pub fn layer_fn(self, f: F) -> RpcServiceBuilder, L>> { + RpcServiceBuilder(self.0.layer_fn(f)) + } + + /// Add a logging layer to [`RpcServiceBuilder`] + /// + /// This logs each request and response for every call. + /// + pub fn rpc_logger(self, max_log_len: u32) -> RpcServiceBuilder> { + RpcServiceBuilder(self.0.layer(layer::RpcLoggerLayer::new(max_log_len))) + } + + /// Wrap the service `S` with the middleware. + pub fn service(&self, service: S) -> L::Service + where + L: tower::Layer, + { + self.0.service(service) + } +} + +/// Response which may be ready or a future. +#[derive(Debug)] +#[pin_project] +pub struct ResponseFuture(#[pin] futures_util::future::Either>>); + +impl ResponseFuture { + /// Returns a future that resolves to a response. + pub fn future(f: F) -> ResponseFuture { + ResponseFuture(Either::Left(f)) + } + + /// Return a response which is already computed. + pub fn ready(response: R) -> ResponseFuture { + ResponseFuture(Either::Right(std::future::ready(Ok(response)))) + } +} + +impl Future for ResponseFuture +where + F: Future>, +{ + type Output = F::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.project().0.poll(cx) { + Poll::Ready(rp) => Poll::Ready(rp), + Poll::Pending => Poll::Pending, + } + } +} + +#[cfg(test)] +mod tests { + use jsonrpsee_types::{ErrorCode, ErrorObject}; + + #[test] + fn serialize_batch_entry() { + use super::{BatchEntry, Notification, Request}; + use jsonrpsee_types::Id; + + let req = Request::borrowed("say_hello", None, Id::Number(1)); + let batch_entry = BatchEntry::Call(req.clone()); + assert_eq!( + serde_json::to_string(&batch_entry).unwrap(), + "{\"jsonrpc\":\"2.0\",\"id\":1,\"method\":\"say_hello\"}", + ); + + let notification = Notification::new("say_hello".into(), None); + let batch_entry = BatchEntry::Notification(notification.clone()); + assert_eq!( + serde_json::to_string(&batch_entry).unwrap(), + "{\"jsonrpc\":\"2.0\",\"method\":\"say_hello\",\"params\":null}", + ); + } + + #[test] + fn serialize_batch_works() { + use super::{Batch, BatchEntry, BatchEntryErr, Notification, Request}; + use jsonrpsee_types::Id; + + let req = Request::borrowed("say_hello", None, Id::Number(1)); + let notification = Notification::new("say_hello".into(), None); + let batch = Batch::from(vec![ + Ok(BatchEntry::Call(req)), + Ok(BatchEntry::Notification(notification)), + Err(BatchEntryErr::new(Id::Number(2), ErrorObject::from(ErrorCode::InvalidRequest))), + ]); + assert_eq!( + serde_json::to_string(&batch).unwrap(), + r#"[{"jsonrpc":"2.0","id":1,"method":"say_hello"},{"jsonrpc":"2.0","method":"say_hello","params":null},{"jsonrpc":"2.0","id":2,"error":{"code":-32600,"message":"Invalid request"}}]"#, + ); + } +} diff --git a/core/src/server/helpers.rs b/core/src/server/helpers.rs index 1135155d30..cf480312ea 100644 --- a/core/src/server/helpers.rs +++ b/core/src/server/helpers.rs @@ -24,7 +24,6 @@ // IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use std::io; use std::time::Duration; use jsonrpsee_types::{ErrorCode, ErrorObject, Id, InvalidRequest, Response, ResponsePayload}; @@ -32,51 +31,6 @@ use tokio::sync::mpsc; use super::{DisconnectError, SendTimeoutError, SubscriptionMessage, TrySendError}; -/// Bounded writer that allows writing at most `max_len` bytes. -/// -/// ``` -/// use std::io::Write; -/// -/// use jsonrpsee_core::server::helpers::BoundedWriter; -/// -/// let mut writer = BoundedWriter::new(10); -/// (&mut writer).write("hello".as_bytes()).unwrap(); -/// assert_eq!(std::str::from_utf8(&writer.into_bytes()).unwrap(), "hello"); -/// ``` -#[derive(Debug, Clone)] -pub struct BoundedWriter { - max_len: usize, - buf: Vec, -} - -impl BoundedWriter { - /// Create a new bounded writer. - pub fn new(max_len: usize) -> Self { - Self { max_len, buf: Vec::with_capacity(128) } - } - - /// Consume the writer and extract the written bytes. - pub fn into_bytes(self) -> Vec { - self.buf - } -} - -impl io::Write for &mut BoundedWriter { - fn write(&mut self, buf: &[u8]) -> io::Result { - let len = self.buf.len() + buf.len(); - if self.max_len >= len { - self.buf.extend_from_slice(buf); - Ok(buf.len()) - } else { - Err(io::Error::new(io::ErrorKind::OutOfMemory, "Memory capacity exceeded")) - } - } - - fn flush(&mut self) -> io::Result<()> { - Ok(()) - } -} - /// Sink that is used to send back the result to the server for a specific method. #[derive(Clone, Debug)] pub struct MethodSink { @@ -171,26 +125,3 @@ pub fn prepare_error(data: &[u8]) -> (Id<'_>, ErrorCode) { Err(_) => (Id::Null, ErrorCode::ParseError), } } - -#[cfg(test)] -mod tests { - use crate::server::BoundedWriter; - use jsonrpsee_types::{Id, Response, ResponsePayload}; - - #[test] - fn bounded_serializer_work() { - let mut writer = BoundedWriter::new(100); - let result = ResponsePayload::success(&"success"); - let rp = &Response::new(result, Id::Number(1)); - - assert!(serde_json::to_writer(&mut writer, rp).is_ok()); - assert_eq!(String::from_utf8(writer.into_bytes()).unwrap(), r#"{"jsonrpc":"2.0","id":1,"result":"success"}"#); - } - - #[test] - fn bounded_serializer_cap_works() { - let mut writer = BoundedWriter::new(100); - // NOTE: `"` is part of the serialization so 101 characters. - assert!(serde_json::to_writer(&mut writer, &"x".repeat(99)).is_err()); - } -} diff --git a/core/src/server/method_response.rs b/core/src/server/method_response.rs index d0ba38b05b..9e8d7dc9f2 100644 --- a/core/src/server/method_response.rs +++ b/core/src/server/method_response.rs @@ -24,23 +24,30 @@ // IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::server::{BoundedWriter, LOG_TARGET}; +//! Represent a method response. + +const LOG_TARGET: &str = "jsonrpsee-core"; + +use std::io; use std::task::Poll; use futures_util::{Future, FutureExt}; use http::Extensions; use jsonrpsee_types::error::{ - reject_too_big_batch_response, ErrorCode, ErrorObject, OVERSIZED_RESPONSE_CODE, OVERSIZED_RESPONSE_MSG, + ErrorCode, ErrorObject, OVERSIZED_RESPONSE_CODE, OVERSIZED_RESPONSE_MSG, reject_too_big_batch_response, }; use jsonrpsee_types::{ErrorObjectOwned, Id, Response, ResponsePayload as InnerResponsePayload}; use serde::Serialize; -use serde_json::value::to_raw_value; +use serde_json::value::{RawValue, to_raw_value}; + +use crate::middleware::ToJson; #[derive(Debug, Clone)] enum ResponseKind { MethodCall, Subscription, Batch, + Notification, } /// Represents a response to a method call. @@ -52,7 +59,7 @@ enum ResponseKind { #[derive(Debug)] pub struct MethodResponse { /// Serialized JSON-RPC response, - result: String, + json: Box, /// Indicates whether the call was successful or not. success_or_error: MethodResponseResult, /// Indicates whether the call was a subscription response. @@ -64,6 +71,24 @@ pub struct MethodResponse { extensions: Extensions, } +impl AsRef for MethodResponse { + fn as_ref(&self) -> &str { + self.json.get() + } +} + +impl ToJson for MethodResponse { + fn to_json(&self) -> Result, serde_json::Error> { + Ok(self.json.clone()) + } +} + +impl std::fmt::Display for MethodResponse { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.json) + } +} + impl MethodResponse { /// Returns whether the call was successful. pub fn is_success(&self) -> bool { @@ -85,24 +110,29 @@ impl MethodResponse { matches!(self.kind, ResponseKind::MethodCall) } + /// Returns whether the response is a notification response. + pub fn is_notification(&self) -> bool { + matches!(self.kind, ResponseKind::Notification) + } + /// Returns whether the response is a batch response. pub fn is_batch(&self) -> bool { matches!(self.kind, ResponseKind::Batch) } - /// Consume the method response and extract the serialized response. - pub fn into_result(self) -> String { - self.result + /// Consume the method response and extract the serialized JSON response. + pub fn into_json(self) -> Box { + self.json } - /// Extract the serialized response as a String. - pub fn to_result(&self) -> String { - self.result.clone() + /// Get the serialized JSON response. + pub fn to_json(&self) -> Box { + self.json.clone() } /// Consume the method response and extract the parts. - pub fn into_parts(self) -> (String, Option, Extensions) { - (self.result, self.on_close, self.extensions) + pub fn into_parts(self) -> (Box, Option, Extensions) { + (self.json, self.on_close, self.extensions) } /// Get the error code @@ -112,15 +142,15 @@ impl MethodResponse { self.success_or_error.as_error_code() } - /// Get a reference to the serialized response. - pub fn as_result(&self) -> &str { - &self.result + /// Get a reference to the serialized JSON response. + pub fn as_json(&self) -> &RawValue { + &self.json } /// Create a method response from [`BatchResponse`]. pub fn from_batch(batch: BatchResponse) -> Self { Self { - result: batch.result, + json: batch.json, success_or_error: MethodResponseResult::Success, kind: ResponseKind::Batch, on_close: None, @@ -161,8 +191,9 @@ impl MethodResponse { Ok(_) => { // Safety - serde_json does not emit invalid UTF-8. let result = unsafe { String::from_utf8_unchecked(writer.into_bytes()) }; + let json = RawValue::from_string(result).expect("JSON serialization infallible; qed"); - Self { result, success_or_error, kind, on_close: rp.on_exit, extensions: Extensions::new() } + Self { json, success_or_error, kind, on_close: rp.on_exit, extensions: Extensions::new() } } Err(err) => { tracing::error!(target: LOG_TARGET, "Error serializing response: {:?}", err); @@ -176,11 +207,11 @@ impl MethodResponse { OVERSIZED_RESPONSE_MSG, data.as_deref(), )); - let result = - serde_json::to_string(&Response::new(err, id)).expect("JSON serialization infallible; qed"); + let json = serde_json::value::to_raw_value(&Response::new(err, id)) + .expect("JSON serialization infallible; qed"); Self { - result, + json, success_or_error: MethodResponseResult::Failed(err_code), kind, on_close: rp.on_exit, @@ -189,10 +220,10 @@ impl MethodResponse { } else { let err = ErrorCode::InternalError; let payload = jsonrpsee_types::ResponsePayload::<()>::error(err); - let result = - serde_json::to_string(&Response::new(payload, id)).expect("JSON serialization infallible; qed"); + let json = serde_json::value::to_raw_value(&Response::new(payload, id)) + .expect("JSON serialization infallible; qed"); Self { - result, + json, success_or_error: MethodResponseResult::Failed(err.code()), kind, on_close: rp.on_exit, @@ -216,9 +247,10 @@ impl MethodResponse { let err: ErrorObject = err.into(); let err_code = err.code(); let err = InnerResponsePayload::<()>::error_borrowed(err); - let result = serde_json::to_string(&Response::new(err, id)).expect("JSON serialization infallible; qed"); + let json = + serde_json::value::to_raw_value(&Response::new(err, id)).expect("JSON serialization infallible; qed"); Self { - result, + json, success_or_error: MethodResponseResult::Failed(err_code), kind: ResponseKind::MethodCall, on_close: None, @@ -226,12 +258,23 @@ impl MethodResponse { } } + /// Create notification response which is a response that doesn't expect a reply. + pub fn notification() -> Self { + Self { + json: Box::::default(), + success_or_error: MethodResponseResult::Success, + kind: ResponseKind::Notification, + on_close: None, + extensions: Extensions::new(), + } + } + /// Returns a reference to the associated extensions. pub fn extensions(&self) -> &Extensions { &self.extensions } - /// Returns a reference to the associated extensions. + /// Returns a mut reference to the associated extensions. pub fn extensions_mut(&mut self) -> &mut Extensions { &mut self.extensions } @@ -300,13 +343,13 @@ impl BatchResponseBuilder { pub fn append(&mut self, response: MethodResponse) -> Result<(), MethodResponse> { // `,` will occupy one extra byte for each entry // on the last item the `,` is replaced by `]`. - let len = response.result.len() + self.result.len() + 1; + let len = response.json.get().len() + self.result.len() + 1; self.extensions.extend(response.extensions); if len > self.max_response_size { Err(MethodResponse::error(Id::Null, reject_too_big_batch_response(self.max_response_size))) } else { - self.result.push_str(&response.result); + self.result.push_str(response.json.get()); self.result.push(','); Ok(()) } @@ -321,13 +364,14 @@ impl BatchResponseBuilder { pub fn finish(mut self) -> BatchResponse { if self.result.len() == 1 { BatchResponse { - result: batch_response_error(Id::Null, ErrorObject::from(ErrorCode::InvalidRequest)), + json: batch_response_error(Id::Null, ErrorObject::from(ErrorCode::InvalidRequest)), extensions: self.extensions, } } else { self.result.pop(); self.result.push(']'); - BatchResponse { result: self.result, extensions: self.extensions } + let json = RawValue::from_string(self.result).expect("JSON serialization infallible; qed"); + BatchResponse { json, extensions: self.extensions } } } } @@ -335,14 +379,14 @@ impl BatchResponseBuilder { /// Serialized batch response. #[derive(Debug, Clone)] pub struct BatchResponse { - result: String, + json: Box, extensions: Extensions, } /// Create a JSON-RPC error response. -pub fn batch_response_error(id: Id, err: impl Into>) -> String { +pub fn batch_response_error(id: Id, err: impl Into>) -> Box { let err = InnerResponsePayload::<()>::error_borrowed(err); - serde_json::to_string(&Response::new(err, id)).expect("JSON serialization infallible; qed") + serde_json::value::to_raw_value(&Response::new(err, id)).expect("JSON serialization infallible; qed") } /// Similar to [`jsonrpsee_types::ResponsePayload`] but possible to with an async-like @@ -440,11 +484,11 @@ pub struct MethodResponseFuture(tokio::sync::oneshot::Receiver); /// was succesful or not. #[derive(Debug, Copy, Clone)] pub enum NotifyMsg { - /// The response was succesfully processed. + /// The response was successfully processed. Ok, /// The response was the wrong kind /// such an error response when - /// one expected a succesful response. + /// one expected a successful response. Err, } @@ -470,29 +514,63 @@ impl Future for MethodResponseFuture { } } +/// Bounded writer that allows writing at most `max_len` bytes. +#[derive(Debug, Clone)] +struct BoundedWriter { + max_len: usize, + buf: Vec, +} + +impl BoundedWriter { + /// Create a new bounded writer. + pub fn new(max_len: usize) -> Self { + Self { max_len, buf: Vec::with_capacity(128) } + } + + /// Consume the writer and extract the written bytes. + pub fn into_bytes(self) -> Vec { + self.buf + } +} + +impl io::Write for &mut BoundedWriter { + fn write(&mut self, buf: &[u8]) -> io::Result { + let len = self.buf.len() + buf.len(); + if self.max_len >= len { + self.buf.extend_from_slice(buf); + Ok(buf.len()) + } else { + Err(io::Error::new(io::ErrorKind::OutOfMemory, "Memory capacity exceeded")) + } + } + + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} + #[cfg(test)] mod tests { - use super::{BatchResponseBuilder, MethodResponse, ResponsePayload}; - use jsonrpsee_types::Id; + use super::{BatchResponseBuilder, BoundedWriter, Id, MethodResponse, ResponsePayload}; #[test] fn batch_with_single_works() { let method = MethodResponse::response(Id::Number(1), ResponsePayload::success_borrowed(&"a"), usize::MAX); - assert_eq!(method.result.len(), 37); + assert_eq!(method.json.get().len(), 37); // Recall a batch appends two bytes for the `[]`. let mut builder = BatchResponseBuilder::new_with_limit(39); builder.append(method).unwrap(); let batch = builder.finish(); - assert_eq!(batch.result, r#"[{"jsonrpc":"2.0","id":1,"result":"a"}]"#) + assert_eq!(batch.json.get(), r#"[{"jsonrpc":"2.0","id":1,"result":"a"}]"#) } #[test] fn batch_with_multiple_works() { let m1 = MethodResponse::response(Id::Number(1), ResponsePayload::success_borrowed(&"a"), usize::MAX); let m11 = MethodResponse::response(Id::Number(1), ResponsePayload::success_borrowed(&"a"), usize::MAX); - assert_eq!(m1.result.len(), 37); + assert_eq!(m1.json.get().len(), 37); // Recall a batch appends two bytes for the `[]` and one byte for `,` to append a method call. // so it should be 2 + (37 * n) + (n-1) @@ -502,7 +580,7 @@ mod tests { builder.append(m11).unwrap(); let batch = builder.finish(); - assert_eq!(batch.result, r#"[{"jsonrpc":"2.0","id":1,"result":"a"},{"jsonrpc":"2.0","id":1,"result":"a"}]"#) + assert_eq!(batch.json.get(), r#"[{"jsonrpc":"2.0","id":1,"result":"a"},{"jsonrpc":"2.0","id":1,"result":"a"}]"#) } #[test] @@ -510,17 +588,36 @@ mod tests { let batch = BatchResponseBuilder::new_with_limit(1024).finish(); let exp_err = r#"{"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"Invalid request"}}"#; - assert_eq!(batch.result, exp_err); + assert_eq!(batch.json.get(), exp_err); } #[test] fn batch_too_big() { let method = MethodResponse::response(Id::Number(1), ResponsePayload::success_borrowed(&"a".repeat(28)), 128); - assert_eq!(method.result.len(), 64); + assert_eq!(method.json.get().len(), 64); let batch = BatchResponseBuilder::new_with_limit(63).append(method).unwrap_err(); let exp_err = r#"{"jsonrpc":"2.0","id":null,"error":{"code":-32011,"message":"The batch response was too large","data":"Exceeded max limit of 63"}}"#; - assert_eq!(batch.result, exp_err); + assert_eq!(batch.json.get(), exp_err); + } + + #[test] + fn bounded_serializer_work() { + use jsonrpsee_types::{Response, ResponsePayload}; + + let mut writer = BoundedWriter::new(100); + let result = ResponsePayload::success(&"success"); + let rp = &Response::new(result, Id::Number(1)); + + assert!(serde_json::to_writer(&mut writer, rp).is_ok()); + assert_eq!(String::from_utf8(writer.into_bytes()).unwrap(), r#"{"jsonrpc":"2.0","id":1,"result":"success"}"#); + } + + #[test] + fn bounded_serializer_cap_works() { + let mut writer = BoundedWriter::new(100); + // NOTE: `"` is part of the serialization so 101 characters. + assert!(serde_json::to_writer(&mut writer, &"x".repeat(99)).is_err()); } } diff --git a/core/src/server/mod.rs b/core/src/server/mod.rs index 8a187fafd3..12e1fbe544 100644 --- a/core/src/server/mod.rs +++ b/core/src/server/mod.rs @@ -30,7 +30,7 @@ mod error; /// Helpers. pub mod helpers; -/// Method response related types. +/// Method response. mod method_response; /// JSON-RPC "modules" group sets of methods that belong together and handles method/subscription registration. mod rpc_module; diff --git a/core/src/server/rpc_module.rs b/core/src/server/rpc_module.rs index 4fcd5aab32..38937f13a8 100644 --- a/core/src/server/rpc_module.rs +++ b/core/src/server/rpc_module.rs @@ -33,15 +33,14 @@ use std::sync::Arc; use crate::error::RegisterMethodError; use crate::id_providers::RandomIntegerIdProvider; use crate::server::helpers::MethodSink; -use crate::server::method_response::MethodResponse; use crate::server::subscription::{ - sub_message_to_json, BoundedSubscriptions, IntoSubscriptionCloseResponse, PendingSubscriptionSink, - SubNotifResultOrError, Subscribers, Subscription, SubscriptionCloseResponse, SubscriptionKey, SubscriptionPermit, - SubscriptionState, + BoundedSubscriptions, IntoSubscriptionCloseResponse, PendingSubscriptionSink, SubNotifResultOrError, Subscribers, + Subscription, SubscriptionCloseResponse, SubscriptionKey, SubscriptionPermit, SubscriptionState, + sub_message_to_json, }; -use crate::server::{ResponsePayload, LOG_TARGET}; +use crate::server::{LOG_TARGET, MethodResponse, ResponsePayload}; use crate::traits::ToRpcParams; -use futures_util::{future::BoxFuture, FutureExt}; +use futures_util::{FutureExt, future::BoxFuture}; use http::Extensions; use jsonrpsee_types::error::{ErrorCode, ErrorObject}; use jsonrpsee_types::{ @@ -49,6 +48,7 @@ use jsonrpsee_types::{ }; use rustc_hash::FxHashMap; use serde::de::DeserializeOwned; +use serde_json::value::RawValue; use tokio::sync::{mpsc, oneshot}; use super::IntoResponse; @@ -95,7 +95,7 @@ pub type MaxResponseSize = usize; /// A tuple containing: /// - Call result as a `String`, /// - a [`mpsc::UnboundedReceiver`] to receive future subscription results -pub type RawRpcResponse = (String, mpsc::Receiver); +pub type RawRpcResponse = (Box, mpsc::Receiver); /// The error that can occur when [`Methods::call`] or [`Methods::subscribe`] is invoked. #[derive(thiserror::Error, Debug)] @@ -307,11 +307,11 @@ impl Methods { params: Params, ) -> Result { let params = params.to_rpc_params()?; - let req = Request::new(method.into(), params.as_ref().map(|p| p.as_ref()), Id::Number(0)); + let req = Request::borrowed(method, params.as_ref().map(|p| p.as_ref()), Id::Number(0)); tracing::trace!(target: LOG_TARGET, "[Methods::call] Method: {:?}, params: {:?}", method, params); let (rp, _) = self.inner_call(req, 1, mock_subscription_permit()).await; - let rp = serde_json::from_str::>(&rp)?; + let rp = serde_json::from_str::>(rp.get())?; ResponseSuccess::try_from(rp).map(|s| s.result).map_err(|e| MethodsError::JsonRpc(e.into_owned())) } @@ -339,7 +339,7 @@ impl Methods { /// }).unwrap(); /// let (resp, mut stream) = module.raw_json_request(r#"{"jsonrpc":"2.0","method":"hi","id":0}"#, 1).await.unwrap(); /// // If the response is an error converting it to `Success` will fail. - /// let resp: Success = serde_json::from_str::>(&resp).unwrap().try_into().unwrap(); + /// let resp: Success = serde_json::from_str::>(resp.get()).unwrap().try_into().unwrap(); /// let sub_resp = stream.recv().await.unwrap(); /// assert_eq!( /// format!(r#"{{"jsonrpc":"2.0","method":"hi","params":{{"subscription":{},"result":"one answer"}}}}"#, resp.result), @@ -351,7 +351,7 @@ impl Methods { &self, request: &str, buf_size: usize, - ) -> Result<(String, mpsc::Receiver), serde_json::Error> { + ) -> Result<(Box, mpsc::Receiver), serde_json::Error> { tracing::trace!("[Methods::raw_json_request] Request: {:?}", request); let req: Request = serde_json::from_str(request)?; let (resp, rx) = self.inner_call(req, buf_size, mock_subscription_permit()).await; @@ -455,18 +455,15 @@ impl Methods { buf_size: usize, ) -> Result { let params = params.to_rpc_params()?; - let req = Request::new(sub_method.into(), params.as_ref().map(|p| p.as_ref()), Id::Number(0)); + let req = Request::borrowed(sub_method, params.as_ref().map(|p| p.as_ref()), Id::Number(0)); tracing::trace!(target: LOG_TARGET, "[Methods::subscribe] Method: {}, params: {:?}", sub_method, params); let (resp, rx) = self.inner_call(req, buf_size, mock_subscription_permit()).await; + let as_success: ResponseSuccess<&RawValue> = serde_json::from_str::>(resp.get())?.try_into()?; + let sub_id: RpcSubscriptionId = serde_json::from_str(as_success.result.get())?; - // TODO: hack around the lifetime on the `SubscriptionId` by deserialize first to serde_json::Value. - let as_success: ResponseSuccess = serde_json::from_str::>(&resp)?.try_into()?; - - let sub_id = as_success.result.try_into().map_err(|_| MethodsError::InvalidSubscriptionId(resp.clone()))?; - - Ok(Subscription { sub_id, rx }) + Ok(Subscription { sub_id: sub_id.into_owned(), rx }) } /// Returns an `Iterator` with all the method names registered on this server. diff --git a/core/src/server/subscription.rs b/core/src/server/subscription.rs index deb581ed28..6ffd5f3c71 100644 --- a/core/src/server/subscription.rs +++ b/core/src/server/subscription.rs @@ -28,17 +28,17 @@ use super::helpers::MethodSink; use super::{MethodResponse, MethodsError, ResponsePayload}; +use crate::server::LOG_TARGET; use crate::server::error::{DisconnectError, PendingSubscriptionAcceptError, SendTimeoutError, TrySendError}; use crate::server::rpc_module::ConnectionId; -use crate::server::LOG_TARGET; use crate::{error::StringError, traits::IdProvider}; use jsonrpsee_types::SubscriptionPayload; -use jsonrpsee_types::{response::SubscriptionError, ErrorObjectOwned, Id, SubscriptionId, SubscriptionResponse}; +use jsonrpsee_types::{ErrorObjectOwned, Id, SubscriptionId, SubscriptionResponse, response::SubscriptionError}; use parking_lot::Mutex; use rustc_hash::FxHashMap; -use serde::{de::DeserializeOwned, Serialize}; +use serde::{Serialize, de::DeserializeOwned}; use std::{sync::Arc, time::Duration}; -use tokio::sync::{mpsc, oneshot, OwnedSemaphorePermit, Semaphore}; +use tokio::sync::{OwnedSemaphorePermit, Semaphore, mpsc, oneshot}; /// Type-alias for subscribers. pub type Subscribers = Arc)>>>; @@ -259,7 +259,7 @@ impl PendingSubscriptionSink { /// once reject has been called. pub async fn reject(self, err: impl Into) { let err = MethodResponse::subscription_error(self.id, err.into()); - _ = self.inner.send(err.to_result()).await; + _ = self.inner.send(err.as_json().get().to_owned()).await; _ = self.subscribe.send(err); } @@ -282,7 +282,7 @@ impl PendingSubscriptionSink { // // The same message is sent twice here because one is sent directly to the transport layer and // the other one is sent internally to accept the subscription. - self.inner.send(response.to_result()).await.map_err(|_| PendingSubscriptionAcceptError)?; + self.inner.send(response.as_json().get().to_owned()).await.map_err(|_| PendingSubscriptionAcceptError)?; self.subscribe.send(response).map_err(|_| PendingSubscriptionAcceptError)?; if success { @@ -297,7 +297,9 @@ impl PendingSubscriptionSink { _permit: Arc::new(self.permit), }) } else { - panic!("The subscription response was too big; adjust the `max_response_size` or change Subscription ID generation"); + panic!( + "The subscription response was too big; adjust the `max_response_size` or change Subscription ID generation" + ); } } diff --git a/core/src/tracing.rs b/core/src/tracing.rs deleted file mode 100644 index 4967f1a8e6..0000000000 --- a/core/src/tracing.rs +++ /dev/null @@ -1,124 +0,0 @@ -use serde::Serialize; -use tracing::Level; - -const CLIENT: &str = "jsonrpsee-client"; -const SERVER: &str = "jsonrpsee-server"; - -/// Logging with jsonrpsee client target. -pub mod client { - use super::*; - - /// Helper for writing trace logs from str. - pub fn tx_log_from_str(s: impl AsRef, max: u32) { - if tracing::enabled!(Level::TRACE) { - let msg = truncate_at_char_boundary(s.as_ref(), max as usize); - tracing::trace!(target: CLIENT, send = msg); - } - } - - /// Helper for writing trace logs from JSON. - pub fn tx_log_from_json(s: &impl Serialize, max: u32) { - if tracing::enabled!(Level::TRACE) { - let json = serde_json::to_string(s).unwrap_or_default(); - let msg = truncate_at_char_boundary(&json, max as usize); - tracing::trace!(target: CLIENT, send = msg); - } - } - - /// Helper for writing trace logs from str. - pub fn rx_log_from_str(s: impl AsRef, max: u32) { - if tracing::enabled!(Level::TRACE) { - let msg = truncate_at_char_boundary(s.as_ref(), max as usize); - tracing::trace!(target: CLIENT, recv = msg); - } - } - - /// Helper for writing trace logs from JSON. - pub fn rx_log_from_json(s: &impl Serialize, max: u32) { - if tracing::enabled!(Level::TRACE) { - let res = serde_json::to_string(s).unwrap_or_default(); - let msg = truncate_at_char_boundary(res.as_str(), max as usize); - tracing::trace!(target: CLIENT, recv = msg); - } - } - - /// Helper for writing trace logs from bytes. - pub fn rx_log_from_bytes(bytes: &[u8], max: u32) { - if tracing::enabled!(Level::TRACE) { - let res = serde_json::from_slice::(bytes).unwrap_or_default(); - rx_log_from_json(&res, max); - } - } -} - -/// Logging with jsonrpsee server target. -pub mod server { - use super::*; - - /// Helper for writing trace logs from str. - pub fn tx_log_from_str(s: impl AsRef, max: u32) { - if tracing::enabled!(Level::TRACE) { - let msg = truncate_at_char_boundary(s.as_ref(), max as usize); - tracing::trace!(target: SERVER, send = msg); - } - } - - /// Helper for writing trace logs from JSON. - pub fn tx_log_from_json(s: &impl Serialize, max: u32) { - if tracing::enabled!(Level::TRACE) { - let json = serde_json::to_string(s).unwrap_or_default(); - let msg = truncate_at_char_boundary(&json, max as usize); - tracing::trace!(target: SERVER, send = msg); - } - } - - /// Helper for writing trace logs from str. - pub fn rx_log_from_str(s: impl AsRef, max: u32) { - if tracing::enabled!(Level::TRACE) { - let msg = truncate_at_char_boundary(s.as_ref(), max as usize); - tracing::trace!(target: SERVER, recv = msg); - } - } - - /// Helper for writing trace logs from JSON. - pub fn rx_log_from_json(s: &impl Serialize, max: u32) { - if tracing::enabled!(Level::TRACE) { - let res = serde_json::to_string(s).unwrap_or_default(); - let msg = truncate_at_char_boundary(res.as_str(), max as usize); - tracing::trace!(target: SERVER, recv = msg); - } - } - - /// Helper for writing trace logs from bytes. - pub fn rx_log_from_bytes(bytes: &[u8], max: u32) { - if tracing::enabled!(Level::TRACE) { - let res = serde_json::from_slice::(bytes).unwrap_or_default(); - rx_log_from_json(&res, max); - } - } -} - -/// Find the next char boundary to truncate at. -fn truncate_at_char_boundary(s: &str, max: usize) -> &str { - if s.len() < max { - return s; - } - - match s.char_indices().nth(max) { - None => s, - Some((idx, _)) => &s[..idx], - } -} - -#[cfg(test)] -mod tests { - use super::truncate_at_char_boundary; - - #[test] - fn truncate_at_char_boundary_works() { - assert_eq!(truncate_at_char_boundary("ボルテックス", 0), ""); - assert_eq!(truncate_at_char_boundary("ボルテックス", 4), "ボルテッ"); - assert_eq!(truncate_at_char_boundary("ボルテックス", 100), "ボルテックス"); - assert_eq!(truncate_at_char_boundary("hola-hola", 4), "hola"); - } -} diff --git a/examples/examples/core_client.rs b/examples/examples/core_client.rs index c7483b46be..7c7380bcb5 100644 --- a/examples/examples/core_client.rs +++ b/examples/examples/core_client.rs @@ -27,7 +27,7 @@ use std::net::SocketAddr; use jsonrpsee::client_transport::ws::{Url, WsTransportClientBuilder}; -use jsonrpsee::core::client::{Client, ClientBuilder, ClientT}; +use jsonrpsee::core::client::{ClientBuilder, ClientT}; use jsonrpsee::rpc_params; use jsonrpsee::server::{RpcModule, Server}; @@ -42,7 +42,7 @@ async fn main() -> anyhow::Result<()> { let uri = Url::parse(&format!("ws://{}", addr))?; let (tx, rx) = WsTransportClientBuilder::default().build(uri).await?; - let client: Client = ClientBuilder::default().build_with_tokio(tx, rx); + let client = ClientBuilder::default().build_with_tokio(tx, rx); let response: String = client.request("say_hello", rpc_params![]).await?; tracing::info!("response: {:?}", response); diff --git a/examples/examples/http.rs b/examples/examples/http.rs index f9adcb01d5..02722db1f7 100644 --- a/examples/examples/http.rs +++ b/examples/examples/http.rs @@ -29,6 +29,7 @@ use std::time::Duration; use hyper::body::Bytes; use jsonrpsee::core::client::ClientT; +use jsonrpsee::core::middleware::RpcServiceBuilder; use jsonrpsee::http_client::HttpClient; use jsonrpsee::rpc_params; use jsonrpsee::server::{RpcModule, Server}; @@ -58,7 +59,9 @@ async fn main() -> anyhow::Result<()> { .on_response(DefaultOnResponse::new().include_headers(true).latency_unit(LatencyUnit::Micros)), ); - let client = HttpClient::builder().set_http_middleware(middleware).build(url)?; + let rpc = RpcServiceBuilder::new().rpc_logger(1024); + + let client = HttpClient::builder().set_http_middleware(middleware).set_rpc_middleware(rpc).build(url)?; let params = rpc_params![1_u64, 2, 3]; let response: Result = client.request("say_hello", params).await; tracing::info!("r: {:?}", response); diff --git a/examples/examples/jsonrpsee_as_service.rs b/examples/examples/jsonrpsee_as_service.rs index ab91c6b304..2a71ebda37 100644 --- a/examples/examples/jsonrpsee_as_service.rs +++ b/examples/examples/jsonrpsee_as_service.rs @@ -31,6 +31,7 @@ //! The typical use-case for this is when one wants to have //! access to HTTP related things. +use std::convert::Infallible; use std::net::SocketAddr; use std::sync::Arc; use std::sync::atomic::{AtomicUsize, Ordering}; @@ -39,9 +40,9 @@ use futures::FutureExt; use hyper::HeaderMap; use hyper::header::AUTHORIZATION; use jsonrpsee::core::async_trait; +use jsonrpsee::core::middleware::{Batch, BatchEntry, BatchEntryErr, Notification, RpcServiceBuilder, RpcServiceT}; use jsonrpsee::http_client::HttpClient; use jsonrpsee::proc_macros::rpc; -use jsonrpsee::server::middleware::rpc::{ResponseFuture, RpcServiceBuilder, RpcServiceT}; use jsonrpsee::server::{ ServerConfig, ServerHandle, StopHandle, TowerServiceBuilder, serve_with_graceful_shutdown, stop_channel, }; @@ -62,6 +63,10 @@ struct Metrics { success_http_calls: Arc, } +fn auth_reject_error() -> ErrorObjectOwned { + ErrorObject::owned(-32999, "HTTP Authorization header is missing", None::<()>) +} + #[derive(Clone)] struct AuthorizationMiddleware { headers: HeaderMap, @@ -70,26 +75,96 @@ struct AuthorizationMiddleware { transport_label: &'static str, } -impl<'a, S> RpcServiceT<'a> for AuthorizationMiddleware +impl AuthorizationMiddleware { + /// Authorize the request by checking the `Authorization` header. + /// + /// + /// In this example for simplicity, the authorization value is not checked + // and used because it's just a toy example. + fn auth_method_call(&self, req: &Request<'_>) -> bool { + if req.method_name() == "trusted_call" { + let Some(Ok(_)) = self.headers.get(AUTHORIZATION).map(|auth| auth.to_str()) else { return false }; + } + + true + } + + /// Authorize the notification by checking the `Authorization` header. + /// + /// Because notifications are not expected to return a response, we + /// return a `MethodResponse` by injecting an error into the extensions + /// which could be read by other middleware or the server. + fn auth_notif(&self, notif: &Notification<'_>) -> bool { + if notif.method_name() == "trusted_call" { + let Some(Ok(_)) = self.headers.get(AUTHORIZATION).map(|auth| auth.to_str()) else { return false }; + } + + true + } +} + +impl RpcServiceT for AuthorizationMiddleware where - S: Send + Clone + Sync + RpcServiceT<'a>, + S: RpcServiceT + Send + Clone + Sync + 'static, + S::Response: Send, { - type Future = ResponseFuture; + type Error = S::Error; + type Response = S::Response; + + fn call<'a>(&self, req: Request<'a>) -> impl Future> + Send + 'a { + let this = self.clone(); + let auth_ok = this.auth_method_call(&req); + + async move { + // If the authorization header is missing, it's recommended to + // to return the response as MethodResponse::error instead of + // returning an error from the service. + // + // This way the error is returned as a JSON-RPC error + if !auth_ok { + return Ok(MethodResponse::error(req.id, auth_reject_error())); + } + this.inner.call(req).await + } + } - fn call(&self, req: Request<'a>) -> Self::Future { - if req.method_name() == "trusted_call" { - let Some(Ok(_)) = self.headers.get(AUTHORIZATION).map(|auth| auth.to_str()) else { - let rp = MethodResponse::error(req.id, ErrorObject::borrowed(-32000, "Authorization failed", None)); - return ResponseFuture::ready(rp); - }; + fn batch<'a>(&self, batch: Batch<'a>) -> impl Future> + Send + 'a { + // Check the authorization header for each entry in the batch. + let entries: Vec<_> = batch + .into_iter() + .filter_map(|entry| match entry { + Ok(BatchEntry::Call(req)) => { + if self.auth_method_call(&req) { + Some(Ok(BatchEntry::Call(req))) + } else { + // If the authorization header is missing, we return + // a JSON-RPC error instead of an error from the service. + Some(Err(BatchEntryErr::new(req.id, auth_reject_error()))) + } + } + Ok(BatchEntry::Notification(notif)) => { + if self.auth_notif(¬if) { + Some(Ok(BatchEntry::Notification(notif))) + } else { + // Just filter out the notification if the auth fails + // because notifications are not expected to return a response. + None + } + } + // Errors which could happen such as invalid JSON-RPC call + // or invalid JSON are just passed through. + Err(err) => Some(Err(err)), + }) + .collect(); - // In this example for simplicity, the authorization value is not checked - // and used because it's just a toy example. + self.inner.batch(Batch::from(entries)).boxed() + } - ResponseFuture::future(self.inner.call(req)) - } else { - ResponseFuture::future(self.inner.call(req)) - } + fn notification<'a>( + &self, + n: Notification<'a>, + ) -> impl Future> + Send + 'a { + self.inner.notification(n) } } @@ -234,10 +309,7 @@ async fn run_server(metrics: Metrics) -> anyhow::Result { async move { tracing::info!("Opened WebSocket connection"); metrics.opened_ws_connections.fetch_add(1, Ordering::Relaxed); - // https://github.com/rust-lang/rust/issues/102211 the error type can't be inferred - // to be `Box` so we need to convert it to a concrete type - // as workaround. - svc.call(req).await.map_err(|e| anyhow::anyhow!("{:?}", e)) + svc.call(req).await } .boxed() } else { @@ -252,10 +324,7 @@ async fn run_server(metrics: Metrics) -> anyhow::Result { } tracing::info!("Closed HTTP connection"); - // https://github.com/rust-lang/rust/issues/102211 the error type can't be inferred - // to be `Box` so we need to convert it to a concrete type - // as workaround. - rp.map_err(|e| anyhow::anyhow!("{:?}", e)) + rp } .boxed() } diff --git a/examples/examples/jsonrpsee_server_close_connection_from_rpc_handler.rs b/examples/examples/jsonrpsee_server_close_connection_from_rpc_handler.rs index e9c91a08b8..d834d35f81 100644 --- a/examples/examples/jsonrpsee_server_close_connection_from_rpc_handler.rs +++ b/examples/examples/jsonrpsee_server_close_connection_from_rpc_handler.rs @@ -35,10 +35,11 @@ use std::sync::Arc; use std::sync::atomic::{AtomicU32, Ordering}; use futures::FutureExt; +use jsonrpsee::core::middleware::RpcServiceBuilder; use jsonrpsee::core::{SubscriptionResult, async_trait}; use jsonrpsee::proc_macros::rpc; use jsonrpsee::server::{ - ConnectionGuard, ConnectionState, HttpRequest, RpcServiceBuilder, ServerConfig, ServerHandle, StopHandle, http, + ConnectionGuard, ConnectionState, HttpRequest, ServerConfig, ServerHandle, StopHandle, http, serve_with_graceful_shutdown, stop_channel, ws, }; use jsonrpsee::types::ErrorObjectOwned; diff --git a/examples/examples/jsonrpsee_server_low_level_api.rs b/examples/examples/jsonrpsee_server_low_level_api.rs index f30e8981e3..75d3f0bebc 100644 --- a/examples/examples/jsonrpsee_server_low_level_api.rs +++ b/examples/examples/jsonrpsee_server_low_level_api.rs @@ -45,16 +45,15 @@ use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::{Arc, Mutex}; use futures::FutureExt; -use futures::future::BoxFuture; use jsonrpsee::core::async_trait; +use jsonrpsee::core::middleware::{Batch, Notification, RpcServiceBuilder, RpcServiceT}; use jsonrpsee::http_client::HttpClient; use jsonrpsee::proc_macros::rpc; -use jsonrpsee::server::middleware::rpc::RpcServiceT; use jsonrpsee::server::{ - ConnectionGuard, ConnectionState, RpcServiceBuilder, ServerConfig, ServerHandle, StopHandle, http, - serve_with_graceful_shutdown, stop_channel, ws, + ConnectionGuard, ConnectionState, ServerConfig, ServerHandle, StopHandle, http, serve_with_graceful_shutdown, + stop_channel, ws, }; -use jsonrpsee::types::{ErrorObject, ErrorObjectOwned, Request}; +use jsonrpsee::types::{ErrorObject, ErrorObjectOwned, Id, Request}; use jsonrpsee::ws_client::WsClientBuilder; use jsonrpsee::{MethodResponse, Methods}; use tokio::net::TcpListener; @@ -73,13 +72,14 @@ struct CallLimit { state: mpsc::Sender<()>, } -impl<'a, S> RpcServiceT<'a> for CallLimit +impl RpcServiceT for CallLimit where - S: Send + Sync + RpcServiceT<'a> + Clone + 'static, + S: RpcServiceT + Send + Sync + Clone + 'static, { - type Future = BoxFuture<'a, MethodResponse>; + type Error = S::Error; + type Response = S::Response; - fn call(&self, req: Request<'a>) -> Self::Future { + fn call<'a>(&self, req: Request<'a>) -> impl Future> + Send + 'a { let count = self.count.clone(); let state = self.state.clone(); let service = self.service.clone(); @@ -89,14 +89,47 @@ where if *lock >= 10 { let _ = state.try_send(()); - MethodResponse::error(req.id, ErrorObject::borrowed(-32000, "RPC rate limit", None)) + Ok(MethodResponse::error(req.id, ErrorObject::borrowed(-32000, "RPC rate limit", None))) } else { let rp = service.call(req).await; *lock += 1; rp } } - .boxed() + } + + fn batch<'a>(&self, batch: Batch<'a>) -> impl Future> + Send + 'a { + let count = self.count.clone(); + let state = self.state.clone(); + let service = self.service.clone(); + + async move { + let mut lock = count.lock().await; + let batch_len = batch.len(); + + if *lock >= 10 + batch_len { + let _ = state.try_send(()); + Ok(MethodResponse::error(Id::Null, ErrorObject::borrowed(-32000, "RPC rate limit", None))) + } else { + let rp = service.batch(batch).await; + *lock += batch_len; + rp + } + } + } + + fn notification<'a>( + &self, + n: Notification<'a>, + ) -> impl Future> + Send + 'a { + let count = self.count.clone(); + let service = self.service.clone(); + + // A notification is not expected to return a response so the result here doesn't matter + // rather than other middlewares may not be invoked. + async move { + if *count.lock().await >= 10 { Ok(MethodResponse::notification()) } else { service.notification(n).await } + } } } diff --git a/examples/examples/rpc_middleware.rs b/examples/examples/rpc_middleware.rs index 585a4372e8..166c36d5d6 100644 --- a/examples/examples/rpc_middleware.rs +++ b/examples/examples/rpc_middleware.rs @@ -36,17 +36,20 @@ //! Contrary the HTTP middleware does only apply per HTTP request and //! may be handy in some scenarios such CORS but if you want to access //! to the actual JSON-RPC details this is the middleware to use. +//! +//! This example enables the same middleware for both the server and client which +//! can be confusing when one runs this but it is just to demonstrate the API. +//! +//! That the middleware is applied to the server and client in the same way. use std::net::SocketAddr; use std::sync::Arc; use std::sync::atomic::{AtomicUsize, Ordering}; -use futures::FutureExt; -use futures::future::BoxFuture; use jsonrpsee::core::client::ClientT; +use jsonrpsee::core::middleware::{Batch, Notification, RpcServiceBuilder, RpcServiceT}; use jsonrpsee::rpc_params; -use jsonrpsee::server::middleware::rpc::{RpcServiceBuilder, RpcServiceT}; -use jsonrpsee::server::{MethodResponse, RpcModule, Server}; +use jsonrpsee::server::{RpcModule, Server}; use jsonrpsee::types::Request; use jsonrpsee::ws_client::WsClientBuilder; @@ -56,26 +59,44 @@ use jsonrpsee::ws_client::WsClientBuilder; pub struct CallsPerConn { service: S, count: Arc, + role: &'static str, } -impl<'a, S> RpcServiceT<'a> for CallsPerConn +impl RpcServiceT for CallsPerConn where - S: RpcServiceT<'a> + Send + Sync + Clone + 'static, + S: RpcServiceT + Send + Sync + Clone + 'static, { - type Future = BoxFuture<'a, MethodResponse>; + type Error = S::Error; + type Response = S::Response; - fn call(&self, req: Request<'a>) -> Self::Future { + fn call<'a>( + &self, + req: Request<'a>, + ) -> impl Future::Error>> + Send + 'a { let count = self.count.clone(); let service = self.service.clone(); + let role = self.role; async move { let rp = service.call(req).await; count.fetch_add(1, Ordering::SeqCst); - let count = count.load(Ordering::SeqCst); - println!("the server has processed calls={count} on the connection"); + println!("{role} processed calls={} on the connection", count.load(Ordering::SeqCst)); rp } - .boxed() + } + + fn batch<'a>(&self, batch: Batch<'a>) -> impl Future> + Send + 'a { + let len = batch.len(); + self.count.fetch_add(len, Ordering::SeqCst); + println!("{} processed calls={} on the connection", self.role, self.count.load(Ordering::SeqCst)); + self.service.batch(batch) + } + + fn notification<'a>( + &self, + n: Notification<'a>, + ) -> impl Future> + Send + 'a { + self.service.notification(n) } } @@ -83,41 +104,72 @@ where pub struct GlobalCalls { service: S, count: Arc, + role: &'static str, } -impl<'a, S> RpcServiceT<'a> for GlobalCalls +impl RpcServiceT for GlobalCalls where - S: RpcServiceT<'a> + Send + Sync + Clone + 'static, + S: RpcServiceT + Send + Sync + Clone + 'static, { - type Future = BoxFuture<'a, MethodResponse>; + type Error = S::Error; + type Response = S::Response; - fn call(&self, req: Request<'a>) -> Self::Future { + fn call<'a>(&self, req: Request<'a>) -> impl Future> + Send + 'a { let count = self.count.clone(); let service = self.service.clone(); + let role = self.role; async move { let rp = service.call(req).await; count.fetch_add(1, Ordering::SeqCst); - let count = count.load(Ordering::SeqCst); - println!("the server has processed calls={count} in total"); + println!("{role} processed calls={} in total", count.load(Ordering::SeqCst)); + rp } - .boxed() + } + + fn batch<'a>(&self, batch: Batch<'a>) -> impl Future> + Send + 'a { + let len = batch.len(); + self.count.fetch_add(len, Ordering::SeqCst); + println!("{}, processed calls={} in total", self.role, self.count.load(Ordering::SeqCst)); + self.service.batch(batch) + } + + fn notification<'a>( + &self, + n: Notification<'a>, + ) -> impl Future> + Send + 'a { + self.service.notification(n) } } #[derive(Clone)] -pub struct Logger(S); +pub struct Logger { + service: S, + role: &'static str, +} -impl<'a, S> RpcServiceT<'a> for Logger +impl RpcServiceT for Logger where - S: RpcServiceT<'a> + Send + Sync, + S: RpcServiceT + Send + Sync, { - type Future = S::Future; + type Error = S::Error; + type Response = S::Response; + + fn call<'a>(&self, req: Request<'a>) -> impl Future> + Send + 'a { + println!("{} logger middleware: method `{}`", self.role, req.method); + self.service.call(req) + } - fn call(&self, req: Request<'a>) -> Self::Future { - println!("logger middleware: method `{}`", req.method); - self.0.call(req) + fn batch<'a>(&self, batch: Batch<'a>) -> impl Future> + Send + 'a { + println!("{} logger middleware: batch {batch}", self.role); + self.service.batch(batch) + } + fn notification<'a>( + &self, + n: Notification<'a>, + ) -> impl Future> + Send + 'a { + self.service.notification(n) } } @@ -132,7 +184,14 @@ async fn main() -> anyhow::Result<()> { let url = format!("ws://{}", addr); for _ in 0..2 { - let client = WsClientBuilder::default().build(&url).await?; + let global_cnt = Arc::new(AtomicUsize::new(0)); + let rpc_middleware = RpcServiceBuilder::new() + .layer_fn(|service| Logger { service, role: "client" }) + // This state is created per connection. + .layer_fn(|service| CallsPerConn { service, count: Default::default(), role: "client" }) + // This state is shared by all connections. + .layer_fn(move |service| GlobalCalls { service, count: global_cnt.clone(), role: "client" }); + let client = WsClientBuilder::new().set_rpc_middleware(rpc_middleware).build(&url).await?; let response: String = client.request("say_hello", rpc_params![]).await?; println!("response: {:?}", response); let _response: Result = client.request("unknown_method", rpc_params![]).await; @@ -149,11 +208,11 @@ async fn run_server() -> anyhow::Result { let global_cnt = Arc::new(AtomicUsize::new(0)); let rpc_middleware = RpcServiceBuilder::new() - .layer_fn(Logger) + .layer_fn(|service| Logger { service, role: "server" }) // This state is created per connection. - .layer_fn(|service| CallsPerConn { service, count: Default::default() }) + .layer_fn(|service| CallsPerConn { service, count: Default::default(), role: "server" }) // This state is shared by all connections. - .layer_fn(move |service| GlobalCalls { service, count: global_cnt.clone() }); + .layer_fn(move |service| GlobalCalls { service, count: global_cnt.clone(), role: "server" }); let server = Server::builder().set_rpc_middleware(rpc_middleware).build("127.0.0.1:0").await?; let mut module = RpcModule::new(()); module.register_method("say_hello", |_, _, _| "lo")?; diff --git a/examples/examples/rpc_middleware_client.rs b/examples/examples/rpc_middleware_client.rs new file mode 100644 index 0000000000..1ce6d9158f --- /dev/null +++ b/examples/examples/rpc_middleware_client.rs @@ -0,0 +1,177 @@ +// Copyright 2019-2021 Parity Technologies (UK) Ltd. +// +// Permission is hereby granted, free of charge, to any +// person obtaining a copy of this software and associated +// documentation files (the "Software"), to deal in the +// Software without restriction, including without +// limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of +// the Software, and to permit persons to whom the Software +// is furnished to do so, subject to the following +// conditions: +// +// The above copyright notice and this permission notice +// shall be included in all copies or substantial portions +// of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +// SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +//! jsonrpsee supports two kinds of middlewares `http_middleware` and `rpc_middleware`. +//! +//! This example demonstrates how to use the `rpc_middleware` which applies for each +//! JSON-RPC method calls, notifications and batch requests. +//! +//! This example demonstrates how to use the `rpc_middleware` for the client +//! and you may benefit specifying the response type to `core::client::MethodResponse` +//! to actually inspect the response instead of using the serialized JSON-RPC response. + +use std::net::SocketAddr; +use std::ops::Deref; +use std::sync::{Arc, Mutex}; + +use jsonrpsee::core::client::{ClientT, MethodResponse, MethodResponseKind}; +use jsonrpsee::core::middleware::{Batch, Notification, RpcServiceBuilder, RpcServiceT}; +use jsonrpsee::rpc_params; +use jsonrpsee::server::{RpcModule, Server}; +use jsonrpsee::types::{ErrorCode, ErrorObject, Request}; +use jsonrpsee::ws_client::WsClientBuilder; + +#[derive(Default)] +struct InnerMetrics { + method_calls_success: usize, + method_calls_failure: usize, + notifications: usize, + batch_calls: usize, +} + +#[derive(Clone)] +pub struct Metrics { + service: S, + metrics: Arc>, +} + +impl std::fmt::Debug for InnerMetrics { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("InnerMetrics") + .field("method_calls_success", &self.method_calls_success) + .field("method_calls_failure", &self.method_calls_failure) + .field("notifications", &self.notifications) + .field("batch_calls", &self.batch_calls) + .finish() + } +} + +impl Metrics { + pub fn new(service: S) -> Self { + Self { service, metrics: Arc::new(Mutex::new(InnerMetrics::default())) } + } +} + +// NOTE: We are using MethodResponse as the response type here to be able to inspect the response +// and not just the serialized JSON-RPC response. This is not necessary if you only care about +// the serialized JSON-RPC response. +impl RpcServiceT for Metrics +where + S: RpcServiceT + Send + Sync + Clone + 'static, +{ + type Error = S::Error; + type Response = S::Response; + + fn call<'a>(&self, req: Request<'a>) -> impl Future> + Send + 'a { + let m = self.metrics.clone(); + let service = self.service.clone(); + + async move { + let rp = service.call(req).await; + + // Access to inner response via the deref implementation. + match rp.as_ref().map(|r| r.deref()) { + Ok(MethodResponseKind::MethodCall(r)) => { + if r.is_success() { + m.lock().unwrap().method_calls_success += 1; + } else { + m.lock().unwrap().method_calls_failure += 1; + } + } + Ok(e) => unreachable!("Unexpected response type {e:?}"), + Err(e) => { + m.lock().unwrap().method_calls_failure += 1; + tracing::error!("Error: {:?}", e); + } + } + + rp + } + } + + fn batch<'a>(&self, batch: Batch<'a>) -> impl Future> + Send + 'a { + self.metrics.lock().unwrap().batch_calls += 1; + self.service.batch(batch) + } + + fn notification<'a>( + &self, + n: Notification<'a>, + ) -> impl Future> + Send + 'a { + self.metrics.lock().unwrap().notifications += 1; + self.service.notification(n) + } +} + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + tracing_subscriber::FmtSubscriber::builder() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init() + .expect("setting default subscriber failed"); + + let addr = run_server().await?; + let url = format!("ws://{}", addr); + + let metrics = Arc::new(Mutex::new(InnerMetrics::default())); + + for _ in 0..2 { + let metrics = metrics.clone(); + let rpc_middleware = + RpcServiceBuilder::new().layer_fn(move |s| Metrics { service: s, metrics: metrics.clone() }); + let client = WsClientBuilder::new().set_rpc_middleware(rpc_middleware).build(&url).await?; + let _: Result = client.request("say_hello", rpc_params![]).await; + let _: Result = client.request("unknown_method", rpc_params![]).await; + let _: Result = client.request("thready", rpc_params![4]).await; + let _: Result = client.request("mul", rpc_params![4]).await; + let _: Result = client.request("err", rpc_params![4]).await; + + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + } + + println!("Metrics: {:?}", metrics.lock().unwrap()); + + Ok(()) +} + +async fn run_server() -> anyhow::Result { + let server = Server::builder().build("127.0.0.1:0").await?; + let mut module = RpcModule::new(()); + module.register_method("say_hello", |_, _, _| "lo")?; + module.register_method("mul", |params, _, _| { + let count: usize = params.one().unwrap(); + count * 2 + })?; + module.register_method("error", |_, _, _| ErrorObject::from(ErrorCode::InternalError))?; + let addr = server.local_addr()?; + let handle = server.start(module); + + // In this example we don't care about doing shutdown so let's it run forever. + // You may use the `ServerHandle` to shut it down or manage it yourself. + tokio::spawn(handle.stopped()); + + Ok(addr) +} diff --git a/examples/examples/rpc_middleware_modify_request.rs b/examples/examples/rpc_middleware_modify_request.rs index f8fbb4bf9a..fb907e85bd 100644 --- a/examples/examples/rpc_middleware_modify_request.rs +++ b/examples/examples/rpc_middleware_modify_request.rs @@ -25,38 +25,81 @@ // DEALINGS IN THE SOFTWARE. use jsonrpsee::core::client::ClientT; +use jsonrpsee::core::middleware::{Batch, BatchEntry, Notification, RpcServiceBuilder, RpcServiceT}; use jsonrpsee::server::Server; -use jsonrpsee::server::middleware::rpc::{RpcServiceBuilder, RpcServiceT}; use jsonrpsee::types::Request; use jsonrpsee::ws_client::WsClientBuilder; use jsonrpsee::{RpcModule, rpc_params}; use std::borrow::Cow as StdCow; use std::net::SocketAddr; +fn modify_method_call(req: &mut Request<'_>) { + // Example how to modify the params in the call. + if req.method == "say_hello" { + // It's a bit awkward to create new params in the request + // but this shows how to do it. + let raw_value = serde_json::value::to_raw_value("myparams").unwrap(); + req.params = Some(StdCow::Owned(raw_value)); + } + // Re-direct all calls that isn't `say_hello` to `say_goodbye` + else if req.method != "say_hello" { + req.method = "say_goodbye".into(); + } +} + +fn modify_notif(n: &mut Notification<'_>) { + // Example how to modify the params in the notification. + if n.method == "say_hello" { + // It's a bit awkward to create new params in the request + // but this shows how to do it. + let raw_value = serde_json::value::to_raw_value("myparams").unwrap(); + n.params = Some(StdCow::Owned(raw_value)); + } + // Re-direct all notifs that isn't `say_hello` to `say_goodbye` + else if n.method != "say_hello" { + n.method = "say_goodbye".into(); + } +} + #[derive(Clone)] pub struct ModifyRequestIf(S); -impl<'a, S> RpcServiceT<'a> for ModifyRequestIf +impl RpcServiceT for ModifyRequestIf where - S: Send + Sync + RpcServiceT<'a>, + S: Send + Sync + RpcServiceT, { - type Future = S::Future; - - fn call(&self, mut req: Request<'a>) -> Self::Future { - // Example how to modify the params in the call. - if req.method == "say_hello" { - // It's a bit awkward to create new params in the request - // but this shows how to do it. - let raw_value = serde_json::value::to_raw_value("myparams").unwrap(); - req.params = Some(StdCow::Owned(raw_value)); - } - // Re-direct all calls that isn't `say_hello` to `say_goodbye` - else if req.method != "say_hello" { - req.method = "say_goodbye".into(); - } + type Error = S::Error; + type Response = S::Response; + fn call<'a>(&self, mut req: Request<'a>) -> impl Future> + Send + 'a { + modify_method_call(&mut req); self.0.call(req) } + + fn batch<'a>(&self, mut batch: Batch<'a>) -> impl Future> + Send + 'a { + for call in batch.iter_mut() { + match call { + Ok(BatchEntry::Call(call)) => { + modify_method_call(call); + } + Ok(BatchEntry::Notification(n)) => { + modify_notif(n); + } + // Invalid request, we don't care about it. + Err(_err) => {} + } + } + + self.0.batch(batch) + } + + fn notification<'a>( + &self, + mut n: Notification<'a>, + ) -> impl Future> + Send + 'a { + modify_notif(&mut n); + self.0.notification(n) + } } #[tokio::main] diff --git a/examples/examples/rpc_middleware_rate_limiting.rs b/examples/examples/rpc_middleware_rate_limiting.rs index dd83792805..c82fa549d0 100644 --- a/examples/examples/rpc_middleware_rate_limiting.rs +++ b/examples/examples/rpc_middleware_rate_limiting.rs @@ -32,8 +32,10 @@ //! such as `Arc` use jsonrpsee::core::client::ClientT; +use jsonrpsee::core::middleware::{ + Batch, BatchEntry, BatchEntryErr, Notification, ResponseFuture, RpcServiceBuilder, RpcServiceT, +}; use jsonrpsee::server::Server; -use jsonrpsee::server::middleware::rpc::{ResponseFuture, RpcServiceBuilder, RpcServiceT}; use jsonrpsee::types::{ErrorObject, Request}; use jsonrpsee::ws_client::WsClientBuilder; use jsonrpsee::{MethodResponse, RpcModule, rpc_params}; @@ -78,54 +80,83 @@ impl RateLimit { state: Arc::new(Mutex::new(State::Allow { until: Instant::now() + period, rem: num + 1 })), } } -} -impl<'a, S> RpcServiceT<'a> for RateLimit -where - S: Send + RpcServiceT<'a>, -{ - // Instead of `Boxing` the future in this example - // we are using a jsonrpsee's ResponseFuture future - // type to avoid those extra allocations. - type Future = ResponseFuture; - - fn call(&self, req: Request<'a>) -> Self::Future { + fn rate_limit_deny(&self) -> bool { let now = Instant::now(); - - let is_denied = { - let mut lock = self.state.lock().unwrap(); - let next_state = match *lock { - State::Deny { until } => { - if now > until { - State::Allow { until: now + self.rate.period, rem: self.rate.num - 1 } - } else { - State::Deny { until } - } + let mut lock = self.state.lock().unwrap(); + let next_state = match *lock { + State::Deny { until } => { + if now > until { + State::Allow { until: now + self.rate.period, rem: self.rate.num - 1 } + } else { + State::Deny { until } } - State::Allow { until, rem } => { - if now > until { - State::Allow { until: now + self.rate.period, rem: self.rate.num - 1 } - } else { - let n = rem - 1; - if n > 0 { - State::Allow { until: now + self.rate.period, rem: n } - } else { - State::Deny { until } - } - } + } + State::Allow { until, rem } => { + if now > until { + State::Allow { until: now + self.rate.period, rem: self.rate.num - 1 } + } else { + let n = rem - 1; + if n > 0 { State::Allow { until: now + self.rate.period, rem: n } } else { State::Deny { until } } } - }; - - *lock = next_state; - matches!(next_state, State::Deny { .. }) + } }; - if is_denied { + *lock = next_state; + matches!(next_state, State::Deny { .. }) + } +} + +impl RpcServiceT for RateLimit +where + S: Send + RpcServiceT + 'static, + S::Error: Send, +{ + type Error = S::Error; + type Response = S::Response; + + fn call<'a>(&self, req: Request<'a>) -> impl Future> + Send + 'a { + if self.rate_limit_deny() { ResponseFuture::ready(MethodResponse::error(req.id, ErrorObject::borrowed(-32000, "RPC rate limit", None))) } else { ResponseFuture::future(self.service.call(req)) } } + + fn batch<'a>(&self, mut batch: Batch<'a>) -> impl Future> + Send + 'a { + // If the rate limit is reached then we modify each entry + // in the batch to be a request with an error. + // + // This makes sure that the client will receive an error + // for each request in the batch. + if self.rate_limit_deny() { + for entry in batch.iter_mut() { + let id = match entry { + Ok(BatchEntry::Call(req)) => req.id.clone(), + Ok(BatchEntry::Notification(_)) => continue, + Err(_) => continue, + }; + + // This will create a new error response for batch and replace the method call + *entry = Err(BatchEntryErr::new(id, ErrorObject::borrowed(-32000, "RPC rate limit", None))); + } + } + + self.service.batch(batch) + } + + fn notification<'a>( + &self, + n: Notification<'a>, + ) -> impl Future> + Send + 'a { + if self.rate_limit_deny() { + // Notifications are not expected to return a response so just ignore + // if the rate limit is reached. + ResponseFuture::ready(MethodResponse::notification()) + } else { + ResponseFuture::future(self.service.notification(n)) + } + } } #[tokio::main] diff --git a/examples/examples/server_with_connection_details.rs b/examples/examples/server_with_connection_details.rs index 4692aca347..9291f45c0b 100644 --- a/examples/examples/server_with_connection_details.rs +++ b/examples/examples/server_with_connection_details.rs @@ -26,15 +26,13 @@ use std::net::SocketAddr; -use jsonrpsee::ConnectionId; -use jsonrpsee::Extensions; -use jsonrpsee::core::SubscriptionResult; -use jsonrpsee::core::async_trait; +use jsonrpsee::core::middleware::{Batch, Notification, Request, RpcServiceT}; +use jsonrpsee::core::{SubscriptionResult, async_trait}; use jsonrpsee::proc_macros::rpc; -use jsonrpsee::server::middleware::rpc::RpcServiceT; use jsonrpsee::server::{PendingSubscriptionSink, SubscriptionMessage}; use jsonrpsee::types::{ErrorObject, ErrorObjectOwned}; use jsonrpsee::ws_client::WsClientBuilder; +use jsonrpsee::{ConnectionId, Extensions}; #[rpc(server, client)] pub trait Rpc { @@ -51,15 +49,29 @@ pub trait Rpc { struct LoggingMiddleware(S); -impl<'a, S: RpcServiceT<'a>> RpcServiceT<'a> for LoggingMiddleware { - type Future = S::Future; +impl RpcServiceT for LoggingMiddleware { + type Error = S::Error; + type Response = S::Response; - fn call(&self, request: jsonrpsee::types::Request<'a>) -> Self::Future { + fn call<'a>(&self, request: Request<'a>) -> impl Future> + Send + 'a { tracing::info!("Received request: {:?}", request); assert!(request.extensions().get::().is_some()); self.0.call(request) } + + fn batch<'a>(&self, batch: Batch<'a>) -> impl Future> + Send + 'a { + tracing::info!("Received batch: {:?}", batch); + self.0.batch(batch) + } + + fn notification<'a>( + &self, + n: Notification<'a>, + ) -> impl Future> + Send + 'a { + tracing::info!("Received notif: {:?}", n); + self.0.notification(n) + } } pub struct RpcServerImpl; diff --git a/examples/examples/ws.rs b/examples/examples/ws.rs index 4309d1a90a..61173162a4 100644 --- a/examples/examples/ws.rs +++ b/examples/examples/ws.rs @@ -27,7 +27,8 @@ use std::net::SocketAddr; use jsonrpsee::core::client::ClientT; -use jsonrpsee::server::{RpcServiceBuilder, Server}; +use jsonrpsee::core::middleware::RpcServiceBuilder; +use jsonrpsee::server::Server; use jsonrpsee::ws_client::WsClientBuilder; use jsonrpsee::{RpcModule, rpc_params}; use tracing_subscriber::util::SubscriberInitExt; @@ -42,7 +43,8 @@ async fn main() -> anyhow::Result<()> { let addr = run_server().await?; let url = format!("ws://{}", addr); - let client = WsClientBuilder::default().build(&url).await?; + let rpc_middleware = RpcServiceBuilder::new().rpc_logger(1024); + let client = WsClientBuilder::new().set_rpc_middleware(rpc_middleware).build(&url).await?; let response: String = client.request("say_hello", rpc_params![]).await?; tracing::info!("response: {:?}", response); diff --git a/examples/examples/ws_pubsub_broadcast.rs b/examples/examples/ws_pubsub_broadcast.rs index 1365ed5b9b..1e57f09905 100644 --- a/examples/examples/ws_pubsub_broadcast.rs +++ b/examples/examples/ws_pubsub_broadcast.rs @@ -32,6 +32,7 @@ use futures::StreamExt; use futures::future::{self, Either}; use jsonrpsee::PendingSubscriptionSink; use jsonrpsee::core::client::{Subscription, SubscriptionClientT}; +use jsonrpsee::core::middleware::RpcServiceBuilder; use jsonrpsee::core::server::SubscriptionMessage; use jsonrpsee::rpc_params; use jsonrpsee::server::{RpcModule, Server, ServerConfig}; @@ -51,8 +52,10 @@ async fn main() -> anyhow::Result<()> { let addr = run_server().await?; let url = format!("ws://{}", addr); - let client1 = WsClientBuilder::default().build(&url).await?; - let client2 = WsClientBuilder::default().build(&url).await?; + let client1 = + WsClientBuilder::default().set_rpc_middleware(RpcServiceBuilder::new().rpc_logger(1024)).build(&url).await?; + let client2 = + WsClientBuilder::default().set_rpc_middleware(RpcServiceBuilder::new().rpc_logger(1024)).build(&url).await?; let sub1: Subscription = client1.subscribe("subscribe_hello", rpc_params![], "unsubscribe_hello").await?; let sub2: Subscription = client2.subscribe("subscribe_hello", rpc_params![], "unsubscribe_hello").await?; @@ -67,7 +70,11 @@ async fn main() -> anyhow::Result<()> { async fn run_server() -> anyhow::Result { // let's configure the server only hold 5 messages in memory. let config = ServerConfig::builder().set_message_buffer_capacity(5).build(); - let server = Server::builder().set_config(config).build("127.0.0.1:0").await?; + let server = Server::builder() + .set_config(config) + .set_rpc_middleware(RpcServiceBuilder::new().rpc_logger(1024)) + .build("127.0.0.1:0") + .await?; let (tx, _rx) = broadcast::channel::(16); let mut module = RpcModule::new(tx.clone()); diff --git a/proc-macros/src/lib.rs b/proc-macros/src/lib.rs index 9fb60483eb..fe2946c37c 100644 --- a/proc-macros/src/lib.rs +++ b/proc-macros/src/lib.rs @@ -172,8 +172,7 @@ pub(crate) mod visitor; /// /// - `name` (mandatory): name of the RPC method. Does not have to be the same as the Rust method name. /// - `aliases`: list of name aliases for the RPC method as a comma separated string. -/// Aliases are processed ignoring the namespace, so add the complete name, including the -/// namespace. +/// Aliases are processed ignoring the namespace, so add the complete name, including the namespace. /// - `blocking`: when set method execution will always spawn on a dedicated thread. Only usable with non-`async` methods. /// - `param_kind`: kind of structure to use for parameter passing. Can be "array" or "map", defaults to "array". /// @@ -193,9 +192,9 @@ pub(crate) mod visitor; /// /// - `name` (mandatory): name of the RPC method. Does not have to be the same as the Rust method name. /// - `unsubscribe` (optional): name of the RPC method to unsubscribe from the subscription. Must not be the same as `name`. -/// This is generated for you if the subscription name starts with `subscribe`. +/// This is generated for you if the subscription name starts with `subscribe`. /// - `aliases` (optional): aliases for `name`. Aliases are processed ignoring the namespace, -/// so add the complete name, including the namespace. +/// so add the complete name, including the namespace. /// - `unsubscribe_aliases` (optional): Similar to `aliases` but for `unsubscribe`. /// - `item` (mandatory): type of items yielded by the subscription. Note that it must be the type, not string. /// - `param_kind`: kind of structure to use for parameter passing. Can be "array" or "map", defaults to "array". diff --git a/proc-macros/tests/ui/correct/basic.rs b/proc-macros/tests/ui/correct/basic.rs index 1144674c29..dc5180b86b 100644 --- a/proc-macros/tests/ui/correct/basic.rs +++ b/proc-macros/tests/ui/correct/basic.rs @@ -4,12 +4,12 @@ use std::net::SocketAddr; use jsonrpsee::core::client::ClientT; use jsonrpsee::core::params::ArrayParams; -use jsonrpsee::core::{async_trait, RpcResult, SubscriptionResult}; +use jsonrpsee::core::{RpcResult, SubscriptionResult, async_trait}; use jsonrpsee::proc_macros::rpc; use jsonrpsee::server::SubscriptionMessage; use jsonrpsee::types::ErrorObject; use jsonrpsee::ws_client::*; -use jsonrpsee::{rpc_params, Extensions, PendingSubscriptionSink}; +use jsonrpsee::{Extensions, PendingSubscriptionSink, rpc_params}; #[rpc(client, server, namespace = "foo")] pub trait Rpc { @@ -135,10 +135,10 @@ impl RpcServer for RpcServerImpl { pub async fn server() -> SocketAddr { use hyper_util::rt::{TokioExecutor, TokioIo}; - use jsonrpsee::server::middleware::rpc::RpcServiceT; - use jsonrpsee::server::{stop_channel, RpcServiceBuilder}; + use jsonrpsee::core::middleware::{Batch, Notification, RpcServiceBuilder, RpcServiceT}; + use jsonrpsee::server::stop_channel; use std::convert::Infallible; - use std::sync::{atomic::AtomicU32, Arc}; + use std::sync::{Arc, atomic::AtomicU32}; use tower::Service; #[derive(Debug, Clone)] @@ -147,16 +147,31 @@ pub async fn server() -> SocketAddr { connection_id: u32, } - impl<'a, S> RpcServiceT<'a> for ConnectionDetails + impl RpcServiceT for ConnectionDetails where - S: RpcServiceT<'a>, + S: RpcServiceT, { - type Future = S::Future; + type Error = S::Error; + type Response = S::Response; - fn call(&self, mut request: jsonrpsee::types::Request<'a>) -> Self::Future { + fn call<'a>( + &self, + mut request: jsonrpsee::types::Request<'a>, + ) -> impl Future> + Send + 'a { request.extensions_mut().insert(self.connection_id); self.inner.call(request) } + + fn batch<'a>(&self, batch: Batch<'a>) -> impl Future> + Send + 'a { + self.inner.batch(batch) + } + + fn notification<'a>( + &self, + notif: Notification<'a>, + ) -> impl Future> + Send + 'a { + self.inner.notification(notif) + } } let listener = tokio::net::TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0))).await.unwrap(); diff --git a/server/src/lib.rs b/server/src/lib.rs index 5267b5c5b9..023d9a8046 100644 --- a/server/src/lib.rs +++ b/server/src/lib.rs @@ -46,7 +46,6 @@ pub use jsonrpsee_core::error::RegisterMethodError; pub use jsonrpsee_core::server::*; pub use jsonrpsee_core::{id_providers::*, traits::IdProvider}; pub use jsonrpsee_types as types; -pub use middleware::rpc::RpcServiceBuilder; pub use server::{ BatchRequestConfig, Builder as ServerBuilder, ConnectionState, PingConfig, Server, ServerConfig, ServerConfigBuilder, TowerService, TowerServiceBuilder, diff --git a/server/src/middleware/http/proxy_get_request.rs b/server/src/middleware/http/proxy_get_request.rs index 0ae473ff0a..bd25090fea 100644 --- a/server/src/middleware/http/proxy_get_request.rs +++ b/server/src/middleware/http/proxy_get_request.rs @@ -35,7 +35,7 @@ use hyper::header::{ACCEPT, CONTENT_TYPE}; use hyper::http::HeaderValue; use hyper::{Method, StatusCode, Uri}; use jsonrpsee_core::BoxError; -use jsonrpsee_types::{ErrorCode, ErrorObject, Id, RequestSer}; +use jsonrpsee_types::{ErrorCode, ErrorObject, Id, Request}; use std::collections::HashMap; use std::future::Future; use std::pin::Pin; @@ -160,8 +160,8 @@ where req.headers_mut().insert(ACCEPT, HeaderValue::from_static("application/json")); // Adjust the body to reflect the method call. - let bytes = serde_json::to_vec(&RequestSer::borrowed(&Id::Number(0), method, None)) - .expect("Valid request; qed"); + let bytes = + serde_json::to_vec(&Request::borrowed(method, None, Id::Number(0))).expect("Valid request; qed"); let req = req.map(|_| HttpBody::from(bytes)); // Call the inner service and get a future that resolves to the response. diff --git a/server/src/middleware/mod.rs b/server/src/middleware/mod.rs index dff6948aa3..93b66a7b29 100644 --- a/server/src/middleware/mod.rs +++ b/server/src/middleware/mod.rs @@ -28,5 +28,4 @@ /// HTTP related middleware. pub mod http; -/// JSON-RPC specific middleware. pub mod rpc; diff --git a/server/src/middleware/rpc/layer/rpc_service.rs b/server/src/middleware/rpc.rs similarity index 61% rename from server/src/middleware/rpc/layer/rpc_service.rs rename to server/src/middleware/rpc.rs index 34fe1bdd1f..8b2180a122 100644 --- a/server/src/middleware/rpc/layer/rpc_service.rs +++ b/server/src/middleware/rpc.rs @@ -26,18 +26,20 @@ //! JSON-RPC service middleware. -use super::ResponseFuture; +pub use jsonrpsee_core::middleware::*; +pub use jsonrpsee_core::server::MethodResponse; + +use std::convert::Infallible; use std::sync::Arc; use crate::ConnectionId; -use crate::middleware::rpc::RpcServiceT; -use futures_util::future::BoxFuture; +use futures_util::future::FutureExt; use jsonrpsee_core::server::{ - BoundedSubscriptions, MethodCallback, MethodResponse, MethodSink, Methods, SubscriptionState, + BatchResponseBuilder, BoundedSubscriptions, MethodCallback, MethodSink, Methods, SubscriptionState, }; use jsonrpsee_core::traits::IdProvider; +use jsonrpsee_types::ErrorObject; use jsonrpsee_types::error::{ErrorCode, reject_too_many_subscriptions}; -use jsonrpsee_types::{ErrorObject, Request}; /// JSON-RPC service middleware. #[derive(Clone, Debug)] @@ -74,12 +76,11 @@ impl RpcService { } } -impl<'a> RpcServiceT<'a> for RpcService { - // The rpc module is already boxing the futures and - // it's used to under the hood by the RpcService. - type Future = ResponseFuture>; +impl RpcServiceT for RpcService { + type Error = Infallible; + type Response = MethodResponse; - fn call(&self, req: Request<'a>) -> Self::Future { + fn call<'a>(&self, req: Request<'a>) -> impl Future> + Send + 'a { let conn_id = self.conn_id; let max_response_body_size = self.max_response_body_size; @@ -90,19 +91,18 @@ impl<'a> RpcServiceT<'a> for RpcService { None => { let rp = MethodResponse::error(id, ErrorObject::from(ErrorCode::MethodNotFound)).with_extensions(extensions); - ResponseFuture::ready(rp) + async move { Ok(rp) }.boxed() } Some((_name, method)) => match method { MethodCallback::Async(callback) => { let params = params.into_owned(); let id = id.into_owned(); - let fut = (callback)(id, params, conn_id, max_response_body_size, extensions); - ResponseFuture::future(fut) + (callback)(id, params, conn_id, max_response_body_size, extensions).map(Ok).boxed() } MethodCallback::Sync(callback) => { let rp = (callback)(id, params, max_response_body_size, extensions); - ResponseFuture::ready(rp) + async move { Ok(rp) }.boxed() } MethodCallback::Subscription(callback) => { let RpcServiceCfg::CallsAndSubscriptions { @@ -115,20 +115,19 @@ impl<'a> RpcServiceT<'a> for RpcService { tracing::warn!("Subscriptions not supported"); let rp = MethodResponse::error(id, ErrorObject::from(ErrorCode::InternalError)) .with_extensions(extensions); - return ResponseFuture::ready(rp); + return async move { Ok(rp) }.boxed(); }; if let Some(p) = bounded_subscriptions.acquire() { let conn_state = SubscriptionState { conn_id, id_provider: &*id_provider.clone(), subscription_permit: p }; - let fut = callback(id.clone(), params, sink, conn_state, extensions); - ResponseFuture::future(fut) + callback(id.clone(), params, sink, conn_state, extensions).map(Ok).boxed() } else { let max = bounded_subscriptions.max(); let rp = MethodResponse::error(id, reject_too_many_subscriptions(max)).with_extensions(extensions); - ResponseFuture::ready(rp) + async move { Ok(rp) }.boxed() } } MethodCallback::Unsubscription(callback) => { @@ -138,13 +137,68 @@ impl<'a> RpcServiceT<'a> for RpcService { tracing::warn!("Subscriptions not supported"); let rp = MethodResponse::error(id, ErrorObject::from(ErrorCode::InternalError)) .with_extensions(extensions); - return ResponseFuture::ready(rp); + return async move { Ok(rp) }.boxed(); }; let rp = callback(id, params, conn_id, max_response_body_size, extensions); - ResponseFuture::ready(rp) + async move { Ok(rp) }.boxed() } }, } } + + fn batch<'a>(&self, batch: Batch<'a>) -> impl Future> + Send + 'a { + let mut batch_rp = BatchResponseBuilder::new_with_limit(self.max_response_body_size); + let service = self.clone(); + async move { + let mut got_notification = false; + + for batch_entry in batch.into_iter() { + match batch_entry { + Ok(BatchEntry::Call(req)) => { + let rp = match service.call(req).await { + Ok(rp) => rp, + Err(e) => match e {}, + }; + if let Err(err) = batch_rp.append(rp) { + return Ok(err); + } + } + Ok(BatchEntry::Notification(n)) => { + got_notification = true; + match service.notification(n).await { + Ok(rp) => rp, + Err(e) => match e {}, + }; + } + Err(err) => { + let (err, id) = err.into_parts(); + let rp = MethodResponse::error(id, err); + if let Err(err) = batch_rp.append(rp) { + return Ok(err); + } + } + } + } + + // If the batch is empty and we got a notification, we return an empty response. + if batch_rp.is_empty() && got_notification { + Ok(MethodResponse::notification()) + } + // An empty batch is regarded as an invalid request here. + else { + Ok(MethodResponse::from_batch(batch_rp.finish())) + } + } + } + + fn notification<'a>( + &self, + n: Notification<'a>, + ) -> impl Future> + Send + 'a { + // The notification should not be replied to with a response + // but we propogate the extensions to the response which can be useful + // for example HTTP transport to set the headers. + async move { Ok(MethodResponse::notification().with_extensions(n.extensions)) } + } } diff --git a/server/src/middleware/rpc/layer/logger.rs b/server/src/middleware/rpc/layer/logger.rs deleted file mode 100644 index 50f6a616e3..0000000000 --- a/server/src/middleware/rpc/layer/logger.rs +++ /dev/null @@ -1,112 +0,0 @@ -// Copyright 2019-2021 Parity Technologies (UK) Ltd. -// -// Permission is hereby granted, free of charge, to any -// person obtaining a copy of this software and associated -// documentation files (the "Software"), to deal in the -// Software without restriction, including without -// limitation the rights to use, copy, modify, merge, -// publish, distribute, sublicense, and/or sell copies of -// the Software, and to permit persons to whom the Software -// is furnished to do so, subject to the following -// conditions: -// -// The above copyright notice and this permission notice -// shall be included in all copies or substantial portions -// of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF -// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED -// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A -// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT -// SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY -// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR -// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -// DEALINGS IN THE SOFTWARE. - -//! RPC Logger layer. - -use std::{ - pin::Pin, - task::{Context, Poll}, -}; - -use futures_util::Future; -use jsonrpsee_core::{ - server::MethodResponse, - tracing::server::{rx_log_from_json, tx_log_from_str}, -}; -use jsonrpsee_types::Request; -use pin_project::pin_project; -use tracing::{Instrument, instrument::Instrumented}; - -use crate::middleware::rpc::RpcServiceT; - -/// RPC logger layer. -#[derive(Copy, Clone, Debug)] -pub struct RpcLoggerLayer(u32); - -impl RpcLoggerLayer { - /// Create a new logging layer. - pub fn new(max: u32) -> Self { - Self(max) - } -} - -impl tower::Layer for RpcLoggerLayer { - type Service = RpcLogger; - - fn layer(&self, service: S) -> Self::Service { - RpcLogger { service, max: self.0 } - } -} - -/// A middleware that logs each RPC call and response. -#[derive(Debug)] -pub struct RpcLogger { - max: u32, - service: S, -} - -impl<'a, S> RpcServiceT<'a> for RpcLogger -where - S: RpcServiceT<'a>, -{ - type Future = Instrumented>; - - #[tracing::instrument(name = "method_call", skip_all, fields(method = request.method_name()), level = "trace")] - fn call(&self, request: Request<'a>) -> Self::Future { - rx_log_from_json(&request, self.max); - - ResponseFuture { fut: self.service.call(request), max: self.max }.in_current_span() - } -} - -/// Response future to log the response for a method call. -#[pin_project] -pub struct ResponseFuture { - #[pin] - fut: F, - max: u32, -} - -impl std::fmt::Debug for ResponseFuture { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str("ResponseFuture") - } -} - -impl> Future for ResponseFuture { - type Output = F::Output; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let max = self.max; - let fut = self.project().fut; - - let res = fut.poll(cx); - if let Poll::Ready(rp) = &res { - tx_log_from_str(rp.as_result(), max); - } - res - } -} diff --git a/server/src/middleware/rpc/mod.rs b/server/src/middleware/rpc/mod.rs deleted file mode 100644 index d9014cf781..0000000000 --- a/server/src/middleware/rpc/mod.rs +++ /dev/null @@ -1,108 +0,0 @@ -// Copyright 2019-2021 Parity Technologies (UK) Ltd. -// -// Permission is hereby granted, free of charge, to any -// person obtaining a copy of this software and associated -// documentation files (the "Software"), to deal in the -// Software without restriction, including without -// limitation the rights to use, copy, modify, merge, -// publish, distribute, sublicense, and/or sell copies of -// the Software, and to permit persons to whom the Software -// is furnished to do so, subject to the following -// conditions: -// -// The above copyright notice and this permission notice -// shall be included in all copies or substantial portions -// of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF -// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED -// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A -// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT -// SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY -// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR -// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -// DEALINGS IN THE SOFTWARE. - -//! Various middleware implementations for JSON-RPC specific purposes. - -pub mod layer; -pub use layer::*; - -use futures_util::Future; -use jsonrpsee_core::server::MethodResponse; -use jsonrpsee_types::Request; -use layer::either::Either; - -use tower::layer::LayerFn; -use tower::layer::util::{Identity, Stack}; - -/// Similar to the [`tower::Service`] but specific for jsonrpsee and -/// doesn't requires `&mut self` for performance reasons. -pub trait RpcServiceT<'a> { - /// The future response value. - type Future: Future + Send; - - /// Process a single JSON-RPC call it may be a subscription or regular call. - /// In this interface they are treated in the same way but it's possible to - /// distinguish those based on the `MethodResponse`. - fn call(&self, request: Request<'a>) -> Self::Future; -} - -/// Similar to [`tower::ServiceBuilder`] but doesn't -/// support any tower middleware implementations. -#[derive(Debug, Clone)] -pub struct RpcServiceBuilder(tower::ServiceBuilder); - -impl Default for RpcServiceBuilder { - fn default() -> Self { - RpcServiceBuilder(tower::ServiceBuilder::new()) - } -} - -impl RpcServiceBuilder { - /// Create a new [`RpcServiceBuilder`]. - pub fn new() -> Self { - Self(tower::ServiceBuilder::new()) - } -} - -impl RpcServiceBuilder { - /// Optionally add a new layer `T` to the [`RpcServiceBuilder`]. - /// - /// See the documentation for [`tower::ServiceBuilder::option_layer`] for more details. - pub fn option_layer(self, layer: Option) -> RpcServiceBuilder, L>> { - let layer = if let Some(layer) = layer { Either::Left(layer) } else { Either::Right(Identity::new()) }; - self.layer(layer) - } - - /// Add a new layer `T` to the [`RpcServiceBuilder`]. - /// - /// See the documentation for [`tower::ServiceBuilder::layer`] for more details. - pub fn layer(self, layer: T) -> RpcServiceBuilder> { - RpcServiceBuilder(self.0.layer(layer)) - } - - /// Add a [`tower::Layer`] built from a function that accepts a service and returns another service. - /// - /// See the documentation for [`tower::ServiceBuilder::layer_fn`] for more details. - pub fn layer_fn(self, f: F) -> RpcServiceBuilder, L>> { - RpcServiceBuilder(self.0.layer_fn(f)) - } - - /// Add a logging layer to [`RpcServiceBuilder`] - /// - /// This logs each request and response for every call. - /// - pub fn rpc_logger(self, max_log_len: u32) -> RpcServiceBuilder> { - RpcServiceBuilder(self.0.layer(RpcLoggerLayer::new(max_log_len))) - } - - /// Wrap the service `S` with the middleware. - pub(crate) fn service(&self, service: S) -> L::Service - where - L: tower::Layer, - { - self.0.service(service) - } -} diff --git a/server/src/server.rs b/server/src/server.rs index c65966318c..bd23c9a9a1 100644 --- a/server/src/server.rs +++ b/server/src/server.rs @@ -24,6 +24,7 @@ // IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. +use std::convert::Infallible; use std::error::Error as StdError; use std::future::Future; use std::net::{SocketAddr, TcpListener as StdTcpListener}; @@ -34,10 +35,10 @@ use std::task::Poll; use std::time::Duration; use crate::future::{ConnectionGuard, ServerHandle, SessionClose, SessionClosedFuture, StopHandle, session_close}; -use crate::middleware::rpc::{RpcService, RpcServiceBuilder, RpcServiceCfg, RpcServiceT}; +use crate::middleware::rpc::{RpcService, RpcServiceCfg}; use crate::transport::ws::BackgroundTaskParams; use crate::transport::{http, ws}; -use crate::utils::deserialize; +use crate::utils::deserialize_with_ext; use crate::{Extensions, HttpBody, HttpRequest, HttpResponse, LOG_TARGET}; use futures_util::future::{self, Either, FutureExt}; @@ -46,17 +47,16 @@ use futures_util::io::{BufReader, BufWriter}; use hyper::body::Bytes; use hyper_util::rt::{TokioExecutor, TokioIo}; use jsonrpsee_core::id_providers::RandomIntegerIdProvider; +use jsonrpsee_core::middleware::{Batch, BatchEntry, BatchEntryErr, RpcServiceBuilder, RpcServiceT}; use jsonrpsee_core::server::helpers::prepare_error; -use jsonrpsee_core::server::{ - BatchResponseBuilder, BoundedSubscriptions, ConnectionId, MethodResponse, MethodSink, Methods, -}; +use jsonrpsee_core::server::{BoundedSubscriptions, ConnectionId, MethodResponse, MethodSink, Methods}; use jsonrpsee_core::traits::IdProvider; use jsonrpsee_core::{BoxError, JsonRawValue, TEN_MB_SIZE_BYTES}; use jsonrpsee_types::error::{ BATCHES_NOT_SUPPORTED_CODE, BATCHES_NOT_SUPPORTED_MSG, ErrorCode, reject_too_big_batch_request, }; -use jsonrpsee_types::{ErrorObject, Id, InvalidRequest, Notification}; +use jsonrpsee_types::{ErrorObject, Id}; use soketto::handshake::http::is_upgrade_request; use tokio::net::{TcpListener, TcpStream, ToSocketAddrs}; use tokio::sync::{OwnedSemaphorePermit, mpsc, watch}; @@ -65,11 +65,11 @@ use tower::layer::util::Identity; use tower::{Layer, Service}; use tracing::{Instrument, instrument}; -type Notif<'a> = Notification<'a, Option<&'a JsonRawValue>>; - /// Default maximum connections allowed. const MAX_CONNECTIONS: u32 = 100; +type Notif<'a> = Option>; + /// JSON RPC server. pub struct Server { listener: TcpListener, @@ -101,7 +101,7 @@ impl Server { impl Server where RpcMiddleware: tower::Layer + Clone + Send + 'static, - for<'a> >::Service: RpcServiceT<'a>, + for<'a> >::Service: RpcServiceT, HttpMiddleware: Layer> + Send + 'static, >>::Service: Send + Clone + Service, Error = BoxError>, @@ -665,11 +665,8 @@ impl Builder { /// use std::{time::Instant, net::SocketAddr, sync::Arc}; /// use std::sync::atomic::{Ordering, AtomicUsize}; /// - /// use jsonrpsee_server::middleware::rpc::{RpcServiceT, RpcService, RpcServiceBuilder}; - /// use jsonrpsee_server::{ServerBuilder, MethodResponse}; - /// use jsonrpsee_core::async_trait; - /// use jsonrpsee_types::Request; - /// use futures_util::future::BoxFuture; + /// use jsonrpsee_server::middleware::rpc::{RpcService, RpcServiceBuilder, RpcServiceT, MethodResponse, Notification, Request, Batch}; + /// use jsonrpsee_server::ServerBuilder; /// /// #[derive(Clone)] /// struct MyMiddleware { @@ -677,23 +674,33 @@ impl Builder { /// count: Arc, /// } /// - /// impl<'a, S> RpcServiceT<'a> for MyMiddleware - /// where S: RpcServiceT<'a> + Send + Sync + Clone + 'static, + /// impl RpcServiceT for MyMiddleware + /// where S: RpcServiceT + Send + Sync + Clone + 'static, /// { - /// type Future = BoxFuture<'a, MethodResponse>; + /// type Error = S::Error; + /// type Response = S::Response; /// - /// fn call(&self, req: Request<'a>) -> Self::Future { + /// fn call<'a>(&self, req: Request<'a>) -> impl Future> + Send + 'a { /// tracing::info!("MyMiddleware processed call {}", req.method); /// let count = self.count.clone(); /// let service = self.service.clone(); /// - /// Box::pin(async move { + /// async move { /// let rp = service.call(req).await; /// // Modify the state. /// count.fetch_add(1, Ordering::Relaxed); /// rp - /// }) + /// } + /// } + /// + /// fn batch<'a>(&self, batch: Batch<'a>) -> impl Future> + Send + 'a { + /// self.service.batch(batch) + /// } + /// + /// fn notification<'a>(&self, notif: Notification<'a>) -> impl Future> + Send + 'a { + /// self.service.notification(notif) /// } + /// /// } /// /// // Create a state per connection @@ -929,8 +936,7 @@ impl TowerService impl Service> for TowerService where RpcMiddleware: for<'a> tower::Layer + Clone, - >::Service: Send + Sync + 'static, - for<'a> >::Service: RpcServiceT<'a>, + >::Service: RpcServiceT + Send + Sync + 'static, HttpMiddleware: Layer> + Send + 'static, >>::Service: Send + Service, Response = HttpResponse, Error = Box<(dyn StdError + Send + Sync + 'static)>>, @@ -966,8 +972,8 @@ pub struct TowerServiceNoHttp { impl Service> for TowerServiceNoHttp where RpcMiddleware: for<'a> tower::Layer, - >::Service: Send + Sync + 'static, - for<'a> >::Service: RpcServiceT<'a>, + >::Service: + RpcServiceT + Send + Sync + 'static, Body: http_body::Body + Send + 'static, Body::Error: Into, { @@ -1103,9 +1109,7 @@ where )); Box::pin(async move { - let rp = - http::call_with_service(request, batch_config, max_request_size, rpc_service, max_response_size) - .await; + let rp = http::call_with_service(request, batch_config, max_request_size, rpc_service).await; // NOTE: The `conn guard` must be held until the response is processed // to respect the `max_connections` limit. drop(conn); @@ -1231,22 +1235,32 @@ pub(crate) async fn handle_rpc_call( body: &[u8], is_single: bool, batch_config: BatchRequestConfig, - max_response_size: u32, rpc_service: &S, extensions: Extensions, -) -> Option +) -> MethodResponse where - for<'a> S: RpcServiceT<'a> + Send, + S: RpcServiceT + Send, + ::Error: std::fmt::Debug, { // Single request or notification if is_single { - if let Ok(req) = deserialize::from_slice_with_extensions(body, extensions) { - Some(rpc_service.call(req).await) - } else if let Ok(_notif) = serde_json::from_slice::(body) { - None + if let Ok(req) = deserialize_with_ext::call::from_slice(body, &extensions) { + match rpc_service.call(req).await { + Ok(rp) => rp, + Err(err) => match err {}, + } + } else if let Ok(notif) = deserialize_with_ext::notif::from_slice::(body, &extensions) { + match rpc_service.notification(notif).await { + Ok(rp) => rp, + Err(e) => { + // We don't care about if middleware/service encountered if it's an notification. + tracing::debug!(target: LOG_TARGET, "Notification error: {:?}", e); + return MethodResponse::notification(); + } + } } else { let (id, code) = prepare_error(body); - Some(MethodResponse::error(id, ErrorObject::from(code))) + MethodResponse::error(id, ErrorObject::from(code)) } } // Batch of requests. @@ -1257,53 +1271,40 @@ where Id::Null, ErrorObject::borrowed(BATCHES_NOT_SUPPORTED_CODE, BATCHES_NOT_SUPPORTED_MSG, None), ); - return Some(rp); + return rp; } BatchRequestConfig::Limit(limit) => limit as usize, BatchRequestConfig::Unlimited => usize::MAX, }; - if let Ok(batch) = serde_json::from_slice::>(body) { - if batch.len() > max_len { - return Some(MethodResponse::error(Id::Null, reject_too_big_batch_request(max_len))); + if let Ok(unchecked_batch) = serde_json::from_slice::>(body) { + if unchecked_batch.len() > max_len { + return MethodResponse::error(Id::Null, reject_too_big_batch_request(max_len)); } - let mut got_notif = false; - let mut batch_response = BatchResponseBuilder::new_with_limit(max_response_size as usize); - - for call in batch { - if let Ok(req) = deserialize::from_str_with_extensions(call.get(), extensions.clone()) { - let rp = rpc_service.call(req).await; + let mut batch = Vec::with_capacity(unchecked_batch.len()); - if let Err(too_large) = batch_response.append(rp) { - return Some(too_large); - } - } else if let Ok(_notif) = serde_json::from_str::(call.get()) { - // notifications should not be answered. - got_notif = true; + for call in unchecked_batch { + if let Ok(req) = deserialize_with_ext::call::from_str(call.get(), &extensions) { + batch.push(Ok(BatchEntry::Call(req))); + } else if let Ok(notif) = deserialize_with_ext::notif::from_str::(call.get(), &extensions) { + batch.push(Ok(BatchEntry::Notification(notif))); } else { - // valid JSON but could be not parsable as `InvalidRequest` - let id = match serde_json::from_str::(call.get()) { + let id = match serde_json::from_str::(call.get()) { Ok(err) => err.id, Err(_) => Id::Null, }; - if let Err(too_large) = - batch_response.append(MethodResponse::error(id, ErrorObject::from(ErrorCode::InvalidRequest))) - { - return Some(too_large); - } + batch.push(Err(BatchEntryErr::new(id, ErrorCode::InvalidRequest.into()))); } } - if got_notif && batch_response.is_empty() { - None - } else { - let batch_rp = batch_response.finish(); - Some(MethodResponse::from_batch(batch_rp)) + match rpc_service.batch(Batch::from(batch)).await { + Ok(rp) => rp, + Err(e) => match e {}, } } else { - Some(MethodResponse::error(Id::Null, ErrorObject::from(ErrorCode::ParseError))) + MethodResponse::error(Id::Null, ErrorObject::from(ErrorCode::ParseError)) } } } diff --git a/server/src/tests/http.rs b/server/src/tests/http.rs index 3bc6595d78..53ff655bb9 100644 --- a/server/src/tests/http.rs +++ b/server/src/tests/http.rs @@ -26,14 +26,14 @@ use std::net::SocketAddr; -use crate::middleware::rpc::{RpcServiceBuilder, RpcServiceT}; use crate::types::Request; use crate::{ - BatchRequestConfig, HttpBody, HttpRequest, HttpResponse, MethodResponse, RegisterMethodError, RpcModule, - ServerBuilder, ServerConfig, ServerHandle, + BatchRequestConfig, HttpBody, HttpRequest, HttpResponse, RegisterMethodError, RpcModule, ServerBuilder, + ServerConfig, ServerHandle, }; -use futures_util::future::{BoxFuture, Future, FutureExt}; +use futures_util::future::{Future, FutureExt}; use hyper::body::Bytes; +use jsonrpsee_core::middleware::{Batch, Notification, RpcServiceBuilder, RpcServiceT}; use jsonrpsee_core::{BoxError, RpcResult}; use jsonrpsee_test_utils::TimeoutFutureExt; use jsonrpsee_test_utils::helpers::*; @@ -57,20 +57,40 @@ struct InjectExt { service: S, } -impl<'a, S> RpcServiceT<'a> for InjectExt +impl RpcServiceT for InjectExt where - S: Send + Sync + RpcServiceT<'a> + Clone + 'static, + S: Send + Sync + RpcServiceT + Clone + 'static, { - type Future = BoxFuture<'a, MethodResponse>; + type Error = S::Error; + type Response = S::Response; - fn call(&self, mut req: Request<'a>) -> Self::Future { + fn call<'a>(&self, mut req: Request<'a>) -> impl Future> + Send + 'a { if req.method_name().contains("err") { req.extensions_mut().insert(StatusCode::IM_A_TEAPOT); } else { req.extensions_mut().insert(StatusCode::OK); } - self.service.call(req).boxed() + self.service.call(req) + } + + fn batch<'a>(&self, mut batch: Batch<'a>) -> impl Future> + Send + 'a { + if let Some(Ok(last)) = batch.iter_mut().last() { + if last.method_name().contains("err") { + last.extensions_mut().insert(StatusCode::IM_A_TEAPOT); + } else { + last.extensions_mut().insert(StatusCode::OK); + } + } + + self.service.batch(batch) + } + + fn notification<'a>( + &self, + _: Notification<'a>, + ) -> impl Future> + Send + 'a { + async { panic!("Not used for tests") } } } @@ -296,7 +316,7 @@ async fn batched_notifications() { let response = http_request(req.into(), uri).with_default_timeout().await.unwrap().unwrap(); assert_eq!(response.status, StatusCode::OK); // Note: on HTTP acknowledge the notification with an empty response. - assert_eq!(response.body, ""); + assert_eq!(response.body, "null"); } #[tokio::test] @@ -492,7 +512,7 @@ async fn notif_works() { let req = r#"{"jsonrpc":"2.0","method":"bar"}"#; let response = http_request(req.into(), uri).with_default_timeout().await.unwrap().unwrap(); assert_eq!(response.status, StatusCode::OK); - assert_eq!(response.body, ""); + assert_eq!(response.body, "null"); } #[tokio::test] diff --git a/server/src/transport/http.rs b/server/src/transport/http.rs index 060cdcfb6b..0c2562f3b0 100644 --- a/server/src/transport/http.rs +++ b/server/src/transport/http.rs @@ -1,6 +1,8 @@ +use std::convert::Infallible; + use crate::{ BatchRequestConfig, ConnectionState, HttpRequest, HttpResponse, LOG_TARGET, - middleware::rpc::{RpcService, RpcServiceBuilder, RpcServiceCfg, RpcServiceT}, + middleware::rpc::{RpcService, RpcServiceCfg}, server::{ServerConfig, handle_rpc_call}, }; use http::Method; @@ -8,7 +10,8 @@ use hyper::body::{Body, Bytes}; use jsonrpsee_core::{ BoxError, http_helpers::{HttpError, read_body}, - server::Methods, + middleware::{RpcServiceBuilder, RpcServiceT}, + server::{MethodResponse, Methods}, }; /// Checks that content type of received request is valid for JSON-RPC. @@ -42,9 +45,8 @@ where B: http_body::Body + Send + 'static, B::Data: Send, B::Error: Into, - L: for<'a> tower::Layer, - >::Service: Send + Sync + 'static, - for<'a> >::Service: RpcServiceT<'a>, + L: tower::Layer, + >::Service: RpcServiceT + Send, { let ServerConfig { max_response_body_size, batch_requests_config, max_request_body_size, .. } = server_cfg; @@ -55,9 +57,7 @@ where RpcServiceCfg::OnlyCalls, )); - let rp = - call_with_service(request, batch_requests_config, max_request_body_size, rpc_service, max_response_body_size) - .await; + let rp = call_with_service(request, batch_requests_config, max_request_body_size, rpc_service).await; drop(conn); @@ -72,13 +72,12 @@ pub async fn call_with_service( batch_config: BatchRequestConfig, max_request_size: u32, rpc_service: S, - max_response_size: u32, ) -> HttpResponse where B: http_body::Body + Send + 'static, B::Data: Send, B::Error: Into, - for<'a> S: RpcServiceT<'a> + Send, + S: RpcServiceT + Send, { // Only the `POST` method is allowed. match *request.method() { @@ -95,15 +94,11 @@ where } }; - if let Some(rp) = - handle_rpc_call(&body, is_single, batch_config, max_response_size, &rpc_service, parts.extensions).await - { - response::from_method_response(rp) - } else { - // If the response is empty it means that it was a notification or empty batch. - // For HTTP these are just ACK:ed with a empty body. - response::ok_response("") - } + let rp = handle_rpc_call(&body, is_single, batch_config, &rpc_service, parts.extensions).await; + + // If the response is empty it means that it was a notification or empty batch. + // For HTTP these are just ACK:ed with a empty body. + response::from_method_response(rp) } // Error scenarios: Method::POST => response::unsupported_content_type(), @@ -193,8 +188,7 @@ pub mod response { /// This will include the body and extensions from the method response. pub fn from_method_response(rp: MethodResponse) -> HttpResponse { let (body, _, extensions) = rp.into_parts(); - - let mut rp = from_template(hyper::StatusCode::OK, body, JSON); + let mut rp = from_template(hyper::StatusCode::OK, String::from(Box::::from(body)), JSON); rp.extensions_mut().extend(extensions); rp } diff --git a/server/src/transport/ws.rs b/server/src/transport/ws.rs index ff688f6dcf..ba55a807ab 100644 --- a/server/src/transport/ws.rs +++ b/server/src/transport/ws.rs @@ -1,8 +1,9 @@ +use std::convert::Infallible; use std::sync::Arc; use std::time::Instant; use crate::future::{IntervalStream, SessionClose}; -use crate::middleware::rpc::{RpcService, RpcServiceBuilder, RpcServiceCfg, RpcServiceT}; +use crate::middleware::rpc::{RpcService, RpcServiceCfg}; use crate::server::{ConnectionState, ServerConfig, handle_rpc_call}; use crate::{HttpBody, HttpRequest, HttpResponse, LOG_TARGET, PingConfig}; @@ -11,7 +12,8 @@ use futures_util::io::{BufReader, BufWriter}; use futures_util::{Future, StreamExt, TryStreamExt}; use hyper::upgrade::Upgraded; use hyper_util::rt::TokioIo; -use jsonrpsee_core::server::{BoundedSubscriptions, MethodSink, Methods}; +use jsonrpsee_core::middleware::{RpcServiceBuilder, RpcServiceT}; +use jsonrpsee_core::server::{BoundedSubscriptions, MethodResponse, MethodSink, Methods}; use jsonrpsee_types::Id; use jsonrpsee_types::error::{ErrorCode, reject_too_big_request}; use soketto::connection::Error as SokettoError; @@ -62,7 +64,7 @@ pub(crate) struct BackgroundTaskParams { pub(crate) async fn background_task(params: BackgroundTaskParams) where - for<'a> S: RpcServiceT<'a> + Send + Sync + 'static, + S: RpcServiceT + Send + Sync + 'static, { let BackgroundTaskParams { server_cfg, @@ -76,8 +78,7 @@ where mut on_session_close, extensions, } = params; - let ServerConfig { ping_config, batch_requests_config, max_request_body_size, max_response_body_size, .. } = - server_cfg; + let ServerConfig { ping_config, batch_requests_config, max_request_body_size, .. } = server_cfg; let (conn_tx, conn_rx) = oneshot::channel(); @@ -157,30 +158,23 @@ where } }; - if let Some(rp) = handle_rpc_call( - &data[idx..], - is_single, - batch_requests_config, - max_response_body_size, - &*rpc_service, - extensions, - ) - .await - { - if !rp.is_subscription() { - let is_success = rp.is_success(); - let (serialized_rp, mut on_close, _) = rp.into_parts(); - - // The connection is closed, just quit. - if sink.send(serialized_rp).await.is_err() { - return; - } + let rp = handle_rpc_call(&data[idx..], is_single, batch_requests_config, &*rpc_service, extensions).await; - // Notify that the message has been sent out to the internal - // WebSocket buffer. - if let Some(n) = on_close.take() { - n.notify(is_success); - } + // Subscriptions are handled by the subscription callback and + // "ordinary notifications" should not be sent back to the client. + if rp.is_method_call() || rp.is_batch() { + let is_success = rp.is_success(); + let (json, mut on_close, _) = rp.into_parts(); + + // The connection is closed, just quit. + if sink.send(String::from(Box::::from(json))).await.is_err() { + return; + } + + // Notify that the message has been sent out to the internal + // WebSocket buffer. + if let Some(n) = on_close.take() { + n.notify(is_success); } } }); @@ -388,7 +382,8 @@ async fn graceful_shutdown( /// /// ```no_run /// use jsonrpsee_server::{ws, ServerConfig, Methods, ConnectionState, HttpRequest, HttpResponse}; -/// use jsonrpsee_server::middleware::rpc::{RpcServiceBuilder, RpcServiceT, RpcService}; +/// use jsonrpsee_server::middleware::rpc::{RpcServiceBuilder, RpcServiceT, RpcService, MethodResponse}; +/// use std::convert::Infallible; /// /// async fn handle_websocket_conn( /// req: HttpRequest, @@ -399,9 +394,8 @@ async fn graceful_shutdown( /// mut disconnect: tokio::sync::mpsc::Receiver<()> /// ) -> HttpResponse /// where -/// L: for<'a> tower::Layer + 'static, -/// >::Service: Send + Sync + 'static, -/// for<'a> >::Service: RpcServiceT<'a> + 'static, +/// L: tower::Layer + 'static, +/// >::Service: RpcServiceT + Send + Sync + 'static, /// { /// match ws::connect(req, server_cfg, methods, conn, rpc_middleware).await { /// Ok((rp, conn_fut)) => { @@ -427,9 +421,9 @@ pub async fn connect( rpc_middleware: RpcServiceBuilder, ) -> Result<(HttpResponse, impl Future), HttpResponse> where - L: for<'a> tower::Layer, - >::Service: Send + Sync + 'static, - for<'a> >::Service: RpcServiceT<'a>, + L: tower::Layer, + >::Service: + RpcServiceT + Send + Sync + 'static, { let mut server = soketto::handshake::http::Server::new(); diff --git a/server/src/utils.rs b/server/src/utils.rs index cf2dc6daaa..86902a84d6 100644 --- a/server/src/utils.rs +++ b/server/src/utils.rs @@ -141,25 +141,58 @@ where } } -/// Helpers to deserialize a request with extensions. -pub(crate) mod deserialize { - /// Helper to deserialize a request with extensions. - pub(crate) fn from_slice_with_extensions( - data: &[u8], - extensions: http::Extensions, - ) -> Result { - let mut req: jsonrpsee_types::Request = serde_json::from_slice(data)?; - *req.extensions_mut() = extensions; - Ok(req) +/// Deserialize calls, notifications and responses with HTTP extensions. +pub mod deserialize_with_ext { + /// Method call. + pub mod call { + use jsonrpsee_types::Request; + + /// Wrapper over `serde_json::from_slice` that sets the extensions. + pub fn from_slice<'a>( + data: &'a [u8], + extensions: &'a http::Extensions, + ) -> Result, serde_json::Error> { + let mut req: Request = serde_json::from_slice(data)?; + *req.extensions_mut() = extensions.clone(); + Ok(req) + } + + /// Wrapper over `serde_json::from_str` that sets the extensions. + pub fn from_str<'a>(data: &'a str, extensions: &'a http::Extensions) -> Result, serde_json::Error> { + let mut req: Request = serde_json::from_str(data)?; + *req.extensions_mut() = extensions.clone(); + Ok(req) + } } - /// Helper to deserialize a request with extensions. - pub(crate) fn from_str_with_extensions( - data: &str, - extensions: http::Extensions, - ) -> Result { - let mut req: jsonrpsee_types::Request = serde_json::from_str(data)?; - *req.extensions_mut() = extensions; - Ok(req) + /// Notification. + pub mod notif { + use jsonrpsee_types::Notification; + + /// Wrapper over `serde_json::from_slice` that sets the extensions. + pub fn from_slice<'a, T>( + data: &'a [u8], + extensions: &'a http::Extensions, + ) -> Result, serde_json::Error> + where + T: serde::Deserialize<'a>, + { + let mut notif: Notification = serde_json::from_slice(data)?; + *notif.extensions_mut() = extensions.clone(); + Ok(notif) + } + + /// Wrapper over `serde_json::from_str` that sets the extensions. + pub fn from_str<'a, T>( + data: &'a str, + extensions: &http::Extensions, + ) -> Result, serde_json::Error> + where + T: serde::Deserialize<'a>, + { + let mut notif: Notification = serde_json::from_str(data)?; + *notif.extensions_mut() = extensions.clone(); + Ok(notif) + } } } diff --git a/tests/tests/helpers.rs b/tests/tests/helpers.rs index 906f8affec..11a600b81a 100644 --- a/tests/tests/helpers.rs +++ b/tests/tests/helpers.rs @@ -35,12 +35,11 @@ use std::time::Duration; use fast_socks5::client::Socks5Stream; use fast_socks5::server; use futures::{SinkExt, Stream, StreamExt}; +use jsonrpsee::core::middleware::{Batch, Notification, RpcServiceBuilder, RpcServiceT}; use jsonrpsee::server::middleware::http::ProxyGetRequestLayer; - -use jsonrpsee::server::middleware::rpc::RpcServiceT; use jsonrpsee::server::{ - ConnectionGuard, PendingSubscriptionSink, RpcModule, RpcServiceBuilder, Server, ServerBuilder, ServerHandle, - SubscriptionMessage, TrySendError, serve_with_graceful_shutdown, stop_channel, + ConnectionGuard, PendingSubscriptionSink, RpcModule, Server, ServerBuilder, ServerHandle, SubscriptionMessage, + TrySendError, serve_with_graceful_shutdown, stop_channel, }; use jsonrpsee::types::{ErrorObject, ErrorObjectOwned}; use jsonrpsee::{Methods, SubscriptionCloseResponse}; @@ -150,16 +149,31 @@ pub async fn server() -> SocketAddr { connection_id: u32, } - impl<'a, S> RpcServiceT<'a> for ConnectionDetails + impl RpcServiceT for ConnectionDetails where - S: RpcServiceT<'a>, + S: RpcServiceT, { - type Future = S::Future; + type Error = S::Error; + type Response = S::Response; - fn call(&self, mut request: jsonrpsee::types::Request<'a>) -> Self::Future { + fn call<'a>( + &self, + mut request: jsonrpsee::types::Request<'a>, + ) -> impl Future> + Send + 'a { request.extensions_mut().insert(self.connection_id); self.inner.call(request) } + + fn batch<'a>(&self, batch: Batch<'a>) -> impl Future> + Send + 'a { + self.inner.batch(batch) + } + + fn notification<'a>( + &self, + _: Notification<'a>, + ) -> impl Future> + Send + 'a { + async { panic!("Not used for tests") } + } } let mut module = RpcModule::new(()); diff --git a/tests/tests/integration_tests.rs b/tests/tests/integration_tests.rs index 80923c80d4..7c6d2741cc 100644 --- a/tests/tests/integration_tests.rs +++ b/tests/tests/integration_tests.rs @@ -767,15 +767,9 @@ async fn ws_batch_works() { assert_eq!(res.num_failed_calls(), 1); let ok_responses: Vec<_> = res.iter().filter_map(|r| r.as_ref().ok()).collect(); - let err_responses: Vec<_> = res - .iter() - .filter_map(|r| match r { - Err(e) => Some(e), - _ => None, - }) - .collect(); + let err_responses: Vec<_> = res.iter().filter_map(|r| r.clone().err()).collect(); assert_eq!(ok_responses, vec!["hello"]); - assert_eq!(err_responses, vec![&ErrorObject::borrowed(UNKNOWN_ERROR_CODE, "err", None)]); + assert_eq!(err_responses, vec![ErrorObject::borrowed(UNKNOWN_ERROR_CODE, "err", None)]); } #[tokio::test] @@ -807,15 +801,9 @@ async fn http_batch_works() { assert_eq!(res.num_failed_calls(), 1); let ok_responses: Vec<_> = res.iter().filter_map(|r| r.as_ref().ok()).collect(); - let err_responses: Vec<_> = res - .iter() - .filter_map(|r| match r { - Err(e) => Some(e), - _ => None, - }) - .collect(); + let err_responses: Vec<_> = res.iter().filter_map(|r| r.clone().err()).collect(); assert_eq!(ok_responses, vec!["hello"]); - assert_eq!(err_responses, vec![&ErrorObject::borrowed(UNKNOWN_ERROR_CODE, "err", None)]); + assert_eq!(err_responses, vec![ErrorObject::borrowed(UNKNOWN_ERROR_CODE, "err", None)]); } #[tokio::test] @@ -1466,10 +1454,10 @@ async fn server_ws_low_api_works() { async fn run_server() -> anyhow::Result { use futures_util::future::FutureExt; - use jsonrpsee::core::BoxError; + use jsonrpsee::core::{BoxError, middleware::RpcServiceBuilder}; use jsonrpsee::server::{ - ConnectionGuard, ConnectionState, Methods, ServerConfig, StopHandle, http, - middleware::rpc::RpcServiceBuilder, serve_with_graceful_shutdown, stop_channel, ws, + ConnectionGuard, ConnectionState, Methods, ServerConfig, StopHandle, http, serve_with_graceful_shutdown, + stop_channel, ws, }; let listener = tokio::net::TcpListener::bind(std::net::SocketAddr::from(([127, 0, 0, 1], 0))).await?; diff --git a/tests/tests/metrics.rs b/tests/tests/metrics.rs index f7c0796278..eb4c74dd4b 100644 --- a/tests/tests/metrics.rs +++ b/tests/tests/metrics.rs @@ -34,18 +34,15 @@ use std::net::SocketAddr; use std::sync::{Arc, Mutex}; use std::time::Duration; -use futures::FutureExt; -use futures::future::BoxFuture; use helpers::init_logger; -use jsonrpsee::RpcModule; +use jsonrpsee::core::middleware::{Batch, Notification, Request, RpcServiceBuilder, RpcServiceT}; use jsonrpsee::core::{ClientError, client::ClientT}; use jsonrpsee::http_client::HttpClientBuilder; use jsonrpsee::proc_macros::rpc; -use jsonrpsee::server::middleware::rpc::{RpcServiceBuilder, RpcServiceT}; use jsonrpsee::server::{Server, ServerHandle}; -use jsonrpsee::types::{ErrorObject, ErrorObjectOwned, Id, Request}; +use jsonrpsee::types::{ErrorObject, ErrorObjectOwned, Id}; use jsonrpsee::ws_client::WsClientBuilder; -use jsonrpsee::{MethodResponse, rpc_params}; +use jsonrpsee::{MethodResponse, RpcModule, rpc_params}; use tokio::time::sleep; #[derive(Default, Clone)] @@ -62,13 +59,14 @@ pub struct CounterMiddleware { counter: Arc>, } -impl<'a, S> RpcServiceT<'a> for CounterMiddleware +impl RpcServiceT for CounterMiddleware where - S: RpcServiceT<'a> + Send + Sync + Clone + 'static, + S: RpcServiceT + Send + Sync + Clone + 'static, { - type Future = BoxFuture<'a, MethodResponse>; + type Error = S::Error; + type Response = S::Response; - fn call(&self, request: Request<'a>) -> Self::Future { + fn call<'a>(&self, request: Request<'a>) -> impl Future> + Send + 'a { let counter = self.counter.clone(); let service = self.service.clone(); @@ -88,14 +86,24 @@ where { let mut n = counter.lock().unwrap(); n.requests.1 += 1; - if rp.is_success() { + if rp.as_ref().is_ok_and(|r| r.is_success()) { n.calls.get_mut(&name).unwrap().1.push(id.into_owned()); } } rp } - .boxed() + } + + fn batch<'a>(&self, _: Batch<'a>) -> impl Future> + Send + 'a { + async { panic!("Not used for tests") } + } + + fn notification<'a>( + &self, + _: Notification<'a>, + ) -> impl Future> + Send + 'a { + async { panic!("Not used for tests") } } } diff --git a/tests/tests/proc_macros.rs b/tests/tests/proc_macros.rs index ff0a6863c1..e746bc1ab7 100644 --- a/tests/tests/proc_macros.rs +++ b/tests/tests/proc_macros.rs @@ -262,7 +262,7 @@ async fn macro_optional_param_parsing() { .raw_json_request(r#"{"jsonrpc":"2.0","method":"foo_optional_params","params":{"a":22,"c":50},"id":0}"#, 1) .await .unwrap(); - assert_eq!(resp, r#"{"jsonrpc":"2.0","id":0,"result":"Called with: 22, None, Some(50)"}"#); + assert_eq!(resp.get(), r#"{"jsonrpc":"2.0","id":0,"result":"Called with: 22, None, Some(50)"}"#); } #[tokio::test] @@ -286,14 +286,14 @@ async fn macro_zero_copy_cow() { .unwrap(); // std::borrow::Cow always deserialized to owned variant here - assert_eq!(resp, r#"{"jsonrpc":"2.0","id":0,"result":"Zero copy params: false"}"#); + assert_eq!(resp.get(), r#"{"jsonrpc":"2.0","id":0,"result":"Zero copy params: false"}"#); // serde_json will have to allocate a new string to replace `\t` with byte 0x09 (tab) let (resp, _) = module .raw_json_request(r#"{"jsonrpc":"2.0","method":"foo_zero_copy_cow","params":["\tfoo"],"id":0}"#, 1) .await .unwrap(); - assert_eq!(resp, r#"{"jsonrpc":"2.0","id":0,"result":"Zero copy params: false"}"#); + assert_eq!(resp.get(), r#"{"jsonrpc":"2.0","id":0,"result":"Zero copy params: false"}"#); } #[tokio::test] diff --git a/tests/tests/rpc_module.rs b/tests/tests/rpc_module.rs index 7f7e2d1e4b..ccbc8fb1cb 100644 --- a/tests/tests/rpc_module.rs +++ b/tests/tests/rpc_module.rs @@ -383,13 +383,13 @@ async fn subscribe_unsubscribe_without_server() { let unsub_req = format!("{{\"jsonrpc\":\"2.0\",\"method\":\"my_unsub\",\"params\":[{}],\"id\":1}}", ser_id); let (resp, _) = module.raw_json_request(&unsub_req, 1).await.unwrap(); - assert_eq!(resp, r#"{"jsonrpc":"2.0","id":1,"result":true}"#); + assert_eq!(resp.get(), r#"{"jsonrpc":"2.0","id":1,"result":true}"#); // Unsubscribe already performed; should be error. let unsub_req = format!("{{\"jsonrpc\":\"2.0\",\"method\":\"my_unsub\",\"params\":[{}],\"id\":1}}", ser_id); let (resp, _) = module.raw_json_request(&unsub_req, 2).await.unwrap(); - assert_eq!(resp, r#"{"jsonrpc":"2.0","id":1,"result":false}"#); + assert_eq!(resp.get(), r#"{"jsonrpc":"2.0","id":1,"result":false}"#); } let sub1 = subscribe_and_assert(&module); @@ -429,7 +429,7 @@ async fn reject_works() { .unwrap(); let (rp, mut stream) = module.raw_json_request(r#"{"jsonrpc":"2.0","method":"my_sub","id":0}"#, 1).await.unwrap(); - assert_eq!(rp, r#"{"jsonrpc":"2.0","id":0,"error":{"code":-32700,"message":"rejected"}}"#); + assert_eq!(rp.get(), r#"{"jsonrpc":"2.0","id":0,"error":{"code":-32700,"message":"rejected"}}"#); assert!(stream.recv().await.is_none()); } @@ -520,7 +520,7 @@ async fn serialize_sub_error_adds_extra_string_quotes() { .unwrap(); let (rp, mut stream) = module.raw_json_request(r#"{"jsonrpc":"2.0","method":"my_sub","id":0}"#, 1).await.unwrap(); - let resp = serde_json::from_str::>(&rp).unwrap(); + let resp = serde_json::from_str::>(rp.get()).unwrap(); let sub_resp = stream.recv().await.unwrap(); let resp = match resp.payload { @@ -565,7 +565,7 @@ async fn subscription_close_response_works() { { let (rp, mut stream) = module.raw_json_request(r#"{"jsonrpc":"2.0","method":"my_sub","params":[1],"id":0}"#, 1).await.unwrap(); - let resp = serde_json::from_str::>(&rp).unwrap(); + let resp = serde_json::from_str::>(rp.get()).unwrap(); let sub_id = match resp.payload { ResponsePayload::Success(val) => val, @@ -629,7 +629,7 @@ async fn method_response_notify_on_completion() { // Low level call should also work. let (rp, _) = module.raw_json_request(r#"{"jsonrpc":"2.0","method":"hey","params":["success"],"id":0}"#, 1).await.unwrap(); - assert_eq!(rp, r#"{"jsonrpc":"2.0","id":0,"result":"lo"}"#); + assert_eq!(rp.get(), r#"{"jsonrpc":"2.0","id":0,"result":"lo"}"#); assert!(matches!(rx.recv().await, Some(Ok(_)))); // Error call should return a failed notification. diff --git a/types/src/lib.rs b/types/src/lib.rs index 48b6d87692..c5dcd5771a 100644 --- a/types/src/lib.rs +++ b/types/src/lib.rs @@ -44,6 +44,7 @@ pub mod response; pub mod error; pub use error::{ErrorCode, ErrorObject, ErrorObjectOwned}; +pub use http::Extensions; pub use params::{Id, InvalidRequestId, Params, ParamsSequence, SubscriptionId, TwoPointZero}; -pub use request::{InvalidRequest, Notification, NotificationSer, Request, RequestSer}; +pub use request::{InvalidRequest, Notification, Request}; pub use response::{Response, ResponsePayload, SubscriptionPayload, SubscriptionResponse, Success as ResponseSuccess}; diff --git a/types/src/request.rs b/types/src/request.rs index b954586c26..ece555fc38 100644 --- a/types/src/request.rs +++ b/types/src/request.rs @@ -50,6 +50,7 @@ pub struct Request<'a> { pub method: Cow<'a, str>, /// Parameter values of the request. #[serde(borrow)] + #[serde(skip_serializing_if = "Option::is_none")] pub params: Option>, /// The request's extensions. #[serde(skip)] @@ -57,9 +58,26 @@ pub struct Request<'a> { } impl<'a> Request<'a> { - /// Create a new [`Request`]. - pub fn new(method: Cow<'a, str>, params: Option<&'a RawValue>, id: Id<'a>) -> Self { - Self { jsonrpc: TwoPointZero, id, method, params: params.map(Cow::Borrowed), extensions: Extensions::new() } + /// Create new borrowed [`Request`]. + pub fn borrowed(method: &'a str, params: Option<&'a RawValue>, id: Id<'a>) -> Self { + Self { + jsonrpc: TwoPointZero, + id, + method: Cow::Borrowed(method), + params: params.map(Cow::Borrowed), + extensions: Extensions::new(), + } + } + + /// Create new owned [`Request`]. + pub fn owned(method: String, params: Option>, id: Id<'a>) -> Self { + Self { + jsonrpc: TwoPointZero, + id, + method: Cow::Owned(method), + params: params.map(Cow::Owned), + extensions: Extensions::new(), + } } /// Get the ID of the request. @@ -89,7 +107,7 @@ impl<'a> Request<'a> { } /// JSON-RPC Invalid request as defined in the [spec](https://www.jsonrpc.org/specification#request-object). -#[derive(Deserialize, Debug, PartialEq, Eq)] +#[derive(Deserialize, Debug, PartialEq, Eq, Clone)] pub struct InvalidRequest<'a> { /// Request ID #[serde(borrow)] @@ -107,75 +125,41 @@ pub struct Notification<'a, T> { pub method: Cow<'a, str>, /// Parameter values of the request. pub params: T, + /// Extensions of the notification. + #[serde(skip)] + pub extensions: Extensions, } impl<'a, T> Notification<'a, T> { /// Create a new [`Notification`]. pub fn new(method: Cow<'a, str>, params: T) -> Self { - Self { jsonrpc: TwoPointZero, method, params } + Self { jsonrpc: TwoPointZero, method, params, extensions: Extensions::new() } } -} -/// Serializable [JSON-RPC object](https://www.jsonrpc.org/specification#request-object). -#[derive(Serialize, Debug, Clone)] -pub struct RequestSer<'a> { - /// JSON-RPC version. - pub jsonrpc: TwoPointZero, - /// Request ID - pub id: Id<'a>, - /// Name of the method to be invoked. - // NOTE: as this type only implements serialize `#[serde(borrow)]` is not needed. - pub method: Cow<'a, str>, - /// Parameter values of the request. - #[serde(skip_serializing_if = "Option::is_none")] - pub params: Option>, -} - -impl<'a> RequestSer<'a> { - /// Create a borrowed serializable JSON-RPC method call. - pub fn borrowed(id: &'a Id<'a>, method: &'a impl AsRef, params: Option<&'a RawValue>) -> Self { - Self { - jsonrpc: TwoPointZero, - id: id.clone(), - method: method.as_ref().into(), - params: params.map(Cow::Borrowed), - } + /// Get the method name of the request. + pub fn method_name(&self) -> &str { + &self.method } - /// Create a owned serializable JSON-RPC method call. - pub fn owned(id: Id<'a>, method: impl Into, params: Option>) -> Self { - Self { jsonrpc: TwoPointZero, id, method: method.into().into(), params: params.map(Cow::Owned) } + /// Returns a reference to the associated extensions. + pub fn extensions(&self) -> &Extensions { + &self.extensions } -} -/// Serializable [JSON-RPC notification object](https://www.jsonrpc.org/specification#request-object). -#[derive(Serialize, Debug, Clone)] -pub struct NotificationSer<'a> { - /// JSON-RPC version. - pub jsonrpc: TwoPointZero, - /// Name of the method to be invoked. - // NOTE: as this type only implements serialize `#[serde(borrow)]` is not needed. - pub method: Cow<'a, str>, - /// Parameter values of the request. - #[serde(skip_serializing_if = "Option::is_none")] - pub params: Option>, -} - -impl<'a> NotificationSer<'a> { - /// Create a borrowed serializable JSON-RPC notification. - pub fn borrowed(method: &'a impl AsRef, params: Option<&'a RawValue>) -> Self { - Self { jsonrpc: TwoPointZero, method: method.as_ref().into(), params: params.map(Cow::Borrowed) } + /// Get the params of the request. + pub fn params(&self) -> &T { + &self.params } - /// Create an owned serializable JSON-RPC notification. - pub fn owned(method: impl Into, params: Option>) -> Self { - Self { jsonrpc: TwoPointZero, method: method.into().into(), params: params.map(Cow::Owned) } + /// Returns a reference to the associated extensions. + pub fn extensions_mut(&mut self) -> &mut Extensions { + &mut self.extensions } } #[cfg(test)] mod test { - use super::{Cow, Id, InvalidRequest, Notification, NotificationSer, Request, RequestSer, TwoPointZero}; + use super::{Id, InvalidRequest, Notification, Request, TwoPointZero}; use serde_json::value::RawValue; fn assert_request<'a>(request: Request<'a>, id: Id<'a>, method: &str, params: Option<&str>) { @@ -273,32 +257,10 @@ mod test { ]; for (ser, id, params, method) in test_vector.iter().cloned() { - let request = serde_json::to_string(&RequestSer { - jsonrpc: TwoPointZero, - method: method.into(), - id: id.unwrap_or(Id::Null), - params: params.map(Cow::Owned), - }) - .unwrap(); + let request = + serde_json::to_string(&Request::borrowed(method, params.as_deref(), id.unwrap_or(Id::Null))).unwrap(); assert_eq!(&request, ser); } } - - #[test] - fn serialize_notif() { - let exp = r#"{"jsonrpc":"2.0","method":"say_hello","params":["hello"]}"#; - let params = Some(RawValue::from_string(r#"["hello"]"#.into()).unwrap()); - let req = NotificationSer::owned("say_hello", params); - let ser = serde_json::to_string(&req).unwrap(); - assert_eq!(exp, ser); - } - - #[test] - fn serialize_notif_escaped_method_name() { - let exp = r#"{"jsonrpc":"2.0","method":"\"method\""}"#; - let req = NotificationSer::owned("\"method\"", None); - let ser = serde_json::to_string(&req).unwrap(); - assert_eq!(exp, ser); - } } diff --git a/types/src/response.rs b/types/src/response.rs index 3ebb3df640..1a6f143149 100644 --- a/types/src/response.rs +++ b/types/src/response.rs @@ -34,6 +34,7 @@ use crate::error::ErrorCode; use crate::params::{Id, SubscriptionId, TwoPointZero}; use crate::request::Notification; use crate::{ErrorObject, ErrorObjectOwned}; +use http::Extensions; use serde::ser::SerializeStruct; use serde::{Deserialize, Deserializer, Serialize, Serializer}; @@ -45,17 +46,39 @@ pub struct Response<'a, T: Clone> { pub payload: ResponsePayload<'a, T>, /// Request ID pub id: Id<'a>, + /// Extensions + pub extensions: Extensions, } impl<'a, T: Clone> Response<'a, T> { /// Create a new [`Response`]. pub fn new(payload: ResponsePayload<'a, T>, id: Id<'a>) -> Response<'a, T> { - Response { jsonrpc: Some(TwoPointZero), payload, id } + Response { jsonrpc: Some(TwoPointZero), payload, id, extensions: Extensions::new() } + } + + /// Create a new [`Response`] with extensions + pub fn new_with_extensions(payload: ResponsePayload<'a, T>, id: Id<'a>, ext: Extensions) -> Response<'a, T> { + Response { jsonrpc: Some(TwoPointZero), payload, id, extensions: ext } } /// Create an owned [`Response`]. pub fn into_owned(self) -> Response<'static, T> { - Response { jsonrpc: self.jsonrpc, payload: self.payload.into_owned(), id: self.id.into_owned() } + Response { + jsonrpc: self.jsonrpc, + payload: self.payload.into_owned(), + id: self.id.into_owned(), + extensions: self.extensions, + } + } + + /// Get the extensions of the response. + pub fn extensions(&self) -> &Extensions { + &self.extensions + } + + /// Get the mutable ref to the extensions of the response. + pub fn extensions_mut(&mut self) -> &mut Extensions { + &mut self.extensions } } @@ -293,14 +316,27 @@ where (_, Some(_), Some(_)) => { return Err(serde::de::Error::duplicate_field("result and error are mutually exclusive")); } - (Some(jsonrpc), Some(result), None) => { - Response { jsonrpc, payload: ResponsePayload::Success(result), id } - } - (Some(jsonrpc), None, Some(err)) => Response { jsonrpc, payload: ResponsePayload::Error(err), id }, - (None, Some(result), _) => { - Response { jsonrpc: None, payload: ResponsePayload::Success(result), id } + (Some(jsonrpc), Some(result), None) => Response { + jsonrpc, + payload: ResponsePayload::Success(result), + id, + extensions: Extensions::new(), + }, + (Some(jsonrpc), None, Some(err)) => { + Response { jsonrpc, payload: ResponsePayload::Error(err), id, extensions: Extensions::new() } } - (None, _, Some(err)) => Response { jsonrpc: None, payload: ResponsePayload::Error(err), id }, + (None, Some(result), _) => Response { + jsonrpc: None, + payload: ResponsePayload::Success(result), + id, + extensions: Extensions::new(), + }, + (None, _, Some(err)) => Response { + jsonrpc: None, + payload: ResponsePayload::Error(err), + id, + extensions: Extensions::new(), + }, (_, None, None) => return Err(serde::de::Error::missing_field("result/error")), }; @@ -340,6 +376,8 @@ where #[cfg(test)] mod tests { + use http::Extensions; + use super::{Id, Response, TwoPointZero}; use crate::{ErrorObjectOwned, response::ResponsePayload}; @@ -349,6 +387,7 @@ mod tests { jsonrpc: Some(TwoPointZero), payload: ResponsePayload::success("ok"), id: Id::Number(1), + extensions: Extensions::new(), }) .unwrap(); let exp = r#"{"jsonrpc":"2.0","id":1,"result":"ok"}"#; @@ -361,6 +400,7 @@ mod tests { jsonrpc: Some(TwoPointZero), payload: ResponsePayload::<()>::error(ErrorObjectOwned::owned(1, "lo", None::<()>)), id: Id::Number(1), + extensions: Extensions::new(), }) .unwrap(); let exp = r#"{"jsonrpc":"2.0","id":1,"error":{"code":1,"message":"lo"}}"#; @@ -373,6 +413,7 @@ mod tests { jsonrpc: None, payload: ResponsePayload::success("ok"), id: Id::Number(1), + extensions: Extensions::new(), }) .unwrap(); let exp = r#"{"id":1,"result":"ok"}"#; @@ -381,8 +422,12 @@ mod tests { #[test] fn deserialize_success_call() { - let exp = - Response { jsonrpc: Some(TwoPointZero), payload: ResponsePayload::success(99_u64), id: Id::Number(11) }; + let exp = Response { + jsonrpc: Some(TwoPointZero), + payload: ResponsePayload::success(99_u64), + id: Id::Number(11), + extensions: Extensions::new(), + }; let dsr: Response = serde_json::from_str(r#"{"jsonrpc":"2.0", "result":99, "id":11}"#).unwrap(); assert_eq!(dsr.jsonrpc, exp.jsonrpc); assert_eq!(dsr.payload, exp.payload); @@ -395,6 +440,7 @@ mod tests { jsonrpc: Some(TwoPointZero), payload: ResponsePayload::error(ErrorObjectOwned::owned(1, "lo", None::<()>)), id: Id::Number(11), + extensions: Extensions::new(), }; let dsr: Response<()> = serde_json::from_str(r#"{"jsonrpc":"2.0","error":{"code":1,"message":"lo"},"id":11}"#).unwrap(); @@ -405,7 +451,12 @@ mod tests { #[test] fn deserialize_call_missing_version_field() { - let exp = Response { jsonrpc: None, payload: ResponsePayload::success(99_u64), id: Id::Number(11) }; + let exp = Response { + jsonrpc: None, + payload: ResponsePayload::success(99_u64), + id: Id::Number(11), + extensions: Extensions::new(), + }; let dsr: Response = serde_json::from_str(r#"{"jsonrpc":null, "result":99, "id":11}"#).unwrap(); assert_eq!(dsr.jsonrpc, exp.jsonrpc); assert_eq!(dsr.payload, exp.payload); @@ -414,7 +465,12 @@ mod tests { #[test] fn deserialize_with_unknown_field() { - let exp = Response { jsonrpc: None, payload: ResponsePayload::success(99_u64), id: Id::Number(11) }; + let exp = Response { + jsonrpc: None, + payload: ResponsePayload::success(99_u64), + id: Id::Number(11), + extensions: Extensions::new(), + }; let dsr: Response = serde_json::from_str(r#"{"jsonrpc":null, "result":99, "id":11, "unknown":11}"#).unwrap(); assert_eq!(dsr.jsonrpc, exp.jsonrpc);