Skip to content

Commit bd75b7c

Browse files
author
Selina Song
committed
WIP: sort flip collation
Signed-off-by: Selina Song <[email protected]>
1 parent f244bd6 commit bd75b7c

File tree

5 files changed

+205
-51
lines changed

5 files changed

+205
-51
lines changed

core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838
import java.util.stream.Stream;
3939
import org.apache.calcite.plan.RelOptTable;
4040
import org.apache.calcite.plan.ViewExpanders;
41+
import org.apache.calcite.rel.RelCollation;
42+
import org.apache.calcite.rel.RelCollations;
4143
import org.apache.calcite.rel.RelNode;
4244
import org.apache.calcite.rel.core.Aggregate;
4345
import org.apache.calcite.rel.core.JoinRelType;
@@ -567,19 +569,29 @@ public RelNode visitHead(Head node, CalcitePlanContext context) {
567569
public RelNode visitReverse(
568570
org.opensearch.sql.ast.tree.Reverse node, CalcitePlanContext context) {
569571
visitChildren(node, context);
570-
// Add ROW_NUMBER() column
571-
RexNode rowNumber =
572-
context
573-
.relBuilder
574-
.aggregateCall(SqlStdOperatorTable.ROW_NUMBER)
575-
.over()
576-
.rowsTo(RexWindowBounds.CURRENT_ROW)
577-
.as(REVERSE_ROW_NUM);
578-
context.relBuilder.projectPlus(rowNumber);
579-
// Sort by row number descending
580-
context.relBuilder.sort(context.relBuilder.desc(context.relBuilder.field(REVERSE_ROW_NUM)));
581-
// Remove row number column
582-
context.relBuilder.projectExcept(context.relBuilder.field(REVERSE_ROW_NUM));
572+
573+
RelCollation collation = context.relBuilder.peek().getTraitSet().getCollation();
574+
if (collation == null || collation == RelCollations.EMPTY) {
575+
// If no collation exists, use the traditional row_number approach
576+
// Add ROW_NUMBER() column
577+
RexNode rowNumber =
578+
context
579+
.relBuilder
580+
.aggregateCall(SqlStdOperatorTable.ROW_NUMBER)
581+
.over()
582+
.rowsTo(RexWindowBounds.CURRENT_ROW)
583+
.as(REVERSE_ROW_NUM);
584+
context.relBuilder.projectPlus(rowNumber);
585+
// Sort by row number descending
586+
context.relBuilder.sort(context.relBuilder.desc(context.relBuilder.field(REVERSE_ROW_NUM)));
587+
// Remove row number column
588+
context.relBuilder.projectExcept(context.relBuilder.field(REVERSE_ROW_NUM));
589+
} else {
590+
// If collation exists, reverse the sort direction
591+
RelCollation reversedCollation = PlanUtils.reverseCollation(collation);
592+
context.relBuilder.sort(reversedCollation);
593+
}
594+
583595
return context.relBuilder.peek();
584596
}
585597

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.sql.calcite.rule;
7+
8+
import org.apache.calcite.plan.RelOptRule;
9+
import org.apache.calcite.plan.RelOptRuleCall;
10+
import org.apache.calcite.rel.logical.LogicalSort;
11+
import org.opensearch.sql.calcite.plan.LogicalSystemLimit;
12+
13+
/**
14+
* Optimization rule that eliminates redundant consecutive sorts on the same field.
15+
* Detects: LogicalSort(field, direction1) -> LogicalSort(field, direction2)
16+
* Converts to: LogicalSort(field, direction1) (keeps outer sort)
17+
*/
18+
public class SortReverseOptimizationRule extends RelOptRule {
19+
20+
public static final SortReverseOptimizationRule INSTANCE = new SortReverseOptimizationRule();
21+
22+
private SortReverseOptimizationRule() {
23+
super(operand(LogicalSort.class,
24+
operand(LogicalSort.class, any())),
25+
"SortReverseOptimizationRule");
26+
}
27+
28+
@Override
29+
public boolean matches(RelOptRuleCall call) {
30+
LogicalSort outerSort = call.rel(0);
31+
LogicalSort innerSort = call.rel(1);
32+
33+
// Don't optimize if outer sort is a LogicalSystemLimit - we want to preserve system limits
34+
if (call.rel(0) instanceof LogicalSystemLimit) {
35+
return false;
36+
}
37+
38+
return hasSameField(outerSort, innerSort);
39+
}
40+
41+
@Override
42+
public void onMatch(RelOptRuleCall call) {
43+
LogicalSort outerSort = call.rel(0);
44+
LogicalSort innerSort = call.rel(1);
45+
46+
LogicalSort optimizedSort = LogicalSort.create(
47+
innerSort.getInput(),
48+
outerSort.getCollation(),
49+
outerSort.offset,
50+
outerSort.fetch);
51+
52+
call.transformTo(optimizedSort);
53+
}
54+
55+
private boolean hasSameField(LogicalSort outerSort, LogicalSort innerSort) {
56+
if (outerSort.getCollation().getFieldCollations().isEmpty()
57+
|| innerSort.getCollation().getFieldCollations().isEmpty()) {
58+
return false;
59+
}
60+
61+
int outerField = outerSort.getCollation().getFieldCollations().get(0).getFieldIndex();
62+
int innerField = innerSort.getCollation().getFieldCollations().get(0).getFieldIndex();
63+
return outerField == innerField;
64+
}
65+
}

core/src/main/java/org/opensearch/sql/calcite/utils/PlanUtils.java

Lines changed: 25 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
import java.util.List;
1717
import javax.annotation.Nullable;
1818
import org.apache.calcite.plan.RelOptTable;
19+
import org.apache.calcite.rel.RelCollation;
20+
import org.apache.calcite.rel.RelCollations;
21+
import org.apache.calcite.rel.RelFieldCollation;
1922
import org.apache.calcite.rel.RelHomogeneousShuttle;
2023
import org.apache.calcite.rel.RelNode;
2124
import org.apache.calcite.rel.RelShuttle;
@@ -355,40 +358,29 @@ static RexNode derefMapCall(RexNode rexNode) {
355358
return rexNode;
356359
}
357360

358-
/** Check if contains RexOver */
359-
static boolean containsRowNumberDedup(LogicalProject project) {
360-
return project.getProjects().stream()
361-
.anyMatch(p -> p instanceof RexOver && p.getKind() == SqlKind.ROW_NUMBER);
362-
}
361+
/**
362+
* Reverses the direction of a RelCollation.
363+
*
364+
* @param original The original collation to reverse
365+
* @return A new RelCollation with reversed directions
366+
*/
367+
public static RelCollation reverseCollation(RelCollation original) {
368+
if (original == null || original.getFieldCollations().isEmpty()) {
369+
return original;
370+
}
363371

364-
/** Get all RexWindow list from LogicalProject */
365-
static List<RexWindow> getRexWindowFromProject(LogicalProject project) {
366-
final List<RexWindow> res = new ArrayList<>();
367-
final RexVisitorImpl<Void> visitor =
368-
new RexVisitorImpl<>(true) {
369-
@Override
370-
public Void visitOver(RexOver over) {
371-
res.add(over.getWindow());
372-
return null;
373-
}
374-
};
375-
visitor.visitEach(project.getProjects());
376-
return res;
377-
}
372+
List<RelFieldCollation> reversedFields = new ArrayList<>();
373+
for (RelFieldCollation field : original.getFieldCollations()) {
374+
RelFieldCollation.Direction reversedDirection = field.direction.reverse();
375+
376+
RelFieldCollation reversedField = new RelFieldCollation(
377+
field.getFieldIndex(),
378+
reversedDirection,
379+
field.nullDirection
380+
);
381+
reversedFields.add(reversedField);
382+
}
378383

379-
static List<Integer> getSelectColumns(List<RexNode> rexNodes) {
380-
final List<Integer> selectedColumns = new ArrayList<>();
381-
final RexVisitorImpl<Void> visitor =
382-
new RexVisitorImpl<Void>(true) {
383-
@Override
384-
public Void visitInputRef(RexInputRef inputRef) {
385-
if (!selectedColumns.contains(inputRef.getIndex())) {
386-
selectedColumns.add(inputRef.getIndex());
387-
}
388-
return null;
389-
}
390-
};
391-
visitor.visitEach(rexNodes);
392-
return selectedColumns;
384+
return RelCollations.of(reversedFields);
393385
}
394386
}

integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ public void testFilterFunctionScriptPushDownExplain() throws Exception {
207207
public void testExplainWithReverse() throws IOException {
208208
String result =
209209
executeWithReplace(
210-
"explain source=opensearch-sql_test_index_account | sort age | reverse | head 5");
210+
"explain source=opensearch-sql_test_index_account | reverse | head 5");
211211

212212
// Verify that the plan contains a LogicalSort with fetch (from head 5)
213213
assertTrue(result.contains("LogicalSort") && result.contains("fetch=[5]"));
@@ -216,6 +216,37 @@ public void testExplainWithReverse() throws IOException {
216216
assertTrue(result.contains("ROW_NUMBER()"));
217217
assertTrue(result.contains("dir0=[DESC]"));
218218
}
219+
220+
@Test
221+
public void testExplainWithReversePushdown() throws IOException {
222+
// Test with a sort operation that should use the reverse pushdown optimization
223+
String result =
224+
executeWithReplace(
225+
"explain source=opensearch-sql_test_index_account | sort - age | reverse");
226+
227+
// Verify that the plan contains a LogicalSort with ascending direction (reversed from DESC)
228+
assertTrue(result.contains("LogicalSort"));
229+
assertTrue(result.contains("dir0=[ASC]"));
230+
231+
// Verify that ROW_NUMBER is NOT used (since we're using the collation-based optimization)
232+
assertFalse(result.contains("ROW_NUMBER()"));
233+
}
234+
235+
@Test
236+
public void testExplainWithReversePushdownMultipleFields() throws IOException {
237+
// Test with multiple sort fields that should use the reverse pushdown optimization
238+
String result =
239+
executeWithReplace(
240+
"explain source=opensearch-sql_test_index_account | sort - age, + firstname | reverse");
241+
242+
// Verify that the plan contains a LogicalSort with reversed directions
243+
assertTrue(result.contains("LogicalSort"));
244+
assertTrue(result.contains("dir0=[ASC]")); // age was DESC, now ASC
245+
assertTrue(result.contains("dir1=[DESC]")); // firstname was ASC, now DESC
246+
247+
// Verify that ROW_NUMBER is NOT used (since we're using the collation-based optimization)
248+
assertFalse(result.contains("ROW_NUMBER()"));
249+
}
219250

220251
@Test
221252
public void testExplainWithTimechartAvg() throws IOException {

integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteReverseCommandIT.java

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,14 +97,68 @@ public void testReverseWithComplexPipeline() throws IOException {
9797
}
9898

9999
@Test
100-
public void testReverseWithMultipleSorts() throws IOException {
101-
// Use the existing BANK data but with a simpler, more predictable query
100+
public void testReverseWithDescendingSort() throws IOException {
101+
// Test reverse with descending sort (- age)
102102
JSONObject result =
103103
executeQuery(
104104
String.format(
105-
"source=%s | sort account_number | fields account_number | reverse | head 3",
105+
"source=%s | sort - account_number | fields account_number | reverse",
106106
TEST_INDEX_BANK));
107107
verifySchema(result, schema("account_number", "bigint"));
108-
verifyDataRowsInOrder(result, rows(32), rows(25), rows(20));
108+
verifyDataRowsInOrder(
109+
result, rows(1), rows(6), rows(13), rows(18), rows(20), rows(25), rows(32));
110+
}
111+
112+
@Test
113+
public void testReverseWithMixedSortDirections() throws IOException {
114+
// Test reverse with mixed sort directions (- age, + firstname)
115+
JSONObject result =
116+
executeQuery(
117+
String.format(
118+
"source=%s | sort - account_number, + firstname | fields account_number, firstname | reverse",
119+
TEST_INDEX_BANK));
120+
verifySchema(result, schema("account_number", "bigint"), schema("firstname", "string"));
121+
verifyDataRowsInOrder(
122+
result,
123+
rows(1, "Amber JOHnny"),
124+
rows(6, "Hattie"),
125+
rows(13, "Nanette"),
126+
rows(18, "Dale"),
127+
rows(20, "Elinor"),
128+
rows(25, "Virginia"),
129+
rows(32, "Dillard"));
130+
}
131+
132+
@Test
133+
public void testDoubleReverseWithDescendingSort() throws IOException {
134+
// Test double reverse with descending sort (- age)
135+
JSONObject result =
136+
executeQuery(
137+
String.format(
138+
"source=%s | sort - account_number | fields account_number | reverse | reverse",
139+
TEST_INDEX_BANK));
140+
verifySchema(result, schema("account_number", "bigint"));
141+
verifyDataRowsInOrder(
142+
result, rows(32), rows(25), rows(20), rows(18), rows(13), rows(6), rows(1));
143+
}
144+
145+
@Test
146+
public void testDoubleReverseWithMixedSortDirections() throws IOException {
147+
// Test double reverse with mixed sort directions (- age, + firstname)
148+
JSONObject result =
149+
executeQuery(
150+
String.format(
151+
"source=%s | sort - account_number, + firstname | fields account_number, firstname | reverse | reverse",
152+
TEST_INDEX_BANK));
153+
verifySchema(result, schema("account_number", "bigint"), schema("firstname", "string"));
154+
verifyDataRowsInOrder(
155+
result,
156+
rows(32, "Dillard"),
157+
rows(25, "Virginia"),
158+
rows(20, "Elinor"),
159+
rows(18, "Dale"),
160+
rows(13, "Nanette"),
161+
rows(6, "Hattie"),
162+
rows(1, "Amber JOHnny"));
109163
}
110164
}

0 commit comments

Comments
 (0)