diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextExpansionResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextExpansionResults.java index 40d7fd221d09f..b4ea1a99d0574 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextExpansionResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextExpansionResults.java @@ -96,4 +96,8 @@ public Map asMap(String outputField) { map.put(outputField, weightedTokens.stream().collect(Collectors.toMap(WeightedToken::token, WeightedToken::weight))); return map; } + + public boolean hasWeightedTokens() { + return this.getWeightedTokens().isEmpty() == false; + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java index 6d972bcf5863a..33ad73f4ccf92 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java @@ -208,7 +208,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) { } private QueryBuilder weightedTokensToQuery(String fieldName, TextExpansionResults textExpansionResults) { - if (tokenPruningConfig != null) { + if (tokenPruningConfig != null || textExpansionResults.hasWeightedTokens()) { WeightedTokensQueryBuilder weightedTokensQueryBuilder = new WeightedTokensQueryBuilder( fieldName, textExpansionResults.getWeightedTokens(), diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilderTests.java index 00d50e0d0d7bb..c62d42c6fa783 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilderTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilderTests.java @@ -283,14 +283,11 @@ protected String[] shuffleProtectedFields() { return new String[] { WeightedTokensQueryBuilder.TOKENS_FIELD.getPreferredName() }; } - public void testThatTokensAreCorrectlyPruned() { + public void testQueryWasRewrittenToASupportedType() { SearchExecutionContext searchExecutionContext = createSearchExecutionContext(); TextExpansionQueryBuilder queryBuilder = createTestQueryBuilder(); QueryBuilder rewrittenQueryBuilder = rewriteAndFetch(queryBuilder, searchExecutionContext); - if (queryBuilder.getTokenPruningConfig() == null) { - assertTrue(rewrittenQueryBuilder instanceof BoolQueryBuilder); - } else { - assertTrue(rewrittenQueryBuilder instanceof WeightedTokensQueryBuilder); - } + + assertTrue(rewrittenQueryBuilder instanceof BoolQueryBuilder || rewrittenQueryBuilder instanceof WeightedTokensQueryBuilder); } } diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/text_expansion_search.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/text_expansion_search.yml index 21a5a4736675d..9a2271f87cfcd 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/text_expansion_search.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/text_expansion_search.yml @@ -345,3 +345,15 @@ setup: model_id: text_expansion_model model_text: "octopus comforter smells" pruning_config: { } +--- +"Test text-expansion that displays error for invalid queried field type without prune config": + - do: + catch: /\[keyword\] is not an appropriate field type for this query/ + search: + index: index-with-rank-features + body: + query: + text_expansion: + source_text: + model_id: text_expansion_model + model_text: "octopus comforter smells"