diff --git a/flux-cli/src/main/java/com/marklogic/flux/impl/importdata/AggregationParams.java b/flux-cli/src/main/java/com/marklogic/flux/impl/importdata/AggregationParams.java index e29f2727..a90f6344 100644 --- a/flux-cli/src/main/java/com/marklogic/flux/impl/importdata/AggregationParams.java +++ b/flux-cli/src/main/java/com/marklogic/flux/impl/importdata/AggregationParams.java @@ -28,6 +28,20 @@ class AggregationParams implements CommandLine.ITypeConverter aggregations = new ArrayList<>(); + @CommandLine.Option( + names = "--aggregate-order-by", + description = "Specify ordering for an aggregated array. Must be of the form aggregationName=columnName. " + + "The columnName must be one of the columns in the corresponding aggregation. Default order is ascending.", + converter = AggregateOrderBy.class + ) + private AggregateOrderBy aggregateOrderBy; + + @CommandLine.Option( + names = "--aggregate-order-desc", + description = "Sort the aggregated array in descending order. Only applies when --aggregate-order-by is specified." + ) + private boolean aggregateOrderDesc = false; + public static class Aggregation { private String newColumnName; private List columnNamesToGroup; @@ -38,6 +52,37 @@ public Aggregation(String newColumnName, List columnNamesToGroup) { } } + public static class AggregateOrderBy implements CommandLine.ITypeConverter { + private String aggregationName; + private String columnName; + + public AggregateOrderBy() {} + + public AggregateOrderBy(String aggregationName, String columnName) { + this.aggregationName = aggregationName; + this.columnName = columnName; + } + + @Override + public AggregateOrderBy convert(String value) { + String[] parts = value.split("="); + if (parts.length != 2) { + throw new FluxException(String.format("Invalid aggregate order-by: %s; must be of the form " + + "aggregationName=columnName", value)); + } + + return new AggregateOrderBy(parts[0], parts[1]); + } + + public String getAggregationName() { + return aggregationName; + } + + public String getColumnName() { + return columnName; + } + } + @Override public Aggregation convert(String value) { String[] parts = value.split("="); @@ -61,6 +106,11 @@ public void addAggregationExpression(String newColumnName, String... columns) { this.aggregations.add(new Aggregation(newColumnName, Arrays.asList(columns))); } + public void addAggregateOrderBy(String aggregationName, String columnName, boolean ascending) { + this.aggregateOrderBy = new AggregateOrderBy(aggregationName, columnName); + this.aggregateOrderDesc = !ascending; + } + public Dataset applyGroupBy(Dataset dataset) { if (groupBy == null || groupBy.trim().isEmpty()) { return dataset; @@ -73,13 +123,54 @@ public Dataset applyGroupBy(Dataset dataset) { columns.addAll(aggregationColumns); final Column aliasColumn = columns.get(0); final Column[] columnsToGroup = columns.subList(1, columns.size()).toArray(new Column[]{}); + Dataset result; try { - return groupedDataset.agg(aliasColumn, columnsToGroup); + result = groupedDataset.agg(aliasColumn, columnsToGroup); } catch (Exception e) { String columnNames = aggregations.stream().map(agg -> agg.columnNamesToGroup.toString()).collect(Collectors.joining(", ")); throw new FluxException(String.format("Unable to aggregate columns: %s; please ensure that each column " + "name will be present in the data read from the data source.", columnNames), e); } + + // Apply sorting to aggregated arrays if specified + if (aggregateOrderBy != null) { + result = applySortToAggregatedArray(result); + } + + return result; + } + + private Dataset applySortToAggregatedArray(Dataset dataset) { + String aggName = aggregateOrderBy.getAggregationName(); + String sortField = aggregateOrderBy.getColumnName(); + + // Find the aggregation to determine if it's single or multi-column + Aggregation agg = aggregations.stream() + .filter(a -> a.newColumnName.equals(aggName)) + .findFirst() + .orElseThrow(() -> new FluxException(String.format( + "Aggregate order-by references unknown aggregation '%s'", aggName))); + + Column sortedColumn; + if (agg.columnNamesToGroup.size() == 1) { + // For single-column arrays, use sort_array + sortedColumn = functions.sort_array(functions.col(aggName), aggregateOrderDesc); + } else { + // For struct arrays, use array_sort with SQL expression + // In array_sort lambda: return -1 if left should come before right, 1 if after, 0 if equal + int whenLessThan = aggregateOrderDesc ? 1 : -1; + int whenGreaterThan = aggregateOrderDesc ? -1 : 1; + + sortedColumn = functions.expr(String.format( + "array_sort(%s, (left, right) -> " + + "case when left.%s < right.%s then %d " + + "when left.%s > right.%s then %d else 0 end)", + aggName, sortField, sortField, whenLessThan, + sortField, sortField, whenGreaterThan + )); + } + + return dataset.withColumn(aggName, sortedColumn); } /** @@ -108,17 +199,32 @@ private List makeAggregationColumns() { List columns = new ArrayList<>(); aggregations.forEach(aggregation -> { final List columnNames = aggregation.columnNamesToGroup; + Column resultColumn; + if (columnNames.size() == 1) { Column column = new Column(columnNames.get(0)); - Column listOfValuesColumn = functions.collect_list(functions.concat(column)); - columns.add(listOfValuesColumn.alias(aggregation.newColumnName)); + resultColumn = functions.collect_list(functions.concat(column)); } else { Column[] structColumns = columnNames.stream().map(functions::col).toArray(Column[]::new); Column arrayColumn = functions.collect_list(functions.struct(structColumns)); // array_distinct removes duplicate objects that can result from 2+ joins existing in the query. // See https://www.sparkreference.com/reference/array_distinct/ for performance considerations. - columns.add(functions.array_distinct(arrayColumn).alias(aggregation.newColumnName)); + resultColumn = functions.array_distinct(arrayColumn); + } + + // Validate aggregate-order-by if specified for this aggregation + if (aggregateOrderBy != null && aggregateOrderBy.getAggregationName().equals(aggregation.newColumnName)) { + if (!columnNames.contains(aggregateOrderBy.getColumnName())) { + throw new FluxException(String.format( + "Invalid aggregate order-by for '%s': column '%s' is not in the aggregation. " + + "Available columns: %s", + aggregation.newColumnName, + aggregateOrderBy.getColumnName(), + columnNames)); + } } + + columns.add(resultColumn.alias(aggregation.newColumnName)); }); return columns; } diff --git a/flux-cli/src/test/java/com/marklogic/flux/impl/importdata/ImportJdbcWithAggregatesTest.java b/flux-cli/src/test/java/com/marklogic/flux/impl/importdata/ImportJdbcWithAggregatesTest.java index 9b35ab7c..a0114af4 100644 --- a/flux-cli/src/test/java/com/marklogic/flux/impl/importdata/ImportJdbcWithAggregatesTest.java +++ b/flux-cli/src/test/java/com/marklogic/flux/impl/importdata/ImportJdbcWithAggregatesTest.java @@ -6,12 +6,16 @@ import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.node.ArrayNode; import com.fasterxml.jackson.databind.node.JsonNodeType; +import com.marklogic.client.ext.helper.ClientHelper; import com.marklogic.client.io.StringHandle; import com.marklogic.flux.AbstractTest; import com.marklogic.flux.impl.PostgresUtil; import org.junit.jupiter.api.Test; +import java.util.List; + import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; class ImportJdbcWithAggregatesTest extends AbstractTest { @@ -74,7 +78,7 @@ void customerWithArrayOfRentalsAndArrayOfPayments() { from customer c inner join rental r on c.customer_id = r.customer_id inner join payment p on c.customer_id = p.customer_id - where c.customer_id = 1 and r.rental_id < 1000 and p.payment_id < 19000 + where r.rental_id < 1000 and p.payment_id < 19000 order by p.amount """; @@ -86,6 +90,7 @@ void customerWithArrayOfRentalsAndArrayOfPayments() { "--group-by", "customer_id", "--aggregate", "payments=payment_id,amount", "--aggregate", "rentals=rental_id,inventory_id", + "--aggregate-order-by", "payments=amount", "--connection-string", makeConnectionString(), "--permissions", DEFAULT_PERMISSIONS, "--uri-template", "/customer/{customer_id}.json" @@ -109,6 +114,39 @@ void customerWithArrayOfRentalsAndArrayOfPayments() { assertEquals(4020, rentals.get(1).get("inventory_id").asInt()); } + /** + * Attempts to reproduce the query in MLE-25002 by using select statements for the "from" and "inner join". + */ + @Test + void orderByPaymentsOnAllCustomers() { + String query = """ + select + c.customer_id, c.first_name, + r.rental_id, r.inventory_id, + p.payment_id, p.amount + from customer c + inner join rental r on c.customer_id = r.customer_id + inner join payment p on c.customer_id = p.customer_id + order by p.amount + """; + + run( + "import-jdbc", + "--jdbc-url", PostgresUtil.URL_WITH_AUTH, "--jdbc-driver", PostgresUtil.DRIVER, + "--query", query, + "--group-by", "customer_id", + "--aggregate", "payments=payment_id,amount", + "--aggregate-order-by", "payments=amount", + "--aggregate-order-desc", + "--connection-string", makeConnectionString(), + "--permissions", DEFAULT_PERMISSIONS, + "--uri-template", "/customer/{customer_id}.json", + "--collections", "customer" + ); + + verifyEachCustomerHasPaymentsOrderedDescending(); + } + /** * Demonstrates that a join can produce an array with atomic/simple values, instead of structs / objects. */ @@ -176,4 +214,18 @@ void badColumnName() { "--permissions", DEFAULT_PERMISSIONS ); } + + private void verifyEachCustomerHasPaymentsOrderedDescending() { + List uris = new ClientHelper(getDatabaseClient()).getUrisInCollection("customer"); + for (String uri : uris) { + JsonNode doc = readJsonDocument(uri); + ArrayNode payments = (ArrayNode) doc.get("payments"); + if (payments.size() < 2) { + continue; + } + double firstAmount = payments.get(0).get("amount").asDouble(); + double lastAmount = payments.get(payments.size() - 1).get("amount").asDouble(); + assertTrue(lastAmount >= firstAmount, "Customer doesn't have payments ordered descending! " + doc.toPrettyString()); + } + } }