diff --git a/core/src/server/rpc_module.rs b/core/src/server/rpc_module.rs index 541535e12c..55d6e66de1 100644 --- a/core/src/server/rpc_module.rs +++ b/core/src/server/rpc_module.rs @@ -28,7 +28,7 @@ use std::collections::hash_map::Entry; use std::fmt::{self, Debug}; use std::future::Future; use std::ops::{Deref, DerefMut}; -use std::sync::Arc; +use std::sync::{Arc, Weak}; use crate::error::RegisterMethodError; use crate::id_providers::RandomIntegerIdProvider; @@ -58,19 +58,34 @@ use super::IntoResponse; /// the `id`, `params`, a channel the function uses to communicate the result (or error) /// back to `jsonrpsee`, and the connection ID (useful for the websocket transport). pub type SyncMethod = Arc MethodResponse>; +// Weak reference to a synchronous method. +type WeakSyncMethod = Weak MethodResponse>; + /// Similar to [`SyncMethod`], but represents an asynchronous handler. pub type AsyncMethod<'a> = Arc< dyn Send + Sync + Fn(Id<'a>, Params<'a>, ConnectionId, MaxResponseSize, Extensions) -> BoxFuture<'a, MethodResponse>, >; +// Weak reference to an asynchronous method. +type WeakAsyncMethod<'a> = Weak< + dyn Send + + Sync + + Fn(Id<'a>, Params<'a>, ConnectionId, MaxResponseSize, Extensions) -> BoxFuture<'a, MethodResponse>, +>; /// Method callback for subscriptions. pub type SubscriptionMethod<'a> = Arc BoxFuture<'a, MethodResponse>>; +// Weak reference to a subscription method. +type WeakSubscriptionMethod<'a> = + Weak BoxFuture<'a, MethodResponse>>; // Method callback to unsubscribe. type UnsubscriptionMethod = Arc MethodResponse>; +// Weak reference to an unsubscription method. +type WeakUnsubscriptionMethod = + Weak MethodResponse>; /// Connection ID. #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Default, serde::Deserialize, serde::Serialize)] @@ -142,6 +157,52 @@ impl CallOrSubscription { } } +/// The [`Weak`] equivalent of [`MethodCallback`]. This allows the [`Methods`] +/// map to store aliases to the method callbacks and ensures that any dangling +/// alias references are safely handled. +/// +/// This type is an implementation detail of the method aliasing system and +/// should not be used directly. +#[derive(Clone, Debug)] +pub enum WeakMethodCallback<'a> { + /// Synchronous method handler. + Sync(WeakSyncMethod), + /// Asynchronous method handler. + Async(WeakAsyncMethod<'a>), + /// Subscription method handler. + Subscription { + /// The subscription method handler. + method: WeakSubscriptionMethod<'a>, + /// The name of the associated unsubscription method. + unsubscribe_method_name: &'static str, + }, + /// Unsubscription method handler. + Unsubscription { + /// The unsubscription method handler, usually auto-generated on + /// subscription insertion. + method: WeakUnsubscriptionMethod, + /// The name of the associated subscription method. + subscribe_method_name: &'static str, + }, +} + +impl WeakMethodCallback<'static> { + /// Upgrade the weak reference to a strong reference, if the method has not + /// been dropped. See [`Weak::upgrade`]. + fn upgrade(&self) -> Option { + match self { + Self::Sync(w) => w.upgrade().map(MethodCallback::Sync), + Self::Async(w) => w.upgrade().map(MethodCallback::Async), + Self::Subscription { method, unsubscribe_method_name } => { + Some(MethodCallback::Subscription { method: method.upgrade()?, unsubscribe_method_name }) + } + Self::Unsubscription { method, subscribe_method_name } => { + Some(MethodCallback::Unsubscription { method: method.upgrade()?, subscribe_method_name }) + } + } + } +} + /// Callback wrapper that can be either sync or async. #[derive(Clone)] pub enum MethodCallback { @@ -150,9 +211,48 @@ pub enum MethodCallback { /// Asynchronous method handler. Async(AsyncMethod<'static>), /// Subscription method handler. - Subscription(SubscriptionMethod<'static>), + Subscription { + /// The subscription method handler. + method: SubscriptionMethod<'static>, + /// The name of the associated unsubscription method. + unsubscribe_method_name: &'static str, + }, /// Unsubscription method handler. - Unsubscription(UnsubscriptionMethod), + Unsubscription { + /// The unsubscription method handler, usually auto-generated on + /// subscription insertion. + method: UnsubscriptionMethod, + /// The name of the associated subscription method. + subscribe_method_name: &'static str, + }, + /// The method is an alias for another method. + Alias(WeakMethodCallback<'static>), +} + +impl MethodCallback { + /// Downgrade the method callback to a weak reference. This method will + /// return `None` if the method callback is an alias and the target of the + /// alias no longer exists. + fn downgrade(&self) -> Option> { + match self { + Self::Sync(cb) => Some(WeakMethodCallback::Sync(Arc::downgrade(cb))), + Self::Async(cb) => Some(WeakMethodCallback::Async(Arc::downgrade(cb))), + Self::Subscription { method, unsubscribe_method_name } => { + Some(WeakMethodCallback::Subscription { method: Arc::downgrade(method), unsubscribe_method_name }) + } + Self::Unsubscription { method, subscribe_method_name } => { + Some(WeakMethodCallback::Unsubscription { method: Arc::downgrade(method), subscribe_method_name }) + } + // We upgrade the alias to check that it still exists. + Self::Alias(alias) => alias.upgrade().as_ref().and_then(Self::downgrade), + } + } + + /// Create an alias for a method callback. This method will return `None` + /// if `self` is an alias and the target of the alias no longer exists. + fn to_alias(&self) -> Option { + self.downgrade().map(Self::Alias) + } } /// The kind of the JSON-RPC method call, it can be a subscription, method call or unknown. @@ -203,8 +303,9 @@ impl Debug for MethodCallback { match self { Self::Async(_) => write!(f, "Async"), Self::Sync(_) => write!(f, "Sync"), - Self::Subscription(_) => write!(f, "Subscription"), - Self::Unsubscription(_) => write!(f, "Unsubscription"), + Self::Subscription { .. } => write!(f, "Subscription"), + Self::Unsubscription { .. } => write!(f, "Unsubscription"), + Self::Alias(_) => write!(f, "Alias"), } } } @@ -221,6 +322,101 @@ impl Methods { Self::default() } + /// Helper for obtaining a mut ref to the callbacks HashMap. + /// + /// ## Warning + /// + /// Direct access to mut callbacks is discouraged and can result in + /// dangling handlers (e.g. a [`MethodCallback::Subscription`] with no + /// corresponding [`MethodCallback::Unsubscription`]). As such this method + /// should be used with caution. + /// + /// To access the underlying map correctly, you MUST NOT perform the + /// following operations without checking for dangling handlers: + /// + /// - `mut_callbacks().remove()` + /// - `mut_callbacks().insert()` + /// - `mut_callbacks().entry()` + /// - `mut_callbacks().drain()` + /// - `mut_callbacks().iter_mut()` + /// + /// Prefer using the provided methods: + /// - [`Self::remove`] + /// - [`Self::replace`] + /// - [`Self::verify_and_insert`] + /// - [`Self::replace_if`] + /// - [`Self::merge`] + /// - [`Self::merge_replace`] + /// - [`Self::merge_replace_if`] + fn mut_callbacks(&mut self) -> &mut FxHashMap<&'static str, MethodCallback> { + Arc::make_mut(&mut self.callbacks) + } + + /// Returns the method callback. + pub fn method(&self, method_name: &str) -> Option { + match self.callbacks.get(method_name) { + Some(MethodCallback::Alias(alias)) => alias.upgrade(), + other => other.cloned(), + } + } + + /// Returns an `Iterator` with all the method names registered on this server. + pub fn method_names(&self) -> impl Iterator + '_ { + self.callbacks.keys().copied() + } + + /// Returns the method callback along with its name. The returned name is same as the + /// `method_name`, but its lifetime bound is `'static`. + pub fn method_with_name(&self, method_name: &str) -> Option<(&'static str, MethodCallback)> { + self.callbacks.get_key_value(method_name).and_then(|(k, v)| { + let v = match v { + MethodCallback::Alias(alias) => alias.upgrade()?, + other => other.clone(), + }; + Some((*k, v)) + }) + } + + /// Register a new alias for a target method, or returns an error if the + /// alias name was already taken. If the target method does not exist, the + /// alias will not be registered. + pub fn register_alias( + &mut self, + alias: &'static str, + target_name: &'static str, + ) -> Result<(), RegisterMethodError> { + if let Some(method) = self.method(target_name).as_ref().and_then(MethodCallback::to_alias) { + self.verify_and_insert(alias, method)?; + } + Ok(()) + } + + /// Remove a method from the collection, returning the method removed (if + /// any). For subscription and unsubscribe methods, the `Subscription` + /// callback will be returned, regardless of whether the `method_name` + /// corresponds to the subscription or unsubscription method. + pub fn remove(&mut self, method_name: &str) -> Option { + let cb_mut = self.mut_callbacks(); + cb_mut.remove(method_name).map(|cb| { + if let MethodCallback::Subscription { ref unsubscribe_method_name, .. } = cb { + // remove and discard + cb_mut.remove(unsubscribe_method_name); + } + if let MethodCallback::Unsubscription { subscribe_method_name, .. } = cb { + return cb_mut.remove(subscribe_method_name).expect("map in invalid state - dangling unsubscribe"); + } + cb + }) + } + + /// Inserts the method callback for a given name, replacing any existing + /// method with the same name. + pub fn replace(&mut self, name: &'static str, callback: MethodCallback) -> Option { + let opt = self.remove(name); + self.mut_callbacks().insert(name, callback); + opt + } + /// Verifies that the method name is not already taken, and returns an error if it is. pub fn verify_method_name(&mut self, name: &'static str) -> Result<(), RegisterMethodError> { if self.callbacks.contains_key(name) { @@ -243,9 +439,26 @@ impl Methods { } } - /// Helper for obtaining a mut ref to the callbacks HashMap. - fn mut_callbacks(&mut self) -> &mut FxHashMap<&'static str, MethodCallback> { - Arc::make_mut(&mut self.callbacks) + /// Inserts the method callback for a given name. If a method with the same + /// name already exists, evaluates the provided `cond` with the method name + /// to determine whether to replace the existing method. + /// + /// Returns the previous method callback if it was replaced. + /// + /// Note: If a conflict exists, and `cond` evaluates to false, the method + /// will not be replaced and the new method will be dropped. + pub fn replace_if( + &mut self, + name: &'static str, + callback: MethodCallback, + cond: impl Fn(&'static str) -> bool, + ) -> Option { + if self.callbacks.contains_key(name) && !cond(name) { + return None; + } + + self.remove(name); + self.replace(name, callback) } /// Merge two [`Methods`]'s by adding all [`MethodCallback`]s from `other` into `self`. @@ -266,15 +479,66 @@ impl Methods { Ok(()) } - /// Returns the method callback. - pub fn method(&self, method_name: &str) -> Option<&MethodCallback> { - self.callbacks.get(method_name) + /// Merge two [`Methods`]'s by adding all [`MethodCallback`]s from `other` + /// into `self`, removing and returning any existing methods with the same + /// name. + pub fn merge_replace(&mut self, other: impl Into) -> Vec<(&'static str, MethodCallback)> { + let mut other = other.into(); + let mut removed = Vec::with_capacity(other.callbacks.len()); + + // NB: these loops must remain separate. To prevent inconsistent states + // with dangling unsubscribe callbacks, we must remove all names before + // inserting any new names. + for name in other.method_names() { + if let Some(prev) = self.remove(name) { + removed.push((name, prev)); + } + } + for (name, callback) in other.mut_callbacks().drain() { + self.replace(name, callback); + } + + removed } - /// Returns the method callback along with its name. The returned name is same as the - /// `method_name`, but its lifetime bound is `'static`. - pub fn method_with_name(&self, method_name: &str) -> Option<(&'static str, &MethodCallback)> { - self.callbacks.get_key_value(method_name).map(|(k, v)| (*k, v)) + /// Merge two [`Methods`]'s by adding all [`MethodCallback`]s from `other` + /// into `self`. If a method with the same name already exists, evaluates + /// the provided `cond` with the method name to determine whether to + /// replace the existing method. Uses [`Self::replace_if`] + /// internally. + /// + /// Returns a list of removed methods. + /// + /// Note: If a conflict exists, and `cond` evaluates to false, the method + /// will not be replaced and the new method will be dropped. Methods + /// dropped this way will **not** be included in the returned list. + pub fn merge_replace_if( + &mut self, + other: impl Into, + cond: impl Fn(&'static str) -> bool, + ) -> Vec<(&'static str, MethodCallback)> { + let mut other = other.into(); + let mut removed = Vec::with_capacity(other.callbacks.len()); + + // NB: these loops must remain separate. To prevent inconsistent states + // with dangling unsubscribe callbacks, we must remove all names before + // inserting any new names. + for name in other.method_names() { + if !cond(name) { + continue; + } + if let Some(prev) = self.remove(name) { + removed.push((name, prev)); + } + } + + for (name, callback) in other.mut_callbacks().drain() { + if !cond(name) { + continue; + } + self.replace(name, callback); + } + removed } /// Helper to call a method on the `RPC module` without having to spin up a server. @@ -377,7 +641,7 @@ impl Methods { Some(MethodCallback::Async(cb)) => { (cb)(id.into_owned(), params.into_owned(), conn_id, max_response_size, extensions).await } - Some(MethodCallback::Subscription(cb)) => { + Some(MethodCallback::Subscription { method: cb, .. }) => { let conn_state = SubscriptionState { conn_id, id_provider: &RandomIntegerIdProvider, subscription_permit }; let res = (cb)(id, params, MethodSink::new(tx.clone()), conn_state, extensions).await; @@ -390,7 +654,10 @@ impl Methods { res } - Some(MethodCallback::Unsubscription(cb)) => (cb)(id, params, conn_id, max_response_size, extensions), + Some(MethodCallback::Unsubscription { method: cb, .. }) => { + (cb)(id, params, conn_id, max_response_size, extensions) + } + Some(MethodCallback::Alias(_)) => unreachable!("alias resolved in `method_with_name"), }; let is_success = response.is_success(); @@ -463,11 +730,6 @@ impl Methods { Ok(Subscription { sub_id, rx }) } - - /// Returns an `Iterator` with all the method names registered on this server. - pub fn method_names(&self) -> impl Iterator + '_ { - self.callbacks.keys().copied() - } } impl Deref for RpcModule { @@ -521,6 +783,16 @@ impl From> for Methods { } impl RpcModule { + /// Return a mutable reference to the currently registered methods. + pub fn methods_mut(&mut self) -> &mut Methods { + &mut self.methods + } + + /// Remove a method by name. + fn remove_method(&mut self, method_name: &'static str) -> Option { + self.methods.remove(method_name) + } + /// Register a new synchronous RPC method, which computes the response with the given callback. /// /// ## Examples @@ -551,6 +823,20 @@ impl RpcModule { ) } + /// As [`Self::register_method`], but replaces and returns the method if + /// it already exists. + pub fn replace_method(&mut self, method_name: &'static str, callback: F) -> Option + where + R: IntoResponse + 'static, + F: Fn(Params, &Context, &Extensions) -> R + Send + Sync + 'static, + { + let prev = self.remove_method(method_name); + // Errors here can be ignored, as we know the method is not already + // registered (we just removed it). + let _ = self.register_method(method_name, callback); + prev + } + /// Register a new asynchronous RPC method, which computes the response with the given callback. /// /// ## Examples @@ -589,6 +875,25 @@ impl RpcModule { ) } + /// As [`Self::register_async_method`], but replaces and returns the method + /// if it already exists. + pub fn replace_async_method( + &mut self, + method_name: &'static str, + callback: Fun, + ) -> Option + where + R: IntoResponse + 'static, + Fut: Future + Send, + Fun: (Fn(Params<'static>, Arc, Extensions) -> Fut) + Clone + Send + Sync + 'static, + { + let prev = self.remove_method(method_name); + // Errors here can be ignored, as we know the method is not already + // registered (we just removed it). + let _ = self.register_async_method(method_name, callback); + prev + } + /// Register a new **blocking** synchronous RPC method, which computes the response with the given callback. /// Unlike the regular [`register_method`](RpcModule::register_method), this method can block its thread and perform /// expensive computations. @@ -632,6 +937,19 @@ impl RpcModule { Ok(callback) } + /// As [`Self::register_blocking_method`], but replaces and returns the method if it already exists. + pub fn replace_blocking_method(&mut self, method_name: &'static str, callback: F) -> Option + where + R: IntoResponse + 'static, + F: Fn(Params, Arc, &Extensions) -> R + Clone + Send + Sync + 'static, + { + let prev = self.remove_method(method_name); + // Errors here can be ignored, as we know the method is not already + // registered (we just removed it). + let _ = self.register_blocking_method(method_name, callback); + prev + } + /// Register a new publish/subscribe interface using JSON-RPC notifications. /// /// It implements the [ethereum pubsub specification](https://geth.ethereum.org/docs/rpc/pubsub) @@ -748,78 +1066,106 @@ impl RpcModule { // Subscribe let callback = { - self.methods.verify_and_insert( - subscribe_method_name, - MethodCallback::Subscription(Arc::new(move |id, params, method_sink, conn, extensions| { - let uniq_sub = SubscriptionKey { conn_id: conn.conn_id, sub_id: conn.id_provider.next_id() }; - - // response to the subscription call. - let (tx, rx) = oneshot::channel(); - let (accepted_tx, accepted_rx) = oneshot::channel(); - - let sub_id = uniq_sub.sub_id.clone(); - let method = notif_method_name; - - let sink = PendingSubscriptionSink { - inner: method_sink.clone(), - method: notif_method_name, - subscribers: subscribers.clone(), - uniq_sub, - id: id.clone().into_owned(), - subscribe: tx, - permit: conn.subscription_permit, + let method: SubscriptionMethod<'_> = Arc::new(move |id, params, method_sink, conn, extensions| { + let uniq_sub = SubscriptionKey { conn_id: conn.conn_id, sub_id: conn.id_provider.next_id() }; + + // response to the subscription call. + let (tx, rx) = oneshot::channel(); + let (accepted_tx, accepted_rx) = oneshot::channel(); + + let sub_id = uniq_sub.sub_id.clone(); + let method = notif_method_name; + + let sink = PendingSubscriptionSink { + inner: method_sink.clone(), + method: notif_method_name, + subscribers: subscribers.clone(), + uniq_sub, + id: id.clone().into_owned(), + subscribe: tx, + permit: conn.subscription_permit, + }; + + // The subscription callback is a future from the subscription + // definition and not the as same when the subscription call has been completed. + // + // This runs until the subscription callback has completed. + let sub_fut = callback(params.into_owned(), sink, ctx.clone(), extensions.clone()); + + tokio::spawn(async move { + // This will wait for the subscription future to be resolved + let response = match futures_util::future::try_join(sub_fut.map(|f| Ok(f)), accepted_rx).await { + Ok((r, _)) => r.into_response(), + // The accept call failed i.e, the subscription was not accepted. + Err(_) => return, }; - // The subscription callback is a future from the subscription - // definition and not the as same when the subscription call has been completed. - // - // This runs until the subscription callback has completed. - let sub_fut = callback(params.into_owned(), sink, ctx.clone(), extensions.clone()); - - tokio::spawn(async move { - // This will wait for the subscription future to be resolved - let response = match futures_util::future::try_join(sub_fut.map(|f| Ok(f)), accepted_rx).await { - Ok((r, _)) => r.into_response(), - // The accept call failed i.e, the subscription was not accepted. - Err(_) => return, - }; - - match response { - SubscriptionCloseResponse::Notif(msg) => { - let json = sub_message_to_json(msg, SubNotifResultOrError::Result, &sub_id, method); - let _ = method_sink.send(json).await; - } - SubscriptionCloseResponse::NotifErr(msg) => { - let json = sub_message_to_json(msg, SubNotifResultOrError::Error, &sub_id, method); - let _ = method_sink.send(json).await; - } - SubscriptionCloseResponse::None => (), + match response { + SubscriptionCloseResponse::Notif(msg) => { + let json = sub_message_to_json(msg, SubNotifResultOrError::Result, &sub_id, method); + let _ = method_sink.send(json).await; } - }); - - let id = id.clone().into_owned(); - - Box::pin(async move { - match rx.await { - Ok(mut rp) => { - // If the subscription was accepted then send a message - // to subscription task otherwise rely on the drop impl. - if rp.is_success() { - let _ = accepted_tx.send(()); - } - *rp.extensions_mut() = extensions; - rp + SubscriptionCloseResponse::NotifErr(msg) => { + let json = sub_message_to_json(msg, SubNotifResultOrError::Error, &sub_id, method); + let _ = method_sink.send(json).await; + } + SubscriptionCloseResponse::None => (), + } + }); + + let id = id.clone().into_owned(); + + Box::pin(async move { + match rx.await { + Ok(mut rp) => { + // If the subscription was accepted then send a message + // to subscription task otherwise rely on the drop impl. + if rp.is_success() { + let _ = accepted_tx.send(()); } - Err(_) => MethodResponse::error_with_extensions(id, ErrorCode::InternalError, extensions), + *rp.extensions_mut() = extensions; + rp } - }) - })), + Err(_) => MethodResponse::error_with_extensions(id, ErrorCode::InternalError, extensions), + } + }) + }); + self.methods.verify_and_insert( + subscribe_method_name, + MethodCallback::Subscription { method, unsubscribe_method_name }, )? }; Ok(callback) } + /// As [`Self::register_subscription`] but replaces and returns the method + /// if it already exists. + pub fn replace_subscription( + &mut self, + subscribe_method_name: &'static str, + notif_method_name: &'static str, + unsubscribe_method_name: &'static str, + callback: F, + ) -> Option + where + Context: Send + Sync + 'static, + F: (Fn(Params<'static>, PendingSubscriptionSink, Arc, Extensions) -> Fut) + + Send + + Sync + + Clone + + 'static, + Fut: Future + Send + 'static, + R: IntoSubscriptionCloseResponse + Send, + { + let prev = self.methods.remove(subscribe_method_name); + + // Errors here can be ignored, as we know the method is not already + // registered (we just removed it). + let _ = self.register_subscription(subscribe_method_name, notif_method_name, unsubscribe_method_name, callback); + prev + } + /// Similar to [`RpcModule::register_subscription`] but a little lower-level API /// where handling the subscription is managed the user i.e, polling the subscription /// such as spawning a separate task to do so. @@ -882,38 +1228,39 @@ impl RpcModule { // Subscribe let callback = { - self.methods.verify_and_insert( - subscribe_method_name, - MethodCallback::Subscription(Arc::new(move |id, params, method_sink, conn, extensions| { - let uniq_sub = SubscriptionKey { conn_id: conn.conn_id, sub_id: conn.id_provider.next_id() }; - - // response to the subscription call. - let (tx, rx) = oneshot::channel(); - - let sink = PendingSubscriptionSink { - inner: method_sink.clone(), - method: notif_method_name, - subscribers: subscribers.clone(), - uniq_sub, - id: id.clone().into_owned(), - subscribe: tx, - permit: conn.subscription_permit, - }; + let method: SubscriptionMethod<'_> = Arc::new(move |id, params, method_sink, conn, extensions| { + let uniq_sub = SubscriptionKey { conn_id: conn.conn_id, sub_id: conn.id_provider.next_id() }; + + // response to the subscription call. + let (tx, rx) = oneshot::channel(); + + let sink = PendingSubscriptionSink { + inner: method_sink.clone(), + method: notif_method_name, + subscribers: subscribers.clone(), + uniq_sub, + id: id.clone().into_owned(), + subscribe: tx, + permit: conn.subscription_permit, + }; - callback(params, sink, ctx.clone(), &extensions); + callback(params, sink, ctx.clone(), &extensions); - let id = id.clone().into_owned(); + let id = id.clone().into_owned(); - Box::pin(async move { - match rx.await { - Ok(mut rp) => { - *rp.extensions_mut() = extensions; - rp - } - Err(_) => MethodResponse::error_with_extensions(id, ErrorCode::InternalError, extensions), + Box::pin(async move { + match rx.await { + Ok(mut rp) => { + *rp.extensions_mut() = extensions; + rp } - }) - })), + Err(_) => MethodResponse::error_with_extensions(id, ErrorCode::InternalError, extensions), + } + }) + }); + self.methods.verify_and_insert( + subscribe_method_name, + MethodCallback::Subscription { method, unsubscribe_method_name }, )? }; @@ -939,65 +1286,58 @@ impl RpcModule { // Unsubscribe { let subscribers = subscribers.clone(); - self.methods.mut_callbacks().insert( - unsubscribe_method_name, - MethodCallback::Unsubscription(Arc::new(move |id, params, conn_id, max_response_size, extensions| { - let sub_id = match params.one::() { - Ok(sub_id) => sub_id, - Err(_) => { - tracing::warn!( - target: LOG_TARGET, - "Unsubscribe call `{}` failed: couldn't parse subscription id={:?} request id={:?}", - unsubscribe_method_name, - params, - id - ); - - return MethodResponse::response_with_extensions( - id, - ResponsePayload::success(false), - max_response_size, - extensions, - ); - } - }; - - let key = SubscriptionKey { conn_id, sub_id: sub_id.into_owned() }; - let result = subscribers.lock().remove(&key).is_some(); - - if !result { - tracing::debug!( + let method: UnsubscriptionMethod = Arc::new(move |id, params, conn_id, max_response_size, extensions| { + let sub_id = match params.one::() { + Ok(sub_id) => sub_id, + Err(_) => { + tracing::warn!( target: LOG_TARGET, - "Unsubscribe call `{}` subscription key={:?} not an active subscription", + "Unsubscribe call `{}` failed: couldn't parse subscription id={:?} request id={:?}", unsubscribe_method_name, - key, + params, + id + ); + + return MethodResponse::response_with_extensions( + id, + ResponsePayload::success(false), + max_response_size, + extensions, ); } + }; - MethodResponse::response(id, ResponsePayload::success(result), max_response_size) - })), - ); + let key = SubscriptionKey { conn_id, sub_id: sub_id.into_owned() }; + let result = subscribers.lock().remove(&key).is_some(); + + if !result { + tracing::debug!( + target: LOG_TARGET, + "Unsubscribe call `{}` subscription key={:?} not an active subscription", + unsubscribe_method_name, + key, + ); + } + + MethodResponse::response(id, ResponsePayload::success(result), max_response_size) + }); + self.methods + .mut_callbacks() + .insert(unsubscribe_method_name, MethodCallback::Unsubscription { method, subscribe_method_name }); } Ok(subscribers) } - /// Register an alias for an existing_method. Alias uniqueness is enforced. + /// Register a new alias for a target method, or returns an error if the + /// alias name was already taken. If the target method does not exist, the + /// alias will not be registered. pub fn register_alias( &mut self, alias: &'static str, existing_method: &'static str, ) -> Result<(), RegisterMethodError> { - self.methods.verify_method_name(alias)?; - - let callback = match self.methods.callbacks.get(existing_method) { - Some(callback) => callback.clone(), - None => return Err(RegisterMethodError::MethodNotFound(existing_method.into())), - }; - - self.methods.mut_callbacks().insert(alias, callback); - - Ok(()) + self.methods.register_alias(alias, existing_method) } } diff --git a/server/src/middleware/rpc/layer/rpc_service.rs b/server/src/middleware/rpc/layer/rpc_service.rs index 3b69bbb84a..b527d707d0 100644 --- a/server/src/middleware/rpc/layer/rpc_service.rs +++ b/server/src/middleware/rpc/layer/rpc_service.rs @@ -106,7 +106,7 @@ impl<'a> RpcServiceT<'a> for RpcService { let rp = (callback)(id, params, max_response_body_size, extensions); ResponseFuture::ready(rp) } - MethodCallback::Subscription(callback) => { + MethodCallback::Subscription { method: callback, .. } => { let RpcServiceCfg::CallsAndSubscriptions { bounded_subscriptions, sink, @@ -131,7 +131,7 @@ impl<'a> RpcServiceT<'a> for RpcService { ResponseFuture::ready(rp) } } - MethodCallback::Unsubscription(callback) => { + MethodCallback::Unsubscription { method: callback, .. } => { // Don't adhere to any resource or subscription limits; always let unsubscribing happen! let RpcServiceCfg::CallsAndSubscriptions { .. } = self.cfg else { @@ -143,6 +143,7 @@ impl<'a> RpcServiceT<'a> for RpcService { let rp = callback(id, params, conn_id, max_response_body_size, extensions); ResponseFuture::ready(rp) } + MethodCallback::Alias(_) => unreachable!("alias resolved in `method_with_name"), }, } }