diff --git a/pgdog/src/config/mod.rs b/pgdog/src/config/mod.rs index 6f9d794b..e156c9f3 100644 --- a/pgdog/src/config/mod.rs +++ b/pgdog/src/config/mod.rs @@ -352,6 +352,11 @@ pub struct General { /// Dry run for sharding. Parse the query, route to shard 0. #[serde(default)] pub dry_run: bool, + /// Require explicit shard selection for every query. + /// When enabled, queries without explicit shard selection will fail with an error. + /// This prevents unintended cross-shard queries in multi-tenant environments. + #[serde(default)] + pub require_shard_selection: bool, /// Idle timeout. #[serde(default = "General::idle_timeout")] pub idle_timeout: u64, @@ -456,6 +461,7 @@ impl Default for General { query_timeout: Self::default_query_timeout(), checkout_timeout: Self::checkout_timeout(), dry_run: bool::default(), + require_shard_selection: bool::default(), idle_timeout: Self::idle_timeout(), client_idle_timeout: Self::default_client_idle_timeout(), mirror_queue: Self::mirror_queue(), diff --git a/pgdog/src/frontend/router/parser/error.rs b/pgdog/src/frontend/router/parser/error.rs index a1e26e73..8d3906ee 100644 --- a/pgdog/src/frontend/router/parser/error.rs +++ b/pgdog/src/frontend/router/parser/error.rs @@ -53,4 +53,7 @@ pub enum Error { #[error("missing parameter: ${0}")] MissingParameter(usize), + + #[error("shard selection is required but not provided")] + ShardSelectionRequired, } diff --git a/pgdog/src/frontend/router/parser/query.rs b/pgdog/src/frontend/router/parser/query.rs index e5d19693..2ec89233 100644 --- a/pgdog/src/frontend/router/parser/query.rs +++ b/pgdog/src/frontend/router/parser/query.rs @@ -430,6 +430,15 @@ impl QueryParser { debug!("query router decision: {:#?}", command); + // Check if explicit shard selection is required + if config().config.general.require_shard_selection && shards > 1 { + if let Command::Query(ref route) = command { + if route.is_all_shards() { + return Err(Error::ShardSelectionRequired); + } + } + } + if dry_run { let default_route = Route::write(None); cache.record_command( @@ -1454,4 +1463,58 @@ mod test { assert_eq!(route.shard(), &Shard::All); } + + #[test] + fn test_require_shard_selection() { + // Save original config + let original_config = crate::config::config().clone(); + + // Test that queries without shard selection fail when required + let mut config = crate::config::ConfigAndUsers::default(); + config.config.general.require_shard_selection = true; + // Add multiple shards to trigger the requirement check + config.config.databases = vec![ + crate::config::Database { + name: "test".to_string(), + shard: 0, + host: "localhost".to_string(), + ..Default::default() + }, + crate::config::Database { + name: "test".to_string(), + shard: 1, + host: "localhost".to_string(), + ..Default::default() + }, + ]; + crate::config::set(config).unwrap(); + + // Test that a query without explicit shard selection fails + let query = "SELECT * FROM sharded WHERE value = 'test'"; + let mut query_parser = QueryParser::default(); + let buffer = Buffer::from(vec![Query::new(query).into()]); + let cluster = Cluster::new_test(); + let mut stmt = PreparedStatements::default(); + let params = Parameters::default(); + let context = RouterContext::new(&buffer, &cluster, &mut stmt, ¶ms, false).unwrap(); + + let result = query_parser.parse(context); + assert!(matches!(result, Err(Error::ShardSelectionRequired))); + + // Test that a query with explicit shard selection works + let query_with_shard = "/* pgdog_shard: 1 */ SELECT * FROM sharded WHERE value = 'test'"; + let buffer_with_shard = Buffer::from(vec![Query::new(query_with_shard).into()]); + let context_with_shard = RouterContext::new(&buffer_with_shard, &cluster, &mut stmt, ¶ms, false).unwrap(); + + let mut query_parser2 = QueryParser::default(); + let result_with_shard = query_parser2.parse(context_with_shard); + assert!(result_with_shard.is_ok()); + + if let Ok(Command::Query(route)) = result_with_shard { + assert_eq!(route.shard(), &Shard::Direct(1)); + } + + // Restore original config + crate::config::set((*original_config).clone()).unwrap(); + } }