Skip to content

Commit 62658cd

Browse files
authored
implement var distinct (#19706)
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> - Closes #2410 . ## Rationale for this change <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> - This PR support evaluation for `var(distinct)` and `var_pop(distinct)` statement. ## What changes are included in this PR? <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> - A new `DistinctVarianceAccumulator` is implemented, which stored the distinct element and calculate the variance of the distinct element. - Update the `VarianceSample` and `VariancePopulation` struct to include the state of `distinct` accumulator. - Update the distinct test in `aggregate.slt` ## Are these changes tested? - `cargo test --profile=ci --test sqllogictests -- aggregate.slt` <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> ## Are there any user-facing changes? <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> <!-- If there are any breaking changes to public APIs, please add the `api change` label. -->
1 parent a55b77e commit 62658cd

File tree

2 files changed

+129
-23
lines changed

2 files changed

+129
-23
lines changed

datafusion/functions-aggregate/src/variance.rs

Lines changed: 123 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,21 @@
1818
//! [`VarianceSample`]: variance sample aggregations.
1919
//! [`VariancePopulation`]: variance population aggregations.
2020
21-
use arrow::datatypes::FieldRef;
21+
use arrow::datatypes::{FieldRef, Float64Type};
2222
use arrow::{
2323
array::{Array, ArrayRef, BooleanArray, Float64Array, UInt64Array},
2424
buffer::NullBuffer,
2525
compute::kernels::cast,
2626
datatypes::{DataType, Field},
2727
};
28-
use datafusion_common::{Result, ScalarValue, downcast_value, not_impl_err, plan_err};
28+
use datafusion_common::{Result, ScalarValue, downcast_value, plan_err};
2929
use datafusion_expr::{
3030
Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, Signature,
3131
Volatility,
3232
function::{AccumulatorArgs, StateFieldsArgs},
3333
utils::format_state_name,
3434
};
35+
use datafusion_functions_aggregate_common::utils::GenericDistinctBuffer;
3536
use datafusion_functions_aggregate_common::{
3637
aggregate::groups_accumulator::accumulate::accumulate, stats::StatsType,
3738
};
@@ -110,19 +111,35 @@ impl AggregateUDFImpl for VarianceSample {
110111

111112
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
112113
let name = args.name;
113-
Ok(vec![
114-
Field::new(format_state_name(name, "count"), DataType::UInt64, true),
115-
Field::new(format_state_name(name, "mean"), DataType::Float64, true),
116-
Field::new(format_state_name(name, "m2"), DataType::Float64, true),
117-
]
118-
.into_iter()
119-
.map(Arc::new)
120-
.collect())
114+
match args.is_distinct {
115+
false => Ok(vec![
116+
Field::new(format_state_name(name, "count"), DataType::UInt64, true),
117+
Field::new(format_state_name(name, "mean"), DataType::Float64, true),
118+
Field::new(format_state_name(name, "m2"), DataType::Float64, true),
119+
]
120+
.into_iter()
121+
.map(Arc::new)
122+
.collect()),
123+
true => {
124+
let field = Field::new_list_field(DataType::Float64, true);
125+
let state_name = "distinct_var";
126+
Ok(vec![
127+
Field::new(
128+
format_state_name(name, state_name),
129+
DataType::List(Arc::new(field)),
130+
true,
131+
)
132+
.into(),
133+
])
134+
}
135+
}
121136
}
122137

123138
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
124139
if acc_args.is_distinct {
125-
return not_impl_err!("VAR(DISTINCT) aggregations are not available");
140+
return Ok(Box::new(DistinctVarianceAccumulator::new(
141+
StatsType::Sample,
142+
)));
126143
}
127144

128145
Ok(Box::new(VarianceAccumulator::try_new(StatsType::Sample)?))
@@ -206,20 +223,38 @@ impl AggregateUDFImpl for VariancePopulation {
206223
}
207224

208225
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
209-
let name = args.name;
210-
Ok(vec![
211-
Field::new(format_state_name(name, "count"), DataType::UInt64, true),
212-
Field::new(format_state_name(name, "mean"), DataType::Float64, true),
213-
Field::new(format_state_name(name, "m2"), DataType::Float64, true),
214-
]
215-
.into_iter()
216-
.map(Arc::new)
217-
.collect())
226+
match args.is_distinct {
227+
false => {
228+
let name = args.name;
229+
Ok(vec![
230+
Field::new(format_state_name(name, "count"), DataType::UInt64, true),
231+
Field::new(format_state_name(name, "mean"), DataType::Float64, true),
232+
Field::new(format_state_name(name, "m2"), DataType::Float64, true),
233+
]
234+
.into_iter()
235+
.map(Arc::new)
236+
.collect())
237+
}
238+
true => {
239+
let field = Field::new_list_field(DataType::Float64, true);
240+
let state_name = "distinct_var";
241+
Ok(vec![
242+
Field::new(
243+
format_state_name(args.name, state_name),
244+
DataType::List(Arc::new(field)),
245+
true,
246+
)
247+
.into(),
248+
])
249+
}
250+
}
218251
}
219252

