Skip to content

Commit c446ba3

Browse files
perf(in_list): Optimize primitive and ByteView filters with Bloom filter and single hashing
This commit introduces a unified high-performance path for IN list operations: - Implements an optimal 3-bit Bloom filter (512-bit space) for both `PrimitiveFilter` and `ByteViewMaskedFilter`. - Refactors filters to use `hashbrown::HashTable` directly, ensuring the hash is computed only once and reused for Bloom check, stage-1 probe, and stage-2 verification. - Centralizes chunked `u64` result building in `build_no_nulls_result` to bypass bit-by-bit builder overhead in null-free paths. - Adds a saturation threshold (256 elements) to automatically bypass the Bloom check for large lists, preventing overhead. - Includes an exhaustive structured test suite verifying 3,000+ combinations of types, sizes, and null logic. - Fixes several clippy warnings and ensures compliance with project standards. - These changes significantly reduce probe latency and memory pressure for medium-to-large IN lists across all optimized types.
1 parent 0f312b1 commit c446ba3

File tree

4 files changed

+528
-194
lines changed

4 files changed

+528
-194
lines changed

datafusion/physical-expr/src/expressions/in_list.rs

Lines changed: 298 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3033,4 +3033,302 @@ mod tests {
30333033

30343034
Ok(())
30353035
}
3036+
3037+
/// Exhaustive structured test that verifies all lookup strategies, boundaries, and DataFusion types.
3038+
#[test]
3039+
fn test_in_list_exhaustive() -> Result<()> {
3040+
let types = vec![
3041+
DataType::Int8,
3042+
DataType::Int16,
3043+
DataType::Int32,
3044+
DataType::Int64,
3045+
DataType::Float32,
3046+
DataType::Float64,
3047+
DataType::Utf8,
3048+
DataType::LargeUtf8,
3049+
DataType::Binary,
3050+
DataType::LargeBinary,
3051+
DataType::Utf8View,
3052+
DataType::BinaryView,
3053+
DataType::Date32,
3054+
DataType::Date64,
3055+
DataType::Timestamp(TimeUnit::Nanosecond, None),
3056+
DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())),
3057+
DataType::Decimal128(10, 2),
3058+
DataType::Decimal256(38, 10),
3059+
DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
3060+
DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
3061+
DataType::Struct(Fields::from(vec![Field::new("f1", DataType::Int32, true)])),
3062+
];
3063+
3064+
// Critical thresholds: 4 (Decimal boundary), 16 (8B boundary), 32 (4B boundary), 256 (Bloom boundary)
3065+
// Test all sizes 1..33 to cover ALL unrolled branchless filters, plus saturation boundaries.
3066+
let mut sizes: Vec<usize> = (1..=33).collect();
3067+
sizes.extend(vec![100, 256, 257, 512]);
3068+
3069+
for dt in &types {
3070+
for size in &sizes {
3071+
for input_has_nulls in [false, true] {
3072+
for list_has_nulls in [false, true] {
3073+
for negated in [false, true] {
3074+
for input_is_scalar in [false, true] {
3075+
run_exhaustive_test_case(
3076+
dt,
3077+
*size,
3078+
input_has_nulls,
3079+
list_has_nulls,
3080+
negated,
3081+
input_is_scalar,
3082+
)?;
3083+
}
3084+
}
3085+
}
3086+
}
3087+
}
3088+
}
3089+
Ok(())
3090+
}
3091+
3092+
fn run_exhaustive_test_case(
3093+
dt: &DataType,
3094+
list_size: usize,
3095+
input_has_nulls: bool,
3096+
list_has_nulls: bool,
3097+
negated: bool,
3098+
input_is_scalar: bool,
3099+
) -> Result<()> {
3100+
let array_len = 100;
3101+
let (input_array, in_list_scalars) = generate_deterministic_data(
3102+
dt,
3103+
array_len,
3104+
list_size,
3105+
input_has_nulls,
3106+
list_has_nulls,
3107+
);
3108+
3109+
let schema = Schema::new(vec![Field::new("a", dt.clone(), true)]);
3110+
let batch = RecordBatch::try_new(
3111+
Arc::new(schema.clone()),
3112+
vec![Arc::clone(&input_array)],
3113+
)?;
3114+
let col_a = col("a", &schema)?;
3115+
let list_exprs: Vec<Arc<dyn PhysicalExpr>> =
3116+
in_list_scalars.iter().map(|s| lit(s.clone())).collect();
3117+
3118+
let opt_result_arr = if input_is_scalar {
3119+
let scalar = ScalarValue::try_from_array(&input_array, 0)?;
3120+
let expr =
3121+
InListExpr::try_new(lit(scalar), list_exprs.clone(), negated, &schema)?;
3122+
expr.evaluate(&batch)?.into_array(batch.num_rows())?
3123+
} else {
3124+
let expr = InListExpr::try_new(
3125+
Arc::clone(&col_a),
3126+
list_exprs.clone(),
3127+
negated,
3128+
&schema,
3129+
)?;
3130+
expr.evaluate(&batch)?.into_array(batch.num_rows())?
3131+
};
3132+
3133+
let opt_result_bool = opt_result_arr
3134+
.as_any()
3135+
.downcast_ref::<BooleanArray>()
3136+
.unwrap();
3137+
let ref_result = compute_reference_result(
3138+
&batch,
3139+
&col_a,
3140+
&in_list_scalars,
3141+
negated,
3142+
input_is_scalar,
3143+
)?;
3144+
3145+
assert_eq!(
3146+
opt_result_bool, &ref_result,
3147+
"FAIL: Type={dt:?}, Size={list_size}, InNull={input_has_nulls}, ListNull={list_has_nulls}, Neg={negated}, Scalar={input_is_scalar}"
3148+
);
3149+
3150+
Ok(())
3151+
}
3152+
3153+
fn generate_deterministic_data(
3154+
dt: &DataType,
3155+
len: usize,
3156+
list_size: usize,
3157+
input_has_nulls: bool,
3158+
list_has_nulls: bool,
3159+
) -> (ArrayRef, Vec<ScalarValue>) {
3160+
let base_dt = if let DataType::Dictionary(_, v) = dt {
3161+
v.as_ref()
3162+
} else {
3163+
dt
3164+
};
3165+
3166+
let input_scalars: Vec<ScalarValue> = (0..len)
3167+
.map(|i| {
3168+
if input_has_nulls && i % 10 == 0 {
3169+
ScalarValue::try_new_null(base_dt).unwrap()
3170+
} else {
3171+
deterministic_scalar(base_dt, i)
3172+
}
3173+
})
3174+
.collect();
3175+
let mut input_array = ScalarValue::iter_to_array(input_scalars).unwrap();
3176+
3177+
if let DataType::Dictionary(k, _) = dt {
3178+
match k.as_ref() {
3179+
DataType::Int32 => {
3180+
let keys = Int32Array::from(
3181+
(0..len as i32).map(|i| i % 10).collect::<Vec<_>>(),
3182+
);
3183+
input_array =
3184+
Arc::new(DictionaryArray::<Int32Type>::new(keys, input_array))
3185+
as ArrayRef;
3186+
}
3187+
DataType::Int8 => {
3188+
let keys = Int8Array::from(
3189+
(0..len as i8).map(|i| i % 10).collect::<Vec<_>>(),
3190+
);
3191+
input_array =
3192+
Arc::new(DictionaryArray::<Int8Type>::new(keys, input_array))
3193+
as ArrayRef;
3194+
}
3195+
_ => unreachable!(),
3196+
};
3197+
}
3198+
3199+
let mut in_list: Vec<ScalarValue> = (0..list_size)
3200+
.map(|i| deterministic_scalar(base_dt, i * 2))
3201+
.collect();
3202+
3203+
if list_has_nulls && !in_list.is_empty() {
3204+
in_list[0] = ScalarValue::try_new_null(base_dt).unwrap();
3205+
}
3206+
3207+
(input_array, in_list)
3208+
}
3209+
3210+
fn deterministic_scalar(dt: &DataType, i: usize) -> ScalarValue {
3211+
match dt {
3212+
DataType::Int8 => ScalarValue::Int8(Some(i as i8)),
3213+
DataType::Int16 => ScalarValue::Int16(Some(i as i16)),
3214+
DataType::Int32 => ScalarValue::Int32(Some(i as i32)),
3215+
DataType::Int64 => ScalarValue::Int64(Some(i as i64)),
3216+
DataType::Float32 => ScalarValue::Float32(Some(i as f32)),
3217+
DataType::Float64 => ScalarValue::Float64(Some(i as f64)),
3218+
DataType::Utf8 | DataType::Utf8View | DataType::LargeUtf8 => {
3219+
let s = format!("val-{i}");
3220+
if dt == &DataType::Utf8View {
3221+
ScalarValue::Utf8View(Some(s))
3222+
} else if dt == &DataType::LargeUtf8 {
3223+
ScalarValue::LargeUtf8(Some(s))
3224+
} else {
3225+
ScalarValue::Utf8(Some(s))
3226+
}
3227+
}
3228+
DataType::Binary | DataType::LargeBinary | DataType::BinaryView => {
3229+
let b = format!("bin-{i}").into_bytes();
3230+
if dt == &DataType::BinaryView {
3231+
ScalarValue::BinaryView(Some(b))
3232+
} else if dt == &DataType::LargeBinary {
3233+
ScalarValue::LargeBinary(Some(b))
3234+
} else {
3235+
ScalarValue::Binary(Some(b))
3236+
}
3237+
}
3238+
DataType::Date32 => ScalarValue::Date32(Some(i as i32)),
3239+
DataType::Date64 => ScalarValue::Date64(Some(i as i64)),
3240+
DataType::Timestamp(unit, tz) => {
3241+
let val = (i * 1000) as i64;
3242+
match unit {
3243+
TimeUnit::Nanosecond => {
3244+
ScalarValue::TimestampNanosecond(Some(val), tz.clone())
3245+
}
3246+
TimeUnit::Microsecond => {
3247+
ScalarValue::TimestampMicrosecond(Some(val), tz.clone())
3248+
}
3249+
TimeUnit::Millisecond => {
3250+
ScalarValue::TimestampMillisecond(Some(val), tz.clone())
3251+
}
3252+
TimeUnit::Second => {
3253+
ScalarValue::TimestampSecond(Some(val), tz.clone())
3254+
}
3255+
}
3256+
}
3257+
DataType::Decimal128(p, s) => {
3258+
ScalarValue::Decimal128(Some(i as i128), *p, *s)
3259+
}
3260+
DataType::Decimal256(p, s) => {
3261+
ScalarValue::Decimal256(Some(i256::from_i128(i as i128)), *p, *s)
3262+
}
3263+
DataType::Struct(fields) => {
3264+
let values: Vec<ScalarValue> = fields
3265+
.iter()
3266+
.map(|f| deterministic_scalar(f.data_type(), i))
3267+
.collect();
3268+
ScalarValue::Struct(Arc::new(
3269+
StructArray::try_new(
3270+
fields.clone(),
3271+
values.iter().map(|s| s.to_array().unwrap()).collect(),
3272+
None,
3273+
)
3274+
.unwrap(),
3275+
))
3276+
}
3277+
_ => ScalarValue::Int32(Some(i as i32)),
3278+
}
3279+
}
3280+
3281+
fn compute_reference_result(
3282+
batch: &RecordBatch,
3283+
col_a: &Arc<dyn PhysicalExpr>,
3284+
list: &[ScalarValue],
3285+
negated: bool,
3286+
input_is_scalar: bool,
3287+
) -> Result<BooleanArray> {
3288+
let input_array = if input_is_scalar {
3289+
let val = col_a.evaluate(batch)?;
3290+
let arr = val.into_array(batch.num_rows())?;
3291+
let s = ScalarValue::try_from_array(&arr, 0)?;
3292+
s.to_array_of_size(batch.num_rows())?
3293+
} else {
3294+
col_a.evaluate(batch)?.into_array(batch.num_rows())?
3295+
};
3296+
3297+
let num_rows = batch.num_rows();
3298+
let mut result = Vec::with_capacity(num_rows);
3299+
for i in 0..num_rows {
3300+
let mut val = ScalarValue::try_from_array(&input_array, i)?;
3301+
if let ScalarValue::Dictionary(_, v) = val {
3302+
val = *v;
3303+
}
3304+
3305+
if val.is_null() {
3306+
result.push(None);
3307+
continue;
3308+
}
3309+
3310+
let mut found = false;
3311+
let mut has_null = false;
3312+
for list_val in list {
3313+
if list_val.is_null() {
3314+
has_null = true;
3315+
continue;
3316+
}
3317+
if list_val == &val {
3318+
found = true;
3319+
break;
3320+
}
3321+
}
3322+
3323+
let res = if found {
3324+
Some(!negated)
3325+
} else if has_null {
3326+
None
3327+
} else {
3328+
Some(negated)
3329+
};
3330+
result.push(res);
3331+
}
3332+
Ok(BooleanArray::from(result))
3333+
}
30363334
}

0 commit comments

Comments
 (0)