Skip to content

Commit dfaee9a

Browse files
authored
feat(cost-model): implement compute_operation_cost (#44)
* Enable separate get and store cost & estimated_statistic in ORM * Add get_cost & store_cost in the cost model storage layer * Finish compute_operation_cost * Refine store_cost in the ORM layer * Improve comments * Apply comment suggestions * Refine comment
1 parent e9cd234 commit dfaee9a

File tree

12 files changed

+462
-112
lines changed

12 files changed

+462
-112
lines changed

optd-cost-model/src/cost_model.rs

Lines changed: 104 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ use crate::{
1515
},
1616
memo_ext::MemoExt,
1717
stats::AttributeCombValueStats,
18-
storage::CostModelStorageManager,
18+
storage::{self, CostModelStorageManager},
1919
ComputeCostContext, Cost, CostModel, CostModelResult, EstimatedStatistic, StatValue,
2020
};
2121

@@ -43,24 +43,111 @@ impl<S: CostModelStorageManager> CostModelImpl<S> {
4343

4444
#[async_trait::async_trait]
4545
impl<S: CostModelStorageManager + Send + Sync + 'static> CostModel for CostModelImpl<S> {
46+
/// TODO: should we add epoch_id?
4647
async fn compute_operation_cost(
4748
&self,
48-
node: &PhysicalNodeType,
49+
node: PhysicalNodeType,
4950
predicates: &[ArcPredicateNode],
51+
children_costs: &[Cost],
5052
children_stats: &[EstimatedStatistic],
5153
context: ComputeCostContext,
5254
) -> CostModelResult<Cost> {
53-
todo!()
55+
let res = self.storage_manager.get_cost(context.expr_id).await;
56+
if let Ok((Some(cost), _)) = res {
57+
return Ok(cost);
58+
};
59+
let mut output_statistic = None;
60+
if let Ok((_, Some(statistic))) = res {
61+
output_statistic = Some(statistic);
62+
};
63+
let output_cost = match node {
64+
PhysicalNodeType::PhysicalScan => {
65+
let output_statistic_data = output_statistic.unwrap_or(
66+
self.derive_statistics(
67+
node,
68+
predicates,
69+
children_stats,
70+
context.clone(),
71+
false,
72+
)
73+
.await?,
74+
);
75+
output_statistic = Some(output_statistic_data.clone());
76+
Cost {
77+
compute_cost: 0.0,
78+
io_cost: output_statistic_data.0,
79+
}
80+
}
81+
PhysicalNodeType::PhysicalEmptyRelation => Cost {
82+
compute_cost: 0.1,
83+
io_cost: 0.0,
84+
},
85+
PhysicalNodeType::PhysicalLimit => Cost {
86+
compute_cost: children_costs[0].compute_cost,
87+
io_cost: 0.0,
88+
},
89+
PhysicalNodeType::PhysicalFilter => Cost {
90+
// TODO: now this equation is specific to optd, and try to make this equation more general
91+
compute_cost: children_costs[1].compute_cost * children_stats[0].0,
92+
io_cost: 0.0,
93+
},
94+
PhysicalNodeType::PhysicalNestedLoopJoin(join_typ) => {
95+
let child_compute_cost = children_costs[2].compute_cost;
96+
Cost {
97+
compute_cost: children_stats[0].0 * children_stats[1].0 * child_compute_cost
98+
+ children_stats[0].0,
99+
io_cost: 0.0,
100+
}
101+
}
102+
// TODO: we should document that the first child is the left table, which is used to build
103+
// the hash table.
104+
PhysicalNodeType::PhysicalHashJoin(join_typ) => Cost {
105+
compute_cost: children_stats[0].0 * 2.0 + children_stats[1].0,
106+
io_cost: 0.0,
107+
},
108+
PhysicalNodeType::PhysicalAgg => Cost {
109+
compute_cost: children_stats[0].0
110+
* (children_costs[1].compute_cost + children_costs[2].compute_cost),
111+
io_cost: 0.0,
112+
},
113+
PhysicalNodeType::PhysicalProjection => Cost {
114+
compute_cost: children_stats[0].0 * children_costs[1].compute_cost,
115+
io_cost: 0.0,
116+
},
117+
PhysicalNodeType::PhysicalSort => Cost {
118+
compute_cost: children_stats[0].0 * children_stats[0].0.ln_1p().max(1.0),
119+
io_cost: 0.0,
120+
},
121+
};
122+
let res = self
123+
.storage_manager
124+
.store_cost(
125+
context.expr_id,
126+
Some(output_cost.clone()),
127+
output_statistic,
128+
None,
129+
)
130+
.await;
131+
if res.is_err() {
132+
eprintln!("Failed to store output cost");
133+
}
134+
Ok(output_cost)
54135
}
55136

137+
/// TODO: should we add epoch_id?
56138
async fn derive_statistics(
57139
&self,
58140
node: PhysicalNodeType,
59141
predicates: &[ArcPredicateNode],
60142
children_statistics: &[EstimatedStatistic],
61143
context: ComputeCostContext,
144+
store_output_statistic: bool,
62145
) -> CostModelResult<EstimatedStatistic> {
63-
match node {
146+
let res = self.storage_manager.get_cost(context.expr_id).await;
147+
if let Ok((_, Some(statistic))) = res {
148+
return Ok(statistic);
149+
}
150+
let output_statistic = match node {
64151
PhysicalNodeType::PhysicalScan => {
65152
let table_id = TableId(predicates[0].data.as_ref().unwrap().as_u64());
66153
let row_cnt = self
@@ -114,7 +201,17 @@ impl<S: CostModelStorageManager + Send + Sync + 'static> CostModel for CostModel
114201
PhysicalNodeType::PhysicalSort | PhysicalNodeType::PhysicalProjection => {
115202
Ok(children_statistics[0].clone())
116203
}
117-
}
204+
}?;
205+
if store_output_statistic {
206+
let res = self
207+
.storage_manager
208+
.store_cost(context.expr_id, None, Some(output_statistic.clone()), None)
209+
.await;
210+
if res.is_err() {
211+
eprintln!("Failed to store output statistic");
212+
}
213+
};
214+
Ok(output_statistic)
118215
}
119216

120217
async fn update_statistics(
@@ -167,3 +264,5 @@ impl<S: CostModelStorageManager> CostModelImpl<S> {
167264
.await
168265
}
169266
}
267+
268+
// TODO: Add tests for `derive_statistic`` and `compute_operation_cost`.

optd-cost-model/src/lib.rs

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,52 @@ pub struct ComputeCostContext {
3030
}
3131

3232
#[derive(Default, Clone, Debug, PartialOrd, PartialEq)]
33-
pub struct Cost(pub Vec<f64>);
33+
pub struct Cost {
34+
pub compute_cost: f64,
35+
pub io_cost: f64,
36+
}
37+
38+
impl From<Cost> for optd_persistent::cost_model::interface::Cost {
39+
fn from(c: Cost) -> optd_persistent::cost_model::interface::Cost {
40+
Self {
41+
compute_cost: c.compute_cost,
42+
io_cost: c.io_cost,
43+
}
44+
}
45+
}
46+
47+
impl From<optd_persistent::cost_model::interface::Cost> for Cost {
48+
fn from(c: optd_persistent::cost_model::interface::Cost) -> Cost {
49+
Self {
50+
compute_cost: c.compute_cost,
51+
io_cost: c.io_cost,
52+
}
53+
}
54+
}
3455

3556
/// Estimated statistic calculated by the cost model.
3657
/// It is the estimated output row count of the targeted expression.
3758
#[derive(PartialEq, PartialOrd, Clone, Debug)]
3859
pub struct EstimatedStatistic(pub f64);
3960

61+
impl From<EstimatedStatistic> for f32 {
62+
fn from(e: EstimatedStatistic) -> f32 {
63+
e.0 as f32
64+
}
65+
}
66+
67+
impl From<EstimatedStatistic> for f64 {
68+
fn from(e: EstimatedStatistic) -> f64 {
69+
e.0
70+
}
71+
}
72+
73+
impl From<f32> for EstimatedStatistic {
74+
fn from(f: f32) -> EstimatedStatistic {
75+
Self(f as f64)
76+
}
77+
}
78+
4079
pub type CostModelResult<T> = Result<T, CostModelError>;
4180

4281
#[derive(Debug)]
@@ -79,23 +118,42 @@ pub trait CostModel: 'static + Send + Sync {
79118
/// TODO: documentation
80119
async fn compute_operation_cost(
81120
&self,
82-
node: &PhysicalNodeType,
121+
node: PhysicalNodeType,
83122
predicates: &[ArcPredicateNode],
123+
children_costs: &[Cost],
84124
children_stats: &[EstimatedStatistic],
85125
context: ComputeCostContext,
86126
) -> CostModelResult<Cost>;
87127

88128
/// TODO: documentation
89129
/// It is for cardinality estimation. The output should be the estimated
90130
/// statistic calculated by the cost model.
131+
/// If this method is called by `compute_operation_cost`, please set
132+
/// `store_output_statistic` to `false`; if it is called by the optimizer,
133+
/// please set `store_output_statistic` to `true`. Since we can store the
134+
/// estimated statistic and cost by calling the ORM method once.
135+
///
136+
/// TODO: I am not sure whether to introduce `store_output_statistic`, since
137+
/// it add complexity to the interface, considering currently only Scan needs
138+
/// the output row count to calculate the costs. So updating the database twice
139+
/// seems cheap. But in the future, maybe more cost computations rely on the output
140+
/// row count. (Of course, it should be removed if we separate the cost and
141+
/// estimated_statistic into 2 tables.)
142+
///
91143
/// TODO: Consider make it a helper function, so we can store Cost in the
92144
/// ORM more easily.
145+
///
146+
/// TODO: I would suggest to rename this method to `derive_row_count`, since
147+
/// statistic is easily to be confused with the real statistic.
148+
/// Also we need to update other places to use estimated statistic to row count,
149+
/// either in this crate or in optd-persistent.
93150
async fn derive_statistics(
94151
&self,
95152
node: PhysicalNodeType,
96153
predicates: &[ArcPredicateNode],
97154
children_stats: &[EstimatedStatistic],
98155
context: ComputeCostContext,
156+
store_output_statistic: bool,
99157
) -> CostModelResult<EstimatedStatistic>;
100158

101159
/// TODO: documentation

optd-cost-model/src/storage/mock.rs

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@ use std::collections::HashMap;
33

44
use serde::{Deserialize, Serialize};
55

6-
use crate::{common::types::TableId, stats::AttributeCombValueStats, CostModelResult};
6+
use crate::{
7+
common::types::{EpochId, ExprId, TableId},
8+
stats::AttributeCombValueStats,
9+
Cost, CostModelResult, EstimatedStatistic,
10+
};
711

812
use super::CostModelStorageManager;
913

@@ -63,4 +67,23 @@ impl CostModelStorageManager for CostModelStorageMockManagerImpl {
6367
let table_stats = self.per_table_stats_map.get(&table_id);
6468
Ok(table_stats.map(|stats| stats.row_cnt))
6569
}
70+
71+
/// TODO: finish this when implementing the cost get/store tests
72+
async fn get_cost(
73+
&self,
74+
expr_id: ExprId,
75+
) -> CostModelResult<(Option<crate::Cost>, Option<EstimatedStatistic>)> {
76+
todo!()
77+
}
78+
79+
/// TODO: finish this when implementing the cost get/store tests
80+
async fn store_cost(
81+
&self,
82+
expr_id: ExprId,
83+
cost: Option<Cost>,
84+
estimated_statistic: Option<EstimatedStatistic>,
85+
epoch_id: Option<EpochId>,
86+
) -> CostModelResult<()> {
87+
todo!()
88+
}
6689
}

optd-cost-model/src/storage/mod.rs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
use crate::{common::types::TableId, stats::AttributeCombValueStats, CostModelResult};
1+
use crate::{
2+
common::types::{EpochId, ExprId, TableId},
3+
stats::AttributeCombValueStats,
4+
Cost, CostModelResult, EstimatedStatistic,
5+
};
26

37
pub mod mock;
48
pub mod persistent;
@@ -12,4 +16,17 @@ pub trait CostModelStorageManager {
1216
) -> CostModelResult<Option<AttributeCombValueStats>>;
1317

1418
async fn get_table_row_count(&self, table_id: TableId) -> CostModelResult<Option<u64>>;
19+
20+
async fn get_cost(
21+
&self,
22+
expr_id: ExprId,
23+
) -> CostModelResult<(Option<Cost>, Option<EstimatedStatistic>)>;
24+
25+
async fn store_cost(
26+
&self,
27+
expr_id: ExprId,
28+
cost: Option<Cost>,
29+
estimated_statistic: Option<EstimatedStatistic>,
30+
epoch_id: Option<EpochId>,
31+
) -> CostModelResult<()>;
1532
}