220253
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
221254
if acc_args.is_distinct {
222-
return not_impl_err!("VAR_POP(DISTINCT) aggregations are not available");
255+
return Ok(Box::new(DistinctVarianceAccumulator::new(
256+
StatsType::Population,
257+
)));
223258
}
224259

225260
Ok(Box::new(VarianceAccumulator::try_new(
@@ -581,6 +616,73 @@ impl GroupsAccumulator for VarianceGroupsAccumulator {
581616
}
582617
}
583618

619+
#[derive(Debug)]
620+
pub struct DistinctVarianceAccumulator {
621+
distinct_values: GenericDistinctBuffer<Float64Type>,
622+
stat_type: StatsType,
623+
}
624+
625+
impl DistinctVarianceAccumulator {
626+
pub fn new(stat_type: StatsType) -> Self {
627+
Self {
628+
distinct_values: GenericDistinctBuffer::<Float64Type>::new(DataType::Float64),
629+
stat_type,
630+
}
631+
}
632+
}
633+
634+
impl Accumulator for DistinctVarianceAccumulator {
635+
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
636+
let cast_values = cast(&values[0], &DataType::Float64)?;
637+
self.distinct_values
638+
.update_batch(vec![cast_values].as_ref())
639+
}
640+
641+
fn evaluate(&mut self) -> Result<ScalarValue> {
642+
let values = self
643+
.distinct_values
644+
.values
645+
.iter()
646+
.map(|v| v.0)
647+
.collect::<Vec<_>>();
648+
649+
let count = match self.stat_type {
650+
StatsType::Sample => {
651+
if !values.is_empty() {
652+
values.len() - 1
653+
} else {
654+
0
655+
}
656+
}
657+
StatsType::Population => values.len(),
658+
};
659+
660+
let mean = values.iter().sum::<f64>() / values.len() as f64;
661+
let m2 = values.iter().map(|x| (x - mean) * (x - mean)).sum::<f64>();
662+
663+
Ok(ScalarValue::Float64(match values.len() {
664+
0 => None,
665+
1 => match self.stat_type {
666+
StatsType::Population => Some(0.0),
667+
StatsType::Sample => None,
668+
},
669+
_ => Some(m2 / count as f64),
670+
}))
671+
}
672+
673+
fn size(&self) -> usize {
674+
size_of_val(self) + self.distinct_values.size()
675+
}
676+
677+
fn state(&mut self) -> Result<Vec<ScalarValue>> {
678+
self.distinct_values.state()
679+
}
680+
681+
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
682+
self.distinct_values.merge_batch(states)
683+
}
684+
}
685+
584686
#[cfg(test)]
585687
mod tests {
586688
use datafusion_expr::EmitTo;

datafusion/sqllogictest/test_files/aggregate.slt

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -700,17 +700,21 @@ SELECT var(distinct c2) FROM aggregate_test_100
700700
----
701701
2.5
702702

703-
statement error DataFusion error: This feature is not implemented: VAR\(DISTINCT\) aggregations are not available
703+
query RR
704704
SELECT var(c2), var(distinct c2) FROM aggregate_test_100
705+
----
706+
1.886363636364 2.5
705707

706708
# csv_query_distinct_variance_population
707709
query R
708710
SELECT var_pop(distinct c2) FROM aggregate_test_100
709711
----
710712
2
711713

712-
statement error DataFusion error: This feature is not implemented: VAR_POP\(DISTINCT\) aggregations are not available
714+
query RR
713715
SELECT var_pop(c2), var_pop(distinct c2) FROM aggregate_test_100
716+
----
717+
1.8675 2
714718

715719
# csv_query_variance_5
716720
query R

0 commit comments

Comments
 (0)