Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 104 additions & 5 deletions optd-cost-model/src/cost_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use crate::{
},
memo_ext::MemoExt,
stats::AttributeCombValueStats,
storage::CostModelStorageManager,
storage::{self, CostModelStorageManager},
ComputeCostContext, Cost, CostModel, CostModelResult, EstimatedStatistic, StatValue,
};

Expand Down Expand Up @@ -43,24 +43,111 @@ impl<S: CostModelStorageManager> CostModelImpl<S> {

#[async_trait::async_trait]
impl<S: CostModelStorageManager + Send + Sync + 'static> CostModel for CostModelImpl<S> {
/// TODO: should we add epoch_id?
async fn compute_operation_cost(
&self,
node: &PhysicalNodeType,
node: PhysicalNodeType,
predicates: &[ArcPredicateNode],
children_costs: &[Cost],
children_stats: &[EstimatedStatistic],
context: ComputeCostContext,
) -> CostModelResult<Cost> {
todo!()
let res = self.storage_manager.get_cost(context.expr_id).await;
if let Ok((Some(cost), _)) = res {
return Ok(cost);
};
let mut output_statistic = None;
if let Ok((_, Some(statistic))) = res {
output_statistic = Some(statistic);
};
let output_cost = match node {
PhysicalNodeType::PhysicalScan => {
let output_statistic_data = output_statistic.unwrap_or(
self.derive_statistics(
node,
predicates,
children_stats,
context.clone(),
false,
)
.await?,
);
output_statistic = Some(output_statistic_data.clone());
Cost {
compute_cost: 0.0,
io_cost: output_statistic_data.0,
}
}
PhysicalNodeType::PhysicalEmptyRelation => Cost {
compute_cost: 0.1,
io_cost: 0.0,
},
PhysicalNodeType::PhysicalLimit => Cost {
compute_cost: children_costs[0].compute_cost,
io_cost: 0.0,
},
PhysicalNodeType::PhysicalFilter => Cost {
// TODO: now this equation is specific to optd, and try to make this equation more general
compute_cost: children_costs[1].compute_cost * children_stats[0].0,
io_cost: 0.0,
},
PhysicalNodeType::PhysicalNestedLoopJoin(join_typ) => {
let child_compute_cost = children_costs[2].compute_cost;
Cost {
compute_cost: children_stats[0].0 * children_stats[1].0 * child_compute_cost
+ children_stats[0].0,
io_cost: 0.0,
}
}
// TODO: we should document that the first child is the left table, which is used to build
// the hash table.
PhysicalNodeType::PhysicalHashJoin(join_typ) => Cost {
compute_cost: children_stats[0].0 * 2.0 + children_stats[1].0,
io_cost: 0.0,
},
PhysicalNodeType::PhysicalAgg => Cost {
compute_cost: children_stats[0].0
* (children_costs[1].compute_cost + children_costs[2].compute_cost),
io_cost: 0.0,
},
PhysicalNodeType::PhysicalProjection => Cost {
compute_cost: children_stats[0].0 * children_costs[1].compute_cost,
io_cost: 0.0,
},
PhysicalNodeType::PhysicalSort => Cost {
compute_cost: children_stats[0].0 * children_stats[0].0.ln_1p().max(1.0),
io_cost: 0.0,
},
};
let res = self
.storage_manager
.store_cost(
context.expr_id,
Some(output_cost.clone()),
output_statistic,
None,
)
.await;
if res.is_err() {
eprintln!("Failed to store output cost");
}
Ok(output_cost)
}

/// TODO: should we add epoch_id?
async fn derive_statistics(
&self,
node: PhysicalNodeType,
predicates: &[ArcPredicateNode],
children_statistics: &[EstimatedStatistic],
context: ComputeCostContext,
store_output_statistic: bool,
) -> CostModelResult<EstimatedStatistic> {
match node {
let res = self.storage_manager.get_cost(context.expr_id).await;
if let Ok((_, Some(statistic))) = res {
return Ok(statistic);
}
let output_statistic = match node {
PhysicalNodeType::PhysicalScan => {
let table_id = TableId(predicates[0].data.as_ref().unwrap().as_u64());
let row_cnt = self
Expand Down Expand Up @@ -114,7 +201,17 @@ impl<S: CostModelStorageManager + Send + Sync + 'static> CostModel for CostModel
PhysicalNodeType::PhysicalSort | PhysicalNodeType::PhysicalProjection => {
Ok(children_statistics[0].clone())
}
}
}?;
if store_output_statistic {
let res = self
.storage_manager
.store_cost(context.expr_id, None, Some(output_statistic.clone()), None)
.await;
if res.is_err() {
eprintln!("Failed to store output statistic");
}
};
Ok(output_statistic)
Comment on lines +205 to +214
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might cause problem if some stats are missed from the storage. We should in the future consider having sth like a retry queue to store failed database requests. But I'm good with this for now. Maybe we can add a comment here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I think the retry should be handled by the storage manager, and I will add a comment there. I don't think it is a big deal, since maybe we can recalculate it again?

}

async fn update_statistics(
Expand Down Expand Up @@ -167,3 +264,5 @@ impl<S: CostModelStorageManager> CostModelImpl<S> {
.await
}
}

// TODO: Add tests for `derive_statistic`` and `compute_operation_cost`.
62 changes: 60 additions & 2 deletions optd-cost-model/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,52 @@ pub struct ComputeCostContext {
}

#[derive(Default, Clone, Debug, PartialOrd, PartialEq)]
pub struct Cost(pub Vec<f64>);
pub struct Cost {
pub compute_cost: f64,
pub io_cost: f64,
}

impl From<Cost> for optd_persistent::cost_model::interface::Cost {
fn from(c: Cost) -> optd_persistent::cost_model::interface::Cost {
Self {
compute_cost: c.compute_cost,
io_cost: c.io_cost,
}
}
}

impl From<optd_persistent::cost_model::interface::Cost> for Cost {
fn from(c: optd_persistent::cost_model::interface::Cost) -> Cost {
Self {
compute_cost: c.compute_cost,
io_cost: c.io_cost,
}
}
}

/// Estimated statistic calculated by the cost model.
/// It is the estimated output row count of the targeted expression.
#[derive(PartialEq, PartialOrd, Clone, Debug)]
pub struct EstimatedStatistic(pub f64);

impl From<EstimatedStatistic> for f32 {
fn from(e: EstimatedStatistic) -> f32 {
e.0 as f32
}
}

impl From<EstimatedStatistic> for f64 {
fn from(e: EstimatedStatistic) -> f64 {
e.0
}
}

impl From<f32> for EstimatedStatistic {
fn from(f: f32) -> EstimatedStatistic {
Self(f as f64)
}
}

pub type CostModelResult<T> = Result<T, CostModelError>;

#[derive(Debug)]
Expand Down Expand Up @@ -79,23 +118,42 @@ pub trait CostModel: 'static + Send + Sync {
/// TODO: documentation
async fn compute_operation_cost(
&self,
node: &PhysicalNodeType,
node: PhysicalNodeType,
predicates: &[ArcPredicateNode],
children_costs: &[Cost],
children_stats: &[EstimatedStatistic],
context: ComputeCostContext,
) -> CostModelResult<Cost>;

/// TODO: documentation
/// It is for cardinality estimation. The output should be the estimated
/// statistic calculated by the cost model.
/// If this method is called by `compute_operation_cost`, please set
/// `store_output_statistic` to `false`; if it is called by the optimizer,
/// please set `store_output_statistic` to `true`. Since we can store the
/// estimated statistic and cost by calling the ORM method once.
///
/// TODO: I am not sure whether to introduce `store_output_statistic`, since
/// it add complexity to the interface, considering currently only Scan needs
/// the output row count to calculate the costs. So updating the database twice
/// seems cheap. But in the future, maybe more cost computations rely on the output
/// row count. (Of course, it should be removed if we separate the cost and
/// estimated_statistic into 2 tables.)
///
/// TODO: Consider make it a helper function, so we can store Cost in the
/// ORM more easily.
///
/// TODO: I would suggest to rename this method to `derive_row_count`, since
/// statistic is easily to be confused with the real statistic.
/// Also we need to update other places to use estimated statistic to row count,
/// either in this crate or in optd-persistent.
async fn derive_statistics(
&self,
node: PhysicalNodeType,
predicates: &[ArcPredicateNode],
children_stats: &[EstimatedStatistic],
context: ComputeCostContext,
store_output_statistic: bool,
) -> CostModelResult<EstimatedStatistic>;

/// TODO: documentation
Expand Down
25 changes: 24 additions & 1 deletion optd-cost-model/src/storage/mock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@ use std::collections::HashMap;

use serde::{Deserialize, Serialize};

use crate::{common::types::TableId, stats::AttributeCombValueStats, CostModelResult};
use crate::{
common::types::{EpochId, ExprId, TableId},
stats::AttributeCombValueStats,
Cost, CostModelResult, EstimatedStatistic,
};

use super::CostModelStorageManager;

Expand Down Expand Up @@ -63,4 +67,23 @@ impl CostModelStorageManager for CostModelStorageMockManagerImpl {
let table_stats = self.per_table_stats_map.get(&table_id);
Ok(table_stats.map(|stats| stats.row_cnt))
}

/// TODO: finish this when implementing the cost get/store tests
async fn get_cost(
&self,
expr_id: ExprId,
) -> CostModelResult<(Option<crate::Cost>, Option<EstimatedStatistic>)> {
todo!()
}

/// TODO: finish this when implementing the cost get/store tests
async fn store_cost(
&self,
expr_id: ExprId,
cost: Option<Cost>,
estimated_statistic: Option<EstimatedStatistic>,
epoch_id: Option<EpochId>,
) -> CostModelResult<()> {
todo!()
}
}
19 changes: 18 additions & 1 deletion optd-cost-model/src/storage/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
use crate::{common::types::TableId, stats::AttributeCombValueStats, CostModelResult};
use crate::{
common::types::{EpochId, ExprId, TableId},
stats::AttributeCombValueStats,
Cost, CostModelResult, EstimatedStatistic,
};

pub mod mock;
pub mod persistent;
Expand All @@ -12,4 +16,17 @@ pub trait CostModelStorageManager {
) -> CostModelResult<Option<AttributeCombValueStats>>;

async fn get_table_row_count(&self, table_id: TableId) -> CostModelResult<Option<u64>>;

async fn get_cost(
&self,
expr_id: ExprId,
) -> CostModelResult<(Option<Cost>, Option<EstimatedStatistic>)>;

async fn store_cost(
&self,
expr_id: ExprId,
cost: Option<Cost>,
estimated_statistic: Option<EstimatedStatistic>,
epoch_id: Option<EpochId>,
) -> CostModelResult<()>;
}
41 changes: 39 additions & 2 deletions optd-cost-model/src/storage/persistent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ use std::sync::Arc;
use optd_persistent::{cost_model::interface::StatType, CostModelStorageLayer};

