Skip to content

Commit eae4e05

Browse files
WinkerDuMazterQyou
authored andcommitted
Add proper support for null literal by introducing ScalarValue::Null (apache#2364)
* introduce null * fix fmt
1 parent e4087a1 commit eae4e05

File tree

15 files changed

+180
-38
lines changed

15 files changed

+180
-38
lines changed

datafusion/common/src/scalar.rs

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc};
3939
/// This is the single-valued counter-part of arrow’s `Array`.
4040
#[derive(Clone)]
4141
pub enum ScalarValue {
42+
/// represents `DataType::Null` (castable to/from any other type)
43+
Null,
4244
/// true or false value
4345
Boolean(Option<bool>),
4446
/// 32bit float
@@ -170,6 +172,8 @@ impl PartialEq for ScalarValue {
170172
(IntervalMonthDayNano(_), _) => false,
171173
(Struct(v1, t1), Struct(v2, t2)) => v1.eq(v2) && t1.eq(t2),
172174
(Struct(_, _), _) => false,
175+
(Null, Null) => true,
176+
(Null, _) => false,
173177
}
174178
}
175179
}
@@ -270,6 +274,8 @@ impl PartialOrd for ScalarValue {
270274
}
271275
}
272276
(Struct(_, _), _) => None,
277+
(Null, Null) => Some(Ordering::Equal),
278+
(Null, _) => None,
273279
}
274280
}
275281
}
@@ -325,6 +331,8 @@ impl std::hash::Hash for ScalarValue {
325331
v.hash(state);
326332
t.hash(state);
327333
}
334+
// stable hash for Null value
335+
Null => 1.hash(state),
328336
}
329337
}
330338
}
@@ -594,6 +602,7 @@ impl ScalarValue {
594602
DataType::Interval(IntervalUnit::MonthDayNano)
595603
}
596604
ScalarValue::Struct(_, fields) => DataType::Struct(fields.as_ref().clone()),
605+
ScalarValue::Null => DataType::Null,
597606
}
598607
}
599608

@@ -623,7 +632,8 @@ impl ScalarValue {
623632
pub fn is_null(&self) -> bool {
624633
matches!(
625634
*self,
626-
ScalarValue::Boolean(None)
635+
ScalarValue::Null
636+
| ScalarValue::Boolean(None)
627637
| ScalarValue::UInt8(None)
628638
| ScalarValue::UInt16(None)
629639
| ScalarValue::UInt32(None)
@@ -844,6 +854,7 @@ impl ScalarValue {
844854
ScalarValue::iter_to_decimal_array(scalars, precision, scale)?;
845855
Arc::new(decimal_array)
846856
}
857+
DataType::Null => ScalarValue::iter_to_null_array(scalars),
847858
DataType::Boolean => build_array_primitive!(BooleanArray, Boolean),
848859
DataType::Float32 => build_array_primitive!(Float32Array, Float32),
849860
DataType::Float64 => build_array_primitive!(Float64Array, Float64),
@@ -976,6 +987,17 @@ impl ScalarValue {
976987
Ok(array)
977988
}
978989

990+
fn iter_to_null_array(scalars: impl IntoIterator<Item = ScalarValue>) -> ArrayRef {
991+
let length =
992+
scalars
993+
.into_iter()
994+
.fold(0usize, |r, element: ScalarValue| match element {
995+
ScalarValue::Null => r + 1,
996+
_ => unreachable!(),
997+
});
998+
new_null_array(&DataType::Null, length)
999+
}
1000+
9791001
fn iter_to_decimal_array(
9801002
scalars: impl IntoIterator<Item = ScalarValue>,
9811003
precision: &usize,
@@ -1249,6 +1271,7 @@ impl ScalarValue {
12491271
Arc::new(StructArray::from(field_values))
12501272
}
12511273
},
1274+
ScalarValue::Null => new_null_array(&DataType::Null, size),
12521275
}
12531276
}
12541277

@@ -1274,6 +1297,7 @@ impl ScalarValue {
12741297
}
12751298

12761299
Ok(match array.data_type() {
1300+
DataType::Null => ScalarValue::Null,
12771301
DataType::Decimal(precision, scale) => {
12781302
ScalarValue::get_decimal_value_from_array(array, index, precision, scale)
12791303
}
@@ -1519,6 +1543,7 @@ impl ScalarValue {
15191543
eq_array_primitive!(array, index, IntervalMonthDayNanoArray, val)
15201544
}
15211545
ScalarValue::Struct(_, _) => unimplemented!(),
1546+
ScalarValue::Null => array.data().is_null(index),
15221547
}
15231548
}
15241549

@@ -1740,6 +1765,7 @@ impl TryFrom<&DataType> for ScalarValue {
17401765
DataType::Struct(fields) => {
17411766
ScalarValue::Struct(None, Box::new(fields.clone()))
17421767
}
1768+
DataType::Null => ScalarValue::Null,
17431769
_ => {
17441770
return Err(DataFusionError::NotImplemented(format!(
17451771
"Can't create a scalar from data_type \"{:?}\"",
@@ -1832,6 +1858,7 @@ impl fmt::Display for ScalarValue {
18321858
)?,
18331859
None => write!(f, "NULL")?,
18341860
},
1861+
ScalarValue::Null => write!(f, "NULL")?,
18351862
};
18361863
Ok(())
18371864
}
@@ -1899,6 +1926,7 @@ impl fmt::Debug for ScalarValue {
18991926
None => write!(f, "Struct(NULL)"),
19001927
}
19011928
}
1929+
ScalarValue::Null => write!(f, "NULL"),
19021930
}
19031931
}
19041932
}

