Skip to content

Commit bc81bd0

Browse files
committed
Add a test for lenient rerankers
1 parent c47b8a5 commit bc81bd0

File tree

8 files changed

+122
-125
lines changed

8 files changed

+122
-125
lines changed

server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -182,11 +182,6 @@ private void onPhaseDone(
182182
RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordinatorContext,
183183
SearchPhaseController.ReducedQueryPhase reducedQueryPhase
184184
) {
185-
RankFeatureDoc[] docs = rankPhaseResults.getSuccessfulResults()
186-
.flatMap(r -> Arrays.stream(r.rankFeatureResult().shardResult().rankFeatureDocs))
187-
.filter(rfd -> rfd.featureData != null)
188-
.toArray(RankFeatureDoc[]::new);
189-
190185
ThreadedActionListener<RankFeatureDoc[]> rankResultListener = new ThreadedActionListener<>(
191186
context::execute,
192187
new ActionListener<>() {
@@ -204,18 +199,29 @@ public void onResponse(RankFeatureDoc[] docsWithUpdatedScores) {
204199
public void onFailure(Exception e) {
205200
if (rankFeaturePhaseRankCoordinatorContext.isLenient()) {
206201
// TODO: handle the exception somewhere
207-
logger.warn("Exception computing updated ranks. Continuing with existing ranks.", e);
208-
// use the existing docs as-is
202+
// don't want to log the entire stack trace, it's not helpful here
203+
logger.warn("Exception computing updated ranks: {}. Continuing with existing ranks.", e.toString());
204+
// use the existing score docs as-is
205+
RankFeatureDoc[] existingScores = Arrays.stream(reducedQueryPhase.sortedTopDocs().scoreDocs())
206+
.map(sd -> new RankFeatureDoc(sd.doc, sd.score, sd.shardIndex))
207+
.toArray(RankFeatureDoc[]::new);
208+
209209
// AbstractThreadedActionListener forks onFailure to the same executor as onResponse,
210210
// so we can just call this direct
211-
onResponse(docs);
211+
onResponse(existingScores);
212212
} else {
213213
context.onPhaseFailure(NAME, "Computing updated ranks for results failed", e);
214214
}
215215
}
216216
}
217217
);
218-
rankFeaturePhaseRankCoordinatorContext.computeRankScoresForGlobalResults(docs, rankResultListener);
218+
rankFeaturePhaseRankCoordinatorContext.computeRankScoresForGlobalResults(
219+
rankPhaseResults.getSuccessfulResults()
220+
.flatMap(r -> Arrays.stream(r.rankFeatureResult().shardResult().rankFeatureDocs))
221+
.filter(rfd -> rfd.featureData != null)
222+
.toArray(RankFeatureDoc[]::new),
223+
rankResultListener
224+
);
219225
}
220226

221227
private SearchPhaseController.ReducedQueryPhase newReducedQueryPhaseResults(

server/src/main/java/org/elasticsearch/search/rank/context/RankFeaturePhaseRankCoordinatorContext.java

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,9 @@
1212
import org.apache.lucene.search.ScoreDoc;
1313
import org.elasticsearch.action.ActionListener;
1414
import org.elasticsearch.search.rank.feature.RankFeatureDoc;
15-
import org.elasticsearch.search.rank.feature.RankFeatureResult;
16-
import org.elasticsearch.search.rank.feature.RankFeatureShardResult;
1715

18-
import java.util.ArrayList;
1916
import java.util.Arrays;
2017
import java.util.Comparator;
21-
import java.util.List;
2218

2319
import static org.elasticsearch.search.SearchService.DEFAULT_FROM;
2420
import static org.elasticsearch.search.SearchService.DEFAULT_SIZE;
@@ -97,17 +93,4 @@ public RankFeatureDoc[] rankAndPaginate(RankFeatureDoc[] rankFeatureDocs) {
9793
}
9894
return topResults;
9995
}
100-
101-
private RankFeatureDoc[] extractFeatureDocs(List<RankFeatureResult> rankSearchResults) {
102-
List<RankFeatureDoc> docFeatures = new ArrayList<>();
103-
for (RankFeatureResult rankFeatureResult : rankSearchResults) {
104-
RankFeatureShardResult shardResult = rankFeatureResult.shardResult();
105-
for (RankFeatureDoc rankFeatureDoc : shardResult.rankFeatureDocs) {
106-
if (rankFeatureDoc.featureData != null) {
107-
docFeatures.add(rankFeatureDoc);
108-
}
109-
}
110-
}
111-
return docFeatures.toArray(new RankFeatureDoc[0]);
112-
}
11396
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,17 +184,22 @@ public Float minScore() {
184184
return minScore;
185185
}
186186

187+
public boolean isLenient() {
188+
return lenient;
189+
}
190+
187191
@Override
188192
protected boolean doEquals(RankBuilder other) {
189193
TextSimilarityRankBuilder that = (TextSimilarityRankBuilder) other;
190194
return Objects.equals(inferenceId, that.inferenceId)
191195
&& Objects.equals(inferenceText, that.inferenceText)
192196
&& Objects.equals(field, that.field)
193-
&& Objects.equals(minScore, that.minScore);
197+
&& Objects.equals(minScore, that.minScore)
198+
&& lenient == that.lenient;
194199
}
195200

196201
@Override
197202
protected int doHashCode() {
198-
return Objects.hash(inferenceId, inferenceText, field, minScore);
203+
return Objects.hash(inferenceId, inferenceText, field, minScore, lenient);
199204
}
200205
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,10 @@ public int rankWindowSize() {
188188
return rankWindowSize;
189189
}
190190

191+
public boolean isLenient() {
192+
return lenient;
193+
}
194+
191195
@Override
192196
protected void doToXContent(XContentBuilder builder, Params params) throws IOException {
193197
builder.field(RETRIEVER_FIELD.getPreferredName(), innerRetrievers.getFirst().retriever());

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankMultiNodeTests.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ protected RankBuilder getThrowingRankBuilder(int rankWindowSize, String rankFeat
3434
inferenceId,
3535
inferenceText,
3636
minScore,
37+
false,
3738
type.name()
3839
);
3940
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilderTests.java

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -118,36 +118,41 @@ public void testParserDefaults() throws IOException {
118118
parser,
119119
new RetrieverParserContext(new SearchUsage(), nf -> true)
120120
);
121-
assertEquals(DEFAULT_RANK_WINDOW_SIZE, parsed.rankWindowSize());
122-
assertEquals(DEFAULT_RERANK_ID, parsed.inferenceId());
121+
assertThat(parsed.rankWindowSize(), equalTo(DEFAULT_RANK_WINDOW_SIZE));
122+
assertThat(parsed.inferenceId(), equalTo(DEFAULT_RERANK_ID));
123+
assertThat(parsed.isLenient(), equalTo(false));
124+
123125
}
124126
}
125127

