Skip to content

Commit 29c7625

Browse files
authored
[Calcite Engine] Support In expression (#3429)
* [Calcite Engine] Support In expression Signed-off-by: Heng Qian <qianheng@amazon.com> * Address comments Signed-off-by: Heng Qian <qianheng@amazon.com> * Support NOT IN Signed-off-by: Heng Qian <qianheng@amazon.com> * Address comments Signed-off-by: Heng Qian <qianheng@amazon.com> * Add more UT Signed-off-by: Heng Qian <qianheng@amazon.com> * Address comments Signed-off-by: Heng Qian <qianheng@amazon.com> --------- Signed-off-by: Heng Qian <qianheng@amazon.com>
1 parent 137935b commit 29c7625

File tree

7 files changed

+161
-7
lines changed

7 files changed

+161
-7
lines changed

core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import org.opensearch.sql.ast.expression.Compare;
3636
import org.opensearch.sql.ast.expression.EqualTo;
3737
import org.opensearch.sql.ast.expression.Function;
38+
import org.opensearch.sql.ast.expression.In;
3839
import org.opensearch.sql.ast.expression.Let;
3940
import org.opensearch.sql.ast.expression.Literal;
4041
import org.opensearch.sql.ast.expression.Not;
@@ -51,7 +52,9 @@
5152
import org.opensearch.sql.ast.expression.subquery.ScalarSubquery;
5253
import org.opensearch.sql.ast.tree.UnresolvedPlan;
5354
import org.opensearch.sql.calcite.utils.BuiltinFunctionUtils;
55+
import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory;
5456
import org.opensearch.sql.common.utils.StringUtils;
57+
import org.opensearch.sql.data.type.ExprType;
5558
import org.opensearch.sql.exception.CalciteUnsupportedException;
5659
import org.opensearch.sql.exception.SemanticCheckException;
5760

@@ -143,6 +146,29 @@ public RexNode visitNot(Not node, CalcitePlanContext context) {
143146
return context.relBuilder.not(expr);
144147
}
145148

149+
@Override
150+
public RexNode visitIn(In node, CalcitePlanContext context) {
151+
final RexNode field = analyze(node.getField(), context);
152+
final List<RexNode> valueList =
153+
node.getValueList().stream().map(value -> analyze(value, context)).toList();
154+
final List<RelDataType> dataTypes =
155+
new java.util.ArrayList<>(valueList.stream().map(RexNode::getType).toList());
156+
dataTypes.add(field.getType());
157+
RelDataType commonType = context.rexBuilder.getTypeFactory().leastRestrictive(dataTypes);
158+
if (commonType != null) {
159+
List<RexNode> newValueList =
160+
valueList.stream().map(value -> context.rexBuilder.makeCast(commonType, value)).toList();
161+
return context.rexBuilder.makeIn(field, newValueList);
162+
} else {
163+
List<ExprType> exprTypes =
164+
dataTypes.stream().map(OpenSearchTypeFactory::convertRelDataTypeToExprType).toList();
165+
throw new SemanticCheckException(
166+
StringUtils.format(
167+
"In expression types are incompatible: fields type %s, values type %s",
168+
exprTypes.getLast(), exprTypes.subList(0, exprTypes.size() - 1)));
169+
}
170+
}
171+
146172
@Override
147173
public RexNode visitCompare(Compare node, CalcitePlanContext context) {
148174
SqlOperator op = BuiltinFunctionUtils.translate(node.getOperator());
@@ -164,7 +190,9 @@ public RexNode visitBetween(Between node, CalcitePlanContext context) {
164190
throw new SemanticCheckException(
165191
StringUtils.format(
166192
"BETWEEN expression types are incompatible: [%s, %s, %s]",
167-
value.getType(), lowerBound.getType(), upperBound.getType()));
193+
OpenSearchTypeFactory.convertRelDataTypeToExprType(value.getType()),
194+
OpenSearchTypeFactory.convertRelDataTypeToExprType(lowerBound.getType()),
195+
OpenSearchTypeFactory.convertRelDataTypeToExprType(upperBound.getType())));
168196
}
169197
return context.relBuilder.between(value, lowerBound, upperBound);
170198
}

integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteWhereCommandIT.java

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,25 @@
99
import org.junit.Ignore;
1010
import org.opensearch.sql.ppl.WhereCommandIT;
1111

