1010package org .elasticsearch .search .diversification .mmr ;
1111
1212import org .apache .lucene .index .VectorSimilarityFunction ;
13- import org .apache .lucene .search .Explanation ;
1413import org .elasticsearch .index .mapper .vectors .DenseVectorFieldMapper ;
1514import org .elasticsearch .search .SearchHit ;
1615import org .elasticsearch .search .SearchHits ;
@@ -33,63 +32,60 @@ public SearchHits diversify(SearchHits hits, ResultDiversificationContext divers
3332 }
3433
3534 MMRResultDiversificationContext context = (MMRResultDiversificationContext ) diversificationContext ;
36- SearchHit [] docs = hits .getHits (); // NOTE: by reference, not new array
35+ SearchHit [] searchHits = hits .getHits (); // NOTE: by reference, not new array
3736
38- if (docs .length == 0 ) {
37+ if (searchHits .length == 0 ) {
3938 return hits ;
4039 }
4140
4241 Map <Integer , Integer > docIdIndexMapping = new HashMap <>();
43- Map <Integer , VectorData > fieldVectors = getFieldVectorsForHits (docs , context , docIdIndexMapping );
42+ Map <Integer , VectorData > fieldVectors = getFieldVectorsForHits (searchHits , context , docIdIndexMapping );
4443
4544 VectorSimilarityFunction similarityFunction = DenseVectorFieldMapper .VectorSimilarity .MAX_INNER_PRODUCT .vectorSimilarityFunction (
4645 context .getIndexVersion (),
4746 diversificationContext .getElementType ()
4847 );
4948
50- List < Integer > rerankedDocIds = new ArrayList <>();
51- Map <Integer , VectorData > selectedVectors = new HashMap <>();
49+ // our chosen DocIDs to keep
50+ List <Integer > selectedDocIds = new ArrayList <>();
5251
5352 // always add the highest scoring doc to the list
54- int highestDocIdIndex = -1 ;
53+ int highestScoreDocId = -1 ;
5554 float highestScore = Float .MIN_VALUE ;
56- for (int i = 0 ; i < docs . length ; i ++ ) {
57- if (docs [ i ] .getScore () > highestScore ) {
58- highestDocIdIndex = i ;
59- highestScore = docs [ i ] .getScore ();
55+ for (SearchHit hit : searchHits ) {
56+ if (hit .getScore () > highestScore ) {
57+ highestScoreDocId = hit . docId () ;
58+ highestScore = hit .getScore ();
6059 }
6160 }
62- int firstDocId = docs [highestDocIdIndex ].docId ();
63- rerankedDocIds .add (firstDocId );
61+ selectedDocIds .add (highestScoreDocId );
6462
65- // and add the vector for the first items
66- VectorData firstVec = fieldVectors .get (firstDocId );
67- selectedVectors .put (firstDocId , firstVec );
63+ // test the vector to see if we are using floats or bytes
64+ VectorData firstVec = fieldVectors .get (highestScoreDocId );
6865 boolean useFloat = firstVec .isFloat ();
6966
70- // cache the similarity scores for the query vector vs. docs
71- Map <Integer , Float > querySimilarity = getQuerySimilarityForDocs (docs , fieldVectors , similarityFunction , useFloat , context );
67+ // cache the similarity scores for the query vector vs. searchHits
68+ Map <Integer , Float > querySimilarity = getQuerySimilarityForDocs (searchHits , fieldVectors , similarityFunction , useFloat , context );
7269
7370 Map <Integer , Map <Integer , Float >> cachedSimilarities = new HashMap <>();
7471 int numCandidates = context .getNumCandidates ();
7572
76- for (int x = 0 ; x < numCandidates && rerankedDocIds .size () < numCandidates && rerankedDocIds .size () < docs .length ; x ++) {
73+ for (int x = 0 ; x < numCandidates && selectedDocIds .size () < numCandidates && selectedDocIds .size () < searchHits .length ; x ++) {
7774 int thisMaxMMRDocId = -1 ;
78- float thisMaxMMRScore = Float .MIN_VALUE ;
79- for (SearchHit thisHit : docs ) {
75+ float thisMaxMMRScore = Float .NEGATIVE_INFINITY ;
76+ for (SearchHit thisHit : searchHits ) {
8077 int docId = thisHit .docId ();
8178
82- if (rerankedDocIds .contains (docId )) {
79+ if (selectedDocIds .contains (docId )) {
8380 continue ;
8481 }
8582
8683 var thisDocVector = fieldVectors .get (docId );
87-
8884 var cachedScoresForDoc = cachedSimilarities .getOrDefault (docId , new HashMap <>());
8985
90- // compute MMR scores for remaining docs
86+ // compute MMR scores for remaining searchHits
9187 float highestMMRScore = getHighestScoreForSelectedVectors (
92- selectedVectors ,
88+ fieldVectors ,
9389 similarityFunction ,
9490 useFloat ,
9591 thisDocVector ,
@@ -108,15 +104,16 @@ public SearchHits diversify(SearchHits hits, ResultDiversificationContext divers
108104 cachedSimilarities .put (docId , cachedScoresForDoc );
109105 }
110106
111- rerankedDocIds .add (thisMaxMMRDocId );
112- selectedVectors .put (thisMaxMMRDocId , fieldVectors .get (thisMaxMMRDocId ));
107+ if (thisMaxMMRDocId >= 0 ) {
108+ selectedDocIds .add (thisMaxMMRDocId );
109+ }
113110 }
114111
115- // our return should be only those docs that are selected
116- SearchHit [] ret = new SearchHit [rerankedDocIds .size ()];
117- for (int i = 0 ; i < rerankedDocIds .size (); i ++) {
118- int scoredDocIndex = docIdIndexMapping .get (rerankedDocIds .get (i ));
119- ret [i ] = docs [scoredDocIndex ];
112+ // our return should be only those searchHits that are selected
113+ SearchHit [] ret = new SearchHit [selectedDocIds .size ()];
114+ for (int i = 0 ; i < selectedDocIds .size (); i ++) {
115+ int scoredDocIndex = docIdIndexMapping .get (selectedDocIds .get (i ));
116+ ret [i ] = searchHits [scoredDocIndex ];
120117 }
121118
122119 return new SearchHits (
@@ -129,13 +126,6 @@ public SearchHits diversify(SearchHits hits, ResultDiversificationContext divers
129126 );
130127 }
131128
132- @ Override
133- public Explanation explain (int topLevelDocId , ResultDiversificationContext diversificationContext , Explanation sourceExplanation )
134- throws IOException {
135- // TODO
136- return null ;
137- }
138-
139129 private float getHighestScoreForSelectedVectors (
140130 Map <Integer , VectorData > selectedVectors ,
141131 VectorSimilarityFunction similarityFunction ,
@@ -163,15 +153,15 @@ private float getHighestScoreForSelectedVectors(
163153 }
164154
165155 protected Map <Integer , Float > getQuerySimilarityForDocs (
166- SearchHit [] docs ,
156+ SearchHit [] searchHits ,
167157 Map <Integer , VectorData > fieldVectors ,
168158 VectorSimilarityFunction similarityFunction ,
169159 boolean useFloat ,
170160 ResultDiversificationContext context
171161 ) {
172162 Map <Integer , Float > querySimilarity = new HashMap <>();
173- for (int i = 0 ; i < docs . length ; i ++ ) {
174- int docId = docs [ i ] .docId ();
163+ for (SearchHit searchHit : searchHits ) {
164+ int docId = searchHit .docId ();
175165 VectorData vectorData = fieldVectors .get (docId );
176166 if (vectorData != null ) {
177167 float querySimilarityScore = getVectorComparisonScore (similarityFunction , useFloat , vectorData , context .getQueryVector ());
0 commit comments