Skip to content

Commit 9ba03e6

Browse files
authored
feat(cost-model): introduce MemoExt trait for property info (#38)
To compute the cost for an expression, we need information about the schema & attribute ref (including attribute correlations). In the current optd, this is done by calling the optimizer core's methods. We were against this approach in previous discussions because we thought this makes the core and the cost model coupled too much -- we thereby eliminated the optimizer parameter and intended to get all these information from the storage/ORM. However, there is a performance drawback of getting everything from ORM: the core should have all the information (schema & attribute ref) we need in memory -- it would be more efficient for them to be passed in by the core than querying the underlying external database. This also aligns more with the way the cascades optimizer works: building the memo table in a bottom-up approach and remembering everything. Therefore, to avoid getting everything from ORM and still use one general interface for all types of node, we would need the core to implement a trait provided by the cost model, and the cost model will call the corresponding methods to get the information, i.e. MemoExt in this PR. This allows the core to remain ignorant of what the cost model needs for computing the cost.
1 parent b5aed2b commit 9ba03e6

File tree

8 files changed

+455
-5
lines changed

8 files changed

+455
-5
lines changed

optd-cost-model/src/common/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
pub mod nodes;
22
pub mod predicates;
3+
pub mod properties;
34
pub mod types;
45
pub mod values;
Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
use std::collections::HashSet;
2+
3+
use crate::{common::types::TableId, utils::DisjointSets};
4+
5+
pub type AttrRefs = Vec<AttrRef>;
6+
7+
/// [`BaseTableAttrRef`] represents a reference to an attribute in a base table,
8+
/// i.e. a table existing in the catalog.
9+
#[derive(Clone, Debug, Default, Eq, Hash, PartialEq)]
10+
pub struct BaseTableAttrRef {
11+
pub table_id: TableId,
12+
pub attr_idx: u64,
13+
}
14+
15+
/// [`AttrRef`] represents a reference to an attribute in a query.
16+
#[derive(Clone, Debug)]
17+
pub enum AttrRef {
18+
/// Reference to a base table attribute.
19+
BaseTableAttrRef(BaseTableAttrRef),
20+
/// Reference to a derived attribute (e.g. t.v1 + t.v2).
21+
/// TODO: Better representation of derived attributes.
22+
Derived,
23+
}
24+
25+
impl AttrRef {
26+
pub fn base_table_attr_ref(table_id: TableId, attr_idx: u64) -> Self {
27+
AttrRef::BaseTableAttrRef(BaseTableAttrRef { table_id, attr_idx })
28+
}
29+
}
30+
31+
impl From<BaseTableAttrRef> for AttrRef {
32+
fn from(attr: BaseTableAttrRef) -> Self {
33+
AttrRef::BaseTableAttrRef(attr)
34+
}
35+
}
36+
37+
/// [`EqPredicate`] represents an equality predicate between two attributes.
38+
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
39+
pub struct EqPredicate {
40+
pub left: BaseTableAttrRef,
41+
pub right: BaseTableAttrRef,
42+
}
43+
44+
impl EqPredicate {
45+
pub fn new(left: BaseTableAttrRef, right: BaseTableAttrRef) -> Self {
46+
Self { left, right }
47+
}
48+
}
49+
50+
/// [`SemanticCorrelation`] represents the semantic correlation between attributes in a
51+
/// query. "Semantic" means that the attributes are correlated based on the
52+
/// semantics of the query, not the statistics.
53+
///
54+
/// [`SemanticCorrelation`] contains equal attributes denoted by disjoint sets of base
55+
/// table attributes, e.g. {{ t1.c1 = t2.c1 = t3.c1 }, { t1.c2 = t2.c2 }}.
56+
#[derive(Clone, Debug, Default)]
57+
pub struct SemanticCorrelation {
58+
/// A disjoint set of base table attributes with equal values in the same row.
59+
disjoint_eq_attr_sets: DisjointSets<BaseTableAttrRef>,
60+
/// The predicates that define the equalities.
61+
eq_predicates: HashSet<EqPredicate>,
62+
}
63+
64+
impl SemanticCorrelation {
65+
pub fn new() -> Self {
66+
Self {
67+
disjoint_eq_attr_sets: DisjointSets::new(),
68+
eq_predicates: HashSet::new(),
69+
}
70+
}
71+
72+
pub fn add_predicate(&mut self, predicate: EqPredicate) {
73+
let left = &predicate.left;
74+
let right = &predicate.right;
75+
76+
// Add the indices to the set if they do not exist.
77+
if !self.disjoint_eq_attr_sets.contains(left) {
78+
self.disjoint_eq_attr_sets
79+
.make_set(left.clone())
80+
.expect("just checked left attribute index does not exist");
81+
}
82+
if !self.disjoint_eq_attr_sets.contains(right) {
83+
self.disjoint_eq_attr_sets
84+
.make_set(right.clone())
85+
.expect("just checked right attribute index does not exist");
86+
}
87+
// Union the attributes.
88+
self.disjoint_eq_attr_sets
89+
.union(left, right)
90+
.expect("both attribute indices should exist");
91+
92+
// Keep track of the predicate.
93+
self.eq_predicates.insert(predicate);
94+
}
95+
96+
/// Determine if two attributes are in the same set.
97+
pub fn is_eq(&mut self, left: &BaseTableAttrRef, right: &BaseTableAttrRef) -> bool {
98+
self.disjoint_eq_attr_sets
99+
.same_set(left, right)
100+
.unwrap_or(false)
101+
}
102+
103+
pub fn contains(&self, base_attr_ref: &BaseTableAttrRef) -> bool {
104+
self.disjoint_eq_attr_sets.contains(base_attr_ref)
105+
}
106+
107+
/// Get the number of attributes that are equal to `attr`, including `attr` itself.
108+
pub fn num_eq_attributes(&mut self, attr: &BaseTableAttrRef) -> usize {
109+
self.disjoint_eq_attr_sets.set_size(attr).unwrap()
110+
}
111+
112+
/// Find the set of predicates that define the equality of the set of attributes `attr` belongs to.
113+
pub fn find_predicates_for_eq_attr_set(&mut self, attr: &BaseTableAttrRef) -> Vec<EqPredicate> {
114+
let mut predicates = Vec::new();
115+
for predicate in &self.eq_predicates {
116+
let left = &predicate.left;
117+
let right = &predicate.right;
118+
if (left != attr && self.disjoint_eq_attr_sets.same_set(attr, left).unwrap())
119+
|| (right != attr && self.disjoint_eq_attr_sets.same_set(attr, right).unwrap())
120+
{
121+
predicates.push(predicate.clone());
122+
}
123+
}
124+
predicates
125+
}
126+
127+
/// Find the set of attributes that define the equality of the set of attributes `attr` belongs to.
128+
pub fn find_attrs_for_eq_attribute_set(
129+
&mut self,
130+
attr: &BaseTableAttrRef,
131+
) -> HashSet<BaseTableAttrRef> {
132+
let predicates = self.find_predicates_for_eq_attr_set(attr);
133+
predicates
134+
.into_iter()
135+
.flat_map(|predicate| vec![predicate.left, predicate.right])
136+
.collect()
137+
}
138+
139+
/// Union two `EqBaseTableattributesets` to produce a new disjoint sets.
140+
pub fn union(x: Self, y: Self) -> Self {
141+
let mut eq_attr_sets = Self::new();
142+
for predicate in x
143+
.eq_predicates
144+
.into_iter()
145+
.chain(y.eq_predicates.into_iter())
146+
{
147+
eq_attr_sets.add_predicate(predicate);
148+
}
149+
eq_attr_sets
150+
}
151+
152+
pub fn merge(x: Option<Self>, y: Option<Self>) -> Option<Self> {
153+
let eq_attr_sets = match (x, y) {
154+
(Some(x), Some(y)) => Self::union(x, y),
155+
(Some(x), None) => x.clone(),
156+
(None, Some(y)) => y.clone(),
157+
_ => return None,
158+
};
159+
Some(eq_attr_sets)
160+
}
161+
}
162+
163+
/// [`GroupAttrRefs`] represents the attributes of a group in a query.
164+
#[derive(Clone, Debug)]
165+
pub struct GroupAttrRefs {
166+
attribute_refs: AttrRefs,
167+
/// Correlation of the output attributes of the group.
168+
output_correlation: Option<SemanticCorrelation>,
169+
}
170+
171+
impl GroupAttrRefs {
172+
pub fn new(attribute_refs: AttrRefs, output_correlation: Option<SemanticCorrelation>) -> Self {
173+
Self {
174+
attribute_refs,
175+
output_correlation,
176+
}
177+
}
178+
179+
pub fn base_table_attribute_refs(&self) -> &AttrRefs {
180+
&self.attribute_refs
181+
}
182+
183+
pub fn output_correlation(&self) -> Option<&SemanticCorrelation> {
184+
self.output_correlation.as_ref()
185+
}
186+
}
187+
188+
#[cfg(test)]
189+
mod tests {
190+
use super::*;
191+
192+
#[test]
193+
fn test_eq_base_table_attribute_sets() {
194+
let attr1 = BaseTableAttrRef {
195+
table_id: TableId(1),
196+
attr_idx: 1,
197+
};
198+
let attr2 = BaseTableAttrRef {
199+
table_id: TableId(2),
200+
attr_idx: 2,
201+
};
202+
let attr3 = BaseTableAttrRef {
203+
table_id: TableId(3),
204+
attr_idx: 3,
205+
};
206+
let attr4 = BaseTableAttrRef {
207+
table_id: TableId(4),
208+
attr_idx: 4,
209+
};
210+
let pred1 = EqPredicate::new(attr1.clone(), attr2.clone());
211+
let pred2 = EqPredicate::new(attr3.clone(), attr4.clone());
212+
let pred3 = EqPredicate::new(attr1.clone(), attr3.clone());
213+
214+
let mut eq_attr_sets = SemanticCorrelation::new();
215+
216+
// (1, 2)
217+
eq_attr_sets.add_predicate(pred1.clone());
218+
assert!(eq_attr_sets.is_eq(&attr1, &attr2));
219+
220+
// (1, 2), (3, 4)
221+
eq_attr_sets.add_predicate(pred2.clone());
222+
assert!(eq_attr_sets.is_eq(&attr3, &attr4));
223+
assert!(!eq_attr_sets.is_eq(&attr2, &attr3));
224+
225+
let predicates = eq_attr_sets.find_predicates_for_eq_attr_set(&attr1);
226+
assert_eq!(predicates.len(), 1);
227+
assert!(predicates.contains(&pred1));
228+
229+
let predicates = eq_attr_sets.find_predicates_for_eq_attr_set(&attr3);
230+
assert_eq!(predicates.len(), 1);
231+
assert!(predicates.contains(&pred2));
232+
233+
// (1, 2, 3, 4)
234+
eq_attr_sets.add_predicate(pred3.clone());
235+
assert!(eq_attr_sets.is_eq(&attr1, &attr3));
236+
assert!(eq_attr_sets.is_eq(&attr2, &attr4));
237+
assert!(eq_attr_sets.is_eq(&attr1, &attr4));
238+
239+
let predicates = eq_attr_sets.find_predicates_for_eq_attr_set(&attr1);
240+
assert_eq!(predicates.len(), 3);
241+
assert!(predicates.contains(&pred1));
242+
assert!(predicates.contains(&pred2));
243+
assert!(predicates.contains(&pred3));
244+
}
245+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
use serde::{Deserialize, Serialize};
2+
3+
use super::predicates::constant_pred::ConstantType;
4+
5+
pub mod attr_ref;
6+
pub mod schema;
7+
8+
#[derive(Clone, Debug, Serialize, Deserialize)]
9+
pub struct Attribute {
10+
pub name: String,
11+
pub typ: ConstantType,
12+
pub nullable: bool,
13+
}
14+
15+
impl std::fmt::Display for Attribute {
16+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
17+
if self.nullable {
18+
write!(f, "{}:{:?}", self.name, self.typ)
19+
} else {
20+
write!(f, "{}:{:?}(non-null)", self.name, self.typ)
21+
}
22+
}
23+
}
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
use itertools::Itertools;
2+
3+
use serde::{Deserialize, Serialize};
4+
5+
use super::Attribute;
6+
7+
/// [`Schema`] represents the schema of a group in the memo. It contains a list of attributes.
8+
#[derive(Clone, Debug, Serialize, Deserialize)]
9+
pub struct Schema {
10+
pub attributes: Vec<Attribute>,
11+
}
12+
13+
impl std::fmt::Display for Schema {
14+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
15+
write!(
16+
f,
17+
"[{}]",
18+
self.attributes.iter().map(|x| x.to_string()).join(", ")
19+
)
20+
}
21+
}
22+
23+
impl Schema {
24+
pub fn new(attributes: Vec<Attribute>) -> Self {
25+
Self { attributes }
26+
}
27+
28+
pub fn len(&self) -> usize {
29+
self.attributes.len()
30+
}
31+
32+
pub fn is_empty(&self) -> bool {
33+
self.len() == 0
34+
}
35+
}

