Skip to content

Commit 7f947d3

Browse files
committed
implement cost model derive statistics
1 parent 49baa5b commit 7f947d3

File tree

5 files changed

+80
-20
lines changed

5 files changed

+80
-20
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

optd-cost-model/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ itertools = "0.13"
1818
assert_approx_eq = "1.1.0"
1919
trait-variant = "0.1.2"
2020
tokio = { version = "1.0.1", features = ["macros", "rt-multi-thread"] }
21+
async-trait = "0.1"
2122

2223
[dev-dependencies]
2324
crossbeam = "0.8"

optd-cost-model/src/cost/join/core.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -904,7 +904,8 @@ mod tests {
904904
expected_inner_sel
905905
);
906906
// check the outer sels
907-
assert_outer_selectivities(&cost_model, expr_tree, expr_tree_rev, &attr_refs, 0.25, 0.2);
907+
assert_outer_selectivities(&cost_model, expr_tree, expr_tree_rev, &attr_refs, 0.25, 0.2)
908+
.await;
908909
}
909910

910911
/// Non-unique oncond means the column is not unique in either table

optd-cost-model/src/cost_model.rs

Lines changed: 66 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ use optd_persistent::{
99

1010
use crate::{
1111
common::{
12-
nodes::{ArcPredicateNode, PhysicalNodeType},
12+
nodes::{ArcPredicateNode, PhysicalNodeType, ReprPredicateNode},
13+
predicates::list_pred::ListPred,
1314
types::{AttrId, EpochId, ExprId, TableId},
1415
},
1516
memo_ext::MemoExt,
@@ -40,28 +41,83 @@ impl<S: CostModelStorageManager> CostModelImpl<S> {
4041
}
4142
}
4243

44+
#[async_trait::async_trait]
4345
impl<S: CostModelStorageManager + Send + Sync + 'static> CostModel for CostModelImpl<S> {
44-
fn compute_operation_cost(
46+
async fn compute_operation_cost(
4547
&self,
4648
node: &PhysicalNodeType,
4749
predicates: &[ArcPredicateNode],
48-
children_stats: &[Option<&EstimatedStatistic>],
50+
children_stats: &[EstimatedStatistic],
4951
context: ComputeCostContext,
5052
) -> CostModelResult<Cost> {
5153
todo!()
5254
}
5355

54-
fn derive_statistics(
56+
async fn derive_statistics(
5557
&self,
5658
node: PhysicalNodeType,
5759
predicates: &[ArcPredicateNode],
58-
children_statistics: &[Option<&EstimatedStatistic>],
60+
children_statistics: &[EstimatedStatistic],
5961
context: ComputeCostContext,
6062
) -> CostModelResult<EstimatedStatistic> {
61-
todo!()
63+
match node {
64+
PhysicalNodeType::PhysicalScan => {
65+
let table_id = TableId(predicates[0].data.as_ref().unwrap().as_u64());
66+
let row_cnt = self
67+
.storage_manager
68+
.get_table_row_count(table_id)
69+
.await?
70+
.unwrap_or(1) as f64;
71+
Ok(EstimatedStatistic(row_cnt))
72+
}
73+
PhysicalNodeType::PhysicalEmptyRelation => Ok(EstimatedStatistic(0.01)),
74+
PhysicalNodeType::PhysicalLimit => {
75+
self.get_limit_row_cnt(children_statistics[0].clone(), predicates[1].clone())
76+
}
77+
PhysicalNodeType::PhysicalFilter => {
78+
self.get_filter_row_cnt(
79+
children_statistics[0].clone(),
80+
context.group_id,
81+
predicates[0].clone(),
82+
)
83+
.await
84+
}
85+
PhysicalNodeType::PhysicalNestedLoopJoin(join_typ) => {
86+
self.get_nlj_row_cnt(
87+
join_typ,
88+
context.group_id,
89+
children_statistics[0].clone(),
90+
children_statistics[1].clone(),
91+
context.children_group_ids[0],
92+
context.children_group_ids[1],
93+
predicates[0].clone(),
94+
)
95+
.await
96+
}
97+
PhysicalNodeType::PhysicalHashJoin(join_typ) => {
98+
self.get_hash_join_row_cnt(
99+
join_typ,
100+
context.group_id,
101+
children_statistics[0].clone(),
102+
children_statistics[1].clone(),
103+
context.children_group_ids[0],
104+
context.children_group_ids[1],
105+
ListPred::from_pred_node(predicates[0].clone()).unwrap(),
106+
ListPred::from_pred_node(predicates[1].clone()).unwrap(),
107+
)
108+
.await
109+
}
110+
PhysicalNodeType::PhysicalAgg => {
111+
self.get_agg_row_cnt(context.group_id, predicates[1].clone())
112+
.await
113+
}
114+
PhysicalNodeType::PhysicalSort | PhysicalNodeType::PhysicalProjection => {
115+
Ok(children_statistics[0].clone())
116+
}
117+
}
62118
}
63119

64-
fn update_statistics(
120+
async fn update_statistics(
65121
&self,
66122
stats: Vec<Stat>,
67123
source: String,
@@ -70,7 +126,7 @@ impl<S: CostModelStorageManager + Send + Sync + 'static> CostModel for CostModel
70126
todo!()
71127
}
72128

73-
fn get_table_statistic_for_analysis(
129+
async fn get_table_statistic_for_analysis(
74130
&self,
75131
table_id: TableId,
76132
stat_type: StatType,
@@ -79,7 +135,7 @@ impl<S: CostModelStorageManager + Send + Sync + 'static> CostModel for CostModel
79135
todo!()
80136
}
81137

82-
fn get_attribute_statistic_for_analysis(
138+
async fn get_attribute_statistic_for_analysis(
83139
&self,
84140
attr_ids: Vec<AttrId>,
85141
stat_type: StatType,
@@ -88,7 +144,7 @@ impl<S: CostModelStorageManager + Send + Sync + 'static> CostModel for CostModel
88144
todo!()
89145
}
90146

91-
fn get_cost_for_analysis(
147+
async fn get_cost_for_analysis(
92148
&self,
93149
expr_id: ExprId,
94150
epoch_id: Option<EpochId>,

optd-cost-model/src/lib.rs

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ pub struct Cost(pub Vec<f64>);
3333

3434
/// Estimated statistic calculated by the cost model.
3535
/// It is the estimated output row count of the targeted expression.
36-
#[derive(PartialEq, PartialOrd, Debug)]
36+
#[derive(PartialEq, PartialOrd, Clone, Debug)]
3737
pub struct EstimatedStatistic(pub f64);
3838

3939
pub type CostModelResult<T> = Result<T, CostModelError>;
@@ -73,13 +73,14 @@ impl From<serde_json::Error> for CostModelError {
7373
}
7474
}
7575

76+
#[async_trait::async_trait]
7677
pub trait CostModel: 'static + Send + Sync {
7778
/// TODO: documentation
78-
fn compute_operation_cost(
79+
async fn compute_operation_cost(
7980
&self,
8081
node: &PhysicalNodeType,
8182
predicates: &[ArcPredicateNode],
82-
children_stats: &[Option<&EstimatedStatistic>],
83+
children_stats: &[EstimatedStatistic],
8384
context: ComputeCostContext,
8485
) -> CostModelResult<Cost>;
8586

@@ -88,42 +89,42 @@ pub trait CostModel: 'static + Send + Sync {
8889
/// statistic calculated by the cost model.
8990
/// TODO: Consider make it a helper function, so we can store Cost in the
9091
/// ORM more easily.
91-
fn derive_statistics(
92+
async fn derive_statistics(
9293
&self,
9394
node: PhysicalNodeType,
9495
predicates: &[ArcPredicateNode],
95-
children_stats: &[Option<&EstimatedStatistic>],
96+
children_stats: &[EstimatedStatistic],
9697
context: ComputeCostContext,
9798
) -> CostModelResult<EstimatedStatistic>;
9899

99100
/// TODO: documentation
100101
/// It is for **REAL** statistic updates, not for estimated statistics.
101102
/// TODO: Change data from String to other types.
102-
fn update_statistics(
103+
async fn update_statistics(
103104
&self,
104105
stats: Vec<Stat>,
105106
source: String,
106107
data: String,
107108
) -> CostModelResult<()>;
108109

109110
/// TODO: documentation
110-
fn get_table_statistic_for_analysis(
111+
async fn get_table_statistic_for_analysis(
111112
&self,
112113
table_id: TableId,
113114
stat_type: StatType,
114115
epoch_id: Option<EpochId>,
115116
) -> CostModelResult<Option<StatValue>>;
116117

117118
/// TODO: documentation
118-
fn get_attribute_statistic_for_analysis(
119+
async fn get_attribute_statistic_for_analysis(
119120
&self,
120121
attr_ids: Vec<AttrId>,
121122
stat_type: StatType,
122123
epoch_id: Option<EpochId>,
123124
) -> CostModelResult<Option<StatValue>>;
124125

125126
/// TODO: documentation
126-
fn get_cost_for_analysis(
127+
async fn get_cost_for_analysis(
127128
&self,
128129
expr_id: ExprId,
129130
epoch_id: Option<EpochId>,

0 commit comments

Comments
 (0)