Skip to content

Commit 23c444d

Browse files
committed
Finish compute_operation_cost
1 parent 904f275 commit 23c444d

File tree

8 files changed

+164
-26
lines changed

8 files changed

+164
-26
lines changed

optd-cost-model/src/cost_model.rs

Lines changed: 103 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,24 +43,111 @@ impl<S: CostModelStorageManager> CostModelImpl<S> {
4343

4444
#[async_trait::async_trait]
4545
impl<S: CostModelStorageManager + Send + Sync + 'static> CostModel for CostModelImpl<S> {
46+
/// TODO: should we add epoch_id?
4647
async fn compute_operation_cost(
4748
&self,
48-
node: &PhysicalNodeType,
49+
node: PhysicalNodeType,
4950
predicates: &[ArcPredicateNode],
51+
children_costs: &[Cost],
5052
children_stats: &[EstimatedStatistic],
5153
context: ComputeCostContext,
5254
) -> CostModelResult<Cost> {
53-
todo!()
55+
let res = self.storage_manager.get_cost(context.expr_id).await;
56+
if let Ok((Some(cost), _)) = res {
57+
return Ok(cost);
58+
};
59+
let mut output_statistic = None;
60+
if let Ok((_, Some(statistic))) = res {
61+
output_statistic = Some(statistic);
62+
};
63+
let output_cost = match node {
64+
PhysicalNodeType::PhysicalScan => {
65+
let output_statistic_data = output_statistic.unwrap_or(
66+
self.derive_statistics(
67+
node,
68+
predicates,
69+
children_stats,
70+
context.clone(),
71+
false,
72+
)
73+
.await?,
74+
);
75+
output_statistic = Some(output_statistic_data.clone());
76+
Cost {
77+
compute_cost: 0.0,
78+
io_cost: output_statistic_data.0,
79+
}
80+
}
81+
PhysicalNodeType::PhysicalEmptyRelation => Cost {
82+
compute_cost: 0.1,
83+
io_cost: 0.0,
84+
},
85+
PhysicalNodeType::PhysicalLimit => Cost {
86+
compute_cost: children_costs[0].compute_cost,
87+
io_cost: 0.0,
88+
},
89+
PhysicalNodeType::PhysicalFilter => Cost {
90+
// TODO: now this equation is specific to optd, and try to make this equation more general
91+
compute_cost: children_costs[1].compute_cost * children_stats[0].0,
92+
io_cost: 0.0,
93+
},
94+
PhysicalNodeType::PhysicalNestedLoopJoin(join_typ) => {
95+
let child_compute_cost = children_costs[2].compute_cost;
96+
Cost {
97+
compute_cost: children_stats[0].0 * children_stats[1].0 * child_compute_cost
98+
+ children_stats[0].0,
99+
io_cost: 0.0,
100+
}
101+
}
102+
// TODO: we should document that the first child is the left table, which is used to build
103+
// the hash table.
104+
PhysicalNodeType::PhysicalHashJoin(join_typ) => Cost {
105+
compute_cost: children_stats[0].0 * 2.0 + children_stats[1].0,
106+
io_cost: 0.0,
107+
},
108+
PhysicalNodeType::PhysicalAgg => Cost {
109+
compute_cost: children_stats[0].0
110+
* (children_costs[1].compute_cost + children_costs[2].compute_cost),
111+
io_cost: 0.0,
112+
},
113+
PhysicalNodeType::PhysicalProjection => Cost {
114+
compute_cost: children_stats[0].0 * children_costs[1].compute_cost,
115+
io_cost: 0.0,
116+
},
117+
PhysicalNodeType::PhysicalSort => Cost {
118+
compute_cost: children_stats[0].0 * children_stats[0].0.ln_1p().max(1.0),
119+
io_cost: 0.0,
120+
},
121+
};
122+
let res = self
123+
.storage_manager
124+
.store_cost(
125+
context.expr_id,
126+
Some(output_cost.clone()),
127+
output_statistic,
128+
None,
129+
)
130+
.await;
131+
if res.is_err() {
132+
eprintln!("Failed to store output cost");
133+
}
134+
Ok(output_cost)
54135
}
55136

137+
/// TODO: should we add epoch_id?
56138
async fn derive_statistics(
57139
&self,
58140
node: PhysicalNodeType,
59141
predicates: &[ArcPredicateNode],
60142
children_statistics: &[EstimatedStatistic],
61143
context: ComputeCostContext,
144+
store_output_statistic: bool,
62145
) -> CostModelResult<EstimatedStatistic> {
63-
match node {
146+
let res = self.storage_manager.get_cost(context.expr_id).await;
147+
if let Ok((_, Some(statistic))) = res {
148+
return Ok(statistic);
149+
}
150+
let output_statistic = match node {
64151
PhysicalNodeType::PhysicalScan => {
65152
let table_id = TableId(predicates[0].data.as_ref().unwrap().as_u64());
66153
let row_cnt = self
@@ -114,7 +201,17 @@ impl<S: CostModelStorageManager + Send + Sync + 'static> CostModel for CostModel
114201
PhysicalNodeType::PhysicalSort | PhysicalNodeType::PhysicalProjection => {
115202
Ok(children_statistics[0].clone())
116203
}
117-
}
204+
}?;
205+
if store_output_statistic {
206+
let res = self
207+
.storage_manager
208+
.store_cost(context.expr_id, None, Some(output_statistic.clone()), None)
209+
.await;
210+
if res.is_err() {
211+
eprintln!("Failed to store output statistic");
212+
}
213+
};
214+
Ok(output_statistic)
118215
}
119216

120217
async fn update_statistics(
@@ -167,3 +264,5 @@ impl<S: CostModelStorageManager> CostModelImpl<S> {
167264
.await
168265
}
169266
}
267+
268+
// TODO: Add tests for `derive_statistic`` and `compute_operation_cost`.

optd-cost-model/src/lib.rs

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,23 +118,41 @@ pub trait CostModel: 'static + Send + Sync {
118118
/// TODO: documentation
119119
async fn compute_operation_cost(
120120
&self,
121-
node: &PhysicalNodeType,
121+
node: PhysicalNodeType,
122122
predicates: &[ArcPredicateNode],
123+
children_costs: &[Cost],
123124
children_stats: &[EstimatedStatistic],
124125
context: ComputeCostContext,
125126
) -> CostModelResult<Cost>;
126127

127128
/// TODO: documentation
128129
/// It is for cardinality estimation. The output should be the estimated
129130
/// statistic calculated by the cost model.
131+
/// If this method is called by `compute_operation_cost`, please set
132+
/// `store_output_statistic` to `false`; if it is called by the optimizer,
133+
/// please set `store_output_statistic` to `true`. Since we can store the
134+
/// estimated statistic and cost by calling the ORM method once.
135+
///
136+
/// TODO: I am not sure whether to introduce `store_output_statistic`, since
137+
/// it add complexity to the interface, considering currently only Scan needs
138+
/// the output row count to calculate the costs. So updating the database twice
139+
/// seems cheap. But in the future, maybe more cost computations rely on the output
140+
/// row count.
141+
///
130142
/// TODO: Consider make it a helper function, so we can store Cost in the
131143
/// ORM more easily.
144+
///
145+
/// TODO: I would suggest to rename this method to `derive_row_count`, since
146+
/// statistic is easily to be confused with the real statistic.
147+
/// Also we need to update other places to use estimated statistic to row count,
148+
/// either in this crate or in optd-persistent.
132149
async fn derive_statistics(
133150
&self,
134151
node: PhysicalNodeType,
135152
predicates: &[ArcPredicateNode],
136153
children_stats: &[EstimatedStatistic],
137154
context: ComputeCostContext,
155+
store_output_statistic: bool,
138156
) -> CostModelResult<EstimatedStatistic>;
139157

140158
/// TODO: documentation

optd-cost-model/src/storage/mock.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ impl CostModelStorageManager for CostModelStorageMockManagerImpl {
8282
expr_id: ExprId,
8383
cost: Option<Cost>,
8484
estimated_statistic: Option<EstimatedStatistic>,
85-
epoch_id: EpochId,
85+
epoch_id: Option<EpochId>,
8686
) -> CostModelResult<()> {
8787
todo!()
8888
}

optd-cost-model/src/storage/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,6 @@ pub trait CostModelStorageManager {
2727
expr_id: ExprId,
2828
cost: Option<Cost>,
2929
estimated_statistic: Option<EstimatedStatistic>,
30-
epoch_id: EpochId,
30+
epoch_id: Option<EpochId>,
3131
) -> CostModelResult<()>;
3232
}

optd-cost-model/src/storage/persistent.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,14 +145,14 @@ impl<S: CostModelStorageLayer + Send + Sync> CostModelStorageManager
145145
expr_id: ExprId,
146146
cost: Option<Cost>,
147147
estimated_statistic: Option<EstimatedStatistic>,
148-
epoch_id: EpochId,
148+
epoch_id: Option<EpochId>,
149149
) -> CostModelResult<()> {
150150
self.backend_manager
151151
.store_cost(
152152
expr_id.into(),
153153
cost.map(|c| c.into()),
154154
estimated_statistic.map(|x| x.into()),
155-
epoch_id.into(),
155+
epoch_id.map(|id| id.into()),
156156
)
157157
.await?;
158158
Ok(())

optd-persistent/src/cost_model/interface.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ pub trait CostModelStorageLayer {
121121
expr_id: ExprId,
122122
cost: Option<Cost>,
123123
estimated_statistic: Option<f32>,
124-
epoch_id: EpochId,
124+
epoch_id: Option<EpochId>,
125125
) -> StorageResult<()>;
126126

127127
async fn store_expr_stats_mappings(

optd-persistent/src/cost_model/orm.rs

Lines changed: 37 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -519,13 +519,14 @@ impl CostModelStorageLayer for BackendManager {
519519

520520
/// This method should handle the case when the cost is already stored.
521521
/// The name maybe misleading, since it can also store the estimated statistic.
522+
/// If epoch_id is none, we pick the latest epoch_id.
522523
/// TODO: documentation
523524
async fn store_cost(
524525
&self,
525526
physical_expression_id: ExprId,
526527
cost: Option<Cost>,
527528
estimated_statistic: Option<f32>,
528-
epoch_id: EpochId,
529+
epoch_id: Option<EpochId>,
529530
) -> StorageResult<()> {
530531
assert!(cost.is_some() || estimated_statistic.is_some());
531532
// TODO: should we do the following checks in the production environment?
@@ -542,17 +543,32 @@ impl CostModelStorageLayer for BackendManager {
542543
));
543544
}
544545
// Check if epoch_id exists in Event table
545-
let epoch_exists = Event::find()
546-
.filter(event::Column::EpochId.eq(epoch_id))
547-
.one(&self.db)
548-
.await
549-
.unwrap();
550-
if epoch_exists.is_none() {
551-
return Err(BackendError::CostModel(
552-
format!("epoch id {} not found when storing cost", epoch_id).into(),
553-
));
546+
if epoch_id.is_some() {
547+
let epoch_exists = Event::find()
548+
.filter(event::Column::EpochId.eq(epoch_id.unwrap()))
549+
.one(&self.db)
550+
.await
551+
.unwrap();
552+
if epoch_exists.is_none() {
553+
return Err(BackendError::CostModel(
554+
format!("epoch id {} not found when storing cost", epoch_id.unwrap()).into(),
555+
));
556+
}
554557
}
555558

559+
let epoch_id = match epoch_id {
560+
Some(id) => id,
561+
None => {
562+
// When init, please make sure there is at least one epoch in the Event table.
563+
let latest_epoch_id = Event::find()
564+
.order_by_desc(event::Column::EpochId)
565+
.one(&self.db)
566+
.await?
567+
.unwrap();
568+
latest_epoch_id.epoch_id
569+
}
570+
};
571+
556572
let transaction = self.db.begin().await?;
557573

558574
let valid_cost = PlanCost::find()
@@ -805,7 +821,7 @@ mod tests {
805821
io_cost: 42.0,
806822
}),
807823
Some(42.0),
808-
versioned_stat_res[0].epoch_id,
824+
Some(versioned_stat_res[0].epoch_id),
809825
)
810826
.await
811827
.unwrap();
@@ -871,7 +887,7 @@ mod tests {
871887
assert_eq!(cost_res.len(), 1);
872888
assert_eq!(
873889
cost_res[0].cost,
874-
Some(json!({"compute_cost": 42, "io_cost": 42}))
890+
Some(json!({"compute_cost": 42.0, "io_cost": 42.0}))
875891
);
876892
assert_eq!(cost_res[0].epoch_id, epoch_id1);
877893
assert!(!cost_res[0].is_valid);
@@ -1013,7 +1029,7 @@ mod tests {
10131029
physical_expression_id,
10141030
Some(cost.clone()),
10151031
Some(estimated_statistic),
1016-
epoch_id,
1032+
Some(epoch_id),
10171033
)
10181034
.await
10191035
.unwrap();
@@ -1036,7 +1052,7 @@ mod tests {
10361052
physical_expression_id,
10371053
None,
10381054
Some(estimated_statistic),
1039-
epoch_id,
1055+
None,
10401056
)
10411057
.await
10421058
.unwrap();
@@ -1075,7 +1091,12 @@ mod tests {
10751091
io_cost: 42.0,
10761092
};
10771093
let _ = backend_manager
1078-
.store_cost(physical_expression_id, Some(cost.clone()), None, epoch_id)
1094+
.store_cost(
1095+
physical_expression_id,
1096+
Some(cost.clone()),
1097+
None,
1098+
Some(epoch_id),
1099+
)
10791100
.await;
10801101
let costs = super::PlanCost::find()
10811102
.all(&backend_manager.db)
@@ -1117,7 +1138,7 @@ mod tests {
11171138
physical_expression_id,
11181139
None,
11191140
Some(estimated_statistic),
1120-
epoch_id,
1141+
Some(epoch_id),
11211142
)
11221143
.await;
11231144
let costs = super::PlanCost::find()

optd-persistent/src/db/init.db

0 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)