Skip to content

Commit e2587bd

Browse files
committed
Cast join condition automatically
Signed-off-by: Tomoyuki Morita <[email protected]>
1 parent 9ed76e5 commit e2587bd

File tree

7 files changed

+385
-120
lines changed

7 files changed

+385
-120
lines changed

core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1296,7 +1296,7 @@ public RelNode visitJoin(Join node, CalcitePlanContext context) {
12961296
node.getJoinCondition()
12971297
.map(c -> rexVisitor.analyzeJoinCondition(c, context))
12981298
.orElse(context.relBuilder.literal(true));
1299-
JoinAndLookupUtils.verifyJoinConditionNotUseAnyType(joinCondition, context);
1299+
joinCondition = context.rexBuilder.castAnyToAlignTypes(joinCondition, context);
13001300
if (node.getJoinType() == SEMI || node.getJoinType() == ANTI) {
13011301
// semi and anti join only return left table outputs
13021302
context.relBuilder.join(
@@ -1376,14 +1376,9 @@ public Void visitInputRef(RexInputRef inputRef) {
13761376

13771377
private static RexNode buildJoinConditionByFieldName(
13781378
CalcitePlanContext context, String fieldName) {
1379-
RexNode lookupKey = JoinAndLookupUtils.analyzeFieldsInRight(fieldName, context);
13801379
RexNode sourceKey = JoinAndLookupUtils.analyzeFieldsInLeft(fieldName, context);
1381-
if (context.fieldBuilder.isAnyType(sourceKey)) {
1382-
throw new IllegalArgumentException(
1383-
String.format(
1384-
"Source key `%s` needs to be specific type. Please cast explicitly.", fieldName));
1385-
}
1386-
return context.rexBuilder.equals(sourceKey, lookupKey);
1380+
RexNode lookupKey = JoinAndLookupUtils.analyzeFieldsInRight(fieldName, context);
1381+
return context.rexBuilder.equalsWithCastAsNeeded(sourceKey, lookupKey);
13871382
}
13881383

13891384
@Override

core/src/main/java/org/opensearch/sql/calcite/ExtendedRexBuilder.java

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,19 @@
1212
import org.apache.calcite.avatica.util.TimeUnit;
1313
import org.apache.calcite.rel.type.RelDataType;
1414
import org.apache.calcite.rex.RexBuilder;
15+
import org.apache.calcite.rex.RexCall;
1516
import org.apache.calcite.rex.RexLiteral;
1617
import org.apache.calcite.rex.RexNode;
1718
import org.apache.calcite.sql.SqlIntervalQualifier;
19+
import org.apache.calcite.sql.SqlKind;
1820
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
1921
import org.apache.calcite.sql.parser.SqlParserPos;
2022
import org.apache.calcite.sql.type.SqlTypeName;
2123
import org.apache.calcite.sql.type.SqlTypeUtil;
2224
import org.opensearch.sql.ast.expression.SpanUnit;
2325
import org.opensearch.sql.calcite.type.AbstractExprRelDataType;
2426
import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory;
27+
import org.opensearch.sql.calcite.utils.RexConverter;
2528
import org.opensearch.sql.data.type.ExprCoreType;
2629
import org.opensearch.sql.exception.ExpressionEvaluationException;
2730
import org.opensearch.sql.exception.SemanticCheckException;
@@ -41,6 +44,20 @@ public RexNode equals(RexNode n1, RexNode n2) {
4144
return this.makeCall(SqlStdOperatorTable.EQUALS, n1, n2);
4245
}
4346

47+
/** Make equals call with adding cast in case the node type is ANY. */
48+
public RexNode equalsWithCastAsNeeded(RexNode n1, RexNode n2) {
49+
if (isAnyType(n1) && isAnyType(n2)) {
50+
n1 = castToString(n1);
51+
n2 = castToString(n2);
52+
} else if (isAnyType(n1)) {
53+
n1 = castToTargetType(n1, n2);
54+
} else if (isAnyType(n2)) {
55+
n2 = castToTargetType(n2, n1);
56+
}
57+
58+
return equals(n1, n2);
59+
}
60+
4461
public RexNode and(RexNode left, RexNode right) {
4562
final RelDataType booleanType = this.getTypeFactory().createSqlType(SqlTypeName.BOOLEAN);
4663
return this.makeCall(booleanType, SqlStdOperatorTable.AND, List.of(left, right));
@@ -163,4 +180,36 @@ else if ((SqlTypeUtil.isApproximateNumeric(sourceType) || SqlTypeUtil.isDecimal(
163180
}
164181
return super.makeCast(pos, type, exp, matchNullability, safe, format);
165182
}
183+
184+
public boolean isAnyType(RexNode node) {
185+
return node.getType().getSqlTypeName().equals(SqlTypeName.ANY);
186+
}
187+
188+
public RexNode castToString(RexNode node) {
189+
RelDataType stringType = getTypeFactory().createSqlType(SqlTypeName.VARCHAR);
190+
RelDataType nullableStringType = getTypeFactory().createTypeWithNullability(stringType, true);
191+
return makeCast(nullableStringType, node, true, true);
192+
}
193+
194+
/** cast node to the same type as target */
195+
public RexNode castToTargetType(RexNode node, RexNode target) {
196+
return makeCast(target.getType(), node, true, true);
197+
}
198+
199+
/** Utility to cast ANY to specific types to avoid compare issue */
200+
RexNode castAnyToAlignTypes(RexNode rexNode, CalcitePlanContext context) {
201+
return rexNode.accept(
202+
new RexConverter() {
203+
@Override
204+
public RexNode visitCall(RexCall call) {
205+
if (call.getKind() == SqlKind.EQUALS) {
206+
RexNode n0 = call.operands.get(0);
207+
RexNode n1 = call.operands.get(1);
208+
return super.visitCall((RexCall) equalsWithCastAsNeeded(n0, n1));
209+
} else {
210+
return super.visitCall(call);
211+
}
212+
}
213+
});
214+
}
166215
}

core/src/main/java/org/opensearch/sql/calcite/rel/RelFieldBuilder.java

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import org.apache.calcite.rel.type.RelDataTypeField;
2020
import org.apache.calcite.rex.RexInputRef;
2121
import org.apache.calcite.rex.RexNode;
22-
import org.apache.calcite.sql.type.SqlTypeName;
2322
import org.apache.calcite.tools.RelBuilder;
2423
import org.opensearch.sql.calcite.ExtendedRexBuilder;
2524
import org.opensearch.sql.calcite.plan.OpenSearchConstants;
@@ -104,10 +103,6 @@ public ImmutableList<RexNode> staticFields(List<? extends Number> ordinals) {
104103
return relBuilder.fields(ordinals);
105104
}
106105

107-
public boolean isAnyType(RexNode node) {
108-
return node.getType().getSqlTypeName().equals(SqlTypeName.ANY);
109-
}
110-
111106
public boolean isDynamicFieldsExist() {
112107
return isDynamicFieldsExist(1, 0);
113108
}

core/src/main/java/org/opensearch/sql/calcite/utils/JoinAndLookupUtils.java

Lines changed: 1 addition & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,7 @@
1111
import java.util.Map;
1212
import java.util.stream.Collectors;
1313
import org.apache.calcite.rel.core.JoinRelType;
14-
import org.apache.calcite.rex.RexCall;
1514
import org.apache.calcite.rex.RexNode;
16-
import org.apache.calcite.rex.RexVisitorImpl;
17-
import org.apache.calcite.sql.SqlKind;
1815
import org.apache.calcite.util.Pair;
1916
import org.opensearch.sql.ast.tree.Join;
2017
import org.opensearch.sql.ast.tree.Lookup;
@@ -78,39 +75,14 @@ static void addProjectionIfNecessary(Lookup node, CalcitePlanContext context) {
7875
}
7976
}
8077

81-
/** Utility to verify join condition does not use ANY typed field to avoid */
82-
static void verifyJoinConditionNotUseAnyType(RexNode rexNode, CalcitePlanContext context) {
83-
rexNode.accept(
84-
new RexVisitorImpl<Void>(true) {
85-
@Override
86-
public Void visitCall(RexCall call) {
87-
if (call.getKind() == SqlKind.EQUALS) {
88-
RexNode left = call.operands.get(0);
89-
RexNode right = call.operands.get(1);
90-
if (context.fieldBuilder.isAnyType(left) || context.fieldBuilder.isAnyType(right)) {
91-
throw new IllegalArgumentException(
92-
"Join condition needs to use specific type. Please cast explicitly.");
93-
}
94-
}
95-
return super.visitCall(call);
96-
}
97-
});
98-
}
99-
10078
static void addJoinForLookUp(Lookup node, CalcitePlanContext context) {
10179
RexNode joinCondition =
10280
node.getMappingAliasMap().entrySet().stream()
10381
.map(
10482
entry -> {
10583
RexNode lookupKey = analyzeFieldsInRight(entry.getKey(), context);
10684
RexNode sourceKey = analyzeFieldsInLeft(entry.getValue(), context);
107-
if (context.fieldBuilder.isAnyType(sourceKey)) {
108-
throw new IllegalArgumentException(
109-
String.format(
110-
"Source key `%s` needs to be specific type. Please cast explicitly.",
111-
entry.getValue()));
112-
}
113-
return context.rexBuilder.equals(sourceKey, lookupKey);
85+
return context.rexBuilder.equalsWithCastAsNeeded(sourceKey, lookupKey);
11486
})
11587
.reduce(context.rexBuilder::and)
11688
.orElse(context.relBuilder.literal(true));
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.sql.calcite.utils;
7+
8+
import java.util.List;
9+
import java.util.stream.Collectors;
10+
import org.apache.calcite.rex.RexCall;
11+
import org.apache.calcite.rex.RexCorrelVariable;
12+
import org.apache.calcite.rex.RexDynamicParam;
13+
import org.apache.calcite.rex.RexFieldAccess;
14+
import org.apache.calcite.rex.RexInputRef;
15+
import org.apache.calcite.rex.RexLambda;
16+
import org.apache.calcite.rex.RexLambdaRef;
17+
import org.apache.calcite.rex.RexLiteral;
18+
import org.apache.calcite.rex.RexLocalRef;
19+
import org.apache.calcite.rex.RexNode;
20+
import org.apache.calcite.rex.RexOver;
21+
import org.apache.calcite.rex.RexPatternFieldRef;
22+
import org.apache.calcite.rex.RexRangeRef;
23+
import org.apache.calcite.rex.RexSubQuery;
24+
import org.apache.calcite.rex.RexTableInputRef;
25+
import org.apache.calcite.rex.RexVisitor;
26+
27+
/**
28+
* Base class for converting specific portions of a RexNode tree by overriding the node types of
29+
* interest. This class implements the visitor pattern for RexNode traversal, providing default
30+
* implementations that return nodes unchanged. Subclasses can override specific visit methods to
31+
* transform particular node types while leaving others untouched.
32+
*/
33+
public class RexConverter implements RexVisitor<RexNode> {
34+
35+
@Override
36+
public RexNode visitInputRef(RexInputRef inputRef) {
37+
return inputRef;
38+
}
39+
40+
@Override
41+
public RexNode visitLocalRef(RexLocalRef localRef) {
42+
return localRef;
43+
}
44+
45+
@Override
46+
public RexNode visitLiteral(RexLiteral literal) {
47+
return literal;
48+
}
49+
50+
@Override
51+
public RexNode visitCall(RexCall call) {
52+
List<RexNode> operands =
53+
call.getOperands().stream()
54+
.map(operand -> operand.accept(this))
55+
.collect(Collectors.toList());
56+
if (operands.equals(call.getOperands())) {
57+
return call;
58+
}
59+
return call.clone(call.getType(), operands);
60+
}
61+
62+
@Override
63+
public RexNode visitOver(RexOver over) {
64+
List<RexNode> operands =
65+
over.getOperands().stream()
66+
.map(operand -> operand.accept(this))
67+
.collect(Collectors.toList());
68+
if (operands.equals(over.getOperands())) {
69+
return over;
70+
}
71+
return over.clone(over.getType(), operands);
72+
}
73+
74+
@Override
75+
public RexNode visitCorrelVariable(RexCorrelVariable correlVariable) {
76+
return correlVariable;
77+
}
78+
79+
@Override
80+
public RexNode visitDynamicParam(RexDynamicParam dynamicParam) {
81+
return dynamicParam;
82+
}
83+
84+
@Override
85+
public RexNode visitRangeRef(RexRangeRef rangeRef) {
86+
return rangeRef;
87+
}
88+
89+
@Override
90+
public RexNode visitFieldAccess(RexFieldAccess fieldAccess) {
91+
RexNode expr = fieldAccess.getReferenceExpr().accept(this);
92+
if (expr == fieldAccess.getReferenceExpr()) {
93+
return fieldAccess;
94+
}
95+
throw new UnsupportedOperationException(
96+
"RexFieldAccess transformation not supported. Override visitFieldAccess() to handle this"
97+
+ " case.");
98+
}
99+
100+
@Override
101+
public RexNode visitSubQuery(RexSubQuery subQuery) {
102+
List<RexNode> operands =
103+
subQuery.getOperands().stream()
104+
.map(operand -> operand.accept(this))
105+
.collect(Collectors.toList());
106+
if (operands.equals(subQuery.getOperands())) {
107+
return subQuery;
108+
}
109+
return subQuery.clone(subQuery.getType(), operands);
110+
}
111+
112+
@Override
113+
public RexNode visitTableInputRef(RexTableInputRef fieldRef) {
114+
return fieldRef;
115+
}
116+
117+
@Override
118+
public RexNode visitPatternFieldRef(RexPatternFieldRef fieldRef) {
119+
return fieldRef;
120+
}
121+
122+
@Override
123+
public RexNode visitLambda(RexLambda lambda) {
124+
RexNode expr = lambda.getExpression().accept(this);
125+
if (expr == lambda.getExpression()) {
126+
return lambda;
127+
}
128+
throw new UnsupportedOperationException(
129+
"RexLambda transformation not supported. Override visitLambda() to handle this case.");
130+
}
131+
132+
@Override
133+
public RexNode visitLambdaRef(RexLambdaRef lambdaRef) {
134+
return lambdaRef;
135+
}
136+
}

0 commit comments

Comments
 (0)