22
33use std:: ptr:: null;
44
5+ use crate :: cost_model:: interface:: Cost ;
56use crate :: entities:: { prelude:: * , * } ;
67use crate :: { BackendError , BackendManager , CostModelError , CostModelStorageLayer , StorageResult } ;
78use sea_orm:: prelude:: { Expr , Json } ;
@@ -11,6 +12,7 @@ use sea_orm::{
1112 ActiveModelTrait , ColumnTrait , DbBackend , DbErr , DeleteResult , EntityOrSelect , ModelTrait ,
1213 QueryFilter , QueryOrder , QuerySelect , QueryTrait , RuntimeErr , TransactionTrait ,
1314} ;
15+ use serde_json:: json;
1416
1517use super :: catalog:: mock_catalog:: { self , MockCatalog } ;
1618use super :: interface:: { CatalogSource , EpochOption , Stat } ;
@@ -443,33 +445,41 @@ impl CostModelStorageLayer for BackendManager {
443445 & self ,
444446 expr_id : Self :: ExprId ,
445447 epoch_id : Self :: EpochId ,
446- ) -> StorageResult < Option < i32 > > {
448+ ) -> StorageResult < Option < Cost > > {
447449 let cost = PlanCost :: find ( )
448450 . filter ( plan_cost:: Column :: PhysicalExpressionId . eq ( expr_id) )
449451 . filter ( plan_cost:: Column :: EpochId . eq ( epoch_id) )
450452 . one ( & self . db )
451453 . await ?;
452454 assert ! ( cost. is_some( ) , "Cost not found in Cost table" ) ;
453455 assert ! ( cost. clone( ) . unwrap( ) . is_valid, "Cost is not valid" ) ;
454- Ok ( cost. map ( |c| c. cost ) )
456+ Ok ( cost. map ( |c| Cost {
457+ compute_cost : c. cost . get ( "compute_cost" ) . unwrap ( ) . as_i64 ( ) . unwrap ( ) as i32 ,
458+ io_cost : c. cost . get ( "io_cost" ) . unwrap ( ) . as_i64 ( ) . unwrap ( ) as i32 ,
459+ estimated_statistic : c. estimated_statistic ,
460+ } ) )
455461 }
456462
457- async fn get_cost ( & self , expr_id : Self :: ExprId ) -> StorageResult < Option < i32 > > {
463+ async fn get_cost ( & self , expr_id : Self :: ExprId ) -> StorageResult < Option < Cost > > {
458464 let cost = PlanCost :: find ( )
459465 . filter ( plan_cost:: Column :: PhysicalExpressionId . eq ( expr_id) )
460466 . order_by_desc ( plan_cost:: Column :: EpochId )
461467 . one ( & self . db )
462468 . await ?;
463469 assert ! ( cost. is_some( ) , "Cost not found in Cost table" ) ;
464470 assert ! ( cost. clone( ) . unwrap( ) . is_valid, "Cost is not valid" ) ;
465- Ok ( cost. map ( |c| c. cost ) )
471+ Ok ( cost. map ( |c| Cost {
472+ compute_cost : c. cost . get ( "compute_cost" ) . unwrap ( ) . as_i64 ( ) . unwrap ( ) as i32 ,
473+ io_cost : c. cost . get ( "io_cost" ) . unwrap ( ) . as_i64 ( ) . unwrap ( ) as i32 ,
474+ estimated_statistic : c. estimated_statistic ,
475+ } ) )
466476 }
467477
468478 /// TODO: documentation
469479 async fn store_cost (
470480 & self ,
471481 physical_expression_id : Self :: ExprId ,
472- cost : i32 ,
482+ cost : Cost ,
473483 epoch_id : Self :: EpochId ,
474484 ) -> StorageResult < ( ) > {
475485 let expr_exists = PhysicalExpression :: find_by_id ( physical_expression_id)
@@ -496,7 +506,10 @@ impl CostModelStorageLayer for BackendManager {
496506 let new_cost = plan_cost:: ActiveModel {
497507 physical_expression_id : sea_orm:: ActiveValue :: Set ( physical_expression_id) ,
498508 epoch_id : sea_orm:: ActiveValue :: Set ( epoch_id) ,
499- cost : sea_orm:: ActiveValue :: Set ( cost) ,
509+ cost : sea_orm:: ActiveValue :: Set (
510+ json ! ( { "compute_cost" : cost. compute_cost, "io_cost" : cost. io_cost} ) ,
511+ ) ,
512+ estimated_statistic : sea_orm:: ActiveValue :: Set ( cost. estimated_statistic ) ,
500513 is_valid : sea_orm:: ActiveValue :: Set ( true ) ,
501514 ..Default :: default ( )
502515 } ;
@@ -507,7 +520,7 @@ impl CostModelStorageLayer for BackendManager {
507520
508521#[ cfg( test) ]
509522mod tests {
510- use crate :: cost_model:: interface:: { EpochOption , StatType } ;
523+ use crate :: cost_model:: interface:: { Cost , EpochOption , StatType } ;
511524 use crate :: { cost_model:: interface:: Stat , migrate, CostModelStorageLayer } ;
512525 use crate :: { get_sqlite_url, TEST_DATABASE_FILE } ;
513526 use sea_orm:: sqlx:: database;
@@ -681,7 +694,17 @@ mod tests {
681694 . await
682695 . unwrap ( ) ;
683696 backend_manager
684- . store_cost ( expr_id, 42 , versioned_stat_res[ 0 ] . epoch_id )
697+ . store_cost (
698+ expr_id,
699+ {
700+ Cost {
701+ compute_cost : 42 ,
702+ io_cost : 42 ,
703+ estimated_statistic : 42 ,
704+ }
705+ } ,
706+ versioned_stat_res[ 0 ] . epoch_id ,
707+ )
685708 . await
686709 . unwrap ( ) ;
687710 let cost_res = PlanCost :: find ( )
@@ -744,7 +767,7 @@ mod tests {
744767 . await
745768 . unwrap ( ) ;
746769 assert_eq ! ( cost_res. len( ) , 1 ) ;
747- assert_eq ! ( cost_res[ 0 ] . cost, 42 ) ;
770+ assert_eq ! ( cost_res[ 0 ] . cost, json! ( { "compute_cost" : 42 , "io_cost" : 42 } ) ) ;
748771 assert_eq ! ( cost_res[ 0 ] . epoch_id, epoch_id1) ;
749772 assert ! ( !cost_res[ 0 ] . is_valid) ;
750773
@@ -875,9 +898,13 @@ mod tests {
875898 . await
876899 . unwrap ( ) ;
877900 let physical_expression_id = 1 ;
878- let cost = 42 ;
901+ let cost = Cost {
902+ compute_cost : 42 ,
903+ io_cost : 42 ,
904+ estimated_statistic : 42 ,
905+ } ;
879906 backend_manager
880- . store_cost ( physical_expression_id, cost, epoch_id)
907+ . store_cost ( physical_expression_id, cost. clone ( ) , epoch_id)
881908 . await
882909 . unwrap ( ) ;
883910 let costs = super :: PlanCost :: find ( )
@@ -887,7 +914,14 @@ mod tests {
887914 assert_eq ! ( costs. len( ) , 2 ) ; // The first row one is the initialized data
888915 assert_eq ! ( costs[ 1 ] . epoch_id, epoch_id) ;
889916 assert_eq ! ( costs[ 1 ] . physical_expression_id, physical_expression_id) ;
890- assert_eq ! ( costs[ 1 ] . cost, cost) ;
917+ assert_eq ! (
918+ costs[ 1 ] . cost,
919+ json!( { "compute_cost" : cost. compute_cost, "io_cost" : cost. io_cost} )
920+ ) ;
921+ assert_eq ! (
922+ costs[ 1 ] . estimated_statistic as i32 ,
923+ cost. estimated_statistic
924+ ) ;
891925
892926 remove_db_file ( DATABASE_FILE ) ;
893927 }
@@ -903,9 +937,13 @@ mod tests {
903937 . await
904938 . unwrap ( ) ;
905939 let physical_expression_id = 1 ;
906- let cost = 42 ;
940+ let cost = Cost {
941+ compute_cost : 42 ,
942+ io_cost : 42 ,
943+ estimated_statistic : 42 ,
944+ } ;
907945 let _ = backend_manager
908- . store_cost ( physical_expression_id, cost, epoch_id)
946+ . store_cost ( physical_expression_id, cost. clone ( ) , epoch_id)
909947 . await ;
910948 let costs = super :: PlanCost :: find ( )
911949 . all ( & backend_manager. db )
@@ -914,7 +952,14 @@ mod tests {
914952 assert_eq ! ( costs. len( ) , 2 ) ; // The first row one is the initialized data
915953 assert_eq ! ( costs[ 1 ] . epoch_id, epoch_id) ;
916954 assert_eq ! ( costs[ 1 ] . physical_expression_id, physical_expression_id) ;
917- assert_eq ! ( costs[ 1 ] . cost, cost) ;
955+ assert_eq ! (
956+ costs[ 1 ] . cost,
957+ json!( { "compute_cost" : cost. compute_cost, "io_cost" : cost. io_cost} )
958+ ) ;
959+ assert_eq ! (
960+ costs[ 1 ] . estimated_statistic as i32 ,
961+ cost. estimated_statistic
962+ ) ;
918963
919964 let res = backend_manager
920965 . get_cost ( physical_expression_id)
@@ -936,9 +981,13 @@ mod tests {
936981 . await
937982 . unwrap ( ) ;
938983 let physical_expression_id = 1 ;
939- let cost = 42 ;
984+ let cost = Cost {
985+ compute_cost : 1420 ,
986+ io_cost : 42 ,
987+ estimated_statistic : 42 ,
988+ } ;
940989 let _ = backend_manager
941- . store_cost ( physical_expression_id, cost, epoch_id)
990+ . store_cost ( physical_expression_id, cost. clone ( ) , epoch_id)
942991 . await ;
943992 let costs = super :: PlanCost :: find ( )
944993 . all ( & backend_manager. db )
@@ -947,7 +996,14 @@ mod tests {
947996 assert_eq ! ( costs. len( ) , 2 ) ; // The first row one is the initialized data
948997 assert_eq ! ( costs[ 1 ] . epoch_id, epoch_id) ;
949998 assert_eq ! ( costs[ 1 ] . physical_expression_id, physical_expression_id) ;
950- assert_eq ! ( costs[ 1 ] . cost, cost) ;
999+ assert_eq ! (
1000+ costs[ 1 ] . cost,
1001+ json!( { "compute_cost" : cost. compute_cost, "io_cost" : cost. io_cost} )
1002+ ) ;
1003+ assert_eq ! (
1004+ costs[ 1 ] . estimated_statistic as i32 ,
1005+ cost. estimated_statistic
1006+ ) ;
9511007 println ! ( "{:?}" , costs) ;
9521008
9531009 // Retrieve physical_expression_id 1 and epoch_id 1
@@ -957,7 +1013,14 @@ mod tests {
9571013 . unwrap ( ) ;
9581014
9591015 // The cost in the dummy data is 10
960- assert_eq ! ( res. unwrap( ) , 10 ) ;
1016+ assert_eq ! (
1017+ res. unwrap( ) ,
1018+ Cost {
1019+ compute_cost: 10 ,
1020+ io_cost: 10 ,
1021+ estimated_statistic: 10 ,
1022+ }
1023+ ) ;
9611024
9621025 remove_db_file ( DATABASE_FILE ) ;
9631026 }
0 commit comments