optd-cost-model/src/storage/persistent.rs

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ use std::sync::Arc;
44
use optd_persistent::{cost_model::interface::StatType, CostModelStorageLayer};
55

66
use crate::{
7-
common::types::TableId,
7+
common::types::{EpochId, ExprId, TableId},
88
stats::{utilities::counter::Counter, AttributeCombValueStats, Distribution, MostCommonValues},
9-
CostModelResult,
9+
Cost, CostModelResult, EstimatedStatistic,
1010
};
1111

1212
use super::CostModelStorageManager;
@@ -125,5 +125,42 @@ impl<S: CostModelStorageLayer + Send + Sync> CostModelStorageManager
125125
.transpose()?)
126126
}
127127

128+
/// TODO: The name is misleading, since we can also get the estimated statistic. We should
129+
/// rename it.
130+
///
131+
/// TODO: Add retry logic here.
132+
async fn get_cost(
133+
&self,
134+
expr_id: ExprId,
135+
) -> CostModelResult<(Option<Cost>, Option<EstimatedStatistic>)> {
136+
let (cost, estimated_statistic) = self.backend_manager.get_cost(expr_id.into()).await?;
137+
Ok((
138+
cost.map(|c| c.into()),
139+
estimated_statistic.map(|x| x.into()),
140+
))
141+
}
142+
143+
/// TODO: The name is misleading, since we can also get the estimated statistic. We should
144+
/// rename it.
145+
///
146+
/// TODO: Add retry logic here.
147+
async fn store_cost(
148+
&self,
149+
expr_id: ExprId,
150+
cost: Option<Cost>,
151+
estimated_statistic: Option<EstimatedStatistic>,
152+
epoch_id: Option<EpochId>,
153+
) -> CostModelResult<()> {
154+
self.backend_manager
155+
.store_cost(
156+
expr_id.into(),
157+
cost.map(|c| c.into()),
158+
estimated_statistic.map(|x| x.into()),
159+
epoch_id.map(|id| id.into()),
160+
)
161+
.await?;
162+
Ok(())
163+
}
164+
128165
// TODO: Support querying for a specific type of statistics.
129166
}

optd-persistent/src/bin/init.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -355,8 +355,8 @@ async fn init_all_tables() -> Result<(), sea_orm::error::DbErr> {
355355
id: Set(1),
356356
physical_expression_id: Set(1),
357357
epoch_id: Set(1),
358-
cost: Set(json!({"compute_cost":10, "io_cost":10})),
359-
estimated_statistic: Set(10),
358+
cost: Set(Some(json!({"compute_cost":10, "io_cost":10}))),
359+
estimated_statistic: Set(Some(10.0)),
360360
is_valid: Set(true),
361361
};
362362
plan_cost::Entity::insert(plan_cost)

0 commit comments

Comments
 (0)