12-
@Ignore("Not all boolean functions are supported in Calcite now")
1312
public class CalciteWhereCommandIT extends WhereCommandIT {
1413
@Override
1514
public void init() throws IOException {
1615
enableCalcite();
1716
disallowCalciteFallback();
1817
super.init();
1918
}
19+
20+
@Ignore("https://github.com/opensearch-project/sql/issues/3428")
21+
@Override
22+
public void testIsNotNullFunction() throws IOException {}
23+
24+
@Ignore("https://github.com/opensearch-project/sql/issues/3333")
25+
@Override
26+
public void testWhereWithMetadataFields() throws IOException {}
27+
28+
@Override
29+
protected String getIncompatibleTypeErrMsg() {
30+
return "In expression types are incompatible: fields type LONG, values type [INTEGER, INTEGER,"
31+
+ " STRING]";
32+
}
2033
}

integ-test/src/test/java/org/opensearch/sql/ppl/WhereCommandIT.java

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@
55

66
package org.opensearch.sql.ppl;
77

8+
import static org.hamcrest.CoreMatchers.containsString;
89
import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_ACCOUNT;
910
import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_BANK_WITH_NULL_VALUES;
1011
import static org.opensearch.sql.util.MatcherUtils.rows;
1112
import static org.opensearch.sql.util.MatcherUtils.verifyDataRows;
1213

1314
import java.io.IOException;
15+
import org.hamcrest.MatcherAssert;
1416
import org.json.JSONObject;
1517
import org.junit.jupiter.api.Test;
1618

