Skip to content

Commit fa5cfe7

Browse files
fix knn combined query
1 parent 37bfc43 commit fa5cfe7

File tree

2 files changed

+55
-30
lines changed

2 files changed

+55
-30
lines changed

server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -265,24 +265,6 @@ public KnnVectorQueryBuilder(StreamInput in) throws IOException {
265265
this.queryVectorSupplier = null;
266266
}
267267

268-
public KnnVectorQueryBuilder(KnnVectorQueryBuilder queryBuilder) {
269-
this(queryBuilder, queryBuilder.getFieldName(), queryBuilder.queryVectorBuilder());
270-
}
271-
272-
public KnnVectorQueryBuilder(KnnVectorQueryBuilder queryBuilder, String fieldName, QueryVectorBuilder queryVectorBuilder) {
273-
this(
274-
fieldName,
275-
queryBuilder.queryVector(),
276-
queryVectorBuilder,
277-
null,
278-
queryBuilder.k(),
279-
queryBuilder.numCands(),
280-
queryBuilder.rescoreVectorBuilder(),
281-
queryBuilder.getVectorSimilarity()
282-
);
283-
this.filterQueries.addAll(queryBuilder.filterQueries());
284-
}
285-
286268
public String getFieldName() {
287269
return fieldName;
288270
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticKnnVectorQueryRewriteInterceptor.java

Lines changed: 55 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -52,20 +52,16 @@ protected QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceI
5252
assert (queryBuilder instanceof KnnVectorQueryBuilder);
5353
KnnVectorQueryBuilder knnVectorQueryBuilder = (KnnVectorQueryBuilder) queryBuilder;
5454
Map<String, List<String>> inferenceIdsIndices = indexInformation.getInferenceIdsIndices();
55-
QueryBuilder finalQueryBuilder;
5655
if (inferenceIdsIndices.size() == 1) {
5756
// Simple case, everything uses the same inference ID
5857
Map.Entry<String, List<String>> inferenceIdIndex = inferenceIdsIndices.entrySet().iterator().next();
5958
String searchInferenceId = inferenceIdIndex.getKey();
6059
List<String> indices = inferenceIdIndex.getValue();
61-
finalQueryBuilder = buildNestedQueryFromKnnVectorQuery(knnVectorQueryBuilder, indices, searchInferenceId);
60+
return buildNestedQueryFromKnnVectorQuery(knnVectorQueryBuilder, indices, searchInferenceId);
6261
} else {
6362
// Multiple inference IDs, construct a boolean query
64-
finalQueryBuilder = buildInferenceQueryWithMultipleInferenceIds(knnVectorQueryBuilder, inferenceIdsIndices);
63+
return buildInferenceQueryWithMultipleInferenceIds(knnVectorQueryBuilder, inferenceIdsIndices);
6564
}
66-
finalQueryBuilder.boost(queryBuilder.boost());
67-
finalQueryBuilder.queryName(queryBuilder.queryName());
68-
return finalQueryBuilder;
6965
}
7066

7167
private QueryBuilder buildInferenceQueryWithMultipleInferenceIds(
@@ -106,8 +102,6 @@ protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
106102
)
107103
);
108104
}
109-
boolQueryBuilder.boost(queryBuilder.boost());
110-
boolQueryBuilder.queryName(queryBuilder.queryName());
111105
return boolQueryBuilder;
112106
}
113107

@@ -124,17 +118,37 @@ private QueryBuilder buildNestedQueryFromKnnVectorQuery(
124118
}
125119
return QueryBuilders.nestedQuery(
126120
SemanticTextField.getChunksFieldName(filteredKnnVectorQueryBuilder.getFieldName()),
127-
new KnnVectorQueryBuilder(
128-
filteredKnnVectorQueryBuilder,
121+
buildNewKnnVectorQuery(
129122
SemanticTextField.getEmbeddingsFieldName(filteredKnnVectorQueryBuilder.getFieldName()),
123+
filteredKnnVectorQueryBuilder,
130124
queryVectorBuilder
131125
),
132126
ScoreMode.Max
133-
);
127+
).queryName(knnVectorQueryBuilder.queryName()).boost(knnVectorQueryBuilder.boost());
134128
}
135129

136130
private KnnVectorQueryBuilder addIndexFilterToKnnVectorQuery(Collection<String> indices, KnnVectorQueryBuilder original) {
137-
KnnVectorQueryBuilder copy = new KnnVectorQueryBuilder(original);
131+
KnnVectorQueryBuilder copy;
132+
if (original.queryVectorBuilder() != null) {
133+
copy = new KnnVectorQueryBuilder(
134+
original.getFieldName(),
135+
original.queryVectorBuilder(),
136+
original.k(),
137+
original.numCands(),
138+
original.getVectorSimilarity()
139+
);
140+
} else {
141+
copy = new KnnVectorQueryBuilder(
142+
original.getFieldName(),
143+
original.queryVector(),
144+
original.k(),
145+
original.numCands(),
146+
original.rescoreVectorBuilder(),
147+
original.getVectorSimilarity()
148+
);
149+
}
150+
151+
copy.addFilterQueries(original.filterQueries());
138152
copy.addFilterQuery(new TermsQueryBuilder(IndexFieldMapper.NAME, indices));
139153
return copy;
140154
}
@@ -148,6 +162,35 @@ private TextEmbeddingQueryVectorBuilder getTextEmbeddingQueryBuilderFromQuery(Kn
148162
return (TextEmbeddingQueryVectorBuilder) queryVectorBuilder;
149163
}
150164

165+
private KnnVectorQueryBuilder buildNewKnnVectorQuery(
166+
String fieldName,
167+
KnnVectorQueryBuilder original,
168+
QueryVectorBuilder queryVectorBuilder
169+
) {
170+
KnnVectorQueryBuilder newQueryBuilder;
171+
if (original.queryVectorBuilder() != null) {
172+
newQueryBuilder = new KnnVectorQueryBuilder(
173+
fieldName,
174+
queryVectorBuilder,
175+
original.k(),
176+
original.numCands(),
177+
original.getVectorSimilarity()
178+
);
179+
} else {
180+
newQueryBuilder = new KnnVectorQueryBuilder(
181+
fieldName,
182+
original.queryVector(),
183+
original.k(),
184+
original.numCands(),
185+
original.rescoreVectorBuilder(),
186+
original.getVectorSimilarity()
187+
);
188+
}
189+
190+
newQueryBuilder.addFilterQueries(original.filterQueries());
191+
return newQueryBuilder;
192+
}
193+
151194
@Override
152195
public String getQueryName() {
153196
return KnnVectorQueryBuilder.NAME;

0 commit comments

Comments
 (0)