@@ -28,6 +28,15 @@ class AggregationParams implements CommandLine.ITypeConverter<AggregationParams.
2828 )
2929 private List <Aggregation > aggregations = new ArrayList <>();
3030
31+ @ CommandLine .Option (
32+ names = "--aggregate-order-by" ,
33+ description = "Specify ordering for an aggregated array. Must be of the form aggregationName=columnName or " +
34+ "aggregationName=columnName:asc or aggregationName=columnName:desc. The columnName must be one of the " +
35+ "columns in the corresponding aggregation. Default order is ascending." ,
36+ converter = AggregateOrderBy .class
37+ )
38+ private AggregateOrderBy aggregateOrderBy ;
39+
3140 public static class Aggregation {
3241 private String newColumnName ;
3342 private List <String > columnNamesToGroup ;
@@ -38,6 +47,60 @@ public Aggregation(String newColumnName, List<String> columnNamesToGroup) {
3847 }
3948 }
4049
50+ public static class AggregateOrderBy implements CommandLine .ITypeConverter <AggregateOrderBy > {
51+ private String aggregationName ;
52+ private String columnName ;
53+ private boolean ascending = true ;
54+
55+ public AggregateOrderBy () {}
56+
57+ public AggregateOrderBy (String aggregationName , String columnName , boolean ascending ) {
58+ this .aggregationName = aggregationName ;
59+ this .columnName = columnName ;
60+ this .ascending = ascending ;
61+ }
62+
63+ @ Override
64+ public AggregateOrderBy convert (String value ) {
65+ String [] parts = value .split ("=" );
66+ if (parts .length != 2 ) {
67+ throw new FluxException (String .format ("Invalid aggregate order-by: %s; must be of the form " +
68+ "aggregationName=columnName or aggregationName=columnName:asc or aggregationName=columnName:desc" , value ));
69+ }
70+
71+ String aggName = parts [0 ];
72+ String [] columnParts = parts [1 ].split (":" );
73+ String colName = columnParts [0 ];
74+ boolean asc = true ;
75+
76+ if (columnParts .length == 2 ) {
77+ String direction = columnParts [1 ].toLowerCase ();
78+ if ("desc" .equals (direction )) {
79+ asc = false ;
80+ } else if (!"asc" .equals (direction )) {
81+ throw new FluxException (String .format ("Invalid sort direction: %s; must be 'asc' or 'desc'" , columnParts [1 ]));
82+ }
83+ } else if (columnParts .length > 2 ) {
84+ throw new FluxException (String .format ("Invalid aggregate order-by: %s; must be of the form " +
85+ "aggregationName=columnName or aggregationName=columnName:asc or aggregationName=columnName:desc" , value ));
86+ }
87+
88+ return new AggregateOrderBy (aggName , colName , asc );
89+ }
90+
91+ public String getAggregationName () {
92+ return aggregationName ;
93+ }
94+
95+ public String getColumnName () {
96+ return columnName ;
97+ }
98+
99+ public boolean isAscending () {
100+ return ascending ;
101+ }
102+ }
103+
41104 @ Override
42105 public Aggregation convert (String value ) {
43106 String [] parts = value .split ("=" );
@@ -61,6 +124,10 @@ public void addAggregationExpression(String newColumnName, String... columns) {
61124 this .aggregations .add (new Aggregation (newColumnName , Arrays .asList (columns )));
62125 }
63126
127+ public void addAggregateOrderBy (String aggregationName , String columnName , boolean ascending ) {
128+ this .aggregateOrderBy = new AggregateOrderBy (aggregationName , columnName , ascending );
129+ }
130+
64131 public Dataset <Row > applyGroupBy (Dataset <Row > dataset ) {
65132 if (groupBy == null || groupBy .trim ().isEmpty ()) {
66133 return dataset ;
@@ -73,13 +140,54 @@ public Dataset<Row> applyGroupBy(Dataset<Row> dataset) {
73140 columns .addAll (aggregationColumns );
74141 final Column aliasColumn = columns .get (0 );
75142 final Column [] columnsToGroup = columns .subList (1 , columns .size ()).toArray (new Column []{});
143+ Dataset <Row > result ;
76144 try {
77- return groupedDataset .agg (aliasColumn , columnsToGroup );
145+ result = groupedDataset .agg (aliasColumn , columnsToGroup );
78146 } catch (Exception e ) {
79147 String columnNames = aggregations .stream ().map (agg -> agg .columnNamesToGroup .toString ()).collect (Collectors .joining (", " ));
80148 throw new FluxException (String .format ("Unable to aggregate columns: %s; please ensure that each column " +
81149 "name will be present in the data read from the data source." , columnNames ), e );
82150 }
151+
152+ // Apply sorting to aggregated arrays if specified
153+ if (aggregateOrderBy != null ) {
154+ result = applySortToAggregatedArray (result );
155+ }
156+
157+ return result ;
158+ }
159+
160+ private Dataset <Row > applySortToAggregatedArray (Dataset <Row > dataset ) {
161+ String aggName = aggregateOrderBy .getAggregationName ();
162+ String sortField = aggregateOrderBy .getColumnName ();
163+
164+ // Find the aggregation to determine if it's single or multi-column
165+ Aggregation agg = aggregations .stream ()
166+ .filter (a -> a .newColumnName .equals (aggName ))
167+ .findFirst ()
168+ .orElseThrow (() -> new FluxException (String .format (
169+ "Aggregate order-by references unknown aggregation '%s'" , aggName )));
170+
171+ Column sortedColumn ;
172+ if (agg .columnNamesToGroup .size () == 1 ) {
173+ // For single-column arrays, use sort_array
174+ sortedColumn = functions .sort_array (functions .col (aggName ), !aggregateOrderBy .isAscending ());
175+ } else {
176+ // For struct arrays, use array_sort with SQL expression
177+ // In array_sort lambda: return -1 if left should come before right, 1 if after, 0 if equal
178+ int whenLessThan = aggregateOrderBy .isAscending () ? -1 : 1 ;
179+ int whenGreaterThan = aggregateOrderBy .isAscending () ? 1 : -1 ;
180+
181+ sortedColumn = functions .expr (String .format (
182+ "array_sort(%s, (left, right) -> " +
183+ "case when left.%s < right.%s then %d " +
184+ "when left.%s > right.%s then %d else 0 end)" ,
185+ aggName , sortField , sortField , whenLessThan ,
186+ sortField , sortField , whenGreaterThan
187+ ));
188+ }
189+
190+ return dataset .withColumn (aggName , sortedColumn );
83191 }
84192
85193 /**
@@ -108,17 +216,32 @@ private List<Column> makeAggregationColumns() {
108216 List <Column > columns = new ArrayList <>();
109217 aggregations .forEach (aggregation -> {
110218 final List <String > columnNames = aggregation .columnNamesToGroup ;
219+ Column resultColumn ;
220+
111221 if (columnNames .size () == 1 ) {
112222 Column column = new Column (columnNames .get (0 ));
113- Column listOfValuesColumn = functions .collect_list (functions .concat (column ));
114- columns .add (listOfValuesColumn .alias (aggregation .newColumnName ));
223+ resultColumn = functions .collect_list (functions .concat (column ));
115224 } else {
116225 Column [] structColumns = columnNames .stream ().map (functions ::col ).toArray (Column []::new );
117226 Column arrayColumn = functions .collect_list (functions .struct (structColumns ));
118227 // array_distinct removes duplicate objects that can result from 2+ joins existing in the query.
119228 // See https://www.sparkreference.com/reference/array_distinct/ for performance considerations.
120- columns .add (functions .array_distinct (arrayColumn ).alias (aggregation .newColumnName ));
229+ resultColumn = functions .array_distinct (arrayColumn );
230+ }
231+
232+ // Validate aggregate-order-by if specified for this aggregation
233+ if (aggregateOrderBy != null && aggregateOrderBy .getAggregationName ().equals (aggregation .newColumnName )) {
234+ if (!columnNames .contains (aggregateOrderBy .getColumnName ())) {
235+ throw new FluxException (String .format (
236+ "Invalid aggregate order-by for '%s': column '%s' is not in the aggregation. " +
237+ "Available columns: %s" ,
238+ aggregation .newColumnName ,
239+ aggregateOrderBy .getColumnName (),
240+ columnNames ));
241+ }
121242 }
243+
244+ columns .add (resultColumn .alias (aggregation .newColumnName ));
122245 });
123246 return columns ;
124247 }
0 commit comments