4848import java .util .Locale ;
4949import java .util .Map ;
5050import java .util .Set ;
51+ import org .apache .calcite .rel .RelNode ;
5152import org .apache .calcite .rel .type .RelDataType ;
5253import org .apache .calcite .rex .RexCall ;
5354import org .apache .calcite .rex .RexInputRef ;
6465import org .opensearch .index .query .QueryBuilder ;
6566import org .opensearch .index .query .RangeQueryBuilder ;
6667import org .opensearch .sql .calcite .plan .OpenSearchConstants ;
68+ import org .opensearch .sql .opensearch .data .type .OpenSearchDataType ;
69+ import org .opensearch .sql .opensearch .data .type .OpenSearchDataType .MappingType ;
70+ import org .opensearch .sql .opensearch .data .type .OpenSearchTextType ;
6771
6872/**
6973 * Query predicate analyzer. Uses visitor pattern to traverse existing expression and convert it to
@@ -92,8 +96,8 @@ public static final class PredicateAnalyzerException extends RuntimeException {
9296 }
9397
9498 /**
95- * Exception that is thrown when a {@link org.apache.calcite.rel. RelNode} expression cannot be
96- * processed (or converted into an OpenSearch query).
99+ * Exception that is thrown when a {@link RelNode} expression cannot be processed (or converted
100+ * into an OpenSearch query).
97101 */
98102 public static class ExpressionNotAnalyzableException extends Exception {
99103 ExpressionNotAnalyzableException (String message , Throwable cause ) {
@@ -112,15 +116,19 @@ private PredicateAnalyzer() {}
112116 * filters.
113117 *
114118 * @param expression expression to analyze
119+ * @param schema current schema of scan operator
120+ * @param typeMapping mapping of OpenSearch field name to OpenSearchDataType
115121 * @return search query which can be used to query OS cluster
116122 * @throws ExpressionNotAnalyzableException when expression can't processed by this analyzer
117123 */
118- public static QueryBuilder analyze (RexNode expression , List <String > schema )
124+ public static QueryBuilder analyze (
125+ RexNode expression , List <String > schema , Map <String , OpenSearchDataType > typeMapping )
119126 throws ExpressionNotAnalyzableException {
120127 requireNonNull (expression , "expression" );
121128 try {
122129 // visits expression tree
123- QueryExpression queryExpression = (QueryExpression ) expression .accept (new Visitor (schema ));
130+ QueryExpression queryExpression =
131+ (QueryExpression ) expression .accept (new Visitor (schema , typeMapping ));
124132
125133 if (queryExpression != null && queryExpression .isPartial ()) {
126134 throw new UnsupportedOperationException (
@@ -137,15 +145,17 @@ public static QueryBuilder analyze(RexNode expression, List<String> schema)
137145 private static class Visitor extends RexVisitorImpl <Expression > {
138146
139147 List <String > schema ;
148+ Map <String , OpenSearchDataType > typeMapping ;
140149
141- private Visitor (List <String > schema ) {
150+ private Visitor (List <String > schema , Map < String , OpenSearchDataType > typeMapping ) {
142151 super (true );
143152 this .schema = schema ;
153+ this .typeMapping = typeMapping ;
144154 }
145155
146156 @ Override
147157 public Expression visitInputRef (RexInputRef inputRef ) {
148- return new NamedFieldExpression (inputRef , schema );
158+ return new NamedFieldExpression (inputRef , schema , typeMapping );
149159 }
150160
151161 @ Override
@@ -246,7 +256,7 @@ public Expression visitCall(RexCall call) {
246256
247257 SqlSyntax syntax = call .getOperator ().getSyntax ();
248258 if (!supportedRexCall (call )) {
249- String message = String . format (Locale .ROOT , "Unsupported call: [%s]" , call );
259+ String message = format (Locale .ROOT , "Unsupported call: [%s]" , call );
250260 throw new PredicateAnalyzerException (message );
251261 }
252262
@@ -299,7 +309,7 @@ private static String convertQueryString(List<Expression> fields, Expression que
299309 for (Expression expr : fields ) {
300310 if (expr instanceof NamedFieldExpression ) {
301311 NamedFieldExpression field = (NamedFieldExpression ) expr ;
302- String fieldIndexString = String . format (Locale .ROOT , "$%d" , index ++);
312+ String fieldIndexString = format (Locale .ROOT , "$%d" , index ++);
303313 fieldMap .put (fieldIndexString , field .getReference ());
304314 }
305315 }
@@ -315,7 +325,7 @@ private QueryExpression prefix(RexCall call) {
315325 call .getKind () == SqlKind .NOT , "Expected %s got %s" , SqlKind .NOT , call .getKind ());
316326
317327 if (call .getOperands ().size () != 1 ) {
318- String message = String . format (Locale .ROOT , "Unsupported NOT operator: [%s]" , call );
328+ String message = format (Locale .ROOT , "Unsupported NOT operator: [%s]" , call );
319329 throw new PredicateAnalyzerException (message );
320330 }
321331
@@ -326,7 +336,7 @@ private QueryExpression prefix(RexCall call) {
326336 private QueryExpression postfix (RexCall call ) {
327337 checkArgument (call .getKind () == SqlKind .IS_NULL || call .getKind () == SqlKind .IS_NOT_NULL );
328338 if (call .getOperands ().size () != 1 ) {
329- String message = String . format (Locale .ROOT , "Unsupported operator: [%s]" , call );
339+ String message = format (Locale .ROOT , "Unsupported operator: [%s]" , call );
330340 throw new PredicateAnalyzerException (message );
331341 }
332342 Expression a = call .getOperands ().get (0 ).accept (this );
@@ -415,7 +425,7 @@ private QueryExpression binary(RexCall call) {
415425 default :
416426 break ;
417427 }
418- String message = String . format (Locale .ROOT , "Unable to handle call: [%s]" , call );
428+ String message = format (Locale .ROOT , "Unable to handle call: [%s]" , call );
419429 throw new PredicateAnalyzerException (message );
420430 }
421431
@@ -446,16 +456,15 @@ private QueryExpression andOr(RexCall call) {
446456 if (firstError != null ) {
447457 throw firstError ;
448458 } else {
449- final String message =
450- String .format (Locale .ROOT , "Unable to handle call: [%s]" , call );
459+ final String message = format (Locale .ROOT , "Unable to handle call: [%s]" , call );
451460 throw new PredicateAnalyzerException (message );
452461 }
453462 }
454463 return CompoundQueryExpression .or (expressions );
455464 case AND :
456465 return CompoundQueryExpression .and (partial , expressions );
457466 default :
458- String message = String . format (Locale .ROOT , "Unable to handle call: [%s]" , call );
467+ String message = format (Locale .ROOT , "Unable to handle call: [%s]" , call );
459468 throw new PredicateAnalyzerException (message );
460469 }
461470 }
@@ -514,7 +523,7 @@ private static SwapResult swap(Expression left, Expression right) {
514523
515524 if (literal == null || terminal == null ) {
516525 String message =
517- String . format (
526+ format (
518527 Locale .ROOT ,
519528 "Unexpected combination of expressions [left: %s] [right: %s]" ,
520529 left ,
@@ -618,7 +627,7 @@ public static QueryExpression create(TerminalExpression expression) {
618627 if (expression instanceof NamedFieldExpression ) {
619628 return new SimpleQueryExpression ((NamedFieldExpression ) expression );
620629 } else {
621- String message = String . format (Locale .ROOT , "Unsupported expression: [%s]" , expression );
630+ String message = format (Locale .ROOT , "Unsupported expression: [%s]" , expression );
622631 throw new PredicateAnalyzerException (message );
623632 }
624633 }
@@ -777,6 +786,10 @@ private String getFieldReference() {
777786 return rel .getReference ();
778787 }
779788
789+ private String getFieldReferenceForTermQuery () {
790+ return rel .getReferenceForTermQuery ();
791+ }
792+
780793 private SimpleQueryExpression (NamedFieldExpression rel ) {
781794 this .rel = rel ;
782795 }
@@ -840,9 +853,7 @@ public QueryExpression equals(LiteralExpression literal) {
840853 .must (addFormatIfNecessary (literal , rangeQuery (getFieldReference ()).gte (value )))
841854 .must (addFormatIfNecessary (literal , rangeQuery (getFieldReference ()).lte (value )));
842855 } else {
843- // TODO: equal(textFieldType, "value") should not rewrite as termQuery,
844- // it should be addressed by issue: https://github.com/opensearch-project/sql/issues/3334
845- builder = termQuery (getFieldReference (), value );
856+ builder = termQuery (getFieldReferenceForTermQuery (), value );
846857 }
847858 return this ;
848859 }
@@ -860,7 +871,7 @@ public QueryExpression notEquals(LiteralExpression literal) {
860871 boolQuery ()
861872 // NOT LIKE should return false when field is NULL
862873 .must (existsQuery (getFieldReference ()))
863- .mustNot (termQuery (getFieldReference (), value ));
874+ .mustNot (termQuery (getFieldReferenceForTermQuery (), value ));
864875 }
865876 return this ;
866877 }
@@ -900,21 +911,21 @@ public QueryExpression queryString(String query) {
900911
901912 @ Override
902913 public QueryExpression isTrue () {
903- builder = termQuery (getFieldReference (), true );
914+ builder = termQuery (getFieldReferenceForTermQuery (), true );
904915 return this ;
905916 }
906917
907918 @ Override
908919 public QueryExpression in (LiteralExpression literal ) {
909920 Collection <?> collection = (Collection <?>) literal .value ();
910- builder = termsQuery (getFieldReference (), collection );
921+ builder = termsQuery (getFieldReferenceForTermQuery (), collection );
911922 return this ;
912923 }
913924
914925 @ Override
915926 public QueryExpression notIn (LiteralExpression literal ) {
916927 Collection <?> collection = (Collection <?>) literal .value ();
917- builder = boolQuery ().mustNot (termsQuery (getFieldReference (), collection ));
928+ builder = boolQuery ().mustNot (termsQuery (getFieldReferenceForTermQuery (), collection ));
918929 return this ;
919930 }
920931 }
@@ -970,31 +981,64 @@ static boolean isCastExpression(Expression exp) {
970981 static final class NamedFieldExpression implements TerminalExpression {
971982
972983 private final String name ;
984+ private final OpenSearchDataType type ;
973985
974986 private NamedFieldExpression () {
975987 this .name = null ;
988+ this .type = null ;
976989 }
977990
978- private NamedFieldExpression (RexInputRef ref , List <String > schema ) {
991+ private NamedFieldExpression (
992+ RexInputRef ref , List <String > schema , Map <String , OpenSearchDataType > typeMapping ) {
979993 this .name =
980994 (ref == null || ref .getIndex () >= schema .size ()) ? null : schema .get (ref .getIndex ());
995+ this .type = typeMapping .get (name );
981996 }
982997
983998 private NamedFieldExpression (RexLiteral literal ) {
984999 this .name = literal == null ? null : RexLiteral .stringValue (literal );
1000+ this .type = null ;
9851001 }
9861002
9871003 String getRootName () {
9881004 return name ;
9891005 }
9901006
1007+ OpenSearchDataType getOpenSearchDataType () {
1008+ return type ;
1009+ }
1010+
1011+ boolean isTextType () {
1012+ return type != null && type .getMappingType () == OpenSearchDataType .MappingType .Text ;
1013+ }
1014+
1015+ String toKeywordSubField () {
1016+ if (type instanceof OpenSearchTextType ) {
1017+ OpenSearchTextType textType = (OpenSearchTextType ) type ;
1018+ // Find the first subfield with type keyword, return null if non-exist.
1019+ return textType .getFields ().entrySet ().stream ()
1020+ .filter (e -> e .getValue ().getMappingType () == MappingType .Keyword )
1021+ .findFirst ()
1022+ .map (e -> name + "." + e .getKey ())
1023+ .orElse (null );
1024+ }
1025+ return null ;
1026+ }
1027+
9911028 boolean isMetaField () {
9921029 return OpenSearchConstants .METADATAFIELD_TYPE_MAP .containsKey (getRootName ());
9931030 }
9941031
9951032 String getReference () {
9961033 return getRootName ();
9971034 }
1035+
1036+ String getReferenceForTermQuery () {
1037+ if (isTextType ()) {
1038+ return toKeywordSubField ();
1039+ }
1040+ return getRootName ();
1041+ }
9981042 }
9991043
10001044 /** Literal like {@code 'foo' or 42 or true} etc. */
0 commit comments