Skip to content

Commit 84a79e1

Browse files
fix: InList Dictionary filter pushdown type mismatch (apache#20962)
## Which issue does this PR close? Closes apache#20937 ## Rationale for this change `ArrayStaticFilter::contains()` unconditionally unwraps dictionary-encoded needles to evaluate against distinct values. This only works when `in_array` has the same type as the dictionary's value type. When `in_array` is also Dictionary, the unwrap creates a type mismatch in `make_comparator`. This became reachable after: - apache#20505 which removed dictionary flattening in the InList builder, allowing Dictionary arrays to reach `ArrayStaticFilter` via HashJoin dynamic filter pushdown with `pushdown_filters` enabled. ## What changes are included in this PR? - Guard the dictionary unwrap in `ArrayStaticFilter::contains()` to only fire when the dictionary value type matches `in_array`'s type. When both sides are Dictionary, fall through to `make_comparator(Dict, Dict)` which arrow-ord handles natively. - Update the sqllogictest from apache#20960 to expect success. - Add unit tests covering all `try_new_from_array` type combinations (primitive specializations, Utf8, Dictionary, Struct). ## Are these changes tested? Yes — unit tests and sqllogictest.
1 parent 11b9693 commit 84a79e1

File tree

2 files changed

+216
-4
lines changed

2 files changed

+216
-4
lines changed

datafusion/physical-expr/src/expressions/in_list.rs

Lines changed: 210 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,18 @@ impl StaticFilter for ArrayStaticFilter {
9999
));
100100
}
101101

