Skip to content

Commit 3f99211

Browse files
authored
Add support for knn vector queries on semantic_text fields (#119011)
* First cut at KNN interceptor * Move TextEmbeddingQueryVectorBuilder to xpack-core * Infer model ID * Fix test compilation errors * Update docs/changelog/119011.yaml * Update changelog * Update test * Cleanup * PR feedback * Add yaml test * Cleanup tests * Update test * Minor PR feedback * refactor pre filter indices for knn * PR feedback
1 parent 45383c8 commit 3f99211

File tree

11 files changed

+735
-8
lines changed

11 files changed

+735
-8
lines changed

docs/changelog/119011.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 119011
2+
summary: "Add support for knn vector queries on `semantic_text` fields"
3+
area: Search
4+
type: enhancement
5+
issues: []

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ static TransportVersion def(int id) {
149149
public static final TransportVersion TRANSFORMS_UPGRADE_MODE = def(8_814_00_0);
150150
public static final TransportVersion NODE_SHUTDOWN_EPHEMERAL_ID_ADDED = def(8_815_00_0);
151151
public static final TransportVersion ESQL_CCS_TELEMETRY_STATS = def(8_816_00_0);
152+
public static final TransportVersion TEXT_EMBEDDING_QUERY_VECTOR_BUILDER_INFER_MODEL_ID = def(8_817_00_0);
152153

153154
/*
154155
* STOP! READ THIS FIRST! No, really,

x-pack/plugin/core/src/main/java/module-info.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@
126126
exports org.elasticsearch.xpack.core.ml.stats;
127127
exports org.elasticsearch.xpack.core.ml.utils.time;
128128
exports org.elasticsearch.xpack.core.ml.utils;
129+
exports org.elasticsearch.xpack.core.ml.vectors;
129130
exports org.elasticsearch.xpack.core.ml;
130131
exports org.elasticsearch.xpack.core.monitoring.action;
131132
exports org.elasticsearch.xpack.core.monitoring.exporter;
Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
* 2.0.
66
*/
77

8-
package org.elasticsearch.xpack.ml.vectors;
8+
package org.elasticsearch.xpack.core.ml.vectors;
99

1010
import org.elasticsearch.TransportVersion;
1111
import org.elasticsearch.TransportVersions;
@@ -30,7 +30,9 @@
3030
import java.util.List;
3131
import java.util.Objects;
3232

33+
import static org.elasticsearch.TransportVersions.TEXT_EMBEDDING_QUERY_VECTOR_BUILDER_INFER_MODEL_ID;
3334
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
35+
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
3436
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
3537
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
3638

@@ -46,7 +48,7 @@ public class TextEmbeddingQueryVectorBuilder implements QueryVectorBuilder {
4648
);
4749

4850
static {
49-
PARSER.declareString(constructorArg(), TrainedModelConfig.MODEL_ID);
51+
PARSER.declareString(optionalConstructorArg(), TrainedModelConfig.MODEL_ID);
5052
PARSER.declareString(constructorArg(), MODEL_TEXT);
5153
}
5254

@@ -63,7 +65,11 @@ public TextEmbeddingQueryVectorBuilder(String modelId, String modelText) {
6365
}
6466

6567
public TextEmbeddingQueryVectorBuilder(StreamInput in) throws IOException {
66-
this.modelId = in.readString();
68+
if (in.getTransportVersion().onOrAfter(TEXT_EMBEDDING_QUERY_VECTOR_BUILDER_INFER_MODEL_ID)) {
69+
this.modelId = in.readOptionalString();
70+
} else {
71+
this.modelId = in.readString();
72+
}
6773
this.modelText = in.readString();
6874
}
6975

@@ -79,28 +85,40 @@ public TransportVersion getMinimalSupportedVersion() {
7985

8086
@Override
8187
public void writeTo(StreamOutput out) throws IOException {
82-
out.writeString(modelId);
88+
if (out.getTransportVersion().onOrAfter(TEXT_EMBEDDING_QUERY_VECTOR_BUILDER_INFER_MODEL_ID)) {
89+
out.writeOptionalString(modelId);
90+
} else {
91+
out.writeString(modelId);
92+
}
8393
out.writeString(modelText);
8494
}
8595

8696
@Override
8797
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
8898
builder.startObject();
89-
builder.field(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId);
99+
if (modelId != null) {
100+
builder.field(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId);
101+
}
90102
builder.field(MODEL_TEXT.getPreferredName(), modelText);
91103
builder.endObject();
92104
return builder;
93105
}
94106

95107
@Override
96108
public void buildVector(Client client, ActionListener<float[]> listener) {
109+
110+
if (modelId == null) {
111+
throw new IllegalArgumentException("[model_id] must not be null.");
112+
}
113+
97114
CoordinatedInferenceAction.Request inferRequest = CoordinatedInferenceAction.Request.forTextInput(
98115
modelId,
99116
List.of(modelText),
100117
TextEmbeddingConfigUpdate.EMPTY_INSTANCE,
101118
false,
102119
InferModelAction.Request.DEFAULT_TIMEOUT_FOR_API
103120
);
121+
104122
inferRequest.setHighPriority(true);
105123
inferRequest.setPrefixType(TrainedModelPrefixStrings.PrefixType.SEARCH);
106124

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import java.util.Set;
1919

20+
import static org.elasticsearch.xpack.inference.queries.SemanticKnnVectorQueryRewriteInterceptor.SEMANTIC_KNN_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED;
2021
import static org.elasticsearch.xpack.inference.queries.SemanticMatchQueryRewriteInterceptor.SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED;
2122
import static org.elasticsearch.xpack.inference.queries.SemanticSparseVectorQueryRewriteInterceptor.SEMANTIC_SPARSE_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED;
2223

@@ -50,7 +51,8 @@ public Set<NodeFeature> getTestFeatures() {
5051
SEMANTIC_TEXT_HIGHLIGHTER,
5152
SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED,
5253
SEMANTIC_SPARSE_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED,
53-
SemanticInferenceMetadataFieldsMapper.EXPLICIT_NULL_FIXES
54+
SemanticInferenceMetadataFieldsMapper.EXPLICIT_NULL_FIXES,
55+
SEMANTIC_KNN_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED
5456
);
5557
}
5658
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
import org.elasticsearch.xpack.inference.mapper.OffsetSourceFieldMapper;
7979
import org.elasticsearch.xpack.inference.mapper.SemanticInferenceMetadataFieldsMapper;
8080
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
81+
import org.elasticsearch.xpack.inference.queries.SemanticKnnVectorQueryRewriteInterceptor;
8182
import org.elasticsearch.xpack.inference.queries.SemanticMatchQueryRewriteInterceptor;
8283
import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder;
8384
import org.elasticsearch.xpack.inference.queries.SemanticSparseVectorQueryRewriteInterceptor;
@@ -445,7 +446,11 @@ public List<QuerySpec<?>> getQueries() {
445446

446447
@Override
447448
public List<QueryRewriteInterceptor> getQueryRewriteInterceptors() {
448-
return List.of(new SemanticMatchQueryRewriteInterceptor(), new SemanticSparseVectorQueryRewriteInterceptor());
449+
return List.of(
450+
new SemanticKnnVectorQueryRewriteInterceptor(),
451+
new SemanticMatchQueryRewriteInterceptor(),
452+
new SemanticSparseVectorQueryRewriteInterceptor()
453+
);
449454
}
450455

451456
@Override
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.queries;
9+
10+
import org.apache.lucene.search.join.ScoreMode;
11+
import org.elasticsearch.features.NodeFeature;
12+
import org.elasticsearch.index.mapper.IndexFieldMapper;
13+
import org.elasticsearch.index.query.BoolQueryBuilder;
14+
import org.elasticsearch.index.query.QueryBuilder;
15+
import org.elasticsearch.index.query.QueryBuilders;
16+
import org.elasticsearch.index.query.TermsQueryBuilder;
17+
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
18+
import org.elasticsearch.search.vectors.QueryVectorBuilder;
19+
import org.elasticsearch.xpack.core.ml.vectors.TextEmbeddingQueryVectorBuilder;
20+
import org.elasticsearch.xpack.inference.mapper.SemanticTextField;
21+
22+
import java.util.Collection;
23+
import java.util.List;
24+
import java.util.Map;
25+
26+
public class SemanticKnnVectorQueryRewriteInterceptor extends SemanticQueryRewriteInterceptor {
27+
28+
public static final NodeFeature SEMANTIC_KNN_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED = new NodeFeature(
29+
"search.semantic_knn_vector_query_rewrite_interception_supported"
30+
);
31+
32+
public SemanticKnnVectorQueryRewriteInterceptor() {}
33+
34+
@Override
35+
protected String getFieldName(QueryBuilder queryBuilder) {
36+
assert (queryBuilder instanceof KnnVectorQueryBuilder);
37+
KnnVectorQueryBuilder knnVectorQueryBuilder = (KnnVectorQueryBuilder) queryBuilder;
38+
return knnVectorQueryBuilder.getFieldName();
39+
}
40+
41+
@Override
42+
protected String getQuery(QueryBuilder queryBuilder) {
43+
assert (queryBuilder instanceof KnnVectorQueryBuilder);
44+
KnnVectorQueryBuilder knnVectorQueryBuilder = (KnnVectorQueryBuilder) queryBuilder;
45+
TextEmbeddingQueryVectorBuilder queryVectorBuilder = getTextEmbeddingQueryBuilderFromQuery(knnVectorQueryBuilder);
46+
return queryVectorBuilder != null ? queryVectorBuilder.getModelText() : null;
47+
}
48+
49+
@Override
50+
protected QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceIndexInformationForField indexInformation) {
51+
assert (queryBuilder instanceof KnnVectorQueryBuilder);
52+
KnnVectorQueryBuilder knnVectorQueryBuilder = (KnnVectorQueryBuilder) queryBuilder;
53+
Map<String, List<String>> inferenceIdsIndices = indexInformation.getInferenceIdsIndices();
54+
if (inferenceIdsIndices.size() == 1) {
55+
// Simple case, everything uses the same inference ID
56+
Map.Entry<String, List<String>> inferenceIdIndex = inferenceIdsIndices.entrySet().iterator().next();
57+
String searchInferenceId = inferenceIdIndex.getKey();
58+
List<String> indices = inferenceIdIndex.getValue();
59+
return buildNestedQueryFromKnnVectorQuery(knnVectorQueryBuilder, indices, searchInferenceId);
60+
} else {
61+
// Multiple inference IDs, construct a boolean query
62+
return buildInferenceQueryWithMultipleInferenceIds(knnVectorQueryBuilder, inferenceIdsIndices);
63+
}
64+
}
65+
66+
private QueryBuilder buildInferenceQueryWithMultipleInferenceIds(
67+
KnnVectorQueryBuilder queryBuilder,
68+
Map<String, List<String>> inferenceIdsIndices
69+
) {
70+
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
71+
for (String inferenceId : inferenceIdsIndices.keySet()) {
72+
boolQueryBuilder.should(
73+
createSubQueryForIndices(
74+
inferenceIdsIndices.get(inferenceId),
75+
buildNestedQueryFromKnnVectorQuery(queryBuilder, inferenceIdsIndices.get(inferenceId), inferenceId)
76+
)
77+
);
78+
}
79+
return boolQueryBuilder;
80+
}
81+
82+
@Override
83+
protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
84+
QueryBuilder queryBuilder,
85+
InferenceIndexInformationForField indexInformation
86+
) {
87+
assert (queryBuilder instanceof KnnVectorQueryBuilder);
88+
KnnVectorQueryBuilder knnVectorQueryBuilder = (KnnVectorQueryBuilder) queryBuilder;
89+
Map<String, List<String>> inferenceIdsIndices = indexInformation.getInferenceIdsIndices();
90+
91+
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
92+
boolQueryBuilder.should(addIndexFilterToKnnVectorQuery(indexInformation.nonInferenceIndices(), knnVectorQueryBuilder));
93+
94+
// We always perform nested subqueries on semantic_text fields, to support knn queries using query vectors.
95+
// Both pre and post filtering are required here to ensure we get the results we need without errors based on field types.
96+
for (String inferenceId : inferenceIdsIndices.keySet()) {
97+
boolQueryBuilder.should(
98+
createSubQueryForIndices(
99+
inferenceIdsIndices.get(inferenceId),
100+
buildNestedQueryFromKnnVectorQuery(knnVectorQueryBuilder, inferenceIdsIndices.get(inferenceId), inferenceId)
101+
)
102+
);
103+
}
104+
return boolQueryBuilder;
105+
}
106+
107+
private QueryBuilder buildNestedQueryFromKnnVectorQuery(
108+
KnnVectorQueryBuilder knnVectorQueryBuilder,
109+
List<String> indices,
110+
String searchInferenceId
111+
) {
112+
KnnVectorQueryBuilder filteredKnnVectorQueryBuilder = addIndexFilterToKnnVectorQuery(indices, knnVectorQueryBuilder);
113+
TextEmbeddingQueryVectorBuilder queryVectorBuilder = getTextEmbeddingQueryBuilderFromQuery(filteredKnnVectorQueryBuilder);
114+
if (queryVectorBuilder != null && queryVectorBuilder.getModelId() == null && searchInferenceId != null) {
115+
// If the model ID was not specified, we infer the inference ID associated with the semantic_text field.
116+
queryVectorBuilder = new TextEmbeddingQueryVectorBuilder(searchInferenceId, queryVectorBuilder.getModelText());
117+
}
118+
return QueryBuilders.nestedQuery(
119+
SemanticTextField.getChunksFieldName(filteredKnnVectorQueryBuilder.getFieldName()),
120+
buildNewKnnVectorQuery(
121+
SemanticTextField.getEmbeddingsFieldName(filteredKnnVectorQueryBuilder.getFieldName()),
122+
filteredKnnVectorQueryBuilder,
123+
queryVectorBuilder
124+
),
125+
ScoreMode.Max
126+
);
127+
}
128+
129+
private KnnVectorQueryBuilder addIndexFilterToKnnVectorQuery(Collection<String> indices, KnnVectorQueryBuilder original) {
130+
KnnVectorQueryBuilder copy;
131+
if (original.queryVectorBuilder() != null) {
132+
copy = new KnnVectorQueryBuilder(
133+
original.getFieldName(),
134+
original.queryVectorBuilder(),
135+
original.k(),
136+
original.numCands(),
137+
original.getVectorSimilarity()
138+
);
139+
} else {
140+
copy = new KnnVectorQueryBuilder(
141+
original.getFieldName(),
142+
original.queryVector(),
143+
original.k(),
144+
original.numCands(),
145+
original.rescoreVectorBuilder(),
146+
original.getVectorSimilarity()
147+
);
148+
}
149+
150+
copy.addFilterQuery(new TermsQueryBuilder(IndexFieldMapper.NAME, indices));
151+
return copy;
152+
}
153+
154+
private TextEmbeddingQueryVectorBuilder getTextEmbeddingQueryBuilderFromQuery(KnnVectorQueryBuilder knnVectorQueryBuilder) {
155+
QueryVectorBuilder queryVectorBuilder = knnVectorQueryBuilder.queryVectorBuilder();
156+
if (queryVectorBuilder == null) {
157+
return null;
158+
}
159+
assert (queryVectorBuilder instanceof TextEmbeddingQueryVectorBuilder);
160+
return (TextEmbeddingQueryVectorBuilder) queryVectorBuilder;
161+
}
162+
163+
private KnnVectorQueryBuilder buildNewKnnVectorQuery(
164+
String fieldName,
165+
KnnVectorQueryBuilder original,
166+
QueryVectorBuilder queryVectorBuilder
167+
) {
168+
if (original.queryVectorBuilder() != null) {
169+
return new KnnVectorQueryBuilder(
170+
fieldName,
171+
queryVectorBuilder,
172+
original.k(),
173+
original.numCands(),
174+
original.getVectorSimilarity()
175+
);
176+
} else {
177+
return new KnnVectorQueryBuilder(
178+
fieldName,
179+
original.queryVector(),
180+
original.k(),
181+
original.numCands(),
182+
original.rescoreVectorBuilder(),
183+
original.getVectorSimilarity()
184+
);
185+
}
186+
}
187+
188+
@Override
189+
public String getQueryName() {
190+
return KnnVectorQueryBuilder.NAME;
191+
}
192+
}

0 commit comments

Comments
 (0)