diff --git a/Cargo.lock b/Cargo.lock index 2ee907b30cf02..7cfdc916396bf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2309,6 +2309,7 @@ name = "datafusion-functions-window" version = "51.0.0" dependencies = [ "arrow", + "criterion", "datafusion-common", "datafusion-doc", "datafusion-expr", diff --git a/datafusion/functions-window/Cargo.toml b/datafusion/functions-window/Cargo.toml index 42690907ae26c..fae71e180e34c 100644 --- a/datafusion/functions-window/Cargo.toml +++ b/datafusion/functions-window/Cargo.toml @@ -51,3 +51,11 @@ datafusion-physical-expr = { workspace = true } datafusion-physical-expr-common = { workspace = true } log = { workspace = true } paste = { workspace = true } + +[dev-dependencies] +arrow = { workspace = true, features = ["test_utils"] } +criterion = { workspace = true } + +[[bench]] +name = "nth_value" +harness = false diff --git a/datafusion/functions-window/benches/nth_value.rs b/datafusion/functions-window/benches/nth_value.rs new file mode 100644 index 0000000000000..00daf9fa4f9ba --- /dev/null +++ b/datafusion/functions-window/benches/nth_value.rs @@ -0,0 +1,263 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::hint::black_box; +use std::ops::Range; +use std::slice; +use std::sync::Arc; + +use arrow::array::ArrayRef; +use arrow::datatypes::{DataType, Field, FieldRef, Int64Type}; +use arrow::util::bench_util::create_primitive_array; + +use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; +use datafusion_expr::{PartitionEvaluator, WindowUDFImpl}; +use datafusion_functions_window::nth_value::{NthValue, NthValueKind}; +use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +use datafusion_physical_expr::expressions::{Column, Literal}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; + +const ARRAY_SIZE: usize = 8192; + +/// Creates a partition evaluator for FIRST_VALUE, LAST_VALUE, or NTH_VALUE +fn create_evaluator( + kind: NthValueKind, + ignore_nulls: bool, + n: Option, +) -> Box { + let expr = Arc::new(Column::new("c", 0)) as Arc; + let input_field: FieldRef = Field::new("c", DataType::Int64, true).into(); + let input_fields = vec![input_field]; + + let (nth_value, exprs): (NthValue, Vec>) = match kind { + NthValueKind::First => (NthValue::first(), vec![expr]), + NthValueKind::Last => (NthValue::last(), vec![expr]), + NthValueKind::Nth => { + let n_value = + Arc::new(Literal::new(ScalarValue::Int64(n))) as Arc; + (NthValue::nth(), vec![expr, n_value]) + } + }; + + let args = PartitionEvaluatorArgs::new(&exprs, &input_fields, false, ignore_nulls); + nth_value.partition_evaluator(args).unwrap() +} + +fn bench_nth_value_ignore_nulls(c: &mut Criterion) { + let mut group = c.benchmark_group("nth_value_ignore_nulls"); + + // Test different null densities + let null_densities = [0.0, 0.3, 0.5, 0.8]; + + for null_density in null_densities { + let values = Arc::new(create_primitive_array::( + ARRAY_SIZE, + null_density, + )) as ArrayRef; + let null_pct = (null_density * 100.0) as u32; + + // FIRST_VALUE with ignore_nulls - expanding window + group.bench_function( + BenchmarkId::new("first_value_expanding", format!("{null_pct}%_nulls")), + |b| { + b.iter(|| { + let mut evaluator = create_evaluator(NthValueKind::First, true, None); + let values_slice = slice::from_ref(&values); + for i in 0..values.len() { + let range = Range { + start: 0, + end: i + 1, + }; + black_box(evaluator.evaluate(values_slice, &range).unwrap()); + } + }) + }, + ); + + // LAST_VALUE with ignore_nulls - expanding window + group.bench_function( + BenchmarkId::new("last_value_expanding", format!("{null_pct}%_nulls")), + |b| { + b.iter(|| { + let mut evaluator = create_evaluator(NthValueKind::Last, true, None); + let values_slice = slice::from_ref(&values); + for i in 0..values.len() { + let range = Range { + start: 0, + end: i + 1, + }; + black_box(evaluator.evaluate(values_slice, &range).unwrap()); + } + }) + }, + ); + + // NTH_VALUE(col, 10) with ignore_nulls - get 10th non-null value + group.bench_function( + BenchmarkId::new("nth_value_10_expanding", format!("{null_pct}%_nulls")), + |b| { + b.iter(|| { + let mut evaluator = + create_evaluator(NthValueKind::Nth, true, Some(10)); + let values_slice = slice::from_ref(&values); + for i in 0..values.len() { + let range = Range { + start: 0, + end: i + 1, + }; + black_box(evaluator.evaluate(values_slice, &range).unwrap()); + } + }) + }, + ); + + // NTH_VALUE(col, -10) with ignore_nulls - get 10th from last non-null value + group.bench_function( + BenchmarkId::new("nth_value_neg10_expanding", format!("{null_pct}%_nulls")), + |b| { + b.iter(|| { + let mut evaluator = + create_evaluator(NthValueKind::Nth, true, Some(-10)); + let values_slice = slice::from_ref(&values); + for i in 0..values.len() { + let range = Range { + start: 0, + end: i + 1, + }; + black_box(evaluator.evaluate(values_slice, &range).unwrap()); + } + }) + }, + ); + + // Sliding window benchmarks with 100-row window + let window_size: usize = 100; + + group.bench_function( + BenchmarkId::new("first_value_sliding_100", format!("{null_pct}%_nulls")), + |b| { + b.iter(|| { + let mut evaluator = create_evaluator(NthValueKind::First, true, None); + let values_slice = slice::from_ref(&values); + for i in 0..values.len() { + let start = i.saturating_sub(window_size - 1); + let range = Range { start, end: i + 1 }; + black_box(evaluator.evaluate(values_slice, &range).unwrap()); + } + }) + }, + ); + + group.bench_function( + BenchmarkId::new("last_value_sliding_100", format!("{null_pct}%_nulls")), + |b| { + b.iter(|| { + let mut evaluator = create_evaluator(NthValueKind::Last, true, None); + let values_slice = slice::from_ref(&values); + for i in 0..values.len() { + let start = i.saturating_sub(window_size - 1); + let range = Range { start, end: i + 1 }; + black_box(evaluator.evaluate(values_slice, &range).unwrap()); + } + }) + }, + ); + } + + group.finish(); + + // Comparison benchmarks: ignore_nulls vs respect_nulls + let mut comparison_group = c.benchmark_group("nth_value_nulls_comparison"); + let values_with_nulls = + Arc::new(create_primitive_array::(ARRAY_SIZE, 0.5)) as ArrayRef; + + // FIRST_VALUE comparison + comparison_group.bench_function( + BenchmarkId::new("first_value", "ignore_nulls"), + |b| { + b.iter(|| { + let mut evaluator = create_evaluator(NthValueKind::First, true, None); + let values_slice = slice::from_ref(&values_with_nulls); + for i in 0..values_with_nulls.len() { + let range = Range { + start: 0, + end: i + 1, + }; + black_box(evaluator.evaluate(values_slice, &range).unwrap()); + } + }) + }, + ); + + comparison_group.bench_function( + BenchmarkId::new("first_value", "respect_nulls"), + |b| { + b.iter(|| { + let mut evaluator = create_evaluator(NthValueKind::First, false, None); + let values_slice = slice::from_ref(&values_with_nulls); + for i in 0..values_with_nulls.len() { + let range = Range { + start: 0, + end: i + 1, + }; + black_box(evaluator.evaluate(values_slice, &range).unwrap()); + } + }) + }, + ); + + // NTH_VALUE comparison + comparison_group.bench_function( + BenchmarkId::new("nth_value_10", "ignore_nulls"), + |b| { + b.iter(|| { + let mut evaluator = create_evaluator(NthValueKind::Nth, true, Some(10)); + let values_slice = slice::from_ref(&values_with_nulls); + for i in 0..values_with_nulls.len() { + let range = Range { + start: 0, + end: i + 1, + }; + black_box(evaluator.evaluate(values_slice, &range).unwrap()); + } + }) + }, + ); + + comparison_group.bench_function( + BenchmarkId::new("nth_value_10", "respect_nulls"), + |b| { + b.iter(|| { + let mut evaluator = create_evaluator(NthValueKind::Nth, false, Some(10)); + let values_slice = slice::from_ref(&values_with_nulls); + for i in 0..values_with_nulls.len() { + let range = Range { + start: 0, + end: i + 1, + }; + black_box(evaluator.evaluate(values_slice, &range).unwrap()); + } + }) + }, + ); + + comparison_group.finish(); +} + +criterion_group!(benches, bench_nth_value_ignore_nulls); +criterion_main!(benches); diff --git a/datafusion/functions-window/src/nth_value.rs b/datafusion/functions-window/src/nth_value.rs index be08f25ec404b..c62f0a9ae4e89 100644 --- a/datafusion/functions-window/src/nth_value.rs +++ b/datafusion/functions-window/src/nth_value.rs @@ -19,6 +19,7 @@ use crate::utils::{get_scalar_value_from_args, get_signed_integer}; +use arrow::buffer::NullBuffer; use arrow::datatypes::FieldRef; use datafusion_common::arrow::array::ArrayRef; use datafusion_common::arrow::datatypes::{DataType, Field}; @@ -370,6 +371,33 @@ impl PartitionEvaluator for NthValueEvaluator { fn memoize(&mut self, state: &mut WindowAggState) -> Result<()> { let out = &state.out_col; let size = out.len(); + if self.ignore_nulls { + match self.state.kind { + // Prune on first non-null output in case of FIRST_VALUE + NthValueKind::First => { + if let Some(nulls) = out.nulls() { + if self.state.finalized_result.is_none() { + if let Some(valid_index) = nulls.valid_indices().next() { + let result = + ScalarValue::try_from_array(out, valid_index)?; + self.state.finalized_result = Some(result); + } else { + // The output is empty or all nulls, ignore + } + } + if state.window_frame_range.start < state.window_frame_range.end { + state.window_frame_range.start = + state.window_frame_range.end - 1; + } + return Ok(()); + } else { + // Fall through to the main case because there are no nulls + } + } + // Do not memoize for other kinds when nulls are ignored + NthValueKind::Last | NthValueKind::Nth => return Ok(()), + } + } let mut buffer_size = 1; // Decide if we arrived at a final result yet: let (is_prunable, is_reverse_direction) = match self.state.kind { @@ -397,8 +425,7 @@ impl PartitionEvaluator for NthValueEvaluator { } } }; - // Do not memoize results when nulls are ignored. - if is_prunable && !self.ignore_nulls { + if is_prunable { if self.state.finalized_result.is_none() && !is_reverse_direction { let result = ScalarValue::try_from_array(out, size - 1)?; self.state.finalized_result = Some(result); @@ -424,99 +451,90 @@ impl PartitionEvaluator for NthValueEvaluator { // We produce None if the window is empty. return ScalarValue::try_from(arr.data_type()); } + match self.valid_index(arr, range) { + Some(index) => ScalarValue::try_from_array(arr, index), + None => ScalarValue::try_from(arr.data_type()), + } + } + } - // If null values exist and need to be ignored, extract the valid indices. - let valid_indices = if self.ignore_nulls { - // Calculate valid indices, inside the window frame boundaries. - let slice = arr.slice(range.start, n_range); - match slice.nulls() { - Some(nulls) => { - let valid_indices = nulls - .valid_indices() - .map(|idx| { - // Add offset `range.start` to valid indices, to point correct index in the original arr. - idx + range.start - }) - .collect::>(); - if valid_indices.is_empty() { - // If all values are null, return directly. - return ScalarValue::try_from(arr.data_type()); - } - Some(valid_indices) - } - None => None, - } - } else { - None - }; - match self.state.kind { - NthValueKind::First => { - if let Some(valid_indices) = &valid_indices { - ScalarValue::try_from_array(arr, valid_indices[0]) + fn supports_bounded_execution(&self) -> bool { + true + } + + fn uses_window_frame(&self) -> bool { + true + } +} + +impl NthValueEvaluator { + fn valid_index(&self, array: &ArrayRef, range: &Range) -> Option { + let n_range = range.end - range.start; + if self.ignore_nulls { + // Calculate valid indices, inside the window frame boundaries. + let slice = array.slice(range.start, n_range); + if let Some(nulls) = slice.nulls() + && nulls.null_count() > 0 + { + return self.valid_index_with_nulls(nulls, range.start); + } + } + // Either no nulls, or nulls are regarded as valid rows + match self.state.kind { + NthValueKind::First => Some(range.start), + NthValueKind::Last => Some(range.end - 1), + NthValueKind::Nth => match self.n.cmp(&0) { + Ordering::Greater => { + // SQL indices are not 0-based. + let index = (self.n as usize) - 1; + if index >= n_range { + // Outside the range, return NULL: + None } else { - ScalarValue::try_from_array(arr, range.start) + Some(range.start + index) } } - NthValueKind::Last => { - if let Some(valid_indices) = &valid_indices { - ScalarValue::try_from_array( - arr, - valid_indices[valid_indices.len() - 1], - ) + Ordering::Less => { + let reverse_index = (-self.n) as usize; + if n_range < reverse_index { + // Outside the range, return NULL: + None } else { - ScalarValue::try_from_array(arr, range.end - 1) + Some(range.end - reverse_index) } } - NthValueKind::Nth => { - match self.n.cmp(&0) { - Ordering::Greater => { - // SQL indices are not 0-based. - let index = (self.n as usize) - 1; - if index >= n_range { - // Outside the range, return NULL: - ScalarValue::try_from(arr.data_type()) - } else if let Some(valid_indices) = valid_indices { - if index >= valid_indices.len() { - return ScalarValue::try_from(arr.data_type()); - } - ScalarValue::try_from_array(&arr, valid_indices[index]) - } else { - ScalarValue::try_from_array(arr, range.start + index) - } - } - Ordering::Less => { - let reverse_index = (-self.n) as usize; - if n_range < reverse_index { - // Outside the range, return NULL: - ScalarValue::try_from(arr.data_type()) - } else if let Some(valid_indices) = valid_indices { - if reverse_index > valid_indices.len() { - return ScalarValue::try_from(arr.data_type()); - } - let new_index = - valid_indices[valid_indices.len() - reverse_index]; - ScalarValue::try_from_array(&arr, new_index) - } else { - ScalarValue::try_from_array( - arr, - range.start + n_range - reverse_index, - ) - } + Ordering::Equal => None, + }, + } + } + + fn valid_index_with_nulls(&self, nulls: &NullBuffer, offset: usize) -> Option { + match self.state.kind { + NthValueKind::First => nulls.valid_indices().next().map(|idx| idx + offset), + NthValueKind::Last => nulls.valid_indices().last().map(|idx| idx + offset), + NthValueKind::Nth => { + match self.n.cmp(&0) { + Ordering::Greater => { + // SQL indices are not 0-based. + let index = (self.n as usize) - 1; + nulls.valid_indices().nth(index).map(|idx| idx + offset) + } + Ordering::Less => { + let reverse_index = (-self.n) as usize; + let valid_indices_len = nulls.len() - nulls.null_count(); + if reverse_index > valid_indices_len { + return None; } - Ordering::Equal => ScalarValue::try_from(arr.data_type()), + nulls + .valid_indices() + .nth(valid_indices_len - reverse_index) + .map(|idx| idx + offset) } + Ordering::Equal => None, } } } } - - fn supports_bounded_execution(&self) -> bool { - true - } - - fn uses_window_frame(&self) -> bool { - true - } } #[cfg(test)]