Skip to content

Commit ea20998

Browse files
committed
Enable separate get and store cost & estimated_statistic in ORM
1 parent e9cd234 commit ea20998

File tree

6 files changed

+161
-78
lines changed

6 files changed

+161
-78
lines changed

optd-persistent/src/bin/init.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -355,8 +355,8 @@ async fn init_all_tables() -> Result<(), sea_orm::error::DbErr> {
355355
id: Set(1),
356356
physical_expression_id: Set(1),
357357
epoch_id: Set(1),
358-
cost: Set(json!({"compute_cost":10, "io_cost":10})),
359-
estimated_statistic: Set(10),
358+
cost: Set(Some(json!({"compute_cost":10, "io_cost":10}))),
359+
estimated_statistic: Set(Some(10)),
360360
is_valid: Set(true),
361361
};
362362
plan_cost::Entity::insert(plan_cost)

optd-persistent/src/cost_model/interface.rs

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,6 @@ pub struct Stat {
9191
pub struct Cost {
9292
pub compute_cost: i32,
9393
pub io_cost: i32,
94-
// Raw estimated output row count of targeted expression.
95-
pub estimated_statistic: i32,
9694
}
9795

9896
#[derive(Clone, Debug)]
@@ -118,8 +116,13 @@ pub trait CostModelStorageLayer {
118116
epoch_option: EpochOption,
119117
) -> StorageResult<Option<EpochId>>;
120118

121-
async fn store_cost(&self, expr_id: ExprId, cost: Cost, epoch_id: EpochId)
122-
-> StorageResult<()>;
119+
async fn store_cost(
120+
&self,
121+
expr_id: ExprId,
122+
cost: Option<Cost>,
123+
estimated_statistic: Option<i32>,
124+
epoch_id: EpochId,
125+
) -> StorageResult<()>;
123126

124127
async fn store_expr_stats_mappings(
125128
&self,
@@ -162,9 +165,9 @@ pub trait CostModelStorageLayer {
162165
&self,
163166
expr_id: ExprId,
164167
epoch_id: EpochId,
165-
) -> StorageResult<Option<Cost>>;
168+
) -> StorageResult<(Option<Cost>, Option<i32>)>;
166169

167-
async fn get_cost(&self, expr_id: ExprId) -> StorageResult<Option<Cost>>;
170+
async fn get_cost(&self, expr_id: ExprId) -> StorageResult<(Option<Cost>, Option<i32>)>;
168171

169172
async fn get_attribute(
170173
&self,

optd-persistent/src/cost_model/orm.rs

Lines changed: 138 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use crate::cost_model::interface::Cost;
44
use crate::entities::{prelude::*, *};
55
use crate::{BackendError, BackendManager, CostModelStorageLayer, StorageResult};
66
use sea_orm::prelude::{Expr, Json};
7-
use sea_orm::sea_query::Query;
7+
use sea_orm::sea_query::{ExprTrait, Query};
88
use sea_orm::{sqlx::types::chrono::Utc, EntityTrait};
99
use sea_orm::{
1010
ActiveModelTrait, ColumnTrait, Condition, DbBackend, DbErr, DeleteResult, EntityOrSelect,
@@ -208,7 +208,7 @@ impl CostModelStorageLayer for BackendManager {
208208
// 0. Check if the stat already exists. If exists, get stat_id, else insert into statistic table.
209209
let stat_id = match stat.table_id {
210210
Some(table_id) => {
211-
// TODO(lanlou): only select needed fields
211+
// TODO: only select needed fields
212212
let res = Statistic::find()
213213
.filter(statistic::Column::TableId.eq(table_id))
214214
.inner_join(versioned_statistic::Entity)
@@ -467,47 +467,68 @@ impl CostModelStorageLayer for BackendManager {
467467
}
468468

469469
/// TODO: documentation
470+
/// Each record in the `plan_cost` table can contain either the cost or the estimated statistic
471+
/// or both, but never neither.
472+
/// The name can be misleading, since it can also return the estimated statistic.
470473
async fn get_cost_analysis(
471474
&self,
472475
expr_id: ExprId,
473476
epoch_id: EpochId,
474-
) -> StorageResult<Option<Cost>> {
477+
) -> StorageResult<(Option<Cost>, Option<i32>)> {
475478
let cost = PlanCost::find()
476479
.filter(plan_cost::Column::PhysicalExpressionId.eq(expr_id))
477480
.filter(plan_cost::Column::EpochId.eq(epoch_id))
478481
.one(&self.db)
479482
.await?;
480-
assert!(cost.is_some(), "Cost not found in Cost table");
481-
assert!(cost.clone().unwrap().is_valid, "Cost is not valid");
482-
Ok(cost.map(|c| Cost {
483-
compute_cost: c.cost.get("compute_cost").unwrap().as_i64().unwrap() as i32,
484-
io_cost: c.cost.get("io_cost").unwrap().as_i64().unwrap() as i32,
485-
estimated_statistic: c.estimated_statistic,
486-
}))
483+
// When this cost is invalid or not found, we should return None
484+
if cost.is_none() || !cost.clone().unwrap().is_valid {
485+
return Ok((None, None));
486+
}
487+
488+
let real_cost = cost.as_ref().and_then(|c| c.cost.as_ref()).map(|c| Cost {
489+
compute_cost: c.get("compute_cost").unwrap().as_i64().unwrap() as i32,
490+
io_cost: c.get("io_cost").unwrap().as_i64().unwrap() as i32,
491+
});
492+
493+
Ok((real_cost, cost.unwrap().estimated_statistic))
487494
}
488495

489-
async fn get_cost(&self, expr_id: ExprId) -> StorageResult<Option<Cost>> {
496+
/// TODO: documentation
497+
/// It returns the cost and estimated statistic if applicable.
498+
/// Each record in the `plan_cost` table can contain either the cost or the estimated statistic
499+
/// or both, but never neither.
500+
/// The name can be misleading, since it can also return the estimated statistic.
501+
async fn get_cost(&self, expr_id: ExprId) -> StorageResult<(Option<Cost>, Option<i32>)> {
490502
let cost = PlanCost::find()
491503
.filter(plan_cost::Column::PhysicalExpressionId.eq(expr_id))
492504
.order_by_desc(plan_cost::Column::EpochId)
493505
.one(&self.db)
494506
.await?;
495-
assert!(cost.is_some(), "Cost not found in Cost table");
496-
assert!(cost.clone().unwrap().is_valid, "Cost is not valid");
497-
Ok(cost.map(|c| Cost {
498-
compute_cost: c.cost.get("compute_cost").unwrap().as_i64().unwrap() as i32,
499-
io_cost: c.cost.get("io_cost").unwrap().as_i64().unwrap() as i32,
500-
estimated_statistic: c.estimated_statistic,
501-
}))
507+
// When this cost is invalid or not found, we should return None
508+
if cost.is_none() || !cost.clone().unwrap().is_valid {
509+
return Ok((None, None));
510+
}
511+
512+
let real_cost = cost.as_ref().and_then(|c| c.cost.as_ref()).map(|c| Cost {
513+
compute_cost: c.get("compute_cost").unwrap().as_i64().unwrap() as i32,
514+
io_cost: c.get("io_cost").unwrap().as_i64().unwrap() as i32,
515+
});
516+
517+
Ok((real_cost, cost.unwrap().estimated_statistic))
502518
}
503519

520+
/// This method should handle the case when the cost is already stored.
521+
/// The name maybe misleading, since it can also store the estimated statistic.
504522
/// TODO: documentation
505523
async fn store_cost(
506524
&self,
507525
physical_expression_id: ExprId,
508-
cost: Cost,
526+
cost: Option<Cost>,
527+
estimated_statistic: Option<i32>,
509528
epoch_id: EpochId,
510529
) -> StorageResult<()> {
530+
assert!(cost.is_some() || estimated_statistic.is_some());
531+
// TODO: should we do the following checks in the production environment?
511532
let expr_exists = PhysicalExpression::find_by_id(physical_expression_id)
512533
.one(&self.db)
513534
.await?;
@@ -520,7 +541,6 @@ impl CostModelStorageLayer for BackendManager {
520541
.into(),
521542
));
522543
}
523-
524544
// Check if epoch_id exists in Event table
525545
let epoch_exists = Event::find()
526546
.filter(event::Column::EpochId.eq(epoch_id))
@@ -533,17 +553,42 @@ impl CostModelStorageLayer for BackendManager {
533553
));
534554
}
535555

536-
let new_cost = plan_cost::ActiveModel {
537-
physical_expression_id: sea_orm::ActiveValue::Set(physical_expression_id),
538-
epoch_id: sea_orm::ActiveValue::Set(epoch_id),
539-
cost: sea_orm::ActiveValue::Set(
540-
json!({"compute_cost": cost.compute_cost, "io_cost": cost.io_cost}),
541-
),
542-
estimated_statistic: sea_orm::ActiveValue::Set(cost.estimated_statistic),
543-
is_valid: sea_orm::ActiveValue::Set(true),
544-
..Default::default()
545-
};
546-
let _ = PlanCost::insert(new_cost).exec(&self.db).await?;
556+
let transaction = self.db.begin().await?;
557+
558+
let valid_cost = PlanCost::find()
559+
.filter(plan_cost::Column::PhysicalExpressionId.eq(physical_expression_id))
560+
.filter(plan_cost::Column::EpochId.eq(epoch_id))
561+
.filter(plan_cost::Column::IsValid.eq(true))
562+
.one(&transaction)
563+
.await?;
564+
565+
if valid_cost.is_some() {
566+
let mut new_cost: plan_cost::ActiveModel = valid_cost.unwrap().into();
567+
if cost.is_some() {
568+
new_cost.cost = sea_orm::ActiveValue::Set(Some(json!({
569+
"compute_cost": cost.clone().unwrap().compute_cost,
570+
"io_cost": cost.clone().unwrap().io_cost
571+
})));
572+
}
573+
if estimated_statistic.is_some() {
574+
new_cost.estimated_statistic = sea_orm::ActiveValue::Set(estimated_statistic);
575+
}
576+
let _ = PlanCost::update(new_cost).exec(&transaction).await?;
577+
} else {
578+
let new_cost = plan_cost::ActiveModel {
579+
physical_expression_id: sea_orm::ActiveValue::Set(physical_expression_id),
580+
epoch_id: sea_orm::ActiveValue::Set(epoch_id),
581+
cost: sea_orm::ActiveValue::Set(
582+
cost.map(|c| json!({"compute_cost": c.compute_cost, "io_cost": c.io_cost})),
583+
),
584+
estimated_statistic: sea_orm::ActiveValue::Set(estimated_statistic),
585+
is_valid: sea_orm::ActiveValue::Set(true),
586+
..Default::default()
587+
};
588+
let _ = PlanCost::insert(new_cost).exec(&transaction).await?;
589+
}
590+
591+
transaction.commit().await?;
547592
Ok(())
548593
}
549594

@@ -755,13 +800,11 @@ mod tests {
755800
backend_manager
756801
.store_cost(
757802
expr_id,
758-
{
759-
Cost {
760-
compute_cost: 42,
761-
io_cost: 42,
762-
estimated_statistic: 42,
763-
}
764-
},
803+
Some(Cost {
804+
compute_cost: 42,
805+
io_cost: 42,
806+
}),
807+
Some(42),
765808
versioned_stat_res[0].epoch_id,
766809
)
767810
.await
@@ -826,7 +869,10 @@ mod tests {
826869
.await
827870
.unwrap();
828871
assert_eq!(cost_res.len(), 1);
829-
assert_eq!(cost_res[0].cost, json!({"compute_cost": 42, "io_cost": 42}));
872+
assert_eq!(
873+
cost_res[0].cost,
874+
Some(json!({"compute_cost": 42, "io_cost": 42}))
875+
);
830876
assert_eq!(cost_res[0].epoch_id, epoch_id1);
831877
assert!(!cost_res[0].is_valid);
832878

@@ -960,10 +1006,15 @@ mod tests {
9601006
let cost = Cost {
9611007
compute_cost: 42,
9621008
io_cost: 42,
963-
estimated_statistic: 42,
9641009
};
1010+
let mut estimated_statistic = 42;
9651011
backend_manager
966-
.store_cost(physical_expression_id, cost.clone(), epoch_id)
1012+
.store_cost(
1013+
physical_expression_id,
1014+
Some(cost.clone()),
1015+
Some(estimated_statistic),
1016+
epoch_id,
1017+
)
9671018
.await
9681019
.unwrap();
9691020
let costs = super::PlanCost::find()
@@ -975,11 +1026,37 @@ mod tests {
9751026
assert_eq!(costs[1].physical_expression_id, physical_expression_id);
9761027
assert_eq!(
9771028
costs[1].cost,
978-
json!({"compute_cost": cost.compute_cost, "io_cost": cost.io_cost})
1029+
Some(json!({"compute_cost": cost.compute_cost, "io_cost": cost.io_cost}))
9791030
);
9801031
assert_eq!(
981-
costs[1].estimated_statistic as i32,
982-
cost.estimated_statistic
1032+
costs[1].estimated_statistic.unwrap() as i32,
1033+
estimated_statistic
1034+
);
1035+
1036+
estimated_statistic = 50;
1037+
backend_manager
1038+
.store_cost(
1039+
physical_expression_id,
1040+
None,
1041+
Some(estimated_statistic),
1042+
epoch_id,
1043+
)
1044+
.await
1045+
.unwrap();
1046+
let costs = super::PlanCost::find()
1047+
.all(&backend_manager.db)
1048+
.await
1049+
.unwrap();
1050+
assert_eq!(costs.len(), 2); // We should not insert a new row
1051+
assert_eq!(costs[1].epoch_id, epoch_id);
1052+
assert_eq!(costs[1].physical_expression_id, physical_expression_id);
1053+
assert_eq!(
1054+
costs[1].cost,
1055+
Some(json!({"compute_cost": cost.compute_cost, "io_cost": cost.io_cost}))
1056+
);
1057+
assert_eq!(
1058+
costs[1].estimated_statistic.unwrap() as i32,
1059+
estimated_statistic // The estimated_statistic should be update
9831060
);
9841061

9851062
remove_db_file(DATABASE_FILE);
@@ -999,10 +1076,9 @@ mod tests {
9991076
let cost = Cost {
10001077
compute_cost: 42,
10011078
io_cost: 42,
1002-
estimated_statistic: 42,
10031079
};
10041080
let _ = backend_manager
1005-
.store_cost(physical_expression_id, cost.clone(), epoch_id)
1081+
.store_cost(physical_expression_id, Some(cost.clone()), None, epoch_id)
10061082
.await;
10071083
let costs = super::PlanCost::find()
10081084
.all(&backend_manager.db)
@@ -1013,18 +1089,16 @@ mod tests {
10131089
assert_eq!(costs[1].physical_expression_id, physical_expression_id);
10141090
assert_eq!(
10151091
costs[1].cost,
1016-
json!({"compute_cost": cost.compute_cost, "io_cost": cost.io_cost})
1017-
);
1018-
assert_eq!(
1019-
costs[1].estimated_statistic as i32,
1020-
cost.estimated_statistic
1092+
Some(json!({"compute_cost": cost.compute_cost, "io_cost": cost.io_cost}))
10211093
);
1094+
assert_eq!(costs[1].estimated_statistic, None);
10221095

10231096
let res = backend_manager
10241097
.get_cost(physical_expression_id)
10251098
.await
10261099
.unwrap();
1027-
assert_eq!(res.unwrap(), cost);
1100+
assert_eq!(res.0.unwrap(), cost);
1101+
assert_eq!(res.1, None);
10281102

10291103
remove_db_file(DATABASE_FILE);
10301104
}
@@ -1040,13 +1114,14 @@ mod tests {
10401114
.await
10411115
.unwrap();
10421116
let physical_expression_id = 1;
1043-
let cost = Cost {
1044-
compute_cost: 1420,
1045-
io_cost: 42,
1046-
estimated_statistic: 42,
1047-
};
1117+
let estimated_statistic = 42;
10481118
let _ = backend_manager
1049-
.store_cost(physical_expression_id, cost.clone(), epoch_id)
1119+
.store_cost(
1120+
physical_expression_id,
1121+
None,
1122+
Some(estimated_statistic),
1123+
epoch_id,
1124+
)
10501125
.await;
10511126
let costs = super::PlanCost::find()
10521127
.all(&backend_manager.db)
@@ -1055,13 +1130,10 @@ mod tests {
10551130
assert_eq!(costs.len(), 2); // The first row one is the initialized data
10561131
assert_eq!(costs[1].epoch_id, epoch_id);
10571132
assert_eq!(costs[1].physical_expression_id, physical_expression_id);
1133+
assert_eq!(costs[1].cost, None);
10581134
assert_eq!(
1059-
costs[1].cost,
1060-
json!({"compute_cost": cost.compute_cost, "io_cost": cost.io_cost})
1061-
);
1062-
assert_eq!(
1063-
costs[1].estimated_statistic as i32,
1064-
cost.estimated_statistic
1135+
costs[1].estimated_statistic.unwrap() as i32,
1136+
estimated_statistic
10651137
);
10661138
println!("{:?}", costs);
10671139

@@ -1073,13 +1145,13 @@ mod tests {
10731145

10741146
// The cost in the dummy data is 10
10751147
assert_eq!(
1076-
res.unwrap(),
1148+
res.0.unwrap(),
10771149
Cost {
10781150
compute_cost: 10,
10791151
io_cost: 10,
1080-
estimated_statistic: 10,
10811152
}
10821153
);
1154+
assert_eq!(res.1.unwrap(), 10);
10831155

10841156
remove_db_file(DATABASE_FILE);
10851157
}

optd-persistent/src/db/init.db

0 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)