Skip to content

Commit 821d410

Browse files
authored
feat(spark): Implement collect_list/collect_set aggregate functions (#19699)
## Which issue does this PR close? - Part of #15914 - Closes #17923 - Close #17924 ## Rationale for this change ## What changes are included in this PR? Implementation of spark `collect_list` and `collect_set` aggregate functions. ## Are these changes tested? yes ## Are there any user-facing changes? yes
1 parent 5c2b123 commit 821d410

File tree

6 files changed

+314
-2
lines changed

6 files changed

+314
-2
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/functions-aggregate/src/array_agg.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,7 @@ impl Accumulator for ArrayAggAccumulator {
415415
}
416416

417417
#[derive(Debug)]
418-
struct DistinctArrayAggAccumulator {
418+
pub struct DistinctArrayAggAccumulator {
419419
values: HashSet<ScalarValue>,
420420
datatype: DataType,
421421
sort_options: Option<SortOptions>,

datafusion/spark/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ datafusion-common = { workspace = true }
4848
datafusion-execution = { workspace = true }
4949
datafusion-expr = { workspace = true }
5050
datafusion-functions = { workspace = true, features = ["crypto_expressions"] }
51+
datafusion-functions-aggregate = { workspace = true }
5152
datafusion-functions-nested = { workspace = true }
5253
log = { workspace = true }
5354
percent-encoding = "2.3.2"
Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::array::ArrayRef;
19+
use arrow::datatypes::{DataType, Field, FieldRef};
20+
use datafusion_common::utils::SingleRowListArrayBuilder;
21+
use datafusion_common::{Result, ScalarValue};
22+
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
23+
use datafusion_expr::utils::format_state_name;
24+
use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility};
25+
use datafusion_functions_aggregate::array_agg::{
26+
ArrayAggAccumulator, DistinctArrayAggAccumulator,
27+
};
28+
use std::{any::Any, sync::Arc};
29+
30+
// Spark implementation of collect_list/collect_set aggregate function.
31+
// Differs from DataFusion ArrayAgg in the following ways:
32+
// - ignores NULL inputs
33+
// - returns an empty list when all inputs are NULL
34+
// - does not support ordering
35+
36+
// <https://spark.apache.org/docs/latest/api/sql/index.html#collect_list>
37+
#[derive(Debug, PartialEq, Eq, Hash)]
38+
pub struct SparkCollectList {
39+
signature: Signature,
40+
}
41+
42+
impl Default for SparkCollectList {
43+
fn default() -> Self {
44+
Self::new()
45+
}
46+
}
47+
48+
impl SparkCollectList {
49+
pub fn new() -> Self {
50+
Self {
51+
signature: Signature::any(1, Volatility::Immutable),
52+
}
53+
}
54+
}
55+
56+
impl AggregateUDFImpl for SparkCollectList {
57+
fn as_any(&self) -> &dyn Any {
58+
self
59+
}
60+
61+
fn name(&self) -> &str {
62+
"collect_list"
63+
}
64+
65+
fn signature(&self) -> &Signature {
66+
&self.signature
67+
}
68+
69+
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
70+
Ok(DataType::List(Arc::new(Field::new_list_field(
71+
arg_types[0].clone(),
72+
true,
73+
))))
74+
}
75+
76+
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
77+
Ok(vec![
78+
Field::new_list(
79+
format_state_name(args.name, "collect_list"),
80+
Field::new_list_field(args.input_fields[0].data_type().clone(), true),
81+
true,
82+
)
83+
.into(),
84+
])
85+
}
86+
87+
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
88+
let field = &acc_args.expr_fields[0];
89+
let data_type = field.data_type().clone();
90+
let ignore_nulls = true;
91+
Ok(Box::new(NullToEmptyListAccumulator::new(
92+
ArrayAggAccumulator::try_new(&data_type, ignore_nulls)?,
93+
data_type,
94+
)))
95+
}
96+
}
97+
98+
// <https://spark.apache.org/docs/latest/api/sql/index.html#collect_set>
99+
#[derive(Debug, PartialEq, Eq, Hash)]
100+
pub struct SparkCollectSet {
101+
signature: Signature,
102+
}
103+
104+
impl Default for SparkCollectSet {
105+
fn default() -> Self {
106+
Self::new()
107+
}
108+
}
109+
110+
impl SparkCollectSet {
111+
pub fn new() -> Self {
112+
Self {
113+
signature: Signature::any(1, Volatility::Immutable),
114+
}
115+
}
116+
}
117+
118+
impl AggregateUDFImpl for SparkCollectSet {
119+
fn as_any(&self) -> &dyn Any {
120+
self
121+
}
122+
123+
fn name(&self) -> &str {
124+
"collect_set"
125+
}
126+
127+
fn signature(&self) -> &Signature {
128+
&self.signature
129+
}
130+
131+
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
132+
Ok(DataType::List(Arc::new(Field::new_list_field(
133+
arg_types[0].clone(),
134+
true,
135+
))))
136+
}
137+
138+
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
139+
Ok(vec![
140+
Field::new_list(
141+
format_state_name(args.name, "collect_set"),
142+
Field::new_list_field(args.input_fields[0].data_type().clone(), true),
143+
true,
144+
)
145+
.into(),
146+
])
147+
}
148+
149+
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
150+
let field = &acc_args.expr_fields[0];
151+
let data_type = field.data_type().clone();
152+
let ignore_nulls = true;
153+
Ok(Box::new(NullToEmptyListAccumulator::new(
154+
DistinctArrayAggAccumulator::try_new(&data_type, None, ignore_nulls)?,
155+
data_type,
156+
)))
157+
}
158+
}
159+
160+
/// Wrapper accumulator that returns an empty list instead of NULL when all inputs are NULL.
161+
/// This implements Spark's behavior for collect_list and collect_set.
162+
#[derive(Debug)]
163+
struct NullToEmptyListAccumulator<T: Accumulator> {
164+
inner: T,
165+
data_type: DataType,
166+
}
167+
168+
impl<T: Accumulator> NullToEmptyListAccumulator<T> {
169+
pub fn new(inner: T, data_type: DataType) -> Self {
170+
Self { inner, data_type }
171+
}
172+
}
173+
174+
impl<T: Accumulator> Accumulator for NullToEmptyListAccumulator<T> {
175+
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
176+
self.inner.update_batch(values)
177+
}
178+
179+
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
180+
self.inner.merge_batch(states)
181+
}
182+
183+
fn state(&mut self) -> Result<Vec<ScalarValue>> {
184+
self.inner.state()
185+
}
186+
187+
fn evaluate(&mut self) -> Result<ScalarValue> {
188+
let result = self.inner.evaluate()?;
189+
if result.is_null() {
190+
let empty_array = arrow::array::new_empty_array(&self.data_type);
191+
Ok(SingleRowListArrayBuilder::new(empty_array).build_list_scalar())
192+
} else {
193+
Ok(result)
194+
}
195+
}
196+
197+
fn size(&self) -> usize {
198+
self.inner.size() + self.data_type.size()
199+
}
200+
}

