diff --git a/optd-persistent/src/bin/migrate_test.rs b/optd-persistent/src/bin/migrate_test.rs new file mode 100644 index 0000000..f4c9001 --- /dev/null +++ b/optd-persistent/src/bin/migrate_test.rs @@ -0,0 +1,23 @@ +use optd_persistent::{migrate, TEST_DATABASE_URL}; +use sea_orm::*; +use sea_orm_migration::prelude::*; + +#[tokio::main] +async fn main() { + let _ = std::fs::remove_file(TEST_DATABASE_URL); + + let db = Database::connect(TEST_DATABASE_URL) + .await + .expect("Unable to connect to the database"); + + migrate(&db) + .await + .expect("Something went wrong during migration"); + + db.execute(sea_orm::Statement::from_string( + sea_orm::DatabaseBackend::Sqlite, + "PRAGMA foreign_keys = ON;".to_owned(), + )) + .await + .expect("Unable to enable foreign keys"); +} diff --git a/optd-persistent/src/entities/prelude.rs b/optd-persistent/src/entities/prelude.rs index 3c76253..fd96671 100644 --- a/optd-persistent/src/entities/prelude.rs +++ b/optd-persistent/src/entities/prelude.rs @@ -1,4 +1,5 @@ //! `SeaORM` Entity, @generated by sea-orm-codegen 1.1.0 +#![allow(dead_code, unused_imports, unused_variables)] pub use super::attribute::Entity as Attribute; pub use super::attribute_constraint_junction::Entity as AttributeConstraintJunction; diff --git a/optd-persistent/src/lib.rs b/optd-persistent/src/lib.rs index c11bafa..8e1e149 100644 --- a/optd-persistent/src/lib.rs +++ b/optd-persistent/src/lib.rs @@ -1,11 +1,15 @@ use sea_orm::*; use sea_orm_migration::prelude::*; +mod entities; mod migrator; +mod orm_manager; +mod storage_layer; use migrator::Migrator; pub const DATABASE_URL: &str = "sqlite:./sqlite.db?mode=rwc"; pub const DATABASE_FILE: &str = "./sqlite.db"; +pub const TEST_DATABASE_URL: &str = "sqlite:./test.db?mode=rwc"; pub async fn migrate(db: &DatabaseConnection) -> Result<(), DbErr> { Migrator::refresh(db).await diff --git a/optd-persistent/src/orm_manager.rs b/optd-persistent/src/orm_manager.rs new file mode 100644 index 0000000..5a9f921 --- /dev/null +++ b/optd-persistent/src/orm_manager.rs @@ -0,0 +1,397 @@ +#![allow(dead_code, unused_imports, unused_variables)] + +use crate::entities::{prelude::*, *}; +use crate::orm_manager::{Event, PlanCost}; +use crate::storage_layer::{self, EpochId, StorageLayer, StorageResult}; +use crate::DATABASE_URL; +use sea_orm::*; +use sea_query::Expr; +use sqlx::types::chrono::Utc; + +pub struct ORMManager { + db_conn: DatabaseConnection, + // TODO: Change EpochId to event::Model::epoch_id + latest_epoch_id: EpochId, +} + +impl ORMManager { + pub async fn new(database_url: Option<&str>) -> Self { + let latest_epoch_id = -1; + let db_conn = Database::connect(database_url.unwrap_or(DATABASE_URL)) + .await + .unwrap(); + Self { + db_conn, + latest_epoch_id, + } + } +} + +impl StorageLayer for ORMManager { + async fn create_new_epoch( + &mut self, + source: String, + data: String, + ) -> StorageResult { + let new_event = event::ActiveModel { + source_variant: sea_orm::ActiveValue::Set(source), + timestamp: sea_orm::ActiveValue::Set(Utc::now()), + data: sea_orm::ActiveValue::Set(sea_orm::JsonValue::String(data)), + ..Default::default() + }; + let res = Event::insert(new_event).exec(&self.db_conn).await; + res.and_then(|insert_res| { + self.latest_epoch_id = insert_res.last_insert_id; + Ok(self.latest_epoch_id) + }) + } + + async fn update_stats_from_catalog( + &self, + c: storage_layer::CatalogSource, + epoch_id: storage_layer::EpochId, + ) -> StorageResult<()> { + todo!() + } + + async fn update_stats( + &self, + stats: i32, + epoch_id: storage_layer::EpochId, + ) -> StorageResult<()> { + todo!() + } + + async fn store_cost( + &self, + expr_id: storage_layer::ExprId, + cost: i32, + epoch_id: storage_layer::EpochId, + ) -> StorageResult<()> { + // TODO: update PhysicalExpression and Event tables + // Check if expr_id exists in PhysicalExpression table + let expr_exists = PhysicalExpression::find_by_id(expr_id) + .one(&self.db_conn) + .await?; + if expr_exists.is_none() { + return Err(DbErr::RecordNotFound( + "ExprId not found in PhysicalExpression table".to_string(), + )); + } + + // Check if epoch_id exists in Event table + let epoch_exists = Event::find() + .filter(event::Column::EpochId.eq(epoch_id)) + .one(&self.db_conn) + .await + .unwrap(); + if epoch_exists.is_none() { + return Err(DbErr::RecordNotFound( + "EpochId not found in Event table".to_string(), + )); + } + + let new_cost = plan_cost::ActiveModel { + physical_expression_id: ActiveValue::Set(expr_id), + epoch_id: ActiveValue::Set(epoch_id), + cost: ActiveValue::Set(cost), + is_valid: ActiveValue::Set(true), + ..Default::default() + }; + PlanCost::insert(new_cost) + .exec(&self.db_conn) + .await + .map(|_| ()) + } + + async fn get_stats_for_table( + &self, + table_id: i32, + stat_type: i32, + epoch_id: Option, + ) -> StorageResult> { + match epoch_id { + Some(epoch_id) => Statistic::find() + .filter(statistic::Column::TableId.eq(table_id)) + .filter(statistic::Column::StatisticType.eq(stat_type)) + .filter(statistic::Column::EpochId.eq(epoch_id)) + .one(&self.db_conn) + .await + .map(|stat| stat.map(|s| s.statistic_value)), + + None => Statistic::find() + .filter(statistic::Column::TableId.eq(table_id)) + .filter(statistic::Column::StatisticType.eq(stat_type)) + .order_by_desc(statistic::Column::EpochId) + .one(&self.db_conn) + .await + .map(|stat| stat.map(|s| s.statistic_value)), + } + } + + async fn get_stats_for_attr( + &self, + attr_id: i32, + stat_type: i32, + epoch_id: Option, + ) -> StorageResult> { + match epoch_id { + Some(epoch_id) => Statistic::find() + .filter(statistic::Column::NumberOfAttributes.eq(1)) + .filter(statistic::Column::StatisticType.eq(stat_type)) + .filter(statistic::Column::EpochId.eq(epoch_id)) + .inner_join(statistic_to_attribute_junction::Entity) + .filter(statistic_to_attribute_junction::Column::AttributeId.eq(attr_id)) + .one(&self.db_conn) + .await + .map(|stat| stat.map(|s| s.statistic_value)), + + None => Statistic::find() + .filter(statistic::Column::NumberOfAttributes.eq(1)) + .filter(statistic::Column::StatisticType.eq(stat_type)) + .inner_join(statistic_to_attribute_junction::Entity) + .filter(statistic_to_attribute_junction::Column::AttributeId.eq(attr_id)) + .order_by_desc(statistic::Column::EpochId) + .one(&self.db_conn) + .await + .map(|stat| stat.map(|s| s.statistic_value)), + } + } + + async fn get_stats_for_attrs( + &self, + attr_ids: Vec, + stat_type: i32, + epoch_id: Option, + ) -> StorageResult> { + let attr_count = attr_ids.len() as i32; + match epoch_id { + Some(epoch_id) => Statistic::find() + .filter(statistic::Column::NumberOfAttributes.eq(attr_count)) + .filter(statistic::Column::StatisticType.eq(stat_type)) + .filter(statistic::Column::EpochId.eq(epoch_id)) + .inner_join(statistic_to_attribute_junction::Entity) + .filter(statistic_to_attribute_junction::Column::AttributeId.is_in(attr_ids)) + .group_by(statistic::Column::Id) + .having(Expr::col(statistic::Column::Name).count().eq(attr_count)) + .one(&self.db_conn) + .await + .map(|stat| stat.map(|s| s.statistic_value)), + + None => Statistic::find() + .filter(statistic::Column::NumberOfAttributes.eq(attr_count)) + .filter(statistic::Column::StatisticType.eq(stat_type)) + .inner_join(statistic_to_attribute_junction::Entity) + .filter(statistic_to_attribute_junction::Column::AttributeId.is_in(attr_ids)) + .group_by(statistic::Column::Id) + .having(Expr::col(statistic::Column::Name).count().eq(attr_count)) + .order_by_desc(statistic::Column::EpochId) + .one(&self.db_conn) + .await + .map(|stat| stat.map(|s| s.statistic_value)), + } + } + + async fn get_cost_analysis( + &self, + expr_id: storage_layer::ExprId, + epoch_id: storage_layer::EpochId, + ) -> StorageResult> { + let cost = PlanCost::find() + .filter(plan_cost::Column::PhysicalExpressionId.eq(expr_id)) + .filter(plan_cost::Column::EpochId.eq(epoch_id)) + .one(&self.db_conn) + .await?; + assert!(cost.is_some(), "Cost not found in Cost table"); + assert!(cost.clone().unwrap().is_valid, "Cost is not valid"); + Ok(cost.map(|c| c.cost)) + } + + /// Get the latest cost for an expression + async fn get_cost(&self, expr_id: storage_layer::ExprId) -> StorageResult> { + let cost = PlanCost::find() + .filter(plan_cost::Column::PhysicalExpressionId.eq(expr_id)) + .order_by_desc(plan_cost::Column::EpochId) + .one(&self.db_conn) + .await?; + assert!(cost.is_some(), "Cost not found in Cost table"); + assert!(cost.clone().unwrap().is_valid, "Cost is not valid"); + Ok(cost.map(|c| c.cost)) + } + + async fn get_group_winner_from_group_id( + &self, + group_id: i32, + ) -> StorageResult> { + todo!() + } + + async fn add_new_expr( + &mut self, + expr: storage_layer::Expression, + ) -> StorageResult<(storage_layer::GroupId, storage_layer::ExprId)> { + todo!() + } + + async fn add_expr_to_group( + &mut self, + expr: storage_layer::Expression, + group_id: storage_layer::GroupId, + ) -> StorageResult> { + todo!() + } + + async fn get_group_id( + &self, + expr_id: storage_layer::ExprId, + ) -> StorageResult { + todo!() + } + + async fn get_expr_memoed( + &self, + expr_id: storage_layer::ExprId, + ) -> StorageResult { + todo!() + } + + async fn get_all_group_ids(&self) -> StorageResult> { + todo!() + } + + async fn get_group( + &self, + group_id: storage_layer::GroupId, + ) -> StorageResult { + todo!() + } + + async fn update_group_winner( + &mut self, + group_id: storage_layer::GroupId, + latest_winner: Option, + ) -> StorageResult<()> { + todo!() + } + + async fn get_all_exprs_in_group( + &self, + group_id: storage_layer::GroupId, + ) -> StorageResult> { + todo!() + } + + async fn get_group_info( + &self, + group_id: storage_layer::GroupId, + ) -> StorageResult<&Option> { + todo!() + } + + async fn get_predicate_binding( + &self, + group_id: storage_layer::GroupId, + ) -> StorageResult> { + todo!() + } + + async fn try_get_predicate_binding( + &self, + group_id: storage_layer::GroupId, + ) -> StorageResult> { + todo!() + } +} + +#[cfg(test)] +mod tests { + use crate::migrate; + use sea_orm::{ConnectionTrait, Database, EntityTrait, ModelTrait}; + use serde_json::de; + + use crate::entities::event::Entity as Event; + use crate::storage_layer::StorageLayer; + use crate::TEST_DATABASE_URL; + + async fn run_migration() { + let _ = std::fs::remove_file(TEST_DATABASE_URL); + + let db = Database::connect(TEST_DATABASE_URL) + .await + .expect("Unable to connect to the database"); + + migrate(&db) + .await + .expect("Something went wrong during migration"); + + db.execute(sea_orm::Statement::from_string( + sea_orm::DatabaseBackend::Sqlite, + "PRAGMA foreign_keys = ON;".to_owned(), + )) + .await + .expect("Unable to enable foreign keys"); + } + + #[tokio::test] + async fn test_create_new_epoch() { + run_migration().await; + let mut orm_manager = super::ORMManager::new(Some(TEST_DATABASE_URL)).await; + let res = orm_manager + .create_new_epoch("source".to_string(), "data".to_string()) + .await; + println!("{:?}", res); + assert!(res.is_ok()); + assert_eq!( + super::Event::find() + .all(&orm_manager.db_conn) + .await + .unwrap() + .len(), + 1 + ); + println!( + "{:?}", + super::Event::find() + .all(&orm_manager.db_conn) + .await + .unwrap()[0] + ); + assert_eq!( + super::Event::find() + .all(&orm_manager.db_conn) + .await + .unwrap()[0] + .epoch_id, + res.unwrap() + ); + } + + #[tokio::test] + #[ignore] // Need to update all tables + async fn test_store_cost() { + run_migration().await; + let mut orm_manager = super::ORMManager::new(Some(TEST_DATABASE_URL)).await; + let epoch_id = orm_manager + .create_new_epoch("source".to_string(), "data".to_string()) + .await + .unwrap(); + let expr_id = 1; + let cost = 42; + let res = orm_manager.store_cost(expr_id, cost, epoch_id).await; + match res { + Ok(_) => assert!(true), + Err(e) => { + println!("Error: {:?}", e); + assert!(false); + } + } + let costs = super::PlanCost::find() + .all(&orm_manager.db_conn) + .await + .unwrap(); + assert_eq!(costs.len(), 1); + assert_eq!(costs[0].epoch_id, epoch_id); + assert_eq!(costs[0].physical_expression_id, expr_id); + assert_eq!(costs[0].cost, cost); + } +} diff --git a/optd-persistent/src/storage_layer.rs b/optd-persistent/src/storage_layer.rs new file mode 100644 index 0000000..93590f6 --- /dev/null +++ b/optd-persistent/src/storage_layer.rs @@ -0,0 +1,176 @@ +#![allow(dead_code, unused_imports)] + +use crate::entities::cascades_group; +use crate::entities::event::Model as event_model; +use crate::entities::logical_expression; +use crate::entities::physical_expression; +use sea_orm::*; +use sea_orm_migration::prelude::*; +use serde_json::json; +use std::sync::Arc; + +pub type GroupId = i32; +pub type ExprId = i32; +pub type EpochId = i32; + +pub type StorageResult = Result; + +pub enum CatalogSource { + Iceberg(), +} + +pub enum Expression { + LogicalExpression(logical_expression::Model), + PhysicalExpression(physical_expression::Model), +} + +// TODO +// A dummy WinnerInfo struct +// pub struct WinnerInfo { +// pub expr_id: ExprId, +// pub total_weighted_cost: f64, +// pub operation_weighted_cost: f64, +// pub total_cost: Cost, +// pub operation_cost: Cost, +// pub statistics: Arc, +// } +// The optd WinnerInfo struct makes everything too coupled. +pub struct WinnerInfo {} + +pub trait StorageLayer { + // TODO: Change EpochId to event::Model::epoch_id + async fn create_new_epoch(&mut self, source: String, data: String) -> StorageResult; + + async fn update_stats_from_catalog( + &self, + c: CatalogSource, + epoch_id: EpochId, + ) -> StorageResult<()>; + + // i32 in `stats:i32` is a placeholder for the stats type + async fn update_stats(&self, stats: i32, epoch_id: EpochId) -> StorageResult<()>; + + async fn store_cost(&self, expr_id: ExprId, cost: i32, epoch_id: EpochId) -> StorageResult<()>; + + /// Get the statistics for a given table. + /// + /// If `epoch_id` is None, it will return the latest statistics. + async fn get_stats_for_table( + &self, + table_id: i32, + stat_type: i32, + epoch_id: Option, + ) -> StorageResult>; + + /// Get the statistics for a given attribute. + /// + /// If `epoch_id` is None, it will return the latest statistics. + async fn get_stats_for_attr( + &self, + attr_id: i32, + stat_type: i32, + epoch_id: Option, + ) -> StorageResult>; + + /// Get the joint statistics for a list of attributes. + /// + /// If `epoch_id` is None, it will return the latest statistics. + async fn get_stats_for_attrs( + &self, + attr_ids: Vec, + stat_type: i32, + epoch_id: Option, + ) -> StorageResult>; + + async fn get_cost_analysis( + &self, + expr_id: ExprId, + epoch_id: EpochId, + ) -> StorageResult>; + + async fn get_cost(&self, expr_id: ExprId) -> StorageResult>; + + async fn get_group_winner_from_group_id( + &self, + group_id: i32, + ) -> StorageResult>; + + /// Add an expression to the memo table. If the expression already exists, it will return the existing group id and + /// expr id. Otherwise, a new group and expr will be created. + async fn add_new_expr(&mut self, expr: Expression) -> StorageResult<(GroupId, ExprId)>; + + /// Add a new expression to an existing group. If the expression is a group, it will merge the two groups. Otherwise, + /// it will add the expression to the group. Returns the expr id if the expression is not a group. + async fn add_expr_to_group( + &mut self, + expr: Expression, + group_id: GroupId, + ) -> StorageResult>; + + /// Get the group id of an expression. + /// The group id is volatile, depending on whether the groups are merged. + async fn get_group_id(&self, expr_id: ExprId) -> StorageResult; + + /// Get the memoized representation of a node. + async fn get_expr_memoed(&self, expr_id: ExprId) -> StorageResult; + + /// Get all groups IDs in the memo table. + async fn get_all_group_ids(&self) -> StorageResult>; + + /// Get a group by ID + async fn get_group(&self, group_id: GroupId) -> StorageResult; + + /// Update the group winner. + async fn update_group_winner( + &mut self, + group_id: GroupId, + latest_winner: Option, + ) -> StorageResult<()>; + + // The below functions can be overwritten by the memo table implementation if there + // are more efficient way to retrieve the information. + + /// Get all expressions in the group. + async fn get_all_exprs_in_group(&self, group_id: GroupId) -> StorageResult>; + + /// Get winner info for a group id + async fn get_group_info(&self, group_id: GroupId) -> StorageResult<&Option>; + + // TODO: + /// Get the best group binding based on the cost + // fn get_best_group_binding( + // &self, + // group_id: GroupId, + // mut post_process: impl FnMut(Arc, GroupId, &WinnerInfo), + // ) -> Result; + // { + // // let info: &GroupInfo = this.get_group_info(group_id); + // // if let Winner::Full(info @ WinnerInfo { expr_id, .. }) = &info.winner { + // // let expr = this.get_expr_memoed(*expr_id); + // // let mut children = Vec::with_capacity(expr.children.len()); + // // for child in &expr.children { + // // children.push( + // // get_best_group_binding_inner(this, *child, post_process) + // // .with_context(|| format!("when processing expr {}", expr_id))?, + // // ); + // // } + // // let node = Arc::new(RelNode { + // // typ: expr.typ.clone(), + // // children, + // // data: expr.data.clone(), + // // }); + // // post_process(node.clone(), group_id, info); + // // return Ok(node); + // // } + // // bail!("no best group binding for group {}", group_id) + // }; + + /// Get all bindings of a predicate group. Will panic if the group contains more than one bindings. + async fn get_predicate_binding(&self, group_id: GroupId) -> StorageResult>; + + /// Get all bindings of a predicate group. Returns None if the group contains zero or more than one bindings. + async fn try_get_predicate_binding( + &self, + group_id: GroupId, + ) -> StorageResult>; +}