Skip to content

Commit 42b9474

Browse files
author
longshan.lu
committed
feat: Introduce CASE expression support in SQL parser and planner, enhancing expression handling and type coercion for conditional logic
1 parent 7d6d9d0 commit 42b9474

File tree

24 files changed

+1041
-45
lines changed

24 files changed

+1041
-45
lines changed

qurious/src/common/table_schema.rs

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,14 @@ use arrow::datatypes::{DataType, Field, FieldRef, Schema, SchemaRef};
1010

1111
pub type TableSchemaRef = Arc<TableSchema>;
1212

13+
/// Arrow schema metadata key used to preserve per-field qualifiers (table/alias) across planning stages.
14+
///
15+
/// This is needed because Arrow `Schema` fields are identified by name only, and we allow duplicate
16+
/// column names across different relations (e.g. `nation n1`, `nation n2` both have `n_name`).
17+
/// Physical planning must be able to map a `(relation, column_name)` to the correct field index.
18+
pub const FIELD_QUALIFIERS_META_KEY: &str = "qurious.field_qualifiers";
19+
const FIELD_QUALIFIERS_META_SEP: char = '\u{1f}'; // unit separator (unlikely to appear in names)
20+
1321
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1422
pub struct TableSchema {
1523
pub schema: SchemaRef,
@@ -48,7 +56,24 @@ impl TableSchema {
4856
}
4957

5058
pub fn arrow_schema(&self) -> SchemaRef {
51-
self.schema.clone()
59+
// Preserve qualifier information via Schema metadata so physical planning can disambiguate
60+
// same-named fields from different relations.
61+
let mut metadata = self.schema.metadata().clone();
62+
let qualifiers = self
63+
.field_qualifiers
64+
.iter()
65+
.map(|q| q.as_ref().map(|t| t.to_qualified_name()).unwrap_or_default())
66+
.collect::<Vec<_>>()
67+
.join(&FIELD_QUALIFIERS_META_SEP.to_string());
68+
metadata.insert(FIELD_QUALIFIERS_META_KEY.to_string(), qualifiers);
69+
70+
let fields = self
71+
.schema
72+
.fields()
73+
.iter()
74+
.map(|f| f.as_ref().clone())
75+
.collect::<Vec<_>>();
76+
Arc::new(Schema::new_with_metadata(fields, metadata))
5277
}
5378

5479
pub fn has_field(&self, qualifier: Option<&TableRelation>, name: &str) -> bool {

qurious/src/datatypes/scalar.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,18 @@ impl ScalarValue {
213213
DataType::LargeUtf8 => typed_cast!(array, index, LargeStringArray, Utf8),
214214
DataType::Decimal128(p, s) => typed_cast_decimal!(Decimal128Array, Decimal128, array, index, *p, *s),
215215
DataType::Decimal256(p, s) => typed_cast_decimal!(Decimal256Array, Decimal256, array, index, *p, *s),
216+
DataType::Interval(IntervalUnit::MonthDayNano) => {
217+
let arr = array
218+
.as_any()
219+
.downcast_ref::<IntervalMonthDayNanoArray>()
220+
.ok_or_else(|| {
221+
Error::InternalError(format!(
222+
"could not cast value to {}",
223+
type_name::<IntervalMonthDayNanoArray>()
224+
))
225+
})?;
226+
Ok(ScalarValue::IntervalMonthDayNano(Some(arr.value(index))))
227+
}
216228
_ => unimplemented!("data type {} not supported", array.data_type()),
217229
}
218230
}
@@ -257,6 +269,9 @@ impl TryFrom<&DataType> for ScalarValue {
257269
DataType::Float64 => Ok(ScalarValue::Float64(None)),
258270
DataType::Utf8 => Ok(ScalarValue::Utf8(None)),
259271
DataType::LargeUtf8 => Ok(ScalarValue::Utf8(None)),
272+
DataType::Decimal128(p, s) => Ok(ScalarValue::Decimal128(None, *p, *s)),
273+
DataType::Decimal256(p, s) => Ok(ScalarValue::Decimal256(None, *p, *s)),
274+
DataType::Interval(IntervalUnit::MonthDayNano) => Ok(ScalarValue::IntervalMonthDayNano(None)),
260275
_ => unimplemented!("data type {} not supported", value),
261276
}
262277
}