126128
public void testTextSimilarityRetrieverParsing() throws IOException {
127-
String restContent = "{"
128-
+ " \"retriever\": {"
129-
+ " \"text_similarity_reranker\": {"
130-
+ " \"retriever\": {"
131-
+ " \"test\": {"
132-
+ " \"value\": \"my-test-retriever\""
133-
+ " }"
134-
+ " },"
135-
+ " \"field\": \"my-field\","
136-
+ " \"inference_id\": \"my-inference-id\","
137-
+ " \"inference_text\": \"my-inference-text\","
138-
+ " \"rank_window_size\": 100,"
139-
+ " \"min_score\": 20.0,"
140-
+ " \"_name\": \"foo_reranker\""
141-
+ " }"
142-
+ " }"
143-
+ "}";
129+
String restContent = """
130+
{
131+
"retriever": {
132+
"text_similarity_reranker": {
133+
"retriever": {
134+
"test": {
135+
"value": "my-test-retriever"
136+
}
137+
},
138+
"field": "my-field",
139+
"inference_id": "my-inference-id",
140+
"inference_text": "my-inference-text",
141+
"rank_window_size": 100,
142+
"min_score": 20.0,
143+
"lenient": true,
144+
"_name": "foo_reranker"
145+
}
146+
}
147+
}""";
144148
SearchUsageHolder searchUsageHolder = new UsageService().getSearchUsageHolder();
145149
try (XContentParser jsonParser = createParser(JsonXContent.jsonXContent, restContent)) {
146150
SearchSourceBuilder source = new SearchSourceBuilder().parseXContent(jsonParser, true, searchUsageHolder, nf -> true);
147151
assertThat(source.retriever(), instanceOf(TextSimilarityRankRetrieverBuilder.class));
148152
TextSimilarityRankRetrieverBuilder parsed = (TextSimilarityRankRetrieverBuilder) source.retriever();
149153
assertThat(parsed.minScore(), equalTo(20f));
150154
assertThat(parsed.retrieverName(), equalTo("foo_reranker"));
155+
assertThat(parsed.isLenient(), equalTo(true));
151156
try (XContentParser parseSerialized = createParser(JsonXContent.jsonXContent, Strings.toString(source))) {
152157
SearchSourceBuilder deserializedSource = new SearchSourceBuilder().parseXContent(
153158
parseSerialized,

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java

Lines changed: 62 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,17 @@
2121
import org.elasticsearch.test.hamcrest.ElasticsearchAssertions;
2222
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
2323
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
24+
import org.hamcrest.Matcher;
2425
import org.junit.Before;
2526

2627
import java.util.Collection;
2728
import java.util.Collections;
2829
import java.util.List;
2930
import java.util.Map;
30-
import java.util.Objects;
3131

32+
import static org.elasticsearch.test.LambdaMatchers.transformedMatch;
33+
import static org.hamcrest.Matchers.allOf;
34+
import static org.hamcrest.Matchers.arrayContaining;
3235
import static org.hamcrest.Matchers.containsString;
3336
import static org.hamcrest.Matchers.equalTo;
3437

@@ -66,9 +69,10 @@ public InferenceResultCountAcceptingTextSimilarityRankBuilder(
6669
String inferenceText,
6770
int rankWindowSize,
6871
Float minScore,
72+
boolean lenient,
6973
int inferenceResultCount
7074
) {
71-
super(field, inferenceId, inferenceText, rankWindowSize, minScore, false);
75+
super(field, inferenceId, inferenceText, rankWindowSize, minScore, lenient);
7276
this.inferenceResultCount = inferenceResultCount;
7377
}
7478

@@ -82,7 +86,7 @@ public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorCo
8286
inferenceId,
8387
inferenceText,
8488
minScore,
85-
false
89+
isLenient()
8690
) {
8791
@Override
8892
protected InferenceAction.Request generateRequest(List<String> docFeatures) {
@@ -130,14 +134,17 @@ public void testRerank() {
130134
.setQuery(QueryBuilders.matchAllQuery()),
131135
response -> {
132136
// Verify order, rank and score of results
133-
SearchHit[] hits = response.getHits().getHits();
134-
assertEquals(5, hits.length);
135-
// we add + 1 to all expected scores due to the default normalization being applied which shifts positive scores to by 1
136-
assertHitHasRankScoreAndText(hits[0], 1, 4.0f + 1f, "4");
137-
assertHitHasRankScoreAndText(hits[1], 2, 3.0f + 1f, "3");
138-
assertHitHasRankScoreAndText(hits[2], 3, 2.0f + 1f, "2");
139-
assertHitHasRankScoreAndText(hits[3], 4, 1.0f + 1f, "1");
140-
assertHitHasRankScoreAndText(hits[4], 5, 0.0f + 1f, "0");
137+
assertThat(
138+
response.getHits().getHits(),
139+
arrayContaining(
140+
// add 1 to all expected scores due to the default normalization being applied which shifts positive scores by 1
141+
searchHitWith(1, 4.0f + 1f, "4"),
142+
searchHitWith(2, 3.0f + 1f, "3"),
143+
searchHitWith(3, 2.0f + 1f, "2"),
144+
searchHitWith(4, 1.0f + 1f, "1"),
145+
searchHitWith(5, 0.0f + 1f, "0")
146+
)
147+
);
141148
}
142149
);
143150
}
@@ -150,11 +157,10 @@ public void testRerankWithMinScore() {
150157
.setQuery(QueryBuilders.matchAllQuery()),
151158
response -> {
152159
// Verify order, rank and score of results
153-
SearchHit[] hits = response.getHits().getHits();
154-
assertEquals(3, hits.length);
155-
assertHitHasRankScoreAndText(hits[0], 1, 4.0f + 1f, "4");
156-
assertHitHasRankScoreAndText(hits[1], 2, 3.0f + 1f, "3");
157-
assertHitHasRankScoreAndText(hits[2], 3, 2.0f + 1f, "2");
160+
assertThat(
161+
response.getHits().getHits(),
162+
arrayContaining(searchHitWith(1, 4.0f + 1f, "4"), searchHitWith(2, 3.0f + 1f, "3"), searchHitWith(3, 2.0f + 1f, "2"))
163+
);
158164
}
159165
);
160166
}
@@ -170,6 +176,7 @@ public void testRerankInferenceFailure() {
170176
"my-rerank-model",
171177
"my query",
172178
0.7f,
179+
false,
173180
AbstractRerankerIT.ThrowingRankBuilderType.THROWING_RANK_FEATURE_PHASE_COORDINATOR_CONTEXT.name()
174181
)
175182
)
@@ -179,6 +186,38 @@ public void testRerankInferenceFailure() {
179186
);
180187
}
181188

189+
public void testLenientRerankInference() {
190+
ElasticsearchAssertions.assertNoFailuresAndResponse(
191+
// Execute search with text similarity reranking
192+
client.prepareSearch()
193+
.setRankBuilder(
194+
new TextSimilarityTestPlugin.ThrowingMockRequestActionBasedRankBuilder(
195+
100,
196+
"text",
197+
"my-rerank-model",
198+
"my query",
199+
null,
200+
true,
201+
AbstractRerankerIT.ThrowingRankBuilderType.THROWING_RANK_FEATURE_PHASE_COORDINATOR_CONTEXT.name()
202+
)
203+
)
204+
.setQuery(QueryBuilders.matchAllQuery()),
205+
response -> {
206+
// these will all have a score of 2 (default 1 + normalization)
207+
assertThat(
208+
response.getHits().getHits(),
209+
arrayContaining(
210+
searchHitWith(1, 2.0f, "0"),
211+
searchHitWith(2, 2.0f, "1"),
212+
searchHitWith(3, 2.0f, "2"),
213+
searchHitWith(4, 2.0f, "3"),
214+
searchHitWith(5, 2.0f, "4")
215+
)
216+
);
217+
}
218+
);
219+
}
220+
182221
public void testRerankTopNConfigurationAndRankWindowSizeMismatch() {
183222
SearchPhaseExecutionException ex = expectThrows(
184223
SearchPhaseExecutionException.class,
@@ -205,18 +244,19 @@ public void testRerankInputSizeAndInferenceResultsMismatch() {
205244
client.prepareSearch()
206245
.setRankBuilder(
207246
// Simulate reranker returning different number of results from input
208-
new InferenceResultCountAcceptingTextSimilarityRankBuilder("text", "my-rerank-model", "my query", 100, 1.5f, 4)
247+
new InferenceResultCountAcceptingTextSimilarityRankBuilder("text", "my-rerank-model", "my query", 100, 1.5f, false, 4)
209248
)
210249
.setQuery(QueryBuilders.matchAllQuery())
211250
);
212251
assertThat(ex.status(), equalTo(RestStatus.INTERNAL_SERVER_ERROR));
213252
assertThat(ex.getDetailedMessage(), containsString("Reranker input document count and returned score count mismatch"));
214253
}
215254

216-
private static void assertHitHasRankScoreAndText(SearchHit hit, int expectedRank, float expectedScore, String expectedText) {
217-
assertEquals(expectedRank, hit.getRank());
218-
assertEquals(expectedScore, hit.getScore(), 0.0f);
219-
assertEquals(expectedText, Objects.requireNonNull(hit.getSourceAsMap()).get("text"));
255+
private static Matcher<SearchHit> searchHitWith(int expectedRank, float expectedScore, String expectedText) {
256+
return allOf(
257+
transformedMatch(SearchHit::getRank, equalTo(expectedRank)),
258+
transformedMatch(SearchHit::getScore, equalTo(expectedScore)),
259+
transformedMatch(hit -> hit.getSourceAsMap().get("text"), equalTo(expectedText))
260+
);
220261
}
221-
222262
}

0 commit comments

Comments
 (0)