Skip to content

Commit 13fb42e

Browse files
authored
Implement retract_batch for AvgAccumulator (apache#4846)
* Implement retract_batch for AvgAccumulator, Add avg to custom window frame tests * fmt
1 parent 42f7dd5 commit 13fb42e

File tree

2 files changed

+33
-18
lines changed

2 files changed

+33
-18
lines changed

datafusion/core/tests/sql/window.rs

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -503,21 +503,22 @@ async fn window_frame_rows_preceding() -> Result<()> {
503503
register_aggregate_csv(&ctx).await?;
504504
let sql = "SELECT \
505505
SUM(c4) OVER(ORDER BY c4 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING),\
506+
AVG(c4) OVER(ORDER BY c4 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING),\
506507
COUNT(*) OVER(ORDER BY c4 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING)\
507508
FROM aggregate_test_100 \
508509
ORDER BY c9 \
509510
LIMIT 5";
510511
let actual = execute_to_batches(&ctx, sql).await;
511512
let expected = vec![
512-
"+----------------------------+-----------------+",
513-
"| SUM(aggregate_test_100.c4) | COUNT(UInt8(1)) |",
514-
"+----------------------------+-----------------+",
515-
"| -48302 | 3 |",
516-
"| 11243 | 3 |",
517-
"| -51311 | 3 |",
518-
"| -2391 | 3 |",
519-
"| 46756 | 3 |",
520-
"+----------------------------+-----------------+",
513+
"+----------------------------+----------------------------+-----------------+",
514+
"| SUM(aggregate_test_100.c4) | AVG(aggregate_test_100.c4) | COUNT(UInt8(1)) |",
515+
"+----------------------------+----------------------------+-----------------+",
516+
"| -48302 | -16100.666666666666 | 3 |",
517+
"| 11243 | 3747.6666666666665 | 3 |",
518+
"| -51311 | -17103.666666666668 | 3 |",
519+
"| -2391 | -797 | 3 |",
520+
"| 46756 | 15585.333333333334 | 3 |",
521+
"+----------------------------+----------------------------+-----------------+",
521522
];
522523
assert_batches_eq!(expected, &actual);
523524
Ok(())
@@ -529,21 +530,22 @@ async fn window_frame_rows_preceding_with_partition_unique_order_by() -> Result<
529530
register_aggregate_csv(&ctx).await?;
530531
let sql = "SELECT \
531532
SUM(c4) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING),\
533+
AVG(c4) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING),\
532534
COUNT(*) OVER(PARTITION BY c2 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING)\
533535
FROM aggregate_test_100 \
534536
ORDER BY c9 \
535537
LIMIT 5";
536538
let actual = execute_to_batches(&ctx, sql).await;
537539
let expected = vec![
538-
"+----------------------------+-----------------+",
539-
"| SUM(aggregate_test_100.c4) | COUNT(UInt8(1)) |",
540-
"+----------------------------+-----------------+",
541-
"| -38611 | 2 |",
542-
"| 17547 | 2 |",
543-
"| -1301 | 2 |",
544-
"| 26638 | 3 |",
545-
"| 26861 | 3 |",
546-
"+----------------------------+-----------------+",
540+
"+----------------------------+----------------------------+-----------------+",
541+
"| SUM(aggregate_test_100.c4) | AVG(aggregate_test_100.c4) | COUNT(UInt8(1)) |",
542+
"+----------------------------+----------------------------+-----------------+",
543+
"| -38611 | -19305.5 | 2 |",
544+
"| 17547 | 8773.5 | 2 |",
545+
"| -1301 | -650.5 | 2 |",
546+
"| 26638 | 13319 | 3 |",
547+
"| 26861 | 8953.666666666666 | 3 |",
548+
"+----------------------------+----------------------------+-----------------+",
547549
];
548550
assert_batches_eq!(expected, &actual);
549551
Ok(())

datafusion/physical-expr/src/aggregate/average.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ use crate::aggregate::row_accumulator::{
2525
is_row_accumulator_support_dtype, RowAccumulator,
2626
};
2727
use crate::aggregate::sum;
28+
use crate::aggregate::sum::sum_batch;
2829
use crate::expressions::format_state_name;
2930
use crate::{AggregateExpr, PhysicalExpr};
3031
use arrow::compute;
@@ -119,6 +120,10 @@ impl AggregateExpr for Avg {
119120
self.data_type.clone(),
120121
)))
121122
}
123+
124+
fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
125+
Ok(Box::new(AvgAccumulator::try_new(&self.data_type)?))
126+
}
122127
}
123128

124129
/// An accumulator to compute the average
@@ -154,6 +159,14 @@ impl Accumulator for AvgAccumulator {
154159
Ok(())
155160
}
156161

162+
fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
163+
let values = &values[0];
164+
self.count -= (values.len() - values.data().null_count()) as u64;
165+
let delta = sum_batch(values, &self.sum.get_datatype())?;
166+
self.sum = self.sum.sub(&delta)?;
167+
Ok(())
168+
}
169+
157170
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
158171
let counts = downcast_value!(states[0], UInt64Array);
159172
// counts are summed

0 commit comments

Comments
 (0)