Skip to content

Commit c73e9f9

Browse files
committed
Extend not to support not(expr)
1 parent 295c46d commit c73e9f9

File tree

4 files changed

+119
-17
lines changed

4 files changed

+119
-17
lines changed

datafusion/core/tests/parquet/mod.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ use datafusion::{
3737
prelude::{ParquetReadOptions, SessionConfig, SessionContext},
3838
};
3939
use datafusion_expr::{Expr, LogicalPlan, LogicalPlanBuilder};
40+
use datafusion_physical_plan::execute_stream;
4041
use parquet::arrow::ArrowWriter;
4142
use parquet::file::properties::{EnabledStatistics, WriterProperties};
4243
use std::sync::Arc;
@@ -225,6 +226,7 @@ impl ContextWithParquet {
225226
) -> Self {
226227
// Use a single partition for deterministic results no matter how many CPUs the host has
227228
config = config.with_target_partitions(1);
229+
config.options_mut().execution.parquet.pushdown_filters = true;
228230
let file = match unit {
229231
Unit::RowGroup(row_per_group) => {
230232
config = config.with_parquet_bloom_filter_pruning(true);
@@ -308,6 +310,15 @@ impl ContextWithParquet {
308310
.await
309311
.expect("creating physical plan");
310312

313+
/*
314+
use arrow::util::pretty::print_batches;
315+
use futures::TryStreamExt;
316+
let res =
317+
execute_stream(physical_plan.clone(), self.ctx.task_ctx().clone()).unwrap();
318+
let batches = res.try_collect::<Vec<_>>().await.unwrap();
319+
print_batches(&batches).unwrap();
320+
*/
321+
311322
let task_ctx = state.task_ctx();
312323
let results = datafusion::physical_plan::collect(physical_plan.clone(), task_ctx)
313324
.await

datafusion/core/tests/parquet/row_group_pruning.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -174,10 +174,11 @@ impl RowGroupPruningTest {
174174
self,
175175
schema: Arc<Schema>,
176176
batches: Vec<RecordBatch>,
177+
max_row_per_row_group: usize,
177178
) {
178179
let output = ContextWithParquet::with_custom_data(
179180
self.scenario,
180-
RowGroup(2),
181+
RowGroup(max_row_per_row_group),
181182
schema,
182183
batches,
183184
)
@@ -1745,7 +1746,7 @@ async fn test_limit_pruning() -> datafusion_common::error::Result<()> {
17451746
// So 3 row groups are effectively pruned due to limit pruning.
17461747

17471748
let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, false)]));
1748-
let query = "explain verbose SELECT c1 FROM t WHERE c1 > 0 LIMIT 2";
1749+
let query = "SELECT c1 FROM t WHERE c1 > 0 LIMIT 2";
17491750

17501751
let batches = vec![
17511752
make_i32_batch("c1", vec![1, 2])?, // RG0: Fully matched, 2 rows
@@ -1764,8 +1765,8 @@ async fn test_limit_pruning() -> datafusion_common::error::Result<()> {
17641765
.with_pruned_by_bloom_filter(Some(0))
17651766
.with_matched_by_stats(Some(3)) // RG0, RG1, RG2 are matched by stats (c1 > 0)
17661767
.with_pruned_by_stats(Some(1)) // RG3 is pruned by stats (c1 = [-1, 0] does not satisfy c1 > 0)
1767-
// .with_limit_pruned_row_groups(Some(2)) // RG1, RG2 are pruned by limit. (RG3 is already pruned by stats)
1768-
.test_row_group_prune_with_custom_data(schema, batches)
1768+
.with_limit_pruned_row_groups(Some(2)) // RG1, RG2 are pruned by limit. (RG3 is already pruned by stats)
1769+
.test_row_group_prune_with_custom_data(schema, batches, 2)
17691770
.await;
17701771

17711772
Ok(())

datafusion/datasource-parquet/src/row_group_filter.rs

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,8 @@ impl RowGroupAccessPlanFilter {
183183
match predicate.prune(&pruning_stats) {
184184
Ok(values) => {
185185
let mut new_access_plan = ParquetAccessPlan::new_all(groups.len());
186-
let mut fully_contained_candidates_original_idx: Vec<usize> = Vec::new();
186+
let mut fully_contained_candidates_original_idxes: Vec<usize> =
187+
Vec::new();
187188

188189
for (idx_in_pruning_stats_result, &pruning_result) in
189190
values.iter().enumerate()
@@ -194,13 +195,13 @@ impl RowGroupAccessPlanFilter {
194195
new_access_plan.skip(original_row_group_idx);
195196
metrics.row_groups_pruned_statistics.add(1);
196197
} else {
197-
fully_contained_candidates_original_idx
198+
fully_contained_candidates_original_idxes
198199
.push(original_row_group_idx);
199200
metrics.row_groups_matched_statistics.add(1);
200201
}
201202
}
202203

203-
if !fully_contained_candidates_original_idx.is_empty() {
204+
if !fully_contained_candidates_original_idxes.is_empty() {
204205
// Use NotExpr to create the inverted predicate
205206
let inverted_expr =
206207
Arc::new(NotExpr::new(predicate.orig_expr().clone()));
@@ -210,18 +211,20 @@ impl RowGroupAccessPlanFilter {
210211
) {
211212
let inverted_pruning_stats = RowGroupPruningStatistics {
212213
parquet_schema,
213-
row_group_metadatas: fully_contained_candidates_original_idx
214-
.iter()
215-
.map(|&i| &groups[i])
216-
.collect::<Vec<_>>(),
214+
row_group_metadatas:
215+
fully_contained_candidates_original_idxes
216+
.iter()
217+
.map(|&i| &groups[i])
218+
.collect::<Vec<_>>(),
217219
arrow_schema,
218220
};
219-
220221
if let Ok(inverted_values) =
221222
inverted_predicate.prune(&inverted_pruning_stats)
222223
{
223224
for (i, &original_row_group_idx) in
224-
fully_contained_candidates_original_idx.iter().enumerate()
225+
fully_contained_candidates_original_idxes
226+
.iter()
227+
.enumerate()
225228
{
226229
// If the inverted predicate *also* prunes this row group (meaning inverted_values[i] is false),
227230
// it implies that *all* rows in this group satisfy the original predicate.

datafusion/pruning/src/pruning_predicate.rs

Lines changed: 91 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1415,14 +1415,39 @@ fn build_predicate_expression(
14151415
.unwrap_or_else(|| unhandled_hook.handle(expr));
14161416
}
14171417
if let Some(not) = expr_any.downcast_ref::<phys_expr::NotExpr>() {
1418-
// match !col (don't do so recursively)
14191418
if let Some(col) = not.arg().as_any().downcast_ref::<phys_expr::Column>() {
14201419
return build_single_column_expr(col, schema, required_columns, true)
14211420
.unwrap_or_else(|| unhandled_hook.handle(expr));
1422-
} else {
1421+
}
1422+
1423+
let inner_expr = build_predicate_expression(
1424+
not.arg(),
1425+
schema,
1426+
required_columns,
1427+
unhandled_hook,
1428+
);
1429+
1430+
// Only apply NOT if the inner expression is NOT a true literal
1431+
// (because true literals may come from unhandled cases)
1432+
if is_always_true(&inner_expr) {
1433+
// Conservative approach: if inner returns true (possibly unhandled),
1434+
// then NOT should also return true (unhandled) to be safe
14231435
return unhandled_hook.handle(expr);
14241436
}
1437+
1438+
// Handle other boolean literals
1439+
if let Some(literal) = inner_expr.as_any().downcast_ref::<phys_expr::Literal>() {
1440+
if let ScalarValue::Boolean(Some(val)) = literal.value() {
1441+
return Arc::new(phys_expr::Literal::new(ScalarValue::Boolean(Some(
1442+
!val,
1443+
))));
1444+
}
1445+
}
1446+
1447+
// Apply NOT to the result
1448+
return Arc::new(phys_expr::NotExpr::new(inner_expr));
14251449
}
1450+
14261451
if let Some(in_list) = expr_any.downcast_ref::<phys_expr::InListExpr>() {
14271452
if !in_list.list().is_empty()
14281453
&& in_list.list().len() <= MAX_LIST_VALUE_SIZE_REWRITE
@@ -1868,7 +1893,7 @@ mod tests {
18681893

18691894
use super::*;
18701895
use datafusion_common::test_util::batches_to_string;
1871-
use datafusion_expr::{and, col, lit, or};
1896+
use datafusion_expr::{and, col, lit, not, or};
18721897
use insta::assert_snapshot;
18731898

18741899
use arrow::array::Decimal128Array;
@@ -4422,7 +4447,7 @@ mod tests {
44224447
true,
44234448
// s1 ["AB", "A\u{10ffff}\u{10ffff}\u{10ffff}"] ==> some rows could pass (must keep)
44244449
true,
4425-
// s1 ["A\u{10ffff}\u{10ffff}", "A\u{10ffff}\u{10ffff}"] ==> no row match. (min, max) maybe truncate
4450+
// s1 ["A\u{10ffff}\u{10ffff}", "A\u{10ffff}\u{10ffff}"] ==> no row match. (min, max) maybe truncate
44264451
// original (min, max) maybe ("A\u{10ffff}\u{10ffff}\u{10ffff}", "A\u{10ffff}\u{10ffff}\u{10ffff}\u{10ffff}")
44274452
true,
44284453
];
@@ -5175,4 +5200,66 @@ mod tests {
51755200
"c1_null_count@2 != row_count@3 AND c1_min@0 <= a AND a <= c1_max@1";
51765201
assert_eq!(res.to_string(), expected);
51775202
}
5203+
5204+
#[test]
5205+
fn test_not_expression_unhandled_inner_true() -> Result<()> {
5206+
// Test case: when inner expression returns true (unhandled),
5207+
// NOT should also return true (unhandled) for safety
5208+
let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]);
5209+
5210+
// NOT(c1) for Int32 returns true because build_single_column_expr
5211+
// only handles boolean columns, so non-boolean columns fall back to unhandled_hook
5212+
let expr = not(col("c1"));
5213+
let predicate_expr =
5214+
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
5215+
assert_eq!(predicate_expr.to_string(), "true");
5216+
Ok(())
5217+
}
5218+
5219+
#[test]
5220+
fn test_not_expression_boolean_literal_handling() -> Result<()> {
5221+
let schema = Schema::empty();
5222+
5223+
// NOT(false) -> true
5224+
let expr = not(lit(false));
5225+
let predicate_expr =
5226+
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
5227+
assert_eq!(predicate_expr.to_string(), "true");
5228+
5229+
// NOT(true) -> true (conservatively)
5230+
let expr = not(lit(true));
5231+
let predicate_expr =
5232+
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
5233+
assert_eq!(predicate_expr.to_string(), "true");
5234+
5235+
Ok(())
5236+
}
5237+
5238+
#[test]
5239+
fn test_not_expression_wraps_complex_expressions() -> Result<()> {
5240+
let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]);
5241+
5242+
let expr = not(col("c1").gt(lit(5)));
5243+
let predicate_expr =
5244+
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
5245+
5246+
let result_str = predicate_expr.to_string();
5247+
assert_eq!(
5248+
result_str,
5249+
"NOT c1_null_count@1 != row_count@2 AND c1_max@0 > 5"
5250+
);
5251+
5252+
// NOT(c1 = 10)
5253+
let expr = not(col("c1").eq(lit(10)));
5254+
let predicate_expr =
5255+
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
5256+
5257+
let result_str = predicate_expr.to_string();
5258+
assert_eq!(
5259+
result_str,
5260+
"NOT c1_null_count@2 != row_count@3 AND c1_min@0 <= 10 AND 10 <= c1_max@1"
5261+
);
5262+
5263+
Ok(())
5264+
}
51785265
}

0 commit comments

Comments
 (0)