Skip to content

Commit e25ff2c

Browse files
authored
fix(cost-model): fix the epoch_id in store_cost (#45)
* fix: fix the epoch_id in store_cost * fix failed tests
1 parent dfaee9a commit e25ff2c

File tree

1 file changed

+96
-23
lines changed
  • optd-persistent/src/cost_model

1 file changed

+96
-23
lines changed

optd-persistent/src/cost_model/orm.rs

Lines changed: 96 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -534,7 +534,7 @@ impl CostModelStorageLayer for BackendManager {
534534
epoch_id: Option<EpochId>,
535535
) -> StorageResult<()> {
536536
assert!(cost.is_some() || estimated_statistic.is_some());
537-
// TODO: should we do the following checks in the production environment?
537+
// TODO: we shouldn't do the following checks in the production environment.
538538
let expr_exists = PhysicalExpression::find_by_id(physical_expression_id)
539539
.one(&self.db)
540540
.await?;
@@ -561,30 +561,54 @@ impl CostModelStorageLayer for BackendManager {
561561
}
562562
}
563563

564-
let epoch_id = match epoch_id {
565-
Some(id) => id,
566-
None => {
567-
// When init, please make sure there is at least one epoch in the Event table.
568-
let latest_epoch_id = Event::find()
569-
.order_by_desc(event::Column::EpochId)
570-
.one(&self.db)
571-
.await?
572-
.unwrap();
573-
latest_epoch_id.epoch_id
574-
}
575-
};
576-
577564
let transaction = self.db.begin().await?;
578565

579-
let valid_cost = PlanCost::find()
580-
.filter(plan_cost::Column::PhysicalExpressionId.eq(physical_expression_id))
581-
.filter(plan_cost::Column::EpochId.eq(epoch_id))
582-
.filter(plan_cost::Column::IsValid.eq(true))
583-
.one(&transaction)
584-
.await?;
566+
/*
567+
The `store_cost` logic is as follows:
568+
1. If the epoch_id is provided, we should update the cost with the corresponding epoch_id,
569+
or insert a new record if it doesn't exist.
570+
2. If the epoch_id is not provided, we cannot directly use the latest epoch_id, since in the
571+
plan_cost table, for the current physical expression, there may be a valid cost with a lower
572+
epoch_id, since the update_stats function updates unrelated stats. So we need to handle the
573+
epoch_id in following logics:
574+
1) If a valid cost is already in the plan_cost table, we use the same epoch_id.
575+
2) If there is no valid cost in the plan_cost table, or there is no record, we use the
576+
latest epoch_id.
577+
*/
578+
// TODO: We should add some integration tests to fully test the above logic
579+
let epoch_id_data;
580+
let existed_cost;
581+
if let Some(epoch_id) = epoch_id {
582+
epoch_id_data = epoch_id;
583+
existed_cost = PlanCost::find()
584+
.filter(plan_cost::Column::PhysicalExpressionId.eq(physical_expression_id))
585+
.filter(plan_cost::Column::EpochId.eq(epoch_id))
586+
.one(&transaction)
587+
.await?;
588+
} else {
589+
existed_cost = PlanCost::find()
590+
.filter(plan_cost::Column::PhysicalExpressionId.eq(physical_expression_id))
591+
.filter(plan_cost::Column::IsValid.eq(true))
592+
.order_by_desc(plan_cost::Column::EpochId)
593+
.one(&transaction)
594+
.await?;
595+
if existed_cost.is_none() {
596+
epoch_id_data = {
597+
// When init, please make sure there is at least one epoch in the Event table.
598+
let latest_epoch_id = Event::find()
599+
.order_by_desc(event::Column::EpochId)
600+
.one(&self.db)
601+
.await?
602+
.unwrap();
603+
latest_epoch_id.epoch_id
604+
}
605+
} else {
606+
epoch_id_data = existed_cost.clone().unwrap().epoch_id;
607+
}
608+
}
585609

586-
if valid_cost.is_some() {
587-
let mut new_cost: plan_cost::ActiveModel = valid_cost.unwrap().into();
610+
if existed_cost.is_some() {
611+
let mut new_cost: plan_cost::ActiveModel = existed_cost.unwrap().into();
588612
let mut update = false;
589613
if cost.is_some() {
590614
let input_cost = sea_orm::ActiveValue::Set(Some(json!({
@@ -604,12 +628,25 @@ impl CostModelStorageLayer for BackendManager {
604628
}
605629
}
606630
if update {
631+
assert!(new_cost.epoch_id.is_unchanged());
607632
let _ = PlanCost::update(new_cost).exec(&transaction).await?;
608633
}
609634
} else {
635+
// TODO: we shouldn't do the following checks in the production environment.
636+
// This check may be easy to violate, so consider removing epoch_id input parameter.
637+
let latest_cost = PlanCost::find()
638+
.filter(plan_cost::Column::PhysicalExpressionId.eq(physical_expression_id))
639+
.order_by_desc(plan_cost::Column::EpochId)
640+
.one(&transaction)
641+
.await?;
642+
if latest_cost.is_some() {
643+
assert!(latest_cost.clone().unwrap().epoch_id < epoch_id_data);
644+
assert!(!latest_cost.clone().unwrap().is_valid);
645+
}
646+
610647
let new_cost = plan_cost::ActiveModel {
611648
physical_expression_id: sea_orm::ActiveValue::Set(physical_expression_id),
612-
epoch_id: sea_orm::ActiveValue::Set(epoch_id),
649+
epoch_id: sea_orm::ActiveValue::Set(epoch_id_data),
613650
cost: sea_orm::ActiveValue::Set(
614651
cost.map(|c| json!({"compute_cost": c.compute_cost, "io_cost": c.io_cost})),
615652
),
@@ -1035,6 +1072,18 @@ mod tests {
10351072
.create_new_epoch("source".to_string(), "data".to_string())
10361073
.await
10371074
.unwrap();
1075+
let stat = Stat {
1076+
stat_type: StatType::TableRowCount,
1077+
stat_value: json!(10),
1078+
attr_ids: vec![],
1079+
table_id: Some(1),
1080+
name: "row_count".to_owned(),
1081+
};
1082+
let res = backend_manager
1083+
.update_stats(stat, EpochOption::Existed(epoch_id))
1084+
.await;
1085+
assert!(res.is_ok());
1086+
10381087
let physical_expression_id = 1;
10391088
let cost = Cost {
10401089
compute_cost: 42.0,
@@ -1102,6 +1151,18 @@ mod tests {
11021151
.create_new_epoch("source".to_string(), "data".to_string())
11031152
.await
11041153
.unwrap();
1154+
let stat = Stat {
1155+
stat_type: StatType::TableRowCount,
1156+
stat_value: json!(10),
1157+
attr_ids: vec![],
1158+
table_id: Some(1),
1159+
name: "row_count".to_owned(),
1160+
};
1161+
let res = backend_manager
1162+
.update_stats(stat, EpochOption::Existed(epoch_id))
1163+
.await;
1164+
assert!(res.is_ok());
1165+
11051166
let physical_expression_id = 1;
11061167
let cost = Cost {
11071168
compute_cost: 42.0,
@@ -1148,6 +1209,18 @@ mod tests {
11481209
.create_new_epoch("source".to_string(), "data".to_string())
11491210
.await
11501211
.unwrap();
1212+
let stat = Stat {
1213+
stat_type: StatType::TableRowCount,
1214+
stat_value: json!(10),
1215+
attr_ids: vec![],
1216+
table_id: Some(1),
1217+
name: "row_count".to_owned(),
1218+
};
1219+
let res = backend_manager
1220+
.update_stats(stat, EpochOption::Existed(epoch_id))
1221+
.await;
1222+
assert!(res.is_ok());
1223+
11511224
let physical_expression_id = 1;
11521225
let estimated_statistic = 42.0;
11531226
let _ = backend_manager

0 commit comments

Comments
 (0)