Skip to content

Commit ac06da4

Browse files
authored
Coerce any to specific types (#4753)
Signed-off-by: Tomoyuki Morita <[email protected]>
1 parent 4246a49 commit ac06da4

File tree

4 files changed

+111
-55
lines changed

4 files changed

+111
-55
lines changed

core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -247,16 +247,17 @@ public RelNode visitRegex(Regex node, CalcitePlanContext context) {
247247
RexNode fieldRex = rexVisitor.analyze(node.getField(), context);
248248
RexNode patternRex = rexVisitor.analyze(node.getPattern(), context);
249249

250-
if (!SqlTypeFamily.CHARACTER.contains(fieldRex.getType())) {
250+
if (!SqlTypeFamily.CHARACTER.contains(fieldRex.getType())
251+
&& !SqlTypeName.ANY.equals(fieldRex.getType().getSqlTypeName())) {
251252
throw new IllegalArgumentException(
252253
String.format(
253254
"Regex command requires field of string type, but got %s for field '%s'",
254255
fieldRex.getType().getSqlTypeName(), node.getField().toString()));
255256
}
256257

257258
RexNode regexCondition =
258-
context.rexBuilder.makeCall(
259-
org.apache.calcite.sql.fun.SqlLibraryOperators.REGEXP_CONTAINS, fieldRex, patternRex);
259+
PPLFuncImpTable.INSTANCE.resolve(
260+
context.rexBuilder, BuiltinFunctionName.REGEX_MATCH, fieldRex, patternRex);
260261

261262
if (node.isNegated()) {
262263
regexCondition = context.rexBuilder.makeCall(SqlStdOperatorTable.NOT, regexCondition);

core/src/main/java/org/opensearch/sql/expression/function/CoercionUtils.java

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717
import java.util.function.BinaryOperator;
1818
import java.util.stream.Collectors;
1919
import javax.annotation.Nullable;
20+
import org.apache.calcite.rel.type.RelDataType;
2021
import org.apache.calcite.rex.RexBuilder;
2122
import org.apache.calcite.rex.RexNode;
23+
import org.apache.calcite.sql.type.SqlTypeName;
2224
import org.apache.commons.lang3.tuple.Pair;
2325
import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory;
2426
import org.opensearch.sql.data.type.ExprCoreType;
@@ -106,6 +108,14 @@ public final class CoercionUtils {
106108

107109
private static @Nullable RexNode cast(RexBuilder builder, ExprType targetType, RexNode arg) {
108110
ExprType argType = OpenSearchTypeFactory.convertRelDataTypeToExprType(arg.getType());
111+
112+
// Special handling for ANY type (dynamic fields)
113+
if (isAnyType(arg.getType())) {
114+
// ANY type can be cast to any target type
115+
return builder.makeCast(
116+
OpenSearchTypeFactory.convertExprTypeToRelDataType(targetType), arg, true, true);
117+
}
118+
109119
if (!argType.shouldCast(targetType)) {
110120
return arg;
111121
}
@@ -176,6 +186,10 @@ public static boolean hasString(List<RexNode> rexNodeList) {
176186

177187
private static final List<CoercionRule> COMMON_COERCION_RULES =
178188
List.of(
189+
// ANY type coercion: if one side is ANY, use the other type
190+
CoercionRule.of(
191+
(left, right) -> isAnyType(left) || isAnyType(right),
192+
(left, right) -> isAnyType(left) ? right : left),
179193
CoercionRule.of(
180194
(left, right) -> areDateAndTime(left, right),
181195
(left, right) -> ExprCoreType.TIMESTAMP),
@@ -212,6 +226,7 @@ static CoercionRule of(
212226

213227
private static final int IMPOSSIBLE_WIDENING = Integer.MAX_VALUE;
214228
private static final int TYPE_EQUAL = 0;
229+
private static final int ANY_TYPE_DISTANCE = 1;
215230

216231
private static int distance(ExprType type1, ExprType type2) {
217232
return distance(type1, type2, TYPE_EQUAL);
@@ -222,6 +237,9 @@ private static int distance(ExprType type1, ExprType type2, int distance) {
222237
return distance;
223238
} else if (type1 == UNKNOWN) {
224239
return IMPOSSIBLE_WIDENING;
240+
} else if (isAnyType(type1)) {
241+
// ANY type (from dynamic fields) can coerce to any other type with distance 1
242+
return distance + ANY_TYPE_DISTANCE;
225243
} else if (type1 == ExprCoreType.STRING && type2 == ExprCoreType.DOUBLE) {
226244
return 1;
227245
} else {
@@ -232,6 +250,20 @@ private static int distance(ExprType type1, ExprType type2, int distance) {
232250
}
233251
}
234252

253+
private static boolean isAnyType(RelDataType type) {
254+
return type.getSqlTypeName() == SqlTypeName.ANY;
255+
}
256+
257+
private static boolean isAnyType(ExprType type) {
258+
// UNDEFINED is the ExprType representation of SqlTypeName.ANY
259+
// but we need to distinguish it from actual UNDEFINED (NULL type)
260+
// In the context of dynamic fields, UNDEFINED with no parents represents ANY type
261+
if (type == ExprCoreType.UNDEFINED) {
262+
return type.getParent().isEmpty();
263+
}
264+
return false;
265+
}
266+
235267
/**
236268
* The max type among two types. The max is defined as follow if type1 could widen to type2, then
237269
* max is type2, vice versa if type1 couldn't widen to type2 and type2 could't widen to type1,

core/src/test/java/org/opensearch/sql/expression/function/CoercionUtilsTest.java

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import org.apache.calcite.rel.type.RelDataType;
1717
import org.apache.calcite.rex.RexBuilder;
1818
import org.apache.calcite.rex.RexNode;
19+
import org.apache.calcite.sql.type.SqlTypeName;
1920
import org.junit.jupiter.api.Test;
2021
import org.junit.jupiter.params.ParameterizedTest;
2122
import org.junit.jupiter.params.provider.Arguments;
@@ -32,6 +33,11 @@ private static RexNode nullLiteral(ExprCoreType type) {
3233
return REX_BUILDER.makeNullLiteral(OpenSearchTypeFactory.convertExprTypeToRelDataType(type));
3334
}
3435

36+
private static RexNode anyTypedLiteral() {
37+
return REX_BUILDER.makeNullLiteral(
38+
OpenSearchTypeFactory.TYPE_FACTORY.createSqlType(SqlTypeName.ANY));
39+
}
40+
3541
private static Stream<Arguments> commonWidestTypeArguments() {
3642
return Stream.of(
3743
Arguments.of(STRING, INTEGER, DOUBLE),
@@ -82,6 +88,26 @@ void castArgumentsReturnsNullWhenNoCompatibleSignatureExists() {
8288
assertNull(CoercionUtils.castArguments(REX_BUILDER, typeChecker, arguments));
8389
}
8490

91+
@Test
92+
void coerceAnyToString() {
93+
testAnyToSpecificTypeCoercion(STRING);
94+
testAnyToSpecificTypeCoercion(INTEGER);
95+
testAnyToSpecificTypeCoercion(DOUBLE);
96+
testAnyToSpecificTypeCoercion(BOOLEAN);
97+
testAnyToSpecificTypeCoercion(ExprCoreType.TIMESTAMP);
98+
}
99+
100+
void testAnyToSpecificTypeCoercion(ExprCoreType toType) {
101+
PPLTypeChecker typeChecker = new StubTypeChecker(List.of(List.of(toType)));
102+
List<RexNode> arguments = List.of(anyTypedLiteral());
103+
104+
List<RexNode> result = CoercionUtils.castArguments(REX_BUILDER, typeChecker, arguments);
105+
106+
assertEquals(1, result.size());
107+
assertEquals(
108+
toType, OpenSearchTypeFactory.convertRelDataTypeToExprType(result.getFirst().getType()));
109+
}
110+
85111
private static class StubTypeChecker implements PPLTypeChecker {
86112
private final List<List<ExprType>> signatures;
87113

integ-test/src/test/java/org/opensearch/sql/calcite/standalone/CalciteDynamicFieldsCommandIT.java

Lines changed: 49 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -124,16 +124,6 @@ public void testRegex() throws IOException {
124124
String query =
125125
source(
126126
TEST_INDEX_DYNAMIC, "regex department=\"Eng.*\" | fields account_number, department");
127-
assertThrows(RuntimeException.class, () -> executeQuery(query));
128-
}
129-
130-
@Test
131-
public void testCastAndRegex() throws IOException {
132-
String query =
133-
source(
134-
TEST_INDEX_DYNAMIC,
135-
"eval department=CAST(department AS string) | regex department=\"Eng.*\" | fields"
136-
+ " account_number, department");
137127
JSONObject result = executeQuery(query);
138128

139129
verifySchema(result, schema("account_number", "bigint"), schema("department", "string"));
@@ -145,40 +135,39 @@ public void testRex() throws IOException {
145135
String query =
146136
source(
147137
TEST_INDEX_DYNAMIC,
148-
"rex field=firstname \"(?<initial>[A-Z])\" | fields firstname, initial | head 1");
138+
"rex field=department '(?<initial>[A-Z])'" + " | fields department, initial | head 1");
149139
JSONObject result = executeQuery(query);
150140

151-
verifySchema(result, schema("firstname", "string"), schema("initial", "string"));
152-
verifyDataRows(result, rows("John", "J"));
141+
verifySchema(result, schema("department", "string"), schema("initial", "string"));
142+
verifyDataRows(result, rows("Engineering", "E"));
153143
}
154144

155145
@Test
156-
public void testCastAndRex() throws IOException {
146+
public void testRexWithNumeric() throws IOException {
157147
String query =
158148
source(
159149
TEST_INDEX_DYNAMIC,
160-
"eval department=CAST(department AS string) | rex field=department '(?<initial>[A-Z])'"
161-
+ " | fields department, initial | head 1");
150+
"rex field=salary '(?<initial>[0-9])'" + " | fields salary, initial | head 1");
162151
JSONObject result = executeQuery(query);
163152

164-
verifySchema(result, schema("department", "string"), schema("initial", "string"));
165-
verifyDataRows(result, rows("Engineering", "E"));
153+
verifySchema(result, schema("salary", "int"), schema("initial", "string"));
154+
verifyDataRows(result, rows(75000, "7"));
166155
}
167156

168157
@Test
169-
public void testCastAndParse() throws IOException {
158+
public void testParse() throws IOException {
170159
String query =
171160
source(
172161
TEST_INDEX_DYNAMIC,
173-
"eval department=CAST(department AS string) | fields department | parse department"
174-
+ " '(?<initial>[A-Z]).*' | fields department, initial | head 1");
162+
"parse department '(?<initial>[A-Z]).*' | fields department, initial | head 1");
163+
175164
assertExplainYaml(
176165
query,
177166
"calcite:\n"
178167
+ " logical: |\n"
179168
+ " LogicalSystemLimit(fetch=[200], type=[QUERY_SIZE_LIMIT])\n"
180169
+ " LogicalSort(fetch=[1])\n"
181-
+ " LogicalProject(department=[SAFE_CAST(ITEM($9, 'department'))],"
170+
+ " LogicalProject(department=[ITEM($9, 'department')],"
182171
+ " initial=[ITEM(PARSE(SAFE_CAST(ITEM($9, 'department')),"
183172
+ " '(?<initial>[A-Z]).*':VARCHAR, 'regex':VARCHAR), 'initial':VARCHAR)])\n"
184173
+ " CalciteLogicalIndexScan(table=[[OpenSearch, test_dynamic_fields]])\n"
@@ -188,7 +177,7 @@ public void testCastAndParse() throws IOException {
188177
+ " expr#11=[ITEM($t9, $t10)], expr#12=[SAFE_CAST($t11)],"
189178
+ " expr#13=['(?<initial>[A-Z]).*':VARCHAR], expr#14=['regex':VARCHAR],"
190179
+ " expr#15=[PARSE($t12, $t13, $t14)], expr#16=['initial':VARCHAR], expr#17=[ITEM($t15,"
191-
+ " $t16)], department=[$t12], initial=[$t17])\n"
180+
+ " $t16)], department=[$t11], initial=[$t17])\n"
192181
+ " EnumerableLimit(fetch=[1])\n"
193182
+ " CalciteEnumerableIndexScan(table=[[OpenSearch, test_dynamic_fields]])\n");
194183

@@ -235,28 +224,41 @@ public void testRevers() throws IOException {
235224

236225
@Test
237226
public void testBin() throws IOException {
238-
String query = source(TEST_INDEX_DYNAMIC, "bin salary span=10000 | head 1");
227+
String query =
228+
source(TEST_INDEX_DYNAMIC, "bin salary span=10000 | fields account_number, salary");
239229
JSONObject result = executeQuery(query);
240230

241-
verifySchema(
242-
result,
243-
schema("account_number", "bigint"),
244-
schema("firstname", "string"),
245-
schema("lastname", "string"),
246-
schema("salary", "string"),
247-
schema("city", "string"),
248-
schema("department", "string"),
249-
schema("json", "string"));
231+
assertExplainYaml(
232+
query,
233+
"calcite:\n"
234+
+ " logical: |\n"
235+
+ " LogicalSystemLimit(fetch=[200], type=[QUERY_SIZE_LIMIT])\n"
236+
+ " LogicalProject(account_number=[$0], salary=[SPAN_BUCKET(ITEM($9, 'salary'),"
237+
+ " 10000)])\n"
238+
+ " CalciteLogicalIndexScan(table=[[OpenSearch, test_dynamic_fields]])\n"
239+
+ " physical: |\n"
240+
+ " EnumerableCalc(expr#0..9=[{inputs}], expr#10=['salary'], expr#11=[ITEM($t9,"
241+
+ " $t10)], expr#12=[10000], expr#13=[SPAN_BUCKET($t11, $t12)], account_number=[$t0],"
242+
+ " salary=[$t13])\n"
243+
+ " EnumerableLimit(fetch=[200])\n"
244+
+ " CalciteEnumerableIndexScan(table=[[OpenSearch, test_dynamic_fields]])\n");
245+
246+
verifySchema(result, schema("account_number", "bigint"), schema("salary", "string"));
250247
verifyDataRows(
251-
result, rows(1, "John", "Doe", "70000-80000", "NYC", "{\"n\":1}", "Engineering"));
248+
result,
249+
rows(1, "70000-80000"),
250+
rows(2, "60000-70000"),
251+
rows(3, null),
252+
rows(4, "80000-90000"),
253+
rows(5, "60000-70000"));
252254
}
253255

254256
@Test
255-
public void testCastAndPatterns() throws IOException {
257+
public void testPatterns() throws IOException {
256258
String query =
257259
source(
258260
TEST_INDEX_DYNAMIC,
259-
"eval department=CAST(department as string) | patterns department method=simple_pattern"
261+
"patterns department method=simple_pattern"
260262
+ " | fields department, patterns_field | head 1");
261263
JSONObject result = executeQuery(query);
262264

@@ -265,12 +267,12 @@ public void testCastAndPatterns() throws IOException {
265267
}
266268

267269
@Test
268-
public void testCastAndPatternsWithAggregation() throws IOException {
270+
public void testPatternsWithAggregation() throws IOException {
269271
// TODO:
270272
String query =
271273
source(
272274
TEST_INDEX_DYNAMIC,
273-
"eval department=CAST(department as string) | patterns department mode=aggregation"
275+
"patterns department mode=aggregation"
274276
+ " method=simple_pattern | fields patterns_field, pattern_count, sample_logs |"
275277
+ " head 1");
276278
JSONObject result = executeQuery(query);
@@ -490,34 +492,29 @@ public void testEval() throws IOException {
490492
String query =
491493
source(
492494
TEST_INDEX_DYNAMIC,
493-
"eval salary = cast(salary as int) * 2 | fields firstname,"
494-
+ " lastname, salary | head 1");
495+
"eval salary = 1 + salary * 2 | fields account_number, salary | head 1");
495496

496497
assertExplainYaml(
497498
query,
498499
"calcite:\n"
499500
+ " logical: |\n"
500501
+ " LogicalSystemLimit(fetch=[200], type=[QUERY_SIZE_LIMIT])\n"
501502
+ " LogicalSort(fetch=[1])\n"
502-
+ " LogicalProject(firstname=[$1], lastname=[$2], salary=[*(SAFE_CAST(ITEM($9,"
503-
+ " 'salary')), 2)])\n"
503+
+ " LogicalProject(account_number=[$0], salary=[+(1, *(SAFE_CAST(ITEM($9,"
504+
+ " 'salary')), 2.0E0:DOUBLE))])\n"
504505
+ " CalciteLogicalIndexScan(table=[[OpenSearch, test_dynamic_fields]])\n"
505506
+ " physical: |\n"
506507
+ " EnumerableLimit(fetch=[200])\n"
507-
+ " EnumerableCalc(expr#0..9=[{inputs}], expr#10=['salary'], expr#11=[ITEM($t9,"
508-
+ " $t10)], expr#12=[SAFE_CAST($t11)], expr#13=[2], expr#14=[*($t12, $t13)],"
509-
+ " firstname=[$t1], lastname=[$t2], salary=[$t14])\n"
508+
+ " EnumerableCalc(expr#0..9=[{inputs}], expr#10=[1], expr#11=['salary'],"
509+
+ " expr#12=[ITEM($t9, $t11)], expr#13=[SAFE_CAST($t12)], expr#14=[2.0E0:DOUBLE],"
510+
+ " expr#15=[*($t13, $t14)], expr#16=[+($t10, $t15)], account_number=[$t0],"
511+
+ " salary=[$t16])\n"
510512
+ " EnumerableLimit(fetch=[1])\n"
511-
+ " CalciteEnumerableIndexScan(table=[[OpenSearch, test_dynamic_fields]])\n"
512-
+ "");
513+
+ " CalciteEnumerableIndexScan(table=[[OpenSearch, test_dynamic_fields]])\n");
513514

514515
JSONObject result = executeQuery(query);
515516

516-
verifySchema(
517-
result,
518-
schema("firstname", "string"),
519-
schema("lastname", "string"),
520-
schema("salary", "int"));
517+
verifySchema(result, schema("account_number", "bigint"), schema("salary", "double"));
521518
}
522519

523520
private void createTestIndexWithUnmappedFields() throws IOException {

0 commit comments

Comments
 (0)