Skip to content

Commit 44f3791

Browse files
authored
Updating toXContent implementation for retrievers (#114017)
1 parent 7bbebbd commit 44f3791

File tree

10 files changed

+187
-25
lines changed

10 files changed

+187
-25
lines changed

server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilder.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,11 +251,19 @@ public ActionRequestValidationException validate(
251251
@Override
252252
public final XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
253253
builder.startObject();
254+
builder.startObject(getName());
254255
if (preFilterQueryBuilders.isEmpty() == false) {
255256
builder.field(PRE_FILTER_FIELD.getPreferredName(), preFilterQueryBuilders);
256257
}
258+
if (minScore != null) {
259+
builder.field(MIN_SCORE_FIELD.getPreferredName(), minScore);
260+
}
261+
if (retrieverName != null) {
262+
builder.field(NAME_FIELD.getPreferredName(), retrieverName);
263+
}
257264
doToXContent(builder, params);
258265
builder.endObject();
266+
builder.endObject();
259267

260268
return builder;
261269
}

server/src/test/java/org/elasticsearch/search/builder/SearchSourceBuilderTests.java

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141
import org.elasticsearch.search.collapse.CollapseBuilderTests;
4242
import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder;
4343
import org.elasticsearch.search.rescore.QueryRescorerBuilder;
44+
import org.elasticsearch.search.retriever.KnnRetrieverBuilder;
45+
import org.elasticsearch.search.retriever.StandardRetrieverBuilder;
4446
import org.elasticsearch.search.slice.SliceBuilder;
4547
import org.elasticsearch.search.sort.FieldSortBuilder;
4648
import org.elasticsearch.search.sort.ScoreSortBuilder;
@@ -600,6 +602,75 @@ public void testNegativeTrackTotalHits() throws IOException {
600602
}
601603
}
602604

