Skip to content

Commit 6f3162a

Browse files
committed
asdfasdfasdf
1 parent 611bb52 commit 6f3162a

File tree

2 files changed

+170
-17
lines changed

2 files changed

+170
-17
lines changed

flux-cli/src/main/java/com/marklogic/flux/impl/importdata/AggregationParams.java

Lines changed: 127 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,15 @@ class AggregationParams implements CommandLine.ITypeConverter<AggregationParams.
2828
)
2929
private List<Aggregation> aggregations = new ArrayList<>();
3030

31+
@CommandLine.Option(
32+
names = "--aggregate-order-by",
33+
description = "Specify ordering for an aggregated array. Must be of the form aggregationName=columnName or " +
34+
"aggregationName=columnName:asc or aggregationName=columnName:desc. The columnName must be one of the " +
35+
"columns in the corresponding aggregation. Default order is ascending.",
36+
converter = AggregateOrderBy.class
37+
)
38+
private AggregateOrderBy aggregateOrderBy;
39+
3140
public static class Aggregation {
3241
private String newColumnName;
3342
private List<String> columnNamesToGroup;
@@ -38,6 +47,60 @@ public Aggregation(String newColumnName, List<String> columnNamesToGroup) {
3847
}
3948
}
4049

50+
public static class AggregateOrderBy implements CommandLine.ITypeConverter<AggregateOrderBy> {
51+
private String aggregationName;
52+
private String columnName;
53+
private boolean ascending = true;
54+
55+
public AggregateOrderBy() {}
56+
57+
public AggregateOrderBy(String aggregationName, String columnName, boolean ascending) {
58+
this.aggregationName = aggregationName;
59+
this.columnName = columnName;
60+
this.ascending = ascending;
61+
}
62+
63+
@Override
64+
public AggregateOrderBy convert(String value) {
65+
String[] parts = value.split("=");
66+
if (parts.length != 2) {
67+
throw new FluxException(String.format("Invalid aggregate order-by: %s; must be of the form " +
68+
"aggregationName=columnName or aggregationName=columnName:asc or aggregationName=columnName:desc", value));
69+
}
70+
71+
String aggName = parts[0];
72+
String[] columnParts = parts[1].split(":");
73+
String colName = columnParts[0];
74+
boolean asc = true;
75+
76+
if (columnParts.length == 2) {
77+
String direction = columnParts[1].toLowerCase();
78+
if ("desc".equals(direction)) {
79+
asc = false;
80+
} else if (!"asc".equals(direction)) {
81+
throw new FluxException(String.format("Invalid sort direction: %s; must be 'asc' or 'desc'", columnParts[1]));
82+
}
83+
} else if (columnParts.length > 2) {
84+
throw new FluxException(String.format("Invalid aggregate order-by: %s; must be of the form " +
85+
"aggregationName=columnName or aggregationName=columnName:asc or aggregationName=columnName:desc", value));
86+
}
87+
88+
return new AggregateOrderBy(aggName, colName, asc);
89+
}
90+
91+
public String getAggregationName() {
92+
return aggregationName;
93+
}
94+
95+
public String getColumnName() {
96+
return columnName;
97+
}
98+
99+
public boolean isAscending() {
100+
return ascending;
101+
}
102+
}
103+
41104
@Override
42105
public Aggregation convert(String value) {
43106
String[] parts = value.split("=");
@@ -61,6 +124,10 @@ public void addAggregationExpression(String newColumnName, String... columns) {
61124
this.aggregations.add(new Aggregation(newColumnName, Arrays.asList(columns)));
62125
}
63126

127+
public void addAggregateOrderBy(String aggregationName, String columnName, boolean ascending) {
128+
this.aggregateOrderBy = new AggregateOrderBy(aggregationName, columnName, ascending);
129+
}
130+
64131
public Dataset<Row> applyGroupBy(Dataset<Row> dataset) {
65132
if (groupBy == null || groupBy.trim().isEmpty()) {
66133
return dataset;
@@ -73,13 +140,54 @@ public Dataset<Row> applyGroupBy(Dataset<Row> dataset) {
73140
columns.addAll(aggregationColumns);
74141
final Column aliasColumn = columns.get(0);
75142
final Column[] columnsToGroup = columns.subList(1, columns.size()).toArray(new Column[]{});
143+
Dataset<Row> result;
76144
try {
77-
return groupedDataset.agg(aliasColumn, columnsToGroup);
145+
result = groupedDataset.agg(aliasColumn, columnsToGroup);
78146
} catch (Exception e) {
79147
String columnNames = aggregations.stream().map(agg -> agg.columnNamesToGroup.toString()).collect(Collectors.joining(", "));
80148
throw new FluxException(String.format("Unable to aggregate columns: %s; please ensure that each column " +
81149
"name will be present in the data read from the data source.", columnNames), e);
82150
}
151+
152+
// Apply sorting to aggregated arrays if specified
153+
if (aggregateOrderBy != null) {
154+
result = applySortToAggregatedArray(result);
155+
}
156+
157+
return result;
158+
}
159+
160+
private Dataset<Row> applySortToAggregatedArray(Dataset<Row> dataset) {
161+
String aggName = aggregateOrderBy.getAggregationName();
162+
String sortField = aggregateOrderBy.getColumnName();
163+
164+
// Find the aggregation to determine if it's single or multi-column
165+
Aggregation agg = aggregations.stream()
166+
.filter(a -> a.newColumnName.equals(aggName))
167+
.findFirst()
168+
.orElseThrow(() -> new FluxException(String.format(
169+
"Aggregate order-by references unknown aggregation '%s'", aggName)));
170+
171+
Column sortedColumn;
172+
if (agg.columnNamesToGroup.size() == 1) {
173+
// For single-column arrays, use sort_array
174+
sortedColumn = functions.sort_array(functions.col(aggName), !aggregateOrderBy.isAscending());
175+
} else {
176+
// For struct arrays, use array_sort with SQL expression
177+
// In array_sort lambda: return -1 if left should come before right, 1 if after, 0 if equal
178+
int whenLessThan = aggregateOrderBy.isAscending() ? -1 : 1;
179+
int whenGreaterThan = aggregateOrderBy.isAscending() ? 1 : -1;
180+
181+
sortedColumn = functions.expr(String.format(
182+
"array_sort(%s, (left, right) -> " +
183+
"case when left.%s < right.%s then %d " +
184+
"when left.%s > right.%s then %d else 0 end)",
185+
aggName, sortField, sortField, whenLessThan,
186+
sortField, sortField, whenGreaterThan
187+
));
188+
}
189+
190+
return dataset.withColumn(aggName, sortedColumn);
83191
}
84192

85193
/**
@@ -108,17 +216,32 @@ private List<Column> makeAggregationColumns() {
108216
List<Column> columns = new ArrayList<>();
109217
aggregations.forEach(aggregation -> {
110218
final List<String> columnNames = aggregation.columnNamesToGroup;
219+
Column resultColumn;
220+
111221
if (columnNames.size() == 1) {
112222
Column column = new Column(columnNames.get(0));
113-
Column listOfValuesColumn = functions.collect_list(functions.concat(column));
114-
columns.add(listOfValuesColumn.alias(aggregation.newColumnName));
223+
resultColumn = functions.collect_list(functions.concat(column));
115224
} else {
116225
Column[] structColumns = columnNames.stream().map(functions::col).toArray(Column[]::new);
117226
Column arrayColumn = functions.collect_list(functions.struct(structColumns));
118227
// array_distinct removes duplicate objects that can result from 2+ joins existing in the query.
119228
// See https://www.sparkreference.com/reference/array_distinct/ for performance considerations.
120-
columns.add(functions.array_distinct(arrayColumn).alias(aggregation.newColumnName));
229+
resultColumn = functions.array_distinct(arrayColumn);
230+
}
231+
232+
// Validate aggregate-order-by if specified for this aggregation
233+
if (aggregateOrderBy != null && aggregateOrderBy.getAggregationName().equals(aggregation.newColumnName)) {
234+
if (!columnNames.contains(aggregateOrderBy.getColumnName())) {
235+
throw new FluxException(String.format(
236+
"Invalid aggregate order-by for '%s': column '%s' is not in the aggregation. " +
237+
"Available columns: %s",
238+
aggregation.newColumnName,
239+
aggregateOrderBy.getColumnName(),
240+
columnNames));
241+
}
121242
}
243+
244+
columns.add(resultColumn.alias(aggregation.newColumnName));
122245
});
123246
return columns;
124247
}

flux-cli/src/test/java/com/marklogic/flux/impl/importdata/ImportJdbcWithAggregatesTest.java

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,16 @@
66
import com.fasterxml.jackson.databind.JsonNode;
77
import com.fasterxml.jackson.databind.node.ArrayNode;
88
import com.fasterxml.jackson.databind.node.JsonNodeType;
9+
import com.marklogic.client.ext.helper.ClientHelper;
910
import com.marklogic.client.io.StringHandle;
1011
import com.marklogic.flux.AbstractTest;
1112
import com.marklogic.flux.impl.PostgresUtil;
1213
import org.junit.jupiter.api.Test;
1314

15+
import java.util.List;
16+
1417
import static org.junit.jupiter.api.Assertions.assertEquals;
18+
import static org.junit.jupiter.api.Assertions.assertTrue;
1519

1620
class ImportJdbcWithAggregatesTest extends AbstractTest {
1721

@@ -74,7 +78,7 @@ void customerWithArrayOfRentalsAndArrayOfPayments() {
7478
from customer c
7579
inner join rental r on c.customer_id = r.customer_id
7680
inner join payment p on c.customer_id = p.customer_id
77-
where c.customer_id = 1 and r.rental_id < 1000 and p.payment_id < 19000
81+
where r.rental_id < 1000 and p.payment_id < 19000
7882
order by p.amount
7983
""";
8084

@@ -88,7 +92,8 @@ void customerWithArrayOfRentalsAndArrayOfPayments() {
8892
"--aggregate", "rentals=rental_id,inventory_id",
8993
"--connection-string", makeConnectionString(),
9094
"--permissions", DEFAULT_PERMISSIONS,
91-
"--uri-template", "/customer/{customer_id}.json"
95+
"--uri-template", "/customer/{customer_id}.json",
96+
"--collections", "customer"
9297
);
9398

9499
JsonNode doc = readJsonDocument("/customer/1.json");
@@ -113,14 +118,16 @@ void customerWithArrayOfRentalsAndArrayOfPayments() {
113118
* Attempts to reproduce the query in MLE-25002 by using select statements for the "from" and "inner join".
114119
*/
115120
@Test
116-
void orderByWithFromAndInnerJoinsThatUseSelects() {
121+
void copilotUseThisTest() {
117122
String query = """
118123
select
119124
c.customer_id, c.first_name,
125+
r.rental_id, r.inventory_id,
120126
p.payment_id, p.amount
121-
from (select * from customer c) as c
122-
inner join (select * from payment where payment_id < 19000) p on c.customer_id = p.customer_id
123-
order by p.amount desc
127+
from customer c
128+
inner join rental r on c.customer_id = r.customer_id
129+
inner join payment p on c.customer_id = p.customer_id
130+
order by p.amount
124131
""";
125132

126133
run(
@@ -129,18 +136,41 @@ inner join (select * from payment where payment_id < 19000) p on c.customer_id =
129136
"--query", query,
130137
"--group-by", "customer_id",
131138
"--aggregate", "payments=payment_id,amount",
139+
"--aggregate-order-by", "payments=amount:asc",
132140
"--connection-string", makeConnectionString(),
133141
"--permissions", DEFAULT_PERMISSIONS,
134-
"--uri-template", "/customer/{customer_id}.json"
142+
"--uri-template", "/customer/{customer_id}.json",
143+
"--collections", "customer"
135144
);
136145

137-
JsonNode doc = readJsonDocument("/customer/1.json");
138-
System.out.println(doc.toPrettyString());
146+
verifyEachCustomerHasPaymentsOrdered();
147+
}
139148

140-
ArrayNode payments = (ArrayNode) doc.get("payments");
141-
assertEquals(7, payments.size(), "The query should have selected 7 related payments.");
142-
assertEquals(0.99, payments.get(6).get("amount").asDouble());
143-
assertEquals(9.99, payments.get(0).get("amount").asDouble());
149+
private void verifyEachCustomerHasPaymentsOrdered() {
150+
List<String> uris = new ClientHelper(getDatabaseClient()).getUrisInCollection("customer");
151+
for (String uri : uris) {
152+
JsonNode doc = readJsonDocument(uri);
153+
ArrayNode payments = (ArrayNode) doc.get("payments");
154+
if (payments.size() < 2) {
155+
continue;
156+
}
157+
double firstAmount = payments.get(0).get("amount").asDouble();
158+
double lastAmount = payments.get(payments.size() - 1).get("amount").asDouble();
159+
assertTrue(lastAmount >= firstAmount, "BAD CUSTOMER: " + doc.toPrettyString());
160+
}
161+
162+
}
163+
164+
@Test
165+
void backup() {
166+
String query = """
167+
select
168+
c.customer_id, c.first_name,
169+
p.payment_id, p.amount
170+
from (select * from customer c) as c
171+
inner join (select * from payment where payment_id < 19000) p on c.customer_id = p.customer_id
172+
order by p.amount desc
173+
""";
144174
}
145175

146176
/**

0 commit comments

Comments
 (0)