Skip to content

Commit d46cb4c

Browse files
authored
Support join field list and join options (#3803)
* Support join field list and join options Signed-off-by: Lantao Jin <[email protected]> * Add SPL-compatible syntax setting Signed-off-by: Lantao Jin <[email protected]> * Revert SPL settings Signed-off-by: Lantao Jin <[email protected]> * Fix IT Signed-off-by: Lantao Jin <[email protected]> * Fix IT Signed-off-by: Lantao Jin <[email protected]> * Support max=n option Signed-off-by: Lantao Jin <[email protected]> * support max=n in sql-like join syntax Signed-off-by: Lantao Jin <[email protected]> * Add Explain IT for new join syntax Signed-off-by: Lantao Jin <[email protected]> * Refactor the user doc Signed-off-by: Lantao Jin <[email protected]> * Fix conflicts Signed-off-by: Lantao Jin <[email protected]> * Fix conflicts Signed-off-by: Lantao Jin <[email protected]> * Disable the collapse pushdown Signed-off-by: Lantao Jin <[email protected]> * refactor Signed-off-by: Lantao Jin <[email protected]> * Fix IT Signed-off-by: Lantao Jin <[email protected]> --------- Signed-off-by: Lantao Jin <[email protected]>
1 parent e9cff8a commit d46cb4c

File tree

31 files changed

+1408
-328
lines changed

31 files changed

+1408
-328
lines changed

common/src/main/java/org/opensearch/sql/common/setting/Settings.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ public enum Key {
3737
CALCITE_PUSHDOWN_ENABLED("plugins.calcite.pushdown.enabled"),
3838
CALCITE_PUSHDOWN_ROWCOUNT_ESTIMATION_FACTOR(
3939
"plugins.calcite.pushdown.rowcount.estimation.factor"),
40+
CALCITE_SUPPORT_ALL_JOIN_TYPES("plugins.calcite.all_join_types.allowed"),
4041

4142
/** Query Settings. */
4243
FIELD_TYPE_TOLERANCE("plugins.query.field_type_tolerance"),

core/src/main/java/org/opensearch/sql/ast/expression/Argument.java

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,24 @@ public static class ArgumentMap {
3939
private final Map<String, Literal> map;
4040

4141
public ArgumentMap(List<Argument> arguments) {
42-
this.map =
43-
arguments.stream()
44-
.collect(java.util.stream.Collectors.toMap(Argument::getArgName, Argument::getValue));
42+
if (arguments == null || arguments.isEmpty()) {
43+
this.map = Map.of();
44+
} else {
45+
this.map =
46+
arguments.stream()
47+
.collect(
48+
java.util.stream.Collectors.toMap(Argument::getArgName, Argument::getValue));
49+
}
4550
}
4651

4752
public static ArgumentMap of(List<Argument> arguments) {
4853
return new ArgumentMap(arguments);
4954
}
5055

56+
public static ArgumentMap empty() {
57+
return new ArgumentMap(null);
58+
}
59+
5160
/**
5261
* Get argument value by name.
5362
*

core/src/main/java/org/opensearch/sql/ast/expression/Literal.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,4 +46,8 @@ public <R, C> R accept(AbstractNodeVisitor<R, C> nodeVisitor, C context) {
4646
public String toString() {
4747
return String.valueOf(value);
4848
}
49+
50+
public static Literal TRUE = new Literal(true, DataType.BOOLEAN);
51+
public static Literal FALSE = new Literal(false, DataType.BOOLEAN);
52+
public static Literal ZERO = new Literal(Integer.valueOf("0"), DataType.INTEGER);
4953
}

core/src/main/java/org/opensearch/sql/ast/tree/Join.java

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
import lombok.RequiredArgsConstructor;
1616
import lombok.ToString;
1717
import org.opensearch.sql.ast.AbstractNodeVisitor;
18+
import org.opensearch.sql.ast.expression.Argument;
19+
import org.opensearch.sql.ast.expression.Field;
1820
import org.opensearch.sql.ast.expression.UnresolvedExpression;
1921

2022
@ToString
@@ -28,20 +30,26 @@ public class Join extends UnresolvedPlan {
2830
private final JoinType joinType;
2931
private final Optional<UnresolvedExpression> joinCondition;
3032
private final JoinHint joinHint;
33+
private final Optional<List<Field>> joinFields;
34+
private final Argument.ArgumentMap argumentMap;
3135

3236
public Join(
3337
UnresolvedPlan right,
3438
Optional<String> leftAlias,
3539
Optional<String> rightAlias,
3640
JoinType joinType,
3741
Optional<UnresolvedExpression> joinCondition,
38-
JoinHint joinHint) {
42+
JoinHint joinHint,
43+
Optional<List<Field>> joinFields,
44+
Argument.ArgumentMap argumentMap) {
3945
this.right = right;
4046
this.leftAlias = leftAlias;
4147
this.rightAlias = rightAlias;
4248
this.joinType = joinType;
4349
this.joinCondition = joinCondition;
4450
this.joinHint = joinHint;
51+
this.joinFields = joinFields;
52+
this.argumentMap = argumentMap;
4553
}
4654

4755
@Override
@@ -89,6 +97,11 @@ public enum JoinType {
8997
FULL
9098
}
9199

100+
/** RIGHT, CROSS, FULL are performance sensitive join types */
101+
public static List<JoinType> highCostJoinTypes() {
102+
return List.of(JoinType.RIGHT, JoinType.CROSS, JoinType.FULL);
103+
}
104+
92105
@Getter
93106
@RequiredArgsConstructor
94107
public static class JoinHint {

core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java

Lines changed: 180 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
import org.apache.calcite.rex.RexInputRef;
4949
import org.apache.calcite.rex.RexLiteral;
5050
import org.apache.calcite.rex.RexNode;
51+
import org.apache.calcite.rex.RexVisitorImpl;
5152
import org.apache.calcite.rex.RexWindowBounds;
5253
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
5354
import org.apache.calcite.sql.type.SqlTypeFamily;
@@ -903,6 +904,70 @@ private Optional<RexLiteral> extractAliasLiteral(RexNode node) {
903904
public RelNode visitJoin(Join node, CalcitePlanContext context) {
904905
List<UnresolvedPlan> children = node.getChildren();
905906
children.forEach(c -> analyze(c, context));
907+
if (node.getJoinCondition().isEmpty()) {
908+
// join-with-field-list grammar
909+
List<String> leftColumns = context.relBuilder.peek(1).getRowType().getFieldNames();
910+
List<String> rightColumns = context.relBuilder.peek().getRowType().getFieldNames();
911+
List<String> duplicatedFieldNames =
912+
leftColumns.stream().filter(rightColumns::contains).toList();
913+
RexNode joinCondition;
914+
if (node.getJoinFields().isPresent()) {
915+
joinCondition =
916+
node.getJoinFields().get().stream()
917+
.map(field -> buildJoinConditionByFieldName(context, field.getField().toString()))
918+
.reduce(context.rexBuilder::and)
919+
.orElse(context.relBuilder.literal(true));
920+
} else {
921+
joinCondition =
922+
duplicatedFieldNames.stream()
923+
.map(fieldName -> buildJoinConditionByFieldName(context, fieldName))
924+
.reduce(context.rexBuilder::and)
925+
.orElse(context.relBuilder.literal(true));
926+
}
927+
if (node.getJoinType() == SEMI || node.getJoinType() == ANTI) {
928+
// semi and anti join only return left table outputs
929+
context.relBuilder.join(
930+
JoinAndLookupUtils.translateJoinType(node.getJoinType()), joinCondition);
931+
return context.relBuilder.peek();
932+
}
933+
List<RexNode> toBeRemovedFields;
934+
if (node.getArgumentMap().get("overwrite") == null // 'overwrite' default value is true
935+
|| (node.getArgumentMap().get("overwrite").equals(Literal.TRUE))) {
936+
toBeRemovedFields =
937+
duplicatedFieldNames.stream()
938+
.map(field -> JoinAndLookupUtils.analyzeFieldsForLookUp(field, true, context))
939+
.toList();
940+
} else {
941+
toBeRemovedFields =
942+
duplicatedFieldNames.stream()
943+
.map(field -> JoinAndLookupUtils.analyzeFieldsForLookUp(field, false, context))
944+
.toList();
945+
}
946+
Literal max = node.getArgumentMap().get("max");
947+
if (max != null && !max.equals(Literal.ZERO)) {
948+
// max != 0 means the right-side should be dedup
949+
Integer allowedDuplication = (Integer) max.getValue();
950+
if (allowedDuplication < 0) {
951+
throw new SemanticCheckException("max option must be a positive integer");
952+
}
953+
List<RexNode> dedupeFields =
954+
node.getJoinFields().isPresent()
955+
? node.getJoinFields().get().stream()
956+
.map(a -> (RexNode) context.relBuilder.field(a.getField().toString()))
957+
.toList()
958+
: duplicatedFieldNames.stream()
959+
.map(a -> (RexNode) context.relBuilder.field(a))
960+
.toList();
961+
buildDedupNotNull(context, dedupeFields, allowedDuplication);
962+
}
963+
context.relBuilder.join(
964+
JoinAndLookupUtils.translateJoinType(node.getJoinType()), joinCondition);
965+
if (!toBeRemovedFields.isEmpty()) {
966+
context.relBuilder.projectExcept(toBeRemovedFields);
967+
}
968+
return context.relBuilder.peek();
969+
}
970+
// The join-with-criteria grammar doesn't allow empty join condition
906971
RexNode joinCondition =
907972
node.getJoinCondition()
908973
.map(c -> rexVisitor.analyzeJoinCondition(c, context))
@@ -938,6 +1003,19 @@ public RelNode visitJoin(Join node, CalcitePlanContext context) {
9381003
.orElse(rightTableQualifiedName + "." + col)
9391004
: col)
9401005
.toList();
1006+
1007+
Literal max = node.getArgumentMap().get("max");
1008+
if (max != null && !max.equals(Literal.ZERO)) {
1009+
// max != 0 means the right-side should be dedup
1010+
Integer allowedDuplication = (Integer) max.getValue();
1011+
if (allowedDuplication < 0) {
1012+
throw new SemanticCheckException("max option must be a positive integer");
1013+
}
1014+
List<RexNode> dedupeFields =
1015+
getRightColumnsInJoinCriteria(context.relBuilder, joinCondition);
1016+
1017+
buildDedupNotNull(context, dedupeFields, allowedDuplication);
1018+
}
9411019
context.relBuilder.join(
9421020
JoinAndLookupUtils.translateJoinType(node.getJoinType()), joinCondition);
9431021
JoinAndLookupUtils.renameToExpectedFields(
@@ -946,6 +1024,37 @@ public RelNode visitJoin(Join node, CalcitePlanContext context) {
9461024
return context.relBuilder.peek();
9471025
}
9481026

1027+
private List<RexNode> getRightColumnsInJoinCriteria(
1028+
RelBuilder relBuilder, RexNode joinCondition) {
1029+
int stackSize = relBuilder.size();
1030+
int leftFieldCount = relBuilder.peek(stackSize - 1).getRowType().getFieldCount();
1031+
RelNode right = relBuilder.peek(stackSize - 2);
1032+
List<String> allColumnNamesOfRight = right.getRowType().getFieldNames();
1033+
1034+
List<Integer> rightColumnIndexes = new ArrayList<>();
1035+
joinCondition.accept(
1036+
new RexVisitorImpl<Void>(true) {
1037+
@Override
1038+
public Void visitInputRef(RexInputRef inputRef) {
1039+
if (inputRef.getIndex() >= leftFieldCount) {
1040+
rightColumnIndexes.add(inputRef.getIndex() - leftFieldCount);
1041+
}
1042+
return super.visitInputRef(inputRef);
1043+
}
1044+
});
1045+
return rightColumnIndexes.stream()
1046+
.map(allColumnNamesOfRight::get)
1047+
.map(n -> (RexNode) relBuilder.field(n))
1048+
.toList();
1049+
}
1050+
1051+
private static RexNode buildJoinConditionByFieldName(
1052+
CalcitePlanContext context, String fieldName) {
1053+
RexNode lookupKey = JoinAndLookupUtils.analyzeFieldsForLookUp(fieldName, false, context);
1054+
RexNode sourceKey = JoinAndLookupUtils.analyzeFieldsForLookUp(fieldName, true, context);
1055+
return context.rexBuilder.equals(sourceKey, lookupKey);
1056+
}
1057+
9491058
@Override
9501059
public RelNode visitSubqueryAlias(SubqueryAlias node, CalcitePlanContext context) {
9511060
visitChildren(node, context);
@@ -1068,74 +1177,82 @@ public RelNode visitDedupe(Dedupe node, CalcitePlanContext context) {
10681177
List<RexNode> dedupeFields =
10691178
node.getFields().stream().map(f -> rexVisitor.analyze(f, context)).toList();
10701179
if (keepEmpty) {
1071-
/*
1072-
* | dedup 2 a, b keepempty=false
1073-
* DropColumns('_row_number_dedup_)
1074-
* +- Filter ('_row_number_dedup_ <= n OR isnull('a) OR isnull('b))
1075-
* +- Window [row_number() windowspecdefinition('a, 'b, 'a ASC NULLS FIRST, 'b ASC NULLS FIRST, specifiedwindowoundedpreceding$(), currentrow$())) AS _row_number_dedup_], ['a, 'b], ['a ASC NULLS FIRST, 'b ASC NULLS FIRST]
1076-
* +- ...
1077-
*/
1078-
// Window [row_number() windowspecdefinition('a, 'b, 'a ASC NULLS FIRST, 'b ASC NULLS FIRST,
1079-
// specifiedwindowoundedpreceding$(), currentrow$())) AS _row_number_dedup_], ['a, 'b], ['a
1080-
// ASC
1081-
// NULLS FIRST, 'b ASC NULLS FIRST]
1082-
RexNode rowNumber =
1083-
context
1084-
.relBuilder
1085-
.aggregateCall(SqlStdOperatorTable.ROW_NUMBER)
1086-
.over()
1087-
.partitionBy(dedupeFields)
1088-
.orderBy(dedupeFields)
1089-
.rowsTo(RexWindowBounds.CURRENT_ROW)
1090-
.as(ROW_NUMBER_COLUMN_FOR_DEDUP);
1091-
context.relBuilder.projectPlus(rowNumber);
1092-
RexNode _row_number_dedup_ = context.relBuilder.field(ROW_NUMBER_COLUMN_FOR_DEDUP);
1093-
// Filter (isnull('a) OR isnull('b) OR '_row_number_dedup_ <= n)
1094-
context.relBuilder.filter(
1095-
context.relBuilder.or(
1096-
context.relBuilder.or(dedupeFields.stream().map(context.relBuilder::isNull).toList()),
1097-
context.relBuilder.lessThanOrEqual(
1098-
_row_number_dedup_, context.relBuilder.literal(allowedDuplication))));
1099-
// DropColumns('_row_number_)
1100-
context.relBuilder.projectExcept(_row_number_dedup_);
1180+
buildDedupOrNull(context, dedupeFields, allowedDuplication);
11011181
} else {
1102-
/*
1103-
* | dedup 2 a, b keepempty=false
1104-
* DropColumns('_row_number_dedup_)
1105-
* +- Filter ('_row_number_dedup_ <= n)
1106-
* +- Window [row_number() windowspecdefinition('a, 'b, 'a ASC NULLS FIRST, 'b ASC NULLS FIRST, specifiedwindowoundedpreceding$(), currentrow$())) AS _row_number_dedup_], ['a, 'b], ['a ASC NULLS FIRST, 'b ASC NULLS FIRST]
1107-
* +- Filter (isnotnull('a) AND isnotnull('b))
1108-
* +- ...
1109-
*/
1110-
// Filter (isnotnull('a) AND isnotnull('b))
1111-
context.relBuilder.filter(
1112-
context.relBuilder.and(
1113-
dedupeFields.stream().map(context.relBuilder::isNotNull).toList()));
1114-
// Window [row_number() windowspecdefinition('a, 'b, 'a ASC NULLS FIRST, 'b ASC NULLS FIRST,
1115-
// specifiedwindowoundedpreceding$(), currentrow$())) AS _row_number_dedup_], ['a, 'b], ['a
1116-
// ASC
1117-
// NULLS FIRST, 'b ASC NULLS FIRST]
1118-
RexNode rowNumber =
1119-
context
1120-
.relBuilder
1121-
.aggregateCall(SqlStdOperatorTable.ROW_NUMBER)
1122-
.over()
1123-
.partitionBy(dedupeFields)
1124-
.orderBy(dedupeFields)
1125-
.rowsTo(RexWindowBounds.CURRENT_ROW)
1126-
.as(ROW_NUMBER_COLUMN_FOR_DEDUP);
1127-
context.relBuilder.projectPlus(rowNumber);
1128-
RexNode _row_number_dedup_ = context.relBuilder.field(ROW_NUMBER_COLUMN_FOR_DEDUP);
1129-
// Filter ('_row_number_dedup_ <= n)
1130-
context.relBuilder.filter(
1131-
context.relBuilder.lessThanOrEqual(
1132-
_row_number_dedup_, context.relBuilder.literal(allowedDuplication)));
1133-
// DropColumns('_row_number_dedup_)
1134-
context.relBuilder.projectExcept(_row_number_dedup_);
1182+
buildDedupNotNull(context, dedupeFields, allowedDuplication);
11351183
}
11361184
return context.relBuilder.peek();
11371185
}
11381186

1187+
private static void buildDedupOrNull(
1188+
CalcitePlanContext context, List<RexNode> dedupeFields, Integer allowedDuplication) {
1189+
/*
1190+
* | dedup 2 a, b keepempty=false
1191+
* DropColumns('_row_number_dedup_)
1192+
* +- Filter ('_row_number_dedup_ <= n OR isnull('a) OR isnull('b))
1193+
* +- Window [row_number() windowspecdefinition('a, 'b, 'a ASC NULLS FIRST, 'b ASC NULLS FIRST, specifiedwindowoundedpreceding$(), currentrow$())) AS _row_number_dedup_], ['a, 'b], ['a ASC NULLS FIRST, 'b ASC NULLS FIRST]
1194+
* +- ...
1195+
*/
1196+
// Window [row_number() windowspecdefinition('a, 'b, 'a ASC NULLS FIRST, 'b ASC NULLS FIRST,
1197+
// specifiedwindowoundedpreceding$(), currentrow$())) AS _row_number_dedup_], ['a, 'b], ['a
1198+
// ASC
1199+
// NULLS FIRST, 'b ASC NULLS FIRST]
1200+
RexNode rowNumber =
1201+
context
1202+
.relBuilder
1203+
.aggregateCall(SqlStdOperatorTable.ROW_NUMBER)
1204+
.over()
1205+
.partitionBy(dedupeFields)
1206+
.orderBy(dedupeFields)
1207+
.rowsTo(RexWindowBounds.CURRENT_ROW)
1208+
.as(ROW_NUMBER_COLUMN_FOR_DEDUP);
1209+
context.relBuilder.projectPlus(rowNumber);
1210+
RexNode _row_number_dedup_ = context.relBuilder.field(ROW_NUMBER_COLUMN_FOR_DEDUP);
1211+
// Filter (isnull('a) OR isnull('b) OR '_row_number_dedup_ <= n)
1212+
context.relBuilder.filter(
1213+
context.relBuilder.or(
1214+
context.relBuilder.or(dedupeFields.stream().map(context.relBuilder::isNull).toList()),
1215+
context.relBuilder.lessThanOrEqual(
1216+
_row_number_dedup_, context.relBuilder.literal(allowedDuplication))));
1217+
// DropColumns('_row_number_dedup_)
1218+
context.relBuilder.projectExcept(_row_number_dedup_);
1219+
}
1220+
1221+
private static void buildDedupNotNull(
1222+
CalcitePlanContext context, List<RexNode> dedupeFields, Integer allowedDuplication) {
1223+
/*
1224+
* | dedup 2 a, b keepempty=false
1225+
* DropColumns('_row_number_dedup_)
1226+
* +- Filter ('_row_number_dedup_ <= n)
1227+
* +- Window [row_number() windowspecdefinition('a, 'b, 'a ASC NULLS FIRST, 'b ASC NULLS FIRST, specifiedwindowoundedpreceding$(), currentrow$())) AS _row_number_dedup_], ['a, 'b], ['a ASC NULLS FIRST, 'b ASC NULLS FIRST]
1228+
* +- Filter (isnotnull('a) AND isnotnull('b))
1229+
* +- ...
1230+
*/
1231+
// Filter (isnotnull('a) AND isnotnull('b))
1232+
context.relBuilder.filter(
1233+
context.relBuilder.and(dedupeFields.stream().map(context.relBuilder::isNotNull).toList()));
1234+
// Window [row_number() windowspecdefinition('a, 'b, 'a ASC NULLS FIRST, 'b ASC NULLS FIRST,
1235+
// specifiedwindowoundedpreceding$(), currentrow$())) AS _row_number_dedup_], ['a, 'b], ['a ASC
1236+
// NULLS FIRST, 'b ASC NULLS FIRST]
1237+
RexNode rowNumber =
1238+
context
1239+
.relBuilder
1240+
.aggregateCall(SqlStdOperatorTable.ROW_NUMBER)
1241+
.over()
1242+
.partitionBy(dedupeFields)
1243+
.orderBy(dedupeFields)
1244+
.rowsTo(RexWindowBounds.CURRENT_ROW)
1245+
.as(ROW_NUMBER_COLUMN_FOR_DEDUP);
1246+
context.relBuilder.projectPlus(rowNumber);
1247+
RexNode _row_number_dedup_ = context.relBuilder.field(ROW_NUMBER_COLUMN_FOR_DEDUP);
1248+
// Filter ('_row_number_dedup_ <= n)
1249+
context.relBuilder.filter(
1250+
context.relBuilder.lessThanOrEqual(
1251+
_row_number_dedup_, context.relBuilder.literal(allowedDuplication)));
1252+
// DropColumns('_row_number_dedup_)
1253+
context.relBuilder.projectExcept(_row_number_dedup_);
1254+
}
1255+
11391256
@Override
11401257
public RelNode visitWindow(Window node, CalcitePlanContext context) {
11411258
visitChildren(node, context);

0 commit comments

Comments
 (0)