Skip to content

Commit 16dbca0

Browse files
committed
Fix snippet calculation
1 parent e3259d9 commit 16dbca0

File tree

1 file changed

+12
-10
lines changed

1 file changed

+12
-10
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankTaskSettings;
2424

2525
import java.util.ArrayList;
26+
import java.util.Arrays;
2627
import java.util.List;
2728
import java.util.Map;
2829

@@ -210,22 +211,23 @@ private float[] extractScoresFromRankedDocs(List<RankedDocsResults.RankedDoc> ra
210211
}
211212

212213
private float[] extractScoresFromRankedSnippets(List<RankedDocsResults.RankedDoc> rankedDocs, RankFeatureDoc[] featureDocs) {
213-
int numSnippets = rankedDocs.size() / featureDocs.length;
214214
float[] scores = new float[featureDocs.length];
215215
boolean[] hasScore = new boolean[featureDocs.length];
216216

217+
// We need to correlate the index/doc values of each RankedDoc in correlation with its associated RankFeatureDoc.
218+
int[] rankedDocToFeatureDoc = Arrays.stream(featureDocs)
219+
.flatMapToInt(
220+
doc -> java.util.stream.IntStream.generate(() -> Arrays.asList(featureDocs).indexOf(doc)).limit(doc.featureData.size())
221+
)
222+
.limit(rankedDocs.size())
223+
.toArray();
224+
217225
for (int i = 0; i < rankedDocs.size(); i++) {
218-
// TODO this naively assumes that we always get the requested number of snippets per ranked document
219226
RankedDocsResults.RankedDoc rankedDoc = rankedDocs.get(i);
220-
int docId = rankedDoc.index() / numSnippets;
227+
int docId = rankedDocToFeatureDoc[rankedDoc.index()];
221228
float score = rankedDoc.relevanceScore();
222-
223-
if (hasScore[docId] == false) {
224-
scores[docId] = score;
225-
hasScore[docId] = true;
226-
} else {
227-
scores[docId] = Math.max(scores[docId], score);
228-
}
229+
scores[docId] = hasScore[docId] == false ? score : Math.max(scores[docId], score);
230+
hasScore[docId] = true;
229231
}
230232

231233
float[] result = new float[featureDocs.length];

0 commit comments

Comments
 (0)