Skip to content

Commit c5a57fc

Browse files
authored
[8.16] backporting fix for negative scores in text_similarity_ranker retriever (#121056)
1 parent 97c4bdc commit c5a57fc

File tree

13 files changed

+105
-25
lines changed

13 files changed

+105
-25
lines changed

docs/changelog/120930.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 120930
2+
summary: Normalize negative scores for `text_similarity_reranker` retriever
3+
area: Ranking
4+
type: bug
5+
issues:
6+
- 120201

docs/reference/search/retriever.asciidoc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,23 @@ Currently you can:
367367
** Then set up an <<inference-example-eland,{es} service inference endpoint>> with the `rerank` task type
368368
** Refer to the <<text-similarity-reranker-retriever-example-eland,example>> on this page for a step-by-step guide.
369369

370+
[IMPORTANT]
371+
====
372+
Scores from the re-ranking process are normalized using the following formula before returned to the user,
373+
to avoid having negative scores.
374+
[source,text]
375+
----
376+
score = max(score, 0) + min(exp(score), 1)
377+
----
378+
Using the above, any initially negative scores are projected to (0, 1) and positive scores to [1, infinity).
379+
To revert back if needed, one can use:
380+
[source, text]
381+
----
382+
score = score - 1, if score >= 0
383+
score = ln(score), if score < 0
384+
----
385+
====
386+
370387
===== Parameters
371388
`retriever`::
372389
(Required, <<retriever, retriever>>)

server/src/main/java/org/elasticsearch/search/retriever/rankdoc/RankDocsQuery.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,11 @@ public static class TopQuery extends Query {
5858
this.queryNames = queryNames;
5959
this.segmentStarts = segmentStarts;
6060
this.contextIdentity = contextIdentity;
61+
for (RankDoc doc : docs) {
62+
if (false == doc.score >= 0) {
63+
throw new IllegalArgumentException("RankDoc scores must be positive values. Missing a normalization step?");
64+
}
65+
}
6166
}
6267

6368
@Override
@@ -161,7 +166,11 @@ public float getMaxScore(int docId) {
161166

162167
@Override
163168
public float score() {
164-
return docs[upTo].score;
169+
// We could still end up with a valid 0 score for a RankDoc
170+
// so here we want to differentiate between this and all the tailQuery matches
171+
// that would also produce a 0 score due to filtering, by setting the score to `Float.MIN_VALUE` instead for
172+
// RankDoc matches.
173+
return Math.max(docs[upTo].score, Float.MIN_VALUE);
165174
}
166175

167176
@Override

server/src/test/java/org/elasticsearch/index/query/RankDocsQueryBuilderTests.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,4 +251,16 @@ public void testUnknownField() throws IOException {
251251
public void testValidOutput() throws IOException {
252252
// no-op since RankDocsQueryBuilder is an internal only API
253253
}
254+
255+
public void shouldThrowForNegativeScores() throws IOException {
256+
try (Directory directory = newDirectory(); RandomIndexWriter iw = new RandomIndexWriter(random(), directory)) {
257+
iw.addDocument(new Document());
258+
try (IndexReader reader = iw.getReader()) {
259+
SearchExecutionContext context = createSearchExecutionContext(newSearcher(reader));
260+
RankDocsQueryBuilder queryBuilder = new RankDocsQueryBuilder(new RankDoc[] { new RankDoc(0, -1.0f, 0) }, null, false);
261+
IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () -> queryBuilder.doToQuery(context));
262+
assertEquals("RankDoc scores must be positive values. Missing a normalization step?", ex.getMessage());
263+
}
264+
}
265+
}
254266
}

test/framework/src/main/java/org/elasticsearch/search/rank/rerank/AbstractRerankerIT.java

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ public enum ThrowingRankBuilderType {
5656

5757
protected abstract Collection<Class<? extends Plugin>> pluginsNeeded();
5858

59+
protected boolean shouldCheckScores() {
60+
return true;
61+
}
62+
5963
@Override
6064
protected Collection<Class<? extends Plugin>> nodePlugins() {
6165
return pluginsNeeded();
@@ -95,9 +99,11 @@ public void testRerankerNoExceptions() throws Exception {
9599
int rank = 1;
96100
for (SearchHit searchHit : response.getHits().getHits()) {
97101
assertThat(searchHit, hasId(String.valueOf(5 - (rank - 1))));
98-
assertEquals(0.5f - ((rank - 1) * 0.1f), searchHit.getScore(), 1e-5f);
99102
assertThat(searchHit, hasRank(rank));
100103
assertNotNull(searchHit.getFields().get(searchField));
104+
if (shouldCheckScores()) {
105+
assertEquals(0.5f - ((rank - 1) * 0.1f), searchHit.getScore(), 1e-5f);
106+
}
101107
rank++;
102108
}
103109
}
@@ -140,9 +146,11 @@ public void testRerankerPagination() throws Exception {
140146
int rank = 3;
141147
for (SearchHit searchHit : response.getHits().getHits()) {
142148
assertThat(searchHit, hasId(String.valueOf(5 - (rank - 1))));
143-
assertEquals(0.5f - ((rank - 1) * 0.1f), searchHit.getScore(), 1e-5f);
144149
assertThat(searchHit, hasRank(rank));
145150
assertNotNull(searchHit.getFields().get(searchField));
151+
if (shouldCheckScores()) {
152+
assertEquals(0.5f - ((rank - 1) * 0.1f), searchHit.getScore(), 1e-5f);
153+
}
146154
rank++;
147155
}
148156
}
@@ -222,9 +230,11 @@ public void testNotAllShardsArePresentInFetchPhase() throws Exception {
222230
int rank = 1;
223231
for (SearchHit searchHit : response.getHits().getHits()) {
224232
assertThat(searchHit, hasId(String.valueOf(5 - (rank - 1))));
225-
assertEquals(0.5f - ((rank - 1) * 0.1f), searchHit.getScore(), 1e-5f);
226233
assertThat(searchHit, hasRank(rank));
227234
assertNotNull(searchHit.getFields().get(searchField));
235+
if (shouldCheckScores()) {
236+
assertEquals(0.5f - ((rank - 1) * 0.1f), searchHit.getScore(), 1e-5f);
237+
}
228238
rank++;
229239
}
230240
}

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,16 @@
2525
import java.io.IOException;
2626
import java.util.HashMap;
2727
import java.util.Map;
28+
import java.util.Random;
2829

2930
public abstract class AbstractTestInferenceService implements InferenceService {
3031

32+
protected static final Random random = new Random(
33+
System.getProperty("tests.seed") == null
34+
? System.currentTimeMillis()
35+
: Long.parseUnsignedLong(System.getProperty("tests.seed").split(":")[0], 16)
36+
);
37+
3138
protected static int stringWeight(String input, int position) {
3239
int hashCode = input.hashCode();
3340
if (hashCode < 0) {

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import java.util.Map;
3737

3838
public class TestRerankingServiceExtension implements InferenceServiceExtension {
39+
3940
@Override
4041
public List<Factory> getInferenceServiceFactories() {
4142
return List.of(TestInferenceService::new);
@@ -122,9 +123,12 @@ public void chunkedInfer(
122123
private RankedDocsResults makeResults(List<String> input) {
123124
List<RankedDocsResults.RankedDoc> results = new ArrayList<>();
124125
int totalResults = input.size();
126+
float minScore = random.nextFloat(-1f, 1f);
125127
float resultDiff = 0.2f;
126128
for (int i = 0; i < input.size(); i++) {
127-
results.add(new RankedDocsResults.RankedDoc(totalResults - 1 - i, resultDiff * (totalResults - i), input.get(i)));
129+
results.add(
130+
new RankedDocsResults.RankedDoc(totalResults - 1 - i, minScore + resultDiff * (totalResults - i), input.get(i))
131+
);
128132
}
129133
return new RankedDocsResults(results);
130134
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings;
2121
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankTaskSettings;
2222

23+
import java.util.ArrayList;
2324
import java.util.Arrays;
24-
import java.util.Comparator;
2525
import java.util.List;
2626
import java.util.Map;
2727

@@ -130,10 +130,15 @@ protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[
130130
*/
131131
@Override
132132
protected RankFeatureDoc[] preprocess(RankFeatureDoc[] originalDocs) {
133-
return Arrays.stream(originalDocs)
134-
.filter(doc -> minScore == null || doc.score >= minScore)
135-
.sorted(Comparator.comparing((RankFeatureDoc doc) -> doc.score).reversed())
136-
.toArray(RankFeatureDoc[]::new);
133+
List<RankFeatureDoc> docs = new ArrayList<>();
134+
for (RankFeatureDoc doc : originalDocs) {
135+
if (minScore == null || doc.score >= minScore) {
136+
doc.score = normalizeScore(doc.score);
137+
docs.add(doc);
138+
}
139+
}
140+
docs.sort(RankFeatureDoc::compareTo);
141+
return docs.toArray(new RankFeatureDoc[0]);
137142
}
138143

139144
protected InferenceAction.Request generateRequest(List<String> docFeatures) {
@@ -154,7 +159,15 @@ private float[] extractScoresFromRankedDocs(List<RankedDocsResults.RankedDoc> ra
154159
for (RankedDocsResults.RankedDoc rankedDoc : rankedDocs) {
155160
scores[rankedDoc.index()] = rankedDoc.relevanceScore();
156161
}
157-
158162
return scores;
159163
}
164+
165+
private static float normalizeScore(float score) {
166+
// As some models might produce negative scores, we want to ensure that all scores will be positive
167+
// so we will make use of the following normalization formula:
168+
// score = max(score, 0) + min(exp(score), 1)
169+
// this will ensure that all positive scores lie in the [1, inf) range,
170+
// while negative values (and 0) will be shifted to (0, 1]
171+
return Math.max(score, 0) + Math.min((float) Math.exp(score), 1);
172+
}
160173
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ protected RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults) {
156156
TextSimilarityRankDoc[] textSimilarityRankDocs = new TextSimilarityRankDoc[scoreDocs.length];
157157
for (int i = 0; i < scoreDocs.length; i++) {
158158
ScoreDoc scoreDoc = scoreDocs[i];
159+
assert scoreDoc.score >= 0;
159160
textSimilarityRankDocs[i] = new TextSimilarityRankDoc(scoreDoc.doc, scoreDoc.score, scoreDoc.shardIndex, inferenceId, field);
160161
}
161162
return textSimilarityRankDocs;

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankMultiNodeTests.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,4 +50,9 @@ public void testQueryPhaseShardThrowingAllShardsFail() throws Exception {
5050
public void testQueryPhaseCoordinatorThrowingAllShardsFail() throws Exception {
5151
// no-op
5252
}
53+
54+
@Override
55+
protected boolean shouldCheckScores() {
56+
return false;
57+
}
5358
}

0 commit comments

Comments
 (0)