Skip to content

Commit d760242

Browse files
committed
use TopDocs and not ScoreDocs
1 parent a7c7070 commit d760242

File tree

5 files changed

+113
-107
lines changed

5 files changed

+113
-107
lines changed

server/src/main/java/org/elasticsearch/search/diversification/ResultDiversification.java

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
package org.elasticsearch.search.diversification;
1111

1212
import org.apache.lucene.index.VectorSimilarityFunction;
13-
import org.elasticsearch.search.SearchHit;
14-
import org.elasticsearch.search.SearchHits;
13+
import org.apache.lucene.search.ScoreDoc;
14+
import org.apache.lucene.search.TopDocs;
1515
import org.elasticsearch.search.vectors.VectorData;
1616

1717
import java.io.IOException;
@@ -23,24 +23,24 @@
2323
*/
2424
public abstract class ResultDiversification {
2525

26-
public abstract SearchHits diversify(SearchHits hits, ResultDiversificationContext diversificationContext) throws IOException;
26+
public abstract TopDocs diversify(TopDocs hits, ResultDiversificationContext diversificationContext) throws IOException;
2727

2828
protected Map<Integer, VectorData> getFieldVectorsForHits(
29-
SearchHit[] searchHits,
29+
ScoreDoc[] docs,
3030
ResultDiversificationContext context,
3131
Map<Integer, Integer> docIdIndexMapping
3232
) {
3333
Map<Integer, VectorData> fieldVectors = new HashMap<>();
34-
for (int i = 0; i < searchHits.length; i++) {
35-
SearchHit hit = searchHits[i];
36-
int docId = hit.docId();
34+
for (int i = 0; i < docs.length; i++) {
35+
ScoreDoc hit = docs[i];
36+
int docId = hit.doc;
3737
docIdIndexMapping.put(docId, i);
38-
Object collapseValue = hit.field(context.getField()).getValue();
39-
if (collapseValue instanceof float[] vecData) {
40-
fieldVectors.put(docId, new VectorData(vecData));
41-
} else if (collapseValue instanceof byte[] byteVecData) {
42-
fieldVectors.put(docId, new VectorData(byteVecData));
43-
}
38+
// hit.Object collapseValue = hit.field(context.getField()).getValue();
39+
// if (collapseValue instanceof float[] vecData) {
40+
// fieldVectors.put(docId, new VectorData(vecData));
41+
// } else if (collapseValue instanceof byte[] byteVecData) {
42+
// fieldVectors.put(docId, new VectorData(byteVecData));
43+
// }
4444
}
4545
return fieldVectors;
4646
}

server/src/main/java/org/elasticsearch/search/diversification/ResultDiversificationContext.java

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,26 +13,32 @@
1313
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
1414
import org.elasticsearch.search.vectors.VectorData;
1515

16+
import java.util.Map;
17+
import java.util.Set;
18+
1619
public abstract class ResultDiversificationContext {
1720
private final String field;
1821
private final int numCandidates;
1922
private final DenseVectorFieldMapper fieldMapper;
2023
private final IndexVersion indexVersion;
2124
private final VectorData queryVector;
25+
private final Map<Integer, VectorData> fieldVectors;
2226

2327
// Field _must_ be a dense_vector type
2428
protected ResultDiversificationContext(
2529
String field,
2630
int numCandidates,
2731
VectorData queryVector,
2832
DenseVectorFieldMapper fieldMapper,
29-
IndexVersion indexVersion
33+
IndexVersion indexVersion,
34+
Map<Integer, VectorData> fieldVectors
3035
) {
3136
this.field = field;
3237
this.numCandidates = numCandidates;
3338
this.fieldMapper = fieldMapper;
3439
this.indexVersion = indexVersion;
3540
this.queryVector = queryVector;
41+
this.fieldVectors = fieldVectors;
3642
}
3743

3844
public String getField() {
@@ -58,4 +64,12 @@ public IndexVersion getIndexVersion() {
5864
public VectorData getQueryVector() {
5965
return queryVector;
6066
}
67+
68+
public VectorData getFieldVector(int docId) {
69+
return fieldVectors.getOrDefault(docId, null);
70+
}
71+
72+
public Set<Map.Entry<Integer, VectorData>> getFieldVectorsEntrySet() {
73+
return fieldVectors.entrySet();
74+
}
6175
}

server/src/main/java/org/elasticsearch/search/diversification/mmr/MMRResultDiversification.java

Lines changed: 43 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
package org.elasticsearch.search.diversification.mmr;
1111

1212
import org.apache.lucene.index.VectorSimilarityFunction;
13+
import org.apache.lucene.search.ScoreDoc;
14+
import org.apache.lucene.search.TopDocs;
1315
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
14-
import org.elasticsearch.search.SearchHit;
15-
import org.elasticsearch.search.SearchHits;
1616
import org.elasticsearch.search.diversification.ResultDiversification;
1717
import org.elasticsearch.search.diversification.ResultDiversificationContext;
1818
import org.elasticsearch.search.vectors.VectorData;
@@ -26,20 +26,21 @@
2626
public class MMRResultDiversification extends ResultDiversification {
2727

2828
@Override
29-
public SearchHits diversify(SearchHits hits, ResultDiversificationContext diversificationContext) throws IOException {
30-
if (hits == null || ((diversificationContext instanceof MMRResultDiversificationContext) == false)) {
31-
return hits;
29+
public TopDocs diversify(TopDocs topDocs, ResultDiversificationContext diversificationContext) throws IOException {
30+
if (topDocs == null || ((diversificationContext instanceof MMRResultDiversificationContext) == false)) {
31+
return topDocs;
3232
}
3333

3434
MMRResultDiversificationContext context = (MMRResultDiversificationContext) diversificationContext;
35-
SearchHit[] searchHits = hits.getHits(); // NOTE: by reference, not new array
3635

37-
if (searchHits.length == 0) {
38-
return hits;
36+
if (topDocs.scoreDocs == null || topDocs.scoreDocs.length == 0) {
37+
return topDocs;
3938
}
4039

4140
Map<Integer, Integer> docIdIndexMapping = new HashMap<>();
42-
Map<Integer, VectorData> fieldVectors = getFieldVectorsForHits(searchHits, context, docIdIndexMapping);
41+
for (int i = 0; i < topDocs.scoreDocs.length; i++) {
42+
docIdIndexMapping.put(topDocs.scoreDocs[i].doc, i);
43+
}
4344

4445
VectorSimilarityFunction similarityFunction = DenseVectorFieldMapper.VectorSimilarity.MAX_INNER_PRODUCT.vectorSimilarityFunction(
4546
context.getIndexVersion(),
@@ -52,48 +53,55 @@ public SearchHits diversify(SearchHits hits, ResultDiversificationContext divers
5253
// always add the highest scoring doc to the list
5354
int highestScoreDocId = -1;
5455
float highestScore = Float.MIN_VALUE;
55-
for (SearchHit hit : searchHits) {
56-
if (hit.getScore() > highestScore) {
57-
highestScoreDocId = hit.docId();
58-
highestScore = hit.getScore();
56+
for (ScoreDoc doc : topDocs.scoreDocs) {
57+
if (doc.score > highestScore) {
58+
highestScoreDocId = doc.doc;
59+
highestScore = doc.score;
5960
}
6061
}
6162
selectedDocIds.add(highestScoreDocId);
6263

6364
// test the vector to see if we are using floats or bytes
64-
VectorData firstVec = fieldVectors.get(highestScoreDocId);
65+
VectorData firstVec = context.getFieldVector(highestScoreDocId);
6566
boolean useFloat = firstVec.isFloat();
6667

6768
// cache the similarity scores for the query vector vs. searchHits
68-
Map<Integer, Float> querySimilarity = getQuerySimilarityForDocs(searchHits, fieldVectors, similarityFunction, useFloat, context);
69+
Map<Integer, Float> querySimilarity = getQuerySimilarityForDocs(topDocs.scoreDocs, similarityFunction, useFloat, context);
6970

7071
Map<Integer, Map<Integer, Float>> cachedSimilarities = new HashMap<>();
7172
int numCandidates = context.getNumCandidates();
7273

73-
for (int x = 0; x < numCandidates && selectedDocIds.size() < numCandidates && selectedDocIds.size() < searchHits.length; x++) {
74+
for (int x = 0; x < numCandidates
75+
&& selectedDocIds.size() < numCandidates
76+
&& selectedDocIds.size() < topDocs.scoreDocs.length; x++) {
7477
int thisMaxMMRDocId = -1;
7578
float thisMaxMMRScore = Float.NEGATIVE_INFINITY;
76-
for (SearchHit thisHit : searchHits) {
77-
int docId = thisHit.docId();
79+
for (ScoreDoc doc : topDocs.scoreDocs) {
80+
int docId = doc.doc;
7881

7982
if (selectedDocIds.contains(docId)) {
8083
continue;
8184
}
8285

83-
var thisDocVector = fieldVectors.get(docId);
86+
var thisDocVector = context.getFieldVector(docId);
87+
if (thisDocVector == null) {
88+
continue;
89+
}
90+
8491
var cachedScoresForDoc = cachedSimilarities.getOrDefault(docId, new HashMap<>());
8592

8693
// compute MMR scores for remaining searchHits
8794
float highestMMRScore = getHighestScoreForSelectedVectors(
88-
fieldVectors,
95+
docId,
96+
context,
8997
similarityFunction,
9098
useFloat,
9199
thisDocVector,
92100
cachedScoresForDoc
93101
);
94102

95103
// compute MMR
96-
float querySimilarityScore = querySimilarity.getOrDefault(thisHit.docId(), 0.0f);
104+
float querySimilarityScore = querySimilarity.getOrDefault(doc.doc, 0.0f);
97105
float mmr = (context.getLambda() * querySimilarityScore) - ((1 - context.getLambda()) * highestMMRScore);
98106
if (mmr > thisMaxMMRScore) {
99107
thisMaxMMRScore = mmr;
@@ -110,34 +118,29 @@ public SearchHits diversify(SearchHits hits, ResultDiversificationContext divers
110118
}
111119

112120
// our return should be only those searchHits that are selected
113-
SearchHit[] ret = new SearchHit[selectedDocIds.size()];
121+
ScoreDoc[] ret = new ScoreDoc[selectedDocIds.size()];
114122
for (int i = 0; i < selectedDocIds.size(); i++) {
115123
int scoredDocIndex = docIdIndexMapping.get(selectedDocIds.get(i));
116-
ret[i] = searchHits[scoredDocIndex];
124+
ret[i] = topDocs.scoreDocs[scoredDocIndex];
117125
}
118126

119-
// cleanup for GC
120-
searchHits = null;
121-
122-
return new SearchHits(
123-
ret,
124-
hits.getTotalHits(),
125-
hits.getMaxScore(),
126-
hits.getSortFields(),
127-
hits.getCollapseField(),
128-
hits.getCollapseValues()
129-
);
127+
return new TopDocs(topDocs.totalHits, ret);
130128
}
131129

132130
private float getHighestScoreForSelectedVectors(
133-
Map<Integer, VectorData> selectedVectors,
131+
int docId,
132+
MMRResultDiversificationContext context,
134133
VectorSimilarityFunction similarityFunction,
135134
boolean useFloat,
136135
VectorData thisDocVector,
137136
Map<Integer, Float> cachedScoresForDoc
138137
) {
139138
float highestScore = Float.MIN_VALUE;
140-
for (var vec : selectedVectors.entrySet()) {
139+
for (var vec : context.getFieldVectorsEntrySet()) {
140+
if (vec.getKey().equals(docId)) {
141+
continue;
142+
}
143+
141144
if (cachedScoresForDoc.containsKey(vec.getKey())) {
142145
float score = cachedScoresForDoc.get(vec.getKey());
143146
if (score > highestScore) {
@@ -156,19 +159,17 @@ private float getHighestScoreForSelectedVectors(
156159
}
157160

158161
protected Map<Integer, Float> getQuerySimilarityForDocs(
159-
SearchHit[] searchHits,
160-
Map<Integer, VectorData> fieldVectors,
162+
ScoreDoc[] docs,
161163
VectorSimilarityFunction similarityFunction,
162164
boolean useFloat,
163165
ResultDiversificationContext context
164166
) {
165167
Map<Integer, Float> querySimilarity = new HashMap<>();
166-
for (SearchHit searchHit : searchHits) {
167-
int docId = searchHit.docId();
168-
VectorData vectorData = fieldVectors.get(docId);
168+
for (ScoreDoc doc : docs) {
169+
VectorData vectorData = context.getFieldVector(doc.doc);
169170
if (vectorData != null) {
170171
float querySimilarityScore = getVectorComparisonScore(similarityFunction, useFloat, vectorData, context.getQueryVector());
171-
querySimilarity.put(docId, querySimilarityScore);
172+
querySimilarity.put(doc.doc, querySimilarityScore);
172173
}
173174
}
174175
return querySimilarity;

server/src/main/java/org/elasticsearch/search/diversification/mmr/MMRResultDiversificationContext.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
import org.elasticsearch.search.diversification.ResultDiversificationContext;
1515
import org.elasticsearch.search.vectors.VectorData;
1616

17+
import java.util.Map;
18+
1719
public class MMRResultDiversificationContext extends ResultDiversificationContext {
1820

1921
private final float lambda;
@@ -24,9 +26,10 @@ public MMRResultDiversificationContext(
2426
int numCandidates,
2527
VectorData queryVector,
2628
DenseVectorFieldMapper fieldMapper,
27-
IndexVersion indexVersion
29+
IndexVersion indexVersion,
30+
Map<Integer, VectorData> fieldVectors
2831
) {
29-
super(field, numCandidates, queryVector, fieldMapper, indexVersion);
32+
super(field, numCandidates, queryVector, fieldMapper, indexVersion, fieldVectors);
3033
this.lambda = lambda;
3134
}
3235

0 commit comments

Comments
 (0)