Skip to content

Commit 2cdc289

Browse files
authored
Move SparseVectorQueryBuilder and TextExpansionQueryBuilder to x-pack core (#117857) (#117896)
This commit moves the SparseVectorQueryBuilder and TextExpansionQueryBuilder classes to the x-pack core module, enabling other modules to utilize these query builders. Additionally, it introduces a SparseVectorQueryWrapper to extract sparse vector queries from standard Lucene queries. This is needed for supporting semantic highlighting with sparse vector fields as follow up.
1 parent d62be0f commit 2cdc289

File tree

10 files changed

+125
-53
lines changed

10 files changed

+125
-53
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@
7171
import org.elasticsearch.xpack.core.ml.job.config.JobTaskState;
7272
import org.elasticsearch.xpack.core.ml.job.snapshot.upgrade.SnapshotUpgradeTaskParams;
7373
import org.elasticsearch.xpack.core.ml.job.snapshot.upgrade.SnapshotUpgradeTaskState;
74+
import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder;
75+
import org.elasticsearch.xpack.core.ml.search.TextExpansionQueryBuilder;
7476
import org.elasticsearch.xpack.core.ml.search.WeightedTokensQueryBuilder;
7577
import org.elasticsearch.xpack.core.monitoring.MonitoringFeatureSetUsage;
7678
import org.elasticsearch.xpack.core.rollup.RollupFeatureSetUsage;
@@ -398,6 +400,14 @@ public List<NamedXContentRegistry.Entry> getNamedXContent() {
398400
@Override
399401
public List<SearchPlugin.QuerySpec<?>> getQueries() {
400402
return List.of(
403+
new QuerySpec<>(SparseVectorQueryBuilder.NAME, SparseVectorQueryBuilder::new, SparseVectorQueryBuilder::fromXContent),
404+
new QuerySpec<QueryBuilder>(
405+
TextExpansionQueryBuilder.NAME,
406+
TextExpansionQueryBuilder::new,
407+
TextExpansionQueryBuilder::fromXContent
408+
),
409+
// TODO: The WeightedTokensBuilder is slated for removal after the SparseVectorQueryBuilder is available.
410+
// The logic to create a Boolean query based on weighted tokens will remain and/or be moved to server.
401411
new SearchPlugin.QuerySpec<QueryBuilder>(
402412
WeightedTokensQueryBuilder.NAME,
403413
WeightedTokensQueryBuilder::new,
Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
* 2.0.
66
*/
77

8-
package org.elasticsearch.xpack.ml.queries;
8+
package org.elasticsearch.xpack.core.ml.search;
99

1010
import org.apache.lucene.search.MatchNoDocsQuery;
1111
import org.apache.lucene.search.Query;
@@ -33,9 +33,6 @@
3333
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
3434
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
3535
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfigUpdate;
36-
import org.elasticsearch.xpack.core.ml.search.TokenPruningConfig;
37-
import org.elasticsearch.xpack.core.ml.search.WeightedToken;
38-
import org.elasticsearch.xpack.core.ml.search.WeightedTokensUtils;
3936

4037
import java.io.IOException;
4138
import java.util.ArrayList;
@@ -210,7 +207,7 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException {
210207

211208
return (shouldPruneTokens)
212209
? WeightedTokensUtils.queryBuilderWithPrunedTokens(fieldName, tokenPruningConfig, queryVectors, ft, context)
213-
: WeightedTokensUtils.queryBuilderWithAllTokens(queryVectors, ft, context);
210+
: WeightedTokensUtils.queryBuilderWithAllTokens(fieldName, queryVectors, ft, context);
214211
}
215212

216213
@Override
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.core.ml.search;
9+
10+
import org.apache.lucene.search.BooleanClause;
11+
import org.apache.lucene.search.IndexSearcher;
12+
import org.apache.lucene.search.Query;
13+
import org.apache.lucene.search.QueryVisitor;
14+
import org.apache.lucene.search.ScoreMode;
15+
import org.apache.lucene.search.Weight;
16+
import org.elasticsearch.index.query.SearchExecutionContext;
17+
18+
import java.io.IOException;
19+
import java.util.Objects;
20+
21+
/**
22+
* A wrapper class for the Lucene query generated by {@link SparseVectorQueryBuilder#toQuery(SearchExecutionContext)}.
23+
* This wrapper facilitates the extraction of the complete sparse vector query using a {@link QueryVisitor}.
24+
*/
25+
public class SparseVectorQueryWrapper extends Query {
26+
private final String fieldName;
27+
private final Query termsQuery;
28+
29+
public SparseVectorQueryWrapper(String fieldName, Query termsQuery) {
30+
this.fieldName = fieldName;
31+
this.termsQuery = termsQuery;
32+
}
33+
34+
public Query getTermsQuery() {
35+
return termsQuery;
36+
}
37+
38+
@Override
39+
public Query rewrite(IndexSearcher indexSearcher) throws IOException {
40+
var rewrite = termsQuery.rewrite(indexSearcher);
41+
if (rewrite != termsQuery) {
42+
return new SparseVectorQueryWrapper(fieldName, rewrite);
43+
}
44+
return this;
45+
}
46+
47+
@Override
48+
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
49+
return termsQuery.createWeight(searcher, scoreMode, boost);
50+
}
51+
52+
@Override
53+
public String toString(String field) {
54+
return termsQuery.toString(field);
55+
}
56+
57+
@Override
58+
public void visit(QueryVisitor visitor) {
59+
if (visitor.acceptField(fieldName)) {
60+
termsQuery.visit(visitor.getSubVisitor(BooleanClause.Occur.MUST, this));
61+
}
62+
}
63+
64+
@Override
65+
public boolean equals(Object obj) {
66+
if (sameClassAs(obj) == false) {
67+
return false;
68+
}
69+
SparseVectorQueryWrapper that = (SparseVectorQueryWrapper) obj;
70+
return fieldName.equals(that.fieldName) && termsQuery.equals(that.termsQuery);
71+
}
72+
73+
@Override
74+
public int hashCode() {
75+
return Objects.hash(classHash(), fieldName, termsQuery);
76+
}
77+
}
Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
* 2.0.
66
*/
77

8-
package org.elasticsearch.xpack.ml.queries;
8+
package org.elasticsearch.xpack.core.ml.search;
99

1010
import org.apache.lucene.search.Query;
1111
import org.apache.lucene.util.SetOnce;
@@ -32,8 +32,6 @@
3232
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
3333
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
3434
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfigUpdate;
35-
import org.elasticsearch.xpack.core.ml.search.TokenPruningConfig;
36-
import org.elasticsearch.xpack.core.ml.search.WeightedTokensQueryBuilder;
3735

3836
import java.io.IOException;
3937
import java.util.List;

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/WeightedTokensQueryBuilder.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException {
125125
}
126126

127127
return (this.tokenPruningConfig == null)
128-
? WeightedTokensUtils.queryBuilderWithAllTokens(tokens, ft, context)
128+
? WeightedTokensUtils.queryBuilderWithAllTokens(fieldName, tokens, ft, context)
129129
: WeightedTokensUtils.queryBuilderWithPrunedTokens(fieldName, tokenPruningConfig, tokens, ft, context);
130130
}
131131

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/WeightedTokensUtils.java

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,18 @@ public final class WeightedTokensUtils {
2424

2525
private WeightedTokensUtils() {}
2626

27-
public static Query queryBuilderWithAllTokens(List<WeightedToken> tokens, MappedFieldType ft, SearchExecutionContext context) {
27+
public static Query queryBuilderWithAllTokens(
28+
String fieldName,
29+
List<WeightedToken> tokens,
30+
MappedFieldType ft,
31+
SearchExecutionContext context
32+
) {
2833
var qb = new BooleanQuery.Builder();
2934

3035
for (var token : tokens) {
3136
qb.add(new BoostQuery(ft.termQuery(token.token(), context), token.weight()), BooleanClause.Occur.SHOULD);
3237
}
33-
return qb.setMinimumNumberShouldMatch(1).build();
38+
return new SparseVectorQueryWrapper(fieldName, qb.setMinimumNumberShouldMatch(1).build());
3439
}
3540

3641
public static Query queryBuilderWithPrunedTokens(
@@ -64,7 +69,7 @@ public static Query queryBuilderWithPrunedTokens(
6469
}
6570
}
6671

67-
return qb.setMinimumNumberShouldMatch(1).build();
72+
return new SparseVectorQueryWrapper(fieldName, qb.setMinimumNumberShouldMatch(1).build());
6873
}
6974

7075
/**
Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
* 2.0.
66
*/
77

8-
package org.elasticsearch.xpack.ml.queries;
8+
package org.elasticsearch.xpack.core.ml.search;
99

1010
import org.apache.lucene.document.Document;
1111
import org.apache.lucene.document.FeatureField;
@@ -40,17 +40,14 @@
4040
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
4141
import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings;
4242
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
43-
import org.elasticsearch.xpack.core.ml.search.TokenPruningConfig;
44-
import org.elasticsearch.xpack.core.ml.search.WeightedToken;
45-
import org.elasticsearch.xpack.ml.MachineLearning;
4643

4744
import java.io.IOException;
4845
import java.lang.reflect.Method;
4946
import java.util.ArrayList;
5047
import java.util.Collection;
5148
import java.util.List;
5249

53-
import static org.elasticsearch.xpack.ml.queries.SparseVectorQueryBuilder.QUERY_VECTOR_FIELD;
50+
import static org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder.QUERY_VECTOR_FIELD;
5451
import static org.hamcrest.CoreMatchers.instanceOf;
5552
import static org.hamcrest.Matchers.either;
5653
import static org.hamcrest.Matchers.hasSize;
@@ -102,7 +99,7 @@ private SparseVectorQueryBuilder createTestQueryBuilder(TokenPruningConfig token
10299

103100
@Override
104101
protected Collection<Class<? extends Plugin>> getPlugins() {
105-
return List.of(MachineLearning.class, MapperExtrasPlugin.class, XPackClientPlugin.class);
102+
return List.of(MapperExtrasPlugin.class, XPackClientPlugin.class);
106103
}
107104

108105
@Override
@@ -156,8 +153,10 @@ protected void initializeAdditionalMappings(MapperService mapperService) throws
156153

157154
@Override
158155
protected void doAssertLuceneQuery(SparseVectorQueryBuilder queryBuilder, Query query, SearchExecutionContext context) {
159-
assertThat(query, instanceOf(BooleanQuery.class));
160-
BooleanQuery booleanQuery = (BooleanQuery) query;
156+
assertThat(query, instanceOf(SparseVectorQueryWrapper.class));
157+
var sparseQuery = (SparseVectorQueryWrapper) query;
158+
assertThat(sparseQuery.getTermsQuery(), instanceOf(BooleanQuery.class));
159+
BooleanQuery booleanQuery = (BooleanQuery) sparseQuery.getTermsQuery();
161160
assertEquals(booleanQuery.getMinimumNumberShouldMatch(), 1);
162161
assertThat(booleanQuery.clauses(), hasSize(NUM_TOKENS));
163162

@@ -233,11 +232,13 @@ public void testToQuery() throws IOException {
233232

234233
private void testDoToQuery(SparseVectorQueryBuilder queryBuilder, SearchExecutionContext context) throws IOException {
235234
Query query = queryBuilder.doToQuery(context);
235+
assertTrue(query instanceof SparseVectorQueryWrapper);
236+
var sparseQuery = (SparseVectorQueryWrapper) query;
236237
if (queryBuilder.shouldPruneTokens()) {
237238
// It's possible that all documents were pruned for aggressive pruning configurations
238-
assertTrue(query instanceof BooleanQuery || query instanceof MatchNoDocsQuery);
239+
assertTrue(sparseQuery.getTermsQuery() instanceof BooleanQuery || sparseQuery.getTermsQuery() instanceof MatchNoDocsQuery);
239240
} else {
240-
assertTrue(query instanceof BooleanQuery);
241+
assertTrue(sparseQuery.getTermsQuery() instanceof BooleanQuery);
241242
}
242243
}
243244

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
* 2.0.
66
*/
77

8-
package org.elasticsearch.xpack.ml.queries;
8+
package org.elasticsearch.xpack.core.ml.search;
99

1010
import org.apache.lucene.document.Document;
1111
import org.apache.lucene.document.FeatureField;
@@ -35,10 +35,6 @@
3535
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
3636
import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings;
3737
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
38-
import org.elasticsearch.xpack.core.ml.search.TokenPruningConfig;
39-
import org.elasticsearch.xpack.core.ml.search.WeightedToken;
40-
import org.elasticsearch.xpack.core.ml.search.WeightedTokensQueryBuilder;
41-
import org.elasticsearch.xpack.ml.MachineLearning;
4238

4339
import java.io.IOException;
4440
import java.lang.reflect.Method;
@@ -77,7 +73,7 @@ protected TextExpansionQueryBuilder doCreateTestQueryBuilder() {
7773

7874
@Override
7975
protected Collection<Class<? extends Plugin>> getPlugins() {
80-
return List.of(MachineLearning.class, MapperExtrasPlugin.class, XPackClientPlugin.class);
76+
return List.of(MapperExtrasPlugin.class, XPackClientPlugin.class);
8177
}
8278

8379
@Override
@@ -129,8 +125,10 @@ protected void initializeAdditionalMappings(MapperService mapperService) throws
129125

130126
@Override
131127
protected void doAssertLuceneQuery(TextExpansionQueryBuilder queryBuilder, Query query, SearchExecutionContext context) {
132-
assertThat(query, instanceOf(BooleanQuery.class));
133-
BooleanQuery booleanQuery = (BooleanQuery) query;
128+
assertThat(query, instanceOf(SparseVectorQueryWrapper.class));
129+
var sparseQuery = (SparseVectorQueryWrapper) query;
130+
assertThat(sparseQuery.getTermsQuery(), instanceOf(BooleanQuery.class));
131+
BooleanQuery booleanQuery = (BooleanQuery) sparseQuery.getTermsQuery();
134132
assertEquals(booleanQuery.getMinimumNumberShouldMatch(), 1);
135133
assertThat(booleanQuery.clauses(), hasSize(NUM_TOKENS));
136134

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/search/WeightedTokensQueryBuilderTests.java

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -271,8 +271,11 @@ public void testPruningIsAppliedCorrectly() throws IOException {
271271
}
272272

273273
private void assertCorrectLuceneQuery(String name, Query query, List<String> expectedFeatureFields) {
274-
assertTrue(query instanceof BooleanQuery);
275-
List<BooleanClause> booleanClauses = ((BooleanQuery) query).clauses();
274+
assertThat(query, instanceOf(SparseVectorQueryWrapper.class));
275+
var sparseQuery = (SparseVectorQueryWrapper) query;
276+
assertThat(sparseQuery.getTermsQuery(), instanceOf(BooleanQuery.class));
277+
BooleanQuery booleanQuery = (BooleanQuery) sparseQuery.getTermsQuery();
278+
List<BooleanClause> booleanClauses = booleanQuery.clauses();
276279
assertEquals(
277280
name + " had " + booleanClauses.size() + " clauses, expected " + expectedFeatureFields.size(),
278281
expectedFeatureFields.size(),
@@ -343,8 +346,10 @@ public void testMustRewrite() throws IOException {
343346

344347
@Override
345348
protected void doAssertLuceneQuery(WeightedTokensQueryBuilder queryBuilder, Query query, SearchExecutionContext context) {
346-
assertThat(query, instanceOf(BooleanQuery.class));
347-
BooleanQuery booleanQuery = (BooleanQuery) query;
349+
assertThat(query, instanceOf(SparseVectorQueryWrapper.class));
350+
var sparseQuery = (SparseVectorQueryWrapper) query;
351+
assertThat(sparseQuery.getTermsQuery(), instanceOf(BooleanQuery.class));
352+
BooleanQuery booleanQuery = (BooleanQuery) sparseQuery.getTermsQuery();
348353
assertEquals(booleanQuery.getMinimumNumberShouldMatch(), 1);
349354
assertThat(booleanQuery.clauses(), hasSize(NUM_TOKENS));
350355

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@
4848
import org.elasticsearch.features.NodeFeature;
4949
import org.elasticsearch.index.analysis.CharFilterFactory;
5050
import org.elasticsearch.index.analysis.TokenizerFactory;
51-
import org.elasticsearch.index.query.QueryBuilder;
5251
import org.elasticsearch.indices.AssociatedIndexDescriptor;
5352
import org.elasticsearch.indices.SystemIndexDescriptor;
5453
import org.elasticsearch.indices.analysis.AnalysisModule.AnalysisProvider;
@@ -376,8 +375,6 @@
376375
import org.elasticsearch.xpack.ml.process.MlMemoryTracker;
377376
import org.elasticsearch.xpack.ml.process.NativeController;
378377
import org.elasticsearch.xpack.ml.process.NativeStorageProvider;
379-
import org.elasticsearch.xpack.ml.queries.SparseVectorQueryBuilder;
380-
import org.elasticsearch.xpack.ml.queries.TextExpansionQueryBuilder;
381378
import org.elasticsearch.xpack.ml.rest.RestDeleteExpiredDataAction;
382379
import org.elasticsearch.xpack.ml.rest.RestMlInfoAction;
383380
import org.elasticsearch.xpack.ml.rest.RestMlMemoryAction;
@@ -1764,22 +1761,6 @@ public List<QueryVectorBuilderSpec<?>> getQueryVectorBuilders() {
17641761
);
17651762
}
17661763

1767-
@Override
1768-
public List<QuerySpec<?>> getQueries() {
1769-
return List.of(
1770-
new QuerySpec<QueryBuilder>(
1771-
TextExpansionQueryBuilder.NAME,
1772-
TextExpansionQueryBuilder::new,
1773-
TextExpansionQueryBuilder::fromXContent
1774-
),
1775-
new QuerySpec<QueryBuilder>(
1776-
SparseVectorQueryBuilder.NAME,
1777-
SparseVectorQueryBuilder::new,
1778-
SparseVectorQueryBuilder::fromXContent
1779-
)
1780-
);
1781-
}
1782-
17831764
private <T> ContextParser<String, T> checkAggLicense(ContextParser<String, T> realParser, LicensedFeature.Momentary feature) {
17841765
return (parser, name) -> {
17851766
if (feature.check(getLicenseState()) == false) {

0 commit comments

Comments
 (0)