102+
// Unwrap dictionary-encoded needles when the value type matches
103+
// in_array, evaluating against the dictionary values and mapping
104+
// back via keys.
102105
downcast_dictionary_array! {
103106
v => {
104-
let values_contains = self.contains(v.values().as_ref(), negated)?;
105-
let result = take(&values_contains, v.keys(), None)?;
106-
return Ok(downcast_array(result.as_ref()))
107+
// Only unwrap when the haystack (in_array) type matches
108+
// the dictionary value type
109+
if v.values().data_type() == self.in_array.data_type() {
110+
let values_contains = self.contains(v.values().as_ref(), negated)?;
111+
let result = take(&values_contains, v.keys(), None)?;
112+
return Ok(downcast_array(result.as_ref()));
113+
}
107114
}
108115
_ => {}
109116
}
@@ -3878,4 +3885,204 @@ mod tests {
38783885
);
38793886
Ok(())
38803887
}
3888+
3889+
// -----------------------------------------------------------------------
3890+
// Tests for try_new_from_array: evaluates `needle IN in_array`.
3891+
//
3892+
// This exercises the code path used by HashJoin dynamic filter pushdown,
3893+
// where in_array is built directly from the join's build-side arrays.
3894+
// Unlike try_new (used by SQL IN expressions), which always produces a
3895+
// non-Dictionary in_array because evaluate_list() flattens Dictionary
3896+
// scalars, try_new_from_array passes the array directly and can produce
3897+
// a Dictionary in_array.
3898+
// -----------------------------------------------------------------------
3899+
3900+
fn wrap_in_dict(array: ArrayRef) -> ArrayRef {
3901+
let keys = Int32Array::from((0..array.len() as i32).collect::<Vec<_>>());
3902+
Arc::new(DictionaryArray::new(keys, array))
3903+
}
3904+
3905+
/// Evaluates `needle IN in_array` via try_new_from_array, the same
3906+
/// path used by HashJoin dynamic filter pushdown (not the SQL literal
3907+
/// IN path which goes through try_new).
3908+
fn eval_in_list_from_array(
3909+
needle: ArrayRef,
3910+
in_array: ArrayRef,
3911+
) -> Result<BooleanArray> {
3912+
let schema =
3913+
Schema::new(vec![Field::new("a", needle.data_type().clone(), false)]);
3914+
let col_a = col("a", &schema)?;
3915+
let expr = Arc::new(InListExpr::try_new_from_array(col_a, in_array, false)?)
3916+
as Arc<dyn PhysicalExpr>;
3917+
let batch = RecordBatch::try_new(Arc::new(schema), vec![needle])?;
3918+
let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?;
3919+
Ok(as_boolean_array(&result).clone())
3920+
}
3921+
3922+
#[test]
3923+
fn test_in_list_from_array_type_combinations() -> Result<()> {
3924+
use arrow::compute::cast;
3925+
3926+
// All cases: needle[0] and needle[2] match, needle[1] does not.
3927+
let expected = BooleanArray::from(vec![Some(true), Some(false), Some(true)]);
3928+
3929+
// Base arrays cast to each target type
3930+
let base_in = Arc::new(Int64Array::from(vec![1i64, 2, 3])) as ArrayRef;
3931+
let base_needle = Arc::new(Int64Array::from(vec![1i64, 4, 2])) as ArrayRef;
3932+
3933+
// Test all specializations in instantiate_static_filter
3934+
let primitive_types = vec![
3935+
DataType::Int8,
3936+
DataType::Int16,
3937+
DataType::Int32,
3938+
DataType::Int64,
3939+
DataType::UInt8,
3940+
DataType::UInt16,
3941+
DataType::UInt32,
3942+
DataType::UInt64,
3943+
DataType::Float32,
3944+
DataType::Float64,
3945+
];
3946+
3947+
for dt in &primitive_types {
3948+
let in_array = cast(&base_in, dt)?;
3949+
let needle = cast(&base_needle, dt)?;
3950+
3951+
// T in_array, T needle
3952+
assert_eq!(
3953+
expected,
3954+
eval_in_list_from_array(Arc::clone(&needle), Arc::clone(&in_array))?,
3955+
"same-type failed for {dt:?}"
3956+
);
3957+
3958+
// T in_array, Dict(Int32, T) needle
3959+
assert_eq!(
3960+
expected,
3961+
eval_in_list_from_array(wrap_in_dict(needle), in_array)?,
3962+
"dict-needle failed for {dt:?}"
3963+
);
3964+
}
3965+
3966+
// Utf8 (falls through to ArrayStaticFilter)
3967+
let utf8_in = Arc::new(StringArray::from(vec!["a", "b", "c"])) as ArrayRef;
3968+
let utf8_needle = Arc::new(StringArray::from(vec!["a", "d", "b"])) as ArrayRef;
3969+
3970+
// Utf8 in_array, Utf8 needle
3971+
assert_eq!(
3972+
expected,
3973+
eval_in_list_from_array(Arc::clone(&utf8_needle), Arc::clone(&utf8_in),)?
3974+
);
3975+
3976+
// Utf8 in_array, Dict(Utf8) needle
3977+
assert_eq!(
3978+
expected,
3979+
eval_in_list_from_array(
3980+
wrap_in_dict(Arc::clone(&utf8_needle)),
3981+
Arc::clone(&utf8_in),
3982+
)?
3983+
);
3984+
3985+
// Dict(Utf8) in_array, Dict(Utf8) needle: the #20937 bug
3986+
assert_eq!(
3987+
expected,
3988+
eval_in_list_from_array(
3989+
wrap_in_dict(Arc::clone(&utf8_needle)),
3990+
wrap_in_dict(Arc::clone(&utf8_in)),
3991+
)?
3992+
);
3993+
3994+
// Struct in_array, Struct needle: multi-column join
3995+
let struct_fields = Fields::from(vec![
3996+
Field::new("c0", DataType::Utf8, true),
3997+
Field::new("c1", DataType::Int64, true),
3998+
]);
3999+
let make_struct = |c0: ArrayRef, c1: ArrayRef| -> ArrayRef {
4000+
let pairs: Vec<(FieldRef, ArrayRef)> =
4001+
struct_fields.iter().cloned().zip([c0, c1]).collect();
4002+
Arc::new(StructArray::from(pairs))
4003+
};
4004+
assert_eq!(
4005+
expected,
4006+
eval_in_list_from_array(
4007+
make_struct(
4008+
Arc::clone(&utf8_needle),
4009+
Arc::new(Int64Array::from(vec![1, 4, 2])),
4010+
),
4011+
make_struct(
4012+
Arc::clone(&utf8_in),
4013+
Arc::new(Int64Array::from(vec![1, 2, 3])),
4014+
),
4015+
)?
4016+
);
4017+
4018+
// Struct with Dict fields: multi-column Dict join
4019+
let dict_struct_fields = Fields::from(vec![
4020+
Field::new(
4021+
"c0",
4022+
DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
4023+
true,
4024+
),
4025+
Field::new("c1", DataType::Int64, true),
4026+
]);
4027+
let make_dict_struct = |c0: ArrayRef, c1: ArrayRef| -> ArrayRef {
4028+
let pairs: Vec<(FieldRef, ArrayRef)> =
4029+
dict_struct_fields.iter().cloned().zip([c0, c1]).collect();
4030+
Arc::new(StructArray::from(pairs))
4031+
};
4032+
assert_eq!(
4033+
expected,
4034+
eval_in_list_from_array(
4035+
make_dict_struct(
4036+
wrap_in_dict(Arc::clone(&utf8_needle)),
4037+
Arc::new(Int64Array::from(vec![1, 4, 2])),
4038+
),
4039+
make_dict_struct(
4040+
wrap_in_dict(Arc::clone(&utf8_in)),
4041+
Arc::new(Int64Array::from(vec![1, 2, 3])),
4042+
),
4043+
)?
4044+
);
4045+
4046+
Ok(())
4047+
}
4048+
4049+
#[test]
4050+
fn test_in_list_from_array_type_mismatch_errors() -> Result<()> {
4051+
// Utf8 needle, Dict(Utf8) in_array
4052+
let err = eval_in_list_from_array(
4053+
Arc::new(StringArray::from(vec!["a", "d", "b"])),
4054+
wrap_in_dict(Arc::new(StringArray::from(vec!["a", "b", "c"]))),
4055+
)
4056+
.unwrap_err()
4057+
.to_string();
4058+
assert!(
4059+
err.contains("Can't compare arrays of different types"),
4060+
"{err}"
4061+
);
4062+
4063+
// Dict(Utf8) needle, Int64 in_array: specialized Int64StaticFilter
4064+
// rejects the Utf8 dictionary values at construction time
4065+
let err = eval_in_list_from_array(
4066+
wrap_in_dict(Arc::new(StringArray::from(vec!["a", "d", "b"]))),
4067+
Arc::new(Int64Array::from(vec![1, 2, 3])),
4068+
)
4069+
.unwrap_err()
4070+
.to_string();
4071+
assert!(err.contains("Failed to downcast"), "{err}");
4072+
4073+
// Dict(Int64) needle, Dict(Utf8) in_array: both Dict but different
4074+
// value types, make_comparator rejects the comparison
4075+
let err = eval_in_list_from_array(
4076+
wrap_in_dict(Arc::new(Int64Array::from(vec![1, 4, 2]))),
4077+
wrap_in_dict(Arc::new(StringArray::from(vec!["a", "b", "c"]))),
4078+
)
4079+
.unwrap_err()
4080+
.to_string();
4081+
assert!(
4082+
err.contains("Can't compare arrays of different types"),
4083+
"{err}"
4084+
);
4085+
4086+
Ok(())
4087+
}
38814088
}

datafusion/sqllogictest/test_files/parquet_filter_pushdown.slt

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -918,13 +918,18 @@ CREATE EXTERNAL TABLE dict_filter_bug
918918
STORED AS PARQUET
919919
LOCATION 'test_files/scratch/parquet_filter_pushdown/dict_filter_bug.parquet';
920920

921-
query error Can't compare arrays of different types
921+
query TR
922922
SELECT t.tag1, t.value
923923
FROM dict_filter_bug t
924924
JOIN (VALUES ('A'), ('B')) AS v(c1)
925925
ON t.tag1 = v.c1
926926
ORDER BY t.tag1, t.value
927927
LIMIT 4;
928+
----
929+
A 0
930+
A 26
931+
A 52
932+
A 78
928933

929934
# Cleanup
930935
statement ok

0 commit comments

Comments
 (0)