Skip to content

Commit 0dab8ea

Browse files
committed
Add test for knn retriever
1 parent 4fbbadd commit 0dab8ea

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,11 @@ public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder
257257
searchSourceBuilder.knnSearch(knnSearchBuilders);
258258
}
259259

260-
// ---- FOR TESTING XCONTENT PARSING ----
260+
RescoreVectorBuilder rescoreVectorBuilder() {
261+
return rescoreVectorBuilder;
262+
}
263+
264+
// ---- FOR TESTING XCONTENT PARSING ----
261265

262266
@Override
263267
public void doToXContent(XContentBuilder builder, Params params) throws IOException {
@@ -278,7 +282,9 @@ public void doToXContent(XContentBuilder builder, Params params) throws IOExcept
278282
}
279283

280284
if (rescoreVectorBuilder != null) {
281-
builder.field(RESCORE_FIELD.getPreferredName(), rescoreVectorBuilder);
285+
builder.startObject(RESCORE_FIELD.getPreferredName());
286+
rescoreVectorBuilder.toXContent(builder, params);
287+
builder.endObject();
282288
}
283289
}
284290

server/src/test/java/org/elasticsearch/search/retriever/KnnRetrieverBuilderParsingTests.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ public void testRewrite() throws IOException {
105105
assertNull(source.query());
106106
assertThat(source.knnSearch().size(), equalTo(1));
107107
assertThat(source.knnSearch().get(0).getFilterQueries().size(), equalTo(knnRetriever.preFilterQueryBuilders.size()));
108+
assertThat(source.knnSearch().get(0).getRescoreVectorBuilder(), equalTo(knnRetriever.rescoreVectorBuilder()));
108109
for (int j = 0; j < knnRetriever.preFilterQueryBuilders.size(); j++) {
109110
assertThat(
110111
source.knnSearch().get(0).getFilterQueries().get(j),

0 commit comments

Comments
 (0)