@@ -534,7 +534,7 @@ impl CostModelStorageLayer for BackendManager {
534534 epoch_id : Option < EpochId > ,
535535 ) -> StorageResult < ( ) > {
536536 assert ! ( cost. is_some( ) || estimated_statistic. is_some( ) ) ;
537- // TODO: should we do the following checks in the production environment?
537+ // TODO: we shouldn't do the following checks in the production environment.
538538 let expr_exists = PhysicalExpression :: find_by_id ( physical_expression_id)
539539 . one ( & self . db )
540540 . await ?;
@@ -561,30 +561,54 @@ impl CostModelStorageLayer for BackendManager {
561561 }
562562 }
563563
564- let epoch_id = match epoch_id {
565- Some ( id) => id,
566- None => {
567- // When init, please make sure there is at least one epoch in the Event table.
568- let latest_epoch_id = Event :: find ( )
569- . order_by_desc ( event:: Column :: EpochId )
570- . one ( & self . db )
571- . await ?
572- . unwrap ( ) ;
573- latest_epoch_id. epoch_id
574- }
575- } ;
576-
577564 let transaction = self . db . begin ( ) . await ?;
578565
579- let valid_cost = PlanCost :: find ( )
580- . filter ( plan_cost:: Column :: PhysicalExpressionId . eq ( physical_expression_id) )
581- . filter ( plan_cost:: Column :: EpochId . eq ( epoch_id) )
582- . filter ( plan_cost:: Column :: IsValid . eq ( true ) )
583- . one ( & transaction)
584- . await ?;
566+ /*
567+ The `store_cost` logic is as follows:
568+ 1. If the epoch_id is provided, we should update the cost with the corresponding epoch_id,
569+ or insert a new record if it doesn't exist.
570+ 2. If the epoch_id is not provided, we cannot directly use the latest epoch_id, since in the
571+ plan_cost table, for the current physical expression, there may be a valid cost with a lower
572+ epoch_id, since the update_stats function updates unrelated stats. So we need to handle the
573+ epoch_id in following logics:
574+ 1) If a valid cost is already in the plan_cost table, we use the same epoch_id.
575+ 2) If there is no valid cost in the plan_cost table, or there is no record, we use the
576+ latest epoch_id.
577+ */
578+ // TODO: We should add some integration tests to fully test the above logic
579+ let epoch_id_data;
580+ let existed_cost;
581+ if let Some ( epoch_id) = epoch_id {
582+ epoch_id_data = epoch_id;
583+ existed_cost = PlanCost :: find ( )
584+ . filter ( plan_cost:: Column :: PhysicalExpressionId . eq ( physical_expression_id) )
585+ . filter ( plan_cost:: Column :: EpochId . eq ( epoch_id) )
586+ . one ( & transaction)
587+ . await ?;
588+ } else {
589+ existed_cost = PlanCost :: find ( )
590+ . filter ( plan_cost:: Column :: PhysicalExpressionId . eq ( physical_expression_id) )
591+ . filter ( plan_cost:: Column :: IsValid . eq ( true ) )
592+ . order_by_desc ( plan_cost:: Column :: EpochId )
593+ . one ( & transaction)
594+ . await ?;
595+ if existed_cost. is_none ( ) {
596+ epoch_id_data = {
597+ // When init, please make sure there is at least one epoch in the Event table.
598+ let latest_epoch_id = Event :: find ( )
599+ . order_by_desc ( event:: Column :: EpochId )
600+ . one ( & self . db )
601+ . await ?
602+ . unwrap ( ) ;
603+ latest_epoch_id. epoch_id
604+ }
605+ } else {
606+ epoch_id_data = existed_cost. clone ( ) . unwrap ( ) . epoch_id ;
607+ }
608+ }
585609
586- if valid_cost . is_some ( ) {
587- let mut new_cost: plan_cost:: ActiveModel = valid_cost . unwrap ( ) . into ( ) ;
610+ if existed_cost . is_some ( ) {
611+ let mut new_cost: plan_cost:: ActiveModel = existed_cost . unwrap ( ) . into ( ) ;
588612 let mut update = false ;
589613 if cost. is_some ( ) {
590614 let input_cost = sea_orm:: ActiveValue :: Set ( Some ( json ! ( {
@@ -604,12 +628,25 @@ impl CostModelStorageLayer for BackendManager {
604628 }
605629 }
606630 if update {
631+ assert ! ( new_cost. epoch_id == sea_orm:: ActiveValue :: Set ( epoch_id_data) ) ;
607632 let _ = PlanCost :: update ( new_cost) . exec ( & transaction) . await ?;
608633 }
609634 } else {
635+ // TODO: we shouldn't do the following checks in the production environment.
636+ // This check may be easy to violate, so consider removing epoch_id input parameter.
637+ let latest_cost = PlanCost :: find ( )
638+ . filter ( plan_cost:: Column :: PhysicalExpressionId . eq ( physical_expression_id) )
639+ . order_by_desc ( plan_cost:: Column :: EpochId )
640+ . one ( & transaction)
641+ . await ?;
642+ if latest_cost. is_some ( ) {
643+ assert ! ( latest_cost. clone( ) . unwrap( ) . epoch_id < epoch_id_data) ;
644+ assert ! ( !latest_cost. clone( ) . unwrap( ) . is_valid) ;
645+ }
646+
610647 let new_cost = plan_cost:: ActiveModel {
611648 physical_expression_id : sea_orm:: ActiveValue :: Set ( physical_expression_id) ,
612- epoch_id : sea_orm:: ActiveValue :: Set ( epoch_id ) ,
649+ epoch_id : sea_orm:: ActiveValue :: Set ( epoch_id_data ) ,
613650 cost : sea_orm:: ActiveValue :: Set (
614651 cost. map ( |c| json ! ( { "compute_cost" : c. compute_cost, "io_cost" : c. io_cost} ) ) ,
615652 ) ,
0 commit comments