diff --git a/optd-cost-model/src/common/types.rs b/optd-cost-model/src/common/types.rs index 2c4d4ae..1e92355 100644 --- a/optd-cost-model/src/common/types.rs +++ b/optd-cost-model/src/common/types.rs @@ -49,3 +49,33 @@ impl Display for EpochId { write!(f, "Epoch#{}", self.0) } } + +impl From for i32 { + fn from(id: GroupId) -> i32 { + id.0 as i32 + } +} + +impl From for i32 { + fn from(id: ExprId) -> i32 { + id.0 as i32 + } +} + +impl From for i32 { + fn from(id: TableId) -> i32 { + id.0 as i32 + } +} + +impl From for i32 { + fn from(id: AttrId) -> i32 { + id.0 as i32 + } +} + +impl From for i32 { + fn from(id: EpochId) -> i32 { + id.0 as i32 + } +} diff --git a/optd-cost-model/src/lib.rs b/optd-cost-model/src/lib.rs index f081d4c..a635b66 100644 --- a/optd-cost-model/src/lib.rs +++ b/optd-cost-model/src/lib.rs @@ -2,7 +2,10 @@ use common::{ nodes::{ArcPredicateNode, PhysicalNodeType}, types::{AttrId, EpochId, ExprId, GroupId, TableId}, }; -use optd_persistent::cost_model::interface::{Stat, StatType}; +use optd_persistent::{ + cost_model::interface::{Stat, StatType}, + BackendError, +}; pub mod common; pub mod cost; @@ -32,10 +35,25 @@ pub struct EstimatedStatistic(pub u64); pub type CostModelResult = Result; +#[derive(Debug)] +pub enum SemanticError { + // TODO: Add more error types + UnknownStatisticType, + VersionedStatisticNotFound, + AttributeNotFound(TableId, i32), // (table_id, attribute_base_index) +} + #[derive(Debug)] pub enum CostModelError { // TODO: Add more error types - ORMError, + ORMError(BackendError), + SemanticError(SemanticError), +} + +impl From for CostModelError { + fn from(err: BackendError) -> Self { + CostModelError::ORMError(err) + } } pub trait CostModel: 'static + Send + Sync { diff --git a/optd-cost-model/src/stats/counter.rs b/optd-cost-model/src/stats/counter.rs index baa32ab..65a2d63 100644 --- a/optd-cost-model/src/stats/counter.rs +++ b/optd-cost-model/src/stats/counter.rs @@ -6,7 +6,7 @@ use serde::{Deserialize, Serialize}; /// The Counter structure to track exact frequencies of fixed elements. #[serde_with::serde_as] -#[derive(Serialize, Deserialize, Debug)] +#[derive(Default, Serialize, Deserialize, Debug)] pub struct Counter { #[serde_as(as = "HashMap")] counts: HashMap, // The exact counts of an element T. @@ -33,7 +33,7 @@ where } // Inserts an element in the Counter if it is being tracked. - pub fn insert_element(&mut self, elem: T, occ: i32) { + fn insert_element(&mut self, elem: T, occ: i32) { if let Some(frequency) = self.counts.get_mut(&elem) { *frequency += occ; } diff --git a/optd-cost-model/src/stats/mod.rs b/optd-cost-model/src/stats/mod.rs index 287b20a..0b1396a 100644 --- a/optd-cost-model/src/stats/mod.rs +++ b/optd-cost-model/src/stats/mod.rs @@ -72,6 +72,7 @@ impl MostCommonValues { } #[derive(Serialize, Deserialize, Debug)] +#[serde(tag = "type")] pub enum Distribution { TDigest(tdigest::TDigest), // Add more types here... @@ -116,8 +117,61 @@ impl AttributeCombValueStats { } } -impl From for AttributeCombValueStats { - fn from(value: serde_json::Value) -> Self { - serde_json::from_value(value).unwrap() +#[cfg(test)] +mod tests { + use super::{Counter, MostCommonValues}; + use crate::{common::values::Value, stats::AttributeCombValue}; + use serde_json::json; + + #[test] + fn test_most_common_values() { + let elem1 = vec![Some(Value::Int32(1))]; + let elem2 = vec![Some(Value::Int32(2))]; + let mut counter = Counter::new(&[elem1.clone(), elem2.clone()]); + + let elems = vec![elem2.clone(), elem1.clone(), elem2.clone(), elem2.clone()]; + counter.aggregate(&elems); + + let mcvs = MostCommonValues::Counter(counter); + assert_eq!(mcvs.freq(&elem1), Some(0.25)); + assert_eq!(mcvs.freq(&elem2), Some(0.75)); + assert_eq!(mcvs.total_freq(), 1.0); + + let elem1_cloned = elem1.clone(); + let pred1 = Box::new(move |x: &AttributeCombValue| x == &elem1_cloned); + let pred2 = Box::new(move |x: &AttributeCombValue| x != &elem1); + assert_eq!(mcvs.freq_over_pred(pred1), 0.25); + assert_eq!(mcvs.freq_over_pred(pred2), 0.75); + + assert_eq!(mcvs.cnt(), 2); + } + + #[test] + fn test_most_common_values_serde() { + let elem1 = vec![Some(Value::Int32(1))]; + let elem2 = vec![Some(Value::Int32(2))]; + let mut counter = Counter::new(&[elem1.clone(), elem2.clone()]); + + let elems = vec![elem2.clone(), elem1.clone(), elem2.clone(), elem2.clone()]; + counter.aggregate(&elems); + + let mcvs = MostCommonValues::Counter(counter); + let serialized = serde_json::to_value(&mcvs).unwrap(); + println!("serialized: {:?}", serialized); + + let deserialized: MostCommonValues = serde_json::from_value(serialized).unwrap(); + assert_eq!(mcvs.freq(&elem1), Some(0.25)); + assert_eq!(mcvs.freq(&elem2), Some(0.75)); + assert_eq!(mcvs.total_freq(), 1.0); + + let elem1_cloned = elem1.clone(); + let pred1 = Box::new(move |x: &AttributeCombValue| x == &elem1_cloned); + let pred2 = Box::new(move |x: &AttributeCombValue| x != &elem1); + assert_eq!(mcvs.freq_over_pred(pred1), 0.25); + assert_eq!(mcvs.freq_over_pred(pred2), 0.75); + + assert_eq!(mcvs.cnt(), 2); } + + // TODO: Add tests for Distribution } diff --git a/optd-cost-model/src/storage.rs b/optd-cost-model/src/storage.rs index 2a53d7c..5538618 100644 --- a/optd-cost-model/src/storage.rs +++ b/optd-cost-model/src/storage.rs @@ -1,6 +1,16 @@ +#![allow(unused_variables)] use std::sync::Arc; -use optd_persistent::CostModelStorageLayer; +use optd_persistent::{ + cost_model::interface::{Attr, StatType}, + CostModelStorageLayer, +}; + +use crate::{ + common::types::TableId, + stats::{counter::Counter, AttributeCombValueStats, Distribution, MostCommonValues}, + CostModelResult, +}; /// TODO: documentation pub struct CostModelStorageManager { @@ -9,8 +19,113 @@ pub struct CostModelStorageManager { } impl CostModelStorageManager { - /// TODO: documentation pub fn new(backend_manager: Arc) -> Self { Self { backend_manager } } + + /// Gets the attribute information for a given table and attribute base index. + /// + /// TODO: if we have memory cache, + /// we should add the reference. (&Attr) + pub async fn get_attribute_info( + &self, + table_id: TableId, + attr_base_index: i32, + ) -> CostModelResult> { + Ok(self + .backend_manager + .get_attribute(table_id.into(), attr_base_index) + .await?) + } + + /// Gets the latest statistics for a given table. + /// + /// TODO: Currently, in `AttributeCombValueStats`, only `Distribution` is optional. + /// This poses a question about the behavior of the system if there is no corresponding + /// `MostCommonValues`, `ndistinct`, or other statistics. We should have a clear + /// specification about the behavior of the system in the presence of missing statistics. + /// + /// TODO: if we have memory cache, + /// we should add the reference. (&AttributeCombValueStats) + /// + /// TODO: Shall we pass in an epoch here to make sure that the statistics are from the same + /// epoch? + pub async fn get_attributes_comb_statistics( + &self, + table_id: TableId, + attr_base_indices: &[i32], + ) -> CostModelResult> { + let dist: Option = self + .backend_manager + .get_stats_for_attr_indices_based( + table_id.into(), + attr_base_indices.to_vec(), + StatType::Distribution, + None, + ) + .await? + .map(|json| serde_json::from_value(json).unwrap()); + + let mcvs = self + .backend_manager + .get_stats_for_attr_indices_based( + table_id.into(), + attr_base_indices.to_vec(), + StatType::MostCommonValues, + None, + ) + .await? + .map(|json| serde_json::from_value(json).unwrap()) + .unwrap_or_else(|| MostCommonValues::Counter(Counter::default())); + + let ndistinct = self + .backend_manager + .get_stats_for_attr_indices_based( + table_id.into(), + attr_base_indices.to_vec(), + StatType::Cardinality, + None, + ) + .await? + .map(|json| serde_json::from_value(json).unwrap()) + .unwrap_or(0); + + let table_row_count = self + .backend_manager + .get_stats_for_attr_indices_based( + table_id.into(), + attr_base_indices.to_vec(), + StatType::TableRowCount, + None, + ) + .await? + .map(|json| serde_json::from_value(json).unwrap()) + .unwrap_or(0); + let non_null_count = self + .backend_manager + .get_stats_for_attr_indices_based( + table_id.into(), + attr_base_indices.to_vec(), + StatType::NonNullCount, + None, + ) + .await? + .map(|json| serde_json::from_value(json).unwrap()) + .unwrap_or(0); + + // FIXME: Only minimal checks for invalid values is conducted here. We should have + // much clear specification about the behavior of the system in the presence of + // invalid statistics. + let null_frac = if table_row_count == 0 { + 0.0 + } else { + 1.0 - (non_null_count as f64 / table_row_count as f64) + }; + + Ok(Some(AttributeCombValueStats::new( + mcvs, ndistinct, null_frac, dist, + ))) + } } + +// TODO: add some tests, especially cover the error cases. diff --git a/optd-persistent/src/bin/init.rs b/optd-persistent/src/bin/init.rs index e39bd56..9cc07e2 100644 --- a/optd-persistent/src/bin/init.rs +++ b/optd-persistent/src/bin/init.rs @@ -63,7 +63,7 @@ async fn init_all_tables() -> Result<(), sea_orm::error::DbErr> { name: Set("user_id".to_owned()), compression_method: Set("N".to_owned()), variant_tag: Set(AttrType::Integer as i32), - base_attribute_number: Set(1), + base_attribute_number: Set(0), is_not_null: Set(true), }; let attribute2 = attribute::ActiveModel { @@ -72,7 +72,7 @@ async fn init_all_tables() -> Result<(), sea_orm::error::DbErr> { name: Set("username".to_owned()), compression_method: Set("N".to_owned()), variant_tag: Set(AttrType::Varchar as i32), - base_attribute_number: Set(2), + base_attribute_number: Set(1), is_not_null: Set(true), }; attribute::Entity::insert(attribute1) diff --git a/optd-persistent/src/cost_model/catalog/mock_catalog.rs b/optd-persistent/src/cost_model/catalog/mock_catalog.rs index f79f930..5b2e28e 100644 --- a/optd-persistent/src/cost_model/catalog/mock_catalog.rs +++ b/optd-persistent/src/cost_model/catalog/mock_catalog.rs @@ -115,7 +115,7 @@ impl MockCatalog { let statistics: Vec = vec![ MockStatistic { id: 1, - stat_type: StatType::NotNullCount as i32, + stat_type: StatType::NonNullCount as i32, stat_value: json!(100), attr_ids: vec![1], table_id: None, @@ -123,7 +123,7 @@ impl MockCatalog { }, MockStatistic { id: 2, - stat_type: StatType::NotNullCount as i32, + stat_type: StatType::NonNullCount as i32, stat_value: json!(200), attr_ids: vec![2], table_id: None, diff --git a/optd-persistent/src/cost_model/interface.rs b/optd-persistent/src/cost_model/interface.rs index 92d7e44..a03087f 100644 --- a/optd-persistent/src/cost_model/interface.rs +++ b/optd-persistent/src/cost_model/interface.rs @@ -1,7 +1,6 @@ #![allow(dead_code, unused_imports)] use crate::entities::cascades_group; -use crate::entities::event::Model as event_model; use crate::entities::logical_expression; use crate::entities::physical_expression; use crate::StorageResult; @@ -11,6 +10,13 @@ use sea_orm_migration::prelude::*; use serde_json::json; use std::sync::Arc; +pub type GroupId = i32; +pub type TableId = i32; +pub type AttrId = i32; +pub type ExprId = i32; +pub type EpochId = i32; +pub type StatId = i32; + /// TODO: documentation pub enum CatalogSource { Iceberg(), @@ -40,13 +46,22 @@ pub enum ConstraintType { } /// TODO: documentation +#[derive(Copy, Clone, Debug, PartialEq)] pub enum StatType { - /// `TableRowCount` only applies to table statistics. + /// The row count in a table. `TableRowCount` only applies to table statistics. TableRowCount, - NotNullCount, + /// The number of non-null values in a column. + NonNullCount, + /// The number of distinct values in a column. Cardinality, + /// The minimum value in a column. Min, + /// The maximum value in a column. Max, + /// The frequency of each value in a column. + MostCommonValues, + /// The distribution of values in a column. + Distribution, } /// TODO: documentation @@ -58,9 +73,9 @@ pub enum EpochOption { } /// TODO: documentation -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct Stat { - pub stat_type: i32, + pub stat_type: StatType, pub stat_value: Json, pub attr_ids: Vec, pub table_id: Option, @@ -76,38 +91,36 @@ pub struct Cost { pub estimated_statistic: i32, } +#[derive(Clone, Debug)] +pub struct Attr { + pub table_id: i32, + pub name: String, + pub compression_method: String, + pub attr_type: i32, + pub base_index: i32, + pub nullable: bool, +} + /// TODO: documentation #[trait_variant::make(Send)] pub trait CostModelStorageLayer { - type GroupId; - type TableId; - type AttrId; - type ExprId; - type EpochId; - type StatId; + async fn create_new_epoch(&self, source: String, data: String) -> StorageResult; - // TODO: Change EpochId to event::Model::epoch_id - async fn create_new_epoch(&self, source: String, data: String) -> StorageResult; - - async fn update_stats_from_catalog(&self, c: CatalogSource) -> StorageResult; + async fn update_stats_from_catalog(&self, c: CatalogSource) -> StorageResult; async fn update_stats( &self, stat: Stat, epoch_option: EpochOption, - ) -> StorageResult>; + ) -> StorageResult>; - async fn store_cost( - &self, - expr_id: Self::ExprId, - cost: Cost, - epoch_id: Self::EpochId, - ) -> StorageResult<()>; + async fn store_cost(&self, expr_id: ExprId, cost: Cost, epoch_id: EpochId) + -> StorageResult<()>; async fn store_expr_stats_mappings( &self, - expr_id: Self::ExprId, - stat_ids: Vec, + expr_id: ExprId, + stat_ids: Vec, ) -> StorageResult<()>; /// Get the statistics for a given table. @@ -115,10 +128,9 @@ pub trait CostModelStorageLayer { /// If `epoch_id` is None, it will return the latest statistics. async fn get_stats_for_table( &self, - table_id: Self::TableId, - // TODO: Add enum for stat_type - stat_type: i32, - epoch_id: Option, + table_id: TableId, + stat_type: StatType, + epoch_id: Option, ) -> StorageResult>; /// Get the (joint) statistics for one or more attributes. @@ -126,16 +138,33 @@ pub trait CostModelStorageLayer { /// If `epoch_id` is None, it will return the latest statistics. async fn get_stats_for_attr( &self, - attr_ids: Vec, - stat_type: i32, - epoch_id: Option, + attr_ids: Vec, + stat_type: StatType, + epoch_id: Option, + ) -> StorageResult>; + + /// Get the (joint) statistics for one or more attributes based on attribute base indices. + /// + /// If `epoch_id` is None, it will return the latest statistics. + async fn get_stats_for_attr_indices_based( + &self, + table_id: TableId, + attr_base_indices: Vec, + stat_type: StatType, + epoch_id: Option, ) -> StorageResult>; async fn get_cost_analysis( &self, - expr_id: Self::ExprId, - epoch_id: Self::EpochId, + expr_id: ExprId, + epoch_id: EpochId, ) -> StorageResult>; - async fn get_cost(&self, expr_id: Self::ExprId) -> StorageResult>; + async fn get_cost(&self, expr_id: ExprId) -> StorageResult>; + + async fn get_attribute( + &self, + table_id: TableId, + attribute_base_index: i32, + ) -> StorageResult>; } diff --git a/optd-persistent/src/cost_model/orm.rs b/optd-persistent/src/cost_model/orm.rs index b6d1cdc..d172c14 100644 --- a/optd-persistent/src/cost_model/orm.rs +++ b/optd-persistent/src/cost_model/orm.rs @@ -1,28 +1,27 @@ #![allow(dead_code, unused_imports, unused_variables)] -use std::ptr::null; - use crate::cost_model::interface::Cost; use crate::entities::{prelude::*, *}; -use crate::{BackendError, BackendManager, CostModelError, CostModelStorageLayer, StorageResult}; +use crate::{BackendError, BackendManager, CostModelStorageLayer, StorageResult}; use sea_orm::prelude::{Expr, Json}; use sea_orm::sea_query::Query; use sea_orm::{sqlx::types::chrono::Utc, EntityTrait}; use sea_orm::{ - ActiveModelTrait, ColumnTrait, DbBackend, DbErr, DeleteResult, EntityOrSelect, ModelTrait, - QueryFilter, QueryOrder, QuerySelect, QueryTrait, RuntimeErr, TransactionTrait, + ActiveModelTrait, ColumnTrait, Condition, DbBackend, DbErr, DeleteResult, EntityOrSelect, + ModelTrait, QueryFilter, QueryOrder, QuerySelect, QueryTrait, RuntimeErr, TransactionTrait, }; use serde_json::json; use super::catalog::mock_catalog::{self, MockCatalog}; -use super::interface::{CatalogSource, EpochOption, Stat}; +use super::interface::{ + Attr, AttrId, CatalogSource, EpochId, EpochOption, ExprId, Stat, StatId, StatType, TableId, +}; impl BackendManager { - fn get_description_from_attr_ids( - &self, - attr_ids: Vec<::AttrId>, - ) -> String { - let mut attr_ids = attr_ids; + /// The description is to concat `attr_ids` using commas + /// Note that `attr_ids` should be sorted before concatenation + /// e.g. [1, 2, 3] -> "1,2,3" + fn get_description_from_attr_ids(&self, mut attr_ids: Vec) -> String { attr_ids.sort(); attr_ids .iter() @@ -33,15 +32,8 @@ impl BackendManager { } impl CostModelStorageLayer for BackendManager { - type GroupId = i32; - type TableId = i32; - type AttrId = i32; - type ExprId = i32; - type EpochId = i32; - type StatId = i32; - /// TODO: documentation - async fn create_new_epoch(&self, source: String, data: String) -> StorageResult { + async fn create_new_epoch(&self, source: String, data: String) -> StorageResult { let new_event = event::ActiveModel { source_variant: sea_orm::ActiveValue::Set(source), timestamp: sea_orm::ActiveValue::Set(Utc::now()), @@ -53,7 +45,7 @@ impl CostModelStorageLayer for BackendManager { } /// TODO: documentation - async fn update_stats_from_catalog(&self, c: CatalogSource) -> StorageResult { + async fn update_stats_from_catalog(&self, c: CatalogSource) -> StorageResult { let transaction = self.db.begin().await?; let source = match c { CatalogSource::Mock => "Mock", @@ -210,7 +202,7 @@ impl CostModelStorageLayer for BackendManager { &self, stat: Stat, epoch_option: EpochOption, - ) -> StorageResult> { + ) -> StorageResult> { let transaction = self.db.begin().await?; // 0. Check if the stat already exists. If exists, get stat_id, else insert into statistic table. let stat_id = match stat.table_id { @@ -238,7 +230,7 @@ impl CostModelStorageLayer for BackendManager { stat.attr_ids.len() as i32 ), creation_time: sea_orm::ActiveValue::Set(Utc::now()), - variant_tag: sea_orm::ActiveValue::Set(stat.stat_type), + variant_tag: sea_orm::ActiveValue::Set(stat.stat_type as i32), description: sea_orm::ActiveValue::Set("".to_string()), ..Default::default() }; @@ -246,10 +238,9 @@ impl CostModelStorageLayer for BackendManager { match res { Ok(insert_res) => insert_res.last_insert_id, Err(_) => { - return Err(BackendError::Database(DbErr::Exec( - RuntimeErr::Internal( - "Failed to insert into statistic table".to_string(), - ), + return Err(BackendError::BackendError(format!( + "failed to insert statistic {:?} into statistic table", + stat ))) } } @@ -261,7 +252,7 @@ impl CostModelStorageLayer for BackendManager { let res = Statistic::find() .filter(statistic::Column::NumberOfAttributes.eq(stat.attr_ids.len() as i32)) .filter(statistic::Column::Description.eq(description.clone())) - .filter(statistic::Column::VariantTag.eq(stat.stat_type)) + .filter(statistic::Column::VariantTag.eq(stat.stat_type as i32)) .inner_join(versioned_statistic::Entity) .select_also(versioned_statistic::Entity) .order_by_desc(versioned_statistic::Column::EpochId) @@ -281,7 +272,7 @@ impl CostModelStorageLayer for BackendManager { stat.attr_ids.len() as i32 ), creation_time: sea_orm::ActiveValue::Set(Utc::now()), - variant_tag: sea_orm::ActiveValue::Set(stat.stat_type), + variant_tag: sea_orm::ActiveValue::Set(stat.stat_type as i32), description: sea_orm::ActiveValue::Set(description), ..Default::default() }; @@ -356,8 +347,8 @@ impl CostModelStorageLayer for BackendManager { /// TODO: documentation async fn store_expr_stats_mappings( &self, - expr_id: Self::ExprId, - stat_ids: Vec, + expr_id: ExprId, + stat_ids: Vec, ) -> StorageResult<()> { let to_insert_mappings = stat_ids .iter() @@ -377,16 +368,16 @@ impl CostModelStorageLayer for BackendManager { /// TODO: documentation async fn get_stats_for_table( &self, - table_id: i32, - stat_type: i32, - epoch_id: Option, + table_id: TableId, + stat_type: StatType, + epoch_id: Option, ) -> StorageResult> { match epoch_id { Some(epoch_id) => Ok(VersionedStatistic::find() .filter(versioned_statistic::Column::EpochId.eq(epoch_id)) .inner_join(statistic::Entity) .filter(statistic::Column::TableId.eq(table_id)) - .filter(statistic::Column::VariantTag.eq(stat_type)) + .filter(statistic::Column::VariantTag.eq(stat_type as i32)) .one(&self.db) .await? .map(|stat| stat.statistic_value)), @@ -394,7 +385,7 @@ impl CostModelStorageLayer for BackendManager { None => Ok(VersionedStatistic::find() .inner_join(statistic::Entity) .filter(statistic::Column::TableId.eq(table_id)) - .filter(statistic::Column::VariantTag.eq(stat_type)) + .filter(statistic::Column::VariantTag.eq(stat_type as i32)) .order_by_desc(versioned_statistic::Column::EpochId) .one(&self.db) .await? @@ -405,14 +396,11 @@ impl CostModelStorageLayer for BackendManager { /// TODO: documentation async fn get_stats_for_attr( &self, - mut attr_ids: Vec, - stat_type: i32, - epoch_id: Option, + mut attr_ids: Vec, + stat_type: StatType, + epoch_id: Option, ) -> StorageResult> { let attr_num = attr_ids.len() as i32; - // The description is to concat `attr_ids` using commas - // Note that `attr_ids` should be sorted before concatenation - // e.g. [1, 2, 3] -> "1,2,3" attr_ids.sort(); let description = self.get_description_from_attr_ids(attr_ids); @@ -423,7 +411,7 @@ impl CostModelStorageLayer for BackendManager { .inner_join(statistic::Entity) .filter(statistic::Column::NumberOfAttributes.eq(attr_num)) .filter(statistic::Column::Description.eq(description)) - .filter(statistic::Column::VariantTag.eq(stat_type)) + .filter(statistic::Column::VariantTag.eq(stat_type as i32)) .one(&self.db) .await? .map(|stat| stat.statistic_value)), @@ -432,7 +420,7 @@ impl CostModelStorageLayer for BackendManager { .inner_join(statistic::Entity) .filter(statistic::Column::NumberOfAttributes.eq(attr_num)) .filter(statistic::Column::Description.eq(description)) - .filter(statistic::Column::VariantTag.eq(stat_type)) + .filter(statistic::Column::VariantTag.eq(stat_type as i32)) .order_by_desc(versioned_statistic::Column::EpochId) .one(&self.db) .await? @@ -440,11 +428,42 @@ impl CostModelStorageLayer for BackendManager { } } + async fn get_stats_for_attr_indices_based( + &self, + table_id: TableId, + attr_base_indices: Vec, + stat_type: StatType, + epoch_id: Option, + ) -> StorageResult> { + // Get the attribute ids based on table id and attribute base indices + let mut condition = Condition::any(); + for attr_base_index in &attr_base_indices { + condition = condition.add(attribute::Column::BaseAttributeNumber.eq(*attr_base_index)); + } + let attr_ids = Attribute::find() + .filter(attribute::Column::TableId.eq(table_id)) + .filter(condition) + .all(&self.db) + .await? + .iter() + .map(|attr| attr.id) + .collect::>(); + + if attr_ids.len() != attr_base_indices.len() { + return Err(BackendError::BackendError(format!( + "Not all attributes found for table_id {} and base indices {:?}", + table_id, attr_base_indices + ))); + } + + self.get_stats_for_attr(attr_ids, stat_type, epoch_id).await + } + /// TODO: documentation async fn get_cost_analysis( &self, - expr_id: Self::ExprId, - epoch_id: Self::EpochId, + expr_id: ExprId, + epoch_id: EpochId, ) -> StorageResult> { let cost = PlanCost::find() .filter(plan_cost::Column::PhysicalExpressionId.eq(expr_id)) @@ -460,7 +479,7 @@ impl CostModelStorageLayer for BackendManager { })) } - async fn get_cost(&self, expr_id: Self::ExprId) -> StorageResult> { + async fn get_cost(&self, expr_id: ExprId) -> StorageResult> { let cost = PlanCost::find() .filter(plan_cost::Column::PhysicalExpressionId.eq(expr_id)) .order_by_desc(plan_cost::Column::EpochId) @@ -478,16 +497,17 @@ impl CostModelStorageLayer for BackendManager { /// TODO: documentation async fn store_cost( &self, - physical_expression_id: Self::ExprId, + physical_expression_id: ExprId, cost: Cost, - epoch_id: Self::EpochId, + epoch_id: EpochId, ) -> StorageResult<()> { let expr_exists = PhysicalExpression::find_by_id(physical_expression_id) .one(&self.db) .await?; if expr_exists.is_none() { - return Err(BackendError::Database(DbErr::RecordNotFound( - "ExprId not found in PhysicalExpression table".to_string(), + return Err(BackendError::BackendError(format!( + "physical expression id {} not found when storing cost", + physical_expression_id ))); } @@ -498,8 +518,9 @@ impl CostModelStorageLayer for BackendManager { .await .unwrap(); if epoch_exists.is_none() { - return Err(BackendError::Database(DbErr::RecordNotFound( - "EpochId not found in Event table".to_string(), + return Err(BackendError::BackendError(format!( + "epoch id {} not found when storing cost", + epoch_id ))); } @@ -516,6 +537,26 @@ impl CostModelStorageLayer for BackendManager { let _ = PlanCost::insert(new_cost).exec(&self.db).await?; Ok(()) } + + async fn get_attribute( + &self, + table_id: TableId, + attribute_base_index: i32, + ) -> StorageResult> { + Ok(Attribute::find() + .filter(attribute::Column::TableId.eq(table_id)) + .filter(attribute::Column::BaseAttributeNumber.eq(attribute_base_index)) + .one(&self.db) + .await? + .map(|attr| Attr { + table_id, + name: attr.name, + compression_method: attr.compression_method, + attr_type: attr.variant_tag, + base_index: attribute_base_index, + nullable: !attr.is_not_null, + })) + } } #[cfg(test)] @@ -613,12 +654,12 @@ mod tests { assert_eq!(lookup_res.len(), 3); let stat_res = backend_manager - .get_stats_for_table(1, StatType::TableRowCount as i32, Some(epoch_id)) + .get_stats_for_table(1, StatType::TableRowCount, Some(epoch_id)) .await; assert!(stat_res.is_ok()); assert_eq!(stat_res.unwrap().unwrap(), json!(300)); let stat_res = backend_manager - .get_stats_for_attr([2].to_vec(), StatType::NotNullCount as i32, None) + .get_stats_for_attr([2].to_vec(), StatType::NonNullCount, None) .await; assert!(stat_res.is_ok()); assert_eq!(stat_res.unwrap().unwrap(), json!(200)); @@ -638,7 +679,7 @@ mod tests { .await .unwrap(); let stat = Stat { - stat_type: StatType::NotNullCount as i32, + stat_type: StatType::NonNullCount, stat_value: json!(100), attr_ids: vec![1], table_id: None, @@ -657,7 +698,7 @@ mod tests { println!("{:?}", stat_res); assert_eq!(stat_res[0].number_of_attributes, 1); assert_eq!(stat_res[0].description, "1".to_string()); - assert_eq!(stat_res[0].variant_tag, StatType::NotNullCount as i32); + assert_eq!(stat_res[0].variant_tag, StatType::NonNullCount as i32); let stat_attr_res = StatisticToAttributeJunction::find() .filter(statistic_to_attribute_junction::Column::StatisticId.eq(stat_res[0].id)) .all(&backend_manager.db) @@ -720,7 +761,7 @@ mod tests { .await .unwrap(); let stat2 = Stat { - stat_type: StatType::NotNullCount as i32, + stat_type: StatType::NonNullCount, stat_value: json!(200), attr_ids: vec![1], table_id: None, @@ -774,7 +815,7 @@ mod tests { // 3. Update existed stat with the same value let epoch_num = Event::find().all(&backend_manager.db).await.unwrap().len(); let stat3 = Stat { - stat_type: StatType::NotNullCount as i32, + stat_type: StatType::NonNullCount, stat_value: json!(200), attr_ids: vec![1], table_id: None, @@ -815,21 +856,21 @@ mod tests { let statistics: Vec = vec![ Stat { - stat_type: StatType::TableRowCount as i32, + stat_type: StatType::TableRowCount, stat_value: json!(0), attr_ids: vec![], table_id: Some(1), name: "row_count".to_string(), }, Stat { - stat_type: StatType::TableRowCount as i32, + stat_type: StatType::TableRowCount, stat_value: json!(20), attr_ids: vec![], table_id: Some(1), name: "row_count".to_string(), }, Stat { - stat_type: StatType::TableRowCount as i32, + stat_type: StatType::TableRowCount, stat_value: json!(100), attr_ids: vec![], table_id: Some(table_inserted_res.last_insert_id), @@ -1033,7 +1074,7 @@ mod tests { let backend_manager = binding.as_mut().unwrap(); let epoch_id = 1; let table_id = 1; - let stat_type = StatType::TableRowCount as i32; + let stat_type = StatType::TableRowCount; // Get initial stats let res = backend_manager @@ -1050,7 +1091,7 @@ mod tests { .await .unwrap(); let stat = Stat { - stat_type: StatType::TableRowCount as i32, + stat_type: StatType::TableRowCount, stat_value: json!(100), attr_ids: vec![], table_id: Some(table_id), @@ -1090,7 +1131,7 @@ mod tests { let backend_manager = binding.as_mut().unwrap(); let epoch_id = 1; let attr_ids = vec![1]; - let stat_type = StatType::Cardinality as i32; + let stat_type = StatType::Cardinality; // Get initial stats let res = backend_manager @@ -1110,7 +1151,7 @@ mod tests { .await .unwrap(); let stat = Stat { - stat_type: StatType::Cardinality as i32, + stat_type: StatType::Cardinality, stat_value: json!(100), attr_ids: attr_ids.clone(), table_id: None, @@ -1150,7 +1191,7 @@ mod tests { let backend_manager = binding.as_mut().unwrap(); let epoch_id = 1; let attr_ids = vec![2, 1]; - let stat_type = StatType::Cardinality as i32; + let stat_type = StatType::Cardinality; // Get initial stats let res = backend_manager @@ -1170,7 +1211,7 @@ mod tests { .await .unwrap(); let stat = Stat { - stat_type: StatType::Cardinality as i32, + stat_type: StatType::Cardinality, stat_value: json!(111), attr_ids: attr_ids.clone(), table_id: None, @@ -1201,4 +1242,42 @@ mod tests { remove_db_file(DATABASE_FILE); } + + #[tokio::test] + async fn test_get_stats_for_attr_indices_based() { + const DATABASE_FILE: &str = "test_get_stats_for_attr_indices_based.db"; + let database_url = copy_init_db(DATABASE_FILE).await; + let mut binding = super::BackendManager::new(Some(&database_url)).await; + let backend_manager = binding.as_mut().unwrap(); + let epoch_id = 1; + let table_id = 1; + let attr_base_indices = vec![0, 1]; + let stat_type = StatType::Cardinality; + + // Statistics exist in the database + let res = backend_manager + .get_stats_for_attr_indices_based(table_id, attr_base_indices.clone(), stat_type, None) + .await + .unwrap() + .unwrap(); + let cardinality = res.as_i64().unwrap(); + assert_eq!(cardinality, 0); + + // Statistics do not exist in the database + let attr_base_indices = vec![1]; + let res = backend_manager + .get_stats_for_attr_indices_based(table_id, attr_base_indices.clone(), stat_type, None) + .await + .unwrap(); + assert!(res.is_none()); + + // Attribute base indices not valid. + let attr_base_indices = vec![1, 2]; + let res = backend_manager + .get_stats_for_attr_indices_based(table_id, attr_base_indices.clone(), stat_type, None) + .await; + assert!(res.is_err()); + + remove_db_file(DATABASE_FILE); + } } diff --git a/optd-persistent/src/db/init.db b/optd-persistent/src/db/init.db index 6395ac9..5350952 100644 Binary files a/optd-persistent/src/db/init.db and b/optd-persistent/src/db/init.db differ diff --git a/optd-persistent/src/lib.rs b/optd-persistent/src/lib.rs index 2638940..9bac1b6 100644 --- a/optd-persistent/src/lib.rs +++ b/optd-persistent/src/lib.rs @@ -41,29 +41,21 @@ fn get_sqlite_url(file: &str) -> String { pub type StorageResult = Result; -#[derive(Debug)] -pub enum CostModelError { - // TODO: Add more error types - UnknownStatisticType, - VersionedStatisticNotFound, -} - #[derive(Debug)] pub enum BackendError { - CostModel(CostModelError), - Database(DbErr), - // TODO: Add other variants as needed for different error types + DatabaseError(DbErr), + BackendError(String), } -impl From for BackendError { - fn from(value: CostModelError) -> Self { - BackendError::CostModel(value) +impl From for BackendError { + fn from(value: String) -> Self { + BackendError::BackendError(value) } } impl From for BackendError { fn from(value: DbErr) -> Self { - BackendError::Database(value) + BackendError::DatabaseError(value) } }