Skip to content

Commit aaea901

Browse files
author
cancai
committed
feat: Support default alias for non-column refs for sql select
1 parent ede725e commit aaea901

File tree

4 files changed

+216
-3
lines changed

4 files changed

+216
-3
lines changed

src/daft-sql/src/expr_name.rs

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
use sqlparser::ast::{
2+
BinaryOperator, Expr as SQLExpr, FunctionArg, FunctionArgExpr, FunctionArguments, Ident,
3+
ObjectName, ObjectNamePart, UnaryOperator, Value, ValueWithSpan,
4+
};
5+
6+
fn ident_to_string(ident: &Ident) -> String {
7+
ident.value.clone()
8+
}
9+
10+
fn object_name_to_string(name: &ObjectName) -> String {
11+
name.0
12+
.iter()
13+
.map(|part| match part {
14+
ObjectNamePart::Identifier(ident) => ident_to_string(ident),
15+
ObjectNamePart::Function(func) => func.name.value.clone(),
16+
})
17+
.collect::<Vec<_>>()
18+
.join(".")
19+
}
20+
21+
fn value_to_string(v: &Value) -> String {
22+
match v {
23+
Value::Number(n, _) => n.clone(),
24+
Value::Boolean(b) => b.to_string(),
25+
Value::Null => "null".to_string(),
26+
Value::SingleQuotedString(s) | Value::DoubleQuotedString(s) => format!("\"{s}\""),
27+
other => format!("{other}"),
28+
}
29+
}
30+
31+
fn value_with_span_to_string(v: &ValueWithSpan) -> String {
32+
value_to_string(&v.value)
33+
}
34+
35+
fn binary_op_to_string(op: &BinaryOperator) -> String {
36+
match op {
37+
BinaryOperator::Plus => "+".to_string(),
38+
BinaryOperator::Minus => "-".to_string(),
39+
BinaryOperator::Multiply => "*".to_string(),
40+
BinaryOperator::Divide => "/".to_string(),
41+
BinaryOperator::Modulo => "%".to_string(),
42+
BinaryOperator::StringConcat => "||".to_string(),
43+
other => format!("{other}"),
44+
}
45+
}
46+
47+
fn unary_op_to_string(op: &UnaryOperator) -> String {
48+
match op {
49+
UnaryOperator::Plus => "+".to_string(),
50+
UnaryOperator::Minus => "-".to_string(),
51+
UnaryOperator::Not => "not".to_string(),
52+
other => format!("{other}"),
53+
}
54+
}
55+
56+
fn expr_to_string(e: &SQLExpr) -> String {
57+
match e {
58+
SQLExpr::Identifier(ident) => ident_to_string(ident),
59+
SQLExpr::CompoundIdentifier(idents) => idents
60+
.iter()
61+
.map(ident_to_string)
62+
.collect::<Vec<_>>()
63+
.join("."),
64+
SQLExpr::Value(v) => value_with_span_to_string(v),
65+
66+
SQLExpr::UnaryOp { op, expr } => {
67+
let op = unary_op_to_string(op);
68+
format!("({op} {})", expr_to_string(expr))
69+
}
70+
SQLExpr::BinaryOp { left, op, right } => {
71+
let op = binary_op_to_string(op);
72+
format!("({} {op} {})", expr_to_string(left), expr_to_string(right))
73+
}
74+
75+
SQLExpr::Nested(inner) => expr_to_string(inner),
76+
SQLExpr::Cast {
77+
expr, data_type, ..
78+
} => {
79+
format!("cast({} as {data_type})", expr_to_string(expr))
80+
}
81+
82+
SQLExpr::Function(func) => {
83+
let fn_name = object_name_to_string(&func.name).to_lowercase();
84+
85+
// Special-case COUNT(*) to preserve the star.
86+
if fn_name == "count"
87+
&& matches!(
88+
&func.args,
89+
FunctionArguments::List(args)
90+
if args.args.len() == 1
91+
&& matches!(&args.args[0], FunctionArg::Unnamed(FunctionArgExpr::Wildcard))
92+
)
93+
{
94+
return "count(*)".to_string();
95+
}
96+
97+
let args: Vec<String> = match &func.args {
98+
FunctionArguments::None => vec![],
99+
FunctionArguments::Subquery(_) => vec!["<subquery>".to_string()],
100+
FunctionArguments::List(args) => args
101+
.args
102+
.iter()
103+
.map(|arg| match arg {
104+
FunctionArg::Named { arg, .. } | FunctionArg::ExprNamed { arg, .. } => {
105+
match arg {
106+
FunctionArgExpr::Expr(e) => expr_to_string(e),
107+
FunctionArgExpr::QualifiedWildcard(obj) => {
108+
format!("{}.*", object_name_to_string(obj))
109+
}
110+
FunctionArgExpr::Wildcard => "*".to_string(),
111+
}
112+
}
113+
FunctionArg::Unnamed(arg) => match arg {
114+
FunctionArgExpr::Expr(e) => expr_to_string(e),
115+
FunctionArgExpr::QualifiedWildcard(obj) => {
116+
format!("{}.*", object_name_to_string(obj))
117+
}
118+
FunctionArgExpr::Wildcard => "*".to_string(),
119+
},
120+
})
121+
.collect(),
122+
};
123+
124+
format!("{fn_name}({})", args.join(", "))
125+
}
126+
127+
// If we haven't normalized a specific expression yet, fall back to sqlparser's Display
128+
// to avoid changing the set of supported queries.
129+
other => format!("{other}"),
130+
}
131+
}
132+
133+
/// Returns a stable, user-facing name for an unnamed SQL projection expression.
134+
///
135+
/// This is intended to be shared across SQL (AST -> name) and, in the future,
136+
/// DataFrame (DSL Expr -> name) so that default column names follow consistent rules.
137+
pub fn normalized_sql_expr_name(expr: &SQLExpr) -> String {
138+
expr_to_string(expr)
139+
}

