Skip to content

Commit d62f262

Browse files
authored
feat(substrait): support order_by in aggregate functions (#13114)
1 parent 89e71ef commit d62f262

File tree

3 files changed

+143
-3
lines changed

3 files changed

+143
-3
lines changed

datafusion/substrait/src/logical_plan/consumer.rs

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -714,14 +714,27 @@ pub async fn from_substrait_rel(
714714
}
715715
_ => false,
716716
};
717+
let order_by = if !f.sorts.is_empty() {
718+
Some(
719+
from_substrait_sorts(
720+
ctx,
721+
&f.sorts,
722+
input.schema(),
723+
extensions,
724+
)
725+
.await?,
726+
)
727+
} else {
728+
None
729+
};
730+
717731
from_substrait_agg_func(
718732
ctx,
719733
f,
720734
input.schema(),
721735
extensions,
722736
filter,
723-
// TODO: Add parsing of order_by also
724-
None,
737+
order_by,
725738
distinct,
726739
)
727740
.await

datafusion/substrait/tests/cases/roundtrip_logical_plan.rs

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -685,6 +685,19 @@ async fn aggregate_wo_projection_consume() -> Result<()> {
685685
.await
686686
}
687687

688+
#[tokio::test]
689+
async fn aggregate_wo_projection_sorted_consume() -> Result<()> {
690+
let proto_plan =
691+
read_json("tests/testdata/test_plans/aggregate_sorted_no_project.substrait.json");
692+
693+
assert_expected_plan_substrait(
694+
proto_plan,
695+
"Aggregate: groupBy=[[data.a]], aggr=[[count(data.a) ORDER BY [data.a DESC NULLS FIRST] AS countA]]\
696+
\n TableScan: data projection=[a]",
697+
)
698+
.await
699+
}
700+
688701
#[tokio::test]
689702
async fn simple_intersect_consume() -> Result<()> {
690703
let proto_plan = read_json("tests/testdata/test_plans/intersect.substrait.json");
@@ -1025,8 +1038,9 @@ async fn roundtrip_aggregate_udf() -> Result<()> {
10251038

10261039
let ctx = create_context().await?;
10271040
ctx.register_udaf(dummy_agg);
1041+
roundtrip_with_ctx("select dummy_agg(a) from data", ctx.clone()).await?;
1042+
roundtrip_with_ctx("select dummy_agg(a order by a) from data", ctx.clone()).await?;
10281043

1029-
roundtrip_with_ctx("select dummy_agg(a) from data", ctx).await?;
10301044
Ok(())
10311045
}
10321046

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
{
2+
"extensionUris": [
3+
{
4+
"uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate_generic.yaml"
5+
}
6+
],
7+
"extensions": [
8+
{
9+
"extensionFunction": {
10+
"functionAnchor": 185,
11+
"name": "count:any"
12+
}
13+
}
14+
],
15+
"relations": [
16+
{
17+
"root": {
18+
"input": {
19+
"aggregate": {
20+
"input": {
21+
"read": {
22+
"common": {
23+
"direct": {}
24+
},
25+
"baseSchema": {
26+
"names": [
27+
"a"
28+
],
29+
"struct": {
30+
"types": [
31+
{
32+
"i64": {
33+
"nullability": "NULLABILITY_NULLABLE"
34+
}
35+
}
36+
],
37+
"nullability": "NULLABILITY_NULLABLE"
38+
}
39+
},
40+
"namedTable": {
41+
"names": [
42+
"data"
43+
]
44+
}
45+
}
46+
},
47+
"groupings": [
48+
{
49+
"groupingExpressions": [
50+
{
51+
"selection": {
52+
"directReference": {
53+
"structField": {}
54+
},
55+
"rootReference": {}
56+
}
57+
}
58+
]
59+
}
60+
],
61+
"measures": [
62+
{
63+
"measure": {
64+
"functionReference": 185,
65+
"phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT",
66+
"outputType": {
67+
"i64": {}
68+
},
69+
"arguments": [
70+
{
71+
"value": {
72+
"selection": {
73+
"directReference": {
74+
"structField": {}
75+
},
76+
"rootReference": {}
77+
}
78+
}
79+
}
80+
],
81+
"sorts": [
82+
{
83+
"expr": {
84+
"selection": {
85+
"directReference": {
86+
"structField": {
87+
"field": 0
88+
}
89+
},
90+
"rootReference": {
91+
}
92+
}
93+
},
94+
"direction": "SORT_DIRECTION_DESC_NULLS_FIRST"
95+
}
96+
]
97+
}
98+
}
99+
]
100+
}
101+
},
102+
"names": [
103+
"a",
104+
"countA"
105+
]
106+
}
107+
}
108+
],
109+
"version": {
110+
"minorNumber": 54,
111+
"producer": "manual"
112+
}
113+
}

0 commit comments

Comments
 (0)