datafusion/core/src/logical_plan/builder.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ impl LogicalPlanBuilder {
154154
.iter()
155155
.enumerate()
156156
.map(|(j, expr)| {
157-
if let Expr::Literal(ScalarValue::Utf8(None)) = expr {
157+
if let Expr::Literal(ScalarValue::Null) = expr {
158158
nulls.push((i, j));
159159
Ok(field_types[j].clone())
160160
} else {

datafusion/core/src/physical_plan/hash_join.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -845,7 +845,11 @@ fn equal_rows(
845845
.iter()
846846
.zip(right_arrays)
847847
.all(|(l, r)| match l.data_type() {
848-
DataType::Null => true,
848+
DataType::Null => {
849+
// lhs and rhs are both `DataType::Null`, so the euqal result
850+
// is dependent on `null_equals_null`
851+
null_equals_null
852+
}
849853
DataType::Boolean => {
850854
equal_rows_elem!(BooleanArray, l, r, left, right, null_equals_null)
851855
}

datafusion/core/src/physical_plan/hash_utils.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,19 @@ fn combine_hashes(l: u64, r: u64) -> u64 {
3939
hash.wrapping_mul(37).wrapping_add(r)
4040
}
4141

42+
fn hash_null(random_state: &RandomState, hashes_buffer: &'_ mut [u64], mul_col: bool) {
43+
if mul_col {
44+
hashes_buffer.iter_mut().for_each(|hash| {
45+
// stable hash for null value
46+
*hash = combine_hashes(i128::get_hash(&1, random_state), *hash);
47+
})
48+
} else {
49+
hashes_buffer.iter_mut().for_each(|hash| {
50+
*hash = i128::get_hash(&1, random_state);
51+
})
52+
}
53+
}
54+
4255
fn hash_decimal128<'a>(
4356
array: &ArrayRef,
4457
random_state: &RandomState,
@@ -310,6 +323,9 @@ pub fn create_hashes<'a>(
310323

311324
for col in arrays {
312325
match col.data_type() {
326+
DataType::Null => {
327+
hash_null(random_state, hashes_buffer, multi_col);
328+
}
313329
DataType::Decimal(_, _) => {
314330
hash_decimal128(col, random_state, hashes_buffer, multi_col);
315331
}

datafusion/core/src/sql/planner.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1681,7 +1681,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
16811681
SQLExpr::Value(Value::Number(n, _)) => parse_sql_number(&n),
16821682
SQLExpr::Value(Value::SingleQuotedString(s)) => Ok(lit(s)),
16831683
SQLExpr::Value(Value::Null) => {
1684-
Ok(Expr::Literal(ScalarValue::Utf8(None)))
1684+
Ok(Expr::Literal(ScalarValue::Null))
16851685
}
16861686
SQLExpr::Value(Value::Boolean(n)) => Ok(lit(n)),
16871687
SQLExpr::UnaryOp { op, expr } => {
@@ -1707,7 +1707,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
17071707
SQLExpr::Value(Value::SingleQuotedString(ref s)) => Ok(lit(s.clone())),
17081708
SQLExpr::Value(Value::EscapedStringLiteral(ref s)) => Ok(lit(s.clone())),
17091709
SQLExpr::Value(Value::Boolean(n)) => Ok(lit(n)),
1710-
SQLExpr::Value(Value::Null) => Ok(Expr::Literal(ScalarValue::Utf8(None))),
1710+
SQLExpr::Value(Value::Null) => Ok(Expr::Literal(ScalarValue::Null)),
17111711
SQLExpr::Extract { field, expr } => Ok(Expr::ScalarFunction {
17121712
fun: BuiltinScalarFunction::DatePart,
17131713
args: vec![
@@ -4258,9 +4258,9 @@ mod tests {
42584258
fn union_with_null() {
42594259
let sql = "SELECT NULL a UNION ALL SELECT 1.1 a";
42604260
let expected = "Union\
4261-
\n Projection: Utf8(NULL) AS a\
4261+
\n Projection: CAST(NULL AS Float64) AS a\
42624262
\n EmptyRelation\
4263-
\n Projection: CAST(Float64(1.1) AS Utf8) AS a\
4263+
\n Projection: Float64(1.1) AS a\
42644264
\n EmptyRelation";
42654265
quick_test(sql, expected);
42664266
}

datafusion/core/tests/sql/expr.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -362,11 +362,11 @@ async fn test_string_concat_operator() -> Result<()> {
362362
let sql = "SELECT 'aa' || NULL || 'd'";
363363
let actual = execute_to_batches(&ctx, sql).await;
364364
let expected = vec![
365-
"+---------------------------------------+",
366-
"| Utf8(\"aa\") || Utf8(NULL) || Utf8(\"d\") |",
367-
"+---------------------------------------+",
368-
"| |",
369-
"+---------------------------------------+",
365+
"+---------------------------------+",
366+
"| Utf8(\"aa\") || NULL || Utf8(\"d\") |",
367+
"+---------------------------------+",
368+
"| |",
369+
"+---------------------------------+",
370370
];
371371
assert_batches_eq!(expected, &actual);
372372

datafusion/core/tests/sql/functions.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -197,11 +197,11 @@ async fn coalesce_static_value_with_null() -> Result<()> {
197197
let sql = "SELECT COALESCE(NULL, 'test')";
198198
let actual = execute_to_batches(&ctx, sql).await;
199199
let expected = vec![
200-
"+-----------------------------------+",
201-
"| coalesce(Utf8(NULL),Utf8(\"test\")) |",
202-
"+-----------------------------------+",
203-
"| test |",
204-
"+-----------------------------------+",
200+
"+-----------------------------+",
201+
"| coalesce(NULL,Utf8(\"test\")) |",
202+
"+-----------------------------+",
203+
"| test |",
204+
"+-----------------------------+",
205205
];
206206
assert_batches_eq!(expected, &actual);
207207
Ok(())

datafusion/core/tests/sql/joins.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -840,7 +840,11 @@ async fn inner_join_nulls() {
840840
let sql = "SELECT * FROM (SELECT null AS id1) t1
841841
INNER JOIN (SELECT null AS id2) t2 ON id1 = id2";
842842

843-
let expected = vec!["++", "++"];
843+
#[rustfmt::skip]
844+
let expected = vec![
845+
"++",
846+
"++",
847+
];
844848

845849
let ctx = create_join_context_qualified("t1", "t2").unwrap();
846850
let actual = execute_to_batches(&ctx, sql).await;

datafusion/core/tests/sql/select.rs

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -398,15 +398,37 @@ async fn select_distinct_from() {
398398
1 IS NOT DISTINCT FROM CAST(NULL as INT) as c,
399399
1 IS NOT DISTINCT FROM 1 as d,
400400
NULL IS DISTINCT FROM NULL as e,
401-
NULL IS NOT DISTINCT FROM NULL as f
401+
NULL IS NOT DISTINCT FROM NULL as f,
402+
NULL is DISTINCT FROM 1 as g,
403+
NULL is NOT DISTINCT FROM 1 as h
402404
";
403405
let actual = execute_to_batches(&ctx, sql).await;
404406
let expected = vec![
405-
"+------+-------+-------+------+-------+------+",
406-
"| a | b | c | d | e | f |",
407-
"+------+-------+-------+------+-------+------+",
408-
"| true | false | false | true | false | true |",
409-
"+------+-------+-------+------+-------+------+",
407+
"+------+-------+-------+------+-------+------+------+-------+",
408+
"| a | b | c | d | e | f | g | h |",
409+
"+------+-------+-------+------+-------+------+------+-------+",
410+
"| true | false | false | true | false | true | true | false |",
411+
"+------+-------+-------+------+-------+------+------+-------+",
412+
];
413+
assert_batches_eq!(expected, &actual);
414+
415+
let sql = "select
416+
NULL IS DISTINCT FROM NULL as a,
417+
NULL IS NOT DISTINCT FROM NULL as b,
418+
NULL is DISTINCT FROM 1 as c,
419+
NULL is NOT DISTINCT FROM 1 as d,
420+
1 IS DISTINCT FROM CAST(NULL as INT) as e,
421+
1 IS DISTINCT FROM 1 as f,
422+
1 IS NOT DISTINCT FROM CAST(NULL as INT) as g,
423+
1 IS NOT DISTINCT FROM 1 as h
424+
";
425+
let actual = execute_to_batches(&ctx, sql).await;
426+
let expected = vec![
427+
"+-------+------+------+-------+------+-------+-------+------+",
428+
"| a | b | c | d | e | f | g | h |",
429+
"+-------+------+------+-------+------+-------+-------+------+",
430+
"| false | true | true | false | true | false | false | true |",
431+
"+-------+------+------+-------+------+-------+-------+------+",
410432
];
411433
assert_batches_eq!(expected, &actual);
412434
}

datafusion/expr/src/binary_rule.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -625,6 +625,7 @@ fn eq_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
625625
numerical_coercion(lhs_type, rhs_type)
626626
.or_else(|| dictionary_coercion(lhs_type, rhs_type))
627627
.or_else(|| temporal_coercion(lhs_type, rhs_type))
628+
.or_else(|| null_coercion(lhs_type, rhs_type))
628629
}
629630

630631
/// Coercion rule for interval

0 commit comments

Comments
 (0)