Skip to content

Commit be71afb

Browse files
committed
pass group id to join and fix filter-related tests
1 parent 489ff48 commit be71afb

File tree

5 files changed

+124
-63
lines changed

5 files changed

+124
-63
lines changed

optd-cost-model/src/common/properties/mod.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,21 @@ impl std::fmt::Display for Attribute {
2121
}
2222
}
2323
}
24+
25+
impl Attribute {
26+
pub fn new(name: String, typ: ConstantType, nullable: bool) -> Self {
27+
Self {
28+
name,
29+
typ,
30+
nullable,
31+
}
32+
}
33+
34+
pub fn new_non_null_int64(name: String) -> Self {
35+
Self {
36+
name,
37+
typ: ConstantType::Int64,
38+
nullable: false,
39+
}
40+
}
41+
}

optd-cost-model/src/cost/join/core.rs

Lines changed: 82 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ impl<S: CostModelStorageManager> CostModelImpl<S> {
3434
pub(crate) async fn get_join_selectivity_from_expr_tree(
3535
&self,
3636
join_typ: JoinType,
37+
group_id: GroupId,
3738
expr_tree: ArcPredicateNode,
3839
attr_refs: &AttrRefs,
3940
input_correlation: Option<SemanticCorrelation>,
@@ -61,6 +62,7 @@ impl<S: CostModelStorageManager> CostModelImpl<S> {
6162
};
6263
self.get_join_selectivity_core(
6364
join_typ,
65+
group_id,
6466
on_attr_ref_pairs,
6567
filter_expr_tree,
6668
attr_refs,
@@ -75,6 +77,7 @@ impl<S: CostModelStorageManager> CostModelImpl<S> {
7577
if let Some(on_attr_ref_pair) = get_on_attr_ref_pair(expr_tree.clone(), attr_refs) {
7678
self.get_join_selectivity_core(
7779
join_typ,
80+
group_id,
7881
vec![on_attr_ref_pair],
7982
None,
8083
attr_refs,
@@ -87,6 +90,7 @@ impl<S: CostModelStorageManager> CostModelImpl<S> {
8790
} else {
8891
self.get_join_selectivity_core(
8992
join_typ,
93+
group_id,
9094
vec![],
9195
Some(expr_tree),
9296
attr_refs,
@@ -105,6 +109,7 @@ impl<S: CostModelStorageManager> CostModelImpl<S> {
105109
pub(crate) async fn get_join_selectivity_from_keys(
106110
&self,
107111
join_typ: JoinType,
112+
group_id: GroupId,
108113
left_keys: ListPred,
109114
right_keys: ListPred,
110115
attr_refs: &AttrRefs,
@@ -129,6 +134,7 @@ impl<S: CostModelStorageManager> CostModelImpl<S> {
129134
.collect_vec();
130135
self.get_join_selectivity_core(
131136
join_typ,
137+
group_id,
132138
on_attr_ref_pairs,
133139
None,
134140
attr_refs,
@@ -156,6 +162,7 @@ impl<S: CostModelStorageManager> CostModelImpl<S> {
156162
async fn get_join_selectivity_core(
157163
&self,
158164
join_typ: JoinType,
165+
group_id: GroupId,
159166
on_attr_ref_pairs: Vec<(AttrIndexPred, AttrIndexPred)>,
160167
filter_expr_tree: Option<ArcPredicateNode>,
161168
attr_refs: &AttrRefs,
@@ -180,8 +187,6 @@ impl<S: CostModelStorageManager> CostModelImpl<S> {
180187
// get_filter_selectivity() function, but this may change in the future.
181188
let join_filter_selectivity = match filter_expr_tree {
182189
Some(filter_expr_tree) => {
183-
// FIXME(group_id): Pass in group id or schema & attr_refs
184-
let group_id = GroupId(0);
185190
self.get_filter_selectivity(group_id, filter_expr_tree)
186191
.await?
187192
}
@@ -405,20 +410,28 @@ mod tests {
405410
use std::collections::HashMap;
406411

407412
use crate::{
408-
common::{predicates::attr_index_pred, types::TableId, values::Value},
413+
common::{
414+
predicates::{attr_index_pred, constant_pred::ConstantType},
415+
properties::Attribute,
416+
types::TableId,
417+
values::Value,
418+
},
409419
cost_model::tests::{
410420
attr_index, bin_op, cnst, create_four_table_mock_cost_model, create_mock_cost_model,
411421
create_three_table_mock_cost_model, create_two_table_mock_cost_model,
412422
create_two_table_mock_cost_model_custom_row_cnts, empty_per_attr_stats, log_op,
413423
per_attr_stats_with_dist_and_ndistinct, per_attr_stats_with_ndistinct,
414-
TestOptCostModelMock, TestPerAttributeStats, TEST_TABLE1_ID, TEST_TABLE2_ID,
415-
TEST_TABLE3_ID, TEST_TABLE4_ID,
424+
TestOptCostModelMock, TestPerAttributeStats, TEST_ATTR1_NAME, TEST_ATTR2_NAME,
425+
TEST_TABLE1_ID, TEST_TABLE2_ID, TEST_TABLE3_ID, TEST_TABLE4_ID,
416426
},
427+
memo_ext::tests::MemoGroupInfo,
417428
stats::DEFAULT_EQ_SEL,
418429
};
419430

420431
use super::*;
421432

433+
const JOIN_GROUP_ID: GroupId = GroupId(10);
434+
422435
/// A wrapper around get_join_selectivity_from_expr_tree that extracts the
423436
/// table row counts from the cost model.
424437
async fn test_get_join_selectivity(
@@ -436,6 +449,7 @@ mod tests {
436449
cost_model
437450
.get_join_selectivity_from_expr_tree(
438451
join_typ,
452+
JOIN_GROUP_ID,
439453
expr_tree,
440454
attr_refs,
441455
input_correlation,
@@ -448,6 +462,7 @@ mod tests {
448462
cost_model
449463
.get_join_selectivity_from_expr_tree(
450464
join_typ,
465+
JOIN_GROUP_ID,
451466
expr_tree,
452467
attr_refs,
453468
input_correlation,
@@ -470,6 +485,7 @@ mod tests {
470485
cost_model
471486
.get_join_selectivity_from_expr_tree(
472487
JoinType::Inner,
488+
JOIN_GROUP_ID,
473489
cnst(Value::Bool(true)),
474490
&vec![],
475491
None,
@@ -484,6 +500,7 @@ mod tests {
484500
cost_model
485501
.get_join_selectivity_from_expr_tree(
486502
JoinType::Inner,
503+
JOIN_GROUP_ID,
487504
cnst(Value::Bool(false)),
488505
&vec![],
489506
None,
@@ -501,6 +518,7 @@ mod tests {
501518
let cost_model = create_two_table_mock_cost_model(
502519
per_attr_stats_with_ndistinct(5),
503520
per_attr_stats_with_ndistinct(4),
521+
None,
504522
);
505523

506524
let attr_refs = vec![
@@ -540,6 +558,7 @@ mod tests {
540558
let cost_model = create_two_table_mock_cost_model(
541559
per_attr_stats_with_ndistinct(5),
542560
per_attr_stats_with_ndistinct(4),
561+
None,
543562
);
544563

545564
let attr_refs = vec![
@@ -578,11 +597,28 @@ mod tests {
578597
}
579598

580599
#[tokio::test]
581-
#[ignore = "index out of bounds: the len is 1 but the index is 1"]
582600
async fn test_inner_and_of_oncond_and_filter() {
601+
let join_memo = HashMap::from([(
602+
JOIN_GROUP_ID,
603+
MemoGroupInfo::new(
604+
vec![
605+
Attribute::new_non_null_int64(TEST_ATTR1_NAME.to_string()),
606+
Attribute::new_non_null_int64(TEST_ATTR2_NAME.to_string()),
607+
]
608+
.into(),
609+
GroupAttrRefs::new(
610+
vec![
611+
AttrRef::new_base_table_attr_ref(TEST_TABLE1_ID, 0),
612+
AttrRef::new_base_table_attr_ref(TEST_TABLE2_ID, 0),
613+
],
614+
None,
615+
),
616+
),
617+
)]);
583618
let cost_model = create_two_table_mock_cost_model(
584619
per_attr_stats_with_ndistinct(5),
585620
per_attr_stats_with_ndistinct(4),
621+
Some(join_memo),
586622
);
587623

588624
let attr_refs = vec![
@@ -621,11 +657,28 @@ mod tests {
621657
}
622658

623659
#[tokio::test]
624-
#[ignore = "filter todo"]
625660
async fn test_inner_and_of_filters() {
661+
let join_memo = HashMap::from([(
662+
JOIN_GROUP_ID,
663+
MemoGroupInfo::new(
664+
vec![
665+
Attribute::new_non_null_int64(TEST_ATTR1_NAME.to_string()),
666+
Attribute::new_non_null_int64(TEST_ATTR2_NAME.to_string()),
667+
]
668+
.into(),
669+
GroupAttrRefs::new(
670+
vec![
671+
AttrRef::new_base_table_attr_ref(TEST_TABLE1_ID, 0),
672+
AttrRef::new_base_table_attr_ref(TEST_TABLE2_ID, 0),
673+
],
674+
None,
675+
),
676+
),
677+
)]);
626678
let cost_model = create_two_table_mock_cost_model(
627679
per_attr_stats_with_ndistinct(5),
628680
per_attr_stats_with_ndistinct(4),
681+
Some(join_memo),
629682
);
630683

631684
let attr_refs = vec![
@@ -668,6 +721,7 @@ mod tests {
668721
let cost_model = create_two_table_mock_cost_model(
669722
per_attr_stats_with_ndistinct(5),
670723
per_attr_stats_with_ndistinct(4),
724+
None,
671725
);
672726

673727
let attr_refs = vec![
@@ -812,6 +866,7 @@ mod tests {
812866
per_attr_stats_with_ndistinct(4),
813867
5,
814868
4,
869+
None,
815870
);
816871

817872
let attr_refs = vec![
@@ -863,6 +918,7 @@ mod tests {
863918
per_attr_stats_with_ndistinct(4),
864919
10,
865920
8,
921+
None,
866922
);
867923

868924
let attr_refs = vec![
@@ -916,6 +972,7 @@ mod tests {
916972
per_attr_stats_with_ndistinct(2),
917973
20,
918974
4,
975+
None,
919976
);
920977

921978
let attr_refs = vec![
@@ -964,11 +1021,29 @@ mod tests {
9641021
/// the inner will be < 1 / row count of both tables
9651022
#[tokio::test]
9661023
async fn test_outer_unique_oncond_filter() {
1024+
let join_memo = HashMap::from([(
1025+
JOIN_GROUP_ID,
1026+
MemoGroupInfo::new(
1027+
vec![
1028+
Attribute::new_non_null_int64(TEST_ATTR1_NAME.to_string()),
1029+
Attribute::new_non_null_int64(TEST_ATTR2_NAME.to_string()),
1030+
]
1031+
.into(),
1032+
GroupAttrRefs::new(
1033+
vec![
1034+
AttrRef::new_base_table_attr_ref(TEST_TABLE1_ID, 0),
1035+
AttrRef::new_base_table_attr_ref(TEST_TABLE2_ID, 0),
1036+
],
1037+
None,
1038+
),
1039+
),
1040+
)]);
9671041
let cost_model = create_two_table_mock_cost_model_custom_row_cnts(
9681042
per_attr_stats_with_dist_and_ndistinct(vec![(Value::Int32(128), 0.4)], 50),
9691043
per_attr_stats_with_ndistinct(4),
9701044
50,
9711045
4,
1046+
Some(join_memo),
9721047
);
9731048

9741049
let attr_refs = vec![

optd-cost-model/src/cost/join/hash_join.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ impl<S: CostModelStorageManager> CostModelImpl<S> {
3737
let input_correlation = get_input_correlation(left_attr_refs, right_attr_refs);
3838
self.get_join_selectivity_from_keys(
3939
join_typ,
40+
group_id,
4041
left_keys,
4142
right_keys,
4243
output_attr_refs.attr_refs(),

optd-cost-model/src/cost/join/nested_loop_join.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ impl<S: CostModelStorageManager> CostModelImpl<S> {
3232

3333
self.get_join_selectivity_from_expr_tree(
3434
join_typ,
35+
group_id,
3536
join_cond,
3637
output_attr_refs.attr_refs(),
3738
input_correlation,

0 commit comments

Comments
 (0)