@@ -534,7 +534,7 @@ impl CostModelStorageLayer for BackendManager {
534534 epoch_id : Option < EpochId > ,
535535 ) -> StorageResult < ( ) > {
536536 assert ! ( cost. is_some( ) || estimated_statistic. is_some( ) ) ;
537- // TODO: should we do the following checks in the production environment?
537+ // TODO: we shouldn't do the following checks in the production environment.
538538 let expr_exists = PhysicalExpression :: find_by_id ( physical_expression_id)
539539 . one ( & self . db )
540540 . await ?;
@@ -561,30 +561,54 @@ impl CostModelStorageLayer for BackendManager {
561561 }
562562 }
563563
564- let epoch_id = match epoch_id {
565- Some ( id) => id,
566- None => {
567- // When init, please make sure there is at least one epoch in the Event table.
568- let latest_epoch_id = Event :: find ( )
569- . order_by_desc ( event:: Column :: EpochId )
570- . one ( & self . db )
571- . await ?
572- . unwrap ( ) ;
573- latest_epoch_id. epoch_id
574- }
575- } ;
576-
577564 let transaction = self . db . begin ( ) . await ?;
578565
579- let valid_cost = PlanCost :: find ( )
580- . filter ( plan_cost:: Column :: PhysicalExpressionId . eq ( physical_expression_id) )
581- . filter ( plan_cost:: Column :: EpochId . eq ( epoch_id) )
582- . filter ( plan_cost:: Column :: IsValid . eq ( true ) )
583- . one ( & transaction)
584- . await ?;
566+ /*
567+ The `store_cost` logic is as follows:
568+ 1. If the epoch_id is provided, we should update the cost with the corresponding epoch_id,
569+ or insert a new record if it doesn't exist.
570+ 2. If the epoch_id is not provided, we cannot directly use the latest epoch_id, since in the
571+ plan_cost table, for the current physical expression, there may be a valid cost with a lower
572+ epoch_id, since the update_stats function updates unrelated stats. So we need to handle the
573+ epoch_id in following logics:
574+ 1) If a valid cost is already in the plan_cost table, we use the same epoch_id.
575+ 2) If there is no valid cost in the plan_cost table, or there is no record, we use the
576+ latest epoch_id.
577+ */
578+ // TODO: We should add some integration tests to fully test the above logic
579+ let epoch_id_data;
580+ let existed_cost;
581+ if let Some ( epoch_id) = epoch_id {
582+ epoch_id_data = epoch_id;
583+ existed_cost = PlanCost :: find ( )
584+ . filter ( plan_cost:: Column :: PhysicalExpressionId . eq ( physical_expression_id) )
585+ . filter ( plan_cost:: Column :: EpochId . eq ( epoch_id) )
586+ . one ( & transaction)
587+ . await ?;
588+ } else {
589+ existed_cost = PlanCost :: find ( )
590+ . filter ( plan_cost:: Column :: PhysicalExpressionId . eq ( physical_expression_id) )
591+ . filter ( plan_cost:: Column :: IsValid . eq ( true ) )
592+ . order_by_desc ( plan_cost:: Column :: EpochId )
593+ . one ( & transaction)
594+ . await ?;
595+ if existed_cost. is_none ( ) {
596+ epoch_id_data = {
597+ // When init, please make sure there is at least one epoch in the Event table.
598+ let latest_epoch_id = Event :: find ( )
599+ . order_by_desc ( event:: Column :: EpochId )
600+ . one ( & self . db )
601+ . await ?
602+ . unwrap ( ) ;
603+ latest_epoch_id. epoch_id
604+ }
605+ } else {
606+ epoch_id_data = existed_cost. clone ( ) . unwrap ( ) . epoch_id ;
607+ }
608+ }
585609
586- if valid_cost . is_some ( ) {
587- let mut new_cost: plan_cost:: ActiveModel = valid_cost . unwrap ( ) . into ( ) ;
610+ if existed_cost . is_some ( ) {
611+ let mut new_cost: plan_cost:: ActiveModel = existed_cost . unwrap ( ) . into ( ) ;
588612 let mut update = false ;
589613 if cost. is_some ( ) {
590614 let input_cost = sea_orm:: ActiveValue :: Set ( Some ( json ! ( {
@@ -604,12 +628,25 @@ impl CostModelStorageLayer for BackendManager {
604628 }
605629 }
606630 if update {
631+ assert ! ( new_cost. epoch_id. is_unchanged( ) ) ;
607632 let _ = PlanCost :: update ( new_cost) . exec ( & transaction) . await ?;
608633 }
609634 } else {
635+ // TODO: we shouldn't do the following checks in the production environment.
636+ // This check may be easy to violate, so consider removing epoch_id input parameter.
637+ let latest_cost = PlanCost :: find ( )
638+ . filter ( plan_cost:: Column :: PhysicalExpressionId . eq ( physical_expression_id) )
639+ . order_by_desc ( plan_cost:: Column :: EpochId )
640+ . one ( & transaction)
641+ . await ?;
642+ if latest_cost. is_some ( ) {
643+ assert ! ( latest_cost. clone( ) . unwrap( ) . epoch_id < epoch_id_data) ;
644+ assert ! ( !latest_cost. clone( ) . unwrap( ) . is_valid) ;
645+ }
646+
610647 let new_cost = plan_cost:: ActiveModel {
611648 physical_expression_id : sea_orm:: ActiveValue :: Set ( physical_expression_id) ,
612- epoch_id : sea_orm:: ActiveValue :: Set ( epoch_id ) ,
649+ epoch_id : sea_orm:: ActiveValue :: Set ( epoch_id_data ) ,
613650 cost : sea_orm:: ActiveValue :: Set (
614651 cost. map ( |c| json ! ( { "compute_cost" : c. compute_cost, "io_cost" : c. io_cost} ) ) ,
615652 ) ,
@@ -1035,6 +1072,18 @@ mod tests {
10351072 . create_new_epoch ( "source" . to_string ( ) , "data" . to_string ( ) )
10361073 . await
10371074 . unwrap ( ) ;
1075+ let stat = Stat {
1076+ stat_type : StatType :: TableRowCount ,
1077+ stat_value : json ! ( 10 ) ,
1078+ attr_ids : vec ! [ ] ,
1079+ table_id : Some ( 1 ) ,
1080+ name : "row_count" . to_owned ( ) ,
1081+ } ;
1082+ let res = backend_manager
1083+ . update_stats ( stat, EpochOption :: Existed ( epoch_id) )
1084+ . await ;
1085+ assert ! ( res. is_ok( ) ) ;
1086+
10381087 let physical_expression_id = 1 ;
10391088 let cost = Cost {
10401089 compute_cost : 42.0 ,
@@ -1102,6 +1151,18 @@ mod tests {
11021151 . create_new_epoch ( "source" . to_string ( ) , "data" . to_string ( ) )
11031152 . await
11041153 . unwrap ( ) ;
1154+ let stat = Stat {
1155+ stat_type : StatType :: TableRowCount ,
1156+ stat_value : json ! ( 10 ) ,
1157+ attr_ids : vec ! [ ] ,
1158+ table_id : Some ( 1 ) ,
1159+ name : "row_count" . to_owned ( ) ,
1160+ } ;
1161+ let res = backend_manager
1162+ . update_stats ( stat, EpochOption :: Existed ( epoch_id) )
1163+ . await ;
1164+ assert ! ( res. is_ok( ) ) ;
1165+
11051166 let physical_expression_id = 1 ;
11061167 let cost = Cost {
11071168 compute_cost : 42.0 ,
@@ -1148,6 +1209,18 @@ mod tests {
11481209 . create_new_epoch ( "source" . to_string ( ) , "data" . to_string ( ) )
11491210 . await
11501211 . unwrap ( ) ;
1212+ let stat = Stat {
1213+ stat_type : StatType :: TableRowCount ,
1214+ stat_value : json ! ( 10 ) ,
1215+ attr_ids : vec ! [ ] ,
1216+ table_id : Some ( 1 ) ,
1217+ name : "row_count" . to_owned ( ) ,
1218+ } ;
1219+ let res = backend_manager
1220+ . update_stats ( stat, EpochOption :: Existed ( epoch_id) )
1221+ . await ;
1222+ assert ! ( res. is_ok( ) ) ;
1223+
11511224 let physical_expression_id = 1 ;
11521225 let estimated_statistic = 42.0 ;
11531226 let _ = backend_manager
0 commit comments