optd-cost-model/src/cost_model.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use crate::{
1212
nodes::{ArcPredicateNode, PhysicalNodeType},
1313
types::{AttrId, EpochId, ExprId, TableId},
1414
},
15+
memo_ext::MemoExt,
1516
storage::CostModelStorageManager,
1617
ComputeCostContext, Cost, CostModel, CostModelResult, EstimatedStatistic, StatValue,
1718
};
@@ -20,28 +21,31 @@ use crate::{
2021
pub struct CostModelImpl<S: CostModelStorageLayer> {
2122
storage_manager: CostModelStorageManager<S>,
2223
default_catalog_source: CatalogSource,
24+
_memo: Arc<dyn MemoExt>,
2325
}
2426

2527
impl<S: CostModelStorageLayer> CostModelImpl<S> {
2628
/// TODO: documentation
2729
pub fn new(
2830
storage_manager: CostModelStorageManager<S>,
2931
default_catalog_source: CatalogSource,
32+
memo: Arc<dyn MemoExt>,
3033
) -> Self {
3134
Self {
3235
storage_manager,
3336
default_catalog_source,
37+
_memo: memo,
3438
}
3539
}
3640
}
3741

38-
impl<S: CostModelStorageLayer + std::marker::Sync + 'static> CostModel for CostModelImpl<S> {
42+
impl<S: CostModelStorageLayer + Sync + 'static> CostModel for CostModelImpl<S> {
3943
fn compute_operation_cost(
4044
&self,
4145
node: &PhysicalNodeType,
4246
predicates: &[ArcPredicateNode],
4347
children_stats: &[Option<&EstimatedStatistic>],
44-
context: Option<ComputeCostContext>,
48+
context: ComputeCostContext,
4549
) -> CostModelResult<Cost> {
4650
todo!()
4751
}
@@ -51,7 +55,7 @@ impl<S: CostModelStorageLayer + std::marker::Sync + 'static> CostModel for CostM
5155
node: PhysicalNodeType,
5256
predicates: &[ArcPredicateNode],
5357
children_statistics: &[Option<&EstimatedStatistic>],
54-
context: Option<ComputeCostContext>,
58+
context: ComputeCostContext,
5559
) -> CostModelResult<EstimatedStatistic> {
5660
todo!()
5761
}

optd-cost-model/src/lib.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@ use optd_persistent::{
1010
pub mod common;
1111
pub mod cost;
1212
pub mod cost_model;
13+
pub mod memo_ext;
1314
pub mod stats;
1415
pub mod storage;
16+
pub mod utils;
1517

1618
pub enum StatValue {
1719
Int(i64),
@@ -63,7 +65,7 @@ pub trait CostModel: 'static + Send + Sync {
6365
node: &PhysicalNodeType,
6466
predicates: &[ArcPredicateNode],
6567
children_stats: &[Option<&EstimatedStatistic>],
66-
context: Option<ComputeCostContext>,
68+
context: ComputeCostContext,
6769
) -> CostModelResult<Cost>;
6870

6971
/// TODO: documentation
@@ -76,7 +78,7 @@ pub trait CostModel: 'static + Send + Sync {
7678
node: PhysicalNodeType,
7779
predicates: &[ArcPredicateNode],
7880
children_statistics: &[Option<&EstimatedStatistic>],
79-
context: Option<ComputeCostContext>,
81+
context: ComputeCostContext,
8082
) -> CostModelResult<EstimatedStatistic>;
8183

8284
/// TODO: documentation

0 commit comments

Comments
 (0)