From aea3034b47e413390ab4aa9d6196831538b4e6b8 Mon Sep 17 00:00:00 2001 From: Henrique Paes Date: Wed, 6 Nov 2024 10:19:01 -0500 Subject: [PATCH 1/4] spotting where type check should happen --- .../xpack/ml/queries/TextExpansionQueryBuilder.java | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java index 6d972bcf5863a..18abff7a8688e 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java @@ -38,6 +38,7 @@ import java.io.IOException; import java.util.List; import java.util.Objects; +import java.util.Set; import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; @@ -59,6 +60,8 @@ public class TextExpansionQueryBuilder extends AbstractQueryBuilder weightedTokensSupplier; private final TokenPruningConfig tokenPruningConfig; + private static final Set ALLOWED_FIELD_TYPES = Set.of("sparse_vector", "rank_features"); + private static final DeprecationLogger deprecationLogger = DeprecationLogger.getLogger(ParseField.class); public static final String TEXT_EXPANSION_DEPRECATION_MESSAGE = NAME + " is deprecated. Use sparse_vector instead."; @@ -158,6 +161,8 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) { return weightedTokensToQuery(fieldName, weightedTokensSupplier.get()); } + // Do field type check if query won't be rewritten as a WeightedTokensQuery + CoordinatedInferenceAction.Request inferRequest = CoordinatedInferenceAction.Request.forTextInput( modelId, List.of(modelText), From 9df2a86aa060f99a23928974d960b9ebd59fcbc3 Mon Sep 17 00:00:00 2001 From: Henrique Paes Date: Thu, 7 Nov 2024 11:31:01 -0500 Subject: [PATCH 2/4] using weighted tokens as criteria for rewriting text expansion query --- .../core/ml/inference/results/TextExpansionResults.java | 4 ++++ .../xpack/ml/queries/TextExpansionQueryBuilder.java | 7 +------ 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextExpansionResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextExpansionResults.java index 40d7fd221d09f..b4ea1a99d0574 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextExpansionResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextExpansionResults.java @@ -96,4 +96,8 @@ public Map asMap(String outputField) { map.put(outputField, weightedTokens.stream().collect(Collectors.toMap(WeightedToken::token, WeightedToken::weight))); return map; } + + public boolean hasWeightedTokens() { + return this.getWeightedTokens().isEmpty() == false; + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java index 18abff7a8688e..33ad73f4ccf92 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java @@ -38,7 +38,6 @@ import java.io.IOException; import java.util.List; import java.util.Objects; -import java.util.Set; import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; @@ -60,8 +59,6 @@ public class TextExpansionQueryBuilder extends AbstractQueryBuilder weightedTokensSupplier; private final TokenPruningConfig tokenPruningConfig; - private static final Set ALLOWED_FIELD_TYPES = Set.of("sparse_vector", "rank_features"); - private static final DeprecationLogger deprecationLogger = DeprecationLogger.getLogger(ParseField.class); public static final String TEXT_EXPANSION_DEPRECATION_MESSAGE = NAME + " is deprecated. Use sparse_vector instead."; @@ -161,8 +158,6 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) { return weightedTokensToQuery(fieldName, weightedTokensSupplier.get()); } - // Do field type check if query won't be rewritten as a WeightedTokensQuery - CoordinatedInferenceAction.Request inferRequest = CoordinatedInferenceAction.Request.forTextInput( modelId, List.of(modelText), @@ -213,7 +208,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) { } private QueryBuilder weightedTokensToQuery(String fieldName, TextExpansionResults textExpansionResults) { - if (tokenPruningConfig != null) { + if (tokenPruningConfig != null || textExpansionResults.hasWeightedTokens()) { WeightedTokensQueryBuilder weightedTokensQueryBuilder = new WeightedTokensQueryBuilder( fieldName, textExpansionResults.getWeightedTokens(), From 7c0a3e72eb4452375056e16a44338fe5c366269e Mon Sep 17 00:00:00 2001 From: Henrique Paes Date: Thu, 7 Nov 2024 11:31:32 -0500 Subject: [PATCH 3/4] modify test to ensure that query is rewritten to a supported type --- .../xpack/ml/queries/TextExpansionQueryBuilderTests.java | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilderTests.java index 00d50e0d0d7bb..c62d42c6fa783 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilderTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilderTests.java @@ -283,14 +283,11 @@ protected String[] shuffleProtectedFields() { return new String[] { WeightedTokensQueryBuilder.TOKENS_FIELD.getPreferredName() }; } - public void testThatTokensAreCorrectlyPruned() { + public void testQueryWasRewrittenToASupportedType() { SearchExecutionContext searchExecutionContext = createSearchExecutionContext(); TextExpansionQueryBuilder queryBuilder = createTestQueryBuilder(); QueryBuilder rewrittenQueryBuilder = rewriteAndFetch(queryBuilder, searchExecutionContext); - if (queryBuilder.getTokenPruningConfig() == null) { - assertTrue(rewrittenQueryBuilder instanceof BoolQueryBuilder); - } else { - assertTrue(rewrittenQueryBuilder instanceof WeightedTokensQueryBuilder); - } + + assertTrue(rewrittenQueryBuilder instanceof BoolQueryBuilder || rewrittenQueryBuilder instanceof WeightedTokensQueryBuilder); } } From 7d3e655a0ba0209173761cbf1142033043c915a8 Mon Sep 17 00:00:00 2001 From: Henrique Paes Date: Wed, 13 Nov 2024 14:21:21 -0500 Subject: [PATCH 4/4] add a test without pruning to yaml --- .../rest-api-spec/test/ml/text_expansion_search.yml | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/text_expansion_search.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/text_expansion_search.yml index 21a5a4736675d..9a2271f87cfcd 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/text_expansion_search.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/text_expansion_search.yml @@ -345,3 +345,15 @@ setup: model_id: text_expansion_model model_text: "octopus comforter smells" pruning_config: { } +--- +"Test text-expansion that displays error for invalid queried field type without prune config": + - do: + catch: /\[keyword\] is not an appropriate field type for this query/ + search: + index: index-with-rank-features + body: + query: + text_expansion: + source_text: + model_id: text_expansion_model + model_text: "octopus comforter smells"