Skip to content

Commit 23f5caa

Browse files
committed
Internal MMR based result diversification method
1 parent 1521291 commit 23f5caa

File tree

5 files changed

+474
-0
lines changed

5 files changed

+474
-0
lines changed
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.search.diversification;
11+
12+
import org.apache.lucene.index.VectorSimilarityFunction;
13+
import org.apache.lucene.search.Explanation;
14+
import org.elasticsearch.search.SearchHit;
15+
import org.elasticsearch.search.SearchHits;
16+
import org.elasticsearch.search.vectors.VectorData;
17+
18+
import java.io.IOException;
19+
import java.util.HashMap;
20+
import java.util.Map;
21+
22+
/**
23+
* Base interface for result diversification.
24+
*/
25+
public abstract class ResultDiversification {
26+
27+
public abstract SearchHits diversify(SearchHits hits, ResultDiversificationContext diversificationContext) throws IOException;
28+
29+
public abstract Explanation explain(
30+
int topLevelDocId,
31+
ResultDiversificationContext diversificationContext,
32+
Explanation sourceExplanation
33+
) throws IOException;
34+
35+
protected Map<Integer, VectorData> getFieldVectorsForHits(
36+
SearchHit[] hits,
37+
ResultDiversificationContext context,
38+
Map<Integer, Integer> docIdIndexMapping
39+
) {
40+
Map<Integer, VectorData> fieldVectors = new HashMap<>();
41+
for (int i = 0; i < hits.length; i++) {
42+
SearchHit hit = hits[i];
43+
int docId = hit.docId();
44+
docIdIndexMapping.put(docId, i);
45+
Object collapseValue = hit.field(context.getField()).getValue();
46+
if (collapseValue instanceof float[] vecData) {
47+
fieldVectors.put(docId, new VectorData(vecData));
48+
} else if (collapseValue instanceof byte[] byteVecData) {
49+
fieldVectors.put(docId, new VectorData(byteVecData));
50+
}
51+
}
52+
return fieldVectors;
53+
}
54+
55+
protected float getVectorComparisonScore(
56+
VectorSimilarityFunction similarityFunction,
57+
boolean useFloat,
58+
VectorData thisDocVector,
59+
VectorData comparisonVector
60+
) {
61+
return useFloat
62+
? similarityFunction.compare(thisDocVector.floatVector(), comparisonVector.floatVector())
63+
: similarityFunction.compare(thisDocVector.byteVector(), comparisonVector.byteVector());
64+
}
65+
}
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.search.diversification;
11+
12+
import org.elasticsearch.index.IndexVersion;
13+
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
14+
import org.elasticsearch.search.vectors.VectorData;
15+
16+
public abstract class ResultDiversificationContext {
17+
private final String field;
18+
private final int numCandidates;
19+
private final DenseVectorFieldMapper fieldMapper;
20+
private final IndexVersion indexVersion;
21+
private final VectorData queryVector;
22+
23+
// Field _must_ be a dense_vector type
24+
protected ResultDiversificationContext(
25+
String field,
26+
int numCandidates,
27+
VectorData queryVector,
28+
DenseVectorFieldMapper fieldMapper,
29+
IndexVersion indexVersion
30+
) {
31+
this.field = field;
32+
this.numCandidates = numCandidates;
33+
this.fieldMapper = fieldMapper;
34+
this.indexVersion = indexVersion;
35+
this.queryVector = queryVector;
36+
}
37+
38+
public String getField() {
39+
return field;
40+
}
41+
42+
public int getNumCandidates() {
43+
return numCandidates;
44+
}
45+
46+
public DenseVectorFieldMapper getFieldMapper() {
47+
return fieldMapper;
48+
}
49+
50+
public DenseVectorFieldMapper.ElementType getElementType() {
51+
return fieldMapper.fieldType().getElementType();
52+
}
53+
54+
public IndexVersion getIndexVersion() {
55+
return indexVersion;
56+
}
57+
58+
public VectorData getQueryVector() {
59+
return queryVector;
60+
}
61+
}
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.search.diversification.mmr;
11+
12+
import org.apache.lucene.index.VectorSimilarityFunction;
13+
import org.apache.lucene.search.Explanation;
14+
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
15+
import org.elasticsearch.search.SearchHit;
16+
import org.elasticsearch.search.SearchHits;
17+
import org.elasticsearch.search.diversification.ResultDiversification;
18+
import org.elasticsearch.search.diversification.ResultDiversificationContext;
19+
import org.elasticsearch.search.vectors.VectorData;
20+
21+
import java.io.IOException;
22+
import java.util.ArrayList;
23+
import java.util.HashMap;
24+
import java.util.List;
25+
import java.util.Map;
26+
27+
public class MMRResultDiversification extends ResultDiversification {
28+
29+
@Override
30+
public SearchHits diversify(SearchHits hits, ResultDiversificationContext diversificationContext) throws IOException {
31+
if (hits == null || ((diversificationContext instanceof MMRResultDiversificationContext) == false)) {
32+
return hits;
33+
}
34+
35+
MMRResultDiversificationContext context = (MMRResultDiversificationContext) diversificationContext;
36+
SearchHit[] docs = hits.getHits(); // NOTE: by reference, not new array
37+
38+
if (docs.length == 0) {
39+
return hits;
40+
}
41+
42+
Map<Integer, Integer> docIdIndexMapping = new HashMap<>();
43+
Map<Integer, VectorData> fieldVectors = getFieldVectorsForHits(docs, context, docIdIndexMapping);
44+
45+
VectorSimilarityFunction similarityFunction = DenseVectorFieldMapper.VectorSimilarity.MAX_INNER_PRODUCT.vectorSimilarityFunction(
46+
context.getIndexVersion(),
47+
diversificationContext.getElementType()
48+
);
49+
50+
List<Integer> rerankedDocIds = new ArrayList<>();
51+
Map<Integer, VectorData> selectedVectors = new HashMap<>();
52+
53+
// always add the highest scoring doc to the list
54+
int highestDocIdIndex = -1;
55+
float highestScore = Float.MIN_VALUE;
56+
for (int i = 0; i < docs.length; i++) {
57+
if (docs[i].getScore() > highestScore) {
58+
highestDocIdIndex = i;
59+
highestScore = docs[i].getScore();
60+
}
61+
}
62+
int firstDocId = docs[highestDocIdIndex].docId();
63+
rerankedDocIds.add(firstDocId);
64+
65+
// and add the vector for the first items
66+
VectorData firstVec = fieldVectors.get(firstDocId);
67+
selectedVectors.put(firstDocId, firstVec);
68+
boolean useFloat = firstVec.isFloat();
69+
70+
// cache the similarity scores for the query vector vs. docs
71+
Map<Integer, Float> querySimilarity = getQuerySimilarityForDocs(docs, fieldVectors, similarityFunction, useFloat, context);
72+
73+
Map<Integer, Map<Integer, Float>> cachedSimilarities = new HashMap<>();
74+
int numCandidates = context.getNumCandidates();
75+
76+
for (int x = 0; x < numCandidates && rerankedDocIds.size() < numCandidates && rerankedDocIds.size() < docs.length; x++) {
77+
int thisMaxMMRDocId = -1;
78+
float thisMaxMMRScore = Float.MIN_VALUE;
79+
for (SearchHit thisHit : docs) {
80+
int docId = thisHit.docId();
81+
82+
if (rerankedDocIds.contains(docId)) {
83+
continue;
84+
}
85+
86+
var thisDocVector = fieldVectors.get(docId);
87+
88+
var cachedScoresForDoc = cachedSimilarities.getOrDefault(docId, new HashMap<>());
89+
90+
// compute MMR scores for remaining docs
91+
float highestMMRScore = getHighestScoreForSelectedVectors(
92+
selectedVectors,
93+
similarityFunction,
94+
useFloat,
95+
thisDocVector,
96+
cachedScoresForDoc
97+
);
98+
99+
// compute MMR
100+
float querySimilarityScore = querySimilarity.getOrDefault(thisHit.docId(), 0.0f);
101+
float mmr = (context.getLambda() * querySimilarityScore) - ((1 - context.getLambda()) * highestMMRScore);
102+
if (mmr > thisMaxMMRScore) {
103+
thisMaxMMRScore = mmr;
104+
thisMaxMMRDocId = docId;
105+
}
106+
107+
// cache these scores
108+
cachedSimilarities.put(docId, cachedScoresForDoc);
109+
}
110+
111+
rerankedDocIds.add(thisMaxMMRDocId);
112+
selectedVectors.put(thisMaxMMRDocId, fieldVectors.get(thisMaxMMRDocId));
113+
}
114+
115+
// our return should be only those docs that are selected
116+
SearchHit[] ret = new SearchHit[rerankedDocIds.size()];
117+
for (int i = 0; i < rerankedDocIds.size(); i++) {
118+
int scoredDocIndex = docIdIndexMapping.get(rerankedDocIds.get(i));
119+
ret[i] = docs[scoredDocIndex];
120+
}
121+
122+
return new SearchHits(
123+
ret,
124+
hits.getTotalHits(),
125+
hits.getMaxScore(),
126+
hits.getSortFields(),
127+
hits.getCollapseField(),
128+
hits.getCollapseValues()
129+
);
130+
}
131+
132+
@Override
133+
public Explanation explain(int topLevelDocId, ResultDiversificationContext diversificationContext, Explanation sourceExplanation)
134+
throws IOException {
135+
// TODO
136+
return null;
137+
}
138+
139+
private float getHighestScoreForSelectedVectors(
140+
Map<Integer, VectorData> selectedVectors,
141+
VectorSimilarityFunction similarityFunction,
142+
boolean useFloat,
143+
VectorData thisDocVector,
144+
Map<Integer, Float> cachedScoresForDoc
145+
) {
146+
float highestScore = Float.MIN_VALUE;
147+
for (var vec : selectedVectors.entrySet()) {
148+
if (cachedScoresForDoc.containsKey(vec.getKey())) {
149+
float score = cachedScoresForDoc.get(vec.getKey());
150+
if (score > highestScore) {
151+
highestScore = score;
152+
}
153+
} else {
154+
VectorData comparisonVector = vec.getValue();
155+
float score = getVectorComparisonScore(similarityFunction, useFloat, thisDocVector, comparisonVector);
156+
cachedScoresForDoc.put(vec.getKey(), score);
157+
if (score > highestScore) {
158+
highestScore = score;
159+
}
160+
}
161+
}
162+
return highestScore;
163+
}
164+
165+
protected Map<Integer, Float> getQuerySimilarityForDocs(
166+
SearchHit[] docs,
167+
Map<Integer, VectorData> fieldVectors,
168+
VectorSimilarityFunction similarityFunction,
169+
boolean useFloat,
170+
ResultDiversificationContext context
171+
) {
172+
Map<Integer, Float> querySimilarity = new HashMap<>();
173+
for (int i = 0; i < docs.length; i++) {
174+
int docId = docs[i].docId();
175+
VectorData vectorData = fieldVectors.get(docId);
176+
if (vectorData != null) {
177+
float querySimilarityScore = getVectorComparisonScore(similarityFunction, useFloat, vectorData, context.getQueryVector());
178+
querySimilarity.put(docId, querySimilarityScore);
179+
}
180+
}
181+
return querySimilarity;
182+
}
183+
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.search.diversification.mmr;
11+
12+
import org.elasticsearch.index.IndexVersion;
13+
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
14+
import org.elasticsearch.search.diversification.ResultDiversificationContext;
15+
import org.elasticsearch.search.vectors.VectorData;
16+
17+
public class MMRResultDiversificationContext extends ResultDiversificationContext {
18+
19+
private final float lambda;
20+
21+
public MMRResultDiversificationContext(
22+
String field,
23+
float lambda,
24+
int numCandidates,
25+
VectorData queryVector,
26+
DenseVectorFieldMapper fieldMapper,
27+
IndexVersion indexVersion
28+
) {
29+
super(field, numCandidates, queryVector, fieldMapper, indexVersion);
30+
this.lambda = lambda;
31+
}
32+
33+
public float getLambda() {
34+
return lambda;
35+
}
36+
}

0 commit comments

Comments
 (0)