Skip to content

Commit c5d5d55

Browse files
authored
Minor cleanups for MMR Diversify Retriever (#138910)
* ensure size >= 0; use query vec selection * assert size >0 (instead of >= 0) * update error catch > 0 in YAML test
1 parent f974070 commit c5d5d55

File tree

5 files changed

+94
-16
lines changed

5 files changed

+94
-16
lines changed

rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/result-diversification/10_mmr_result_diversification_retriever.yml

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -322,9 +322,9 @@ teardown:
322322

323323
- match: { hits.total.value: 10 }
324324
- length: { hits.hits: 3 }
325-
- match: { hits.hits.0._source.textbody: "second text" }
326-
- match: { hits.hits.1._source.textbody: "sixth text" }
327-
- match: { hits.hits.2._source.textbody: "ninth text" }
325+
- match: { hits.hits.0._source.textbody: "sixth text" }
326+
- match: { hits.hits.1._source.textbody: "ninth text" }
327+
- match: { hits.hits.2._source.textbody: "fourth text" }
328328

329329
- do:
330330
search:
@@ -346,9 +346,9 @@ teardown:
346346

347347
- match: { hits.total.value: 10 }
348348
- length: { hits.hits: 3 }
349-
- match: { hits.hits.0._source.textbody: "second text" }
350-
- match: { hits.hits.1._source.textbody: "fourth text" }
351-
- match: { hits.hits.2._source.textbody: "fifth text" }
349+
- match: { hits.hits.0._source.textbody: "fourth text" }
350+
- match: { hits.hits.1._source.textbody: "fifth text" }
351+
- match: { hits.hits.2._source.textbody: "seventh text" }
352352

353353
---
354354
"Test MMR result diversification byte vector type":
@@ -584,6 +584,27 @@ teardown:
584584
- match: { status: 400 }
585585
- match: { error.type: illegal_argument_exception }
586586

587+
- do:
588+
catch: /\[diversify\] MMR result diversification \[size\] of -3 must be greater than zero/
589+
search:
590+
index: test-result-diversification-index
591+
body:
592+
retriever:
593+
diversify:
594+
type: "mmr"
595+
field: "textvector"
596+
size: -3
597+
lambda: 0.7
598+
retriever:
599+
knn:
600+
field: "textvector"
601+
query_vector: [ 0.5, 0.2, 0.4, 0.4 ]
602+
k: 6
603+
num_candidates: 6
604+
605+
- match: { status: 400 }
606+
- match: { error.type: action_request_validation_exception }
607+
587608
- do:
588609
catch: /\[diversify\] MMR result diversification must have a \[lambda\] between 0.0 and 1.0. The value provided was null/
589610
search:

server/src/main/java/org/elasticsearch/search/diversification/DiversifyRetrieverBuilder.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,19 @@ public ActionRequestValidationException validate(
189189
}
190190

191191
private ActionRequestValidationException validateMMRDiversification(ActionRequestValidationException validationException) {
192+
if (this.size <= 0) {
193+
validationException = addValidationError(
194+
String.format(
195+
Locale.ROOT,
196+
"[%s] MMR result diversification [%s] of %d must be greater than zero",
197+
getName(),
198+
SIZE_FIELD.getPreferredName(),
199+
this.size
200+
),
201+
validationException
202+
);
203+
}
204+
192205
if (this.size > this.rankWindowSize) {
193206
validationException = addValidationError(
194207
String.format(

server/src/main/java/org/elasticsearch/search/diversification/mmr/MMRResultDiversification.java

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,18 +45,16 @@ public RankDoc[] diversify(RankDoc[] docs) throws IOException {
4545
// our chosen DocIDs to keep
4646
List<Integer> selectedDocRanks = new ArrayList<>();
4747

48-
// always add the highest scoring doc to the list
49-
RankDoc highestScoreDoc = Arrays.stream(docs).max(Comparator.comparingDouble(doc -> doc.score)).orElse(docs[0]);
50-
int highestScoreDocRank = highestScoreDoc.rank;
51-
selectedDocRanks.add(highestScoreDocRank);
52-
5348
// test the vector to see if we are using floats or bytes
54-
VectorData firstVec = context.getFieldVector(highestScoreDocRank);
49+
VectorData firstVec = context.getFieldVector(docs[0].rank);
5550
boolean useFloat = firstVec.isFloat();
5651

5752
// cache the similarity scores for the query vector vs. searchHits
5853
Map<Integer, Float> querySimilarity = getQuerySimilarityForDocs(docs, useFloat, context);
5954

55+
// always add the highest relevant doc to the list
56+
selectedDocRanks.add(getHighestRelevantDocRank(docs, querySimilarity));
57+
6058
Map<Integer, Map<Integer, Float>> cachedSimilarities = new HashMap<>();
6159
int topDocsSize = context.getSize();
6260

@@ -113,6 +111,20 @@ public RankDoc[] diversify(RankDoc[] docs) throws IOException {
113111
return ret;
114112
}
115113

114+
private Integer getHighestRelevantDocRank(RankDoc[] docs, Map<Integer, Float> querySimilarity) {
115+
Map.Entry<Integer, Float> highestRelevantDoc = querySimilarity.entrySet()
116+
.stream()
117+
.max(Comparator.comparingDouble(Map.Entry::getValue))
118+
.orElse(null);
119+
120+
if (highestRelevantDoc != null) {
121+
return highestRelevantDoc.getKey();
122+
}
123+
124+
RankDoc highestScoreDoc = Arrays.stream(docs).max(Comparator.comparingDouble(doc -> doc.score)).orElse(docs[0]);
125+
return highestScoreDoc.rank;
126+
}
127+
116128
private float getHighestScoreForSelectedVectors(
117129
int docRank,
118130
MMRResultDiversificationContext context,

server/src/test/java/org/elasticsearch/search/diversification/DiversifyRetrieverBuilderTests.java

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,38 @@ public class DiversifyRetrieverBuilderTests extends ESTestCase {
5757
public void testValidate() {
5858
SearchSourceBuilder source = new SearchSourceBuilder();
5959

60+
var retrieverWithZeroSize = new DiversifyRetrieverBuilder(
61+
getInnerRetriever(),
62+
ResultDiversificationType.MMR,
63+
"test_field",
64+
10,
65+
0,
66+
getRandomQueryVector(),
67+
0.3f
68+
);
69+
var validationZeroSize = retrieverWithZeroSize.validate(source, null, false, false);
70+
assertEquals(1, validationZeroSize.validationErrors().size());
71+
assertEquals(
72+
"[diversify] MMR result diversification [size] of 0 must be greater than zero",
73+
validationZeroSize.validationErrors().getFirst()
74+
);
75+
76+
var retrieverWithNegativeSize = new DiversifyRetrieverBuilder(
77+
getInnerRetriever(),
78+
ResultDiversificationType.MMR,
79+
"test_field",
80+
10,
81+
-1,
82+
getRandomQueryVector(),
83+
0.3f
84+
);
85+
var validationNegativeSize = retrieverWithNegativeSize.validate(source, null, false, false);
86+
assertEquals(1, validationNegativeSize.validationErrors().size());
87+
assertEquals(
88+
"[diversify] MMR result diversification [size] of -1 must be greater than zero",
89+
validationNegativeSize.validationErrors().getFirst()
90+
);
91+
6092
var retrieverWithLargeSize = new DiversifyRetrieverBuilder(
6193
getInnerRetriever(),
6294
ResultDiversificationType.MMR,
@@ -184,8 +216,8 @@ public void testMmrResultDiversification() {
184216
var result = retriever.combineInnerRetrieverResults(docs, false);
185217

186218
assertEquals(3, result.length);
187-
assertEquals(1, result[0].rank);
188-
assertEquals(3, result[1].rank);
219+
assertEquals(3, result[0].rank);
220+
assertEquals(4, result[1].rank);
189221
assertEquals(6, result[2].rank);
190222

191223
var retrieverWithoutRewrite = new DiversifyRetrieverBuilder(

server/src/test/java/org/elasticsearch/search/diversification/mmr/MMRResultDiversificationTests.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ private MMRResultDiversificationContext getRandomFloatContext(List<Integer> expe
8989
)
9090
);
9191

92-
expectedDocIds.addAll(List.of(1, 3, 6));
92+
expectedDocIds.addAll(List.of(3, 4, 6));
9393

9494
return diversificationContext;
9595
}
@@ -124,7 +124,7 @@ private MMRResultDiversificationContext getRandomByteContext(List<Integer> expec
124124
)
125125
);
126126

127-
expectedDocIds.addAll(List.of(1, 3, 6));
127+
expectedDocIds.addAll(List.of(2, 3, 6));
128128

129129
return diversificationContext;
130130
}

0 commit comments

Comments
 (0)