Skip to content

Commit 8f90814

Browse files
authored
refactor: update cmp and nested data in binary operator (#18256)
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> - Related #18210 ## Rationale for this change To keep logic clear in binary operator and make it possible to use binary operators for nested data structures in coming changes. ## What changes are included in this PR? Another housekeeping refactor for binary operators. - Keep the API from datum module consistent by using `Operator` instead of kernel function - Move nested data structure check into cmp operators. This allows us to implement binary operators for `List`, `Struct` and etc. ## Are these changes tested? Unit tests ## Are there any user-facing changes? N/A
1 parent 3de195a commit 8f90814

File tree

3 files changed

+66
-38
lines changed

3 files changed

+66
-38
lines changed

datafusion/physical-expr-common/src/datum.rs

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@
1818
use arrow::array::BooleanArray;
1919
use arrow::array::{make_comparator, ArrayRef, Datum};
2020
use arrow::buffer::NullBuffer;
21-
use arrow::compute::SortOptions;
21+
use arrow::compute::kernels::cmp::{
22+
distinct, eq, gt, gt_eq, lt, lt_eq, neq, not_distinct,
23+
};
24+
use arrow::compute::{ilike, like, nilike, nlike, SortOptions};
2225
use arrow::error::ArrowError;
2326
use datafusion_common::DataFusionError;
2427
use datafusion_common::{arrow_datafusion_err, internal_err};
@@ -53,22 +56,49 @@ pub fn apply(
5356
}
5457
}
5558

56-
/// Applies a binary [`Datum`] comparison kernel `f` to `lhs` and `rhs`
59+
/// Applies a binary [`Datum`] comparison operator `op` to `lhs` and `rhs`
5760
pub fn apply_cmp(
61+
op: Operator,
5862
lhs: &ColumnarValue,
5963
rhs: &ColumnarValue,
60-
f: impl Fn(&dyn Datum, &dyn Datum) -> Result<BooleanArray, ArrowError>,
6164
) -> Result<ColumnarValue> {
62-
apply(lhs, rhs, |l, r| Ok(Arc::new(f(l, r)?)))
65+
if lhs.data_type().is_nested() {
66+
apply_cmp_for_nested(op, lhs, rhs)
67+
} else {
68+
let f = match op {
69+
Operator::Eq => eq,
70+
Operator::NotEq => neq,
71+
Operator::Lt => lt,
72+
Operator::LtEq => lt_eq,
73+
Operator::Gt => gt,
74+
Operator::GtEq => gt_eq,
75+
Operator::IsDistinctFrom => distinct,
76+
Operator::IsNotDistinctFrom => not_distinct,
77+
78+
Operator::LikeMatch => like,
79+
Operator::ILikeMatch => ilike,
80+
Operator::NotLikeMatch => nlike,
81+
Operator::NotILikeMatch => nilike,
82+
83+
_ => {
84+
return internal_err!("Invalid compare operator: {}", op);
85+
}
86+
};
87+
88+
apply(lhs, rhs, |l, r| Ok(Arc::new(f(l, r)?)))
89+
}
6390
}
6491

