|
23 | 23 | import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankTaskSettings; |
24 | 24 |
|
25 | 25 | import java.util.ArrayList; |
| 26 | +import java.util.Arrays; |
26 | 27 | import java.util.List; |
27 | 28 | import java.util.Map; |
28 | 29 |
|
@@ -210,22 +211,23 @@ private float[] extractScoresFromRankedDocs(List<RankedDocsResults.RankedDoc> ra |
210 | 211 | } |
211 | 212 |
|
212 | 213 | private float[] extractScoresFromRankedSnippets(List<RankedDocsResults.RankedDoc> rankedDocs, RankFeatureDoc[] featureDocs) { |
213 | | - int numSnippets = rankedDocs.size() / featureDocs.length; |
214 | 214 | float[] scores = new float[featureDocs.length]; |
215 | 215 | boolean[] hasScore = new boolean[featureDocs.length]; |
216 | 216 |
|
| 217 | + // We need to correlate the index/doc values of each RankedDoc in correlation with its associated RankFeatureDoc. |
| 218 | + int[] rankedDocToFeatureDoc = Arrays.stream(featureDocs) |
| 219 | + .flatMapToInt( |
| 220 | + doc -> java.util.stream.IntStream.generate(() -> Arrays.asList(featureDocs).indexOf(doc)).limit(doc.featureData.size()) |
| 221 | + ) |
| 222 | + .limit(rankedDocs.size()) |
| 223 | + .toArray(); |
| 224 | + |
217 | 225 | for (int i = 0; i < rankedDocs.size(); i++) { |
218 | | - // TODO this naively assumes that we always get the requested number of snippets per ranked document |
219 | 226 | RankedDocsResults.RankedDoc rankedDoc = rankedDocs.get(i); |
220 | | - int docId = rankedDoc.index() / numSnippets; |
| 227 | + int docId = rankedDocToFeatureDoc[rankedDoc.index()]; |
221 | 228 | float score = rankedDoc.relevanceScore(); |
222 | | - |
223 | | - if (hasScore[docId] == false) { |
224 | | - scores[docId] = score; |
225 | | - hasScore[docId] = true; |
226 | | - } else { |
227 | | - scores[docId] = Math.max(scores[docId], score); |
228 | | - } |
| 229 | + scores[docId] = hasScore[docId] == false ? score : Math.max(scores[docId], score); |
| 230 | + hasScore[docId] = true; |
229 | 231 | } |
230 | 232 |
|
231 | 233 | float[] result = new float[featureDocs.length]; |
|
0 commit comments