@@ -9,7 +9,8 @@ use optd_persistent::{
99
1010use crate :: {
1111 common:: {
12- nodes:: { ArcPredicateNode , PhysicalNodeType } ,
12+ nodes:: { ArcPredicateNode , PhysicalNodeType , ReprPredicateNode } ,
13+ predicates:: list_pred:: ListPred ,
1314 types:: { AttrId , EpochId , ExprId , TableId } ,
1415 } ,
1516 memo_ext:: MemoExt ,
@@ -40,28 +41,83 @@ impl<S: CostModelStorageManager> CostModelImpl<S> {
4041 }
4142}
4243
44+ #[ async_trait:: async_trait]
4345impl < S : CostModelStorageManager + Send + Sync + ' static > CostModel for CostModelImpl < S > {
44- fn compute_operation_cost (
46+ async fn compute_operation_cost (
4547 & self ,
4648 node : & PhysicalNodeType ,
4749 predicates : & [ ArcPredicateNode ] ,
48- children_stats : & [ Option < & EstimatedStatistic > ] ,
50+ children_stats : & [ EstimatedStatistic ] ,
4951 context : ComputeCostContext ,
5052 ) -> CostModelResult < Cost > {
5153 todo ! ( )
5254 }
5355
54- fn derive_statistics (
56+ async fn derive_statistics (
5557 & self ,
5658 node : PhysicalNodeType ,
5759 predicates : & [ ArcPredicateNode ] ,
58- children_statistics : & [ Option < & EstimatedStatistic > ] ,
60+ children_statistics : & [ EstimatedStatistic ] ,
5961 context : ComputeCostContext ,
6062 ) -> CostModelResult < EstimatedStatistic > {
61- todo ! ( )
63+ match node {
64+ PhysicalNodeType :: PhysicalScan => {
65+ let table_id = TableId ( predicates[ 0 ] . data . as_ref ( ) . unwrap ( ) . as_u64 ( ) ) ;
66+ let row_cnt = self
67+ . storage_manager
68+ . get_table_row_count ( table_id)
69+ . await ?
70+ . unwrap_or ( 1 ) as f64 ;
71+ Ok ( EstimatedStatistic ( row_cnt) )
72+ }
73+ PhysicalNodeType :: PhysicalEmptyRelation => Ok ( EstimatedStatistic ( 0.01 ) ) ,
74+ PhysicalNodeType :: PhysicalLimit => {
75+ self . get_limit_row_cnt ( children_statistics[ 0 ] . clone ( ) , predicates[ 1 ] . clone ( ) )
76+ }
77+ PhysicalNodeType :: PhysicalFilter => {
78+ self . get_filter_row_cnt (
79+ children_statistics[ 0 ] . clone ( ) ,
80+ context. group_id ,
81+ predicates[ 0 ] . clone ( ) ,
82+ )
83+ . await
84+ }
85+ PhysicalNodeType :: PhysicalNestedLoopJoin ( join_typ) => {
86+ self . get_nlj_row_cnt (
87+ join_typ,
88+ context. group_id ,
89+ children_statistics[ 0 ] . clone ( ) ,
90+ children_statistics[ 1 ] . clone ( ) ,
91+ context. children_group_ids [ 0 ] ,
92+ context. children_group_ids [ 1 ] ,
93+ predicates[ 0 ] . clone ( ) ,
94+ )
95+ . await
96+ }
97+ PhysicalNodeType :: PhysicalHashJoin ( join_typ) => {
98+ self . get_hash_join_row_cnt (
99+ join_typ,
100+ context. group_id ,
101+ children_statistics[ 0 ] . clone ( ) ,
102+ children_statistics[ 1 ] . clone ( ) ,
103+ context. children_group_ids [ 0 ] ,
104+ context. children_group_ids [ 1 ] ,
105+ ListPred :: from_pred_node ( predicates[ 0 ] . clone ( ) ) . unwrap ( ) ,
106+ ListPred :: from_pred_node ( predicates[ 1 ] . clone ( ) ) . unwrap ( ) ,
107+ )
108+ . await
109+ }
110+ PhysicalNodeType :: PhysicalAgg => {
111+ self . get_agg_row_cnt ( context. group_id , predicates[ 1 ] . clone ( ) )
112+ . await
113+ }
114+ PhysicalNodeType :: PhysicalSort | PhysicalNodeType :: PhysicalProjection => {
115+ Ok ( children_statistics[ 0 ] . clone ( ) )
116+ }
117+ }
62118 }
63119
64- fn update_statistics (
120+ async fn update_statistics (
65121 & self ,
66122 stats : Vec < Stat > ,
67123 source : String ,
@@ -70,7 +126,7 @@ impl<S: CostModelStorageManager + Send + Sync + 'static> CostModel for CostModel
70126 todo ! ( )
71127 }
72128
73- fn get_table_statistic_for_analysis (
129+ async fn get_table_statistic_for_analysis (
74130 & self ,
75131 table_id : TableId ,
76132 stat_type : StatType ,
@@ -79,7 +135,7 @@ impl<S: CostModelStorageManager + Send + Sync + 'static> CostModel for CostModel
79135 todo ! ( )
80136 }
81137
82- fn get_attribute_statistic_for_analysis (
138+ async fn get_attribute_statistic_for_analysis (
83139 & self ,
84140 attr_ids : Vec < AttrId > ,
85141 stat_type : StatType ,
@@ -88,7 +144,7 @@ impl<S: CostModelStorageManager + Send + Sync + 'static> CostModel for CostModel
88144 todo ! ( )
89145 }
90146
91- fn get_cost_for_analysis (
147+ async fn get_cost_for_analysis (
92148 & self ,
93149 expr_id : ExprId ,
94150 epoch_id : Option < EpochId > ,
0 commit comments