Skip to content

Commit 30eb8f6

Browse files
committed
iter
1 parent 16ca93f commit 30eb8f6

File tree

4 files changed

+22
-4
lines changed

4 files changed

+22
-4
lines changed

server/src/main/java/org/elasticsearch/search/retriever/rankdoc/RankDocsQuery.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.apache.lucene.search.Scorer;
2525
import org.apache.lucene.search.ScorerSupplier;
2626
import org.apache.lucene.search.Weight;
27+
import org.elasticsearch.index.query.RankDocsQueryBuilder;
2728
import org.elasticsearch.search.rank.RankDoc;
2829

2930
import java.io.IOException;
@@ -57,6 +58,11 @@ public static class TopQuery extends Query {
5758
this.queryNames = queryNames;
5859
this.segmentStarts = segmentStarts;
5960
this.contextIdentity = contextIdentity;
61+
for (RankDoc doc : docs) {
62+
if (false == doc.score >= 0) {
63+
throw new IllegalArgumentException("RankDoc scores must be positive values. Missing a normalization step?");
64+
}
65+
}
6066
}
6167

6268
@Override
@@ -160,7 +166,7 @@ public float getMaxScore(int docId) {
160166

161167
@Override
162168
public float score() {
163-
return docs[upTo].score;
169+
return Math.max(docs[upTo].score, Float.MIN_VALUE);
164170
}
165171

166172
@Override

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,10 @@
4040
import java.util.HashMap;
4141
import java.util.List;
4242
import java.util.Map;
43+
import java.util.Random;
4344

4445
public class TestRerankingServiceExtension implements InferenceServiceExtension {
46+
4547
@Override
4648
public List<Factory> getInferenceServiceFactories() {
4749
return List.of(TestInferenceService::new);
@@ -149,9 +151,10 @@ public void chunkedInfer(
149151
private RankedDocsResults makeResults(List<String> input) {
150152
List<RankedDocsResults.RankedDoc> results = new ArrayList<>();
151153
int totalResults = input.size();
154+
float minScore = new Random().nextFloat(-1f, 1f);
152155
float resultDiff = 0.2f;
153156
for (int i = 0; i < input.size(); i++) {
154-
results.add(new RankedDocsResults.RankedDoc(totalResults - 1 - i, resultDiff * (totalResults - i), input.get(i)));
157+
results.add(new RankedDocsResults.RankedDoc(totalResults - 1 - i, minScore + resultDiff * (totalResults - i), input.get(i)));
155158
}
156159
return new RankedDocsResults(results);
157160
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,11 +150,19 @@ protected InferenceAction.Request generateRequest(List<String> docFeatures) {
150150
}
151151

152152
private float[] extractScoresFromRankedDocs(List<RankedDocsResults.RankedDoc> rankedDocs) {
153+
// As some models might produce negative scores, we want to ensure that all scores will be positive
154+
// so we will make use of the following normalization formula:
155+
// score = max(score, 0) + min(exp(score), 1)
156+
// this will ensure that all positive scores lie in the [1, inf) range,
157+
// while negative values (and 0) will be shifted to (0, 1]
153158
float[] scores = new float[rankedDocs.size()];
154159
for (RankedDocsResults.RankedDoc rankedDoc : rankedDocs) {
155-
scores[rankedDoc.index()] = rankedDoc.relevanceScore();
160+
scores[rankedDoc.index()] = normalizeScore(rankedDoc.relevanceScore());
156161
}
157-
158162
return scores;
159163
}
164+
165+
private static float normalizeScore(float score) {
166+
return Math.max(score, 0) + Math.min((float) Math.exp(score), 1);
167+
}
160168
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ protected RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults, b
142142
TextSimilarityRankDoc[] textSimilarityRankDocs = new TextSimilarityRankDoc[scoreDocs.length];
143143
for (int i = 0; i < scoreDocs.length; i++) {
144144
ScoreDoc scoreDoc = scoreDocs[i];
145+
assert scoreDoc.score >= 0;
145146
if (explain) {
146147
textSimilarityRankDocs[i] = new TextSimilarityRankDoc(
147148
scoreDoc.doc,

0 commit comments

Comments
 (0)