@@ -4,7 +4,7 @@ use crate::cost_model::interface::Cost;
44use crate :: entities:: { prelude:: * , * } ;
55use crate :: { BackendError , BackendManager , CostModelStorageLayer , StorageResult } ;
66use sea_orm:: prelude:: { Expr , Json } ;
7- use sea_orm:: sea_query:: Query ;
7+ use sea_orm:: sea_query:: { ExprTrait , Query } ;
88use sea_orm:: { sqlx:: types:: chrono:: Utc , EntityTrait } ;
99use sea_orm:: {
1010 ActiveModelTrait , ColumnTrait , Condition , DbBackend , DbErr , DeleteResult , EntityOrSelect ,
@@ -208,7 +208,7 @@ impl CostModelStorageLayer for BackendManager {
208208 // 0. Check if the stat already exists. If exists, get stat_id, else insert into statistic table.
209209 let stat_id = match stat. table_id {
210210 Some ( table_id) => {
211- // TODO(lanlou) : only select needed fields
211+ // TODO: only select needed fields
212212 let res = Statistic :: find ( )
213213 . filter ( statistic:: Column :: TableId . eq ( table_id) )
214214 . inner_join ( versioned_statistic:: Entity )
@@ -467,47 +467,68 @@ impl CostModelStorageLayer for BackendManager {
467467 }
468468
469469 /// TODO: documentation
470+ /// Each record in the `plan_cost` table can contain either the cost or the estimated statistic
471+ /// or both, but never neither.
472+ /// The name can be misleading, since it can also return the estimated statistic.
470473 async fn get_cost_analysis (
471474 & self ,
472475 expr_id : ExprId ,
473476 epoch_id : EpochId ,
474- ) -> StorageResult < Option < Cost > > {
477+ ) -> StorageResult < ( Option < Cost > , Option < i32 > ) > {
475478 let cost = PlanCost :: find ( )
476479 . filter ( plan_cost:: Column :: PhysicalExpressionId . eq ( expr_id) )
477480 . filter ( plan_cost:: Column :: EpochId . eq ( epoch_id) )
478481 . one ( & self . db )
479482 . await ?;
480- assert ! ( cost. is_some( ) , "Cost not found in Cost table" ) ;
481- assert ! ( cost. clone( ) . unwrap( ) . is_valid, "Cost is not valid" ) ;
482- Ok ( cost. map ( |c| Cost {
483- compute_cost : c. cost . get ( "compute_cost" ) . unwrap ( ) . as_i64 ( ) . unwrap ( ) as i32 ,
484- io_cost : c. cost . get ( "io_cost" ) . unwrap ( ) . as_i64 ( ) . unwrap ( ) as i32 ,
485- estimated_statistic : c. estimated_statistic ,
486- } ) )
483+ // When this cost is invalid or not found, we should return None
484+ if cost. is_none ( ) || !cost. clone ( ) . unwrap ( ) . is_valid {
485+ return Ok ( ( None , None ) ) ;
486+ }
487+
488+ let real_cost = cost. as_ref ( ) . and_then ( |c| c. cost . as_ref ( ) ) . map ( |c| Cost {
489+ compute_cost : c. get ( "compute_cost" ) . unwrap ( ) . as_i64 ( ) . unwrap ( ) as i32 ,
490+ io_cost : c. get ( "io_cost" ) . unwrap ( ) . as_i64 ( ) . unwrap ( ) as i32 ,
491+ } ) ;
492+
493+ Ok ( ( real_cost, cost. unwrap ( ) . estimated_statistic ) )
487494 }
488495
489- async fn get_cost ( & self , expr_id : ExprId ) -> StorageResult < Option < Cost > > {
496+ /// TODO: documentation
497+ /// It returns the cost and estimated statistic if applicable.
498+ /// Each record in the `plan_cost` table can contain either the cost or the estimated statistic
499+ /// or both, but never neither.
500+ /// The name can be misleading, since it can also return the estimated statistic.
501+ async fn get_cost ( & self , expr_id : ExprId ) -> StorageResult < ( Option < Cost > , Option < i32 > ) > {
490502 let cost = PlanCost :: find ( )
491503 . filter ( plan_cost:: Column :: PhysicalExpressionId . eq ( expr_id) )
492504 . order_by_desc ( plan_cost:: Column :: EpochId )
493505 . one ( & self . db )
494506 . await ?;
495- assert ! ( cost. is_some( ) , "Cost not found in Cost table" ) ;
496- assert ! ( cost. clone( ) . unwrap( ) . is_valid, "Cost is not valid" ) ;
497- Ok ( cost. map ( |c| Cost {
498- compute_cost : c. cost . get ( "compute_cost" ) . unwrap ( ) . as_i64 ( ) . unwrap ( ) as i32 ,
499- io_cost : c. cost . get ( "io_cost" ) . unwrap ( ) . as_i64 ( ) . unwrap ( ) as i32 ,
500- estimated_statistic : c. estimated_statistic ,
501- } ) )
507+ // When this cost is invalid or not found, we should return None
508+ if cost. is_none ( ) || !cost. clone ( ) . unwrap ( ) . is_valid {
509+ return Ok ( ( None , None ) ) ;
510+ }
511+
512+ let real_cost = cost. as_ref ( ) . and_then ( |c| c. cost . as_ref ( ) ) . map ( |c| Cost {
513+ compute_cost : c. get ( "compute_cost" ) . unwrap ( ) . as_i64 ( ) . unwrap ( ) as i32 ,
514+ io_cost : c. get ( "io_cost" ) . unwrap ( ) . as_i64 ( ) . unwrap ( ) as i32 ,
515+ } ) ;
516+
517+ Ok ( ( real_cost, cost. unwrap ( ) . estimated_statistic ) )
502518 }
503519
520+ /// This method should handle the case when the cost is already stored.
521+ /// The name maybe misleading, since it can also store the estimated statistic.
504522 /// TODO: documentation
505523 async fn store_cost (
506524 & self ,
507525 physical_expression_id : ExprId ,
508- cost : Cost ,
526+ cost : Option < Cost > ,
527+ estimated_statistic : Option < i32 > ,
509528 epoch_id : EpochId ,
510529 ) -> StorageResult < ( ) > {
530+ assert ! ( cost. is_some( ) || estimated_statistic. is_some( ) ) ;
531+ // TODO: should we do the following checks in the production environment?
511532 let expr_exists = PhysicalExpression :: find_by_id ( physical_expression_id)
512533 . one ( & self . db )
513534 . await ?;
@@ -520,7 +541,6 @@ impl CostModelStorageLayer for BackendManager {
520541 . into ( ) ,
521542 ) ) ;
522543 }
523-
524544 // Check if epoch_id exists in Event table
525545 let epoch_exists = Event :: find ( )
526546 . filter ( event:: Column :: EpochId . eq ( epoch_id) )
@@ -533,17 +553,42 @@ impl CostModelStorageLayer for BackendManager {
533553 ) ) ;
534554 }
535555
536- let new_cost = plan_cost:: ActiveModel {
537- physical_expression_id : sea_orm:: ActiveValue :: Set ( physical_expression_id) ,
538- epoch_id : sea_orm:: ActiveValue :: Set ( epoch_id) ,
539- cost : sea_orm:: ActiveValue :: Set (
540- json ! ( { "compute_cost" : cost. compute_cost, "io_cost" : cost. io_cost} ) ,
541- ) ,
542- estimated_statistic : sea_orm:: ActiveValue :: Set ( cost. estimated_statistic ) ,
543- is_valid : sea_orm:: ActiveValue :: Set ( true ) ,
544- ..Default :: default ( )
545- } ;
546- let _ = PlanCost :: insert ( new_cost) . exec ( & self . db ) . await ?;
556+ let transaction = self . db . begin ( ) . await ?;
557+
558+ let valid_cost = PlanCost :: find ( )
559+ . filter ( plan_cost:: Column :: PhysicalExpressionId . eq ( physical_expression_id) )
560+ . filter ( plan_cost:: Column :: EpochId . eq ( epoch_id) )
561+ . filter ( plan_cost:: Column :: IsValid . eq ( true ) )
562+ . one ( & transaction)
563+ . await ?;
564+
565+ if valid_cost. is_some ( ) {
566+ let mut new_cost: plan_cost:: ActiveModel = valid_cost. unwrap ( ) . into ( ) ;
567+ if cost. is_some ( ) {
568+ new_cost. cost = sea_orm:: ActiveValue :: Set ( Some ( json ! ( {
569+ "compute_cost" : cost. clone( ) . unwrap( ) . compute_cost,
570+ "io_cost" : cost. clone( ) . unwrap( ) . io_cost
571+ } ) ) ) ;
572+ }
573+ if estimated_statistic. is_some ( ) {
574+ new_cost. estimated_statistic = sea_orm:: ActiveValue :: Set ( estimated_statistic) ;
575+ }
576+ let _ = PlanCost :: update ( new_cost) . exec ( & transaction) . await ?;
577+ } else {
578+ let new_cost = plan_cost:: ActiveModel {
579+ physical_expression_id : sea_orm:: ActiveValue :: Set ( physical_expression_id) ,
580+ epoch_id : sea_orm:: ActiveValue :: Set ( epoch_id) ,
581+ cost : sea_orm:: ActiveValue :: Set (
582+ cost. map ( |c| json ! ( { "compute_cost" : c. compute_cost, "io_cost" : c. io_cost} ) ) ,
583+ ) ,
584+ estimated_statistic : sea_orm:: ActiveValue :: Set ( estimated_statistic) ,
585+ is_valid : sea_orm:: ActiveValue :: Set ( true ) ,
586+ ..Default :: default ( )
587+ } ;
588+ let _ = PlanCost :: insert ( new_cost) . exec ( & transaction) . await ?;
589+ }
590+
591+ transaction. commit ( ) . await ?;
547592 Ok ( ( ) )
548593 }
549594
@@ -755,13 +800,11 @@ mod tests {
755800 backend_manager
756801 . store_cost (
757802 expr_id,
758- {
759- Cost {
760- compute_cost : 42 ,
761- io_cost : 42 ,
762- estimated_statistic : 42 ,
763- }
764- } ,
803+ Some ( Cost {
804+ compute_cost : 42 ,
805+ io_cost : 42 ,
806+ } ) ,
807+ Some ( 42 ) ,
765808 versioned_stat_res[ 0 ] . epoch_id ,
766809 )
767810 . await
@@ -826,7 +869,10 @@ mod tests {
826869 . await
827870 . unwrap ( ) ;
828871 assert_eq ! ( cost_res. len( ) , 1 ) ;
829- assert_eq ! ( cost_res[ 0 ] . cost, json!( { "compute_cost" : 42 , "io_cost" : 42 } ) ) ;
872+ assert_eq ! (
873+ cost_res[ 0 ] . cost,
874+ Some ( json!( { "compute_cost" : 42 , "io_cost" : 42 } ) )
875+ ) ;
830876 assert_eq ! ( cost_res[ 0 ] . epoch_id, epoch_id1) ;
831877 assert ! ( !cost_res[ 0 ] . is_valid) ;
832878
@@ -960,10 +1006,15 @@ mod tests {
9601006 let cost = Cost {
9611007 compute_cost : 42 ,
9621008 io_cost : 42 ,
963- estimated_statistic : 42 ,
9641009 } ;
1010+ let mut estimated_statistic = 42 ;
9651011 backend_manager
966- . store_cost ( physical_expression_id, cost. clone ( ) , epoch_id)
1012+ . store_cost (
1013+ physical_expression_id,
1014+ Some ( cost. clone ( ) ) ,
1015+ Some ( estimated_statistic) ,
1016+ epoch_id,
1017+ )
9671018 . await
9681019 . unwrap ( ) ;
9691020 let costs = super :: PlanCost :: find ( )
@@ -975,11 +1026,37 @@ mod tests {
9751026 assert_eq ! ( costs[ 1 ] . physical_expression_id, physical_expression_id) ;
9761027 assert_eq ! (
9771028 costs[ 1 ] . cost,
978- json!( { "compute_cost" : cost. compute_cost, "io_cost" : cost. io_cost} )
1029+ Some ( json!( { "compute_cost" : cost. compute_cost, "io_cost" : cost. io_cost} ) )
9791030 ) ;
9801031 assert_eq ! (
981- costs[ 1 ] . estimated_statistic as i32 ,
982- cost. estimated_statistic
1032+ costs[ 1 ] . estimated_statistic. unwrap( ) as i32 ,
1033+ estimated_statistic
1034+ ) ;
1035+
1036+ estimated_statistic = 50 ;
1037+ backend_manager
1038+ . store_cost (
1039+ physical_expression_id,
1040+ None ,
1041+ Some ( estimated_statistic) ,
1042+ epoch_id,
1043+ )
1044+ . await
1045+ . unwrap ( ) ;
1046+ let costs = super :: PlanCost :: find ( )
1047+ . all ( & backend_manager. db )
1048+ . await
1049+ . unwrap ( ) ;
1050+ assert_eq ! ( costs. len( ) , 2 ) ; // We should not insert a new row
1051+ assert_eq ! ( costs[ 1 ] . epoch_id, epoch_id) ;
1052+ assert_eq ! ( costs[ 1 ] . physical_expression_id, physical_expression_id) ;
1053+ assert_eq ! (
1054+ costs[ 1 ] . cost,
1055+ Some ( json!( { "compute_cost" : cost. compute_cost, "io_cost" : cost. io_cost} ) )
1056+ ) ;
1057+ assert_eq ! (
1058+ costs[ 1 ] . estimated_statistic. unwrap( ) as i32 ,
1059+ estimated_statistic // The estimated_statistic should be update
9831060 ) ;
9841061
9851062 remove_db_file ( DATABASE_FILE ) ;
@@ -999,10 +1076,9 @@ mod tests {
9991076 let cost = Cost {
10001077 compute_cost : 42 ,
10011078 io_cost : 42 ,
1002- estimated_statistic : 42 ,
10031079 } ;
10041080 let _ = backend_manager
1005- . store_cost ( physical_expression_id, cost. clone ( ) , epoch_id)
1081+ . store_cost ( physical_expression_id, Some ( cost. clone ( ) ) , None , epoch_id)
10061082 . await ;
10071083 let costs = super :: PlanCost :: find ( )
10081084 . all ( & backend_manager. db )
@@ -1013,18 +1089,16 @@ mod tests {
10131089 assert_eq ! ( costs[ 1 ] . physical_expression_id, physical_expression_id) ;
10141090 assert_eq ! (
10151091 costs[ 1 ] . cost,
1016- json!( { "compute_cost" : cost. compute_cost, "io_cost" : cost. io_cost} )
1017- ) ;
1018- assert_eq ! (
1019- costs[ 1 ] . estimated_statistic as i32 ,
1020- cost. estimated_statistic
1092+ Some ( json!( { "compute_cost" : cost. compute_cost, "io_cost" : cost. io_cost} ) )
10211093 ) ;
1094+ assert_eq ! ( costs[ 1 ] . estimated_statistic, None ) ;
10221095
10231096 let res = backend_manager
10241097 . get_cost ( physical_expression_id)
10251098 . await
10261099 . unwrap ( ) ;
1027- assert_eq ! ( res. unwrap( ) , cost) ;
1100+ assert_eq ! ( res. 0 . unwrap( ) , cost) ;
1101+ assert_eq ! ( res. 1 , None ) ;
10281102
10291103 remove_db_file ( DATABASE_FILE ) ;
10301104 }
@@ -1040,13 +1114,14 @@ mod tests {
10401114 . await
10411115 . unwrap ( ) ;
10421116 let physical_expression_id = 1 ;
1043- let cost = Cost {
1044- compute_cost : 1420 ,
1045- io_cost : 42 ,
1046- estimated_statistic : 42 ,
1047- } ;
1117+ let estimated_statistic = 42 ;
10481118 let _ = backend_manager
1049- . store_cost ( physical_expression_id, cost. clone ( ) , epoch_id)
1119+ . store_cost (
1120+ physical_expression_id,
1121+ None ,
1122+ Some ( estimated_statistic) ,
1123+ epoch_id,
1124+ )
10501125 . await ;
10511126 let costs = super :: PlanCost :: find ( )
10521127 . all ( & backend_manager. db )
@@ -1055,13 +1130,10 @@ mod tests {
10551130 assert_eq ! ( costs. len( ) , 2 ) ; // The first row one is the initialized data
10561131 assert_eq ! ( costs[ 1 ] . epoch_id, epoch_id) ;
10571132 assert_eq ! ( costs[ 1 ] . physical_expression_id, physical_expression_id) ;
1133+ assert_eq ! ( costs[ 1 ] . cost, None ) ;
10581134 assert_eq ! (
1059- costs[ 1 ] . cost,
1060- json!( { "compute_cost" : cost. compute_cost, "io_cost" : cost. io_cost} )
1061- ) ;
1062- assert_eq ! (
1063- costs[ 1 ] . estimated_statistic as i32 ,
1064- cost. estimated_statistic
1135+ costs[ 1 ] . estimated_statistic. unwrap( ) as i32 ,
1136+ estimated_statistic
10651137 ) ;
10661138 println ! ( "{:?}" , costs) ;
10671139
@@ -1073,13 +1145,13 @@ mod tests {
10731145
10741146 // The cost in the dummy data is 10
10751147 assert_eq ! (
1076- res. unwrap( ) ,
1148+ res. 0 . unwrap( ) ,
10771149 Cost {
10781150 compute_cost: 10 ,
10791151 io_cost: 10 ,
1080- estimated_statistic: 10 ,
10811152 }
10821153 ) ;
1154+ assert_eq ! ( res. 1 . unwrap( ) , 10 ) ;
10831155
10841156 remove_db_file ( DATABASE_FILE ) ;
10851157 }
0 commit comments