Skip to content

Commit 904f275

Browse files
committed
Add get_cost & store_cost in the cost model storage layer
1 parent ea20998 commit 904f275

File tree

11 files changed

+153
-46
lines changed

11 files changed

+153
-46
lines changed

optd-cost-model/src/cost_model.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ use crate::{
1515
},
1616
memo_ext::MemoExt,
1717
stats::AttributeCombValueStats,
18-
storage::CostModelStorageManager,
18+
storage::{self, CostModelStorageManager},
1919
ComputeCostContext, Cost, CostModel, CostModelResult, EstimatedStatistic, StatValue,
2020
};
2121

optd-cost-model/src/lib.rs

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,52 @@ pub struct ComputeCostContext {
3030
}
3131

3232
#[derive(Default, Clone, Debug, PartialOrd, PartialEq)]
33-
pub struct Cost(pub Vec<f64>);
33+
pub struct Cost {
34+
pub compute_cost: f64,
35+
pub io_cost: f64,
36+
}
37+
38+
impl From<Cost> for optd_persistent::cost_model::interface::Cost {
39+
fn from(c: Cost) -> optd_persistent::cost_model::interface::Cost {
40+
Self {
41+
compute_cost: c.compute_cost,
42+
io_cost: c.io_cost,
43+
}
44+
}
45+
}
46+
47+
impl From<optd_persistent::cost_model::interface::Cost> for Cost {
48+
fn from(c: optd_persistent::cost_model::interface::Cost) -> Cost {
49+
Self {
50+
compute_cost: c.compute_cost,
51+
io_cost: c.io_cost,
52+
}
53+
}
54+
}
3455

3556
/// Estimated statistic calculated by the cost model.
3657
/// It is the estimated output row count of the targeted expression.
3758
#[derive(PartialEq, PartialOrd, Clone, Debug)]
3859
pub struct EstimatedStatistic(pub f64);
3960

61+
impl From<EstimatedStatistic> for f32 {
62+
fn from(e: EstimatedStatistic) -> f32 {
63+
e.0 as f32
64+
}
65+
}
66+
67+
impl From<EstimatedStatistic> for f64 {
68+
fn from(e: EstimatedStatistic) -> f64 {
69+
e.0
70+
}
71+
}
72+
73+
impl From<f32> for EstimatedStatistic {
74+
fn from(f: f32) -> EstimatedStatistic {
75+
Self(f as f64)
76+
}
77+
}
78+
4079
pub type CostModelResult<T> = Result<T, CostModelError>;
4180

4281
#[derive(Debug)]

optd-cost-model/src/storage/mock.rs

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@ use std::collections::HashMap;
33

44
use serde::{Deserialize, Serialize};
55

6-
use crate::{common::types::TableId, stats::AttributeCombValueStats, CostModelResult};
6+
use crate::{
7+
common::types::{EpochId, ExprId, TableId},
8+
stats::AttributeCombValueStats,
9+
Cost, CostModelResult, EstimatedStatistic,
10+
};
711

812
use super::CostModelStorageManager;
913

@@ -63,4 +67,23 @@ impl CostModelStorageManager for CostModelStorageMockManagerImpl {
6367
let table_stats = self.per_table_stats_map.get(&table_id);
6468
Ok(table_stats.map(|stats| stats.row_cnt))
6569
}
70+
71+
/// TODO: finish this when implementing the cost get/store tests
72+
async fn get_cost(
73+
&self,
74+
expr_id: ExprId,
75+
) -> CostModelResult<(Option<crate::Cost>, Option<EstimatedStatistic>)> {
76+
todo!()
77+
}
78+
79+
/// TODO: finish this when implementing the cost get/store tests
80+
async fn store_cost(
81+
&self,
82+
expr_id: ExprId,
83+
cost: Option<Cost>,
84+
estimated_statistic: Option<EstimatedStatistic>,
85+
epoch_id: EpochId,
86+
) -> CostModelResult<()> {
87+
todo!()
88+
}
6689
}

optd-cost-model/src/storage/mod.rs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
use crate::{common::types::TableId, stats::AttributeCombValueStats, CostModelResult};
1+
use crate::{
2+
common::types::{EpochId, ExprId, TableId},
3+
stats::AttributeCombValueStats,
4+
Cost, CostModelResult, EstimatedStatistic,
5+
};
26

