1010package org .elasticsearch .search .diversification .mmr ;
1111
1212import org .apache .lucene .index .VectorSimilarityFunction ;
13+ import org .apache .lucene .search .ScoreDoc ;
14+ import org .apache .lucene .search .TopDocs ;
1315import org .elasticsearch .index .mapper .vectors .DenseVectorFieldMapper ;
14- import org .elasticsearch .search .SearchHit ;
15- import org .elasticsearch .search .SearchHits ;
1616import org .elasticsearch .search .diversification .ResultDiversification ;
1717import org .elasticsearch .search .diversification .ResultDiversificationContext ;
1818import org .elasticsearch .search .vectors .VectorData ;
2626public class MMRResultDiversification extends ResultDiversification {
2727
2828 @ Override
29- public SearchHits diversify (SearchHits hits , ResultDiversificationContext diversificationContext ) throws IOException {
30- if (hits == null || ((diversificationContext instanceof MMRResultDiversificationContext ) == false )) {
31- return hits ;
29+ public TopDocs diversify (TopDocs topDocs , ResultDiversificationContext diversificationContext ) throws IOException {
30+ if (topDocs == null || ((diversificationContext instanceof MMRResultDiversificationContext ) == false )) {
31+ return topDocs ;
3232 }
3333
3434 MMRResultDiversificationContext context = (MMRResultDiversificationContext ) diversificationContext ;
35- SearchHit [] searchHits = hits .getHits (); // NOTE: by reference, not new array
3635
37- if (searchHits .length == 0 ) {
38- return hits ;
36+ if (topDocs . scoreDocs == null || topDocs . scoreDocs .length == 0 ) {
37+ return topDocs ;
3938 }
4039
4140 Map <Integer , Integer > docIdIndexMapping = new HashMap <>();
42- Map <Integer , VectorData > fieldVectors = getFieldVectorsForHits (searchHits , context , docIdIndexMapping );
41+ for (int i = 0 ; i < topDocs .scoreDocs .length ; i ++) {
42+ docIdIndexMapping .put (topDocs .scoreDocs [i ].doc , i );
43+ }
4344
4445 VectorSimilarityFunction similarityFunction = DenseVectorFieldMapper .VectorSimilarity .MAX_INNER_PRODUCT .vectorSimilarityFunction (
4546 context .getIndexVersion (),
@@ -52,48 +53,55 @@ public SearchHits diversify(SearchHits hits, ResultDiversificationContext divers
5253 // always add the highest scoring doc to the list
5354 int highestScoreDocId = -1 ;
5455 float highestScore = Float .MIN_VALUE ;
55- for (SearchHit hit : searchHits ) {
56- if (hit . getScore () > highestScore ) {
57- highestScoreDocId = hit . docId () ;
58- highestScore = hit . getScore () ;
56+ for (ScoreDoc doc : topDocs . scoreDocs ) {
57+ if (doc . score > highestScore ) {
58+ highestScoreDocId = doc . doc ;
59+ highestScore = doc . score ;
5960 }
6061 }
6162 selectedDocIds .add (highestScoreDocId );
6263
6364 // test the vector to see if we are using floats or bytes
64- VectorData firstVec = fieldVectors . get (highestScoreDocId );
65+ VectorData firstVec = context . getFieldVector (highestScoreDocId );
6566 boolean useFloat = firstVec .isFloat ();
6667
6768 // cache the similarity scores for the query vector vs. searchHits
68- Map <Integer , Float > querySimilarity = getQuerySimilarityForDocs (searchHits , fieldVectors , similarityFunction , useFloat , context );
69+ Map <Integer , Float > querySimilarity = getQuerySimilarityForDocs (topDocs . scoreDocs , similarityFunction , useFloat , context );
6970
7071 Map <Integer , Map <Integer , Float >> cachedSimilarities = new HashMap <>();
7172 int numCandidates = context .getNumCandidates ();
7273
73- for (int x = 0 ; x < numCandidates && selectedDocIds .size () < numCandidates && selectedDocIds .size () < searchHits .length ; x ++) {
74+ for (int x = 0 ; x < numCandidates
75+ && selectedDocIds .size () < numCandidates
76+ && selectedDocIds .size () < topDocs .scoreDocs .length ; x ++) {
7477 int thisMaxMMRDocId = -1 ;
7578 float thisMaxMMRScore = Float .NEGATIVE_INFINITY ;
76- for (SearchHit thisHit : searchHits ) {
77- int docId = thisHit . docId () ;
79+ for (ScoreDoc doc : topDocs . scoreDocs ) {
80+ int docId = doc . doc ;
7881
7982 if (selectedDocIds .contains (docId )) {
8083 continue ;
8184 }
8285
83- var thisDocVector = fieldVectors .get (docId );
86+ var thisDocVector = context .getFieldVector (docId );
87+ if (thisDocVector == null ) {
88+ continue ;
89+ }
90+
8491 var cachedScoresForDoc = cachedSimilarities .getOrDefault (docId , new HashMap <>());
8592
8693 // compute MMR scores for remaining searchHits
8794 float highestMMRScore = getHighestScoreForSelectedVectors (
88- fieldVectors ,
95+ docId ,
96+ context ,
8997 similarityFunction ,
9098 useFloat ,
9199 thisDocVector ,
92100 cachedScoresForDoc
93101 );
94102
95103 // compute MMR
96- float querySimilarityScore = querySimilarity .getOrDefault (thisHit . docId () , 0.0f );
104+ float querySimilarityScore = querySimilarity .getOrDefault (doc . doc , 0.0f );
97105 float mmr = (context .getLambda () * querySimilarityScore ) - ((1 - context .getLambda ()) * highestMMRScore );
98106 if (mmr > thisMaxMMRScore ) {
99107 thisMaxMMRScore = mmr ;
@@ -110,34 +118,29 @@ public SearchHits diversify(SearchHits hits, ResultDiversificationContext divers
110118 }
111119
112120 // our return should be only those searchHits that are selected
113- SearchHit [] ret = new SearchHit [selectedDocIds .size ()];
121+ ScoreDoc [] ret = new ScoreDoc [selectedDocIds .size ()];
114122 for (int i = 0 ; i < selectedDocIds .size (); i ++) {
115123 int scoredDocIndex = docIdIndexMapping .get (selectedDocIds .get (i ));
116- ret [i ] = searchHits [scoredDocIndex ];
124+ ret [i ] = topDocs . scoreDocs [scoredDocIndex ];
117125 }
118126
119- // cleanup for GC
120- searchHits = null ;
121-
122- return new SearchHits (
123- ret ,
124- hits .getTotalHits (),
125- hits .getMaxScore (),
126- hits .getSortFields (),
127- hits .getCollapseField (),
128- hits .getCollapseValues ()
129- );
127+ return new TopDocs (topDocs .totalHits , ret );
130128 }
131129
132130 private float getHighestScoreForSelectedVectors (
133- Map <Integer , VectorData > selectedVectors ,
131+ int docId ,
132+ MMRResultDiversificationContext context ,
134133 VectorSimilarityFunction similarityFunction ,
135134 boolean useFloat ,
136135 VectorData thisDocVector ,
137136 Map <Integer , Float > cachedScoresForDoc
138137 ) {
139138 float highestScore = Float .MIN_VALUE ;
140- for (var vec : selectedVectors .entrySet ()) {
139+ for (var vec : context .getFieldVectorsEntrySet ()) {
140+ if (vec .getKey ().equals (docId )) {
141+ continue ;
142+ }
143+
141144 if (cachedScoresForDoc .containsKey (vec .getKey ())) {
142145 float score = cachedScoresForDoc .get (vec .getKey ());
143146 if (score > highestScore ) {
@@ -156,19 +159,17 @@ private float getHighestScoreForSelectedVectors(
156159 }
157160
158161 protected Map <Integer , Float > getQuerySimilarityForDocs (
159- SearchHit [] searchHits ,
160- Map <Integer , VectorData > fieldVectors ,
162+ ScoreDoc [] docs ,
161163 VectorSimilarityFunction similarityFunction ,
162164 boolean useFloat ,
163165 ResultDiversificationContext context
164166 ) {
165167 Map <Integer , Float > querySimilarity = new HashMap <>();
166- for (SearchHit searchHit : searchHits ) {
167- int docId = searchHit .docId ();
168- VectorData vectorData = fieldVectors .get (docId );
168+ for (ScoreDoc doc : docs ) {
169+ VectorData vectorData = context .getFieldVector (doc .doc );
169170 if (vectorData != null ) {
170171 float querySimilarityScore = getVectorComparisonScore (similarityFunction , useFloat , vectorData , context .getQueryVector ());
171- querySimilarity .put (docId , querySimilarityScore );
172+ querySimilarity .put (doc . doc , querySimilarityScore );
172173 }
173174 }
174175 return querySimilarity ;
0 commit comments