datafusion/spark/src/function/aggregate/mod.rs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ use datafusion_expr::AggregateUDF;
1919
use std::sync::Arc;
2020

2121
pub mod avg;
22+
pub mod collect;
2223
pub mod try_sum;
2324

2425
pub mod expr_fn {
@@ -30,6 +31,16 @@ pub mod expr_fn {
3031
"Returns the sum of values for a column, or NULL if overflow occurs",
3132
arg1
3233
));
34+
export_functions!((
35+
collect_list,
36+
"Returns a list created from the values in a column",
37+
arg1
38+
));
39+
export_functions!((
40+
collect_set,
41+
"Returns a set created from the values in a column",
42+
arg1
43+
));
3344
}
3445

3546
// TODO: try use something like datafusion_functions_aggregate::create_func!()
@@ -39,7 +50,13 @@ pub fn avg() -> Arc<AggregateUDF> {
3950
pub fn try_sum() -> Arc<AggregateUDF> {
4051
Arc::new(AggregateUDF::new_from_impl(try_sum::SparkTrySum::new()))
4152
}
53+
pub fn collect_list() -> Arc<AggregateUDF> {
54+
Arc::new(AggregateUDF::new_from_impl(collect::SparkCollectList::new()))
55+
}
56+
pub fn collect_set() -> Arc<AggregateUDF> {
57+
Arc::new(AggregateUDF::new_from_impl(collect::SparkCollectSet::new()))
58+
}
4259

4360
pub fn functions() -> Vec<Arc<AggregateUDF>> {
44-
vec![avg(), try_sum()]
61+
vec![avg(), try_sum(), collect_list(), collect_set()]
4562
}
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
query ?
19+
SELECT collect_list(a) FROM (VALUES (1), (2), (3)) AS t(a);
20+
----
21+
[1, 2, 3]
22+
23+
query ?
24+
SELECT collect_list(a) FROM (VALUES (1), (2), (2), (3), (1)) AS t(a);
25+
----
26+
[1, 2, 2, 3, 1]
27+
28+
query ?
29+
SELECT collect_list(a) FROM (VALUES (1), (NULL), (3)) AS t(a);
30+
----
31+
[1, 3]
32+
33+
query ?
34+
SELECT collect_list(a) FROM (VALUES (CAST(NULL AS INT)), (NULL), (NULL)) AS t(a);
35+
----
36+
[]
37+
38+
query I?
39+
SELECT g, collect_list(a)
40+
FROM (VALUES (1, 10), (1, 20), (2, 30), (2, 30), (1, 10)) AS t(g, a)
41+
GROUP BY g
42+
ORDER BY g;
43+
----
44+
1 [10, 20, 10]
45+
2 [30, 30]
46+
47+
query I?
48+
SELECT g, collect_list(a)
49+
FROM (VALUES (1, 10), (1, NULL), (2, 20), (2, NULL)) AS t(g, a)
50+
GROUP BY g
51+
ORDER BY g;
52+
----
53+
1 [10]
54+
2 [20]
55+
56+
# we need to wrap collect_set with array_sort to have consistent outputs
57+
query ?
58+
SELECT array_sort(collect_set(a)) FROM (VALUES (1), (2), (3)) AS t(a);
59+
----
60+
[1, 2, 3]
61+
62+
query ?
63+
SELECT array_sort(collect_set(a)) FROM (VALUES (1), (2), (2), (3), (1)) AS t(a);
64+
----
65+
[1, 2, 3]
66+
67+
query ?
68+
SELECT array_sort(collect_set(a)) FROM (VALUES (1), (NULL), (3)) AS t(a);
69+
----
70+
[1, 3]
71+
72+
query ?
73+
SELECT array_sort(collect_set(a)) FROM (VALUES (CAST(NULL AS INT)), (NULL), (NULL)) AS t(a);
74+
----
75+
[]
76+
77+
query I?
78+
SELECT g, array_sort(collect_set(a))
79+
FROM (VALUES (1, 10), (1, 20), (2, 30), (2, 30), (1, 10)) AS t(g, a)
80+
GROUP BY g
81+
ORDER BY g;
82+
----
83+
1 [10, 20]
84+
2 [30]
85+
86+
query I?
87+
SELECT g, array_sort(collect_set(a))
88+
FROM (VALUES (1, 10), (1, NULL), (1, NULL), (2, 20), (2, NULL)) AS t(g, a)
89+
GROUP BY g
90+
ORDER BY g;
91+
----
92+
1 [10]
93+
2 [20]

0 commit comments

Comments
 (0)