Skip to content

Commit acec058

Browse files
authored
perf: Use Arrow vectorized eq kernel for IN list with column references (#20528)
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> - Relates to #20427 . ## Rationale for this change <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> When the IN list contains column references (e.g. `SELECT * FROM t WHERE a IN (b, c, d, e)`), DataFusion falls back to a row-by-row `make_comparator` path which is significantly slower than it needs to be. Arrow provides SIMD-optimized `eq` kernels that can compare entire arrays in one call. ## What changes are included in this PR? <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> - Use Arrow's vectorized `eq` kernel instead of row-by-row `make_comparator` for non-nested types (primitive, string, binary) in the column-reference IN list evaluation path - For nested types (Struct, List, etc.), fall back to `make_comparator` since Arrow's `eq` kernel does not support them - Add 6 unit tests covering the column-reference evaluation path (Int32, Utf8, NOT IN, NULL handling, NaN semantics) ## Are these changes tested? <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> Yes. 6 new unit tests added: - `test_in_list_with_columns_int32_scalars` - `test_in_list_with_columns_int32_column_refs` - `test_in_list_with_columns_utf8_column_refs` - `test_in_list_with_columns_negated` - `test_in_list_with_columns_null_in_list` - `test_in_list_with_columns_float_nan` ## Are there any user-facing changes? <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> No API changes. Queries with column-reference IN lists will run faster. <!-- If there are any breaking changes to public APIs, please add the `api change` label. -->
1 parent 5d8249f commit acec058

File tree

1 file changed

+231
-14
lines changed
  • datafusion/physical-expr/src/expressions

1 file changed

+231
-14
lines changed

datafusion/physical-expr/src/expressions/in_list.rs

Lines changed: 231 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ use crate::physical_expr::physical_exprs_bag_equal;
2828
use arrow::array::*;
2929
use arrow::buffer::{BooleanBuffer, NullBuffer};
3030
use arrow::compute::kernels::boolean::{not, or_kleene};
31+
use arrow::compute::kernels::cmp::eq as arrow_eq;
3132
use arrow::compute::{SortOptions, take};
3233
use arrow::datatypes::*;
3334
use arrow::util::bit_iterator::BitIndexIterator;
@@ -138,6 +139,21 @@ impl StaticFilter for ArrayStaticFilter {
138139
}
139140
}
140141

