Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions datafusion/physical-expr/benches/case_when.rs
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,31 @@ fn run_benchmarks(c: &mut Criterion, batch: &RecordBatch) {
);
b.iter(|| black_box(expr.evaluate(black_box(batch)).unwrap()))
});

// Benchmark for CASE expr WHEN literal THEN column with many branches
// This tests the WithExprLookupTable optimization
c.bench_function(
format!(
"case_when {}x{}: CASE c1 WHEN 0 THEN c2 WHEN 1 THEN c3 ... (20 branches, column THEN)",
batch.num_rows(),
batch.num_columns()
)
.as_str(),
|b| {
// Create 20 branches mapping different literal values to columns
let when_thens: Vec<_> = (0..20i32)
.map(|i| {
let col_idx = (i as usize % 3) + 1;
let col_name = format!("c{col_idx}");
(lit(i), col(&col_name, &batch.schema()).unwrap())
})
.collect();
let expr = Arc::new(
case(Some(Arc::clone(&c1)), when_thens, Some(Arc::clone(&c2))).unwrap(),
);
b.iter(|| black_box(expr.evaluate(black_box(batch)).unwrap()))
},
);
}

struct Options<T> {
Expand Down
173 changes: 172 additions & 1 deletion datafusion/physical-expr/src/expressions/case.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ use std::borrow::Cow;
use std::hash::Hash;
use std::{any::Any, sync::Arc};

use crate::expressions::case::literal_lookup_table::LiteralLookupTable;
use crate::expressions::case::literal_lookup_table::{
LiteralLookupTable, WhenLiteralIndexMap, try_creating_lookup_table,
};
use arrow::compute::kernels::merge::{MergeIndex, merge, merge_n};
use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
use datafusion_physical_expr_common::datum::compare_with_eq;
Expand Down Expand Up @@ -81,6 +83,12 @@ enum EvalMethod {
///
/// See [`LiteralLookupTable`] for more details
WithExprScalarLookupTable(LiteralLookupTable),

/// This is a specialization for [`EvalMethod::WithExpression`] when the WHEN values are literals
/// but the THEN expressions can be arbitrary expressions (not just literals).
///
/// Uses HashMap for O(1) branch lookup, then evaluates THEN expressions only for matching rows.
WithExprLookupTable(ExpressionLookupTable),
}

/// Implementing hash so we can use `derive` on [`EvalMethod`].
Expand All @@ -107,6 +115,79 @@ impl PartialEq for LiteralLookupTable {

impl Eq for LiteralLookupTable {}

/// Lookup table for CASE expressions where WHEN values are literals but THEN expressions
/// can be arbitrary expressions.
#[derive(Debug)]
struct ExpressionLookupTable {
lookup: Box<dyn WhenLiteralIndexMap>,
num_branches: usize,
else_index: u32,
}

impl Hash for ExpressionLookupTable {
fn hash<H: std::hash::Hasher>(&self, _state: &mut H) {}
}

impl PartialEq for ExpressionLookupTable {
fn eq(&self, _other: &Self) -> bool {
true
}
}

impl Eq for ExpressionLookupTable {}

impl ExpressionLookupTable {
fn maybe_new(body: &CaseBody) -> Option<Self> {
body.expr.as_ref()?;
if body.when_then_expr.len() < 2 {
return None;
}

let when_literals: Option<Vec<ScalarValue>> = body
.when_then_expr
.iter()
.map(|(when, _)| {
when.as_any()
.downcast_ref::<Literal>()
.map(|l| l.value().clone())
})
.collect();

let when_literals = when_literals?;

let mut seen = IndexSet::new();
let unique_when_literals: Vec<ScalarValue> = when_literals
.into_iter()
.filter(|v| !v.is_null() && seen.insert(v.clone()))
.collect();

if unique_when_literals.is_empty() {
return None;
}

let when_data_type = unique_when_literals[0].data_type();
if unique_when_literals
.iter()
.any(|l| l.data_type() != when_data_type)
{
return None;
}

let lookup = try_creating_lookup_table(unique_when_literals).ok()?;

Some(Self {
lookup,
num_branches: body.when_then_expr.len(),
else_index: body.when_then_expr.len() as u32,
})
}

fn map_to_branch_indices(&self, base_values: &ArrayRef) -> Result<Vec<u32>> {
self.lookup
.map_to_when_indices(base_values, self.else_index)
}
}

/// The body of a CASE expression which consists of an optional base expression, the "when/then"
/// branches and an optional "else" branch.
#[derive(Debug, Hash, PartialEq, Eq)]
Expand Down Expand Up @@ -644,6 +725,10 @@ impl CaseExpr {
return Ok(EvalMethod::WithExprScalarLookupTable(mapping));
}

if let Some(table) = ExpressionLookupTable::maybe_new(body) {
return Ok(EvalMethod::WithExprLookupTable(table));
}

return Ok(EvalMethod::WithExpression(body.project()?));
}

Expand Down Expand Up @@ -1170,6 +1255,89 @@ impl CaseExpr {

Ok(result)
}

fn with_expr_lookup_table(
&self,
batch: &RecordBatch,
lookup_table: &ExpressionLookupTable,
) -> Result<ColumnarValue> {
let return_type = self.data_type(&batch.schema())?;
let row_count = batch.num_rows();

if row_count == 0 {
return Ok(ColumnarValue::Array(new_empty_array(&return_type)));
}

let base_expr = self.body.expr.as_ref().unwrap();
let base_values = base_expr.evaluate(batch)?.into_array(row_count)?;
let branch_indices = lookup_table.map_to_branch_indices(&base_values)?;

let mut branch_rows: Vec<Vec<u32>> =
vec![Vec::new(); lookup_table.num_branches + 1];
for (row_idx, &branch_idx) in branch_indices.iter().enumerate() {
branch_rows[branch_idx as usize].push(row_idx as u32);
}

for (branch_idx, rows) in branch_rows.iter().enumerate() {
if rows.len() == row_count {
if branch_idx < lookup_table.num_branches {
return self.body.when_then_expr[branch_idx].1.evaluate(batch);
} else {
return if let Some(else_expr) = &self.body.else_expr {
else_expr.evaluate(batch)
} else {
Ok(ColumnarValue::Scalar(ScalarValue::try_new_null(
&return_type,
)?))
};
}
}
}

let mut result_builder = ResultBuilder::new(&return_type, row_count);

for (branch_idx, rows) in branch_rows
.iter()
.enumerate()
.take(lookup_table.num_branches)
{
if rows.is_empty() {
continue;
}

let row_indices = Arc::new(UInt32Array::from(rows.clone())) as ArrayRef;
let filter_predicate = create_filter_from_indices(rows, row_count);
let filtered_batch = filter_record_batch(batch, &filter_predicate)?;

let then_expr = &self.body.when_then_expr[branch_idx].1;
let then_value = then_expr.evaluate(&filtered_batch)?;

result_builder.add_branch_result(&row_indices, then_value)?;
}

let else_rows = &branch_rows[lookup_table.num_branches];
if !else_rows.is_empty()
&& let Some(else_expr) = &self.body.else_expr
{
let row_indices = Arc::new(UInt32Array::from(else_rows.clone())) as ArrayRef;
let filter_predicate = create_filter_from_indices(else_rows, row_count);
let filtered_batch = filter_record_batch(batch, &filter_predicate)?;
let else_value = else_expr.evaluate(&filtered_batch)?;
result_builder.add_branch_result(&row_indices, else_value)?;
}

result_builder.finish()
}
}

