Skip to content
This repository was archived by the owner on Jan 7, 2025. It is now read-only.

Commit 279b561

Browse files
committed
refactor(adv-cost): compiles and not rely on opt ctx (#223)
Signed-off-by: Alex Chi <[email protected]>
1 parent 4470ade commit 279b561

File tree

18 files changed

+379
-457
lines changed

18 files changed

+379
-457
lines changed

Cargo.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion-optd-cli/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ tokio = { version = "1.24", features = [
5858
] }
5959
url = "2.2"
6060
optd-datafusion-bridge = { path = "../optd-datafusion-bridge" }
61-
# optd-datafusion-repr-adv-cost = { path = "../optd-datafusion-repr-adv-cost" }
61+
optd-datafusion-repr-adv-cost = { path = "../optd-datafusion-repr-adv-cost" }
6262
optd-datafusion-repr = { path = "../optd-datafusion-repr" }
6363
tracing-subscriber = "0.3"
6464
tracing = "0.1"

optd-core/src/nodes.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,17 @@ impl Value {
188188
Value::Int64(i64) => (*i64).try_into().unwrap(),
189189
_ => panic!("{self} could not be converted into an Int32"),
190190
}),
191+
DataType::Int64 => Value::Int64(match self {
192+
Value::Int64(i64) => *i64,
193+
Value::Int32(i32) => (*i32).try_into().unwrap(),
194+
_ => panic!("{self} could not be converted into an Int64"),
195+
}),
196+
DataType::UInt64 => Value::UInt64(match self {
197+
Value::Int64(i64) => (*i64).try_into().unwrap(),
198+
Value::UInt64(i64) => *i64,
199+
Value::UInt32(i32) => (*i32).try_into().unwrap(),
200+
_ => panic!("{self} could not be converted into an UInt64"),
201+
}),
191202
DataType::Date32 => Value::Date32(match self {
192203
Value::Date32(date32) => *date32,
193204
Value::String(str) => {
@@ -307,6 +318,12 @@ impl<T: NodeType> From<ArcPlanNode<T>> for PlanNodeOrGroup<T> {
307318
}
308319
}
309320

321+
impl<T: NodeType> From<GroupId> for PlanNodeOrGroup<T> {
322+
fn from(value: GroupId) -> Self {
323+
Self::Group(value)
324+
}
325+
}
326+
310327
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
311328
pub struct PredNode<T: NodeType> {
312329
/// A generic predicate node type

optd-datafusion-repr-adv-cost/src/adv_stats.rs

Lines changed: 27 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,13 @@ impl<
7171
mod tests {
7272
use arrow_schema::DataType;
7373
use itertools::Itertools;
74-
use optd_core::nodes::Value;
75-
use optd_datafusion_repr::plan_nodes::{
76-
BinOpExpr, BinOpType, CastExpr, ColumnRefExpr, ConstantExpr, Expr, ExprList, InListExpr,
77-
LikeExpr, LogOpExpr, LogOpType, OptRelNode, OptRelNodeRef, UnOpExpr, UnOpType,
74+
use optd_datafusion_repr::{
75+
plan_nodes::{
76+
ArcDfPredNode, BinOpPred, BinOpType, CastPred, ColumnRefPred, ConstantPred,
77+
DfReprPredNode, InListPred, LikePred, ListPred, LogOpPred, LogOpType, UnOpPred,
78+
UnOpType,
79+
},
80+
Value,
7881
};
7982
use serde::{Deserialize, Serialize};
8083
use std::collections::HashMap;
@@ -285,74 +288,46 @@ mod tests {
285288
)
286289
}
287290

288-
pub fn col_ref(idx: u64) -> OptRelNodeRef {
291+
pub fn col_ref(idx: u64) -> ArcDfPredNode {
289292
// this conversion is always safe because idx was originally a usize
290293
let idx_as_usize = idx as usize;
291-
ColumnRefExpr::new(idx_as_usize).into_rel_node()
294+
ColumnRefPred::new(idx_as_usize).into_pred_node()
292295
}
293296

294-
pub fn cnst(value: Value) -> OptRelNodeRef {
295-
ConstantExpr::new(value).into_rel_node()
297+
pub fn cnst(value: Value) -> ArcDfPredNode {
298+
ConstantPred::new(value).into_pred_node()
296299
}
297300

298-
pub fn cast(child: OptRelNodeRef, cast_type: DataType) -> OptRelNodeRef {
299-
CastExpr::new(
300-
Expr::from_rel_node(child).expect("child should be an Expr"),
301-
cast_type,
302-
)
303-
.into_rel_node()
301+
pub fn cast(child: ArcDfPredNode, cast_type: DataType) -> ArcDfPredNode {
302+
CastPred::new(child, cast_type).into_pred_node()
304303
}
305304

306-
pub fn bin_op(op_type: BinOpType, left: OptRelNodeRef, right: OptRelNodeRef) -> OptRelNodeRef {
307-
BinOpExpr::new(
308-
Expr::from_rel_node(left).expect("left should be an Expr"),
309-
Expr::from_rel_node(right).expect("right should be an Expr"),
310-
op_type,
311-
)
312-
.into_rel_node()
305+
pub fn bin_op(op_type: BinOpType, left: ArcDfPredNode, right: ArcDfPredNode) -> ArcDfPredNode {
306+
BinOpPred::new(left, right, op_type).into_pred_node()
313307
}
314308

315-
pub fn log_op(op_type: LogOpType, children: Vec<OptRelNodeRef>) -> OptRelNodeRef {
316-
LogOpExpr::new(
317-
op_type,
318-
ExprList::new(
319-
children
320-
.into_iter()
321-
.map(|opt_rel_node_ref| {
322-
Expr::from_rel_node(opt_rel_node_ref).expect("all children should be Expr")
323-
})
324-
.collect(),
325-
),
326-
)
327-
.into_rel_node()
309+
pub fn log_op(op_type: LogOpType, children: Vec<ArcDfPredNode>) -> ArcDfPredNode {
310+
LogOpPred::new(op_type, children).into_pred_node()
328311
}
329312

330-
pub fn un_op(op_type: UnOpType, child: OptRelNodeRef) -> OptRelNodeRef {
331-
UnOpExpr::new(
332-
Expr::from_rel_node(child).expect("child should be an Expr"),
333-
op_type,
334-
)
335-
.into_rel_node()
313+
pub fn un_op(op_type: UnOpType, child: ArcDfPredNode) -> ArcDfPredNode {
314+
UnOpPred::new(child, op_type).into_pred_node()
336315
}
337316

338-
pub fn in_list(col_ref_idx: u64, list: Vec<Value>, negated: bool) -> InListExpr {
339-
InListExpr::new(
340-
Expr::from_rel_node(col_ref(col_ref_idx)).unwrap(),
341-
ExprList::new(
342-
list.into_iter()
343-
.map(|v| Expr::from_rel_node(cnst(v)).unwrap())
344-
.collect_vec(),
345-
),
317+
pub fn in_list(col_ref_idx: u64, list: Vec<Value>, negated: bool) -> InListPred {
318+
InListPred::new(
319+
col_ref(col_ref_idx),
320+
ListPred::new(list.into_iter().map(|v| cnst(v)).collect_vec()),
346321
negated,
347322
)
348323
}
349324

350-
pub fn like(col_ref_idx: u64, pattern: &str, negated: bool) -> LikeExpr {
351-
LikeExpr::new(
325+
pub fn like(col_ref_idx: u64, pattern: &str, negated: bool) -> LikePred {
326+
LikePred::new(
352327
negated,
353328
false,
354-
Expr::from_rel_node(col_ref(col_ref_idx)).unwrap(),
355-
Expr::from_rel_node(cnst(Value::String(pattern.into()))).unwrap(),
329+
col_ref(col_ref_idx),
330+
cnst(Value::String(pattern.into())),
356331
)
357332
}
358333

Lines changed: 39 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,88 +1,59 @@
1-
use optd_core::{
2-
cascades::{BindingType, CascadesOptimizer, RelNodeContext},
3-
cost::Cost,
4-
};
1+
use optd_core::cascades::{CascadesOptimizer, RelNodeContext};
52
use serde::{de::DeserializeOwned, Serialize};
63

7-
use crate::adv_cost::{
4+
use crate::adv_stats::{
85
stats::{Distribution, MostCommonValues},
96
DEFAULT_NUM_DISTINCT,
107
};
118
use optd_datafusion_repr::{
12-
plan_nodes::{ExprList, OptRelNode, OptRelNodeTyp},
13-
properties::column_ref::{BaseTableColumnRef, ColumnRef, ColumnRefPropertyBuilder},
9+
plan_nodes::{ArcDfPredNode, DfReprPredNode, ListPred},
10+
properties::column_ref::{
11+
BaseTableColumnRef, ColumnRef, ColumnRefPropertyBuilder, GroupColumnRefs,
12+
},
1413
};
1514

16-
use super::{OptCostModel, DEFAULT_UNK_SEL};
15+
use super::{AdvStats, DEFAULT_UNK_SEL};
1716

1817
impl<
1918
M: MostCommonValues + Serialize + DeserializeOwned,
2019
D: Distribution + Serialize + DeserializeOwned,
21-
> OptCostModel<M, D>
20+
> AdvStats<M, D>
2221
{
23-
pub(super) fn get_agg_cost(
24-
&self,
25-
children: &[Cost],
26-
context: Option<RelNodeContext>,
27-
optimizer: Option<&CascadesOptimizer<DfNodeType>>,
28-
) -> Cost {
29-
let child_row_cnt = Self::row_cnt(&children[0]);
30-
let row_cnt = self.get_agg_row_cnt(context, optimizer, child_row_cnt);
31-
let (_, compute_cost_1, _) = Self::cost_tuple(&children[1]);
32-
let (_, compute_cost_2, _) = Self::cost_tuple(&children[2]);
33-
Self::cost(
34-
row_cnt,
35-
child_row_cnt * (compute_cost_1 + compute_cost_2),
36-
0.0,
37-
)
38-
}
39-
40-
fn get_agg_row_cnt(
22+
pub(crate) fn get_agg_row_cnt(
4123
&self,
42-
context: Option<RelNodeContext>,
43-
optimizer: Option<&CascadesOptimizer<DfNodeType>>,
44-
child_row_cnt: f64,
24+
group_by: ArcDfPredNode,
25+
output_col_refs: GroupColumnRefs,
4526
) -> f64 {
46-
if let (Some(context), Some(optimizer)) = (context, optimizer) {
47-
let group_by_id = context.children_group_ids[2];
48-
let group_by = optimizer
49-
.get_predicate_binding(group_by_id)
50-
.expect("no expression found?");
51-
let group_by = ExprList::from_rel_node(group_by).unwrap();
52-
if group_by.is_empty() {
53-
1.0
54-
} else {
55-
// Multiply the n-distinct of all the group by columns.
56-
// TODO: improve with multi-dimensional n-distinct
57-
let group_col_refs = optimizer
58-
.get_property_by_group::<ColumnRefPropertyBuilder>(context.group_id, 1);
59-
group_col_refs
60-
.base_table_column_refs()
61-
.iter()
62-
.take(group_by.len())
63-
.map(|col_ref| match col_ref {
64-
ColumnRef::BaseTableColumnRef(BaseTableColumnRef { table, col_idx }) => {
65-
let table_stats = self.per_table_stats_map.get(table);
66-
let column_stats = table_stats.and_then(|table_stats| {
67-
table_stats.column_comb_stats.get(&vec![*col_idx])
68-
});
27+
let group_by = ListPred::from_pred_node(group_by).unwrap();
28+
if group_by.is_empty() {
29+
1.0
30+
} else {
31+
// Multiply the n-distinct of all the group by columns.
32+
// TODO: improve with multi-dimensional n-distinct
33+
output_col_refs
34+
.base_table_column_refs()
35+
.iter()
36+
.take(group_by.len())
37+
.map(|col_ref| match col_ref {
38+
ColumnRef::BaseTableColumnRef(BaseTableColumnRef { table, col_idx }) => {
39+
let table_stats = self.per_table_stats_map.get(table);
40+
let column_stats = table_stats.and_then(|table_stats| {
41+
table_stats.column_comb_stats.get(&vec![*col_idx])
42+
});
6943

70-
if let Some(column_stats) = column_stats {
71-
column_stats.ndistinct as f64
72-
} else {
73-
// The column type is not supported or stats are missing.
74-
DEFAULT_NUM_DISTINCT as f64
75-
}
44+
if let Some(column_stats) = column_stats {
45+
column_stats.ndistinct as f64
46+
} else {
47+
// The column type is not supported or stats are missing.
48+
DEFAULT_NUM_DISTINCT as f64
7649
}
77-
ColumnRef::Derived => DEFAULT_NUM_DISTINCT as f64,
78-
_ => panic!(
79-
"GROUP BY base table column ref must either be derived or base table"
80-
),
81-
})
82-
.product()
83-
}
84-
} else {
85-
(child_row_cnt * DEFAULT_UNK_SEL).max(1.0)
50+
}
51+
ColumnRef::Derived => DEFAULT_NUM_DISTINCT as f64,
52+
_ => panic!(
53+
"GROUP BY base table column ref must either be derived or base table"
54+
),
55+
})
56+
.product()
8657
}
8758
}
8859
}

0 commit comments

Comments
 (0)