diff --git a/docs/changelog/120930.yaml b/docs/changelog/120930.yaml new file mode 100644 index 0000000000000..376edb7632a0b --- /dev/null +++ b/docs/changelog/120930.yaml @@ -0,0 +1,6 @@ +pr: 120930 +summary: Normalize negative scores for `text_similarity_reranker` retriever +area: Ranking +type: bug +issues: + - 120201 diff --git a/docs/reference/search/retriever.asciidoc b/docs/reference/search/retriever.asciidoc index 21892b4efe5a8..4cccf4d204d99 100644 --- a/docs/reference/search/retriever.asciidoc +++ b/docs/reference/search/retriever.asciidoc @@ -523,6 +523,23 @@ You have the following options: ** Then set up an <> with the `rerank` task type. ** Refer to the <> on this page for a step-by-step guide. +[IMPORTANT] +==== +Scores from the re-ranking process are normalized using the following formula before returned to the user, +to avoid having negative scores. +[source,text] +---- +score = max(score, 0) + min(exp(score), 1) +---- +Using the above, any initially negative scores are projected to (0, 1) and positive scores to [1, infinity). +To revert back if needed, one can use: +[source, text] +---- +score = score - 1, if score >= 0 +score = ln(score), if score < 0 +---- +==== + ===== Parameters `retriever`:: diff --git a/server/src/main/java/org/elasticsearch/search/retriever/rankdoc/RankDocsQuery.java b/server/src/main/java/org/elasticsearch/search/retriever/rankdoc/RankDocsQuery.java index ebbdf58cc8c4f..5920567646030 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/rankdoc/RankDocsQuery.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/rankdoc/RankDocsQuery.java @@ -57,6 +57,11 @@ public static class TopQuery extends Query { this.queryNames = queryNames; this.segmentStarts = segmentStarts; this.contextIdentity = contextIdentity; + for (RankDoc doc : docs) { + if (false == doc.score >= 0) { + throw new IllegalArgumentException("RankDoc scores must be positive values. Missing a normalization step?"); + } + } } @Override @@ -160,7 +165,11 @@ public float getMaxScore(int docId) { @Override public float score() { - return docs[upTo].score; + // We could still end up with a valid 0 score for a RankDoc + // so here we want to differentiate between this and all the tailQuery matches + // that would also produce a 0 score due to filtering, by setting the score to `Float.MIN_VALUE` instead for + // RankDoc matches. + return Math.max(docs[upTo].score, Float.MIN_VALUE); } @Override diff --git a/server/src/test/java/org/elasticsearch/index/query/RankDocsQueryBuilderTests.java b/server/src/test/java/org/elasticsearch/index/query/RankDocsQueryBuilderTests.java index ba39702d3d162..9f1d2fbfdefff 100644 --- a/server/src/test/java/org/elasticsearch/index/query/RankDocsQueryBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/index/query/RankDocsQueryBuilderTests.java @@ -251,4 +251,16 @@ public void testUnknownField() throws IOException { public void testValidOutput() throws IOException { // no-op since RankDocsQueryBuilder is an internal only API } + + public void shouldThrowForNegativeScores() throws IOException { + try (Directory directory = newDirectory(); RandomIndexWriter iw = new RandomIndexWriter(random(), directory)) { + iw.addDocument(new Document()); + try (IndexReader reader = iw.getReader()) { + SearchExecutionContext context = createSearchExecutionContext(newSearcher(reader)); + RankDocsQueryBuilder queryBuilder = new RankDocsQueryBuilder(new RankDoc[] { new RankDoc(0, -1.0f, 0) }, null, false); + IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () -> queryBuilder.doToQuery(context)); + assertEquals("RankDoc scores must be positive values. Missing a normalization step?", ex.getMessage()); + } + } + } } diff --git a/test/framework/src/main/java/org/elasticsearch/search/rank/rerank/AbstractRerankerIT.java b/test/framework/src/main/java/org/elasticsearch/search/rank/rerank/AbstractRerankerIT.java index 06763c27a3536..ad4e5842629e7 100644 --- a/test/framework/src/main/java/org/elasticsearch/search/rank/rerank/AbstractRerankerIT.java +++ b/test/framework/src/main/java/org/elasticsearch/search/rank/rerank/AbstractRerankerIT.java @@ -56,6 +56,10 @@ public enum ThrowingRankBuilderType { protected abstract Collection> pluginsNeeded(); + protected boolean shouldCheckScores() { + return true; + } + @Override protected Collection> nodePlugins() { return pluginsNeeded(); @@ -95,9 +99,11 @@ public void testRerankerNoExceptions() throws Exception { int rank = 1; for (SearchHit searchHit : response.getHits().getHits()) { assertThat(searchHit, hasId(String.valueOf(5 - (rank - 1)))); - assertEquals(0.5f - ((rank - 1) * 0.1f), searchHit.getScore(), 1e-5f); assertThat(searchHit, hasRank(rank)); assertNotNull(searchHit.getFields().get(searchField)); + if (shouldCheckScores()) { + assertEquals(0.5f - ((rank - 1) * 0.1f), searchHit.getScore(), 1e-5f); + } rank++; } } @@ -140,9 +146,11 @@ public void testRerankerPagination() throws Exception { int rank = 3; for (SearchHit searchHit : response.getHits().getHits()) { assertThat(searchHit, hasId(String.valueOf(5 - (rank - 1)))); - assertEquals(0.5f - ((rank - 1) * 0.1f), searchHit.getScore(), 1e-5f); assertThat(searchHit, hasRank(rank)); assertNotNull(searchHit.getFields().get(searchField)); + if (shouldCheckScores()) { + assertEquals(0.5f - ((rank - 1) * 0.1f), searchHit.getScore(), 1e-5f); + } rank++; } } @@ -222,9 +230,11 @@ public void testNotAllShardsArePresentInFetchPhase() throws Exception { int rank = 1; for (SearchHit searchHit : response.getHits().getHits()) { assertThat(searchHit, hasId(String.valueOf(5 - (rank - 1)))); - assertEquals(0.5f - ((rank - 1) * 0.1f), searchHit.getScore(), 1e-5f); assertThat(searchHit, hasRank(rank)); assertNotNull(searchHit.getFields().get(searchField)); + if (shouldCheckScores()) { + assertEquals(0.5f - ((rank - 1) * 0.1f), searchHit.getScore(), 1e-5f); + } rank++; } } diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java index 3be85ee857bbb..3c29cef47d628 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java @@ -26,9 +26,16 @@ import java.io.IOException; import java.util.HashMap; import java.util.Map; +import java.util.Random; public abstract class AbstractTestInferenceService implements InferenceService { + protected static final Random random = new Random( + System.getProperty("tests.seed") == null + ? System.currentTimeMillis() + : Long.parseUnsignedLong(System.getProperty("tests.seed").split(":")[0], 16) + ); + protected static int stringWeight(String input, int position) { int hashCode = input.hashCode(); if (hashCode < 0) { diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java index e79c8b9bad522..765c69e28a9ad 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java @@ -42,6 +42,7 @@ import java.util.Map; public class TestRerankingServiceExtension implements InferenceServiceExtension { + @Override public List getInferenceServiceFactories() { return List.of(TestInferenceService::new); @@ -149,9 +150,12 @@ public void chunkedInfer( private RankedDocsResults makeResults(List input) { List results = new ArrayList<>(); int totalResults = input.size(); + float minScore = random.nextFloat(-1f, 1f); float resultDiff = 0.2f; for (int i = 0; i < input.size(); i++) { - results.add(new RankedDocsResults.RankedDoc(totalResults - 1 - i, resultDiff * (totalResults - i), input.get(i))); + results.add( + new RankedDocsResults.RankedDoc(totalResults - 1 - i, minScore + resultDiff * (totalResults - i), input.get(i)) + ); } return new RankedDocsResults(results); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java index 0ff48bfd493ba..63274e5104207 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java @@ -20,8 +20,8 @@ import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings; import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankTaskSettings; +import java.util.ArrayList; import java.util.Arrays; -import java.util.Comparator; import java.util.List; import java.util.Map; @@ -130,10 +130,15 @@ protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener minScore == null || doc.score >= minScore) - .sorted(Comparator.comparing((RankFeatureDoc doc) -> doc.score).reversed()) - .toArray(RankFeatureDoc[]::new); + List docs = new ArrayList<>(); + for (RankFeatureDoc doc : originalDocs) { + if (minScore == null || doc.score >= minScore) { + doc.score = normalizeScore(doc.score); + docs.add(doc); + } + } + docs.sort(RankFeatureDoc::compareTo); + return docs.toArray(new RankFeatureDoc[0]); } protected InferenceAction.Request generateRequest(List docFeatures) { @@ -154,7 +159,15 @@ private float[] extractScoresFromRankedDocs(List ra for (RankedDocsResults.RankedDoc rankedDoc : rankedDocs) { scores[rankedDoc.index()] = rankedDoc.relevanceScore(); } - return scores; } + + private static float normalizeScore(float score) { + // As some models might produce negative scores, we want to ensure that all scores will be positive + // so we will make use of the following normalization formula: + // score = max(score, 0) + min(exp(score), 1) + // this will ensure that all positive scores lie in the [1, inf) range, + // while negative values (and 0) will be shifted to (0, 1] + return Math.max(score, 0) + Math.min((float) Math.exp(score), 1); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java index 10a1bc324fd2b..165c42fdb7d1f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java @@ -142,6 +142,7 @@ protected RankDoc[] combineInnerRetrieverResults(List rankResults, b TextSimilarityRankDoc[] textSimilarityRankDocs = new TextSimilarityRankDoc[scoreDocs.length]; for (int i = 0; i < scoreDocs.length; i++) { ScoreDoc scoreDoc = scoreDocs[i]; + assert scoreDoc.score >= 0; if (explain) { textSimilarityRankDocs[i] = new TextSimilarityRankDoc( scoreDoc.doc, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankMultiNodeTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankMultiNodeTests.java index daed03c198e0d..27a8f0e962761 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankMultiNodeTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankMultiNodeTests.java @@ -50,4 +50,9 @@ public void testQueryPhaseShardThrowingAllShardsFail() throws Exception { public void testQueryPhaseCoordinatorThrowingAllShardsFail() throws Exception { // no-op } + + @Override + protected boolean shouldCheckScores() { + return false; + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java index f81f2965c392e..0969a902870b6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java @@ -131,11 +131,12 @@ public void testRerank() { // Verify order, rank and score of results SearchHit[] hits = response.getHits().getHits(); assertEquals(5, hits.length); - assertHitHasRankScoreAndText(hits[0], 1, 4.0f, "4"); - assertHitHasRankScoreAndText(hits[1], 2, 3.0f, "3"); - assertHitHasRankScoreAndText(hits[2], 3, 2.0f, "2"); - assertHitHasRankScoreAndText(hits[3], 4, 1.0f, "1"); - assertHitHasRankScoreAndText(hits[4], 5, 0.0f, "0"); + // we add + 1 to all expected scores due to the default normalization being applied which shifts positive scores to by 1 + assertHitHasRankScoreAndText(hits[0], 1, 4.0f + 1f, "4"); + assertHitHasRankScoreAndText(hits[1], 2, 3.0f + 1f, "3"); + assertHitHasRankScoreAndText(hits[2], 3, 2.0f + 1f, "2"); + assertHitHasRankScoreAndText(hits[3], 4, 1.0f + 1f, "1"); + assertHitHasRankScoreAndText(hits[4], 5, 0.0f + 1f, "0"); } ); } @@ -150,9 +151,9 @@ public void testRerankWithMinScore() { // Verify order, rank and score of results SearchHit[] hits = response.getHits().getHits(); assertEquals(3, hits.length); - assertHitHasRankScoreAndText(hits[0], 1, 4.0f, "4"); - assertHitHasRankScoreAndText(hits[1], 2, 3.0f, "3"); - assertHitHasRankScoreAndText(hits[2], 3, 2.0f, "2"); + assertHitHasRankScoreAndText(hits[0], 1, 4.0f + 1f, "4"); + assertHitHasRankScoreAndText(hits[1], 2, 3.0f + 1f, "3"); + assertHitHasRankScoreAndText(hits[2], 3, 2.0f + 1f, "2"); } ); } diff --git a/x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java b/x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java index 8d8ad94d608d7..da01459b057b6 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java +++ b/x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java @@ -20,6 +20,7 @@ public class InferenceRestIT extends ESClientYamlSuiteTestCase { @ClassRule public static ElasticsearchCluster cluster = ElasticsearchCluster.local() + .systemProperty("tests.seed", System.getProperty("tests.seed")) .setting("xpack.security.enabled", "false") .setting("xpack.security.http.ssl.enabled", "false") .setting("xpack.license.self_generated.type", "trial") diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/70_text_similarity_rank_retriever.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/70_text_similarity_rank_retriever.yml index 88569daaa6070..9a6ecffe29d4d 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/70_text_similarity_rank_retriever.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/70_text_similarity_rank_retriever.yml @@ -89,10 +89,7 @@ setup: - length: { hits.hits: 2 } - match: { hits.hits.0._id: "doc_2" } - - close_to: { hits.hits.0._score: { value: 0.4, error: 0.001 } } - - match: { hits.hits.1._id: "doc_1" } - - close_to: { hits.hits.1._score: { value: 0.2, error: 0.001 } } --- "Simple text similarity rank retriever and filtering": @@ -123,8 +120,6 @@ setup: - length: { hits.hits: 1 } - match: { hits.hits.0._id: "doc_1" } - - close_to: { hits.hits.0._score: { value: 0.2, error: 0.001 } } - --- "Text similarity reranking fails if the inference ID does not exist": @@ -211,7 +206,6 @@ setup: - contains: { hits.hits: { _id: "doc_2" } } - contains: { hits.hits: { _id: "doc_1" } } - - close_to: { hits.hits.0._explanation.value: { value: 0.4, error: 0.000001 } } - match: {hits.hits.0._explanation.description: "/text_similarity_reranker.match.using.inference.endpoint:.\\[my-rerank-model\\].on.document.field:.\\[text\\].*/" } - match: {hits.hits.0._explanation.details.0.description: "/weight.*science.*/" }