Skip to content

Commit a78e521

Browse files
WinkerDuMazterQyou
authored andcommitted
Add proper support for null literal by introducing ScalarValue::Null (apache#2364)
* introduce null * fix fmt
1 parent 3fc9b4c commit a78e521

File tree

15 files changed

+178
-38
lines changed

15 files changed

+178
-38
lines changed

datafusion/common/src/scalar.rs

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc};
3939
/// This is the single-valued counter-part of arrow’s `Array`.
4040
#[derive(Clone)]
4141
pub enum ScalarValue {
42+
/// represents `DataType::Null` (castable to/from any other type)
43+
Null,
4244
/// true or false value
4345
Boolean(Option<bool>),
4446
/// 32bit float
@@ -170,6 +172,8 @@ impl PartialEq for ScalarValue {
170172
(IntervalMonthDayNano(_), _) => false,
171173
(Struct(v1, t1), Struct(v2, t2)) => v1.eq(v2) && t1.eq(t2),
172174
(Struct(_, _), _) => false,
175+
(Null, Null) => true,
176+
(Null, _) => false,
173177
}
174178
}
175179
}
@@ -270,6 +274,8 @@ impl PartialOrd for ScalarValue {
270274
}
271275
}
272276
(Struct(_, _), _) => None,
277+
(Null, Null) => Some(Ordering::Equal),
278+
(Null, _) => None,
273279
}
274280
}
275281
}
@@ -325,6 +331,8 @@ impl std::hash::Hash for ScalarValue {
325331
v.hash(state);
326332
t.hash(state);
327333
}
334+
// stable hash for Null value
335+
Null => 1.hash(state),
328336
}
329337
}
330338
}
@@ -594,6 +602,7 @@ impl ScalarValue {
594602
DataType::Interval(IntervalUnit::MonthDayNano)
595603
}
596604
ScalarValue::Struct(_, fields) => DataType::Struct(fields.as_ref().clone()),
605+
ScalarValue::Null => DataType::Null,
597606
}
598607
}
599608

@@ -623,7 +632,8 @@ impl ScalarValue {
623632
pub fn is_null(&self) -> bool {
624633
matches!(
625634
*self,
626-
ScalarValue::Boolean(None)
635+
ScalarValue::Null
636+
| ScalarValue::Boolean(None)
627637
| ScalarValue::UInt8(None)
628638
| ScalarValue::UInt16(None)
629639
| ScalarValue::UInt32(None)
@@ -847,6 +857,7 @@ impl ScalarValue {
847857
ScalarValue::iter_to_decimal_array(scalars, precision, scale)?;
848858
Arc::new(decimal_array)
849859
}
860+
DataType::Null => ScalarValue::iter_to_null_array(scalars),
850861
DataType::Boolean => build_array_primitive!(BooleanArray, Boolean),
851862
DataType::Float32 => build_array_primitive!(Float32Array, Float32),
852863
DataType::Float64 => build_array_primitive!(Float64Array, Float64),
@@ -979,6 +990,17 @@ impl ScalarValue {
979990
Ok(array)
980991
}
981992

993+
fn iter_to_null_array(scalars: impl IntoIterator<Item = ScalarValue>) -> ArrayRef {
994+
let length =
995+
scalars
996+
.into_iter()
997+
.fold(0usize, |r, element: ScalarValue| match element {
998+
ScalarValue::Null => r + 1,
999+
_ => unreachable!(),
1000+
});
1001+
new_null_array(&DataType::Null, length)
1002+
}
1003+
9821004
fn iter_to_decimal_array(
9831005
scalars: impl IntoIterator<Item = ScalarValue>,
9841006
precision: &usize,
@@ -1252,6 +1274,7 @@ impl ScalarValue {
12521274
Arc::new(StructArray::from(field_values))
12531275
}
12541276
},
1277+
ScalarValue::Null => new_null_array(&DataType::Null, size),
12551278
}
12561279
}
12571280

@@ -1277,6 +1300,7 @@ impl ScalarValue {
12771300
}
12781301

12791302
Ok(match array.data_type() {
1303+
DataType::Null => ScalarValue::Null,
12801304
DataType::Decimal(precision, scale) => {
12811305
ScalarValue::get_decimal_value_from_array(array, index, precision, scale)
12821306
}
@@ -1530,6 +1554,7 @@ impl ScalarValue {
15301554
eq_array_primitive!(array, index, IntervalMonthDayNanoArray, val)
15311555
}
15321556
ScalarValue::Struct(_, _) => unimplemented!(),
1557+
ScalarValue::Null => array.data().is_null(index),
15331558
}
15341559
}
15351560

@@ -1760,6 +1785,7 @@ impl TryFrom<&DataType> for ScalarValue {
17601785
DataType::Interval(IntervalUnit::MonthDayNano) => {
17611786
ScalarValue::IntervalMonthDayNano(None)
17621787
}
1788+
DataType::Null => ScalarValue::Null,
17631789
_ => {
17641790
return Err(DataFusionError::NotImplemented(format!(
17651791
"Can't create a scalar from data_type \"{:?}\"",
@@ -1852,6 +1878,7 @@ impl fmt::Display for ScalarValue {
18521878
)?,
18531879
None => write!(f, "NULL")?,
18541880
},
1881+
ScalarValue::Null => write!(f, "NULL")?,
18551882
};
18561883
Ok(())
18571884
}
@@ -1919,6 +1946,7 @@ impl fmt::Debug for ScalarValue {
19191946
None => write!(f, "Struct(NULL)"),
19201947
}
19211948
}
1949+
ScalarValue::Null => write!(f, "NULL"),
19221950
}
19231951
}
19241952
}

