diff --git a/pgdog/src/admin/mod.rs b/pgdog/src/admin/mod.rs index 9494e429..58f6e649 100644 --- a/pgdog/src/admin/mod.rs +++ b/pgdog/src/admin/mod.rs @@ -28,6 +28,7 @@ pub mod show_servers; pub mod show_stats; pub mod show_version; pub mod shutdown; +pub mod pause_traffic; pub use error::Error; diff --git a/pgdog/src/admin/parser.rs b/pgdog/src/admin/parser.rs index 4a44af68..f06b1d82 100644 --- a/pgdog/src/admin/parser.rs +++ b/pgdog/src/admin/parser.rs @@ -1,5 +1,7 @@ //! Admin command parser. +use crate::admin::pause_traffic::PauseTraffic; + use super::{ ban::Ban, pause::Pause, prelude::Message, probe::Probe, reconnect::Reconnect, reload::Reload, reset_query_cache::ResetQueryCache, set::Set, setup_schema::SetupSchema, @@ -14,6 +16,7 @@ use tracing::debug; /// Parser result. pub enum ParseResult { Pause(Pause), + PauseTraffic(PauseTraffic), Reconnect(Reconnect), ShowClients(ShowClients), Reload(Reload), @@ -41,6 +44,7 @@ impl ParseResult { match self { Pause(pause) => pause.execute().await, + PauseTraffic(pause_traffic) => pause_traffic.execute().await, Reconnect(reconnect) => reconnect.execute().await, ShowClients(show_clients) => show_clients.execute().await, Reload(reload) => reload.execute().await, @@ -68,6 +72,7 @@ impl ParseResult { match self { Pause(pause) => pause.name(), + PauseTraffic(pause_traffic) => pause_traffic.name(), Reconnect(reconnect) => reconnect.name(), ShowClients(show_clients) => show_clients.name(), Reload(reload) => reload.name(), @@ -101,6 +106,9 @@ impl Parser { Ok(match iter.next().ok_or(Error::Syntax)?.trim() { "pause" | "resume" => ParseResult::Pause(Pause::parse(&sql)?), + "pause_traffic" | "resume_traffic" => { + ParseResult::PauseTraffic(PauseTraffic::parse(&sql)?) + } "shutdown" => ParseResult::Shutdown(Shutdown::parse(&sql)?), "reconnect" => ParseResult::Reconnect(Reconnect::parse(&sql)?), "reload" => ParseResult::Reload(Reload::parse(&sql)?), diff --git a/pgdog/src/admin/pause_traffic.rs b/pgdog/src/admin/pause_traffic.rs new file mode 100644 index 00000000..f93ee732 --- /dev/null +++ b/pgdog/src/admin/pause_traffic.rs @@ -0,0 +1,49 @@ +//! Pause the traffic, ignore all traffic coming from the clients, and proccess it only when resumed. + +use tracing::info; + +use crate::frontend::comms::comms; + +use super::prelude::*; + +/// Pause traffic. +#[derive(Default)] +pub struct PauseTraffic { + resume: bool, +} + +#[async_trait] +impl Command for PauseTraffic { + fn parse(sql: &str) -> Result { + let parts = sql.split(" ").collect::>(); + + match parts[..] { + ["pause_traffic"] => Ok(Self::default()), + ["resume_traffic"] => Ok(Self { resume: true }), + _ => Err(Error::Syntax), + } + } + + async fn execute(&self) -> Result, Error> { + match self.resume { + true => { + comms().unpause_traffic(); + info!("Traffic resumed"); + } + false => { + comms().pause_traffic(); + info!("Traffic paused"); + } + } + + Ok(vec![]) + } + + fn name(&self) -> String { + if self.resume { + "RESUME".into() + } else { + "PAUSE".into() + } + } +} diff --git a/pgdog/src/frontend/client/mod.rs b/pgdog/src/frontend/client/mod.rs index ca5c2032..a42369d6 100644 --- a/pgdog/src/frontend/client/mod.rs +++ b/pgdog/src/frontend/client/mod.rs @@ -312,10 +312,34 @@ impl Client { let mut inner = Inner::new(self)?; let shutdown = self.comms.shutting_down(); + let pause_traffic = { + let shutdown = shutdown.clone(); + let comms = self.comms.clone(); + let unpausing_traffic_notif = comms.unpausing_traffic(); + #[cold] + async move || { + let unpausing_traffic_notif = unpausing_traffic_notif.notified(); + if !comms.is_traffic_paused() { + return false; + } + select! { + _ = unpausing_traffic_notif => { + // Traffic is unpaused. + false + } + // TODO: use ArcBool pattern to avoid race condition + _ = shutdown.notified() => { + // Shutdown requested. + true + } + } + } + }; loop { let query_timeout = self.timeouts.query_timeout(&inner.stats.state); select! { + // TODO: use ArcBool pattern to avoid race condition _ = shutdown.notified() => { if !inner.backend.connected() && inner.start_transaction.is_none() { break; @@ -324,6 +348,12 @@ impl Client { // Async messages. message = timeout(query_timeout, inner.backend.read()) => { + if self.comms.is_traffic_paused() { + // This returns true if shutdown is requested. + if pause_traffic().await { + break; + } + } let message = message??; let disconnect = self.server_message(&mut inner.get(), message).await?; if disconnect { @@ -332,6 +362,12 @@ impl Client { } buffer = self.buffer(&inner.stats.state) => { + if self.comms.is_traffic_paused() { + // This returns true if shutdown is requested. + if pause_traffic().await { + break; + } + } let event = buffer?; if !self.request_buffer.is_empty() { let disconnect = self.client_messages(inner.get()).await?; diff --git a/pgdog/src/frontend/comms.rs b/pgdog/src/frontend/comms.rs index 5b04efb9..b6cf4803 100644 --- a/pgdog/src/frontend/comms.rs +++ b/pgdog/src/frontend/comms.rs @@ -27,6 +27,8 @@ pub fn comms() -> Comms { /// Sync primitives shared between all clients. struct Global { shutdown: Arc, + traffic_is_paused: AtomicBool, + traffic_unpaused_notif: Arc, offline: AtomicBool, // This uses the FNV hasher, which is safe, // because BackendKeyData is randomly generated by us, @@ -56,6 +58,8 @@ impl Comms { shutdown: Arc::new(Notify::new()), offline: AtomicBool::new(false), clients: Mutex::new(HashMap::default()), + traffic_is_paused: false.into(), + traffic_unpaused_notif: Arc::new(Notify::new()), tracker: TaskTracker::new(), }), id: None, @@ -148,4 +152,27 @@ impl Comms { pub fn offline(&self) -> bool { self.global.offline.load(Ordering::Relaxed) } + + /// Traffic is/should be paused. + pub fn is_traffic_paused(&self) -> bool { + self.global.traffic_is_paused.load(Ordering::Relaxed) + } + + /// Pause traffic to all clients. + pub fn pause_traffic(&self) { + self.global.traffic_is_paused.store(true, Ordering::Relaxed); + } + + /// Unpause traffic to all clients. + pub fn unpause_traffic(&self) { + self.global + .traffic_is_paused + .store(false, Ordering::Relaxed); + self.global.traffic_unpaused_notif.notify_waiters(); + } + + /// Get the `Notify` to wait for unpausing traffic. + pub fn unpausing_traffic(&self) -> Arc { + self.global.traffic_unpaused_notif.clone() + } }