Skip to content

Commit e183f02

Browse files
committed
add test for cost model agg
1 parent 082f0be commit e183f02

File tree

6 files changed

+177
-16
lines changed

6 files changed

+177
-16
lines changed

optd-cost-model/src/common/predicates/attr_ref_pred.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,12 @@ use super::id_pred::IdPred;
2828
pub struct AttributeRefPred(pub ArcPredicateNode);
2929

3030
impl AttributeRefPred {
31-
pub fn new(table_id: TableId, attribute_idx: usize) -> AttributeRefPred {
31+
pub fn new(table_id: TableId, attribute_idx: u64) -> AttributeRefPred {
3232
AttributeRefPred(
3333
PredicateNode {
3434
typ: PredicateType::AttributeRef,
3535
children: vec![
36-
IdPred::new(table_id.0).into_pred_node(),
36+
IdPred::new(table_id.0 as u64).into_pred_node(),
3737
IdPred::new(attribute_idx).into_pred_node(),
3838
],
3939
data: None,

optd-cost-model/src/common/predicates/id_pred.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,12 @@ use crate::common::{
1111
pub struct IdPred(pub ArcPredicateNode);
1212

1313
impl IdPred {
14-
pub fn new(id: usize) -> IdPred {
15-
// This conversion is always safe since usize is at most u64.
16-
let u64_id = id as u64;
14+
pub fn new(id: u64) -> IdPred {
1715
IdPred(
1816
PredicateNode {
1917
typ: PredicateType::Id,
2018
children: vec![],
21-
data: Some(Value::UInt64(u64_id)),
19+
data: Some(Value::UInt64(id)),
2220
}
2321
.into(),
2422
)

optd-cost-model/src/cost/agg.rs

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,166 @@ impl<S: CostModelStorageManager> CostModelImpl<S> {
6161
}
6262
}
6363
}
64+
65+
#[cfg(test)]
66+
mod tests {
67+
use std::collections::HashMap;
68+
69+
use crate::{
70+
common::{predicates::constant_pred::ConstantType, types::TableId, values::Value},
71+
cost_model::tests::{
72+
attr_ref, cnst, create_cost_model_mock_storage, empty_list, empty_per_attr_stats, list,
73+
TestPerAttributeStats,
74+
},
75+
stats::{utilities::simple_map::SimpleMap, MostCommonValues, DEFAULT_NUM_DISTINCT},
76+
storage::Attribute,
77+
EstimatedStatistic,
78+
};
79+
80+
#[tokio::test]
81+
async fn test_agg_no_stats() {
82+
let table_id = TableId(0);
83+
let attr_infos = HashMap::from([(
84+
table_id,
85+
HashMap::from([
86+
(
87+
0,
88+
Attribute {
89+
name: String::from("attr1"),
90+
typ: ConstantType::Int32,
91+
nullable: false,
92+
},
93+
),
94+
(
95+
1,
96+
Attribute {
97+
name: String::from("attr2"),
98+
typ: ConstantType::Int64,
99+
nullable: false,
100+
},
101+
),
102+
]),
103+
)]);
104+
let cost_model =
105+
create_cost_model_mock_storage(vec![table_id], vec![], vec![None], attr_infos);
106+
107+
// Group by empty list should return 1.
108+
let group_bys = empty_list();
109+
assert_eq!(
110+
cost_model.get_agg_row_cnt(group_bys).await.unwrap(),
111+
EstimatedStatistic(1)
112+
);
113+
114+
// Group by single column should return the default value since there are no stats.
115+
let group_bys = list(vec![attr_ref(table_id, 0)]);
116+
assert_eq!(
117+
cost_model.get_agg_row_cnt(group_bys).await.unwrap(),
118+
EstimatedStatistic(DEFAULT_NUM_DISTINCT)
119+
);
120+
121+
// Group by two columns should return the default value squared since there are no stats.
122+
let group_bys = list(vec![attr_ref(table_id, 0), attr_ref(table_id, 1)]);
123+
assert_eq!(
124+
cost_model.get_agg_row_cnt(group_bys).await.unwrap(),
125+
EstimatedStatistic(DEFAULT_NUM_DISTINCT * DEFAULT_NUM_DISTINCT)
126+
);
127+
}
128+
129+
#[tokio::test]
130+
async fn test_agg_with_stats() {
131+
let table_id = TableId(0);
132+
let attr1_base_idx = 0;
133+
let attr2_base_idx = 1;
134+
let attr3_base_idx = 2;
135+
let attr_infos = HashMap::from([(
136+
table_id,
137+
HashMap::from([
138+
(
139+
attr1_base_idx,
140+
Attribute {
141+
name: String::from("attr1"),
142+
typ: ConstantType::Int32,
143+
nullable: false,
144+
},
145+
),
146+
(
147+
attr2_base_idx,
148+
Attribute {
149+
name: String::from("attr2"),
150+
typ: ConstantType::Int64,
151+
nullable: false,
152+
},
153+
),
154+
(
155+
attr3_base_idx,
156+
Attribute {
157+
name: String::from("attr3"),
158+
typ: ConstantType::Int64,
159+
nullable: false,
160+
},
161+
),
162+
]),
163+
)]);
164+
165+
let attr1_ndistinct = 12;
166+
let attr2_ndistinct = 645;
167+
let attr1_stats = TestPerAttributeStats::new(
168+
MostCommonValues::SimpleFrequency(SimpleMap::default()),
169+
None,
170+
attr1_ndistinct,
171+
0.0,
172+
);
173+
let attr2_stats = TestPerAttributeStats::new(
174+
MostCommonValues::SimpleFrequency(SimpleMap::default()),
175+
None,
176+
attr2_ndistinct,
177+
0.0,
178+
);
179+
180+
let cost_model = create_cost_model_mock_storage(
181+
vec![table_id],
182+
vec![HashMap::from([
183+
(attr1_base_idx, attr1_stats),
184+
(attr2_base_idx, attr2_stats),
185+
])],
186+
vec![None],
187+
attr_infos,
188+
);
189+
190+
// Group by empty list should return 1.
191+
let group_bys = empty_list();
192+
assert_eq!(
193+
cost_model.get_agg_row_cnt(group_bys).await.unwrap(),
194+
EstimatedStatistic(1)
195+
);
196+
197+
// Group by single column should return the n-distinct of the column.
198+
let group_bys = list(vec![attr_ref(table_id, attr1_base_idx)]);
199+
assert_eq!(
200+
cost_model.get_agg_row_cnt(group_bys).await.unwrap(),
201+
EstimatedStatistic(attr1_ndistinct)
202+
);
203+
204+
// Group by two columns should return the product of the n-distinct of the columns.
205+
let group_bys = list(vec![
206+
attr_ref(table_id, attr1_base_idx),
207+
attr_ref(table_id, attr2_base_idx),
208+
]);
209+
assert_eq!(
210+
cost_model.get_agg_row_cnt(group_bys).await.unwrap(),
211+
EstimatedStatistic(attr1_ndistinct * attr2_ndistinct)
212+
);
213+
214+
// Group by multiple columns should return the product of the n-distinct of the columns. If one of the columns
215+
// does not have stats, it should use the default value instead.
216+
let group_bys = list(vec![
217+
attr_ref(table_id, attr1_base_idx),
218+
attr_ref(table_id, attr2_base_idx),
219+
attr_ref(table_id, attr3_base_idx),
220+
]);
221+
assert_eq!(
222+
cost_model.get_agg_row_cnt(group_bys).await.unwrap(),
223+
EstimatedStatistic(attr1_ndistinct * attr2_ndistinct * DEFAULT_NUM_DISTINCT)
224+
);
225+
}
226+
}

optd-cost-model/src/cost/filter/controller.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ mod tests {
240240
#[tokio::test]
241241
async fn test_attr_ref_leq_constint_no_mcvs_in_range() {
242242
let per_attribute_stats = TestPerAttributeStats::new(
243-
MostCommonValues::SimpleFrequency(SimpleMap::new(vec![])),
243+
MostCommonValues::SimpleFrequency(SimpleMap::default()),
244244
Some(Distribution::SimpleDistribution(SimpleMap::new(vec![(
245245
Value::Int32(15),
246246
0.7,
@@ -364,7 +364,7 @@ mod tests {
364364
#[tokio::test]
365365
async fn test_attr_ref_lt_constint_no_mcvs_in_range() {
366366
let per_attribute_stats = TestPerAttributeStats::new(
367-
MostCommonValues::SimpleFrequency(SimpleMap::new(vec![])),
367+
MostCommonValues::SimpleFrequency(SimpleMap::default()),
368368
Some(Distribution::SimpleDistribution(SimpleMap::new(vec![(
369369
Value::Int32(15),
370370
0.7,
@@ -492,7 +492,7 @@ mod tests {
492492
#[tokio::test]
493493
async fn test_attr_ref_gt_constint() {
494494
let per_attribute_stats = TestPerAttributeStats::new(
495-
MostCommonValues::SimpleFrequency(SimpleMap::new(vec![])),
495+
MostCommonValues::SimpleFrequency(SimpleMap::default()),
496496
Some(Distribution::SimpleDistribution(SimpleMap::new(vec![(
497497
Value::Int32(15),
498498
0.7,
@@ -530,7 +530,7 @@ mod tests {
530530
#[tokio::test]
531531
async fn test_attr_ref_geq_constint() {
532532
let per_attribute_stats = TestPerAttributeStats::new(
533-
MostCommonValues::SimpleFrequency(SimpleMap::new(vec![])),
533+
MostCommonValues::SimpleFrequency(SimpleMap::default()),
534534
Some(Distribution::SimpleDistribution(SimpleMap::new(vec![(
535535
Value::Int32(15),
536536
0.7,
@@ -798,7 +798,7 @@ mod tests {
798798
#[tokio::test]
799799
async fn test_cast_attr_ref_eq_attr_ref() {
800800
let per_attribute_stats = TestPerAttributeStats::new(
801-
MostCommonValues::SimpleFrequency(SimpleMap::new(vec![])),
801+
MostCommonValues::SimpleFrequency(SimpleMap::default()),
802802
None,
803803
0,
804804
0.0,

optd-cost-model/src/cost_model.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ pub mod tests {
176176
CostModelImpl::new(storage_manager, CatalogSource::Mock)
177177
}
178178

179-
pub fn attr_ref(table_id: TableId, attr_base_index: usize) -> ArcPredicateNode {
179+
pub fn attr_ref(table_id: TableId, attr_base_index: u64) -> ArcPredicateNode {
180180
AttributeRefPred::new(table_id, attr_base_index).into_pred_node()
181181
}
182182

@@ -214,7 +214,7 @@ pub mod tests {
214214

215215
pub fn in_list(
216216
table_id: TableId,
217-
attr_ref_idx: usize,
217+
attr_ref_idx: u64,
218218
list: Vec<Value>,
219219
negated: bool,
220220
) -> InListPred {
@@ -225,7 +225,7 @@ pub mod tests {
225225
)
226226
}
227227

228-
pub fn like(table_id: TableId, attr_ref_idx: usize, pattern: &str, negated: bool) -> LikePred {
228+
pub fn like(table_id: TableId, attr_ref_idx: u64, pattern: &str, negated: bool) -> LikePred {
229229
LikePred::new(
230230
negated,
231231
false,

optd-cost-model/src/storage/mock.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ impl TableStats {
3030
}
3131

3232
pub type BaseTableStats = HashMap<TableId, TableStats>;
33-
pub type BaseTableAttrInfo = HashMap<TableId, HashMap<i32, Attribute>>;
33+
pub type BaseTableAttrInfo = HashMap<TableId, HashMap<u64, Attribute>>; // (table_id, (attr_base_index, attr))
3434

3535
pub struct CostModelStorageMockManagerImpl {
3636
pub(crate) per_table_stats_map: BaseTableStats,
@@ -58,7 +58,7 @@ impl CostModelStorageManager for CostModelStorageMockManagerImpl {
5858
let table_attr_infos = self.per_table_attr_infos_map.get(&table_id);
5959
match table_attr_infos {
6060
None => Ok(None),
61-
Some(table_attr_infos) => match table_attr_infos.get(&attr_base_index) {
61+
Some(table_attr_infos) => match table_attr_infos.get(&(attr_base_index as u64)) {
6262
None => Ok(None),
6363
Some(attr) => Ok(Some(attr.clone())),
6464
},

0 commit comments

Comments
 (0)