@@ -106,4 +108,73 @@ public void testWhereWithMetadataFields() throws IOException {
106108
String.format("source=%s | where _id='1' | fields firstname", TEST_INDEX_ACCOUNT));
107109
verifyDataRows(result, rows("Amber"));
108110
}
111+
112+
@Test
113+
public void testWhereWithIn() throws IOException {
114+
JSONObject result =
115+
executeQuery(
116+
String.format(
117+
"source=%s | where firstname in ('Amber') | fields firstname", TEST_INDEX_ACCOUNT));
118+
verifyDataRows(result, rows("Amber"));
119+
120+
result =
121+
executeQuery(
122+
String.format(
123+
"source=%s | where firstname in ('Amber', 'Dale') | fields firstname",
124+
TEST_INDEX_ACCOUNT));
125+
verifyDataRows(result, rows("Amber"), rows("Dale"));
126+
127+
result =
128+
executeQuery(
129+
String.format(
130+
"source=%s | where balance in (4180, 5686.0) | fields balance",
131+
TEST_INDEX_ACCOUNT));
132+
verifyDataRows(result, rows(4180), rows(5686));
133+
}
134+
135+
@Test
136+
public void testWhereWithNotIn() throws IOException {
137+
JSONObject result =
138+
executeQuery(
139+
String.format(
140+
"source=%s | where account_number < 4 | where firstname not in ('Amber', 'Levine')"
141+
+ " | fields firstname",
142+
TEST_INDEX_ACCOUNT));
143+
verifyDataRows(result, rows("Roberta"), rows("Bradshaw"));
144+
145+
result =
146+
executeQuery(
147+
String.format(
148+
"source=%s | where account_number < 4 | where not firstname in ('Amber', 'Levine')"
149+
+ " | fields firstname",
150+
TEST_INDEX_ACCOUNT));
151+
verifyDataRows(result, rows("Roberta"), rows("Bradshaw"));
152+
153+
result =
154+
executeQuery(
155+
String.format(
156+
"source=%s | where not firstname not in ('Amber', 'Dale') | fields firstname",
157+
TEST_INDEX_ACCOUNT));
158+
verifyDataRows(result, rows("Amber"), rows("Dale"));
159+
}
160+
161+
@Test
162+
public void testInWithIncompatibleType() {
163+
Exception e =
164+
assertThrows(
165+
Exception.class,
166+
() -> {
167+
executeQuery(
168+
String.format(
169+
"source=%s | where balance in (4180, 5686, '6077') | fields firstname",
170+
TEST_INDEX_ACCOUNT));
171+
});
172+
MatcherAssert.assertThat(e.getMessage(), containsString(getIncompatibleTypeErrMsg()));
173+
}
174+
175+
protected String getIncompatibleTypeErrMsg() {
176+
return "function expected"
177+
+ " {[BYTE,BYTE],[SHORT,SHORT],[INTEGER,INTEGER],[LONG,LONG],[FLOAT,FLOAT],[DOUBLE,DOUBLE],[STRING,STRING],[BOOLEAN,BOOLEAN],[DATE,DATE],[TIME,TIME],[TIMESTAMP,TIMESTAMP],[INTERVAL,INTERVAL],[IP,IP],[STRUCT,STRUCT],[ARRAY,ARRAY]},"
178+
+ " but got [LONG,STRING]";
179+
}
109180
}

ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -138,11 +138,13 @@ public UnresolvedExpression visitCompareExpr(CompareExprContext ctx) {
138138

139139
@Override
140140
public UnresolvedExpression visitInExpr(InExprContext ctx) {
141-
return new In(
142-
visit(ctx.valueExpression()),
143-
ctx.valueList().literalValue().stream()
144-
.map(this::visitLiteralValue)
145-
.collect(Collectors.toList()));
141+
UnresolvedExpression expr =
142+
new In(
143+
visit(ctx.valueExpression()),
144+
ctx.valueList().literalValue().stream()
145+
.map(this::visitLiteralValue)
146+
.collect(Collectors.toList()));
147+
return ctx.NOT() != null ? new Not(expr) : expr;
146148
}
147149

148150
/** Value Expression. */

ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.opensearch.sql.ast.expression.Compare;
2424
import org.opensearch.sql.ast.expression.Field;
2525
import org.opensearch.sql.ast.expression.Function;
26+
import org.opensearch.sql.ast.expression.In;
2627
import org.opensearch.sql.ast.expression.Interval;
2728
import org.opensearch.sql.ast.expression.Let;
2829
import org.opensearch.sql.ast.expression.Literal;
@@ -421,6 +422,12 @@ public String visitBetween(Between node, String context) {
421422
return StringUtils.format("%s between %s and %s", value, left, right);
422423
}
423424

425+
@Override
426+
public String visitIn(In node, String context) {
427+
String field = analyze(node.getField(), context);
428+
return StringUtils.format("%s in (%s)", field, MASK_LITERAL);
429+
}
430+
424431
@Override
425432
public String visitField(Field node, String context) {
426433
return node.getField().toString();

ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLBasicTest.java

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,34 @@ public void testFilterQueryWithOr2() {
141141
verifyPPLToSparkSQL(root, expectedSparkSql);
142142
}
143143

144+
@Test
145+
public void testFilterQueryWithIn() {
146+
String ppl = "source=scott.products_temporal | where ID in ('1000', '2000')";
147+
RelNode root = getRelNode(ppl);
148+
String expectedLogical =
149+
"LogicalFilter(condition=[SEARCH($0, Sarg['1000':VARCHAR(32),"
150+
+ " '2000':VARCHAR(32)]:VARCHAR(32))])\n"
151+
+ " LogicalTableScan(table=[[scott, products_temporal]])\n";
152+
verifyLogical(root, expectedLogical);
153+
154+
String expectedSparkSql =
155+
"SELECT *\nFROM `scott`.`products_temporal`\nWHERE `ID` IN ('1000', '2000')";
156+
verifyPPLToSparkSQL(root, expectedSparkSql);
157+
}
158+
159+
@Test
160+
public void testFilterQueryWithIn2() {
161+
String ppl = "source=EMP | where DEPTNO in (20, 30.0)";
162+
RelNode root = getRelNode(ppl);
163+
String expectedLogical =
164+
"LogicalFilter(condition=[SEARCH($7, Sarg[20.0E0:DOUBLE, 30.0E0:DOUBLE]:DOUBLE)])\n"
165+
+ " LogicalTableScan(table=[[scott, EMP]])\n";
166+
verifyLogical(root, expectedLogical);
167+
168+
String expectedSparkSql = "SELECT *\nFROM `scott`.`EMP`\nWHERE `DEPTNO` IN (2.00E1, 3.00E1)";
169+
verifyPPLToSparkSQL(root, expectedSparkSql);
170+
}
171+
144172
@Test
145173
public void testQueryWithFields() {
146174
String ppl = "source=products_temporal | fields SUPPLIER, ID";

ppl/src/test/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizerTest.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,11 @@ public void testNotExpression() {
158158
assertEquals("source=t | where not a = ***", anonymize("source=t | where not a=1 "));
159159
}
160160

161+
@Test
162+
public void testInExpression() {
163+
assertEquals("source=t | where a in (***)", anonymize("source=t | where a in (1, 2, 3) "));
164+
}
165+
161166
@Test
162167
public void testQualifiedName() {
163168
assertEquals("source=t | fields + field0", anonymize("source=t | fields field0"));

0 commit comments

Comments
 (0)