Skip to content
Open
27 changes: 10 additions & 17 deletions client/http-client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ use hyper::body::Bytes;
use hyper::http::{Extensions, HeaderMap};
use jsonrpsee_core::client::{
BatchResponse, ClientT, Error, IdKind, MethodResponse, RequestIdManager, Subscription, SubscriptionClientT,
generate_batch_id_range,
};
use jsonrpsee_core::middleware::layer::RpcLoggerLayer;
use jsonrpsee_core::middleware::{Batch, RpcServiceBuilder, RpcServiceT};
Expand Down Expand Up @@ -415,20 +414,19 @@ where
None => None,
};
let batch = batch.build()?;
let id = self.id_manager.next_request_id();
let id_range = generate_batch_id_range(id, batch.len() as u64)?;
let mut ids = Vec::new();

let mut batch_request = Batch::with_capacity(batch.len());
for ((method, params), id) in batch.into_iter().zip(id_range.clone()) {
let id = self.id_manager.as_id_kind().into_id(id);
let req = Request {
for (method, params) in batch.into_iter() {
let id = self.id_manager.next_request_id();
batch_request.push(Request {
jsonrpc: TwoPointZero,
method: method.into(),
params: params.map(StdCow::Owned),
id,
id: id.clone(),
extensions: Extensions::new(),
};
batch_request.push(req);
});
ids.push(id);
}

let rp = run_future_until_timeout(self.service.batch(batch_request), self.request_timeout).await?;
Expand All @@ -444,7 +442,7 @@ where
}

for rp in json_rps.into_iter() {
let id = rp.id().try_parse_inner_as_number()?;
let id = rp.id().clone();

let res = match ResponseSuccess::try_from(rp.into_inner()) {
Ok(r) => {
Expand All @@ -458,13 +456,8 @@ where
}
};

let maybe_elem = id
.checked_sub(id_range.start)
.and_then(|p| p.try_into().ok())
.and_then(|p: usize| batch_response.get_mut(p));

if let Some(elem) = maybe_elem {
*elem = res;
if let Some(pos) = ids.iter().position(|stored_id| stored_id == &id) {
batch_response[pos] = res;
} else {
return Err(InvalidRequestId::NotPendingRequest(id.to_string()).into());
}
Expand Down
39 changes: 39 additions & 0 deletions client/http-client/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ use jsonrpsee_test_utils::TimeoutFutureExt;
use jsonrpsee_test_utils::helpers::*;
use jsonrpsee_test_utils::mocks::Id;
use jsonrpsee_types::error::ErrorObjectOwned;
use jsonrpsee_types::{Id as RequestId, IdGeneratorFn};

fn init_logger() {
let _ = tracing_subscriber::FmtSubscriber::builder()
Expand Down Expand Up @@ -251,6 +252,26 @@ async fn batch_request_out_of_order_response() {
assert_eq!(response, vec!["hello".to_string(), "goodbye".to_string(), "here's your swag".to_string()]);
}

#[tokio::test]
async fn batch_request_with_custom_id_out_of_order_response() {
let mut batch_request = BatchRequestBuilder::new();
batch_request.insert("say_hello", rpc_params![]).unwrap();
batch_request.insert("say_goodbye", rpc_params![0_u64, 1, 2]).unwrap();
batch_request.insert("get_swag", rpc_params![]).unwrap();
let server_response = r#"[{"jsonrpc":"2.0","result":"here's your swag","id":2}, {"jsonrpc":"2.0","result":"hello","id":0}, {"jsonrpc":"2.0","result":"goodbye","id":1}]"#.to_string();
let res = run_batch_request_with_custom_id::<String>(batch_request, server_response, generate_predictable_id)
.with_default_timeout()
.await
.unwrap()
.unwrap();
assert_eq!(res.num_successful_calls(), 3);
assert_eq!(res.num_failed_calls(), 0);
assert_eq!(res.len(), 3);
let response: Vec<_> = res.into_ok().unwrap().collect();

assert_eq!(response, vec!["hello".to_string(), "goodbye".to_string(), "here's your swag".to_string(),]);
}

async fn run_batch_request_with_response<T: Send + DeserializeOwned + std::fmt::Debug + Clone + 'static>(
batch: BatchRequestBuilder<'_>,
response: String,
Expand All @@ -276,3 +297,21 @@ fn assert_jsonrpc_error_response(err: ClientError, exp: ErrorObjectOwned) {
e => panic!("Expected error: \"{err}\", got: {e:?}"),
};
}

async fn run_batch_request_with_custom_id<T: Send + DeserializeOwned + std::fmt::Debug + Clone + 'static>(
batch: BatchRequestBuilder<'_>,
response: String,
id_generator: fn() -> RequestId<'static>,
) -> Result<BatchResponse<T>, ClientError> {
let server_addr = http_server_with_hardcoded_response(response).with_default_timeout().await.unwrap();
let uri = format!("http://{server_addr}");
let client =
HttpClientBuilder::default().id_format(IdKind::Custom(IdGeneratorFn::new(id_generator))).build(&uri).unwrap();
client.batch_request(batch).with_default_timeout().await.unwrap()
}

fn generate_predictable_id() -> RequestId<'static> {
static COUNTER: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(0);
let id = COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
RequestId::Number(id.try_into().unwrap())
}
22 changes: 8 additions & 14 deletions core/src/client/async_client/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,42 +44,36 @@ use jsonrpsee_types::{
TwoPointZero,
};
use std::borrow::Cow;
use std::ops::Range;

