Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -322,9 +322,9 @@ teardown:

- match: { hits.total.value: 10 }
- length: { hits.hits: 3 }
- match: { hits.hits.0._source.textbody: "second text" }
- match: { hits.hits.1._source.textbody: "sixth text" }
- match: { hits.hits.2._source.textbody: "ninth text" }
- match: { hits.hits.0._source.textbody: "sixth text" }
- match: { hits.hits.1._source.textbody: "ninth text" }
- match: { hits.hits.2._source.textbody: "fourth text" }

- do:
search:
Expand All @@ -346,9 +346,9 @@ teardown:

- match: { hits.total.value: 10 }
- length: { hits.hits: 3 }
- match: { hits.hits.0._source.textbody: "second text" }
- match: { hits.hits.1._source.textbody: "fourth text" }
- match: { hits.hits.2._source.textbody: "fifth text" }
- match: { hits.hits.0._source.textbody: "fourth text" }
- match: { hits.hits.1._source.textbody: "fifth text" }
- match: { hits.hits.2._source.textbody: "seventh text" }

---
"Test MMR result diversification byte vector type":
Expand Down Expand Up @@ -584,6 +584,27 @@ teardown:
- match: { status: 400 }
- match: { error.type: illegal_argument_exception }

- do:
catch: /\[diversify\] MMR result diversification \[size\] of -3 must be greater than zero/
search:
index: test-result-diversification-index
body:
retriever:
diversify:
type: "mmr"
field: "textvector"
size: -3
lambda: 0.7
retriever:
knn:
field: "textvector"
query_vector: [ 0.5, 0.2, 0.4, 0.4 ]
k: 6
num_candidates: 6

- match: { status: 400 }
- match: { error.type: action_request_validation_exception }

- do:
catch: /\[diversify\] MMR result diversification must have a \[lambda\] between 0.0 and 1.0. The value provided was null/
search:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,19 @@ public ActionRequestValidationException validate(
}

private ActionRequestValidationException validateMMRDiversification(ActionRequestValidationException validationException) {
if (this.size <= 0) {
validationException = addValidationError(
String.format(
Locale.ROOT,
"[%s] MMR result diversification [%s] of %d must be greater than zero",
getName(),
SIZE_FIELD.getPreferredName(),
this.size
),
validationException
);
}

if (this.size > this.rankWindowSize) {
validationException = addValidationError(
String.format(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,16 @@ public RankDoc[] diversify(RankDoc[] docs) throws IOException {
// our chosen DocIDs to keep
List<Integer> selectedDocRanks = new ArrayList<>();

// always add the highest scoring doc to the list
RankDoc highestScoreDoc = Arrays.stream(docs).max(Comparator.comparingDouble(doc -> doc.score)).orElse(docs[0]);
int highestScoreDocRank = highestScoreDoc.rank;
selectedDocRanks.add(highestScoreDocRank);

// test the vector to see if we are using floats or bytes
VectorData firstVec = context.getFieldVector(highestScoreDocRank);
VectorData firstVec = context.getFieldVector(docs[0].rank);
boolean useFloat = firstVec.isFloat();

// cache the similarity scores for the query vector vs. searchHits
Map<Integer, Float> querySimilarity = getQuerySimilarityForDocs(docs, useFloat, context);

// always add the highest relevant doc to the list
selectedDocRanks.add(getHighestRelevantDocRank(docs, querySimilarity));

Map<Integer, Map<Integer, Float>> cachedSimilarities = new HashMap<>();
int topDocsSize = context.getSize();

Expand Down Expand Up @@ -113,6 +111,20 @@ public RankDoc[] diversify(RankDoc[] docs) throws IOException {
return ret;
}

private Integer getHighestRelevantDocRank(RankDoc[] docs, Map<Integer, Float> querySimilarity) {
Map.Entry<Integer, Float> highestRelevantDoc = querySimilarity.entrySet()
.stream()
.max(Comparator.comparingDouble(Map.Entry::getValue))
.orElse(null);

if (highestRelevantDoc != null) {
return highestRelevantDoc.getKey();
}

RankDoc highestScoreDoc = Arrays.stream(docs).max(Comparator.comparingDouble(doc -> doc.score)).orElse(docs[0]);
return highestScoreDoc.rank;
}

private float getHighestScoreForSelectedVectors(
int docRank,
MMRResultDiversificationContext context,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,38 @@ public class DiversifyRetrieverBuilderTests extends ESTestCase {
public void testValidate() {
SearchSourceBuilder source = new SearchSourceBuilder();

var retrieverWithZeroSize = new DiversifyRetrieverBuilder(
getInnerRetriever(),
ResultDiversificationType.MMR,
"test_field",
10,
0,
getRandomQueryVector(),
0.3f
);
var validationZeroSize = retrieverWithZeroSize.validate(source, null, false, false);
assertEquals(1, validationZeroSize.validationErrors().size());
assertEquals(
"[diversify] MMR result diversification [size] of 0 must be greater than zero",
validationZeroSize.validationErrors().getFirst()
);

var retrieverWithNegativeSize = new DiversifyRetrieverBuilder(
getInnerRetriever(),
ResultDiversificationType.MMR,
"test_field",
10,
-1,
getRandomQueryVector(),
0.3f
);
var validationNegativeSize = retrieverWithNegativeSize.validate(source, null, false, false);
assertEquals(1, validationNegativeSize.validationErrors().size());
assertEquals(
"[diversify] MMR result diversification [size] of -1 must be greater than zero",
validationNegativeSize.validationErrors().getFirst()
);

var retrieverWithLargeSize = new DiversifyRetrieverBuilder(
getInnerRetriever(),
ResultDiversificationType.MMR,
Expand Down Expand Up @@ -184,8 +216,8 @@ public void testMmrResultDiversification() {
var result = retriever.combineInnerRetrieverResults(docs, false);

assertEquals(3, result.length);
assertEquals(1, result[0].rank);
assertEquals(3, result[1].rank);
assertEquals(3, result[0].rank);
assertEquals(4, result[1].rank);
assertEquals(6, result[2].rank);

var retrieverWithoutRewrite = new DiversifyRetrieverBuilder(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ private MMRResultDiversificationContext getRandomFloatContext(List<Integer> expe
)
);

expectedDocIds.addAll(List.of(1, 3, 6));
expectedDocIds.addAll(List.of(3, 4, 6));

return diversificationContext;
}
Expand Down Expand Up @@ -124,7 +124,7 @@ private MMRResultDiversificationContext getRandomByteContext(List<Integer> expec
)
);

expectedDocIds.addAll(List.of(1, 3, 6));
expectedDocIds.addAll(List.of(2, 3, 6));

return diversificationContext;
}
Expand Down