37
pub mod mock;
48
pub mod persistent;
@@ -12,4 +16,17 @@ pub trait CostModelStorageManager {
1216
) -> CostModelResult<Option<AttributeCombValueStats>>;
1317

1418
async fn get_table_row_count(&self, table_id: TableId) -> CostModelResult<Option<u64>>;
19+
20+
async fn get_cost(
21+
&self,
22+
expr_id: ExprId,
23+
) -> CostModelResult<(Option<Cost>, Option<EstimatedStatistic>)>;
24+
25+
async fn store_cost(
26+
&self,
27+
expr_id: ExprId,
28+
cost: Option<Cost>,
29+
estimated_statistic: Option<EstimatedStatistic>,
30+
epoch_id: EpochId,
31+
) -> CostModelResult<()>;
1532
}

optd-cost-model/src/storage/persistent.rs

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ use std::sync::Arc;
44
use optd_persistent::{cost_model::interface::StatType, CostModelStorageLayer};
55

66
use crate::{
7-
common::types::TableId,
7+
common::types::{EpochId, ExprId, TableId},
88
stats::{utilities::counter::Counter, AttributeCombValueStats, Distribution, MostCommonValues},
9-
CostModelResult,
9+
Cost, CostModelResult, EstimatedStatistic,
1010
};
1111

1212
use super::CostModelStorageManager;
@@ -125,5 +125,38 @@ impl<S: CostModelStorageLayer + Send + Sync> CostModelStorageManager
125125
.transpose()?)
126126
}
127127

128+
/// TODO: The name is misleading, since we can also get the estimated statistic. We should
129+
/// rename it.
130+
async fn get_cost(
131+
&self,
132+
expr_id: ExprId,
133+
) -> CostModelResult<(Option<Cost>, Option<EstimatedStatistic>)> {
134+
let (cost, estimated_statistic) = self.backend_manager.get_cost(expr_id.into()).await?;
135+
Ok((
136+
cost.map(|c| c.into()),
137+
estimated_statistic.map(|x| x.into()),
138+
))
139+
}
140+
141+
/// TODO: The name is misleading, since we can also get the estimated statistic. We should
142+
/// rename it.
143+
async fn store_cost(
144+
&self,
145+
expr_id: ExprId,
146+
cost: Option<Cost>,
147+
estimated_statistic: Option<EstimatedStatistic>,
148+
epoch_id: EpochId,
149+
) -> CostModelResult<()> {
150+
self.backend_manager
151+
.store_cost(
152+
expr_id.into(),
153+
cost.map(|c| c.into()),
154+
estimated_statistic.map(|x| x.into()),
155+
epoch_id.into(),
156+
)
157+
.await?;
158+
Ok(())
159+
}
160+
128161
// TODO: Support querying for a specific type of statistics.
129162
}

optd-persistent/src/bin/init.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ async fn init_all_tables() -> Result<(), sea_orm::error::DbErr> {
356356
physical_expression_id: Set(1),
357357
epoch_id: Set(1),
358358
cost: Set(Some(json!({"compute_cost":10, "io_cost":10}))),
359-
estimated_statistic: Set(Some(10)),
359+
estimated_statistic: Set(Some(10.0)),
360360
is_valid: Set(true),
361361
};
362362
plan_cost::Entity::insert(plan_cost)

optd-persistent/src/cost_model/interface.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,8 @@ pub struct Stat {
8989
/// TODO: documentation
9090
#[derive(Clone, Debug, PartialEq)]
9191
pub struct Cost {
92-
pub compute_cost: i32,
93-
pub io_cost: i32,
92+
pub compute_cost: f64,
93+
pub io_cost: f64,
9494
}
9595

9696
#[derive(Clone, Debug)]
@@ -120,7 +120,7 @@ pub trait CostModelStorageLayer {
120120
&self,
121121
expr_id: ExprId,
122122
cost: Option<Cost>,
123-
estimated_statistic: Option<i32>,
123+
estimated_statistic: Option<f32>,
124124
epoch_id: EpochId,
125125
) -> StorageResult<()>;
126126

@@ -165,9 +165,9 @@ pub trait CostModelStorageLayer {
165165
&self,
166166
expr_id: ExprId,
167167
epoch_id: EpochId,
168-
) -> StorageResult<(Option<Cost>, Option<i32>)>;
168+
) -> StorageResult<(Option<Cost>, Option<f32>)>;
169169