/// Attempts to process a batch response.
///
/// On success the result is sent to the frontend.
pub(crate) fn process_batch_response(
manager: &mut RequestManager,
rps: Vec<RawResponseOwned>,
range: Range<u64>,
ids: Vec<Id<'static>>,
) -> Result<(), InvalidRequestId> {
let mut responses = Vec::with_capacity(rps.len());

let start_idx = range.start;

let batch_state = match manager.complete_pending_batch(range.clone()) {
let batch_state = match manager.complete_pending_batch(ids.clone()) {
Some(state) => state,
None => {
tracing::debug!(target: LOG_TARGET, "Received unknown batch response");
return Err(InvalidRequestId::NotPendingRequest(format!("{:?}", range)));
return Err(InvalidRequestId::NotPendingRequest(format!("{:?}", ids)));
}
};

for _ in range {
for _ in &ids {
let err_obj = ErrorObject::borrowed(0, "", None);
responses.push(Response::new(jsonrpsee_types::ResponsePayload::error(err_obj), Id::Null).into());
}

for rp in rps {
let id = rp.id().try_parse_inner_as_number()?;
let maybe_elem =
id.checked_sub(start_idx).and_then(|p| p.try_into().ok()).and_then(|p: usize| responses.get_mut(p));

if let Some(elem) = maybe_elem {
*elem = rp;
let response_id = rp.id();
if let Some(pos) = ids.iter().position(|id| id == response_id) {
responses[pos] = rp.into()
} else {
return Err(InvalidRequestId::NotPendingRequest(rp.id().to_string()));
return Err(InvalidRequestId::NotPendingRequest(response_id.to_string()));
}
}

Expand Down
41 changes: 24 additions & 17 deletions core/src/client/async_client/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,7 @@
//! > **Note**: The spec allow number, string or null but this crate only supports numbers.
//! - SubscriptionId: unique ID generated by server

use std::{
collections::{HashMap, hash_map::Entry},
ops::Range,
};
use std::collections::{HashMap, hash_map::Entry};

use crate::{
client::{Error, RawResponseOwned, SubscriptionReceiver, SubscriptionSender},
Expand Down Expand Up @@ -90,7 +87,7 @@ pub(crate) struct RequestManager {
/// requests.
subscriptions: HashMap<SubscriptionId<'static>, RequestId>,
/// Pending batch requests.
batches: FxHashMap<Range<u64>, BatchState>,
batches: Vec<(Vec<Id<'static>>, BatchState)>,
/// Registered Methods for incoming notifications.
notification_handlers: HashMap<String, SubscriptionSink>,
}
Expand Down Expand Up @@ -123,15 +120,15 @@ impl RequestManager {
/// Returns `Ok` if the pending request was successfully inserted otherwise `Err`.
pub(crate) fn insert_pending_batch(
&mut self,
batch: Range<u64>,
batch: Vec<Id<'static>>,
send_back: PendingBatchOneshot,
) -> Result<(), PendingBatchOneshot> {
if let Entry::Vacant(v) = self.batches.entry(batch) {
v.insert(BatchState { send_back });
Ok(())
} else {
Err(send_back)
if self.batches.iter().any(|(existing_batch, _)| existing_batch == &batch) {
return Err(send_back);
}

self.batches.push((batch, BatchState { send_back }));
Ok(())
}

/// Tries to insert a new pending subscription and reserves a slot for a "potential" unsubscription request.
Expand Down Expand Up @@ -222,14 +219,24 @@ impl RequestManager {
/// Tries to complete a pending batch request.
///
/// Returns `Some` if the subscription was completed otherwise `None`.
pub(crate) fn complete_pending_batch(&mut self, batch: Range<u64>) -> Option<BatchState> {
match self.batches.entry(batch) {
Entry::Occupied(request) => {
let (_digest, state) = request.remove_entry();
Some(state)
pub(crate) fn complete_pending_batch(&mut self, batch: Vec<Id<'static>>) -> Option<BatchState> {
let mut matched_key = None;

for (key, _) in self.batches.iter() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, this is inefficient/annoying compared to the old code 🤔

if key.len() == batch.len() && batch.iter().all(|id| key.contains(id)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you could just use Vec partialeq/eq impl here.

matched_key = Some(key.clone());
break;
}
_ => None,
}

if let Some(key) = matched_key {
if let Some(pos) = self.batches.iter().position(|(existing_batch, _)| existing_batch == &key) {
let (_, state) = self.batches.remove(pos);
return Some(state);
}
}

None
}

/// Tries to complete a pending call.
Expand Down
45 changes: 19 additions & 26 deletions core/src/client/async_client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ use tokio::sync::{mpsc, oneshot};
use tower::layer::util::Identity;

use self::utils::{InactivityCheck, IntervalStream};
use super::{FrontToBack, IdKind, MethodResponse, RequestIdManager, generate_batch_id_range, subscription_channel};
use super::{FrontToBack, IdKind, MethodResponse, RequestIdManager, subscription_channel};

pub(crate) type Notification<'a> = jsonrpsee_types::Notification<'a, Option<Box<JsonRawValue>>>;

Expand Down Expand Up @@ -536,22 +536,23 @@ where
{
async {
let batch = batch.build()?;
let id = self.id_manager.next_request_id();
let id_range = generate_batch_id_range(id, batch.len() as u64)?;
let mut ids = Vec::new();

let mut b = Batch::with_capacity(batch.len());

for ((method, params), id) in batch.into_iter().zip(id_range.clone()) {
for (method, params) in batch.into_iter() {
let id = self.id_manager.next_request_id();
b.push(Request {
jsonrpc: TwoPointZero,
id: self.id_manager.as_id_kind().into_id(id),
id: id.clone(),
method: method.into(),
params: params.map(StdCow::Owned),
extensions: Extensions::new(),
});
ids.push(id);
}

b.extensions_mut().insert(IsBatch { id_range });
b.extensions_mut().insert(IsBatch { ids });

let fut = self.service.batch(b);
let json_values = self.run_future_until_timeout(fut).await?.into_batch().expect("Batch response");
Expand Down Expand Up @@ -709,46 +710,38 @@ fn handle_backend_messages<R: TransportReceiverT>(
}
Some(b'[') => {
// Batch response.
if let Ok(raw_responses) = serde_json::from_slice::<Vec<&JsonRawValue>>(raw) {
if let Ok(raw_responses) = serde_json::from_slice::<Vec<Box<JsonRawValue>>>(raw) {
let mut batch = Vec::with_capacity(raw_responses.len());

let mut range = None;
let mut ids = Vec::new();
let mut got_notif = false;

for r in raw_responses {
if let Ok(response) = serde_json::from_str::<Response<_>>(r.get()) {
let id = response.id.try_parse_inner_as_number()?;
batch.push(response.into_owned().into());
let json_string = r.get().to_string();

let r = range.get_or_insert(id..id);

if id < r.start {
r.start = id;
}

if id > r.end {
r.end = id;
}
} else if let Ok(response) = serde_json::from_str::<SubscriptionResponse<_>>(r.get()) {
if let Ok(response) = serde_json::from_str::<Response<_>>(&json_string) {
let id = response.id.clone().into_owned();
// let result = ResponseSuccess::try_from(response).map(|s| s.result);
batch.push(response.into_owned().into());
ids.push(id);
} else if let Ok(response) = serde_json::from_str::<SubscriptionResponse<_>>(&json_string) {
got_notif = true;
if let Some(sub_id) = process_subscription_response(&mut manager.lock(), response) {
messages.push(FrontToBack::SubscriptionClosed(sub_id));
}
} else if let Ok(response) = serde_json::from_slice::<SubscriptionError<_>>(raw) {
got_notif = true;
process_subscription_close_response(&mut manager.lock(), response);
} else if let Ok(notif) = serde_json::from_str::<Notification>(r.get()) {
} else if let Ok(notif) = serde_json::from_str::<Notification>(&json_string) {
got_notif = true;
process_notification(&mut manager.lock(), notif);
} else {
return Err(unparse_error(raw));
};
}

if let Some(mut range) = range {
// the range is exclusive so need to add one.
range.end += 1;
process_batch_response(&mut manager.lock(), batch, range)?;
if ids.len().gt(&0) {
process_batch_response(&mut manager.lock(), batch, ids)?;
} else if !got_notif {
return Err(EmptyBatchRequest.into());
}
Expand Down
6 changes: 3 additions & 3 deletions core/src/client/async_client/rpc_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,13 @@ impl RpcServiceT for RpcService {
let (send_back_tx, send_back_rx) = oneshot::channel();

let raw = serde_json::to_string(&batch).map_err(client_err)?;
let id_range = batch
let ids = batch
.extensions()
.get::<IsBatch>()
.map(|b| b.id_range.clone())
.map(|b| b.ids.clone())
.expect("Batch ID range must be set in extensions");

tx.send(FrontToBack::Batch(BatchMessage { raw, ids: id_range, send_back: send_back_tx })).await?;
tx.send(FrontToBack::Batch(BatchMessage { raw, ids, send_back: send_back_tx })).await?;
let json = send_back_rx.await?.map_err(client_err)?;

Ok(MethodResponse::batch(json, batch.into_extensions()))
Expand Down
Loading