Skip to content

Commit 9df2a86

Browse files
committed
using weighted tokens as criteria for rewriting text expansion query
1 parent aea3034 commit 9df2a86

File tree

2 files changed

+5
-6
lines changed

2 files changed

+5
-6
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextExpansionResults.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,4 +96,8 @@ public Map<String, Object> asMap(String outputField) {
9696
map.put(outputField, weightedTokens.stream().collect(Collectors.toMap(WeightedToken::token, WeightedToken::weight)));
9797
return map;
9898
}
99+
100+
public boolean hasWeightedTokens() {
101+
return this.getWeightedTokens().isEmpty() == false;
102+
}
99103
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
import java.io.IOException;
3939
import java.util.List;
4040
import java.util.Objects;
41-
import java.util.Set;
4241

4342
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
4443
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
@@ -60,8 +59,6 @@ public class TextExpansionQueryBuilder extends AbstractQueryBuilder<TextExpansio
6059
private SetOnce<TextExpansionResults> weightedTokensSupplier;
6160
private final TokenPruningConfig tokenPruningConfig;
6261

63-
private static final Set<String> ALLOWED_FIELD_TYPES = Set.of("sparse_vector", "rank_features");
64-
6562
private static final DeprecationLogger deprecationLogger = DeprecationLogger.getLogger(ParseField.class);
6663
public static final String TEXT_EXPANSION_DEPRECATION_MESSAGE = NAME + " is deprecated. Use sparse_vector instead.";
6764

@@ -161,8 +158,6 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
161158
return weightedTokensToQuery(fieldName, weightedTokensSupplier.get());
162159
}
163160

164-
// Do field type check if query won't be rewritten as a WeightedTokensQuery
165-
166161
CoordinatedInferenceAction.Request inferRequest = CoordinatedInferenceAction.Request.forTextInput(
167162
modelId,
168163
List.of(modelText),
@@ -213,7 +208,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
213208
}
214209

215210
private QueryBuilder weightedTokensToQuery(String fieldName, TextExpansionResults textExpansionResults) {
216-
if (tokenPruningConfig != null) {
211+
if (tokenPruningConfig != null || textExpansionResults.hasWeightedTokens()) {
217212
WeightedTokensQueryBuilder weightedTokensQueryBuilder = new WeightedTokensQueryBuilder(
218213
fieldName,
219214
textExpansionResults.getWeightedTokens(),

0 commit comments

Comments
 (0)