diff --git a/datafusion/physical-expr/benches/case_when.rs b/datafusion/physical-expr/benches/case_when.rs index eb0886a31e8df..1a71d039d28c2 100644 --- a/datafusion/physical-expr/benches/case_when.rs +++ b/datafusion/physical-expr/benches/case_when.rs @@ -239,6 +239,31 @@ fn run_benchmarks(c: &mut Criterion, batch: &RecordBatch) { ); b.iter(|| black_box(expr.evaluate(black_box(batch)).unwrap())) }); + + // Benchmark for CASE expr WHEN literal THEN column with many branches + // This tests the WithExprLookupTable optimization + c.bench_function( + format!( + "case_when {}x{}: CASE c1 WHEN 0 THEN c2 WHEN 1 THEN c3 ... (20 branches, column THEN)", + batch.num_rows(), + batch.num_columns() + ) + .as_str(), + |b| { + // Create 20 branches mapping different literal values to columns + let when_thens: Vec<_> = (0..20i32) + .map(|i| { + let col_idx = (i as usize % 3) + 1; + let col_name = format!("c{col_idx}"); + (lit(i), col(&col_name, &batch.schema()).unwrap()) + }) + .collect(); + let expr = Arc::new( + case(Some(Arc::clone(&c1)), when_thens, Some(Arc::clone(&c2))).unwrap(), + ); + b.iter(|| black_box(expr.evaluate(black_box(batch)).unwrap())) + }, + ); } struct Options { diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 758317d3d2798..cde783844a9bb 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -38,7 +38,9 @@ use std::borrow::Cow; use std::hash::Hash; use std::{any::Any, sync::Arc}; -use crate::expressions::case::literal_lookup_table::LiteralLookupTable; +use crate::expressions::case::literal_lookup_table::{ + LiteralLookupTable, WhenLiteralIndexMap, try_creating_lookup_table, +}; use arrow::compute::kernels::merge::{MergeIndex, merge, merge_n}; use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; use datafusion_physical_expr_common::datum::compare_with_eq; @@ -81,6 +83,12 @@ enum EvalMethod { /// /// See [`LiteralLookupTable`] for more details WithExprScalarLookupTable(LiteralLookupTable), + + /// This is a specialization for [`EvalMethod::WithExpression`] when the WHEN values are literals + /// but the THEN expressions can be arbitrary expressions (not just literals). + /// + /// Uses HashMap for O(1) branch lookup, then evaluates THEN expressions only for matching rows. + WithExprLookupTable(ExpressionLookupTable), } /// Implementing hash so we can use `derive` on [`EvalMethod`]. @@ -107,6 +115,79 @@ impl PartialEq for LiteralLookupTable { impl Eq for LiteralLookupTable {} +/// Lookup table for CASE expressions where WHEN values are literals but THEN expressions +/// can be arbitrary expressions. +#[derive(Debug)] +struct ExpressionLookupTable { + lookup: Box, + num_branches: usize, + else_index: u32, +} + +impl Hash for ExpressionLookupTable { + fn hash(&self, _state: &mut H) {} +} + +impl PartialEq for ExpressionLookupTable { + fn eq(&self, _other: &Self) -> bool { + true + } +} + +impl Eq for ExpressionLookupTable {} + +impl ExpressionLookupTable { + fn maybe_new(body: &CaseBody) -> Option { + body.expr.as_ref()?; + if body.when_then_expr.len() < 2 { + return None; + } + + let when_literals: Option> = body + .when_then_expr + .iter() + .map(|(when, _)| { + when.as_any() + .downcast_ref::() + .map(|l| l.value().clone()) + }) + .collect(); + + let when_literals = when_literals?; + + let mut seen = IndexSet::new(); + let unique_when_literals: Vec = when_literals + .into_iter() + .filter(|v| !v.is_null() && seen.insert(v.clone())) + .collect(); + + if unique_when_literals.is_empty() { + return None; + } + + let when_data_type = unique_when_literals[0].data_type(); + if unique_when_literals + .iter() + .any(|l| l.data_type() != when_data_type) + { + return None; + } + + let lookup = try_creating_lookup_table(unique_when_literals).ok()?; + + Some(Self { + lookup, + num_branches: body.when_then_expr.len(), + else_index: body.when_then_expr.len() as u32, + }) + } + + fn map_to_branch_indices(&self, base_values: &ArrayRef) -> Result> { + self.lookup + .map_to_when_indices(base_values, self.else_index) + } +} + /// The body of a CASE expression which consists of an optional base expression, the "when/then" /// branches and an optional "else" branch. #[derive(Debug, Hash, PartialEq, Eq)] @@ -644,6 +725,10 @@ impl CaseExpr { return Ok(EvalMethod::WithExprScalarLookupTable(mapping)); } + if let Some(table) = ExpressionLookupTable::maybe_new(body) { + return Ok(EvalMethod::WithExprLookupTable(table)); + } + return Ok(EvalMethod::WithExpression(body.project()?)); } @@ -1170,6 +1255,89 @@ impl CaseExpr { Ok(result) } + + fn with_expr_lookup_table( + &self, + batch: &RecordBatch, + lookup_table: &ExpressionLookupTable, + ) -> Result { + let return_type = self.data_type(&batch.schema())?; + let row_count = batch.num_rows(); + + if row_count == 0 { + return Ok(ColumnarValue::Array(new_empty_array(&return_type))); + } + + let base_expr = self.body.expr.as_ref().unwrap(); + let base_values = base_expr.evaluate(batch)?.into_array(row_count)?; + let branch_indices = lookup_table.map_to_branch_indices(&base_values)?; + + let mut branch_rows: Vec> = + vec![Vec::new(); lookup_table.num_branches + 1]; + for (row_idx, &branch_idx) in branch_indices.iter().enumerate() { + branch_rows[branch_idx as usize].push(row_idx as u32); + } + + for (branch_idx, rows) in branch_rows.iter().enumerate() { + if rows.len() == row_count { + if branch_idx < lookup_table.num_branches { + return self.body.when_then_expr[branch_idx].1.evaluate(batch); + } else { + return if let Some(else_expr) = &self.body.else_expr { + else_expr.evaluate(batch) + } else { + Ok(ColumnarValue::Scalar(ScalarValue::try_new_null( + &return_type, + )?)) + }; + } + } + } + + let mut result_builder = ResultBuilder::new(&return_type, row_count); + + for (branch_idx, rows) in branch_rows + .iter() + .enumerate() + .take(lookup_table.num_branches) + { + if rows.is_empty() { + continue; + } + + let row_indices = Arc::new(UInt32Array::from(rows.clone())) as ArrayRef; + let filter_predicate = create_filter_from_indices(rows, row_count); + let filtered_batch = filter_record_batch(batch, &filter_predicate)?; + + let then_expr = &self.body.when_then_expr[branch_idx].1; + let then_value = then_expr.evaluate(&filtered_batch)?; + + result_builder.add_branch_result(&row_indices, then_value)?; + } + + let else_rows = &branch_rows[lookup_table.num_branches]; + if !else_rows.is_empty() + && let Some(else_expr) = &self.body.else_expr + { + let row_indices = Arc::new(UInt32Array::from(else_rows.clone())) as ArrayRef; + let filter_predicate = create_filter_from_indices(else_rows, row_count); + let filtered_batch = filter_record_batch(batch, &filter_predicate)?; + let else_value = else_expr.evaluate(&filtered_batch)?; + result_builder.add_branch_result(&row_indices, else_value)?; + } + + result_builder.finish() + } +} + +fn create_filter_from_indices(indices: &[u32], total_rows: usize) -> FilterPredicate { + let mut builder = BooleanBufferBuilder::new(total_rows); + builder.append_n(total_rows, false); + for &idx in indices { + builder.set_bit(idx as usize, true); + } + let bool_array = BooleanArray::new(builder.finish(), None); + create_filter(&bool_array, true) } impl PhysicalExpr for CaseExpr { @@ -1268,6 +1436,9 @@ impl PhysicalExpr for CaseExpr { EvalMethod::WithExprScalarLookupTable(lookup_table) => { self.with_lookup_table(batch, lookup_table) } + EvalMethod::WithExprLookupTable(lookup_table) => { + self.with_expr_lookup_table(batch, lookup_table) + } } } diff --git a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs index 67b045f9988f8..8e380b09b5cf4 100644 --- a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs +++ b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs @@ -231,7 +231,9 @@ impl LiteralLookupTable { /// ``` /// /// this will map to 0, to 1, to 2, to 3 -pub(super) trait WhenLiteralIndexMap: Debug + Send + Sync { +pub(in super::super) trait WhenLiteralIndexMap: + Debug + Send + Sync +{ /// Given an array of values, returns a vector of WHEN clause indices corresponding to each value in the provided array. /// /// For example, for this CASE expression: @@ -260,7 +262,7 @@ pub(super) trait WhenLiteralIndexMap: Debug + Send + Sync { ) -> datafusion_common::Result>; } -fn try_creating_lookup_table( +pub(in super::super) fn try_creating_lookup_table( unique_non_null_literals: Vec, ) -> datafusion_common::Result> { assert_ne!(