use crate::{
common::types::TableId,
common::types::{EpochId, ExprId, TableId},
stats::{utilities::counter::Counter, AttributeCombValueStats, Distribution, MostCommonValues},
CostModelResult,
Cost, CostModelResult, EstimatedStatistic,
};

use super::CostModelStorageManager;
Expand Down Expand Up @@ -125,5 +125,42 @@ impl<S: CostModelStorageLayer + Send + Sync> CostModelStorageManager
.transpose()?)
}

/// TODO: The name is misleading, since we can also get the estimated statistic. We should
/// rename it.
///
/// TODO: Add retry logic here.
async fn get_cost(
&self,
expr_id: ExprId,
) -> CostModelResult<(Option<Cost>, Option<EstimatedStatistic>)> {
let (cost, estimated_statistic) = self.backend_manager.get_cost(expr_id.into()).await?;
Ok((
cost.map(|c| c.into()),
estimated_statistic.map(|x| x.into()),
))
}

/// TODO: The name is misleading, since we can also get the estimated statistic. We should
/// rename it.
///
/// TODO: Add retry logic here.
async fn store_cost(
&self,
expr_id: ExprId,
cost: Option<Cost>,
estimated_statistic: Option<EstimatedStatistic>,
epoch_id: Option<EpochId>,
) -> CostModelResult<()> {
self.backend_manager
.store_cost(
expr_id.into(),
cost.map(|c| c.into()),
estimated_statistic.map(|x| x.into()),
epoch_id.map(|id| id.into()),
)
.await?;
Ok(())
}

// TODO: Support querying for a specific type of statistics.
}
4 changes: 2 additions & 2 deletions optd-persistent/src/bin/init.rs
Original file line number Diff line number Diff line change
Expand Up @@ -355,8 +355,8 @@ async fn init_all_tables() -> Result<(), sea_orm::error::DbErr> {
id: Set(1),
physical_expression_id: Set(1),
epoch_id: Set(1),
cost: Set(json!({"compute_cost":10, "io_cost":10})),
estimated_statistic: Set(10),
cost: Set(Some(json!({"compute_cost":10, "io_cost":10}))),
estimated_statistic: Set(Some(10.0)),
is_valid: Set(true),
};
plan_cost::Entity::insert(plan_cost)
Expand Down
Loading