qurious/src/execution/session.rs

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -381,23 +381,46 @@ mod tests {
381381
// session.sql("INSERT INTO test VALUES (1, 1), (2, 2), (3, 3), (3, 5), (NULL, NULL);")?;
382382
// session.sql("select a, b, c, d from x join y on a = c")?;
383383
println!("++++++++++++++");
384-
let batch = session
385-
.sql(
386-
"
384+
// Debug helpers for TPC-H Q8 while this test is ignored:
385+
// validate we actually have BRAZIL rows in the derived `all_nations` subquery.
386+
let debug = session.sql(
387+
"
388+
387389
select
388-
sum(l_extendedprice * l_discount) as revenue
390+
c_custkey,
391+
c_name,
392+
sum(l_extendedprice * (1 - l_discount)) as revenue,
393+
c_acctbal,
394+
n_name,
395+
c_address,
396+
c_phone,
397+
c_comment
389398
from
390-
lineitem
399+
customer,
400+
orders,
401+
lineitem,
402+
nation
391403
where
392-
l_shipdate >= date '1994-01-01'
393-
and l_shipdate < date '1995-01-01'
394-
and l_discount between 0.06 - 0.01 and 0.06 + 0.01
395-
and l_quantity < 24;
404+
c_custkey = o_custkey
405+
and l_orderkey = o_orderkey
406+
and o_orderdate >= date '1993-10-01'
407+
and o_orderdate < date '1994-01-01'
408+
and l_returnflag = 'R'
409+
and c_nationkey = n_nationkey
410+
group by
411+
c_custkey,
412+
c_name,
413+
c_acctbal,
414+
c_phone,
415+
n_name,
416+
c_address,
417+
c_comment
418+
order by
419+
revenue desc
420+
limit 10;
396421
",
397-
)
398-
.unwrap();
399-
400-
print_batches(&batch)?;
422+
)?;
423+
print_batches(&debug)?;
401424

402425
Ok(())
403426
}
@@ -422,7 +445,17 @@ where
422445
session.register_table("t", Arc::new(datasource))?;
423446

424447
let batch = session.sql("SELECT a.* FROM t as a")?;
425-
assert_eq!(data, batch);
448+
// Compare data ignoring schema-level metadata (we attach qualifiers in schema metadata
449+
// to support disambiguation across aliased relations).
450+
assert_eq!(data.len(), batch.len());
451+
for (expected, actual) in data.iter().zip(batch.iter()) {
452+
assert_eq!(expected.schema().fields(), actual.schema().fields());
453+
assert_eq!(expected.num_rows(), actual.num_rows());
454+
assert_eq!(expected.num_columns(), actual.num_columns());
455+
for i in 0..expected.num_columns() {
456+
assert_eq!(expected.column(i).as_ref(), actual.column(i).as_ref());
457+
}
458+
}
426459

427460
Ok(())
428461
}