fn create_filter_from_indices(indices: &[u32], total_rows: usize) -> FilterPredicate {
let mut builder = BooleanBufferBuilder::new(total_rows);
builder.append_n(total_rows, false);
for &idx in indices {
builder.set_bit(idx as usize, true);
}
let bool_array = BooleanArray::new(builder.finish(), None);
create_filter(&bool_array, true)
}

impl PhysicalExpr for CaseExpr {
Expand Down Expand Up @@ -1268,6 +1436,9 @@ impl PhysicalExpr for CaseExpr {
EvalMethod::WithExprScalarLookupTable(lookup_table) => {
self.with_lookup_table(batch, lookup_table)
}
EvalMethod::WithExprLookupTable(lookup_table) => {
self.with_expr_lookup_table(batch, lookup_table)
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,9 @@ impl LiteralLookupTable {
/// ```
///
/// this will map <literal_a> to 0, <literal_b> to 1, <literal_c> to 2, <literal_d> to 3
pub(super) trait WhenLiteralIndexMap: Debug + Send + Sync {
pub(in super::super) trait WhenLiteralIndexMap:
Debug + Send + Sync
{
/// Given an array of values, returns a vector of WHEN clause indices corresponding to each value in the provided array.
///
/// For example, for this CASE expression:
Expand Down Expand Up @@ -260,7 +262,7 @@ pub(super) trait WhenLiteralIndexMap: Debug + Send + Sync {
) -> datafusion_common::Result<Vec<u32>>;
}

fn try_creating_lookup_table(
pub(in super::super) fn try_creating_lookup_table(
unique_non_null_literals: Vec<ScalarValue>,
) -> datafusion_common::Result<Box<dyn WhenLiteralIndexMap>> {
assert_ne!(
Expand Down
Loading