65-
/// Applies a binary [`Datum`] comparison kernel `f` to `lhs` and `rhs` for nested type like
92+
/// Applies a binary [`Datum`] comparison operator `op` to `lhs` and `rhs` for nested type like
6693
/// List, FixedSizeList, LargeList, Struct, Union, Map, or a dictionary of a nested type
6794
pub fn apply_cmp_for_nested(
6895
op: Operator,
6996
lhs: &ColumnarValue,
7097
rhs: &ColumnarValue,
7198
) -> Result<ColumnarValue> {
99+
let left_data_type = lhs.data_type();
100+
let right_data_type = rhs.data_type();
101+
72102
if matches!(
73103
op,
74104
Operator::Eq
@@ -79,12 +109,18 @@ pub fn apply_cmp_for_nested(
79109
| Operator::GtEq
80110
| Operator::IsDistinctFrom
81111
| Operator::IsNotDistinctFrom
82-
) {
112+
) && left_data_type.equals_datatype(&right_data_type)
113+
{
83114
apply(lhs, rhs, |l, r| {
84115
Ok(Arc::new(compare_op_for_nested(op, l, r)?))
85116
})
86117
} else {
87-
internal_err!("invalid operator for nested")
118+
internal_err!(
119+
"invalid operator or data type mismatch for nested data, op {} left {}, right {}",
120+
op,
121+
left_data_type,
122+
right_data_type
123+
)
88124
}
89125
}
90126

@@ -97,7 +133,7 @@ pub fn compare_with_eq(
97133
if is_nested {
98134
compare_op_for_nested(Operator::Eq, lhs, rhs)
99135
} else {
100-
arrow::compute::kernels::cmp::eq(lhs, rhs).map_err(|e| arrow_datafusion_err!(e))
136+
eq(lhs, rhs).map_err(|e| arrow_datafusion_err!(e))
101137
}
102138
}
103139

datafusion/physical-expr/src/expressions/binary.rs

Lines changed: 17 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,8 @@ use std::{any::Any, sync::Arc};
2424

2525
use arrow::array::*;
2626
use arrow::compute::kernels::boolean::{and_kleene, or_kleene};
27-
use arrow::compute::kernels::cmp::*;
2827
use arrow::compute::kernels::concat_elements::concat_elements_utf8;
29-
use arrow::compute::{
30-
cast, filter_record_batch, ilike, like, nilike, nlike, SlicesIterator,
31-
};
28+
use arrow::compute::{cast, filter_record_batch, SlicesIterator};
3229
use arrow::datatypes::*;
3330
use arrow::error::ArrowError;
3431
use datafusion_common::cast::as_boolean_array;
@@ -42,7 +39,7 @@ use datafusion_expr::statistics::{
4239
new_generic_from_binary_op, Distribution,
4340
};
4441
use datafusion_expr::{ColumnarValue, Operator};
45-
use datafusion_physical_expr_common::datum::{apply, apply_cmp, apply_cmp_for_nested};
42+
use datafusion_physical_expr_common::datum::{apply, apply_cmp};
4643

4744
use kernels::{
4845
bitwise_and_dyn, bitwise_and_dyn_scalar, bitwise_or_dyn, bitwise_or_dyn_scalar,
@@ -251,13 +248,6 @@ impl PhysicalExpr for BinaryExpr {
251248
let schema = batch.schema();
252249
let input_schema = schema.as_ref();
253250

254-
if left_data_type.is_nested() {
255-
if !left_data_type.equals_datatype(&right_data_type) {
256-
return internal_err!("Cannot evaluate binary expression because of type mismatch: left {}, right {} ", left_data_type, right_data_type);
257-
}
258-
return apply_cmp_for_nested(self.op, &lhs, &rhs);
259-
}
260-
261251
match self.op {
262252
Operator::Plus if self.fail_on_overflow => return apply(&lhs, &rhs, add),
263253
Operator::Plus => return apply(&lhs, &rhs, add_wrapping),
@@ -267,18 +257,21 @@ impl PhysicalExpr for BinaryExpr {
267257
Operator::Multiply => return apply(&lhs, &rhs, mul_wrapping),
268258
Operator::Divide => return apply(&lhs, &rhs, div),
269259
Operator::Modulo => return apply(&lhs, &rhs, rem),
270-
Operator::Eq => return apply_cmp(&lhs, &rhs, eq),
271-
Operator::NotEq => return apply_cmp(&lhs, &rhs, neq),
272-
Operator::Lt => return apply_cmp(&lhs, &rhs, lt),
273-
Operator::Gt => return apply_cmp(&lhs, &rhs, gt),
274-
Operator::LtEq => return apply_cmp(&lhs, &rhs, lt_eq),
275-
Operator::GtEq => return apply_cmp(&lhs, &rhs, gt_eq),
276-
Operator::IsDistinctFrom => return apply_cmp(&lhs, &rhs, distinct),
277-
Operator::IsNotDistinctFrom => return apply_cmp(&lhs, &rhs, not_distinct),
278-
Operator::LikeMatch => return apply_cmp(&lhs, &rhs, like),
279-
Operator::ILikeMatch => return apply_cmp(&lhs, &rhs, ilike),
280-
Operator::NotLikeMatch => return apply_cmp(&lhs, &rhs, nlike),
281-
Operator::NotILikeMatch => return apply_cmp(&lhs, &rhs, nilike),
260+
261+
Operator::Eq
262+
| Operator::NotEq
263+
| Operator::Lt
264+
| Operator::Gt
265+
| Operator::LtEq
266+
| Operator::GtEq
267+
| Operator::IsDistinctFrom
268+
| Operator::IsNotDistinctFrom
269+
| Operator::LikeMatch
270+
| Operator::ILikeMatch
271+
| Operator::NotLikeMatch
272+
| Operator::NotILikeMatch => {
273+
return apply_cmp(self.op, &lhs, &rhs);
274+
}
282275
_ => {}
283276
}
284277

datafusion/physical-expr/src/expressions/like.rs

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ use crate::PhysicalExpr;
1919
use arrow::datatypes::{DataType, Schema};
2020
use arrow::record_batch::RecordBatch;
2121
use datafusion_common::{internal_err, Result};
22-
use datafusion_expr::ColumnarValue;
22+
use datafusion_expr::{ColumnarValue, Operator};
2323
use datafusion_physical_expr_common::datum::apply_cmp;
2424
use std::hash::Hash;
2525
use std::{any::Any, sync::Arc};
@@ -118,14 +118,13 @@ impl PhysicalExpr for LikeExpr {
118118
}
119119

120120
fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
121-
use arrow::compute::*;
122121
let lhs = self.expr.evaluate(batch)?;
123122
let rhs = self.pattern.evaluate(batch)?;
124123
match (self.negated, self.case_insensitive) {
125-
(false, false) => apply_cmp(&lhs, &rhs, like),
126-
(false, true) => apply_cmp(&lhs, &rhs, ilike),
127-
(true, false) => apply_cmp(&lhs, &rhs, nlike),
128-
(true, true) => apply_cmp(&lhs, &rhs, nilike),
124+
(false, false) => apply_cmp(Operator::LikeMatch, &lhs, &rhs),
125+
(false, true) => apply_cmp(Operator::ILikeMatch, &lhs, &rhs),
126+
(true, false) => apply_cmp(Operator::NotLikeMatch, &lhs, &rhs),
127+
(true, true) => apply_cmp(Operator::NotILikeMatch, &lhs, &rhs),
129128
}
130129
}
131130

0 commit comments

Comments
 (0)