src/daft-sql/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ pub mod error;
22
pub mod functions;
33

44
mod exec;
5+
mod expr_name;
56
mod modules;
67
mod planner;
78
mod schema;

src/daft-sql/src/planner.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ use sqlparser::{
3838
};
3939

4040
use crate::{
41-
column_not_found_err, error::*, invalid_operation_err, schema::sql_dtype_to_dtype,
42-
statement::Statement, table_not_found_err, unsupported_sql_err,
41+
column_not_found_err, error::*, expr_name::normalized_sql_expr_name, invalid_operation_err,
42+
schema::sql_dtype_to_dtype, statement::Statement, table_not_found_err, unsupported_sql_err,
4343
};
4444

4545
/// Bindings are used to lookup in-scope tables, views, and columns (targets T).
@@ -1192,7 +1192,11 @@ impl SQLPlanner<'_> {
11921192
self.bound_columns.insert(alias.clone(), expr.clone());
11931193
Ok(vec![expr.alias(alias)])
11941194
}
1195-
SelectItem::UnnamedExpr(expr) => self.plan_expr(expr).map(|e| vec![e]),
1195+
SelectItem::UnnamedExpr(expr) => {
1196+
let name = normalized_sql_expr_name(expr);
1197+
let expr = self.plan_expr(expr)?;
1198+
Ok(vec![expr.alias(name)])
1199+
}
11961200

11971201
SelectItem::Wildcard(wildcard_opts) => {
11981202
check_wildcard_options(wildcard_opts)?;
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
from __future__ import annotations
2+
3+
import daft
4+
5+
6+
def test_sql_default_column_name_agg() -> None:
7+
df = daft.from_pydict({"a": [1, 2, 3]})
8+
9+
res = daft.sql("select sum(a) from df", df=df).to_pydict()
10+
11+
assert res == {"sum(a)": [6]}
12+
13+
14+
def test_sql_default_column_name_count() -> None:
15+
df = daft.from_pydict({"a": [1, 2, 3]})
16+
17+
res = daft.sql("select count(*) from df", df=df).to_pydict()
18+
19+
assert res == {"count(*)": [3]}
20+
21+
res = daft.sql("select count(1) from df", df=df).to_pydict()
22+
23+
assert res == {"count(1)": [3]}
24+
25+
res = daft.sql("select count(a) from df", df=df).to_pydict()
26+
27+
assert res == {"count(a)": [3]}
28+
29+
30+
def test_sql_default_column_name_non_agg_functions() -> None:
31+
df = daft.from_pydict({"a": ["AbC", None]})
32+
33+
res = daft.sql("select upper(a) from df", df=df).to_pydict()
34+
assert res == {"upper(a)": ["ABC", None]}
35+
36+
res = daft.sql("select coalesce(a, 'x') from df", df=df).to_pydict()
37+
assert res == {'coalesce(a, "x")': ["AbC", "x"]}
38+
39+
# String concatenation operator `||` is not supported; use concat().
40+
res = daft.sql("select concat(upper(a), 'x') from df", df=df).to_pydict()
41+
assert res == {'concat(upper(a), "x")': ["ABCx", None]}
42+
43+
44+
def test_sql_default_column_name_binary_op_is_parenthesized() -> None:
45+
df = daft.from_pydict({"a": [1, 2, 3]})
46+
47+
res = daft.sql("select a + 1 from df", df=df).to_pydict()
48+
49+
assert res == {"(a + 1)": [2, 3, 4]}
50+
51+
52+
def test_sql_default_column_name_string_literal_is_quoted() -> None:
53+
df = daft.from_pydict({"a": [1]})
54+
55+
res = daft.sql("select 'a' from df", df=df).to_pydict()
56+
57+
assert res == {'"a"': ["a"]}
58+
59+
60+
def test_sql_explicit_alias_overrides_default_name() -> None:
61+
df = daft.from_pydict({"a": [1, 2, 3]})
62+
63+
res = daft.sql("select sum(a) as s from df", df=df).to_pydict()
64+
65+
assert res == {"s": [6]}
66+
67+
res = daft.sql("select sum(a) as S from df", df=df).to_pydict()
68+
69+
assert res == {"S": [6]}

0 commit comments

Comments
 (0)