From 5fe831d82d73fdcaf07e80d8ed81c43047223f2b Mon Sep 17 00:00:00 2001 From: tingold Date: Wed, 28 May 2025 13:25:34 -0400 Subject: [PATCH 1/3] initial commit --- Cargo.lock | 1 + pgdog/Cargo.toml | 2 +- pgdog/src/backend/pool/cluster.rs | 3 + pgdog/src/config/mod.rs | 15 +- pgdog/src/config/shards.rs | 230 ++++++++++++++++++ pgdog/src/frontend/router/sharding/context.rs | 12 +- .../router/sharding/context_builder.rs | 38 ++- pgdog/src/frontend/router/sharding/error.rs | 1 + pgdog/src/frontend/router/sharding/mod.rs | 49 ++-- .../src/frontend/router/sharding/operator.rs | 4 + pgdog/src/frontend/router/sharding/value.rs | 22 +- 11 files changed, 348 insertions(+), 29 deletions(-) create mode 100644 pgdog/src/config/shards.rs diff --git a/Cargo.lock b/Cargo.lock index 58a00a54..05e81387 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -386,6 +386,7 @@ dependencies = [ "iana-time-zone", "js-sys", "num-traits", + "serde", "wasm-bindgen", "windows-link", ] diff --git a/pgdog/Cargo.toml b/pgdog/Cargo.toml index 48104fb5..a12c60cb 100644 --- a/pgdog/Cargo.toml +++ b/pgdog/Cargo.toml @@ -49,7 +49,7 @@ uuid = { version = "1", features = ["v4"] } url = "2" ratatui = { version = "0.30.0-alpha.1", optional = true } rmp-serde = "1" -chrono = "0.4" +chrono = { version = "0.4", features = ["serde"] } hyper = { version = "1", features = ["full"] } http-body-util = "0.1" hyper-util = { version = "0.1", features = ["full"] } diff --git a/pgdog/src/backend/pool/cluster.rs b/pgdog/src/backend/pool/cluster.rs index c07a7e5f..d8c0027b 100644 --- a/pgdog/src/backend/pool/cluster.rs +++ b/pgdog/src/backend/pool/cluster.rs @@ -360,6 +360,9 @@ mod test { data_type: DataType::Bigint, centroids_path: None, centroid_probes: 1, + sharding_method: None, + shard_range_map: None, + shard_list_map: None }], vec!["sharded_omni".into()], false, diff --git a/pgdog/src/config/mod.rs b/pgdog/src/config/mod.rs index 09058a46..0dc46f9b 100644 --- a/pgdog/src/config/mod.rs +++ b/pgdog/src/config/mod.rs @@ -4,6 +4,7 @@ pub mod convert; pub mod error; pub mod overrides; pub mod url; +mod shards; use error::Error; pub use overrides::Overrides; @@ -21,6 +22,7 @@ use serde::{Deserialize, Serialize}; use tracing::info; use tracing::warn; +pub(crate) use crate::config::shards::{ShardingMethod, ShardListMap, ShardRangeMap}; use crate::net::messages::Vector; use crate::util::{human_duration_optional, random_string}; @@ -826,6 +828,12 @@ pub struct ShardedTable { /// How many centroids to probe. #[serde(default)] pub centroid_probes: usize, + #[serde(default)] + pub sharding_method: Option, + + pub shard_range_map: Option, + + pub shard_list_map: Option } impl ShardedTable { @@ -865,6 +873,10 @@ pub enum DataType { Bigint, Uuid, Vector, + // TODO: implement more types? + // String, + // DateTimeUTC + // Float } #[derive(Serialize, Deserialize, PartialEq, Debug, Clone, Default)] @@ -955,8 +967,8 @@ pub struct MultiTenant { #[cfg(test)] pub mod test { + use crate::backend::databases::init; - use super::*; pub fn load_test() { @@ -1052,4 +1064,5 @@ column = "tenant_id" assert_eq!(config.tcp.retries().unwrap(), 5); assert_eq!(config.multi_tenant.unwrap().column, "tenant_id"); } + } diff --git a/pgdog/src/config/shards.rs b/pgdog/src/config/shards.rs new file mode 100644 index 00000000..65b56856 --- /dev/null +++ b/pgdog/src/config/shards.rs @@ -0,0 +1,230 @@ +use std::collections::HashMap; + +use serde::{Deserialize, Deserializer, Serialize, Serializer}; + +use crate::frontend::router::parser::{Shard}; + + +// ============================================================================= +// Serialization Helper Module +// ============================================================================= + +/// Helper module for (de)serializing maps with usize keys as strings +mod usize_map_keys_as_strings { + use super::*; + + pub fn serialize(map: &HashMap, serializer: S) -> Result + where + S: Serializer, + V: Serialize, + { + let string_map: HashMap = map + .iter() + .map(|(k, v)| (k.to_string(), v)) + .collect(); + string_map.serialize(serializer) + } + + pub fn deserialize<'de, D, V>(deserializer: D) -> Result, D::Error> + where + D: Deserializer<'de>, + V: Deserialize<'de>, + { + let string_map = HashMap::::deserialize(deserializer)?; + string_map + .into_iter() + .map(|(s, v)| { + s.parse::() + .map(|k| (k, v)) + .map_err(serde::de::Error::custom) + }) + .collect() + } +} + +// ============================================================================= +// Core Sharding Types +// ============================================================================= + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)] +#[serde(rename_all = "snake_case")] +pub enum ShardingMethod { + #[default] + Hash, + Range, + List, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct ShardRange { + pub start: Option, + pub end: Option, + #[serde(default)] + pub no_max: bool, + #[serde(default)] + pub no_min: bool, +} + +impl ShardRange { + /// Check if a value falls within this range + pub fn contains(&self, value: i64) -> bool { + // Check lower bound + if !self.no_min { + if let Some(start) = self.start { + if value < start { + return false; + } + } + } + + // Check upper bound + if !self.no_max { + if let Some(end) = self.end { + if value >= end { + // Using >= for exclusive upper bound + return false; + } + } + } + + true + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct ShardList { + pub values: Vec, +} + +impl ShardList { + /// Check if a value is contained in this list + pub fn contains(&self, value: i64) -> bool { + self.values.contains(&value) + } +} + +// ============================================================================= +// Shard Map Types +// ============================================================================= + +/// A map of shard IDs to their range definitions +#[derive(Debug, Clone, PartialEq)] +pub struct ShardRangeMap(pub HashMap); + +impl ShardRangeMap { + pub fn new() -> Self { + Self::default() + } + + /// Find the shard key for a given value based on range containment + pub fn find_shard_key(&self, value: i64) -> Option { + for (shard_id, range) in &self.0 { + if range.contains(value) { + return Some(Shard::Direct(*shard_id)); + } + } + None + } +} + +impl Default for ShardRangeMap { + fn default() -> Self { + Self(HashMap::new()) + } +} + +impl Serialize for ShardRangeMap { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + usize_map_keys_as_strings::serialize(&self.0, serializer) + } +} + +impl<'de> Deserialize<'de> for ShardRangeMap { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + Ok(ShardRangeMap(usize_map_keys_as_strings::deserialize( + deserializer, + )?)) + } +} + +/// A map of shard IDs to their list definitions +#[derive(Debug, Clone, PartialEq)] +pub struct ShardListMap(pub HashMap); + +impl ShardListMap { + pub fn new() -> Self { + Self::default() + } + + /// Find the shard key for a given value based on list containment + pub fn find_shard_key(&self, value: i64) -> Option { + for (shard_id, list) in &self.0 { + if list.contains(value) { + return Some(Shard::Direct(*shard_id)); + } + } + None + } +} + +impl Default for ShardListMap { + fn default() -> Self { + Self(HashMap::new()) + } +} + +impl Serialize for ShardListMap { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + usize_map_keys_as_strings::serialize(&self.0, serializer) + } +} + +impl<'de> Deserialize<'de> for ShardListMap { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + Ok(ShardListMap(usize_map_keys_as_strings::deserialize( + deserializer, + )?)) + } +} + +// ============================================================================= +// Shardable Trait and Implementations +// ============================================================================= + +/// Trait for types that can provide sharding functionality +pub trait Shardable { + /// Get the shard ID for a given value + fn shard(&self, value: i64) -> Shard; +} + +impl Shardable for ShardRangeMap { + fn shard(&self, value: i64) -> Shard { + if self.0.is_empty() { + return Shard::All; + } + + self.find_shard_key(value).unwrap_or_else(|| Shard::All) + } +} + +impl Shardable for ShardListMap { + fn shard(&self, value: i64) -> Shard { + if self.0.is_empty() { + return Shard::All; + } + + self.find_shard_key(value).unwrap_or_else(|| Shard::All) + } +} \ No newline at end of file diff --git a/pgdog/src/frontend/router/sharding/context.rs b/pgdog/src/frontend/router/sharding/context.rs index 9687081b..051cd91b 100644 --- a/pgdog/src/frontend/router/sharding/context.rs +++ b/pgdog/src/frontend/router/sharding/context.rs @@ -1,5 +1,4 @@ use crate::frontend::router::parser::Shard; - use super::{Error, Operator, Value}; #[derive(Debug)] @@ -16,7 +15,6 @@ impl<'a> Context<'a> { return Ok(Shard::Direct(hash as usize % shards)); } } - Operator::Centroids { shards, probes, @@ -26,6 +24,16 @@ impl<'a> Context<'a> { return Ok(centroids.shard(&vector, *shards, *probes)); } } + Operator::Ranges(srm)=> { + if let Some(i) = self.value.int()? { + return Ok(srm.find_shard_key(i).unwrap()) + } + } + Operator::Lists(slm)=> { + if let Some(i) = self.value.int()? { + return Ok(slm.find_shard_key(i).unwrap()) + } + } } Ok(Shard::All) diff --git a/pgdog/src/frontend/router/sharding/context_builder.rs b/pgdog/src/frontend/router/sharding/context_builder.rs index 4dbb4458..b9330ed9 100644 --- a/pgdog/src/frontend/router/sharding/context_builder.rs +++ b/pgdog/src/frontend/router/sharding/context_builder.rs @@ -1,4 +1,4 @@ -use crate::config::{DataType, ShardedTable}; +use crate::config::{DataType, ShardedTable, ShardingMethod, ShardListMap, ShardRangeMap}; use super::{Centroids, Context, Data, Error, Operator, Value}; @@ -8,6 +8,10 @@ pub struct ContextBuilder<'a> { operator: Option>, centroids: Option>, probes: usize, + sharding_method: Option, + shard_range_map: Option, + shard_list_map: Option + } impl<'a> ContextBuilder<'a> { @@ -22,6 +26,11 @@ impl<'a> ContextBuilder<'a> { probes: table.centroid_probes, operator: None, value: None, + // added for list and range sharding + // todo: add lifetimes to these to avoid cloning + sharding_method: table.sharding_method.clone(), + shard_range_map: table.shard_range_map.clone(), + shard_list_map: table.shard_list_map.clone(), } } @@ -37,6 +46,9 @@ impl<'a> ContextBuilder<'a> { probes: 0, centroids: None, operator: None, + sharding_method: None, + shard_range_map: None, + shard_list_map: None, }) } else if uuid.valid() { Ok(Self { @@ -45,6 +57,9 @@ impl<'a> ContextBuilder<'a> { probes: 0, centroids: None, operator: None, + sharding_method: None, + shard_range_map: None, + shard_list_map: None, }) } else { Err(Error::IncompleteContext) @@ -57,9 +72,24 @@ impl<'a> ContextBuilder<'a> { shards, probes: self.probes, centroids, - }); - } else { - self.operator = Some(Operator::Shards(shards)) + }) + } else if let Some(method) = self.sharding_method.take() { + match method { + ShardingMethod::Hash => { + self.operator = Some(Operator::Shards(shards)); + return self + } + ShardingMethod::Range => { + if self.shard_range_map.is_some() { + self.operator = Some(Operator::Ranges(self.shard_range_map.clone().unwrap())) + } + } + ShardingMethod::List => { + if self.shard_list_map.is_some() { + self.operator = Some(Operator::Lists(self.shard_list_map.clone().unwrap())) + } + } + } } self } diff --git a/pgdog/src/frontend/router/sharding/error.rs b/pgdog/src/frontend/router/sharding/error.rs index b24275a4..b6b9258f 100644 --- a/pgdog/src/frontend/router/sharding/error.rs +++ b/pgdog/src/frontend/router/sharding/error.rs @@ -25,3 +25,4 @@ pub enum Error { #[error("wrong integer binary size")] IntegerSize, } + diff --git a/pgdog/src/frontend/router/sharding/mod.rs b/pgdog/src/frontend/router/sharding/mod.rs index 9b5abb35..db9265fc 100644 --- a/pgdog/src/frontend/router/sharding/mod.rs +++ b/pgdog/src/frontend/router/sharding/mod.rs @@ -1,3 +1,4 @@ +use bytes::Bytes; use uuid::Uuid; use crate::{ @@ -16,6 +17,7 @@ pub mod tables; pub mod value; pub mod vector; + pub use context::*; pub use context_builder::*; pub use error::Error; @@ -26,6 +28,7 @@ pub use vector::{Centroids, Distance}; use super::parser::Shard; + /// Hash `BIGINT`. pub fn bigint(id: i64) -> u64 { unsafe { ffi::hash_combine64(0, ffi::hashint8extended(id)) } @@ -41,6 +44,15 @@ pub fn uuid(uuid: Uuid) -> u64 { } } +pub fn bytes(bytes: Bytes) -> u64 { + unsafe { + ffi::hash_combine64( + 0, + ffi::hash_bytes_extended(bytes.as_ptr(), bytes.len() as i64), + ) + } +} + /// Shard a string value, parsing out a BIGINT, UUID, or vector. /// /// TODO: This is really not great, we should pass in the type oid @@ -69,26 +81,27 @@ pub(crate) fn shard_value( centroids: &Vec, centroid_probes: usize, ) -> Shard { - match data_type { - DataType::Bigint => value - .parse() - .map(|v| bigint(v) as usize % shards) - .ok() - .map(Shard::Direct) - .unwrap_or(Shard::All), - DataType::Uuid => value - .parse() - .map(|v| uuid(v) as usize % shards) - .ok() - .map(Shard::Direct) - .unwrap_or(Shard::All), - DataType::Vector => Vector::try_from(value) - .ok() - .map(|v| Centroids::from(centroids).shard(&v, shards, centroid_probes)) - .unwrap_or(Shard::All), - } + match data_type { + DataType::Bigint => value + .parse() + .map(|v| bigint(v) as usize % shards) + .ok() + .map(Shard::Direct) + .unwrap_or(Shard::All), + DataType::Uuid => value + .parse() + .map(|v| uuid(v) as usize % shards) + .ok() + .map(Shard::Direct) + .unwrap_or(Shard::All), + DataType::Vector => Vector::try_from(value) + .ok() + .map(|v| Centroids::from(centroids).shard(&v, shards, centroid_probes)) + .unwrap_or(Shard::All), + } } + pub(crate) fn shard_binary( bytes: &[u8], data_type: &DataType, diff --git a/pgdog/src/frontend/router/sharding/operator.rs b/pgdog/src/frontend/router/sharding/operator.rs index 380742d8..894259d3 100644 --- a/pgdog/src/frontend/router/sharding/operator.rs +++ b/pgdog/src/frontend/router/sharding/operator.rs @@ -1,3 +1,4 @@ +use crate::config::{ShardListMap, ShardRangeMap}; use super::Centroids; #[derive(Debug)] @@ -8,4 +9,7 @@ pub enum Operator<'a> { probes: usize, centroids: Centroids<'a>, }, + Lists(ShardListMap), + Ranges(ShardRangeMap), + } diff --git a/pgdog/src/frontend/router/sharding/value.rs b/pgdog/src/frontend/router/sharding/value.rs index 101ecb9a..164ee303 100644 --- a/pgdog/src/frontend/router/sharding/value.rs +++ b/pgdog/src/frontend/router/sharding/value.rs @@ -1,5 +1,4 @@ use std::str::{from_utf8, FromStr}; - use uuid::Uuid; use super::{bigint, uuid, Error}; @@ -71,6 +70,22 @@ impl<'a> Value<'a> { Ok(None) } } + + pub fn int(&self) -> Result, Error> { + match self.data_type { + DataType::Bigint => match self.data { + Data::Text(text) => Ok(Some(text.parse::()?)), + Data::Binary(data) => Ok(Some(match data.len() { + 2 => i16::from_be_bytes(data.try_into()?) as i64, + 4 => i32::from_be_bytes(data.try_into()?) as i64, + 8 => i64::from_be_bytes(data.try_into()?) as i64, + _ => return Err(Error::IntegerSize), + })), + Data::Integer(int) => Ok(Some(int)), + }, + _ => Ok(None), + } + } pub fn valid(&self) -> bool { match self.data_type { @@ -99,7 +114,7 @@ impl<'a> Value<'a> { 8 => i64::from_be_bytes(data.try_into()?) as i64, _ => return Err(Error::IntegerSize), }))), - Data::Integer(int) => Ok(Some(bigint(int))), + Data::Integer(int) => Ok(Some(bigint(int))), }, DataType::Uuid => match self.data { @@ -107,8 +122,9 @@ impl<'a> Value<'a> { Data::Binary(data) => Ok(Some(uuid(Uuid::from_bytes(data.try_into()?)))), Data::Integer(_) => Ok(None), }, - + DataType::Vector => Ok(None), + } } } From 402d0a6980a272fa72a3ab0b9768dc88b61fc9fb Mon Sep 17 00:00:00 2001 From: tingold Date: Wed, 28 May 2025 13:36:37 -0400 Subject: [PATCH 2/3] removing unused trait --- pgdog/src/config/shards.rs | 30 ------------------------------ 1 file changed, 30 deletions(-) diff --git a/pgdog/src/config/shards.rs b/pgdog/src/config/shards.rs index 65b56856..df10ca0b 100644 --- a/pgdog/src/config/shards.rs +++ b/pgdog/src/config/shards.rs @@ -198,33 +198,3 @@ impl<'de> Deserialize<'de> for ShardListMap { )?)) } } - -// ============================================================================= -// Shardable Trait and Implementations -// ============================================================================= - -/// Trait for types that can provide sharding functionality -pub trait Shardable { - /// Get the shard ID for a given value - fn shard(&self, value: i64) -> Shard; -} - -impl Shardable for ShardRangeMap { - fn shard(&self, value: i64) -> Shard { - if self.0.is_empty() { - return Shard::All; - } - - self.find_shard_key(value).unwrap_or_else(|| Shard::All) - } -} - -impl Shardable for ShardListMap { - fn shard(&self, value: i64) -> Shard { - if self.0.is_empty() { - return Shard::All; - } - - self.find_shard_key(value).unwrap_or_else(|| Shard::All) - } -} \ No newline at end of file From 8f7674c884655c814b3c0b3a300ac286f5948186 Mon Sep 17 00:00:00 2001 From: tingold Date: Wed, 28 May 2025 17:11:41 -0400 Subject: [PATCH 3/3] lists and ranges working --- .gitignore | 1 + pgdog/src/backend/pool/cluster.rs | 2 +- pgdog/src/config/mod.rs | 321 +++++++++++++++++- pgdog/src/config/shards.rs | 8 +- pgdog/src/frontend/router/sharding/context.rs | 10 +- .../router/sharding/context_builder.rs | 14 +- pgdog/src/frontend/router/sharding/error.rs | 1 - pgdog/src/frontend/router/sharding/mod.rs | 39 +-- .../src/frontend/router/sharding/operator.rs | 3 +- pgdog/src/frontend/router/sharding/value.rs | 7 +- 10 files changed, 354 insertions(+), 52 deletions(-) diff --git a/.gitignore b/.gitignore index 978060a6..309e8e34 100644 --- a/.gitignore +++ b/.gitignore @@ -42,3 +42,4 @@ toxi.log *.sqlite3 perf.data perf.data.old +/shard_test/ diff --git a/pgdog/src/backend/pool/cluster.rs b/pgdog/src/backend/pool/cluster.rs index d8c0027b..deb37527 100644 --- a/pgdog/src/backend/pool/cluster.rs +++ b/pgdog/src/backend/pool/cluster.rs @@ -362,7 +362,7 @@ mod test { centroid_probes: 1, sharding_method: None, shard_range_map: None, - shard_list_map: None + shard_list_map: None, }], vec!["sharded_omni".into()], false, diff --git a/pgdog/src/config/mod.rs b/pgdog/src/config/mod.rs index 0dc46f9b..7baa23df 100644 --- a/pgdog/src/config/mod.rs +++ b/pgdog/src/config/mod.rs @@ -3,8 +3,8 @@ pub mod convert; pub mod error; pub mod overrides; -pub mod url; mod shards; +pub mod url; use error::Error; pub use overrides::Overrides; @@ -22,7 +22,7 @@ use serde::{Deserialize, Serialize}; use tracing::info; use tracing::warn; -pub(crate) use crate::config::shards::{ShardingMethod, ShardListMap, ShardRangeMap}; +pub(crate) use crate::config::shards::{ShardListMap, ShardRangeMap, ShardingMethod}; use crate::net::messages::Vector; use crate::util::{human_duration_optional, random_string}; @@ -833,7 +833,7 @@ pub struct ShardedTable { pub shard_range_map: Option, - pub shard_list_map: Option + pub shard_list_map: Option, } impl ShardedTable { @@ -967,9 +967,10 @@ pub struct MultiTenant { #[cfg(test)] pub mod test { - - use crate::backend::databases::init; + use super::*; + use crate::backend::databases::init; + use crate::config::shards::ShardRange; pub fn load_test() { let mut config = ConfigAndUsers::default(); @@ -1065,4 +1066,314 @@ column = "tenant_id" assert_eq!(config.multi_tenant.unwrap().column, "tenant_id"); } + #[test] + fn test_load_sharded_table_with_range_map() { + let toml_str = r#" + database = "pgdog_sharded" + name = "range_sharded" + column = "user_id" + data_type = "bigint" + sharding_method = "range" + + [shard_range_map] + "0" = { start = 0, end = 1000 } + "1" = { start = 1000, end = 2000 } + "2" = { start = 2000, end = 3000 } + "#; + + let table: ShardedTable = toml::from_str(toml_str).unwrap(); + + // Verify basic fields + assert_eq!(table.database, "pgdog_sharded"); + assert_eq!(table.name, Some("range_sharded".to_string())); + assert_eq!(table.column, "user_id"); + assert_eq!(table.data_type, crate::config::DataType::Bigint); + assert_eq!(table.sharding_method, Some(ShardingMethod::Range)); + + // Verify shard_range_map + let range_map = table.shard_range_map.unwrap(); + assert_eq!(range_map.0.len(), 3); + + // Check first range + let range_0 = range_map.0.get(&0).unwrap(); + assert_eq!(range_0.start, Some(0)); + assert_eq!(range_0.end, Some(1000)); + assert_eq!(range_0.no_min, false); + assert_eq!(range_0.no_max, false); + + // Check second range + let range_1 = range_map.0.get(&1).unwrap(); + assert_eq!(range_1.start, Some(1000)); + assert_eq!(range_1.end, Some(2000)); + + // Check third range + let range_2 = range_map.0.get(&2).unwrap(); + assert_eq!(range_2.start, Some(2000)); + assert_eq!(range_2.end, Some(3000)); + + // Verify that shard_list_map is None + assert!(table.shard_list_map.is_none()); + } + + #[test] + fn test_load_sharded_table_with_list_map() { + let toml_str = r#" + database = "pgdog_sharded" + name = "list_sharded" + column = "category_id" + data_type = "bigint" + sharding_method = "list" + + [shard_list_map] + "0" = { values = [1, 3, 5, 7, 9] } + "1" = { values = [2, 4, 6, 8, 10] } + "2" = { values = [11, 12, 13, 14, 15] } + "#; + + let table: ShardedTable = toml::from_str(toml_str).unwrap(); + + // Verify basic fields + assert_eq!(table.database, "pgdog_sharded"); + assert_eq!(table.name, Some("list_sharded".to_string())); + assert_eq!(table.column, "category_id"); + assert_eq!(table.data_type, crate::config::DataType::Bigint); + assert_eq!(table.sharding_method, Some(ShardingMethod::List)); + + // Verify shard_list_map + let list_map = table.shard_list_map.unwrap(); + assert_eq!(list_map.0.len(), 3); + + // Check first list + let list_0 = list_map.0.get(&0).unwrap(); + assert_eq!(list_0.values, vec![1, 3, 5, 7, 9]); + + // Check second list + let list_1 = list_map.0.get(&1).unwrap(); + assert_eq!(list_1.values, vec![2, 4, 6, 8, 10]); + + // Check third list + let list_2 = list_map.0.get(&2).unwrap(); + assert_eq!(list_2.values, vec![11, 12, 13, 14, 15]); + + // Verify that shard_range_map is None + assert!(table.shard_range_map.is_none()); + } + + #[test] + fn test_load_sharded_table_with_special_range_flags() { + let toml_str = r#" + database = "pgdog_sharded" + name = "special_range_sharded" + column = "timestamp_id" + data_type = "bigint" + sharding_method = "range" + + [shard_range_map] + "0" = { start = 0, end = 1000 } + "1" = { start = 1000, no_max = true } + "2" = { no_min = true, end = 0 } + "#; + + let table: ShardedTable = toml::from_str(toml_str).unwrap(); + + // Verify shard_range_map with special flags + let range_map = table.shard_range_map.unwrap(); + assert_eq!(range_map.0.len(), 3); + + // Standard range + let range_0 = range_map.0.get(&0).unwrap(); + assert_eq!(range_0.start, Some(0)); + assert_eq!(range_0.end, Some(1000)); + assert_eq!(range_0.no_min, false); + assert_eq!(range_0.no_max, false); + + // Range with no maximum (unbounded upper) + let range_1 = range_map.0.get(&1).unwrap(); + assert_eq!(range_1.start, Some(1000)); + assert_eq!(range_1.end, None); + assert_eq!(range_1.no_min, false); + assert_eq!(range_1.no_max, true); + + // Range with no minimum (unbounded lower) + let range_2 = range_map.0.get(&2).unwrap(); + assert_eq!(range_2.start, None); + assert_eq!(range_2.end, Some(0)); + assert_eq!(range_2.no_min, true); + assert_eq!(range_2.no_max, false); + } + + #[test] + fn test_load_sharded_table_with_empty_list_values() { + let toml_str = r#" + database = "pgdog_sharded" + name = "empty_list_sharded" + column = "tag_id" + data_type = "bigint" + sharding_method = "list" + + [shard_list_map] + "0" = { values = [1, 2, 3] } + "1" = { values = [] } + "#; + + let table: ShardedTable = toml::from_str(toml_str).unwrap(); + + // Verify shard_list_map with an empty list + let list_map = table.shard_list_map.unwrap(); + assert_eq!(list_map.0.len(), 2); + + // Check first list + let list_0 = list_map.0.get(&0).unwrap(); + assert_eq!(list_0.values, vec![1, 2, 3]); + + // Check empty list + let list_1 = list_map.0.get(&1).unwrap(); + assert!(list_1.values.is_empty()); + } + + #[test] + fn test_load_sharded_table_with_invalid_shard_map_keys() { + let toml_str = r#" + database = "pgdog_sharded" + name = "invalid_keys" + column = "user_id" + data_type = "bigint" + sharding_method = "range" + + [shard_range_map] + "invalid" = { start = 0, end = 1000 } + "0" = { start = 1000, end = 2000 } + "#; + + let result = toml::from_str::(toml_str); + assert!(result.is_err()); + + // Verify the error message contains information about parsing failure + let error = result.unwrap_err().to_string(); + assert!(error.contains("invalid") || error.contains("parse")); + } + + #[test] + fn test_load_sharded_table_with_both_maps() { + let toml_str = r#" + database = "pgdog_sharded" + name = "dual_sharded" + column = "id" + data_type = "bigint" + sharding_method = "range" + + [shard_range_map] + "0" = { start = 0, end = 1000 } + "1" = { start = 1000, end = 2000 } + + [shard_list_map] + "0" = { values = [1, 3, 5] } + "1" = { values = [2, 4, 6] } + "#; + + let table: ShardedTable = toml::from_str(toml_str).unwrap(); + + // Both maps should be populated, but the actual sharding method used + // should be determined by the sharding_method field + assert_eq!(table.sharding_method, Some(ShardingMethod::Range)); + + // Verify both maps exist + assert!(table.shard_range_map.is_some()); + assert!(table.shard_list_map.is_some()); + + // Check range map + let range_map = table.shard_range_map.unwrap(); + assert_eq!(range_map.0.len(), 2); + + // Check list map + let list_map = table.shard_list_map.unwrap(); + assert_eq!(list_map.0.len(), 2); + } + + #[test] + fn test_load_sharded_table_without_sharding_method() { + let toml_str = r#" + database = "pgdog_sharded" + name = "implicit_hash" + column = "id" + data_type = "bigint" + + [shard_range_map] + "0" = { start = 0, end = 1000 } + "1" = { start = 1000, end = 2000 } + "#; + + let table: ShardedTable = toml::from_str(toml_str).unwrap(); + + // If sharding_method is not specified, it should default to Hash + assert_eq!(table.sharding_method, None); + + // But the range map should still be populated + assert!(table.shard_range_map.is_some()); + let range_map = table.shard_range_map.unwrap(); + assert_eq!(range_map.0.len(), 2); + } + + #[test] + fn test_programmatically_create_and_serialize() { + // Create a ShardedTable with range map programmatically + let mut range_map = HashMap::new(); + range_map.insert( + 0, + ShardRange { + start: Some(0), + end: Some(1000), + no_min: false, + no_max: false, + }, + ); + range_map.insert( + 1, + ShardRange { + start: Some(1000), + end: None, + no_min: false, + no_max: true, + }, + ); + + let shard_range_map = ShardRangeMap(range_map); + + let table = ShardedTable { + database: "pgdog_sharded".to_string(), + name: Some("range_table".to_string()), + column: "id".to_string(), + data_type: crate::config::DataType::Bigint, + sharding_method: Some(ShardingMethod::Range), + shard_range_map: Some(shard_range_map), + shard_list_map: None, + primary: false, + centroids: Vec::new(), + centroids_path: None, + centroid_probes: 0, + }; + + // Serialize to TOML + let toml_str = toml::to_string(&table).unwrap(); + + // Deserialize back to validate + let parsed_table: ShardedTable = toml::from_str(&toml_str).unwrap(); + + // Verify the deserialized structure matches the original + assert_eq!(parsed_table.database, "pgdog_sharded"); + assert_eq!(parsed_table.name, Some("range_table".to_string())); + assert_eq!(parsed_table.sharding_method, Some(ShardingMethod::Range)); + + let parsed_range_map = parsed_table.shard_range_map.unwrap(); + assert_eq!(parsed_range_map.0.len(), 2); + + let range_0 = parsed_range_map.0.get(&0).unwrap(); + assert_eq!(range_0.start, Some(0)); + assert_eq!(range_0.end, Some(1000)); + + let range_1 = parsed_range_map.0.get(&1).unwrap(); + assert_eq!(range_1.start, Some(1000)); + assert_eq!(range_1.end, None); + assert_eq!(range_1.no_max, true); + } } diff --git a/pgdog/src/config/shards.rs b/pgdog/src/config/shards.rs index df10ca0b..a87c269c 100644 --- a/pgdog/src/config/shards.rs +++ b/pgdog/src/config/shards.rs @@ -2,8 +2,7 @@ use std::collections::HashMap; use serde::{Deserialize, Deserializer, Serialize, Serializer}; -use crate::frontend::router::parser::{Shard}; - +use crate::frontend::router::parser::Shard; // ============================================================================= // Serialization Helper Module @@ -18,10 +17,7 @@ mod usize_map_keys_as_strings { S: Serializer, V: Serialize, { - let string_map: HashMap = map - .iter() - .map(|(k, v)| (k.to_string(), v)) - .collect(); + let string_map: HashMap = map.iter().map(|(k, v)| (k.to_string(), v)).collect(); string_map.serialize(serializer) } diff --git a/pgdog/src/frontend/router/sharding/context.rs b/pgdog/src/frontend/router/sharding/context.rs index 051cd91b..c3f1bd3e 100644 --- a/pgdog/src/frontend/router/sharding/context.rs +++ b/pgdog/src/frontend/router/sharding/context.rs @@ -1,5 +1,5 @@ -use crate::frontend::router::parser::Shard; use super::{Error, Operator, Value}; +use crate::frontend::router::parser::Shard; #[derive(Debug)] pub struct Context<'a> { @@ -24,14 +24,14 @@ impl<'a> Context<'a> { return Ok(centroids.shard(&vector, *shards, *probes)); } } - Operator::Ranges(srm)=> { + Operator::Ranges(srm) => { if let Some(i) = self.value.int()? { - return Ok(srm.find_shard_key(i).unwrap()) + return Ok(srm.find_shard_key(i).unwrap()); } } - Operator::Lists(slm)=> { + Operator::Lists(slm) => { if let Some(i) = self.value.int()? { - return Ok(slm.find_shard_key(i).unwrap()) + return Ok(slm.find_shard_key(i).unwrap()); } } } diff --git a/pgdog/src/frontend/router/sharding/context_builder.rs b/pgdog/src/frontend/router/sharding/context_builder.rs index b9330ed9..e552b2d7 100644 --- a/pgdog/src/frontend/router/sharding/context_builder.rs +++ b/pgdog/src/frontend/router/sharding/context_builder.rs @@ -1,4 +1,4 @@ -use crate::config::{DataType, ShardedTable, ShardingMethod, ShardListMap, ShardRangeMap}; +use crate::config::{DataType, ShardListMap, ShardRangeMap, ShardedTable, ShardingMethod}; use super::{Centroids, Context, Data, Error, Operator, Value}; @@ -10,8 +10,7 @@ pub struct ContextBuilder<'a> { probes: usize, sharding_method: Option, shard_range_map: Option, - shard_list_map: Option - + shard_list_map: Option, } impl<'a> ContextBuilder<'a> { @@ -26,8 +25,8 @@ impl<'a> ContextBuilder<'a> { probes: table.centroid_probes, operator: None, value: None, - // added for list and range sharding - // todo: add lifetimes to these to avoid cloning + // added for list and range sharding + // todo: add lifetimes to these to avoid cloning sharding_method: table.sharding_method.clone(), shard_range_map: table.shard_range_map.clone(), shard_list_map: table.shard_list_map.clone(), @@ -77,11 +76,12 @@ impl<'a> ContextBuilder<'a> { match method { ShardingMethod::Hash => { self.operator = Some(Operator::Shards(shards)); - return self + return self; } ShardingMethod::Range => { if self.shard_range_map.is_some() { - self.operator = Some(Operator::Ranges(self.shard_range_map.clone().unwrap())) + self.operator = + Some(Operator::Ranges(self.shard_range_map.clone().unwrap())) } } ShardingMethod::List => { diff --git a/pgdog/src/frontend/router/sharding/error.rs b/pgdog/src/frontend/router/sharding/error.rs index b6b9258f..b24275a4 100644 --- a/pgdog/src/frontend/router/sharding/error.rs +++ b/pgdog/src/frontend/router/sharding/error.rs @@ -25,4 +25,3 @@ pub enum Error { #[error("wrong integer binary size")] IntegerSize, } - diff --git a/pgdog/src/frontend/router/sharding/mod.rs b/pgdog/src/frontend/router/sharding/mod.rs index db9265fc..94807a32 100644 --- a/pgdog/src/frontend/router/sharding/mod.rs +++ b/pgdog/src/frontend/router/sharding/mod.rs @@ -17,7 +17,6 @@ pub mod tables; pub mod value; pub mod vector; - pub use context::*; pub use context_builder::*; pub use error::Error; @@ -28,7 +27,6 @@ pub use vector::{Centroids, Distance}; use super::parser::Shard; - /// Hash `BIGINT`. pub fn bigint(id: i64) -> u64 { unsafe { ffi::hash_combine64(0, ffi::hashint8extended(id)) } @@ -81,27 +79,26 @@ pub(crate) fn shard_value( centroids: &Vec, centroid_probes: usize, ) -> Shard { - match data_type { - DataType::Bigint => value - .parse() - .map(|v| bigint(v) as usize % shards) - .ok() - .map(Shard::Direct) - .unwrap_or(Shard::All), - DataType::Uuid => value - .parse() - .map(|v| uuid(v) as usize % shards) - .ok() - .map(Shard::Direct) - .unwrap_or(Shard::All), - DataType::Vector => Vector::try_from(value) - .ok() - .map(|v| Centroids::from(centroids).shard(&v, shards, centroid_probes)) - .unwrap_or(Shard::All), - } + match data_type { + DataType::Bigint => value + .parse() + .map(|v| bigint(v) as usize % shards) + .ok() + .map(Shard::Direct) + .unwrap_or(Shard::All), + DataType::Uuid => value + .parse() + .map(|v| uuid(v) as usize % shards) + .ok() + .map(Shard::Direct) + .unwrap_or(Shard::All), + DataType::Vector => Vector::try_from(value) + .ok() + .map(|v| Centroids::from(centroids).shard(&v, shards, centroid_probes)) + .unwrap_or(Shard::All), + } } - pub(crate) fn shard_binary( bytes: &[u8], data_type: &DataType, diff --git a/pgdog/src/frontend/router/sharding/operator.rs b/pgdog/src/frontend/router/sharding/operator.rs index 894259d3..fd8c08bc 100644 --- a/pgdog/src/frontend/router/sharding/operator.rs +++ b/pgdog/src/frontend/router/sharding/operator.rs @@ -1,5 +1,5 @@ -use crate::config::{ShardListMap, ShardRangeMap}; use super::Centroids; +use crate::config::{ShardListMap, ShardRangeMap}; #[derive(Debug)] pub enum Operator<'a> { @@ -11,5 +11,4 @@ pub enum Operator<'a> { }, Lists(ShardListMap), Ranges(ShardRangeMap), - } diff --git a/pgdog/src/frontend/router/sharding/value.rs b/pgdog/src/frontend/router/sharding/value.rs index 164ee303..ee9f461d 100644 --- a/pgdog/src/frontend/router/sharding/value.rs +++ b/pgdog/src/frontend/router/sharding/value.rs @@ -70,7 +70,7 @@ impl<'a> Value<'a> { Ok(None) } } - + pub fn int(&self) -> Result, Error> { match self.data_type { DataType::Bigint => match self.data { @@ -114,7 +114,7 @@ impl<'a> Value<'a> { 8 => i64::from_be_bytes(data.try_into()?) as i64, _ => return Err(Error::IntegerSize), }))), - Data::Integer(int) => Ok(Some(bigint(int))), + Data::Integer(int) => Ok(Some(bigint(int))), }, DataType::Uuid => match self.data { @@ -122,9 +122,8 @@ impl<'a> Value<'a> { Data::Binary(data) => Ok(Some(uuid(Uuid::from_bytes(data.try_into()?)))), Data::Integer(_) => Ok(None), }, - + DataType::Vector => Ok(None), - } } }