Skip to content

Commit 23494e4

Browse files
committed
cleanup and complete implementation
1 parent 23f5caa commit 23494e4

File tree

3 files changed

+39
-56
lines changed

3 files changed

+39
-56
lines changed

server/src/main/java/org/elasticsearch/search/diversification/ResultDiversification.java

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
package org.elasticsearch.search.diversification;
1111

1212
import org.apache.lucene.index.VectorSimilarityFunction;
13-
import org.apache.lucene.search.Explanation;
1413
import org.elasticsearch.search.SearchHit;
1514
import org.elasticsearch.search.SearchHits;
1615
import org.elasticsearch.search.vectors.VectorData;
@@ -26,20 +25,14 @@ public abstract class ResultDiversification {
2625

2726
public abstract SearchHits diversify(SearchHits hits, ResultDiversificationContext diversificationContext) throws IOException;
2827

29-
public abstract Explanation explain(
30-
int topLevelDocId,
31-
ResultDiversificationContext diversificationContext,
32-
Explanation sourceExplanation
33-
) throws IOException;
34-
3528
protected Map<Integer, VectorData> getFieldVectorsForHits(
36-
SearchHit[] hits,
29+
SearchHit[] searchHits,
3730
ResultDiversificationContext context,
3831
Map<Integer, Integer> docIdIndexMapping
3932
) {
4033
Map<Integer, VectorData> fieldVectors = new HashMap<>();
41-
for (int i = 0; i < hits.length; i++) {
42-
SearchHit hit = hits[i];
34+
for (int i = 0; i < searchHits.length; i++) {
35+
SearchHit hit = searchHits[i];
4336
int docId = hit.docId();
4437
docIdIndexMapping.put(docId, i);
4538
Object collapseValue = hit.field(context.getField()).getValue();

server/src/main/java/org/elasticsearch/search/diversification/mmr/MMRResultDiversification.java

Lines changed: 32 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
package org.elasticsearch.search.diversification.mmr;
1111

1212
import org.apache.lucene.index.VectorSimilarityFunction;
13-
import org.apache.lucene.search.Explanation;
1413
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
1514
import org.elasticsearch.search.SearchHit;
1615
import org.elasticsearch.search.SearchHits;
@@ -33,63 +32,60 @@ public SearchHits diversify(SearchHits hits, ResultDiversificationContext divers
3332
}
3433

3534
MMRResultDiversificationContext context = (MMRResultDiversificationContext) diversificationContext;
36-
SearchHit[] docs = hits.getHits(); // NOTE: by reference, not new array
35+
SearchHit[] searchHits = hits.getHits(); // NOTE: by reference, not new array
3736

38-
if (docs.length == 0) {
37+
if (searchHits.length == 0) {
3938
return hits;
4039
}
4140

4241
Map<Integer, Integer> docIdIndexMapping = new HashMap<>();
43-
Map<Integer, VectorData> fieldVectors = getFieldVectorsForHits(docs, context, docIdIndexMapping);
42+
Map<Integer, VectorData> fieldVectors = getFieldVectorsForHits(searchHits, context, docIdIndexMapping);
4443

4544
VectorSimilarityFunction similarityFunction = DenseVectorFieldMapper.VectorSimilarity.MAX_INNER_PRODUCT.vectorSimilarityFunction(
4645
context.getIndexVersion(),
4746
diversificationContext.getElementType()
4847
);
4948

50-
List<Integer> rerankedDocIds = new ArrayList<>();
51-
Map<Integer, VectorData> selectedVectors = new HashMap<>();
49+
// our chosen DocIDs to keep
50+
List<Integer> selectedDocIds = new ArrayList<>();
5251

5352
// always add the highest scoring doc to the list
54-
int highestDocIdIndex = -1;
53+
int highestScoreDocId = -1;
5554
float highestScore = Float.MIN_VALUE;
56-
for (int i = 0; i < docs.length; i++) {
57-
if (docs[i].getScore() > highestScore) {
58-
highestDocIdIndex = i;
59-
highestScore = docs[i].getScore();
55+
for (SearchHit hit : searchHits) {
56+
if (hit.getScore() > highestScore) {
57+
highestScoreDocId = hit.docId();
58+
highestScore = hit.getScore();
6059
}
6160
}
62-
int firstDocId = docs[highestDocIdIndex].docId();
63-
rerankedDocIds.add(firstDocId);
61+
selectedDocIds.add(highestScoreDocId);
6462

65-
// and add the vector for the first items
66-
VectorData firstVec = fieldVectors.get(firstDocId);
67-
selectedVectors.put(firstDocId, firstVec);
63+
// test the vector to see if we are using floats or bytes
64+
VectorData firstVec = fieldVectors.get(highestScoreDocId);
6865
boolean useFloat = firstVec.isFloat();
6966

70-
// cache the similarity scores for the query vector vs. docs
71-
Map<Integer, Float> querySimilarity = getQuerySimilarityForDocs(docs, fieldVectors, similarityFunction, useFloat, context);
67+
// cache the similarity scores for the query vector vs. searchHits
68+
Map<Integer, Float> querySimilarity = getQuerySimilarityForDocs(searchHits, fieldVectors, similarityFunction, useFloat, context);
7269

7370
Map<Integer, Map<Integer, Float>> cachedSimilarities = new HashMap<>();
7471
int numCandidates = context.getNumCandidates();
7572

76-
for (int x = 0; x < numCandidates && rerankedDocIds.size() < numCandidates && rerankedDocIds.size() < docs.length; x++) {
73+
for (int x = 0; x < numCandidates && selectedDocIds.size() < numCandidates && selectedDocIds.size() < searchHits.length; x++) {
7774
int thisMaxMMRDocId = -1;
78-
float thisMaxMMRScore = Float.MIN_VALUE;
79-
for (SearchHit thisHit : docs) {
75+
float thisMaxMMRScore = Float.NEGATIVE_INFINITY;
76+
for (SearchHit thisHit : searchHits) {
8077
int docId = thisHit.docId();
8178

82-
if (rerankedDocIds.contains(docId)) {
79+
if (selectedDocIds.contains(docId)) {
8380
continue;
8481
}
8582

8683
var thisDocVector = fieldVectors.get(docId);
87-
8884
var cachedScoresForDoc = cachedSimilarities.getOrDefault(docId, new HashMap<>());
8985

90-
// compute MMR scores for remaining docs
86+
// compute MMR scores for remaining searchHits
9187
float highestMMRScore = getHighestScoreForSelectedVectors(
92-
selectedVectors,
88+
fieldVectors,
9389
similarityFunction,
9490
useFloat,
9591
thisDocVector,
@@ -108,15 +104,16 @@ public SearchHits diversify(SearchHits hits, ResultDiversificationContext divers
108104
cachedSimilarities.put(docId, cachedScoresForDoc);
109105
}
110106

111-
rerankedDocIds.add(thisMaxMMRDocId);
112-
selectedVectors.put(thisMaxMMRDocId, fieldVectors.get(thisMaxMMRDocId));
107+
if (thisMaxMMRDocId >= 0) {
108+
selectedDocIds.add(thisMaxMMRDocId);
109+
}
113110
}
114111

115-
// our return should be only those docs that are selected
116-
SearchHit[] ret = new SearchHit[rerankedDocIds.size()];
117-
for (int i = 0; i < rerankedDocIds.size(); i++) {
118-
int scoredDocIndex = docIdIndexMapping.get(rerankedDocIds.get(i));
119-
ret[i] = docs[scoredDocIndex];
112+
// our return should be only those searchHits that are selected
113+
SearchHit[] ret = new SearchHit[selectedDocIds.size()];
114+
for (int i = 0; i < selectedDocIds.size(); i++) {
115+
int scoredDocIndex = docIdIndexMapping.get(selectedDocIds.get(i));
116+
ret[i] = searchHits[scoredDocIndex];
120117
}
121118

122119
return new SearchHits(
@@ -129,13 +126,6 @@ public SearchHits diversify(SearchHits hits, ResultDiversificationContext divers
129126
);
130127
}
131128

132-
@Override
133-
public Explanation explain(int topLevelDocId, ResultDiversificationContext diversificationContext, Explanation sourceExplanation)
134-
throws IOException {
135-
// TODO
136-
return null;
137-
}
138-
139129
private float getHighestScoreForSelectedVectors(
140130
Map<Integer, VectorData> selectedVectors,
141131
VectorSimilarityFunction similarityFunction,
@@ -163,15 +153,15 @@ private float getHighestScoreForSelectedVectors(
163153
}
164154

165155
protected Map<Integer, Float> getQuerySimilarityForDocs(
166-
SearchHit[] docs,
156+
SearchHit[] searchHits,
167157
Map<Integer, VectorData> fieldVectors,
168158
VectorSimilarityFunction similarityFunction,
169159
boolean useFloat,
170160
ResultDiversificationContext context
171161
) {
172162
Map<Integer, Float> querySimilarity = new HashMap<>();
173-
for (int i = 0; i < docs.length; i++) {
174-
int docId = docs[i].docId();
163+
for (SearchHit searchHit : searchHits) {
164+
int docId = searchHit.docId();
175165
VectorData vectorData = fieldVectors.get(docId);
176166
if (vectorData != null) {
177167
float querySimilarityScore = getVectorComparisonScore(similarityFunction, useFloat, vectorData, context.getQueryVector());

server/src/test/java/org/elasticsearch/search/diversification/mmr/MMRResultDiversificationTests.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ public void testMMRDiversification() throws IOException {
4242
var queryVectorData = new VectorData(new float[] { 0.5f, 0.2f, 0.4f, 0.4f });
4343
var diversificationContext = new MMRResultDiversificationContext(
4444
"dense_vector_field",
45-
0.6f,
45+
0.3f,
4646
3,
4747
queryVectorData,
4848
fieldMapper,
@@ -54,8 +54,8 @@ public void testMMRDiversification() throws IOException {
5454
generateSearchHit(2, 1.8f, 2, new float[] { 0.4f, 0.2f, 0.3f, 0.3f }),
5555
generateSearchHit(3, 1.6f, 3, new float[] { 0.4f, 0.1f, 0.3f, 0.3f }),
5656
generateSearchHit(4, 1.0f, 4, new float[] { 0.1f, 0.9f, 0.5f, 0.9f }),
57-
generateSearchHit(5, 0.9f, 5, new float[] { 0.1f, 0.9f, 0.5f, 0.8f }),
58-
generateSearchHit(6, 0.5f, 6, new float[] { 0.05f, 0.05f, 0.05f, 0.05f }) };
57+
generateSearchHit(5, 0.8f, 5, new float[] { 0.1f, 0.9f, 0.5f, 0.9f }),
58+
generateSearchHit(6, 0.8f, 6, new float[] { 0.05f, 0.05f, 0.05f, 0.05f }) };
5959

6060
TotalHits totalHits = new TotalHits(6L, TotalHits.Relation.EQUAL_TO);
6161
SearchHits searchHits = new SearchHits(hits, totalHits, 2.0f);
@@ -66,7 +66,7 @@ public void testMMRDiversification() throws IOException {
6666

6767
assertEquals(3, diversifiedHits.getHits().length);
6868
assertEquals(1, diversifiedHits.getHits()[0].docId());
69-
assertEquals(4, diversifiedHits.getHits()[1].docId());
69+
assertEquals(6, diversifiedHits.getHits()[1].docId());
7070
assertEquals(3, diversifiedHits.getHits()[2].docId());
7171
}
7272

0 commit comments

Comments
 (0)