Skip to content

Commit eaf51ba

Browse files
authored
Unparse map to sql (#13532)
* map to sql * fix sqllogictest for map arg len error * match array expr concisely
1 parent c0ca4b4 commit eaf51ba

File tree

4 files changed

+46
-4
lines changed

4 files changed

+46
-4
lines changed

datafusion/functions-nested/src/map.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,9 +214,9 @@ impl ScalarUDFImpl for MapFunc {
214214
}
215215

216216
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
217-
if arg_types.len() % 2 != 0 {
217+
if arg_types.len() != 2 {
218218
return exec_err!(
219-
"map requires an even number of arguments, got {} instead",
219+
"map requires exactly 2 arguments, got {} instead",
220220
arg_types.len()
221221
);
222222
}

datafusion/sql/src/unparser/expr.rs

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,7 @@ impl Unparser<'_> {
464464
"array_element" => self.array_element_to_sql(args),
465465
"named_struct" => self.named_struct_to_sql(args),
466466
"get_field" => self.get_field_to_sql(args),
467+
"map" => self.map_to_sql(args),
467468
// TODO: support for the construct and access functions of the `map` type
468469
_ => self.scalar_function_to_sql_internal(func_name, args),
469470
}
@@ -567,6 +568,39 @@ impl Unparser<'_> {
567568
Ok(ast::Expr::CompoundIdentifier(id))
568569
}
569570

571+
fn map_to_sql(&self, args: &[Expr]) -> Result<ast::Expr> {
572+
if args.len() != 2 {
573+
return internal_err!("map must have exactly 2 arguments");
574+
}
575+
576+
let ast::Expr::Array(Array { elem: keys, .. }) = self.expr_to_sql(&args[0])?
577+
else {
578+
return internal_err!(
579+
"map expects first argument to be an array, but received: {:?}",
580+
&args[0]
581+
);
582+
};
583+
584+
let ast::Expr::Array(Array { elem: values, .. }) = self.expr_to_sql(&args[1])?
585+
else {
586+
return internal_err!(
587+
"map expects second argument to be an array, but received: {:?}",
588+
&args[1]
589+
);
590+
};
591+
592+
let entries = keys
593+
.into_iter()
594+
.zip(values)
595+
.map(|(key, value)| ast::MapEntry {
596+
key: Box::new(key),
597+
value: Box::new(value),
598+
})
599+
.collect();
600+
601+
Ok(ast::Expr::Map(ast::Map { entries }))
602+
}
603+
570604
pub fn sort_to_sql(&self, sort: &Sort) -> Result<ast::OrderByExpr> {
571605
let Sort {
572606
expr,
@@ -1581,6 +1615,7 @@ mod tests {
15811615
use datafusion_functions_aggregate::count::count_udaf;
15821616
use datafusion_functions_aggregate::expr_fn::sum;
15831617
use datafusion_functions_nested::expr_fn::{array_element, make_array};
1618+
use datafusion_functions_nested::map::map;
15841619
use datafusion_functions_window::row_number::row_number_udwf;
15851620

15861621
use crate::unparser::dialect::{
@@ -1996,6 +2031,10 @@ mod tests {
19962031
"{a: '1', b: 2}",
19972032
),
19982033
(get_field(col("a.b"), "c"), "a.b.c"),
2034+
(
2035+
map(vec![lit("a"), lit("b")], vec![lit(1), lit(2)]),
2036+
"MAP {'a': 1, 'b': 2}",
2037+
),
19992038
];
20002039

20012040
for (expr, expected) in tests {

datafusion/sql/tests/cases/plan_to_sql.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ use datafusion_expr::{col, lit, table_scan, wildcard, LogicalPlanBuilder};
2727
use datafusion_functions::unicode;
2828
use datafusion_functions_aggregate::grouping::grouping_udaf;
2929
use datafusion_functions_nested::make_array::make_array_udf;
30+
use datafusion_functions_nested::map::map_udf;
3031
use datafusion_functions_window::rank::rank_udwf;
3132
use datafusion_sql::planner::{ContextProvider, PlannerContext, SqlToRel};
3233
use datafusion_sql::unparser::dialect::{
@@ -190,7 +191,8 @@ fn roundtrip_statement() -> Result<()> {
190191
"SELECT [1, 2, 3][1]",
191192
"SELECT left[1] FROM array",
192193
"SELECT {a:1, b:2}",
193-
"SELECT s.a FROM (SELECT {a:1, b:2} AS s)"
194+
"SELECT s.a FROM (SELECT {a:1, b:2} AS s)",
195+
"SELECT MAP {'a': 1, 'b': 2}"
194196
];
195197

196198
// For each test sql string, we transform as follows:
@@ -206,6 +208,7 @@ fn roundtrip_statement() -> Result<()> {
206208
let state = MockSessionState::default()
207209
.with_scalar_function(make_array_udf())
208210
.with_scalar_function(array_element_udf())
211+
.with_scalar_function(map_udf())
209212
.with_aggregate_function(sum_udaf())
210213
.with_aggregate_function(count_udaf())
211214
.with_aggregate_function(max_udaf())

datafusion/sqllogictest/test_files/map.slt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ SELECT MAP([[1,2], [3,4]], ['a', 'b']);
185185
query error
186186
SELECT MAP()
187187

188-
query error DataFusion error: Execution error: map requires an even number of arguments, got 1 instead
188+
query error DataFusion error: Execution error: map requires exactly 2 arguments, got 1 instead
189189
SELECT MAP(['POST', 'HEAD'])
190190

191191
query error DataFusion error: Execution error: Expected list, large_list or fixed_size_list, got Null

0 commit comments

Comments
 (0)