diff --git a/optd-cost-model/src/cost_model.rs b/optd-cost-model/src/cost_model.rs index 9ae84bb..38957f9 100644 --- a/optd-cost-model/src/cost_model.rs +++ b/optd-cost-model/src/cost_model.rs @@ -15,7 +15,7 @@ use crate::{ }, memo_ext::MemoExt, stats::AttributeCombValueStats, - storage::CostModelStorageManager, + storage::{self, CostModelStorageManager}, ComputeCostContext, Cost, CostModel, CostModelResult, EstimatedStatistic, StatValue, }; @@ -43,24 +43,111 @@ impl CostModelImpl { #[async_trait::async_trait] impl CostModel for CostModelImpl { + /// TODO: should we add epoch_id? async fn compute_operation_cost( &self, - node: &PhysicalNodeType, + node: PhysicalNodeType, predicates: &[ArcPredicateNode], + children_costs: &[Cost], children_stats: &[EstimatedStatistic], context: ComputeCostContext, ) -> CostModelResult { - todo!() + let res = self.storage_manager.get_cost(context.expr_id).await; + if let Ok((Some(cost), _)) = res { + return Ok(cost); + }; + let mut output_statistic = None; + if let Ok((_, Some(statistic))) = res { + output_statistic = Some(statistic); + }; + let output_cost = match node { + PhysicalNodeType::PhysicalScan => { + let output_statistic_data = output_statistic.unwrap_or( + self.derive_statistics( + node, + predicates, + children_stats, + context.clone(), + false, + ) + .await?, + ); + output_statistic = Some(output_statistic_data.clone()); + Cost { + compute_cost: 0.0, + io_cost: output_statistic_data.0, + } + } + PhysicalNodeType::PhysicalEmptyRelation => Cost { + compute_cost: 0.1, + io_cost: 0.0, + }, + PhysicalNodeType::PhysicalLimit => Cost { + compute_cost: children_costs[0].compute_cost, + io_cost: 0.0, + }, + PhysicalNodeType::PhysicalFilter => Cost { + // TODO: now this equation is specific to optd, and try to make this equation more general + compute_cost: children_costs[1].compute_cost * children_stats[0].0, + io_cost: 0.0, + }, + PhysicalNodeType::PhysicalNestedLoopJoin(join_typ) => { + let child_compute_cost = children_costs[2].compute_cost; + Cost { + compute_cost: children_stats[0].0 * children_stats[1].0 * child_compute_cost + + children_stats[0].0, + io_cost: 0.0, + } + } + // TODO: we should document that the first child is the left table, which is used to build + // the hash table. + PhysicalNodeType::PhysicalHashJoin(join_typ) => Cost { + compute_cost: children_stats[0].0 * 2.0 + children_stats[1].0, + io_cost: 0.0, + }, + PhysicalNodeType::PhysicalAgg => Cost { + compute_cost: children_stats[0].0 + * (children_costs[1].compute_cost + children_costs[2].compute_cost), + io_cost: 0.0, + }, + PhysicalNodeType::PhysicalProjection => Cost { + compute_cost: children_stats[0].0 * children_costs[1].compute_cost, + io_cost: 0.0, + }, + PhysicalNodeType::PhysicalSort => Cost { + compute_cost: children_stats[0].0 * children_stats[0].0.ln_1p().max(1.0), + io_cost: 0.0, + }, + }; + let res = self + .storage_manager + .store_cost( + context.expr_id, + Some(output_cost.clone()), + output_statistic, + None, + ) + .await; + if res.is_err() { + eprintln!("Failed to store output cost"); + } + Ok(output_cost) } + /// TODO: should we add epoch_id? async fn derive_statistics( &self, node: PhysicalNodeType, predicates: &[ArcPredicateNode], children_statistics: &[EstimatedStatistic], context: ComputeCostContext, + store_output_statistic: bool, ) -> CostModelResult { - match node { + let res = self.storage_manager.get_cost(context.expr_id).await; + if let Ok((_, Some(statistic))) = res { + return Ok(statistic); + } + let output_statistic = match node { PhysicalNodeType::PhysicalScan => { let table_id = TableId(predicates[0].data.as_ref().unwrap().as_u64()); let row_cnt = self @@ -114,7 +201,17 @@ impl CostModel for CostModel PhysicalNodeType::PhysicalSort | PhysicalNodeType::PhysicalProjection => { Ok(children_statistics[0].clone()) } - } + }?; + if store_output_statistic { + let res = self + .storage_manager + .store_cost(context.expr_id, None, Some(output_statistic.clone()), None) + .await; + if res.is_err() { + eprintln!("Failed to store output statistic"); + } + }; + Ok(output_statistic) } async fn update_statistics( @@ -167,3 +264,5 @@ impl CostModelImpl { .await } } + +// TODO: Add tests for `derive_statistic`` and `compute_operation_cost`. diff --git a/optd-cost-model/src/lib.rs b/optd-cost-model/src/lib.rs index 68b56ac..4b65038 100644 --- a/optd-cost-model/src/lib.rs +++ b/optd-cost-model/src/lib.rs @@ -30,13 +30,52 @@ pub struct ComputeCostContext { } #[derive(Default, Clone, Debug, PartialOrd, PartialEq)] -pub struct Cost(pub Vec); +pub struct Cost { + pub compute_cost: f64, + pub io_cost: f64, +} + +impl From for optd_persistent::cost_model::interface::Cost { + fn from(c: Cost) -> optd_persistent::cost_model::interface::Cost { + Self { + compute_cost: c.compute_cost, + io_cost: c.io_cost, + } + } +} + +impl From for Cost { + fn from(c: optd_persistent::cost_model::interface::Cost) -> Cost { + Self { + compute_cost: c.compute_cost, + io_cost: c.io_cost, + } + } +} /// Estimated statistic calculated by the cost model. /// It is the estimated output row count of the targeted expression. #[derive(PartialEq, PartialOrd, Clone, Debug)] pub struct EstimatedStatistic(pub f64); +impl From for f32 { + fn from(e: EstimatedStatistic) -> f32 { + e.0 as f32 + } +} + +impl From for f64 { + fn from(e: EstimatedStatistic) -> f64 { + e.0 + } +} + +impl From for EstimatedStatistic { + fn from(f: f32) -> EstimatedStatistic { + Self(f as f64) + } +} + pub type CostModelResult = Result; #[derive(Debug)] @@ -79,8 +118,9 @@ pub trait CostModel: 'static + Send + Sync { /// TODO: documentation async fn compute_operation_cost( &self, - node: &PhysicalNodeType, + node: PhysicalNodeType, predicates: &[ArcPredicateNode], + children_costs: &[Cost], children_stats: &[EstimatedStatistic], context: ComputeCostContext, ) -> CostModelResult; @@ -88,14 +128,32 @@ pub trait CostModel: 'static + Send + Sync { /// TODO: documentation /// It is for cardinality estimation. The output should be the estimated /// statistic calculated by the cost model. + /// If this method is called by `compute_operation_cost`, please set + /// `store_output_statistic` to `false`; if it is called by the optimizer, + /// please set `store_output_statistic` to `true`. Since we can store the + /// estimated statistic and cost by calling the ORM method once. + /// + /// TODO: I am not sure whether to introduce `store_output_statistic`, since + /// it add complexity to the interface, considering currently only Scan needs + /// the output row count to calculate the costs. So updating the database twice + /// seems cheap. But in the future, maybe more cost computations rely on the output + /// row count. (Of course, it should be removed if we separate the cost and + /// estimated_statistic into 2 tables.) + /// /// TODO: Consider make it a helper function, so we can store Cost in the /// ORM more easily. + /// + /// TODO: I would suggest to rename this method to `derive_row_count`, since + /// statistic is easily to be confused with the real statistic. + /// Also we need to update other places to use estimated statistic to row count, + /// either in this crate or in optd-persistent. async fn derive_statistics( &self, node: PhysicalNodeType, predicates: &[ArcPredicateNode], children_stats: &[EstimatedStatistic], context: ComputeCostContext, + store_output_statistic: bool, ) -> CostModelResult; /// TODO: documentation diff --git a/optd-cost-model/src/storage/mock.rs b/optd-cost-model/src/storage/mock.rs index e2c9b1e..f20c417 100644 --- a/optd-cost-model/src/storage/mock.rs +++ b/optd-cost-model/src/storage/mock.rs @@ -3,7 +3,11 @@ use std::collections::HashMap; use serde::{Deserialize, Serialize}; -use crate::{common::types::TableId, stats::AttributeCombValueStats, CostModelResult}; +use crate::{ + common::types::{EpochId, ExprId, TableId}, + stats::AttributeCombValueStats, + Cost, CostModelResult, EstimatedStatistic, +}; use super::CostModelStorageManager; @@ -63,4 +67,23 @@ impl CostModelStorageManager for CostModelStorageMockManagerImpl { let table_stats = self.per_table_stats_map.get(&table_id); Ok(table_stats.map(|stats| stats.row_cnt)) } + + /// TODO: finish this when implementing the cost get/store tests + async fn get_cost( + &self, + expr_id: ExprId, + ) -> CostModelResult<(Option, Option)> { + todo!() + } + + /// TODO: finish this when implementing the cost get/store tests + async fn store_cost( + &self, + expr_id: ExprId, + cost: Option, + estimated_statistic: Option, + epoch_id: Option, + ) -> CostModelResult<()> { + todo!() + } } diff --git a/optd-cost-model/src/storage/mod.rs b/optd-cost-model/src/storage/mod.rs index 311da44..14cccd6 100644 --- a/optd-cost-model/src/storage/mod.rs +++ b/optd-cost-model/src/storage/mod.rs @@ -1,4 +1,8 @@ -use crate::{common::types::TableId, stats::AttributeCombValueStats, CostModelResult}; +use crate::{ + common::types::{EpochId, ExprId, TableId}, + stats::AttributeCombValueStats, + Cost, CostModelResult, EstimatedStatistic, +}; pub mod mock; pub mod persistent; @@ -12,4 +16,17 @@ pub trait CostModelStorageManager { ) -> CostModelResult>; async fn get_table_row_count(&self, table_id: TableId) -> CostModelResult>; + + async fn get_cost( + &self, + expr_id: ExprId, + ) -> CostModelResult<(Option, Option)>; + + async fn store_cost( + &self, + expr_id: ExprId, + cost: Option, + estimated_statistic: Option, + epoch_id: Option, + ) -> CostModelResult<()>; } diff --git a/optd-cost-model/src/storage/persistent.rs b/optd-cost-model/src/storage/persistent.rs index 2238507..b3078a6 100644 --- a/optd-cost-model/src/storage/persistent.rs +++ b/optd-cost-model/src/storage/persistent.rs @@ -4,9 +4,9 @@ use std::sync::Arc; use optd_persistent::{cost_model::interface::StatType, CostModelStorageLayer}; use crate::{ - common::types::TableId, + common::types::{EpochId, ExprId, TableId}, stats::{utilities::counter::Counter, AttributeCombValueStats, Distribution, MostCommonValues}, - CostModelResult, + Cost, CostModelResult, EstimatedStatistic, }; use super::CostModelStorageManager; @@ -125,5 +125,42 @@ impl CostModelStorageManager .transpose()?) } + /// TODO: The name is misleading, since we can also get the estimated statistic. We should + /// rename it. + /// + /// TODO: Add retry logic here. + async fn get_cost( + &self, + expr_id: ExprId, + ) -> CostModelResult<(Option, Option)> { + let (cost, estimated_statistic) = self.backend_manager.get_cost(expr_id.into()).await?; + Ok(( + cost.map(|c| c.into()), + estimated_statistic.map(|x| x.into()), + )) + } + + /// TODO: The name is misleading, since we can also get the estimated statistic. We should + /// rename it. + /// + /// TODO: Add retry logic here. + async fn store_cost( + &self, + expr_id: ExprId, + cost: Option, + estimated_statistic: Option, + epoch_id: Option, + ) -> CostModelResult<()> { + self.backend_manager + .store_cost( + expr_id.into(), + cost.map(|c| c.into()), + estimated_statistic.map(|x| x.into()), + epoch_id.map(|id| id.into()), + ) + .await?; + Ok(()) + } + // TODO: Support querying for a specific type of statistics. } diff --git a/optd-persistent/src/bin/init.rs b/optd-persistent/src/bin/init.rs index 9cc07e2..a0b6b25 100644 --- a/optd-persistent/src/bin/init.rs +++ b/optd-persistent/src/bin/init.rs @@ -355,8 +355,8 @@ async fn init_all_tables() -> Result<(), sea_orm::error::DbErr> { id: Set(1), physical_expression_id: Set(1), epoch_id: Set(1), - cost: Set(json!({"compute_cost":10, "io_cost":10})), - estimated_statistic: Set(10), + cost: Set(Some(json!({"compute_cost":10, "io_cost":10}))), + estimated_statistic: Set(Some(10.0)), is_valid: Set(true), }; plan_cost::Entity::insert(plan_cost) diff --git a/optd-persistent/src/cost_model/interface.rs b/optd-persistent/src/cost_model/interface.rs index ee767d7..e4e4ceb 100644 --- a/optd-persistent/src/cost_model/interface.rs +++ b/optd-persistent/src/cost_model/interface.rs @@ -89,10 +89,8 @@ pub struct Stat { /// TODO: documentation #[derive(Clone, Debug, PartialEq)] pub struct Cost { - pub compute_cost: i32, - pub io_cost: i32, - // Raw estimated output row count of targeted expression. - pub estimated_statistic: i32, + pub compute_cost: f64, + pub io_cost: f64, } #[derive(Clone, Debug)] @@ -118,8 +116,13 @@ pub trait CostModelStorageLayer { epoch_option: EpochOption, ) -> StorageResult>; - async fn store_cost(&self, expr_id: ExprId, cost: Cost, epoch_id: EpochId) - -> StorageResult<()>; + async fn store_cost( + &self, + expr_id: ExprId, + cost: Option, + estimated_statistic: Option, + epoch_id: Option, + ) -> StorageResult<()>; async fn store_expr_stats_mappings( &self, @@ -162,9 +165,9 @@ pub trait CostModelStorageLayer { &self, expr_id: ExprId, epoch_id: EpochId, - ) -> StorageResult>; + ) -> StorageResult<(Option, Option)>; - async fn get_cost(&self, expr_id: ExprId) -> StorageResult>; + async fn get_cost(&self, expr_id: ExprId) -> StorageResult<(Option, Option)>; async fn get_attribute( &self, diff --git a/optd-persistent/src/cost_model/orm.rs b/optd-persistent/src/cost_model/orm.rs index d5b7ad6..9c068ae 100644 --- a/optd-persistent/src/cost_model/orm.rs +++ b/optd-persistent/src/cost_model/orm.rs @@ -4,7 +4,7 @@ use crate::cost_model::interface::Cost; use crate::entities::{prelude::*, *}; use crate::{BackendError, BackendManager, CostModelStorageLayer, StorageResult}; use sea_orm::prelude::{Expr, Json}; -use sea_orm::sea_query::Query; +use sea_orm::sea_query::{ExprTrait, Query}; use sea_orm::{sqlx::types::chrono::Utc, EntityTrait}; use sea_orm::{ ActiveModelTrait, ColumnTrait, Condition, DbBackend, DbErr, DeleteResult, EntityOrSelect, @@ -208,7 +208,7 @@ impl CostModelStorageLayer for BackendManager { // 0. Check if the stat already exists. If exists, get stat_id, else insert into statistic table. let stat_id = match stat.table_id { Some(table_id) => { - // TODO(lanlou): only select needed fields + // TODO: only select needed fields let res = Statistic::find() .filter(statistic::Column::TableId.eq(table_id)) .inner_join(versioned_statistic::Entity) @@ -467,47 +467,74 @@ impl CostModelStorageLayer for BackendManager { } /// TODO: documentation + /// Each record in the `plan_cost` table can contain either the cost or the estimated statistic + /// or both, but never neither. + /// The name can be misleading, since it can also return the estimated statistic. async fn get_cost_analysis( &self, expr_id: ExprId, epoch_id: EpochId, - ) -> StorageResult> { + ) -> StorageResult<(Option, Option)> { let cost = PlanCost::find() .filter(plan_cost::Column::PhysicalExpressionId.eq(expr_id)) .filter(plan_cost::Column::EpochId.eq(epoch_id)) .one(&self.db) .await?; - assert!(cost.is_some(), "Cost not found in Cost table"); - assert!(cost.clone().unwrap().is_valid, "Cost is not valid"); - Ok(cost.map(|c| Cost { - compute_cost: c.cost.get("compute_cost").unwrap().as_i64().unwrap() as i32, - io_cost: c.cost.get("io_cost").unwrap().as_i64().unwrap() as i32, - estimated_statistic: c.estimated_statistic, - })) + // When this cost is not found, we should return None + if cost.is_none() { + return Ok((None, None)); + } + + let real_cost = cost.as_ref().and_then(|c| c.cost.as_ref()).map(|c| Cost { + compute_cost: c.get("compute_cost").unwrap().as_f64().unwrap(), + io_cost: c.get("io_cost").unwrap().as_f64().unwrap(), + }); + + Ok((real_cost, cost.unwrap().estimated_statistic)) } - async fn get_cost(&self, expr_id: ExprId) -> StorageResult> { + /// TODO: documentation + /// It returns the cost and estimated statistic if applicable. + /// Each record in the `plan_cost` table can contain either the cost or the estimated statistic + /// or both, but never neither. + /// The name can be misleading, since it can also return the estimated statistic. + async fn get_cost(&self, expr_id: ExprId) -> StorageResult<(Option, Option)> { let cost = PlanCost::find() .filter(plan_cost::Column::PhysicalExpressionId.eq(expr_id)) .order_by_desc(plan_cost::Column::EpochId) .one(&self.db) .await?; - assert!(cost.is_some(), "Cost not found in Cost table"); - assert!(cost.clone().unwrap().is_valid, "Cost is not valid"); - Ok(cost.map(|c| Cost { - compute_cost: c.cost.get("compute_cost").unwrap().as_i64().unwrap() as i32, - io_cost: c.cost.get("io_cost").unwrap().as_i64().unwrap() as i32, - estimated_statistic: c.estimated_statistic, - })) + // When this cost is invalid or not found, we should return None + if cost.is_none() || !cost.clone().unwrap().is_valid { + return Ok((None, None)); + } + + let real_cost = cost.as_ref().and_then(|c| c.cost.as_ref()).map(|c| Cost { + compute_cost: c.get("compute_cost").unwrap().as_f64().unwrap(), + io_cost: c.get("io_cost").unwrap().as_f64().unwrap(), + }); + + Ok((real_cost, cost.unwrap().estimated_statistic)) } + /// This method should handle the case when the cost is already stored. + /// The name maybe misleading, since it can also store the estimated statistic. + /// If epoch_id is none, we pick the latest epoch_id. + /// + /// TODO: consider whether we need to pass the epoch_id here. When the epoch is + /// stale because someone else updates the stats while we're still computing cost, + /// what is the expected behavior? + /// /// TODO: documentation async fn store_cost( &self, physical_expression_id: ExprId, - cost: Cost, - epoch_id: EpochId, + cost: Option, + estimated_statistic: Option, + epoch_id: Option, ) -> StorageResult<()> { + assert!(cost.is_some() || estimated_statistic.is_some()); + // TODO: should we do the following checks in the production environment? let expr_exists = PhysicalExpression::find_by_id(physical_expression_id) .one(&self.db) .await?; @@ -520,30 +547,80 @@ impl CostModelStorageLayer for BackendManager { .into(), )); } - // Check if epoch_id exists in Event table - let epoch_exists = Event::find() - .filter(event::Column::EpochId.eq(epoch_id)) - .one(&self.db) - .await - .unwrap(); - if epoch_exists.is_none() { - return Err(BackendError::CostModel( - format!("epoch id {} not found when storing cost", epoch_id).into(), - )); + if epoch_id.is_some() { + let epoch_exists = Event::find() + .filter(event::Column::EpochId.eq(epoch_id.unwrap())) + .one(&self.db) + .await + .unwrap(); + if epoch_exists.is_none() { + return Err(BackendError::CostModel( + format!("epoch id {} not found when storing cost", epoch_id.unwrap()).into(), + )); + } } - let new_cost = plan_cost::ActiveModel { - physical_expression_id: sea_orm::ActiveValue::Set(physical_expression_id), - epoch_id: sea_orm::ActiveValue::Set(epoch_id), - cost: sea_orm::ActiveValue::Set( - json!({"compute_cost": cost.compute_cost, "io_cost": cost.io_cost}), - ), - estimated_statistic: sea_orm::ActiveValue::Set(cost.estimated_statistic), - is_valid: sea_orm::ActiveValue::Set(true), - ..Default::default() + let epoch_id = match epoch_id { + Some(id) => id, + None => { + // When init, please make sure there is at least one epoch in the Event table. + let latest_epoch_id = Event::find() + .order_by_desc(event::Column::EpochId) + .one(&self.db) + .await? + .unwrap(); + latest_epoch_id.epoch_id + } }; - let _ = PlanCost::insert(new_cost).exec(&self.db).await?; + + let transaction = self.db.begin().await?; + + let valid_cost = PlanCost::find() + .filter(plan_cost::Column::PhysicalExpressionId.eq(physical_expression_id)) + .filter(plan_cost::Column::EpochId.eq(epoch_id)) + .filter(plan_cost::Column::IsValid.eq(true)) + .one(&transaction) + .await?; + + if valid_cost.is_some() { + let mut new_cost: plan_cost::ActiveModel = valid_cost.unwrap().into(); + let mut update = false; + if cost.is_some() { + let input_cost = sea_orm::ActiveValue::Set(Some(json!({ + "compute_cost": cost.clone().unwrap().compute_cost, + "io_cost": cost.clone().unwrap().io_cost + }))); + if new_cost.cost != input_cost { + update = true; + new_cost.cost = input_cost; + } + } + if estimated_statistic.is_some() { + let input_estimated_statistic = sea_orm::ActiveValue::Set(estimated_statistic); + if new_cost.estimated_statistic != input_estimated_statistic { + update = true; + new_cost.estimated_statistic = input_estimated_statistic; + } + } + if update { + let _ = PlanCost::update(new_cost).exec(&transaction).await?; + } + } else { + let new_cost = plan_cost::ActiveModel { + physical_expression_id: sea_orm::ActiveValue::Set(physical_expression_id), + epoch_id: sea_orm::ActiveValue::Set(epoch_id), + cost: sea_orm::ActiveValue::Set( + cost.map(|c| json!({"compute_cost": c.compute_cost, "io_cost": c.io_cost})), + ), + estimated_statistic: sea_orm::ActiveValue::Set(estimated_statistic), + is_valid: sea_orm::ActiveValue::Set(true), + ..Default::default() + }; + let _ = PlanCost::insert(new_cost).exec(&transaction).await?; + } + + transaction.commit().await?; Ok(()) } @@ -577,6 +654,7 @@ impl CostModelStorageLayer for BackendManager { } } +// TODO: add integration tests #[cfg(test)] mod tests { use crate::cost_model::interface::{Cost, EpochOption, StatType}; @@ -755,14 +833,12 @@ mod tests { backend_manager .store_cost( expr_id, - { - Cost { - compute_cost: 42, - io_cost: 42, - estimated_statistic: 42, - } - }, - versioned_stat_res[0].epoch_id, + Some(Cost { + compute_cost: 42.0, + io_cost: 42.0, + }), + Some(42.0), + Some(versioned_stat_res[0].epoch_id), ) .await .unwrap(); @@ -826,7 +902,10 @@ mod tests { .await .unwrap(); assert_eq!(cost_res.len(), 1); - assert_eq!(cost_res[0].cost, json!({"compute_cost": 42, "io_cost": 42})); + assert_eq!( + cost_res[0].cost, + Some(json!({"compute_cost": 42.0, "io_cost": 42.0})) + ); assert_eq!(cost_res[0].epoch_id, epoch_id1); assert!(!cost_res[0].is_valid); @@ -958,12 +1037,17 @@ mod tests { .unwrap(); let physical_expression_id = 1; let cost = Cost { - compute_cost: 42, - io_cost: 42, - estimated_statistic: 42, + compute_cost: 42.0, + io_cost: 42.0, }; + let mut estimated_statistic = 42.0; backend_manager - .store_cost(physical_expression_id, cost.clone(), epoch_id) + .store_cost( + physical_expression_id, + Some(cost.clone()), + Some(estimated_statistic), + Some(epoch_id), + ) .await .unwrap(); let costs = super::PlanCost::find() @@ -975,11 +1059,34 @@ mod tests { assert_eq!(costs[1].physical_expression_id, physical_expression_id); assert_eq!( costs[1].cost, - json!({"compute_cost": cost.compute_cost, "io_cost": cost.io_cost}) + Some(json!({"compute_cost": cost.compute_cost, "io_cost": cost.io_cost})) + ); + assert_eq!(costs[1].estimated_statistic.unwrap(), estimated_statistic); + + estimated_statistic = 50.0; + backend_manager + .store_cost( + physical_expression_id, + None, + Some(estimated_statistic), + None, + ) + .await + .unwrap(); + let costs = super::PlanCost::find() + .all(&backend_manager.db) + .await + .unwrap(); + assert_eq!(costs.len(), 2); // We should not insert a new row + assert_eq!(costs[1].epoch_id, epoch_id); + assert_eq!(costs[1].physical_expression_id, physical_expression_id); + assert_eq!( + costs[1].cost, + Some(json!({"compute_cost": cost.compute_cost, "io_cost": cost.io_cost})) ); assert_eq!( - costs[1].estimated_statistic as i32, - cost.estimated_statistic + costs[1].estimated_statistic.unwrap(), + estimated_statistic // The estimated_statistic should be update ); remove_db_file(DATABASE_FILE); @@ -997,12 +1104,16 @@ mod tests { .unwrap(); let physical_expression_id = 1; let cost = Cost { - compute_cost: 42, - io_cost: 42, - estimated_statistic: 42, + compute_cost: 42.0, + io_cost: 42.0, }; let _ = backend_manager - .store_cost(physical_expression_id, cost.clone(), epoch_id) + .store_cost( + physical_expression_id, + Some(cost.clone()), + None, + Some(epoch_id), + ) .await; let costs = super::PlanCost::find() .all(&backend_manager.db) @@ -1013,18 +1124,16 @@ mod tests { assert_eq!(costs[1].physical_expression_id, physical_expression_id); assert_eq!( costs[1].cost, - json!({"compute_cost": cost.compute_cost, "io_cost": cost.io_cost}) - ); - assert_eq!( - costs[1].estimated_statistic as i32, - cost.estimated_statistic + Some(json!({"compute_cost": cost.compute_cost, "io_cost": cost.io_cost})) ); + assert_eq!(costs[1].estimated_statistic, None); let res = backend_manager .get_cost(physical_expression_id) .await .unwrap(); - assert_eq!(res.unwrap(), cost); + assert_eq!(res.0.unwrap(), cost); + assert_eq!(res.1, None); remove_db_file(DATABASE_FILE); } @@ -1040,13 +1149,14 @@ mod tests { .await .unwrap(); let physical_expression_id = 1; - let cost = Cost { - compute_cost: 1420, - io_cost: 42, - estimated_statistic: 42, - }; + let estimated_statistic = 42.0; let _ = backend_manager - .store_cost(physical_expression_id, cost.clone(), epoch_id) + .store_cost( + physical_expression_id, + None, + Some(estimated_statistic), + Some(epoch_id), + ) .await; let costs = super::PlanCost::find() .all(&backend_manager.db) @@ -1055,14 +1165,8 @@ mod tests { assert_eq!(costs.len(), 2); // The first row one is the initialized data assert_eq!(costs[1].epoch_id, epoch_id); assert_eq!(costs[1].physical_expression_id, physical_expression_id); - assert_eq!( - costs[1].cost, - json!({"compute_cost": cost.compute_cost, "io_cost": cost.io_cost}) - ); - assert_eq!( - costs[1].estimated_statistic as i32, - cost.estimated_statistic - ); + assert_eq!(costs[1].cost, None); + assert_eq!(costs[1].estimated_statistic.unwrap(), estimated_statistic); println!("{:?}", costs); // Retrieve physical_expression_id 1 and epoch_id 1 @@ -1073,13 +1177,13 @@ mod tests { // The cost in the dummy data is 10 assert_eq!( - res.unwrap(), + res.0.unwrap(), Cost { - compute_cost: 10, - io_cost: 10, - estimated_statistic: 10, + compute_cost: 10.0, + io_cost: 10.0, } ); + assert_eq!(res.1.unwrap(), 10.0); remove_db_file(DATABASE_FILE); } diff --git a/optd-persistent/src/db/init.db b/optd-persistent/src/db/init.db index 5350952..0c3d71c 100644 Binary files a/optd-persistent/src/db/init.db and b/optd-persistent/src/db/init.db differ diff --git a/optd-persistent/src/entities/plan_cost.rs b/optd-persistent/src/entities/plan_cost.rs index 1acf101..5d7e24b 100644 --- a/optd-persistent/src/entities/plan_cost.rs +++ b/optd-persistent/src/entities/plan_cost.rs @@ -2,15 +2,16 @@ use sea_orm::entity::prelude::*; -#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq)] +#[derive(Clone, Debug, PartialEq, DeriveEntityModel)] #[sea_orm(table_name = "plan_cost")] pub struct Model { #[sea_orm(primary_key)] pub id: i32, pub physical_expression_id: i32, pub epoch_id: i32, - pub cost: Json, - pub estimated_statistic: i32, + pub cost: Option, + #[sea_orm(column_type = "Float", nullable)] + pub estimated_statistic: Option, pub is_valid: bool, } diff --git a/optd-persistent/src/migrator/cost_model/m20241029_000001_plan_cost.rs b/optd-persistent/src/migrator/cost_model/m20241029_000001_plan_cost.rs index d8f9bf9..fdd2fef 100644 --- a/optd-persistent/src/migrator/cost_model/m20241029_000001_plan_cost.rs +++ b/optd-persistent/src/migrator/cost_model/m20241029_000001_plan_cost.rs @@ -1,6 +1,14 @@ //! When a statistic is updated, then all the related costs should be invalidated. (IsValid is set to false) //! This design (using IsValid flag) is based on the assumption that update_stats will not be called very frequently. //! It favors the compute_cost performance over the update_stats performance. +//! +//! This file stores cost like compute_cost, io_cost, network_cost, etc. for each physical expression. It also +//! stores the estimated output row count (estimated statistic) of each physical expression. +//! Sometimes we only have one of them to store, so we make Cost and EstimatedStatistic optional. But +//! one record must have at least one of them. +//! +//! TODO: Ideally, we can separate them since sometimes we only have the estimated output row count to store, +//! (when calling `derive_statistic`) but we don't have the detailed cost. use crate::migrator::cost_model::event::Event; use crate::migrator::memo::physical_expression::PhysicalExpression; @@ -49,8 +57,8 @@ impl MigrationTrait for Migration { .on_delete(ForeignKeyAction::Cascade) .on_update(ForeignKeyAction::Cascade), ) - .col(json(PlanCost::Cost)) - .col(integer(PlanCost::EstimatedStatistic)) + .col(json_null(PlanCost::Cost)) + .col(float_null(PlanCost::EstimatedStatistic)) .col(boolean(PlanCost::IsValid)) .to_owned(), ) diff --git a/schema/all_tables.dbml b/schema/all_tables.dbml index 305075a..29a2136 100644 --- a/schema/all_tables.dbml +++ b/schema/all_tables.dbml @@ -60,9 +60,9 @@ Table plan_cost { physical_expression_id integer [ref: > physical_expression.id] epoch_id integer [ref: > event.epoch_id] // It is json type, including computation cost, I/O cost, etc. - cost json + cost json [null] // Raw estimated output row count of this expression - estimated_statistic integer + estimated_statistic float [null] // Whether the cost is valid or not. If the latest cost for an expr is invalid, then we need to recompute the cost. // We need to invalidate the cost when the related stats are updated. is_valid boolean