4848import java .util .Locale ;
4949import java .util .Map ;
5050import java .util .Set ;
51+ import org .apache .calcite .rel .RelNode ;
5152import org .apache .calcite .rel .type .RelDataType ;
5253import org .apache .calcite .rex .RexCall ;
5354import org .apache .calcite .rex .RexInputRef ;
6465import org .opensearch .index .query .QueryBuilder ;
6566import org .opensearch .index .query .RangeQueryBuilder ;
6667import org .opensearch .sql .calcite .plan .OpenSearchConstants ;
68+ import org .opensearch .sql .opensearch .data .type .OpenSearchDataType ;
69+ import org .opensearch .sql .opensearch .data .type .OpenSearchDataType .MappingType ;
70+ import org .opensearch .sql .opensearch .data .type .OpenSearchTextType ;
6771
6872/**
6973 * Query predicate analyzer. Uses visitor pattern to traverse existing expression and convert it to
@@ -92,8 +96,8 @@ public static final class PredicateAnalyzerException extends RuntimeException {
9296 }
9397
9498 /**
95- * Exception that is thrown when a {@link org.apache.calcite.rel. RelNode} expression cannot be
96- * processed (or converted into an OpenSearch query).
99+ * Exception that is thrown when a {@link RelNode} expression cannot be processed (or converted
100+ * into an OpenSearch query).
97101 */
98102 public static class ExpressionNotAnalyzableException extends Exception {
99103 ExpressionNotAnalyzableException (String message , Throwable cause ) {
@@ -112,15 +116,19 @@ private PredicateAnalyzer() {}
112116 * filters.
113117 *
114118 * @param expression expression to analyze
119+ * @param schema current schema of scan operator
120+ * @param typeMapping mapping of OpenSearch field name to OpenSearchDataType
115121 * @return search query which can be used to query OS cluster
116122 * @throws ExpressionNotAnalyzableException when expression can't processed by this analyzer
117123 */
118- public static QueryBuilder analyze (RexNode expression , List <String > schema )
124+ public static QueryBuilder analyze (
125+ RexNode expression , List <String > schema , Map <String , OpenSearchDataType > typeMapping )
119126 throws ExpressionNotAnalyzableException {
120127 requireNonNull (expression , "expression" );
121128 try {
122129 // visits expression tree
123- QueryExpression queryExpression = (QueryExpression ) expression .accept (new Visitor (schema ));
130+ QueryExpression queryExpression =
131+ (QueryExpression ) expression .accept (new Visitor (schema , typeMapping ));
124132
125133 if (queryExpression != null && queryExpression .isPartial ()) {
126134 throw new UnsupportedOperationException (
@@ -137,15 +145,17 @@ public static QueryBuilder analyze(RexNode expression, List<String> schema)
137145 private static class Visitor extends RexVisitorImpl <Expression > {
138146
139147 List <String > schema ;
148+ Map <String , OpenSearchDataType > typeMapping ;
140149
141- private Visitor (List <String > schema ) {
150+ private Visitor (List <String > schema , Map < String , OpenSearchDataType > typeMapping ) {
142151 super (true );
143152 this .schema = schema ;
153+ this .typeMapping = typeMapping ;
144154 }
145155
146156 @ Override
147157 public Expression visitInputRef (RexInputRef inputRef ) {
148- return new NamedFieldExpression (inputRef , schema );
158+ return new NamedFieldExpression (inputRef , schema , typeMapping );
149159 }
150160
151161 @ Override
@@ -246,7 +256,7 @@ public Expression visitCall(RexCall call) {
246256
247257 SqlSyntax syntax = call .getOperator ().getSyntax ();
248258 if (!supportedRexCall (call )) {
249- String message = String . format (Locale .ROOT , "Unsupported call: [%s]" , call );
259+ String message = format (Locale .ROOT , "Unsupported call: [%s]" , call );
250260 throw new PredicateAnalyzerException (message );
251261 }
252262
@@ -262,7 +272,7 @@ public Expression visitCall(RexCall call) {
262272 case CAST -> toCastExpression (call );
263273 case LIKE , CONTAINS -> binary (call );
264274 default -> {
265- String message = String . format (Locale .ROOT , "Unsupported call: [%s]" , call );
275+ String message = format (Locale .ROOT , "Unsupported call: [%s]" , call );
266276 throw new PredicateAnalyzerException (message );
267277 }
268278 };
@@ -291,7 +301,7 @@ private static String convertQueryString(List<Expression> fields, Expression que
291301 for (Expression expr : fields ) {
292302 if (expr instanceof NamedFieldExpression ) {
293303 NamedFieldExpression field = (NamedFieldExpression ) expr ;
294- String fieldIndexString = String . format (Locale .ROOT , "$%d" , index ++);
304+ String fieldIndexString = format (Locale .ROOT , "$%d" , index ++);
295305 fieldMap .put (fieldIndexString , field .getReference ());
296306 }
297307 }
@@ -307,7 +317,7 @@ private QueryExpression prefix(RexCall call) {
307317 call .getKind () == SqlKind .NOT , "Expected %s got %s" , SqlKind .NOT , call .getKind ());
308318
309319 if (call .getOperands ().size () != 1 ) {
310- String message = String . format (Locale .ROOT , "Unsupported NOT operator: [%s]" , call );
320+ String message = format (Locale .ROOT , "Unsupported NOT operator: [%s]" , call );
311321 throw new PredicateAnalyzerException (message );
312322 }
313323
@@ -318,7 +328,7 @@ private QueryExpression prefix(RexCall call) {
318328 private QueryExpression postfix (RexCall call ) {
319329 checkArgument (call .getKind () == SqlKind .IS_NULL || call .getKind () == SqlKind .IS_NOT_NULL );
320330 if (call .getOperands ().size () != 1 ) {
321- String message = String . format (Locale .ROOT , "Unsupported operator: [%s]" , call );
331+ String message = format (Locale .ROOT , "Unsupported operator: [%s]" , call );
322332 throw new PredicateAnalyzerException (message );
323333 }
324334 Expression a = call .getOperands ().get (0 ).accept (this );
@@ -407,7 +417,7 @@ private QueryExpression binary(RexCall call) {
407417 default :
408418 break ;
409419 }
410- String message = String . format (Locale .ROOT , "Unable to handle call: [%s]" , call );
420+ String message = format (Locale .ROOT , "Unable to handle call: [%s]" , call );
411421 throw new PredicateAnalyzerException (message );
412422 }
413423
@@ -438,16 +448,15 @@ private QueryExpression andOr(RexCall call) {
438448 if (firstError != null ) {
439449 throw firstError ;
440450 } else {
441- final String message =
442- String .format (Locale .ROOT , "Unable to handle call: [%s]" , call );
451+ final String message = format (Locale .ROOT , "Unable to handle call: [%s]" , call );
443452 throw new PredicateAnalyzerException (message );
444453 }
445454 }
446455 return CompoundQueryExpression .or (expressions );
447456 case AND :
448457 return CompoundQueryExpression .and (partial , expressions );
449458 default :
450- String message = String . format (Locale .ROOT , "Unable to handle call: [%s]" , call );
459+ String message = format (Locale .ROOT , "Unable to handle call: [%s]" , call );
451460 throw new PredicateAnalyzerException (message );
452461 }
453462 }
@@ -506,7 +515,7 @@ private static SwapResult swap(Expression left, Expression right) {
506515
507516 if (literal == null || terminal == null ) {
508517 String message =
509- String . format (
518+ format (
510519 Locale .ROOT ,
511520 "Unexpected combination of expressions [left: %s] [right: %s]" ,
512521 left ,
@@ -610,7 +619,7 @@ public static QueryExpression create(TerminalExpression expression) {
610619 if (expression instanceof NamedFieldExpression ) {
611620 return new SimpleQueryExpression ((NamedFieldExpression ) expression );
612621 } else {
613- String message = String . format (Locale .ROOT , "Unsupported expression: [%s]" , expression );
622+ String message = format (Locale .ROOT , "Unsupported expression: [%s]" , expression );
614623 throw new PredicateAnalyzerException (message );
615624 }
616625 }
@@ -769,6 +778,10 @@ private String getFieldReference() {
769778 return rel .getReference ();
770779 }
771780
781+ private String getFieldReferenceForTermQuery () {
782+ return rel .getReferenceForTermQuery ();
783+ }
784+
772785 private SimpleQueryExpression (NamedFieldExpression rel ) {
773786 this .rel = rel ;
774787 }
@@ -832,9 +845,7 @@ public QueryExpression equals(LiteralExpression literal) {
832845 .must (addFormatIfNecessary (literal , rangeQuery (getFieldReference ()).gte (value )))
833846 .must (addFormatIfNecessary (literal , rangeQuery (getFieldReference ()).lte (value )));
834847 } else {
835- // TODO: equal(textFieldType, "value") should not rewrite as termQuery,
836- // it should be addressed by issue: https://github.com/opensearch-project/sql/issues/3334
837- builder = termQuery (getFieldReference (), value );
848+ builder = termQuery (getFieldReferenceForTermQuery (), value );
838849 }
839850 return this ;
840851 }
@@ -852,7 +863,7 @@ public QueryExpression notEquals(LiteralExpression literal) {
852863 boolQuery ()
853864 // NOT LIKE should return false when field is NULL
854865 .must (existsQuery (getFieldReference ()))
855- .mustNot (termQuery (getFieldReference (), value ));
866+ .mustNot (termQuery (getFieldReferenceForTermQuery (), value ));
856867 }
857868 return this ;
858869 }
@@ -892,21 +903,21 @@ public QueryExpression queryString(String query) {
892903
893904 @ Override
894905 public QueryExpression isTrue () {
895- builder = termQuery (getFieldReference (), true );
906+ builder = termQuery (getFieldReferenceForTermQuery (), true );
896907 return this ;
897908 }
898909
899910 @ Override
900911 public QueryExpression in (LiteralExpression literal ) {
901912 Collection <?> collection = (Collection <?>) literal .value ();
902- builder = termsQuery (getFieldReference (), collection );
913+ builder = termsQuery (getFieldReferenceForTermQuery (), collection );
903914 return this ;
904915 }
905916
906917 @ Override
907918 public QueryExpression notIn (LiteralExpression literal ) {
908919 Collection <?> collection = (Collection <?>) literal .value ();
909- builder = boolQuery ().mustNot (termsQuery (getFieldReference (), collection ));
920+ builder = boolQuery ().mustNot (termsQuery (getFieldReferenceForTermQuery (), collection ));
910921 return this ;
911922 }
912923 }
@@ -962,31 +973,64 @@ static boolean isCastExpression(Expression exp) {
962973 static final class NamedFieldExpression implements TerminalExpression {
963974
964975 private final String name ;
976+ private final OpenSearchDataType type ;
965977
966978 private NamedFieldExpression () {
967979 this .name = null ;
980+ this .type = null ;
968981 }
969982
970- private NamedFieldExpression (RexInputRef ref , List <String > schema ) {
983+ private NamedFieldExpression (
984+ RexInputRef ref , List <String > schema , Map <String , OpenSearchDataType > typeMapping ) {
971985 this .name =
972986 (ref == null || ref .getIndex () >= schema .size ()) ? null : schema .get (ref .getIndex ());
987+ this .type = typeMapping .get (name );
973988 }
974989
975990 private NamedFieldExpression (RexLiteral literal ) {
976991 this .name = literal == null ? null : RexLiteral .stringValue (literal );
992+ this .type = null ;
977993 }
978994
979995 String getRootName () {
980996 return name ;
981997 }
982998
999+ OpenSearchDataType getOpenSearchDataType () {
1000+ return type ;
1001+ }
1002+
1003+ boolean isTextType () {
1004+ return type != null && type .getMappingType () == OpenSearchDataType .MappingType .Text ;
1005+ }
1006+
1007+ String toKeywordSubField () {
1008+ if (type instanceof OpenSearchTextType ) {
1009+ OpenSearchTextType textType = (OpenSearchTextType ) type ;
1010+ // Find the first subfield with type keyword, return null if non-exist.
1011+ return textType .getFields ().entrySet ().stream ()
1012+ .filter (e -> e .getValue ().getMappingType () == MappingType .Keyword )
1013+ .findFirst ()
1014+ .map (e -> name + "." + e .getKey ())
1015+ .orElse (null );
1016+ }
1017+ return null ;
1018+ }
1019+
9831020 boolean isMetaField () {
9841021 return OpenSearchConstants .METADATAFIELD_TYPE_MAP .containsKey (getRootName ());
9851022 }
9861023
9871024 String getReference () {
9881025 return getRootName ();
9891026 }
1027+
1028+ String getReferenceForTermQuery () {
1029+ if (isTextType ()) {
1030+ return toKeywordSubField ();
1031+ }
1032+ return getRootName ();
1033+ }
9901034 }
9911035
9921036 /** Literal like {@code 'foo' or 42 or true} etc. */
0 commit comments