170-
async fn get_cost(&self, expr_id: ExprId) -> StorageResult<(Option<Cost>, Option<i32>)>;
170+
async fn get_cost(&self, expr_id: ExprId) -> StorageResult<(Option<Cost>, Option<f32>)>;
171171

172172
async fn get_attribute(
173173
&self,

optd-persistent/src/cost_model/orm.rs

Lines changed: 23 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,7 @@ impl CostModelStorageLayer for BackendManager {
474474
&self,
475475
expr_id: ExprId,
476476
epoch_id: EpochId,
477-
) -> StorageResult<(Option<Cost>, Option<i32>)> {
477+
) -> StorageResult<(Option<Cost>, Option<f32>)> {
478478
let cost = PlanCost::find()
479479
.filter(plan_cost::Column::PhysicalExpressionId.eq(expr_id))
480480
.filter(plan_cost::Column::EpochId.eq(epoch_id))
@@ -486,8 +486,8 @@ impl CostModelStorageLayer for BackendManager {
486486
}
487487

488488
let real_cost = cost.as_ref().and_then(|c| c.cost.as_ref()).map(|c| Cost {
489-
compute_cost: c.get("compute_cost").unwrap().as_i64().unwrap() as i32,
490-
io_cost: c.get("io_cost").unwrap().as_i64().unwrap() as i32,
489+
compute_cost: c.get("compute_cost").unwrap().as_f64().unwrap(),
490+
io_cost: c.get("io_cost").unwrap().as_f64().unwrap(),
491491
});
492492

493493
Ok((real_cost, cost.unwrap().estimated_statistic))
@@ -498,7 +498,7 @@ impl CostModelStorageLayer for BackendManager {
498498
/// Each record in the `plan_cost` table can contain either the cost or the estimated statistic
499499
/// or both, but never neither.
500500
/// The name can be misleading, since it can also return the estimated statistic.
501-
async fn get_cost(&self, expr_id: ExprId) -> StorageResult<(Option<Cost>, Option<i32>)> {
501+
async fn get_cost(&self, expr_id: ExprId) -> StorageResult<(Option<Cost>, Option<f32>)> {
502502
let cost = PlanCost::find()
503503
.filter(plan_cost::Column::PhysicalExpressionId.eq(expr_id))
504504
.order_by_desc(plan_cost::Column::EpochId)
@@ -510,8 +510,8 @@ impl CostModelStorageLayer for BackendManager {
510510
}
511511

512512
let real_cost = cost.as_ref().and_then(|c| c.cost.as_ref()).map(|c| Cost {
513-
compute_cost: c.get("compute_cost").unwrap().as_i64().unwrap() as i32,
514-
io_cost: c.get("io_cost").unwrap().as_i64().unwrap() as i32,
513+
compute_cost: c.get("compute_cost").unwrap().as_f64().unwrap(),
514+
io_cost: c.get("io_cost").unwrap().as_f64().unwrap(),
515515
});
516516

517517
Ok((real_cost, cost.unwrap().estimated_statistic))
@@ -524,7 +524,7 @@ impl CostModelStorageLayer for BackendManager {
524524
&self,
525525
physical_expression_id: ExprId,
526526
cost: Option<Cost>,
527-
estimated_statistic: Option<i32>,
527+
estimated_statistic: Option<f32>,
528528
epoch_id: EpochId,
529529
) -> StorageResult<()> {
530530
assert!(cost.is_some() || estimated_statistic.is_some());
@@ -801,10 +801,10 @@ mod tests {
801801
.store_cost(
802802
expr_id,
803803
Some(Cost {
804-
compute_cost: 42,
805-
io_cost: 42,
804+
compute_cost: 42.0,
805+
io_cost: 42.0,
806806
}),
807-
Some(42),
807+
Some(42.0),
808808
versioned_stat_res[0].epoch_id,
809809
)
810810
.await
@@ -1004,10 +1004,10 @@ mod tests {
10041004
.unwrap();
10051005
let physical_expression_id = 1;
10061006
let cost = Cost {
1007-
compute_cost: 42,
1008-
io_cost: 42,
1007+
compute_cost: 42.0,
1008+
io_cost: 42.0,
10091009
};
1010-
let mut estimated_statistic = 42;
1010+
let mut estimated_statistic = 42.0;
10111011
backend_manager
10121012
.store_cost(
10131013
physical_expression_id,
@@ -1028,12 +1028,9 @@ mod tests {
10281028
costs[1].cost,
10291029
Some(json!({"compute_cost": cost.compute_cost, "io_cost": cost.io_cost}))
10301030
);
1031-
assert_eq!(
1032-
costs[1].estimated_statistic.unwrap() as i32,
1033-
estimated_statistic
1034-
);
1031+
assert_eq!(costs[1].estimated_statistic.unwrap(), estimated_statistic);
10351032

1036-
estimated_statistic = 50;
1033+
estimated_statistic = 50.0;
10371034
backend_manager
10381035
.store_cost(
10391036
physical_expression_id,
@@ -1055,7 +1052,7 @@ mod tests {
10551052
Some(json!({"compute_cost": cost.compute_cost, "io_cost": cost.io_cost}))
10561053
);
10571054
assert_eq!(
1058-
costs[1].estimated_statistic.unwrap() as i32,
1055+
costs[1].estimated_statistic.unwrap(),
10591056
estimated_statistic // The estimated_statistic should be update
10601057
);
10611058

@@ -1074,8 +1071,8 @@ mod tests {
10741071
.unwrap();
10751072
let physical_expression_id = 1;
10761073
let cost = Cost {
1077-
compute_cost: 42,
1078-
io_cost: 42,
1074+
compute_cost: 42.0,
1075+
io_cost: 42.0,
10791076
};
10801077
let _ = backend_manager
10811078
.store_cost(physical_expression_id, Some(cost.clone()), None, epoch_id)
@@ -1114,7 +1111,7 @@ mod tests {
11141111
.await
11151112
.unwrap();
11161113
let physical_expression_id = 1;
1117-
let estimated_statistic = 42;
1114+
let estimated_statistic = 42.0;
11181115
let _ = backend_manager
11191116
.store_cost(
11201117
physical_expression_id,
@@ -1131,10 +1128,7 @@ mod tests {
11311128
assert_eq!(costs[1].epoch_id, epoch_id);
11321129
assert_eq!(costs[1].physical_expression_id, physical_expression_id);
11331130
assert_eq!(costs[1].cost, None);
1134-
assert_eq!(
1135-
costs[1].estimated_statistic.unwrap() as i32,
1136-
estimated_statistic
1137-
);
1131+
assert_eq!(costs[1].estimated_statistic.unwrap(), estimated_statistic);
11381132
println!("{:?}", costs);
11391133

11401134
// Retrieve physical_expression_id 1 and epoch_id 1
@@ -1147,11 +1141,11 @@ mod tests {
11471141
assert_eq!(
11481142
res.0.unwrap(),
11491143
Cost {
1150-
compute_cost: 10,
1151-
io_cost: 10,
1144+
compute_cost: 10.0,
1145+
io_cost: 10.0,
11521146
}
11531147
);
1154-
assert_eq!(res.1.unwrap(), 10);
1148+
assert_eq!(res.1.unwrap(), 10.0);
11551149

11561150
remove_db_file(DATABASE_FILE);
11571151
}

optd-persistent/src/entities/plan_cost.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,16 @@
22
33
use sea_orm::entity::prelude::*;
44

5-
#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq)]
5+
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
66
#[sea_orm(table_name = "plan_cost")]
77
pub struct Model {
88
#[sea_orm(primary_key)]
99
pub id: i32,
1010
pub physical_expression_id: i32,
1111
pub epoch_id: i32,
1212
pub cost: Option<Json>,
13-
pub estimated_statistic: Option<i32>,
13+
#[sea_orm(column_type = "Float", nullable)]
14+
pub estimated_statistic: Option<f32>,
1415
pub is_valid: bool,
1516
}
1617

optd-persistent/src/migrator/cost_model/m20241029_000001_plan_cost.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ impl MigrationTrait for Migration {
5858
.on_update(ForeignKeyAction::Cascade),
5959
)
6060
.col(json_null(PlanCost::Cost))
61-
.col(integer_null(PlanCost::EstimatedStatistic))
61+
.col(float_null(PlanCost::EstimatedStatistic))
6262
.col(boolean(PlanCost::IsValid))
6363
.to_owned(),
6464
)

0 commit comments

Comments
 (0)