605+
public void testStandardRetrieverParsing() throws IOException {
606+
String restContent = "{"
607+
+ " \"retriever\": {"
608+
+ " \"standard\": {"
609+
+ " \"query\": {"
610+
+ " \"match_all\": {}"
611+
+ " },"
612+
+ " \"min_score\": 10,"
613+
+ " \"_name\": \"foo_standard\""
614+
+ " }"
615+
+ " }"
616+
+ "}";
617+
SearchUsageHolder searchUsageHolder = new UsageService().getSearchUsageHolder();
618+
try (XContentParser jsonParser = createParser(JsonXContent.jsonXContent, restContent)) {
619+
SearchSourceBuilder source = new SearchSourceBuilder().parseXContent(jsonParser, true, searchUsageHolder, nf -> true);
620+
assertThat(source.retriever(), instanceOf(StandardRetrieverBuilder.class));
621+
StandardRetrieverBuilder parsed = (StandardRetrieverBuilder) source.retriever();
622+
assertThat(parsed.minScore(), equalTo(10f));
623+
assertThat(parsed.retrieverName(), equalTo("foo_standard"));
624+
try (XContentParser parseSerialized = createParser(JsonXContent.jsonXContent, Strings.toString(source))) {
625+
SearchSourceBuilder deserializedSource = new SearchSourceBuilder().parseXContent(
626+
parseSerialized,
627+
true,
628+
searchUsageHolder,
629+
nf -> true
630+
);
631+
assertThat(deserializedSource.retriever(), instanceOf(StandardRetrieverBuilder.class));
632+
StandardRetrieverBuilder deserialized = (StandardRetrieverBuilder) source.retriever();
633+
assertThat(parsed, equalTo(deserialized));
634+
}
635+
}
636+
}
637+
638+
public void testKnnRetrieverParsing() throws IOException {
639+
String restContent = "{"
640+
+ " \"retriever\": {"
641+
+ " \"knn\": {"
642+
+ " \"query_vector\": ["
643+
+ " 3"
644+
+ " ],"
645+
+ " \"field\": \"vector\","
646+
+ " \"k\": 10,"
647+
+ " \"num_candidates\": 15,"
648+
+ " \"min_score\": 10,"
649+
+ " \"_name\": \"foo_knn\""
650+
+ " }"
651+
+ " }"
652+
+ "}";
653+
SearchUsageHolder searchUsageHolder = new UsageService().getSearchUsageHolder();
654+
try (XContentParser jsonParser = createParser(JsonXContent.jsonXContent, restContent)) {
655+
SearchSourceBuilder source = new SearchSourceBuilder().parseXContent(jsonParser, true, searchUsageHolder, nf -> true);
656+
assertThat(source.retriever(), instanceOf(KnnRetrieverBuilder.class));
657+
KnnRetrieverBuilder parsed = (KnnRetrieverBuilder) source.retriever();
658+
assertThat(parsed.minScore(), equalTo(10f));
659+
assertThat(parsed.retrieverName(), equalTo("foo_knn"));
660+
try (XContentParser parseSerialized = createParser(JsonXContent.jsonXContent, Strings.toString(source))) {
661+
SearchSourceBuilder deserializedSource = new SearchSourceBuilder().parseXContent(
662+
parseSerialized,
663+
true,
664+
searchUsageHolder,
665+
nf -> true
666+
);
667+
assertThat(deserializedSource.retriever(), instanceOf(KnnRetrieverBuilder.class));
668+
KnnRetrieverBuilder deserialized = (KnnRetrieverBuilder) source.retriever();
669+
assertThat(parsed, equalTo(deserialized));
670+
}
671+
}
672+
}
673+
603674
public void testStoredFieldsUsage() throws IOException {
604675
Set<String> storedFieldRestVariations = Set.of(
605676
"{\"stored_fields\" : [\"_none_\"]}",

server/src/test/java/org/elasticsearch/search/retriever/KnnRetrieverBuilderParsingTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ protected KnnRetrieverBuilder createTestInstance() {
7474

7575
@Override
7676
protected KnnRetrieverBuilder doParseInstance(XContentParser parser) throws IOException {
77-
return KnnRetrieverBuilder.fromXContent(
77+
return (KnnRetrieverBuilder) RetrieverBuilder.parseTopLevelRetrieverBuilder(
7878
parser,
7979
new RetrieverParserContext(
8080
new SearchUsage(),

server/src/test/java/org/elasticsearch/search/retriever/StandardRetrieverBuilderParsingTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ protected StandardRetrieverBuilder createTestInstance() {
9898

9999
@Override
100100
protected StandardRetrieverBuilder doParseInstance(XContentParser parser) throws IOException {
101-
return StandardRetrieverBuilder.fromXContent(
101+
return (StandardRetrieverBuilder) RetrieverBuilder.parseTopLevelRetrieverBuilder(
102102
parser,
103103
new RetrieverParserContext(
104104
new SearchUsage(),

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/random/RandomRankRetrieverBuilder.java

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,7 @@ public int rankWindowSize() {
103103

104104
@Override
105105
protected void doToXContent(XContentBuilder builder, Params params) throws IOException {
106-
builder.field(RETRIEVER_FIELD.getPreferredName());
107-
builder.startObject();
108-
builder.field(retrieverBuilder.getName(), retrieverBuilder);
109-
builder.endObject();
106+
builder.field(RETRIEVER_FIELD.getPreferredName(), retrieverBuilder);
110107
builder.field(FIELD_FIELD.getPreferredName(), field);
111108
builder.field(RANK_WINDOW_SIZE_FIELD.getPreferredName(), rankWindowSize);
112109
if (seed != null) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -179,17 +179,11 @@ public int rankWindowSize() {
179179

180180
@Override
181181
protected void doToXContent(XContentBuilder builder, Params params) throws IOException {
182-
builder.field(RETRIEVER_FIELD.getPreferredName());
183-
builder.startObject();
184-
builder.field(retrieverBuilder.getName(), retrieverBuilder);
185-
builder.endObject();
182+
builder.field(RETRIEVER_FIELD.getPreferredName(), retrieverBuilder);
186183
builder.field(INFERENCE_ID_FIELD.getPreferredName(), inferenceId);
187184
builder.field(INFERENCE_TEXT_FIELD.getPreferredName(), inferenceText);
188185
builder.field(FIELD_FIELD.getPreferredName(), field);
189186
builder.field(RANK_WINDOW_SIZE_FIELD.getPreferredName(), rankWindowSize);
190-
if (minScore != null) {
191-
builder.field(MIN_SCORE_FIELD.getPreferredName(), minScore);
192-
}
193187
}
194188

195189
@Override

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/random/RandomRankRetrieverBuilderTests.java

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
import org.elasticsearch.xcontent.ParseField;
1818
import org.elasticsearch.xcontent.XContentParser;
1919
import org.elasticsearch.xcontent.json.JsonXContent;
20-
import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankBuilder;
21-
import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder;
2220

2321
import java.io.IOException;
2422
import java.util.ArrayList;
@@ -48,8 +46,8 @@ protected RandomRankRetrieverBuilder createTestInstance() {
4846
}
4947

5048
@Override
51-
protected RandomRankRetrieverBuilder doParseInstance(XContentParser parser) {
52-
return RandomRankRetrieverBuilder.PARSER.apply(
49+
protected RandomRankRetrieverBuilder doParseInstance(XContentParser parser) throws IOException {
50+
return (RandomRankRetrieverBuilder) RetrieverBuilder.parseTopLevelRetrieverBuilder(
5351
parser,
5452
new RetrieverParserContext(
5553
new SearchUsage(),
@@ -77,8 +75,8 @@ protected NamedXContentRegistry xContentRegistry() {
7775
entries.add(
7876
new NamedXContentRegistry.Entry(
7977
RetrieverBuilder.class,
80-
new ParseField(TextSimilarityRankBuilder.NAME),
81-
(p, c) -> TextSimilarityRankRetrieverBuilder.PARSER.apply(p, (RetrieverParserContext) c)
78+
new ParseField(RandomRankBuilder.NAME),
79+
(p, c) -> RandomRankRetrieverBuilder.PARSER.apply(p, (RetrieverParserContext) c)
8280
)
8381
);
8482
return new NamedXContentRegistry(entries);

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilderTests.java

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
package org.elasticsearch.xpack.inference.rank.textsimilarity;
99

1010
import org.elasticsearch.action.search.SearchRequest;
11+
import org.elasticsearch.common.Strings;
1112
import org.elasticsearch.index.query.BoolQueryBuilder;
1213
import org.elasticsearch.index.query.MatchAllQueryBuilder;
1314
import org.elasticsearch.index.query.MatchNoneQueryBuilder;
@@ -25,6 +26,8 @@
2526
import org.elasticsearch.test.AbstractXContentTestCase;
2627
import org.elasticsearch.test.ESTestCase;
2728
import org.elasticsearch.usage.SearchUsage;
29+
import org.elasticsearch.usage.SearchUsageHolder;
30+
import org.elasticsearch.usage.UsageService;
2831
import org.elasticsearch.xcontent.NamedXContentRegistry;
2932
import org.elasticsearch.xcontent.ParseField;
3033
import org.elasticsearch.xcontent.XContentParser;
@@ -72,8 +75,8 @@ protected TextSimilarityRankRetrieverBuilder createTestInstance() {
7275
}
7376

7477
@Override
75-
protected TextSimilarityRankRetrieverBuilder doParseInstance(XContentParser parser) {
76-
return TextSimilarityRankRetrieverBuilder.PARSER.apply(
78+
protected TextSimilarityRankRetrieverBuilder doParseInstance(XContentParser parser) throws IOException {
79+
return (TextSimilarityRankRetrieverBuilder) RetrieverBuilder.parseTopLevelRetrieverBuilder(
7780
parser,
7881
new RetrieverParserContext(
7982
new SearchUsage(),
@@ -208,6 +211,45 @@ public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder
208211
}
209212
}
210213

214+
public void testTextSimilarityRetrieverParsing() throws IOException {
215+
String restContent = "{"
216+
+ " \"retriever\": {"
217+
+ " \"text_similarity_reranker\": {"
218+
+ " \"retriever\": {"
219+
+ " \"test\": {"
220+
+ " \"value\": \"my-test-retriever\""
221+
+ " }"
222+
+ " },"
223+
+ " \"field\": \"my-field\","
224+
+ " \"inference_id\": \"my-inference-id\","
225+
+ " \"inference_text\": \"my-inference-text\","
226+
+ " \"rank_window_size\": 100,"
227+
+ " \"min_score\": 20.0,"
228+
+ " \"_name\": \"foo_reranker\""
229+
+ " }"
230+
+ " }"
231+
+ "}";
232+
SearchUsageHolder searchUsageHolder = new UsageService().getSearchUsageHolder();
233+
try (XContentParser jsonParser = createParser(JsonXContent.jsonXContent, restContent)) {
234+
SearchSourceBuilder source = new SearchSourceBuilder().parseXContent(jsonParser, true, searchUsageHolder, nf -> true);
235+
assertThat(source.retriever(), instanceOf(TextSimilarityRankRetrieverBuilder.class));
236+
TextSimilarityRankRetrieverBuilder parsed = (TextSimilarityRankRetrieverBuilder) source.retriever();
237+
assertThat(parsed.minScore(), equalTo(20f));
238+
assertThat(parsed.retrieverName(), equalTo("foo_reranker"));
239+
try (XContentParser parseSerialized = createParser(JsonXContent.jsonXContent, Strings.toString(source))) {
240+
SearchSourceBuilder deserializedSource = new SearchSourceBuilder().parseXContent(
241+
parseSerialized,
242+
true,
243+
searchUsageHolder,
244+
nf -> true
245+
);
246+
assertThat(deserializedSource.retriever(), instanceOf(TextSimilarityRankRetrieverBuilder.class));
247+
TextSimilarityRankRetrieverBuilder deserialized = (TextSimilarityRankRetrieverBuilder) source.retriever();
248+
assertThat(parsed, equalTo(deserialized));
249+
}
250+
}
251+
}
252+
211253
public void testIsCompound() {
212254
RetrieverBuilder compoundInnerRetriever = new TestRetrieverBuilder(ESTestCase.randomAlphaOfLengthBetween(5, 10)) {
213255
@Override

x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,10 +180,7 @@ public void doToXContent(XContentBuilder builder, Params params) throws IOExcept
180180
builder.startArray(RETRIEVERS_FIELD.getPreferredName());
181181

182182
for (var entry : innerRetrievers) {
183-
builder.startObject();
184-
builder.field(entry.retriever().getName());
185183
entry.retriever().toXContent(builder, params);
186-
builder.endObject();
187184
}
188185
builder.endArray();
189186
}

x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderParsingTests.java

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,27 @@
88
package org.elasticsearch.xpack.rank.rrf;
99

1010
import org.elasticsearch.action.search.SearchRequest;
11+
import org.elasticsearch.common.Strings;
12+
import org.elasticsearch.search.builder.SearchSourceBuilder;
1113
import org.elasticsearch.search.retriever.RetrieverBuilder;
1214
import org.elasticsearch.search.retriever.RetrieverParserContext;
1315
import org.elasticsearch.search.retriever.TestRetrieverBuilder;
1416
import org.elasticsearch.test.AbstractXContentTestCase;
1517
import org.elasticsearch.usage.SearchUsage;
18+
import org.elasticsearch.usage.SearchUsageHolder;
19+
import org.elasticsearch.usage.UsageService;
1620
import org.elasticsearch.xcontent.NamedXContentRegistry;
1721
import org.elasticsearch.xcontent.ParseField;
1822
import org.elasticsearch.xcontent.XContentParser;
23+
import org.elasticsearch.xcontent.json.JsonXContent;
1924

2025
import java.io.IOException;
2126
import java.util.ArrayList;
2227
import java.util.List;
2328

29+
import static org.hamcrest.Matchers.equalTo;
30+
import static org.hamcrest.Matchers.instanceOf;
31+
2432
public class RRFRetrieverBuilderParsingTests extends AbstractXContentTestCase<RRFRetrieverBuilder> {
2533

2634
/**
@@ -53,7 +61,10 @@ protected RRFRetrieverBuilder createTestInstance() {
5361

5462
@Override
5563
protected RRFRetrieverBuilder doParseInstance(XContentParser parser) throws IOException {
56-
return RRFRetrieverBuilder.PARSER.apply(parser, new RetrieverParserContext(new SearchUsage(), nf -> true));
64+
return (RRFRetrieverBuilder) RetrieverBuilder.parseTopLevelRetrieverBuilder(
65+
parser,
66+
new RetrieverParserContext(new SearchUsage(), nf -> true)
67+
);
5768
}
5869

5970
@Override
@@ -81,4 +92,48 @@ protected NamedXContentRegistry xContentRegistry() {
8192
);
8293
return new NamedXContentRegistry(entries);
8394
}
95+
96+
public void testRRFRetrieverParsing() throws IOException {
97+
String restContent = "{"
98+
+ " \"retriever\": {"
99+
+ " \"rrf\": {"
100+
+ " \"retrievers\": ["
101+
+ " {"
102+
+ " \"test\": {"
103+
+ " \"value\": \"foo\""
104+
+ " }"
105+
+ " },"
106+
+ " {"
107+
+ " \"test\": {"
108+
+ " \"value\": \"bar\""
109+
+ " }"
110+
+ " }"
111+
+ " ],"
112+
+ " \"rank_window_size\": 100,"
113+
+ " \"rank_constant\": 10,"
114+
+ " \"min_score\": 20.0,"
115+
+ " \"_name\": \"foo_rrf\""
116+
+ " }"
117+
+ " }"
118+
+ "}";
119+
SearchUsageHolder searchUsageHolder = new UsageService().getSearchUsageHolder();
120+
try (XContentParser jsonParser = createParser(JsonXContent.jsonXContent, restContent)) {
121+
SearchSourceBuilder source = new SearchSourceBuilder().parseXContent(jsonParser, true, searchUsageHolder, nf -> true);
122+
assertThat(source.retriever(), instanceOf(RRFRetrieverBuilder.class));
123+
RRFRetrieverBuilder parsed = (RRFRetrieverBuilder) source.retriever();
124+
assertThat(parsed.minScore(), equalTo(20f));
125+
assertThat(parsed.retrieverName(), equalTo("foo_rrf"));
126+
try (XContentParser parseSerialized = createParser(JsonXContent.jsonXContent, Strings.toString(source))) {
127+
SearchSourceBuilder deserializedSource = new SearchSourceBuilder().parseXContent(
128+
parseSerialized,
129+
true,
130+
searchUsageHolder,
131+
nf -> true
132+
);
133+
assertThat(deserializedSource.retriever(), instanceOf(RRFRetrieverBuilder.class));
134+
RRFRetrieverBuilder deserialized = (RRFRetrieverBuilder) source.retriever();
135+
assertThat(parsed, equalTo(deserialized));
136+
}
137+
}
138+
}
84139
}

0 commit comments

Comments
 (0)