diff --git a/client/http-client/src/client.rs b/client/http-client/src/client.rs index 1648157ce7..1d58804929 100644 --- a/client/http-client/src/client.rs +++ b/client/http-client/src/client.rs @@ -36,7 +36,6 @@ use hyper::body::Bytes; use hyper::http::{Extensions, HeaderMap}; use jsonrpsee_core::client::{ BatchResponse, ClientT, Error, IdKind, MethodResponse, RequestIdManager, Subscription, SubscriptionClientT, - generate_batch_id_range, }; use jsonrpsee_core::middleware::layer::RpcLoggerLayer; use jsonrpsee_core::middleware::{Batch, RpcServiceBuilder, RpcServiceT}; @@ -415,20 +414,19 @@ where None => None, }; let batch = batch.build()?; - let id = self.id_manager.next_request_id(); - let id_range = generate_batch_id_range(id, batch.len() as u64)?; + let mut ids = Vec::new(); let mut batch_request = Batch::with_capacity(batch.len()); - for ((method, params), id) in batch.into_iter().zip(id_range.clone()) { - let id = self.id_manager.as_id_kind().into_id(id); - let req = Request { + for (method, params) in batch.into_iter() { + let id = self.id_manager.next_request_id(); + batch_request.push(Request { jsonrpc: TwoPointZero, method: method.into(), params: params.map(StdCow::Owned), - id, + id: id.clone(), extensions: Extensions::new(), - }; - batch_request.push(req); + }); + ids.push(id); } let rp = run_future_until_timeout(self.service.batch(batch_request), self.request_timeout).await?; @@ -444,7 +442,7 @@ where } for rp in json_rps.into_iter() { - let id = rp.id().try_parse_inner_as_number()?; + let id = rp.id().clone(); let res = match ResponseSuccess::try_from(rp.into_inner()) { Ok(r) => { @@ -458,13 +456,8 @@ where } }; - let maybe_elem = id - .checked_sub(id_range.start) - .and_then(|p| p.try_into().ok()) - .and_then(|p: usize| batch_response.get_mut(p)); - - if let Some(elem) = maybe_elem { - *elem = res; + if let Some(pos) = ids.iter().position(|stored_id| stored_id == &id) { + batch_response[pos] = res; } else { return Err(InvalidRequestId::NotPendingRequest(id.to_string()).into()); } diff --git a/client/http-client/src/tests.rs b/client/http-client/src/tests.rs index f6af48e623..d094b65df9 100644 --- a/client/http-client/src/tests.rs +++ b/client/http-client/src/tests.rs @@ -34,6 +34,7 @@ use jsonrpsee_test_utils::TimeoutFutureExt; use jsonrpsee_test_utils::helpers::*; use jsonrpsee_test_utils::mocks::Id; use jsonrpsee_types::error::ErrorObjectOwned; +use jsonrpsee_types::{Id as RequestId, IdGeneratorFn}; fn init_logger() { let _ = tracing_subscriber::FmtSubscriber::builder() @@ -251,6 +252,26 @@ async fn batch_request_out_of_order_response() { assert_eq!(response, vec!["hello".to_string(), "goodbye".to_string(), "here's your swag".to_string()]); } +#[tokio::test] +async fn batch_request_with_custom_id_out_of_order_response() { + let mut batch_request = BatchRequestBuilder::new(); + batch_request.insert("say_hello", rpc_params![]).unwrap(); + batch_request.insert("say_goodbye", rpc_params![0_u64, 1, 2]).unwrap(); + batch_request.insert("get_swag", rpc_params![]).unwrap(); + let server_response = r#"[{"jsonrpc":"2.0","result":"here's your swag","id":2}, {"jsonrpc":"2.0","result":"hello","id":0}, {"jsonrpc":"2.0","result":"goodbye","id":1}]"#.to_string(); + let res = run_batch_request_with_custom_id::(batch_request, server_response, generate_predictable_id) + .with_default_timeout() + .await + .unwrap() + .unwrap(); + assert_eq!(res.num_successful_calls(), 3); + assert_eq!(res.num_failed_calls(), 0); + assert_eq!(res.len(), 3); + let response: Vec<_> = res.into_ok().unwrap().collect(); + + assert_eq!(response, vec!["hello".to_string(), "goodbye".to_string(), "here's your swag".to_string(),]); +} + async fn run_batch_request_with_response( batch: BatchRequestBuilder<'_>, response: String, @@ -276,3 +297,21 @@ fn assert_jsonrpc_error_response(err: ClientError, exp: ErrorObjectOwned) { e => panic!("Expected error: \"{err}\", got: {e:?}"), }; } + +async fn run_batch_request_with_custom_id( + batch: BatchRequestBuilder<'_>, + response: String, + id_generator: fn() -> RequestId<'static>, +) -> Result, ClientError> { + let server_addr = http_server_with_hardcoded_response(response).with_default_timeout().await.unwrap(); + let uri = format!("http://{server_addr}"); + let client = + HttpClientBuilder::default().id_format(IdKind::Custom(IdGeneratorFn::new(id_generator))).build(&uri).unwrap(); + client.batch_request(batch).with_default_timeout().await.unwrap() +} + +fn generate_predictable_id() -> RequestId<'static> { + static COUNTER: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(0); + let id = COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + RequestId::Number(id.try_into().unwrap()) +} diff --git a/core/src/client/async_client/helpers.rs b/core/src/client/async_client/helpers.rs index 34be025887..ef8dbbc162 100644 --- a/core/src/client/async_client/helpers.rs +++ b/core/src/client/async_client/helpers.rs @@ -44,7 +44,6 @@ use jsonrpsee_types::{ TwoPointZero, }; use std::borrow::Cow; -use std::ops::Range; /// Attempts to process a batch response. /// @@ -52,34 +51,29 @@ use std::ops::Range; pub(crate) fn process_batch_response( manager: &mut RequestManager, rps: Vec, - range: Range, + ids: Vec>, ) -> Result<(), InvalidRequestId> { let mut responses = Vec::with_capacity(rps.len()); - let start_idx = range.start; - - let batch_state = match manager.complete_pending_batch(range.clone()) { + let batch_state = match manager.complete_pending_batch(ids.clone()) { Some(state) => state, None => { tracing::debug!(target: LOG_TARGET, "Received unknown batch response"); - return Err(InvalidRequestId::NotPendingRequest(format!("{:?}", range))); + return Err(InvalidRequestId::NotPendingRequest(format!("{:?}", ids))); } }; - for _ in range { + for _ in &ids { let err_obj = ErrorObject::borrowed(0, "", None); responses.push(Response::new(jsonrpsee_types::ResponsePayload::error(err_obj), Id::Null).into()); } for rp in rps { - let id = rp.id().try_parse_inner_as_number()?; - let maybe_elem = - id.checked_sub(start_idx).and_then(|p| p.try_into().ok()).and_then(|p: usize| responses.get_mut(p)); - - if let Some(elem) = maybe_elem { - *elem = rp; + let response_id = rp.id(); + if let Some(pos) = ids.iter().position(|id| id == response_id) { + responses[pos] = rp.into() } else { - return Err(InvalidRequestId::NotPendingRequest(rp.id().to_string())); + return Err(InvalidRequestId::NotPendingRequest(response_id.to_string())); } } diff --git a/core/src/client/async_client/manager.rs b/core/src/client/async_client/manager.rs index 333d745804..31cf934878 100644 --- a/core/src/client/async_client/manager.rs +++ b/core/src/client/async_client/manager.rs @@ -32,10 +32,7 @@ //! > **Note**: The spec allow number, string or null but this crate only supports numbers. //! - SubscriptionId: unique ID generated by server -use std::{ - collections::{HashMap, hash_map::Entry}, - ops::Range, -}; +use std::collections::{HashMap, hash_map::Entry}; use crate::{ client::{Error, RawResponseOwned, SubscriptionReceiver, SubscriptionSender}, @@ -90,7 +87,7 @@ pub(crate) struct RequestManager { /// requests. subscriptions: HashMap, RequestId>, /// Pending batch requests. - batches: FxHashMap, BatchState>, + batches: Vec<(Vec>, BatchState)>, /// Registered Methods for incoming notifications. notification_handlers: HashMap, } @@ -123,15 +120,15 @@ impl RequestManager { /// Returns `Ok` if the pending request was successfully inserted otherwise `Err`. pub(crate) fn insert_pending_batch( &mut self, - batch: Range, + batch: Vec>, send_back: PendingBatchOneshot, ) -> Result<(), PendingBatchOneshot> { - if let Entry::Vacant(v) = self.batches.entry(batch) { - v.insert(BatchState { send_back }); - Ok(()) - } else { - Err(send_back) + if self.batches.iter().any(|(existing_batch, _)| existing_batch == &batch) { + return Err(send_back); } + + self.batches.push((batch, BatchState { send_back })); + Ok(()) } /// Tries to insert a new pending subscription and reserves a slot for a "potential" unsubscription request. @@ -222,14 +219,24 @@ impl RequestManager { /// Tries to complete a pending batch request. /// /// Returns `Some` if the subscription was completed otherwise `None`. - pub(crate) fn complete_pending_batch(&mut self, batch: Range) -> Option { - match self.batches.entry(batch) { - Entry::Occupied(request) => { - let (_digest, state) = request.remove_entry(); - Some(state) + pub(crate) fn complete_pending_batch(&mut self, batch: Vec>) -> Option { + let mut matched_key = None; + + for (key, _) in self.batches.iter() { + if key.len() == batch.len() && batch.iter().all(|id| key.contains(id)) { + matched_key = Some(key.clone()); + break; } - _ => None, } + + if let Some(key) = matched_key { + if let Some(pos) = self.batches.iter().position(|(existing_batch, _)| existing_batch == &key) { + let (_, state) = self.batches.remove(pos); + return Some(state); + } + } + + None } /// Tries to complete a pending call. diff --git a/core/src/client/async_client/mod.rs b/core/src/client/async_client/mod.rs index a5117fe4a5..2536a91494 100644 --- a/core/src/client/async_client/mod.rs +++ b/core/src/client/async_client/mod.rs @@ -66,7 +66,7 @@ use tokio::sync::{mpsc, oneshot}; use tower::layer::util::Identity; use self::utils::{InactivityCheck, IntervalStream}; -use super::{FrontToBack, IdKind, MethodResponse, RequestIdManager, generate_batch_id_range, subscription_channel}; +use super::{FrontToBack, IdKind, MethodResponse, RequestIdManager, subscription_channel}; pub(crate) type Notification<'a> = jsonrpsee_types::Notification<'a, Option>>; @@ -536,22 +536,23 @@ where { async { let batch = batch.build()?; - let id = self.id_manager.next_request_id(); - let id_range = generate_batch_id_range(id, batch.len() as u64)?; + let mut ids = Vec::new(); let mut b = Batch::with_capacity(batch.len()); - for ((method, params), id) in batch.into_iter().zip(id_range.clone()) { + for (method, params) in batch.into_iter() { + let id = self.id_manager.next_request_id(); b.push(Request { jsonrpc: TwoPointZero, - id: self.id_manager.as_id_kind().into_id(id), + id: id.clone(), method: method.into(), params: params.map(StdCow::Owned), extensions: Extensions::new(), }); + ids.push(id); } - b.extensions_mut().insert(IsBatch { id_range }); + b.extensions_mut().insert(IsBatch { ids }); let fut = self.service.batch(b); let json_values = self.run_future_until_timeout(fut).await?.into_batch().expect("Batch response"); @@ -709,27 +710,21 @@ fn handle_backend_messages( } Some(b'[') => { // Batch response. - if let Ok(raw_responses) = serde_json::from_slice::>(raw) { + if let Ok(raw_responses) = serde_json::from_slice::>>(raw) { let mut batch = Vec::with_capacity(raw_responses.len()); - let mut range = None; + let mut ids = Vec::new(); let mut got_notif = false; for r in raw_responses { - if let Ok(response) = serde_json::from_str::>(r.get()) { - let id = response.id.try_parse_inner_as_number()?; - batch.push(response.into_owned().into()); + let json_string = r.get().to_string(); - let r = range.get_or_insert(id..id); - - if id < r.start { - r.start = id; - } - - if id > r.end { - r.end = id; - } - } else if let Ok(response) = serde_json::from_str::>(r.get()) { + if let Ok(response) = serde_json::from_str::>(&json_string) { + let id = response.id.clone().into_owned(); + // let result = ResponseSuccess::try_from(response).map(|s| s.result); + batch.push(response.into_owned().into()); + ids.push(id); + } else if let Ok(response) = serde_json::from_str::>(&json_string) { got_notif = true; if let Some(sub_id) = process_subscription_response(&mut manager.lock(), response) { messages.push(FrontToBack::SubscriptionClosed(sub_id)); @@ -737,7 +732,7 @@ fn handle_backend_messages( } else if let Ok(response) = serde_json::from_slice::>(raw) { got_notif = true; process_subscription_close_response(&mut manager.lock(), response); - } else if let Ok(notif) = serde_json::from_str::(r.get()) { + } else if let Ok(notif) = serde_json::from_str::(&json_string) { got_notif = true; process_notification(&mut manager.lock(), notif); } else { @@ -745,10 +740,8 @@ fn handle_backend_messages( }; } - if let Some(mut range) = range { - // the range is exclusive so need to add one. - range.end += 1; - process_batch_response(&mut manager.lock(), batch, range)?; + if ids.len().gt(&0) { + process_batch_response(&mut manager.lock(), batch, ids)?; } else if !got_notif { return Err(EmptyBatchRequest.into()); } diff --git a/core/src/client/async_client/rpc_service.rs b/core/src/client/async_client/rpc_service.rs index efcbe23451..44702c5b8c 100644 --- a/core/src/client/async_client/rpc_service.rs +++ b/core/src/client/async_client/rpc_service.rs @@ -107,13 +107,13 @@ impl RpcServiceT for RpcService { let (send_back_tx, send_back_rx) = oneshot::channel(); let raw = serde_json::to_string(&batch).map_err(client_err)?; - let id_range = batch + let ids = batch .extensions() .get::() - .map(|b| b.id_range.clone()) + .map(|b| b.ids.clone()) .expect("Batch ID range must be set in extensions"); - tx.send(FrontToBack::Batch(BatchMessage { raw, ids: id_range, send_back: send_back_tx })).await?; + tx.send(FrontToBack::Batch(BatchMessage { raw, ids, send_back: send_back_tx })).await?; let json = send_back_rx.await?.map_err(client_err)?; Ok(MethodResponse::batch(json, batch.into_extensions())) diff --git a/core/src/client/mod.rs b/core/src/client/mod.rs index 68958aac22..5c84de5f09 100644 --- a/core/src/client/mod.rs +++ b/core/src/client/mod.rs @@ -34,10 +34,10 @@ cfg_async_client! { pub mod error; pub use error::Error; +use jsonrpsee_types::request::IdGeneratorFn; use std::fmt; use std::future::Future; -use std::ops::Range; use std::pin::Pin; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::{Arc, RwLock}; @@ -338,7 +338,7 @@ struct BatchMessage { /// Serialized batch request. raw: String, /// Request IDs. - ids: Range, + ids: Vec>, /// One-shot channel over which we send back the result of this request. send_back: oneshot::Sender, InvalidRequestId>>, } @@ -473,7 +473,17 @@ impl RequestIdManager { /// Attempts to get the next request ID. pub fn next_request_id(&self) -> Id<'static> { - self.id_kind.into_id(self.current_id.next()) + match self.id_kind { + IdKind::Number => { + let id = self.current_id.next(); + Id::Number(id) + } + IdKind::String => { + let id = self.current_id.next(); + Id::Str(format!("{id}").into()) + } + IdKind::Custom(generator) => generator.call(), + } } /// Get a handle to the `IdKind`. @@ -489,16 +499,8 @@ pub enum IdKind { String, /// Number. Number, -} - -impl IdKind { - /// Generate an `Id` from number. - pub fn into_id(self, id: u64) -> Id<'static> { - match self { - IdKind::Number => Id::Number(id), - IdKind::String => Id::Str(format!("{id}").into()), - } - } + /// Custom generator. + Custom(IdGeneratorFn), } #[derive(Debug)] @@ -517,16 +519,6 @@ impl CurrentId { } } -/// Generate a range of IDs to be used in a batch request. -pub fn generate_batch_id_range(id: Id, len: u64) -> Result, Error> { - let id_start = id.try_parse_inner_as_number()?; - let id_end = id_start - .checked_add(len) - .ok_or_else(|| Error::Custom("BatchID range wrapped; restart the client or try again later".to_string()))?; - - Ok(id_start..id_end) -} - /// Represent a single entry in a batch response. pub type BatchEntry<'a, R> = Result>; diff --git a/core/src/middleware/mod.rs b/core/src/middleware/mod.rs index 180d2bdbd3..17f9dd8461 100644 --- a/core/src/middleware/mod.rs +++ b/core/src/middleware/mod.rs @@ -200,8 +200,8 @@ impl IsSubscription { /// An extension type for the [`RpcServiceT::batch`] for the expected id range of the batch entries. #[derive(Debug, Clone)] pub struct IsBatch { - /// The range of ids for the batch entries. - pub id_range: std::ops::Range, + /// The ids for the batch entries. + pub ids: Vec>, } /// A batch entry specific for the [`RpcServiceT::batch`] method to support both diff --git a/examples/examples/core_client_with_request_id.rs b/examples/examples/core_client_with_request_id.rs new file mode 100644 index 0000000000..0e5b4b72bc --- /dev/null +++ b/examples/examples/core_client_with_request_id.rs @@ -0,0 +1,74 @@ +// Copyright 2019-2021 Parity Technologies (UK) Ltd. +// +// Permission is hereby granted, free of charge, to any +// person obtaining a copy of this software and associated +// documentation files (the "Software"), to deal in the +// Software without restriction, including without +// limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of +// the Software, and to permit persons to whom the Software +// is furnished to do so, subject to the following +// conditions: +// +// The above copyright notice and this permission notice +// shall be included in all copies or substantial portions +// of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +// SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +use std::net::SocketAddr; + +use jsonrpsee::client_transport::ws::{Url, WsTransportClientBuilder}; +use jsonrpsee::core::client::{ClientBuilder, ClientT, IdKind}; +use jsonrpsee::rpc_params; +use jsonrpsee::server::{RpcModule, Server}; +use jsonrpsee::types::{Id, IdGeneratorFn}; + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + tracing_subscriber::FmtSubscriber::builder() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init() + .expect("setting default subscriber failed"); + + let addr = run_server().await?; + let uri = Url::parse(&format!("ws://{}", addr))?; + + let custom_generator = IdGeneratorFn::new(generate_timestamp_id); + + let (tx, rx) = WsTransportClientBuilder::default().build(uri).await?; + let client = ClientBuilder::default().id_format(IdKind::Custom(custom_generator)).build_with_tokio(tx, rx); + let response: String = client.request("say_hello", rpc_params![]).await?; + tracing::info!("response: {:?}", response); + + Ok(()) +} + +async fn run_server() -> anyhow::Result { + let server = Server::builder().build("127.0.0.1:0").await?; + let mut module = RpcModule::new(()); + module.register_method("say_hello", |_, _, _| "lo")?; + let addr = server.local_addr()?; + + let handle = server.start(module); + + // In this example we don't care about doing shutdown so let's it run forever. + // You may use the `ServerHandle` to shut it down or manage it yourself. + tokio::spawn(handle.stopped()); + + Ok(addr) +} + +fn generate_timestamp_id() -> Id<'static> { + let timestamp = + std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).expect("Time went backwards").as_secs(); + Id::Number(timestamp) +} diff --git a/types/src/lib.rs b/types/src/lib.rs index c5dcd5771a..4781d658ac 100644 --- a/types/src/lib.rs +++ b/types/src/lib.rs @@ -46,5 +46,5 @@ pub mod error; pub use error::{ErrorCode, ErrorObject, ErrorObjectOwned}; pub use http::Extensions; pub use params::{Id, InvalidRequestId, Params, ParamsSequence, SubscriptionId, TwoPointZero}; -pub use request::{InvalidRequest, Notification, Request}; +pub use request::{IdGeneratorFn, InvalidRequest, Notification, Request}; pub use response::{Response, ResponsePayload, SubscriptionPayload, SubscriptionResponse, Success as ResponseSuccess}; diff --git a/types/src/request.rs b/types/src/request.rs index ece555fc38..ecb1ccb16a 100644 --- a/types/src/request.rs +++ b/types/src/request.rs @@ -27,7 +27,10 @@ //! Types to handle JSON-RPC requests according to the [spec](https://www.jsonrpc.org/specification#request-object). //! Some types come with a "*Ser" variant that implements [`serde::Serialize`]; these are used in the client. -use std::borrow::Cow; +use std::{ + borrow::Cow, + fmt::{Debug, Formatter, Result}, +}; use crate::{ Params, @@ -157,6 +160,35 @@ impl<'a, T> Notification<'a, T> { } } +/// Custom id generator function +pub struct IdGeneratorFn(fn() -> Id<'static>); + +impl IdGeneratorFn { + /// Creates a new `IdGeneratorFn` from a function pointer. + pub fn new(generator: fn() -> Id<'static>) -> Self { + IdGeneratorFn(generator) + } + + /// Calls the id generator function + pub fn call(&self) -> Id<'static> { + (self.0)() + } +} + +impl Copy for IdGeneratorFn {} + +impl Clone for IdGeneratorFn { + fn clone(&self) -> Self { + *self + } +} + +impl Debug for IdGeneratorFn { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + f.write_str("") + } +} + #[cfg(test)] mod test { use super::{Id, InvalidRequest, Notification, Request, TwoPointZero};