142+
/// Returns true if Arrow's vectorized `eq` kernel supports this data type.
143+
///
144+
/// Supported: primitives, boolean, strings (Utf8/LargeUtf8/Utf8View),
145+
/// binary (Binary/LargeBinary/BinaryView/FixedSizeBinary), Null, and
146+
/// Dictionary-encoded variants of the above.
147+
/// Unsupported: nested types (Struct, List, Map, Union) and RunEndEncoded.
148+
fn supports_arrow_eq(dt: &DataType) -> bool {
149+
use DataType::*;
150+
match dt {
151+
Boolean | Binary | LargeBinary | BinaryView | FixedSizeBinary(_) => true,
152+
Dictionary(_, v) => supports_arrow_eq(v.as_ref()),
153+
_ => dt.is_primitive() || dt.is_null() || dt.is_string(),
154+
}
155+
}
156+
141157
fn instantiate_static_filter(
142158
in_array: ArrayRef,
143159
) -> Result<Arc<dyn StaticFilter + Send + Sync>> {
@@ -771,32 +787,45 @@ impl PhysicalExpr for InListExpr {
771787
}
772788
}
773789
None => {
774-
// No static filter: iterate through each expression, compare, and OR results
790+
// No static filter: iterate through each expression, compare, and OR results.
791+
// Use Arrow's vectorized eq kernel for types it supports (primitive,
792+
// boolean, string, binary, dictionary), falling back to row-by-row
793+
// comparator for unsupported types (nested, RunEndEncoded, etc.).
775794
let value = value.into_array(num_rows)?;
795+
let lhs_supports_arrow_eq = supports_arrow_eq(value.data_type());
776796
let found = self.list.iter().map(|expr| expr.evaluate(batch)).try_fold(
777797
BooleanArray::new(BooleanBuffer::new_unset(num_rows), None),
778798
|result, expr| -> Result<BooleanArray> {
779799
let rhs = match expr? {
780800
ColumnarValue::Array(array) => {
781-
let cmp = make_comparator(
782-
value.as_ref(),
783-
array.as_ref(),
784-
SortOptions::default(),
785-
)?;
786-
(0..num_rows)
787-
.map(|i| {
788-
if value.is_null(i) || array.is_null(i) {
789-
return None;
790-
}
791-
Some(cmp(i, i).is_eq())
792-
})
793-
.collect::<BooleanArray>()
801+
if lhs_supports_arrow_eq
802+
&& supports_arrow_eq(array.data_type())
803+
{
804+
arrow_eq(&value, &array)?
805+
} else {
806+
let cmp = make_comparator(
807+
value.as_ref(),
808+
array.as_ref(),
809+
SortOptions::default(),
810+
)?;
811+
(0..num_rows)
812+
.map(|i| {
813+
if value.is_null(i) || array.is_null(i) {
814+
return None;
815+
}
816+
Some(cmp(i, i).is_eq())
817+
})
818+
.collect::<BooleanArray>()
819+
}
794820
}
795821
ColumnarValue::Scalar(scalar) => {
796822
// Check if scalar is null once, before the loop
797823
if scalar.is_null() {
798824
// If scalar is null, all comparisons return null
799825
BooleanArray::from(vec![None; num_rows])
826+
} else if lhs_supports_arrow_eq {
827+
let scalar_datum = scalar.to_scalar()?;
828+
arrow_eq(&value, &scalar_datum)?
800829
} else {
801830
// Convert scalar to 1-element array
802831
let array = scalar.to_array()?;
@@ -3507,4 +3536,192 @@ mod tests {
35073536

35083537
Ok(())
35093538
}
3539+
3540+
/// Helper: creates an InListExpr with `static_filter = None`
3541+
/// to force the column-reference evaluation path.
3542+
fn make_in_list_with_columns(
3543+
expr: Arc<dyn PhysicalExpr>,
3544+
list: Vec<Arc<dyn PhysicalExpr>>,
3545+
negated: bool,
3546+
) -> Arc<InListExpr> {
3547+
Arc::new(InListExpr::new(expr, list, negated, None))
3548+
}
3549+
3550+
#[test]
3551+
fn test_in_list_with_columns_int32_scalars() -> Result<()> {
3552+
// Column-reference path with scalar literals (bypassing static filter)
3553+
let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
3554+
let col_a = col("a", &schema)?;
3555+
let batch = RecordBatch::try_new(
3556+
Arc::new(schema),
3557+
vec![Arc::new(Int32Array::from(vec![
3558+
Some(1),
3559+
Some(2),
3560+
Some(3),
3561+
None,
3562+
]))],
3563+
)?;
3564+
3565+
let list = vec![
3566+
lit(ScalarValue::Int32(Some(1))),
3567+
lit(ScalarValue::Int32(Some(3))),
3568+
];
3569+
let expr = make_in_list_with_columns(col_a, list, false);
3570+
3571+
let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?;
3572+
let result = as_boolean_array(&result);
3573+
assert_eq!(
3574+
result,
3575+
&BooleanArray::from(vec![Some(true), Some(false), Some(true), None,])
3576+
);
3577+
Ok(())
3578+
}
3579+
3580+
#[test]
3581+
fn test_in_list_with_columns_int32_column_refs() -> Result<()> {
3582+
// IN list with column references
3583+
let schema = Schema::new(vec![
3584+
Field::new("a", DataType::Int32, true),
3585+
Field::new("b", DataType::Int32, true),
3586+
Field::new("c", DataType::Int32, true),
3587+
]);
3588+
let batch = RecordBatch::try_new(
3589+
Arc::new(schema.clone()),
3590+
vec![
3591+
Arc::new(Int32Array::from(vec![Some(1), Some(2), Some(3), None])),
3592+
Arc::new(Int32Array::from(vec![
3593+
Some(1),
3594+
Some(99),
3595+
Some(99),
3596+
Some(99),
3597+
])),
3598+
Arc::new(Int32Array::from(vec![Some(99), Some(99), Some(3), None])),
3599+
],
3600+
)?;
3601+
3602+
let col_a = col("a", &schema)?;
3603+
let list = vec![col("b", &schema)?, col("c", &schema)?];
3604+
let expr = make_in_list_with_columns(col_a, list, false);
3605+
3606+
let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?;
3607+
let result = as_boolean_array(&result);
3608+
// row 0: 1 IN (1, 99) → true
3609+
// row 1: 2 IN (99, 99) → false
3610+
// row 2: 3 IN (99, 3) → true
3611+
// row 3: NULL IN (99, NULL) → NULL
3612+
assert_eq!(
3613+
result,
3614+
&BooleanArray::from(vec![Some(true), Some(false), Some(true), None,])
3615+
);
3616+
Ok(())
3617+
}
3618+
3619+
#[test]
3620+
fn test_in_list_with_columns_utf8_column_refs() -> Result<()> {
3621+
// IN list with Utf8 column references
3622+
let schema = Schema::new(vec![
3623+
Field::new("a", DataType::Utf8, false),
3624+
Field::new("b", DataType::Utf8, false),
3625+
]);
3626+
let batch = RecordBatch::try_new(
3627+
Arc::new(schema.clone()),
3628+
vec![
3629+
Arc::new(StringArray::from(vec!["x", "y", "z"])),
3630+
Arc::new(StringArray::from(vec!["x", "x", "z"])),
3631+
],
3632+
)?;
3633+
3634+
let col_a = col("a", &schema)?;
3635+
let list = vec![col("b", &schema)?];
3636+
let expr = make_in_list_with_columns(col_a, list, false);
3637+
3638+
let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?;
3639+
let result = as_boolean_array(&result);
3640+
// row 0: "x" IN ("x") → true
3641+
// row 1: "y" IN ("x") → false
3642+
// row 2: "z" IN ("z") → true
3643+
assert_eq!(result, &BooleanArray::from(vec![true, false, true]));
3644+
Ok(())
3645+
}
3646+
3647+
#[test]
3648+
fn test_in_list_with_columns_negated() -> Result<()> {
3649+
// NOT IN with column references
3650+
let schema = Schema::new(vec![
3651+
Field::new("a", DataType::Int32, false),
3652+
Field::new("b", DataType::Int32, false),
3653+
]);
3654+
let batch = RecordBatch::try_new(
3655+
Arc::new(schema.clone()),
3656+
vec![
3657+
Arc::new(Int32Array::from(vec![1, 2, 3])),
3658+
Arc::new(Int32Array::from(vec![1, 99, 3])),
3659+
],
3660+
)?;
3661+
3662+
let col_a = col("a", &schema)?;
3663+
let list = vec![col("b", &schema)?];
3664+
let expr = make_in_list_with_columns(col_a, list, true);
3665+
3666+
let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?;
3667+
let result = as_boolean_array(&result);
3668+
// row 0: 1 NOT IN (1) → false
3669+
// row 1: 2 NOT IN (99) → true
3670+
// row 2: 3 NOT IN (3) → false
3671+
assert_eq!(result, &BooleanArray::from(vec![false, true, false]));
3672+
Ok(())
3673+
}
3674+
3675+
#[test]
3676+
fn test_in_list_with_columns_null_in_list() -> Result<()> {
3677+
// IN list with NULL scalar (column-reference path)
3678+
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
3679+
let col_a = col("a", &schema)?;
3680+
let batch = RecordBatch::try_new(
3681+
Arc::new(schema),
3682+
vec![Arc::new(Int32Array::from(vec![1, 2]))],
3683+
)?;
3684+
3685+
let list = vec![
3686+
lit(ScalarValue::Int32(None)),
3687+
lit(ScalarValue::Int32(Some(1))),
3688+
];
3689+
let expr = make_in_list_with_columns(col_a, list, false);
3690+
3691+
let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?;
3692+
let result = as_boolean_array(&result);
3693+
// row 0: 1 IN (NULL, 1) → true (true OR null = true)
3694+
// row 1: 2 IN (NULL, 1) → NULL (false OR null = null)
3695+
assert_eq!(result, &BooleanArray::from(vec![Some(true), None]));
3696+
Ok(())
3697+
}
3698+
3699+
#[test]
3700+
fn test_in_list_with_columns_float_nan() -> Result<()> {
3701+
// Verify NaN == NaN is true in the column-reference path
3702+
// (consistent with Arrow's totalOrder semantics)
3703+
let schema = Schema::new(vec![
3704+
Field::new("a", DataType::Float64, false),
3705+
Field::new("b", DataType::Float64, false),
3706+
]);
3707+
let batch = RecordBatch::try_new(
3708+
Arc::new(schema.clone()),
3709+
vec![
3710+
Arc::new(Float64Array::from(vec![f64::NAN, 1.0, f64::NAN])),
3711+
Arc::new(Float64Array::from(vec![f64::NAN, 2.0, 0.0])),
3712+
],
3713+
)?;
3714+
3715+
let col_a = col("a", &schema)?;
3716+
let list = vec![col("b", &schema)?];
3717+
let expr = make_in_list_with_columns(col_a, list, false);
3718+
3719+
let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?;
3720+
let result = as_boolean_array(&result);
3721+
// row 0: NaN IN (NaN) → true
3722+
// row 1: 1.0 IN (2.0) → false
3723+
// row 2: NaN IN (0.0) → false
3724+
assert_eq!(result, &BooleanArray::from(vec![true, false, false]));
3725+
Ok(())
3726+
}
35103727
}

0 commit comments

Comments
 (0)