datafusion/core/src/logical_plan/builder.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ impl LogicalPlanBuilder {
154154
.iter()
155155
.enumerate()
156156
.map(|(j, expr)| {
157-
if let Expr::Literal(ScalarValue::Utf8(None)) = expr {
157+
if let Expr::Literal(ScalarValue::Null) = expr {
158158
nulls.push((i, j));
159159
Ok(field_types[j].clone())
160160
} else {

datafusion/core/src/physical_plan/hash_join.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -845,7 +845,11 @@ fn equal_rows(
845845
.iter()
846846
.zip(right_arrays)
847847
.all(|(l, r)| match l.data_type() {
848-
DataType::Null => true,
848+
DataType::Null => {
849+
// lhs and rhs are both `DataType::Null`, so the euqal result
850+
// is dependent on `null_equals_null`
851+
null_equals_null
852+
}
849853
DataType::Boolean => {
850854
equal_rows_elem!(BooleanArray, l, r, left, right, null_equals_null)
851855
}

datafusion/core/src/physical_plan/hash_utils.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,19 @@ fn combine_hashes(l: u64, r: u64) -> u64 {
3939
hash.wrapping_mul(37).wrapping_add(r)
4040
}
4141

42+
fn hash_null(random_state: &RandomState, hashes_buffer: &'_ mut [u64], mul_col: bool) {
43+
if mul_col {
44+
hashes_buffer.iter_mut().for_each(|hash| {
45+
// stable hash for null value
46+
*hash = combine_hashes(i128::get_hash(&1, random_state), *hash);
47+
})
48+
} else {
49+
hashes_buffer.iter_mut().for_each(|hash| {
50+
*hash = i128::get_hash(&1, random_state);
51+
})
52+
}
53+
}
54+
4255
fn hash_decimal128<'a>(
4356
array: &ArrayRef,
4457
random_state: &RandomState,
@@ -310,6 +323,9 @@ pub fn create_hashes<'a>(
310323

311324
for col in arrays {
312325
match col.data_type() {
326+
DataType::Null => {
327+
hash_null(random_state, hashes_buffer, multi_col);
328+
}
313329
DataType::Decimal(_, _) => {
314330
hash_decimal128(col, random_state, hashes_buffer, multi_col);
315331
}

datafusion/core/src/sql/planner.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1681,7 +1681,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
16811681
SQLExpr::Value(Value::Number(n, _)) => parse_sql_number(&n),
16821682
SQLExpr::Value(Value::SingleQuotedString(s)) => Ok(lit(s)),
16831683
SQLExpr::Value(Value::Null) => {
1684-
Ok(Expr::Literal(ScalarValue::Utf8(None)))
1684+
Ok(Expr::Literal(ScalarValue::Null))
16851685
}
16861686
SQLExpr::Value(Value::Boolean(n)) => Ok(lit(n)),
16871687
SQLExpr::UnaryOp { op, expr } => {
@@ -1707,7 +1707,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
17071707
SQLExpr::Value(Value::SingleQuotedString(ref s)) => Ok(lit(s.clone())),
17081708
SQLExpr::Value(Value::EscapedStringLiteral(ref s)) => Ok(lit(s.clone())),
17091709
SQLExpr::Value(Value::Boolean(n)) => Ok(lit(n)),
1710-
SQLExpr::Value(Value::Null) => Ok(Expr::Literal(ScalarValue::Utf8(None))),
1710+
SQLExpr::Value(Value::Null) => Ok(Expr::Literal(ScalarValue::Null)),
17111711
SQLExpr::Extract { field, expr } => Ok(Expr::ScalarFunction {
17121712
fun: BuiltinScalarFunction::DatePart,
17131713
args: vec![
@@ -4259,9 +4259,9 @@ mod tests {
42594259
fn union_with_null() {
42604260
let sql = "SELECT NULL a UNION ALL SELECT 1.1 a";
42614261
let expected = "Union\
4262-
\n Projection: Utf8(NULL) AS a\
4262+
\n Projection: CAST(NULL AS Float64) AS a\
42634263
\n EmptyRelation\
4264-
\n Projection: CAST(Float64(1.1) AS Utf8) AS a\
4264+
\n Projection: Float64(1.1) AS a\
42654265
\n EmptyRelation";
42664266
quick_test(sql, expected);
42674267
}

datafusion/core/tests/sql/expr.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -362,11 +362,11 @@ async fn test_string_concat_operator() -> Result<()> {
362362
let sql = "SELECT 'aa' || NULL || 'd'";
363363
let actual = execute_to_batches(&ctx, sql).await;
364364
let expected = vec![
365-
"+---------------------------------------+",
366-
"| Utf8(\"aa\") || Utf8(NULL) || Utf8(\"d\") |",
367-
"+---------------------------------------+",
368-
"| |",
369-
"+---------------------------------------+",
365+
"+---------------------------------+",
366+
"| Utf8(\"aa\") || NULL || Utf8(\"d\") |",
367+
"+---------------------------------+",
368+
"| |",
369+
"+---------------------------------+",
370370
];
371371
assert_batches_eq!(expected, &actual);
372372

datafusion/core/tests/sql/functions.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -197,11 +197,11 @@ async fn coalesce_static_value_with_null() -> Result<()> {
197197
let sql = "SELECT COALESCE(NULL, 'test')";
198198
let actual = execute_to_batches(&ctx, sql).await;
199199
let expected = vec![
200-
"+-----------------------------------+",
201-
"| coalesce(Utf8(NULL),Utf8(\"test\")) |",
202-
"+-----------------------------------+",
203-
"| test |",
204-
"+-----------------------------------+",
200+
"+-----------------------------+",
201+
"| coalesce(NULL,Utf8(\"test\")) |",
202+
"+-----------------------------+",
203+
"| test |",
204+
"+-----------------------------+",
205205
];
206206
assert_batches_eq!(expected, &actual);
207207
Ok(())

datafusion/core/tests/sql/joins.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -840,7 +840,11 @@ async fn inner_join_nulls() {
840840
let sql = "SELECT * FROM (SELECT null AS id1) t1
841841
INNER JOIN (SELECT null AS id2) t2 ON id1 = id2";
842842

843-
let expected = vec!["++", "++"];
843+
#[rustfmt::skip]
844+
let expected = vec![
845+
"++",
846+
"++",
847+
];
844848

845849
let ctx = create_join_context_qualified("t1", "t2").unwrap();
846850
let actual = execute_to_batches(&ctx, sql).await;

datafusion/core/tests/sql/select.rs

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -398,15 +398,37 @@ async fn select_distinct_from() {
398398
1 IS NOT DISTINCT FROM CAST(NULL as INT) as c,
399399
1 IS NOT DISTINCT FROM 1 as d,
400400
NULL IS DISTINCT FROM NULL as e,
401-
NULL IS NOT DISTINCT FROM NULL as f
401+
NULL IS NOT DISTINCT FROM NULL as f,
402+
NULL is DISTINCT FROM 1 as g,
403+
NULL is NOT DISTINCT FROM 1 as h
402404
";
403405
let actual = execute_to_batches(&ctx, sql).await;
404406
let expected = vec![
405-
"+------+-------+-------+------+-------+------+",
406-
"| a | b | c | d | e | f |",
407-
"+------+-------+-------+------+-------+------+",
408-
"| true | false | false | true | false | true |",
409-
"+------+-------+-------+------+-------+------+",
407+
"+------+-------+-------+------+-------+------+------+-------+",
408+
"| a | b | c | d | e | f | g | h |",
409+
"+------+-------+-------+------+-------+------+------+-------+",
410+
"| true | false | false | true | false | true | true | false |",
411+
"+------+-------+-------+------+-------+------+------+-------+",
412+
];
413+
assert_batches_eq!(expected, &actual);
414+
415+
let sql = "select
416+
NULL IS DISTINCT FROM NULL as a,
417+
NULL IS NOT DISTINCT FROM NULL as b,
418+
NULL is DISTINCT FROM 1 as c,
419+
NULL is NOT DISTINCT FROM 1 as d,
420+
1 IS DISTINCT FROM CAST(NULL as INT) as e,
421+
1 IS DISTINCT FROM 1 as f,
422+
1 IS NOT DISTINCT FROM CAST(NULL as INT) as g,
423+
1 IS NOT DISTINCT FROM 1 as h
424+
";
425+
let actual = execute_to_batches(&ctx, sql).await;
426+
let expected = vec![
427+
"+-------+------+------+-------+------+-------+-------+------+",
428+
"| a | b | c | d | e | f | g | h |",
429+
"+-------+------+------+-------+------+-------+-------+------+",
430+
"| false | true | true | false | true | false | false | true |",
431+
"+-------+------+------+-------+------+-------+-------+------+",
410432
];
411433
assert_batches_eq!(expected, &actual);
412434
}

datafusion/expr/src/binary_rule.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -625,6 +625,7 @@ fn eq_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
625625
numerical_coercion(lhs_type, rhs_type)
626626
.or_else(|| dictionary_coercion(lhs_type, rhs_type))
627627
.or_else(|| temporal_coercion(lhs_type, rhs_type))
628+
.or_else(|| null_coercion(lhs_type, rhs_type))
628629
}
629630

630631
/// Coercion rule for interval

0 commit comments

Comments
 (0)