Skip to content

Commit 1b001a1

Browse files
vbaruaalamb
andauthored
fix(substrait): schema errors for Aggregates with no groupings (#17909)
## Which issue does this PR close? Closes #16590 ## Rationale for this change When consuming Substrait plans containing aggregates with no groupings, we would see the following error ``` Error: Substrait("Named schema must contain names for all fields") ``` The Substrait plan had one _less_ field than DataFusion expected because DataFusion was adding an extra "__grouping_id" to the output of the Aggregate node. This happens when the https://github.com/apache/datafusion/blob/daeb6597a0c7344735460bb2dce13879fd89d7bd/datafusion/expr/src/logical_plan/plan.rs#L3418 condition is true. A natural followup question to this is "Why are we creating an Aggregate with a single empty GroupingSet for the group by, instead of just leaving the group by entirely?". ## What changes are included in this PR? Instead of setting group_exprs to a vector with a single empty grouping set, let's just leave group_exprs empty entirely. This means that the `is_grouping_set` is not triggered, so the Datafusion schema matches the Substrait schema. ## Are these changes tested? Yes I have added direct tests via example Substrait plans ## Are there any user-facing changes? Substrait plans that were not consumable before are now consumable. --------- Co-authored-by: Andrew Lamb <[email protected]>
1 parent 522403b commit 1b001a1

File tree

5 files changed

+274
-0
lines changed

5 files changed

+274
-0
lines changed

datafusion/substrait/src/logical_plan/consumer/rel/aggregate_rel.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ pub async fn from_aggregate_rel(
4040
let mut aggr_exprs = vec![];
4141

4242
match agg.groupings.len() {
43+
0 => {}
4344
1 => {
4445
group_exprs.extend_from_slice(
4546
&from_substrait_grouping(
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! Tests to verify aggregation relation handling in Substrait
19+
20+
#[cfg(test)]
21+
mod tests {
22+
use crate::utils::test::{add_plan_schemas_to_ctx, read_json};
23+
use datafusion::common::Result;
24+
use datafusion::dataframe::DataFrame;
25+
use datafusion::prelude::SessionContext;
26+
use datafusion_substrait::logical_plan::consumer::from_substrait_plan;
27+
use insta::assert_snapshot;
28+
29+
#[tokio::test]
30+
async fn no_grouping_set() -> Result<()> {
31+
let proto_plan =
32+
read_json("tests/testdata/test_plans/aggregate_groupings/no_groupings.json");
33+
let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?;
34+
let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?;
35+
36+
assert_snapshot!(
37+
plan,
38+
@r#"
39+
Aggregate: groupBy=[[]], aggr=[[sum(c0) AS summation]]
40+
EmptyRelation: rows=0
41+
"#
42+
);
43+
44+
// Trigger execution to ensure plan validity
45+
DataFrame::new(ctx.state(), plan).show().await?;
46+
47+
Ok(())
48+
}
49+
50+
#[tokio::test]
51+
async fn one_grouping_set() -> Result<()> {
52+
let proto_plan = read_json(
53+
"tests/testdata/test_plans/aggregate_groupings/single_grouping.json",
54+
);
55+
let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?;
56+
let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?;
57+
58+
assert_snapshot!(
59+
plan,
60+
@r#"
61+
Aggregate: groupBy=[[c0]], aggr=[[sum(c0) AS summation]]
62+
EmptyRelation: rows=0
63+
"#
64+
);
65+
66+
// Trigger execution to ensure plan validity
67+
DataFrame::new(ctx.state(), plan).show().await?;
68+
69+
Ok(())
70+
}
71+
}

datafusion/substrait/tests/cases/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
mod aggregation_tests;
1819
mod builtin_expr_semantics_tests;
1920
mod consumer_integration;
2021
mod emit_kind_tests;
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
{
2+
"extensionUris": [
3+
{
4+
"extensionUriAnchor": 1,
5+
"uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml"
6+
}
7+
],
8+
"extensions": [
9+
{
10+
"extensionFunction": {
11+
"extensionUriReference": 1,
12+
"functionAnchor": 1,
13+
"name": "sum:i8"
14+
}
15+
}
16+
],
17+
"relations": [
18+
{
19+
"root": {
20+
"input": {
21+
"aggregate": {
22+
"common": {
23+
"direct": {}
24+
},
25+
"input": {
26+
"read": {
27+
"baseSchema": {
28+
"names": [
29+
"c0",
30+
"c1"
31+
],
32+
"struct": {
33+
"nullability": "NULLABILITY_REQUIRED",
34+
"types": [
35+
{
36+
"i8": {
37+
"nullability": "NULLABILITY_NULLABLE"
38+
}
39+
},
40+
{
41+
"i8": {
42+
"nullability": "NULLABILITY_NULLABLE"
43+
}
44+
}
45+
]
46+
}
47+
},
48+
"common": {
49+
"direct": {}
50+
},
51+
"virtualTable": {}
52+
}
53+
},
54+
"measures": [
55+
{
56+
"measure": {
57+
"arguments": [
58+
{
59+
"value": {
60+
"selection": {
61+
"directReference": {
62+
"structField": {}
63+
},
64+
"rootReference": {}
65+
}
66+
}
67+
}
68+
],
69+
"functionReference": 1,
70+
"invocation": "AGGREGATION_INVOCATION_ALL",
71+
"outputType": {
72+
"i8": {
73+
"nullability": "NULLABILITY_NULLABLE"
74+
}
75+
},
76+
"phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT"
77+
}
78+
}
79+
]
80+
}
81+
},
82+
"names": [
83+
"summation"
84+
]
85+
}
86+
}
87+
],
88+
"version": {
89+
"minorNumber": 29,
90+
"producer": "substrait-go v4.2.0"
91+
}
92+
}
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
{
2+
"extensionUris": [
3+
{
4+
"extensionUriAnchor": 1,
5+
"uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml"
6+
}
7+
],
8+
"extensions": [
9+
{
10+
"extensionFunction": {
11+
"extensionUriReference": 1,
12+
"functionAnchor": 1,
13+
"name": "sum:i8"
14+
}
15+
}
16+
],
17+
"relations": [
18+
{
19+
"root": {
20+
"input": {
21+
"aggregate": {
22+
"common": {
23+
"direct": {}
24+
},
25+
"input": {
26+
"read": {
27+
"baseSchema": {
28+
"names": [
29+
"c0",
30+
"c1"
31+
],
32+
"struct": {
33+
"nullability": "NULLABILITY_REQUIRED",
34+
"types": [
35+
{
36+
"i8": {
37+
"nullability": "NULLABILITY_NULLABLE"
38+
}
39+
},
40+
{
41+
"i8": {
42+
"nullability": "NULLABILITY_NULLABLE"
43+
}
44+
}
45+
]
46+
}
47+
},
48+
"common": {
49+
"direct": {}
50+
},
51+
"virtualTable": {}
52+
}
53+
},
54+
"groupingExpressions": [
55+
{
56+
"selection": {
57+
"directReference": {
58+
"structField": {}
59+
},
60+
"rootReference": {}
61+
}
62+
}
63+
],
64+
"groupings": [
65+
{
66+
"expressionReferences": [0]
67+
}
68+
69+
],
70+
"measures": [
71+
{
72+
"measure": {
73+
"arguments": [
74+
{
75+
"value": {
76+
"selection": {
77+
"directReference": {
78+
"structField": {}
79+
},
80+
"rootReference": {}
81+
}
82+
}
83+
}
84+
],
85+
"functionReference": 1,
86+
"invocation": "AGGREGATION_INVOCATION_ALL",
87+
"outputType": {
88+
"i8": {
89+
"nullability": "NULLABILITY_NULLABLE"
90+
}
91+
},
92+
"phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT"
93+
}
94+
}
95+
]
96+
}
97+
},
98+
"names": [
99+
"c0",
100+
"summation"
101+
]
102+
}
103+
}
104+
],
105+
"version": {
106+
"minorNumber": 29,
107+
"producer": "substrait-go v4.2.0"
108+
}
109+
}

0 commit comments

Comments
 (0)