Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/120930.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 120930
summary: Normalize negative scores for `text_similarity_reranker` retriever
area: Ranking
type: bug
issues:
- 120201
17 changes: 17 additions & 0 deletions docs/reference/search/retriever.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,23 @@ You have the following options:
** Then set up an <<inference-example-eland,{es} service inference endpoint>> with the `rerank` task type.
** Refer to the <<text-similarity-reranker-retriever-example-eland,example>> on this page for a step-by-step guide.

[IMPORTANT]
====
Scores from the re-ranking process are normalized using the following formula before returned to the user,
to avoid having negative scores.
[source,text]
----
score = max(score, 0) + min(exp(score), 1)
----
Using the above, any initially negative scores are projected to (0, 1) and positive scores to [1, infinity).
To revert back if needed, one can use:
[source, text]
----
score = score - 1, if score >= 0
score = ln(score), if score < 0
----
====

===== Parameters

`retriever`::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ public static class TopQuery extends Query {
this.queryNames = queryNames;
this.segmentStarts = segmentStarts;
this.contextIdentity = contextIdentity;
for (RankDoc doc : docs) {
if (false == doc.score >= 0) {
throw new IllegalArgumentException("RankDoc scores must be positive values. Missing a normalization step?");
}
}
}

@Override
Expand Down Expand Up @@ -160,7 +165,7 @@ public float getMaxScore(int docId) {

@Override
public float score() {
return docs[upTo].score;
return Math.max(docs[upTo].score, Float.MIN_VALUE);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,4 +251,16 @@ public void testUnknownField() throws IOException {
public void testValidOutput() throws IOException {
// no-op since RankDocsQueryBuilder is an internal only API
}

public void shouldThrowForNegativeScores() throws IOException {
try (Directory directory = newDirectory(); RandomIndexWriter iw = new RandomIndexWriter(random(), directory)) {
iw.addDocument(new Document());
try (IndexReader reader = iw.getReader()) {
SearchExecutionContext context = createSearchExecutionContext(newSearcher(reader));
RankDocsQueryBuilder queryBuilder = new RankDocsQueryBuilder(new RankDoc[] { new RankDoc(0, -1.0f, 0) }, null, false);
IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () -> queryBuilder.doToQuery(context));
assertEquals("RankDoc scores must be positive values. Missing a normalization step?", ex.getMessage());
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ public enum ThrowingRankBuilderType {

protected abstract Collection<Class<? extends Plugin>> pluginsNeeded();

protected boolean shouldCheckScores() {
return true;
}

@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
return pluginsNeeded();
Expand Down Expand Up @@ -95,9 +99,11 @@ public void testRerankerNoExceptions() throws Exception {
int rank = 1;
for (SearchHit searchHit : response.getHits().getHits()) {
assertThat(searchHit, hasId(String.valueOf(5 - (rank - 1))));
assertEquals(0.5f - ((rank - 1) * 0.1f), searchHit.getScore(), 1e-5f);
assertThat(searchHit, hasRank(rank));
assertNotNull(searchHit.getFields().get(searchField));
if (shouldCheckScores()) {
assertEquals(0.5f - ((rank - 1) * 0.1f), searchHit.getScore(), 1e-5f);
}
rank++;
}
}
Expand Down Expand Up @@ -140,9 +146,11 @@ public void testRerankerPagination() throws Exception {
int rank = 3;
for (SearchHit searchHit : response.getHits().getHits()) {
assertThat(searchHit, hasId(String.valueOf(5 - (rank - 1))));
assertEquals(0.5f - ((rank - 1) * 0.1f), searchHit.getScore(), 1e-5f);
assertThat(searchHit, hasRank(rank));
assertNotNull(searchHit.getFields().get(searchField));
if (shouldCheckScores()) {
assertEquals(0.5f - ((rank - 1) * 0.1f), searchHit.getScore(), 1e-5f);
}
rank++;
}
}
Expand Down Expand Up @@ -222,9 +230,11 @@ public void testNotAllShardsArePresentInFetchPhase() throws Exception {
int rank = 1;
for (SearchHit searchHit : response.getHits().getHits()) {
assertThat(searchHit, hasId(String.valueOf(5 - (rank - 1))));
assertEquals(0.5f - ((rank - 1) * 0.1f), searchHit.getScore(), 1e-5f);
assertThat(searchHit, hasRank(rank));
assertNotNull(searchHit.getFields().get(searchField));
if (shouldCheckScores()) {
assertEquals(0.5f - ((rank - 1) * 0.1f), searchHit.getScore(), 1e-5f);
}
rank++;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,16 @@
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;

public abstract class AbstractTestInferenceService implements InferenceService {

protected static final Random random = new Random(
System.getProperty("tests.seed") == null
? System.currentTimeMillis()
: Long.parseUnsignedLong(System.getProperty("tests.seed").split(":")[0], 16)
);

protected static int stringWeight(String input, int position) {
int hashCode = input.hashCode();
if (hashCode < 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import java.util.Map;

public class TestRerankingServiceExtension implements InferenceServiceExtension {

@Override
public List<Factory> getInferenceServiceFactories() {
return List.of(TestInferenceService::new);
Expand Down Expand Up @@ -149,9 +150,12 @@ public void chunkedInfer(
private RankedDocsResults makeResults(List<String> input) {
List<RankedDocsResults.RankedDoc> results = new ArrayList<>();
int totalResults = input.size();
float minScore = random.nextFloat(-1f, 1f);
float resultDiff = 0.2f;
for (int i = 0; i < input.size(); i++) {
results.add(new RankedDocsResults.RankedDoc(totalResults - 1 - i, resultDiff * (totalResults - i), input.get(i)));
results.add(
new RankedDocsResults.RankedDoc(totalResults - 1 - i, minScore + resultDiff * (totalResults - i), input.get(i))
);
}
return new RankedDocsResults(results);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankTaskSettings;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
Expand Down Expand Up @@ -130,10 +131,14 @@ protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[
*/
@Override
protected RankFeatureDoc[] preprocess(RankFeatureDoc[] originalDocs) {
return Arrays.stream(originalDocs)
.filter(doc -> minScore == null || doc.score >= minScore)
.sorted(Comparator.comparing((RankFeatureDoc doc) -> doc.score).reversed())
.toArray(RankFeatureDoc[]::new);
List<RankFeatureDoc> docs = new ArrayList<>();
for (RankFeatureDoc doc : originalDocs) {
if (minScore == null || doc.score >= minScore) {
doc.score = normalizeScore(doc.score);
docs.add(doc);
}
}
return docs.stream().sorted(Comparator.comparing((RankFeatureDoc doc) -> doc.score).reversed()).toArray(RankFeatureDoc[]::new);
}

protected InferenceAction.Request generateRequest(List<String> docFeatures) {
Expand All @@ -154,7 +159,15 @@ private float[] extractScoresFromRankedDocs(List<RankedDocsResults.RankedDoc> ra
for (RankedDocsResults.RankedDoc rankedDoc : rankedDocs) {
scores[rankedDoc.index()] = rankedDoc.relevanceScore();
}

return scores;
}

private static float normalizeScore(float score) {
// As some models might produce negative scores, we want to ensure that all scores will be positive
// so we will make use of the following normalization formula:
// score = max(score, 0) + min(exp(score), 1)
// this will ensure that all positive scores lie in the [1, inf) range,
// while negative values (and 0) will be shifted to (0, 1]
return Math.max(score, 0) + Math.min((float) Math.exp(score), 1);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ protected RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults, b
TextSimilarityRankDoc[] textSimilarityRankDocs = new TextSimilarityRankDoc[scoreDocs.length];
for (int i = 0; i < scoreDocs.length; i++) {
ScoreDoc scoreDoc = scoreDocs[i];
assert scoreDoc.score >= 0;
if (explain) {
textSimilarityRankDocs[i] = new TextSimilarityRankDoc(
scoreDoc.doc,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,9 @@ public void testQueryPhaseShardThrowingAllShardsFail() throws Exception {
public void testQueryPhaseCoordinatorThrowingAllShardsFail() throws Exception {
// no-op
}

@Override
protected boolean shouldCheckScores() {
return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,12 @@ public void testRerank() {
// Verify order, rank and score of results
SearchHit[] hits = response.getHits().getHits();
assertEquals(5, hits.length);
assertHitHasRankScoreAndText(hits[0], 1, 4.0f, "4");
assertHitHasRankScoreAndText(hits[1], 2, 3.0f, "3");
assertHitHasRankScoreAndText(hits[2], 3, 2.0f, "2");
assertHitHasRankScoreAndText(hits[3], 4, 1.0f, "1");
assertHitHasRankScoreAndText(hits[4], 5, 0.0f, "0");
// we add + 1 to all expected scores due to the default normalization being applied which shifts positive scores to by 1
assertHitHasRankScoreAndText(hits[0], 1, 4.0f + 1f, "4");
assertHitHasRankScoreAndText(hits[1], 2, 3.0f + 1f, "3");
assertHitHasRankScoreAndText(hits[2], 3, 2.0f + 1f, "2");
assertHitHasRankScoreAndText(hits[3], 4, 1.0f + 1f, "1");
assertHitHasRankScoreAndText(hits[4], 5, 0.0f + 1f, "0");
}
);
}
Expand All @@ -150,9 +151,9 @@ public void testRerankWithMinScore() {
// Verify order, rank and score of results
SearchHit[] hits = response.getHits().getHits();
assertEquals(3, hits.length);
assertHitHasRankScoreAndText(hits[0], 1, 4.0f, "4");
assertHitHasRankScoreAndText(hits[1], 2, 3.0f, "3");
assertHitHasRankScoreAndText(hits[2], 3, 2.0f, "2");
assertHitHasRankScoreAndText(hits[0], 1, 4.0f + 1f, "4");
assertHitHasRankScoreAndText(hits[1], 2, 3.0f + 1f, "3");
assertHitHasRankScoreAndText(hits[2], 3, 2.0f + 1f, "2");
}
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ public class InferenceRestIT extends ESClientYamlSuiteTestCase {

@ClassRule
public static ElasticsearchCluster cluster = ElasticsearchCluster.local()
.systemProperty("tests.seed", System.getProperty("tests.seed"))
.setting("xpack.security.enabled", "false")
.setting("xpack.security.http.ssl.enabled", "false")
.setting("xpack.license.self_generated.type", "trial")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,7 @@ setup:
- length: { hits.hits: 2 }

- match: { hits.hits.0._id: "doc_2" }
- close_to: { hits.hits.0._score: { value: 0.4, error: 0.001 } }

- match: { hits.hits.1._id: "doc_1" }
- close_to: { hits.hits.1._score: { value: 0.2, error: 0.001 } }

---
"Simple text similarity rank retriever and filtering":
Expand Down Expand Up @@ -123,8 +120,6 @@ setup:
- length: { hits.hits: 1 }

- match: { hits.hits.0._id: "doc_1" }
- close_to: { hits.hits.0._score: { value: 0.2, error: 0.001 } }


---
"Text similarity reranking fails if the inference ID does not exist":
Expand Down Expand Up @@ -211,7 +206,6 @@ setup:
- contains: { hits.hits: { _id: "doc_2" } }
- contains: { hits.hits: { _id: "doc_1" } }

- close_to: { hits.hits.0._explanation.value: { value: 0.4, error: 0.000001 } }
- match: {hits.hits.0._explanation.description: "/text_similarity_reranker.match.using.inference.endpoint:.\\[my-rerank-model\\].on.document.field:.\\[text\\].*/" }
- match: {hits.hits.0._explanation.details.0.description: "/weight.*science.*/" }

Expand Down