qurious/src/functions/datetime/extract.rs

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use std::str::FromStr;
33
use arrow::array::Array;
44
use arrow::compute::kernels::cast_utils::IntervalUnit;
55
use arrow::{
6+
array::new_empty_array,
67
array::{ArrayRef, AsArray},
78
compute::{cast, date_part, DatePart},
89
datatypes::DataType,
@@ -31,16 +32,34 @@ impl UserDefinedFunction for DatetimeExtract {
3132
return Err(Error::InvalidArgumentError("EXTRACT requires 2 arguments".to_string()));
3233
}
3334

35+
// EXTRACT is often used in projections that can be evaluated on empty batches (e.g. after filters).
36+
// In that case, literal args can be represented as empty arrays, so we must not assume len > 0.
37+
if args[1].len() == 0 {
38+
return Ok(new_empty_array(&DataType::Int64));
39+
}
40+
3441
// get interval_unit value
3542
let interval_unit = if let Some(val) = args.get(0) {
36-
val.as_string::<i32>().value(0)
43+
let arr = val.as_string::<i32>();
44+
if arr.len() == 0 {
45+
// no rows => empty output already handled above; keep defensive
46+
return Ok(new_empty_array(&DataType::Int64));
47+
}
48+
if !arr.is_valid(0) {
49+
return Err(Error::InvalidArgumentError(
50+
"First argument of `EXTRACT` must be non-null scalar Utf8".to_string(),
51+
));
52+
}
53+
arr.value(0)
3754
} else {
3855
return Err(Error::InvalidArgumentError(
3956
"First argument of `DATE_PART` must be non-null scalar Utf8".to_string(),
4057
));
4158
};
4259

43-
match IntervalUnit::from_str(interval_unit)? {
60+
// Normalize case so we accept YEAR / year / Year, etc.
61+
let interval_unit = interval_unit.to_ascii_lowercase();
62+
match IntervalUnit::from_str(&interval_unit)? {
4463
IntervalUnit::Year => date_part_f64(args[1].as_ref(), DatePart::Year),
4564
IntervalUnit::Month => date_part_f64(args[1].as_ref(), DatePart::Month),
4665
IntervalUnit::Week => date_part_f64(args[1].as_ref(), DatePart::Week),
@@ -60,3 +79,22 @@ impl UserDefinedFunction for DatetimeExtract {
6079
fn date_part_f64(array: &dyn Array, part: DatePart) -> Result<ArrayRef> {
6180
cast(date_part(array, part)?.as_ref(), &DataType::Int64).map_err(|e| arrow_err!(e))
6281
}
82+
83+
#[cfg(test)]
84+
mod tests {
85+
use super::*;
86+
use arrow::array::{new_empty_array, StringArray};
87+
use arrow::datatypes::DataType;
88+
use std::sync::Arc;
89+
90+
#[test]
91+
fn test_extract_empty_input_does_not_panic() {
92+
let udf = DatetimeExtract;
93+
let unit = Arc::new(StringArray::from(vec!["YEAR"])) as ArrayRef;
94+
let empty_date = new_empty_array(&DataType::Date32);
95+
96+
let out = udf.eval(vec![unit, empty_date]).unwrap();
97+
assert_eq!(out.len(), 0);
98+
assert_eq!(out.data_type(), &DataType::Int64);
99+
}
100+
}

qurious/src/logical/expr/case.rs

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
use std::fmt::{Display, Formatter};
2+
use std::sync::Arc;
3+
4+
use arrow::datatypes::{DataType, Field, FieldRef, Schema};
5+
6+
use crate::datatypes::scalar::ScalarValue;
7+
use crate::error::Result;
8+
use crate::logical::plan::LogicalPlan;
9+
10+
use super::LogicalExpr;
11+
12+
/// SQL CASE expression.
13+
///
14+
/// Supports both:
15+
/// - searched CASE: `CASE WHEN <cond> THEN <value> ... ELSE <value> END`
16+
/// - simple CASE: `CASE <operand> WHEN <expr> THEN <value> ... ELSE <value> END`
17+
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
18+
pub struct CaseExpr {
19+
pub operand: Option<Box<LogicalExpr>>,
20+
pub when_then: Vec<(LogicalExpr, LogicalExpr)>,
21+
pub else_expr: Box<LogicalExpr>,
22+
}
23+
24+
impl CaseExpr {
25+
pub fn field(&self, plan: &LogicalPlan) -> Result<FieldRef> {
26+
// Best-effort: use first THEN type (or ELSE) as the output type.
27+
// The optimizer's type coercion rule will cast branches to a common type later.
28+
let schema = plan.schema();
29+
let dt = self.data_type(&schema)?;
30+
Ok(Arc::new(Field::new(format!("{}", self), dt, true)))
31+
}
32+
33+
pub fn data_type(&self, schema: &Arc<Schema>) -> Result<DataType> {
34+
for (_, then_expr) in &self.when_then {
35+
let dt = then_expr.data_type(schema)?;
36+
if dt != DataType::Null {
37+
return Ok(dt);
38+
}
39+
}
40+
self.else_expr.data_type(schema)
41+
}
42+
}
43+
44+
impl Display for CaseExpr {
45+
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
46+
write!(f, "CASE")?;
47+
if let Some(op) = &self.operand {
48+
write!(f, " {op}")?;
49+
}
50+
for (w, t) in &self.when_then {
51+
write!(f, " WHEN {w} THEN {t}")?;
52+
}
53+
// Always print ELSE for determinism (planner fills missing ELSE with NULL).
54+
write!(f, " ELSE {} END", self.else_expr)
55+
}
56+
}
57+
58+
impl From<ScalarValue> for CaseExpr {
59+
fn from(value: ScalarValue) -> Self {
60+
CaseExpr {
61+
operand: None,
62+
when_then: vec![],
63+
else_expr: Box::new(LogicalExpr::Literal(value)),
64+
}
65+
}
66+
}
67+
68+

0 commit comments

Comments
 (0)