Skip to content

Commit 2b0f71d

Browse files
author
longshan.lu
committed
feat: Implement IN and NOT IN list expression rewriting in SQL planner, enhancing filter handling for SQL queries
1 parent 42b9474 commit 2b0f71d

File tree

8 files changed

+543
-53
lines changed

8 files changed

+543
-53
lines changed

qurious/src/execution/session.rs

Lines changed: 25 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -385,39 +385,36 @@ mod tests {
385385
// validate we actually have BRAZIL rows in the derived `all_nations` subquery.
386386
let debug = session.sql(
387387
"
388-
389388
select
390-
c_custkey,
391-
c_name,
392-
sum(l_extendedprice * (1 - l_discount)) as revenue,
393-
c_acctbal,
394-
n_name,
395-
c_address,
396-
c_phone,
397-
c_comment
389+
l_shipmode,
390+
sum(case
391+
when o_orderpriority = '1-URGENT'
392+
or o_orderpriority = '2-HIGH'
393+
then 1
394+
else 0
395+
end) as high_line_count,
396+
sum(case
397+
when o_orderpriority != '1-URGENT'
398+
and o_orderpriority != '2-HIGH'
399+
then 1
400+
else 0
401+
end) as low_line_count
398402
from
399-
customer,
400-
orders,
401-
lineitem,
402-
nation
403+
lineitem
404+
join
405+
orders
406+
on
407+
l_orderkey = o_orderkey
403408
where
404-
c_custkey = o_custkey
405-
and l_orderkey = o_orderkey
406-
and o_orderdate >= date '1993-10-01'
407-
and o_orderdate < date '1994-01-01'
408-
and l_returnflag = 'R'
409-
and c_nationkey = n_nationkey
409+
l_shipmode in ('MAIL', 'SHIP')
410+
and l_commitdate < l_receiptdate
411+
and l_shipdate < l_commitdate
412+
and l_receiptdate >= date '1994-01-01'
413+
and l_receiptdate < date '1995-01-01'
410414
group by
411-
c_custkey,
412-
c_name,
413-
c_acctbal,
414-
c_phone,
415-
n_name,
416-
c_address,
417-
c_comment
415+
l_shipmode
418416
order by
419-
revenue desc
420-
limit 10;
417+
l_shipmode;
421418
",
422419
)?;
423420
print_batches(&debug)?;

qurious/src/logical/expr/aggregate.rs

Lines changed: 188 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,46 @@ use std::fmt::Display;
1212
use std::sync::Arc;
1313

1414
use super::Column;
15+
use crate::datatypes::scalar::ScalarValue;
16+
17+
/// Format an expression for *naming* purposes, stripping out CAST(...) (and nested alias) wrappers.
18+
///
19+
/// Type coercion may insert casts without changing the logical meaning of an expression; we don't
20+
/// want those casts to affect output field names, otherwise downstream column lookups can break
21+
/// (e.g. SUM(a*b) vs SUM(a*CAST(b AS ...))).
22+
fn fmt_expr_for_name(expr: &LogicalExpr) -> String {
23+
match expr {
24+
LogicalExpr::Cast(c) => fmt_expr_for_name(&c.expr),
25+
LogicalExpr::Alias(a) => fmt_expr_for_name(&a.expr),
26+
LogicalExpr::Column(c) => c.to_string(),
27+
LogicalExpr::Literal(v) => v.to_string(),
28+
LogicalExpr::Negative(e) => format!("- {}", fmt_expr_for_name(e)),
29+
LogicalExpr::BinaryExpr(b) => format!(
30+
"{} {} {}",
31+
fmt_expr_for_name(&b.left),
32+
b.op,
33+
fmt_expr_for_name(&b.right)
34+
),
35+
LogicalExpr::Case(case) => {
36+
let mut s = String::from("CASE");
37+
if let Some(op) = &case.operand {
38+
s.push(' ');
39+
s.push_str(&fmt_expr_for_name(op));
40+
}
41+
for (w, t) in &case.when_then {
42+
s.push_str(" WHEN ");
43+
s.push_str(&fmt_expr_for_name(w));
44+
s.push_str(" THEN ");
45+
s.push_str(&fmt_expr_for_name(t));
46+
}
47+
s.push_str(" ELSE ");
48+
s.push_str(&fmt_expr_for_name(&case.else_expr));
49+
s.push_str(" END");
50+
s
51+
}
52+
other => other.to_string(),
53+
}
54+
}
1555

1656
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1757
pub enum AggregateOperator {
@@ -96,10 +136,16 @@ pub struct AggregateExpr {
96136
impl AggregateExpr {
97137
pub fn field(&self, plan: &LogicalPlan) -> Result<FieldRef> {
98138
self.expr.field(plan).and_then(|field| {
99-
let col_name = if let LogicalExpr::Column(inner) = self.expr.as_ref() {
100-
&inner.qualified_name()
101-
} else {
102-
field.name()
139+
// Use the *expression string* for non-column arguments, otherwise we may generate
140+
// names like COUNT(i32) from Arrow field names which won't match expression display.
141+
//
142+
// Special case: COUNT(*) is rewritten to COUNT(1) by `CountWildcardRule`, but the
143+
// output column name must remain COUNT(*) for SQL compatibility / tests.
144+
let col_name = match (self.op.clone(), self.expr.as_ref()) {
145+
(AggregateOperator::Count, LogicalExpr::Literal(ScalarValue::Int32(Some(1))))
146+
| (AggregateOperator::Count, LogicalExpr::Literal(ScalarValue::Int64(Some(1)))) => "*".to_string(),
147+
(_, LogicalExpr::Column(inner)) => inner.qualified_name(),
148+
(_, other) => fmt_expr_for_name(other),
103149
};
104150

105151
Ok(Arc::new(Field::new(
@@ -111,13 +157,30 @@ impl AggregateExpr {
111157
}
112158

113159
pub(crate) fn as_column(&self) -> Result<LogicalExpr> {
114-
self.expr.as_column().map(|inner_col| {
115-
LogicalExpr::Column(Column {
116-
name: format!("{}({})", self.op, inner_col),
117-
relation: None,
118-
is_outer_ref: false,
119-
})
120-
})
160+
// Keep COUNT(*) naming stable even if it was rewritten to COUNT(1) internally.
161+
if self.op == AggregateOperator::Count {
162+
if matches!(
163+
self.expr.as_ref(),
164+
LogicalExpr::Literal(ScalarValue::Int32(Some(1))) | LogicalExpr::Literal(ScalarValue::Int64(Some(1)))
165+
) {
166+
return Ok(LogicalExpr::Column(Column {
167+
name: "COUNT(*)".to_string(),
168+
relation: None,
169+
is_outer_ref: false,
170+
}));
171+
}
172+
}
173+
174+
let arg_name = match self.expr.as_ref() {
175+
LogicalExpr::Column(c) => c.to_string(),
176+
other => fmt_expr_for_name(other),
177+
};
178+
179+
Ok(LogicalExpr::Column(Column {
180+
name: format!("{}({})", self.op, arg_name),
181+
relation: None,
182+
is_outer_ref: false,
183+
}))
121184
}
122185
}
123186

@@ -126,3 +189,117 @@ impl Display for AggregateExpr {
126189
write!(f, "{}({})", self.op, self.expr)
127190
}
128191
}
192+
193+
#[cfg(test)]
194+
mod tests {
195+
use super::*;
196+
use crate::datatypes::operator::Operator;
197+
use crate::logical::expr::{BinaryExpr, CaseExpr, CastExpr, Column, LogicalExpr};
198+
use crate::logical::plan::{EmptyRelation, LogicalPlan};
199+
use arrow::datatypes::{DataType, Field, Schema};
200+
use std::sync::Arc;
201+
202+
fn empty_plan_with_schema(fields: Vec<Field>) -> LogicalPlan {
203+
LogicalPlan::EmptyRelation(EmptyRelation {
204+
produce_one_row: true,
205+
schema: Arc::new(Schema::new(fields)),
206+
})
207+
}
208+
209+
#[test]
210+
fn count_star_keeps_output_name_after_rewrite_to_count_1() {
211+
// Optimizer rule rewrites COUNT(*) -> COUNT(1) for execution.
212+
// However, the output column name must remain COUNT(*) to match SQL surface semantics
213+
// and sqllogictest expectations.
214+
let plan = empty_plan_with_schema(vec![]);
215+
let agg = AggregateExpr {
216+
op: AggregateOperator::Count,
217+
expr: Box::new(LogicalExpr::Literal(ScalarValue::Int32(Some(1)))),
218+
};
219+
220+
let field = agg.field(&plan).unwrap();
221+
assert_eq!(field.name(), "COUNT(*)");
222+
223+
let col = agg.as_column().unwrap();
224+
assert_eq!(col.to_string(), "COUNT(*)");
225+
}
226+
227+
#[test]
228+
fn aggregate_naming_ignores_casts_in_argument_expression() {
229+
// TypeCoercion may insert CASTs (e.g. to make DECIMAL * INT valid),
230+
// but we don't want those CASTs to affect the aggregate output column name.
231+
let plan = empty_plan_with_schema(vec![
232+
Field::new("a", DataType::Decimal128(15, 2), false),
233+
Field::new("b", DataType::Int64, false),
234+
]);
235+
236+
let expr = LogicalExpr::BinaryExpr(BinaryExpr::new(
237+
LogicalExpr::Column(Column::new(
238+
"a",
239+
None::<crate::common::table_relation::TableRelation>,
240+
false,
241+
)),
242+
Operator::Mul,
243+
LogicalExpr::Cast(CastExpr::new(
244+
LogicalExpr::Column(Column::new(
245+
"b",
246+
None::<crate::common::table_relation::TableRelation>,
247+
false,
248+
)),
249+
DataType::Decimal128(20, 0),
250+
)),
251+
));
252+
253+
let agg = AggregateExpr {
254+
op: AggregateOperator::Sum,
255+
expr: Box::new(expr),
256+
};
257+
258+
let field = agg.field(&plan).unwrap();
259+
// cast is ignored for naming: b (not CAST(b AS ...))
260+
assert_eq!(field.name(), "SUM(a * b)");
261+
assert_eq!(agg.as_column().unwrap().to_string(), "SUM(a * b)");
262+
}
263+
264+
#[test]
265+
fn aggregate_naming_ignores_casts_inside_case_expression() {
266+
// Similar to TPCH Q8: CASE branch literals may get casted by type coercion,
267+
// but the aggregate output name should stay stable.
268+
let plan = empty_plan_with_schema(vec![
269+
Field::new("cond", DataType::Boolean, false),
270+
Field::new("v", DataType::Decimal128(38, 4), false),
271+
]);
272+
273+
let case = CaseExpr {
274+
operand: None,
275+
when_then: vec![(
276+
LogicalExpr::Column(Column::new(
277+
"cond",
278+
None::<crate::common::table_relation::TableRelation>,
279+
false,
280+
)),
281+
LogicalExpr::Column(Column::new(
282+
"v",
283+
None::<crate::common::table_relation::TableRelation>,
284+
false,
285+
)),
286+
)],
287+
else_expr: Box::new(LogicalExpr::Cast(CastExpr::new(
288+
LogicalExpr::Literal(ScalarValue::Int64(Some(0))),
289+
DataType::Decimal128(38, 4),
290+
))),
291+
};
292+
293+
let agg = AggregateExpr {
294+
op: AggregateOperator::Sum,
295+
expr: Box::new(LogicalExpr::Case(case)),
296+
};
297+
298+
let field = agg.field(&plan).unwrap();
299+
assert_eq!(field.name(), "SUM(CASE WHEN cond THEN v ELSE Int64(0) END)");
300+
assert_eq!(
301+
agg.as_column().unwrap().to_string(),
302+
"SUM(CASE WHEN cond THEN v ELSE Int64(0) END)"
303+
);
304+
}
305+
}

qurious/src/optimizer/rule/rule_optimizer.rs

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,15 @@ impl RuleBaseOptimizer {
2828
Self {
2929
rules: vec![
3030
Box::new(CountWildcardRule),
31-
Box::new(TypeCoercion),
3231
Box::new(SimplifyExprs),
3332
Box::new(ScalarSubqueryToJoin::default()),
3433
Box::new(DecorrelatePredicateSubquery::default()),
3534
Box::new(EliminateCrossJoin),
3635
Box::new(ExtractEquijoinPredicate),
3736
Box::new(PushdownFilter),
37+
// Run type coercion late so correlated subqueries have been rewritten/decorrelated
38+
// (avoids trying to type-check outer-ref columns inside subquery schemas).
39+
Box::new(TypeCoercion),
3840
],
3941
}
4042
}
@@ -54,3 +56,75 @@ impl Optimizer for RuleBaseOptimizer {
5456
Ok(current_plan)
5557
}
5658
}
59+
60+
#[cfg(test)]
61+
mod tests {
62+
use super::*;
63+
use crate::{test_utils::sql_to_plan, utils};
64+
65+
#[test]
66+
fn tpch_q2_should_keep_part_table_scan() {
67+
// Regression test: optimizer must not drop the `part` relation for TPC-H Q2.
68+
let sql = r#"
69+
select
70+
s_acctbal,
71+
s_name,
72+
n_name,
73+
p_partkey,
74+
p_mfgr,
75+
s_address,
76+
s_phone,
77+
s_comment
78+
from
79+
part,
80+
supplier,
81+
partsupp,
82+
nation,
83+
region
84+
where
85+
p_partkey = ps_partkey
86+
and s_suppkey = ps_suppkey
87+
and p_size = 15
88+
and p_type like '%BRASS'
89+
and s_nationkey = n_nationkey
90+
and n_regionkey = r_regionkey
91+
and r_name = 'EUROPE'
92+
and ps_supplycost = (
93+
select
94+
min(ps_supplycost)
95+
from
96+
partsupp,
97+
supplier,
98+
nation,
99+
region
100+
where
101+
p_partkey = ps_partkey
102+
and s_suppkey = ps_suppkey
103+
and s_nationkey = n_nationkey
104+
and n_regionkey = r_regionkey
105+
and r_name = 'EUROPE'
106+
)
107+
order by
108+
s_acctbal desc,
109+
n_name,
110+
s_name,
111+
p_partkey
112+
limit 10;
113+
"#;
114+
115+
let plan = sql_to_plan(sql);
116+
let original = utils::format(&plan, 0);
117+
assert!(
118+
original.contains("TableScan: part"),
119+
"original plan lost part scan:\n{original}"
120+
);
121+
122+
let optimizer = RuleBaseOptimizer::new();
123+
let optimized = optimizer.optimize(&plan).unwrap();
124+
let formatted = utils::format(&optimized, 0);
125+
assert!(
126+
formatted.contains("TableScan: part"),
127+
"optimized plan lost part scan:\n{formatted}"
128+
);
129+
}
130+
}

0 commit comments

Comments
 (0)