Skip to content

Commit 1a26be2

Browse files
WinkerDuMazterQyou
authored andcommitted
Numeric, String, Boolean comparisons with literal NULL (apache#2481)
1 parent a809c5e commit 1a26be2

File tree

4 files changed

+181
-36
lines changed

4 files changed

+181
-36
lines changed

datafusion/core/src/physical_plan/planner.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1793,7 +1793,8 @@ mod tests {
17931793
let bool_expr = col("c1").eq(col("c1"));
17941794
let cases = vec![
17951795
// utf8 < u32
1796-
col("c1").lt(col("c2")),
1796+
// NOTE(cubesql): valid
1797+
//col("c1").lt(col("c2")),
17971798
// utf8 AND utf8
17981799
col("c1").and(col("c1")),
17991800
// u8 AND u8

datafusion/core/tests/sql/expr.rs

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1280,3 +1280,123 @@ async fn nested_subquery() -> Result<()> {
12801280
assert_batches_eq!(expected, &actual);
12811281
Ok(())
12821282
}
1283+
1284+
#[tokio::test]
1285+
async fn comparisons_with_null() -> Result<()> {
1286+
let ctx = SessionContext::new();
1287+
// 1. Numeric comparison with NULL
1288+
let sql = "select column1 < NULL from (VALUES (1, 'foo' ,2.3), (2, 'bar', 5.4)) as t";
1289+
let actual = execute_to_batches(&ctx, sql).await;
1290+
let expected = vec![
1291+
"+-------------------------+",
1292+
"| t.column1 Lt Utf8(NULL) |",
1293+
"+-------------------------+",
1294+
"| |",
1295+
"| |",
1296+
"+-------------------------+",
1297+
];
1298+
assert_batches_eq!(expected, &actual);
1299+
1300+
let sql =
1301+
"select column1 <= NULL from (VALUES (1, 'foo' ,2.3), (2, 'bar', 5.4)) as t";
1302+
let actual = execute_to_batches(&ctx, sql).await;
1303+
let expected = vec![
1304+
"+---------------------------+",
1305+
"| t.column1 LtEq Utf8(NULL) |",
1306+
"+---------------------------+",
1307+
"| |",
1308+
"| |",
1309+
"+---------------------------+",
1310+
];
1311+
assert_batches_eq!(expected, &actual);
1312+
1313+
let sql = "select column1 > NULL from (VALUES (1, 'foo' ,2.3), (2, 'bar', 5.4)) as t";
1314+
let actual = execute_to_batches(&ctx, sql).await;
1315+
let expected = vec![
1316+
"+-------------------------+",
1317+
"| t.column1 Gt Utf8(NULL) |",
1318+
"+-------------------------+",
1319+
"| |",
1320+
"| |",
1321+
"+-------------------------+",
1322+
];
1323+
assert_batches_eq!(expected, &actual);
1324+
1325+
let sql =
1326+
"select column1 >= NULL from (VALUES (1, 'foo' ,2.3), (2, 'bar', 5.4)) as t";
1327+
let actual = execute_to_batches(&ctx, sql).await;
1328+
let expected = vec![
1329+
"+---------------------------+",
1330+
"| t.column1 GtEq Utf8(NULL) |",
1331+
"+---------------------------+",
1332+
"| |",
1333+
"| |",
1334+
"+---------------------------+",
1335+
];
1336+
assert_batches_eq!(expected, &actual);
1337+
1338+
let sql = "select column1 = NULL from (VALUES (1, 'foo' ,2.3), (2, 'bar', 5.4)) as t";
1339+
let actual = execute_to_batches(&ctx, sql).await;
1340+
let expected = vec![
1341+
"+-------------------------+",
1342+
"| t.column1 Eq Utf8(NULL) |",
1343+
"+-------------------------+",
1344+
"| |",
1345+
"| |",
1346+
"+-------------------------+",
1347+
];
1348+
assert_batches_eq!(expected, &actual);
1349+
1350+
let sql =
1351+
"select column1 != NULL from (VALUES (1, 'foo' ,2.3), (2, 'bar', 5.4)) as t";
1352+
let actual = execute_to_batches(&ctx, sql).await;
1353+
let expected = vec![
1354+
"+----------------------------+",
1355+
"| t.column1 NotEq Utf8(NULL) |",
1356+
"+----------------------------+",
1357+
"| |",
1358+
"| |",
1359+
"+----------------------------+",
1360+
];
1361+
assert_batches_eq!(expected, &actual);
1362+
1363+
// 1.1 Float value comparison with NULL
1364+
let sql = "select column3 < NULL from (VALUES (1, 'foo' ,2.3), (2, 'bar', 5.4)) as t";
1365+
let actual = execute_to_batches(&ctx, sql).await;
1366+
let expected = vec![
1367+
"+-------------------------+",
1368+
"| t.column3 Lt Utf8(NULL) |",
1369+
"+-------------------------+",
1370+
"| |",
1371+
"| |",
1372+
"+-------------------------+",
1373+
];
1374+
assert_batches_eq!(expected, &actual);
1375+
1376+
// String comparison with NULL
1377+
let sql = "select column2 < NULL from (VALUES (1, 'foo' ,2.3), (2, 'bar', 5.4)) as t";
1378+
let actual = execute_to_batches(&ctx, sql).await;
1379+
let expected = vec![
1380+
"+-------------------------+",
1381+
"| t.column2 Lt Utf8(NULL) |",
1382+
"+-------------------------+",
1383+
"| |",
1384+
"| |",
1385+
"+-------------------------+",
1386+
];
1387+
assert_batches_eq!(expected, &actual);
1388+
1389+
// Boolean comparison with NULL
1390+
let sql = "select column1 < NULL from (VALUES (true), (false)) as t";
1391+
let actual = execute_to_batches(&ctx, sql).await;
1392+
let expected = vec![
1393+
"+-------------------------+",
1394+
"| t.column1 Lt Utf8(NULL) |",
1395+
"+-------------------------+",
1396+
"| |",
1397+
"| |",
1398+
"+-------------------------+",
1399+
];
1400+
assert_batches_eq!(expected, &actual);
1401+
Ok(())
1402+
}

datafusion/physical-expr/src/coercion_rule/binary_rule.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,10 @@ pub fn comparison_eq_coercion(
122122
.or_else(|| string_coercion(lhs_type, rhs_type))
123123
.or_else(|| null_coercion(lhs_type, rhs_type))
124124
.or_else(|| string_numeric_coercion(lhs_type, rhs_type))
125+
.or_else(|| string_boolean_coercion(lhs_type, rhs_type))
125126
}
126127

128+
// NOTE: NULL hack!
127129
fn string_numeric_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
128130
use arrow::datatypes::DataType::*;
129131
match (lhs_type, rhs_type) {
@@ -135,6 +137,15 @@ fn string_numeric_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<D
135137
}
136138
}
137139

140+
fn string_boolean_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
141+
use arrow::datatypes::DataType::*;
142+
match (lhs_type, rhs_type) {
143+
(Utf8, Boolean) | (Boolean, Utf8) => Some(Utf8),
144+
(LargeUtf8, Boolean) | (Boolean, LargeUtf8) => Some(LargeUtf8),
145+
_ => None,
146+
}
147+
}
148+
138149
fn comparison_order_coercion(
139150
lhs_type: &DataType,
140151
rhs_type: &DataType,
@@ -149,6 +160,9 @@ fn comparison_order_coercion(
149160
.or_else(|| string_coercion(lhs_type, rhs_type))
150161
.or_else(|| dictionary_coercion(lhs_type, rhs_type))
151162
.or_else(|| temporal_coercion(lhs_type, rhs_type))
163+
.or_else(|| null_coercion(lhs_type, rhs_type))
164+
.or_else(|| string_numeric_coercion(lhs_type, rhs_type))
165+
.or_else(|| string_boolean_coercion(lhs_type, rhs_type))
152166
}
153167

154168
fn comparison_binary_numeric_coercion(

datafusion/physical-expr/src/expressions/binary.rs

Lines changed: 45 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -844,17 +844,15 @@ macro_rules! compute_utf8_op_scalar {
844844

845845
/// Invoke a compute kernel on a data array and a scalar value
846846
macro_rules! compute_utf8_op_dyn_scalar {
847-
($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
847+
($LEFT:expr, $RIGHT:expr, $OP:ident, $OP_TYPE:expr) => {{
848848
if let Some(string_value) = $RIGHT {
849849
Ok(Arc::new(paste::expr! {[<$OP _dyn_utf8_scalar>]}(
850850
$LEFT,
851851
&string_value,
852852
)?))
853853
} else {
854-
Err(DataFusionError::Internal(format!(
855-
"compute_utf8_op_scalar for '{}' failed with literal 'none' value",
856-
stringify!($OP),
857-
)))
854+
// when the $RIGHT is a NULL, generate a NULL array of $OP_TYPE
855+
Ok(Arc::new(new_null_array($OP_TYPE, $LEFT.len())))
858856
}
859857
}};
860858
}
@@ -878,7 +876,7 @@ macro_rules! compute_bool_op_scalar {
878876

879877
/// Invoke a compute kernel on a boolean data array and a scalar value
880878
macro_rules! compute_bool_op_dyn_scalar {
881-
($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
879+
($LEFT:expr, $RIGHT:expr, $OP:ident, $OP_TYPE:expr) => {{
882880
// generate the scalar function name, such as lt_dyn_bool_scalar, from the $OP parameter
883881
// (which could have a value of lt) and the suffix _scalar
884882
if let Some(b) = $RIGHT {
@@ -887,10 +885,8 @@ macro_rules! compute_bool_op_dyn_scalar {
887885
b,
888886
)?))
889887
} else {
890-
Err(DataFusionError::Internal(format!(
891-
"compute_utf8_op_scalar for '{}' failed with literal 'none' value",
892-
stringify!($OP),
893-
)))
888+
// when the $RIGHT is a NULL, generate a NULL array of $OP_TYPE
889+
Ok(Arc::new(new_null_array($OP_TYPE, $LEFT.len())))
894890
}
895891
}};
896892
}
@@ -938,8 +934,9 @@ macro_rules! compute_op_scalar {
938934

939935
/// Invoke a dyn compute kernel on a data array and a scalar value
940936
/// LEFT is Primitive or Dictionart array of numeric values, RIGHT is scalar value
937+
/// OP_TYPE is the return type of scalar function
941938
macro_rules! compute_op_dyn_scalar {
942-
($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
939+
($LEFT:expr, $RIGHT:expr, $OP:ident, $OP_TYPE:expr) => {{
943940
// generate the scalar function name, such as lt_dyn_scalar, from the $OP parameter
944941
// (which could have a value of lt_dyn) and the suffix _scalar
945942
if let Some(value) = $RIGHT {
@@ -948,10 +945,8 @@ macro_rules! compute_op_dyn_scalar {
948945
value,
949946
)?))
950947
} else {
951-
Err(DataFusionError::Internal(format!(
952-
"compute_utf8_op_scalar for '{}' failed with literal 'none' value",
953-
stringify!($OP),
954-
)))
948+
// when the $RIGHT is a NULL, generate a NULL array of $OP_TYPE
949+
Ok(Arc::new(new_null_array($OP_TYPE, $LEFT.len())))
955950
}
956951
}};
957952
}
@@ -1359,22 +1354,22 @@ impl PhysicalExpr for BinaryExpr {
13591354
/// such as Utf8 strings.
13601355
#[macro_export]
13611356
macro_rules! binary_array_op_dyn_scalar {
1362-
($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
1357+
($LEFT:expr, $RIGHT:expr, $OP:ident, $OP_TYPE:expr) => {{
13631358
let result: Result<Arc<dyn Array>> = match $RIGHT {
1364-
ScalarValue::Boolean(b) => compute_bool_op_dyn_scalar!($LEFT, b, $OP),
1359+
ScalarValue::Boolean(b) => compute_bool_op_dyn_scalar!($LEFT, b, $OP, $OP_TYPE),
13651360
ScalarValue::Decimal128(..) => compute_decimal_op_scalar!($LEFT, $RIGHT, $OP, DecimalArray),
1366-
ScalarValue::Utf8(v) => compute_utf8_op_dyn_scalar!($LEFT, v, $OP),
1367-
ScalarValue::LargeUtf8(v) => compute_utf8_op_dyn_scalar!($LEFT, v, $OP),
1368-
ScalarValue::Int8(v) => compute_op_dyn_scalar!($LEFT, v, $OP),
1369-
ScalarValue::Int16(v) => compute_op_dyn_scalar!($LEFT, v, $OP),
1370-
ScalarValue::Int32(v) => compute_op_dyn_scalar!($LEFT, v, $OP),
1371-
ScalarValue::Int64(v) => compute_op_dyn_scalar!($LEFT, v, $OP),
1372-
ScalarValue::UInt8(v) => compute_op_dyn_scalar!($LEFT, v, $OP),
1373-
ScalarValue::UInt16(v) => compute_op_dyn_scalar!($LEFT, v, $OP),
1374-
ScalarValue::UInt32(v) => compute_op_dyn_scalar!($LEFT, v, $OP),
1375-
ScalarValue::UInt64(v) => compute_op_dyn_scalar!($LEFT, v, $OP),
1376-
ScalarValue::Float32(_) => compute_op_scalar!($LEFT, $RIGHT, $OP, Float32Array),
1377-
ScalarValue::Float64(_) => compute_op_scalar!($LEFT, $RIGHT, $OP, Float64Array),
1361+
ScalarValue::Utf8(v) => compute_utf8_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE),
1362+
ScalarValue::LargeUtf8(v) => compute_utf8_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE),
1363+
ScalarValue::Int8(v) => compute_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE),
1364+
ScalarValue::Int16(v) => compute_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE),
1365+
ScalarValue::Int32(v) => compute_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE),
1366+
ScalarValue::Int64(v) => compute_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE),
1367+
ScalarValue::UInt8(v) => compute_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE),
1368+
ScalarValue::UInt16(v) => compute_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE),
1369+
ScalarValue::UInt32(v) => compute_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE),
1370+
ScalarValue::UInt64(v) => compute_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE),
1371+
ScalarValue::Float32(v) => compute_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE),
1372+
ScalarValue::Float64(v) => compute_op_dyn_scalar!($LEFT, v, $OP, $OP_TYPE),
13781373
ScalarValue::Date32(_) => compute_op_scalar!($LEFT, $RIGHT, $OP, Date32Array),
13791374
ScalarValue::Date64(_) => compute_op_scalar!($LEFT, $RIGHT, $OP, Date64Array),
13801375
ScalarValue::TimestampSecond(..) => compute_op_scalar!($LEFT, $RIGHT, $OP, TimestampSecondArray),
@@ -1397,22 +1392,37 @@ impl BinaryExpr {
13971392
) -> Result<Option<Result<ArrayRef>>> {
13981393
let scalar_result = match &self.op {
13991394
Operator::Lt => {
1400-
binary_array_op_dyn_scalar!(array, scalar.clone(), lt)
1395+
binary_array_op_dyn_scalar!(array, scalar.clone(), lt, &DataType::Boolean)
14011396
}
14021397
Operator::LtEq => {
1403-
binary_array_op_dyn_scalar!(array, scalar.clone(), lt_eq)
1398+
binary_array_op_dyn_scalar!(
1399+
array,
1400+
scalar.clone(),
1401+
lt_eq,
1402+
&DataType::Boolean
1403+
)
14041404
}
14051405
Operator::Gt => {
1406-
binary_array_op_dyn_scalar!(array, scalar.clone(), gt)
1406+
binary_array_op_dyn_scalar!(array, scalar.clone(), gt, &DataType::Boolean)
14071407
}
14081408
Operator::GtEq => {
1409-
binary_array_op_dyn_scalar!(array, scalar.clone(), gt_eq)
1409+
binary_array_op_dyn_scalar!(
1410+
array,
1411+
scalar.clone(),
1412+
gt_eq,
1413+
&DataType::Boolean
1414+
)
14101415
}
14111416
Operator::Eq => {
1412-
binary_array_op_dyn_scalar!(array, scalar.clone(), eq)
1417+
binary_array_op_dyn_scalar!(array, scalar.clone(), eq, &DataType::Boolean)
14131418
}
14141419
Operator::NotEq => {
1415-
binary_array_op_dyn_scalar!(array, scalar.clone(), neq)
1420+
binary_array_op_dyn_scalar!(
1421+
array,
1422+
scalar.clone(),
1423+
neq,
1424+
&DataType::Boolean
1425+
)
14161426
}
14171427
Operator::Like => {
14181428
binary_string_array_op_scalar!(array, scalar.clone(), like)

0 commit comments

Comments
 (0)