diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/result-diversification/10_mmr_result_diversification_retriever.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/result-diversification/10_mmr_result_diversification_retriever.yml index dd01deea53299..914ffd28c7177 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/result-diversification/10_mmr_result_diversification_retriever.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/result-diversification/10_mmr_result_diversification_retriever.yml @@ -322,9 +322,9 @@ teardown: - match: { hits.total.value: 10 } - length: { hits.hits: 3 } - - match: { hits.hits.0._source.textbody: "second text" } - - match: { hits.hits.1._source.textbody: "sixth text" } - - match: { hits.hits.2._source.textbody: "ninth text" } + - match: { hits.hits.0._source.textbody: "sixth text" } + - match: { hits.hits.1._source.textbody: "ninth text" } + - match: { hits.hits.2._source.textbody: "fourth text" } - do: search: @@ -346,9 +346,9 @@ teardown: - match: { hits.total.value: 10 } - length: { hits.hits: 3 } - - match: { hits.hits.0._source.textbody: "second text" } - - match: { hits.hits.1._source.textbody: "fourth text" } - - match: { hits.hits.2._source.textbody: "fifth text" } + - match: { hits.hits.0._source.textbody: "fourth text" } + - match: { hits.hits.1._source.textbody: "fifth text" } + - match: { hits.hits.2._source.textbody: "seventh text" } --- "Test MMR result diversification byte vector type": @@ -584,6 +584,27 @@ teardown: - match: { status: 400 } - match: { error.type: illegal_argument_exception } + - do: + catch: /\[diversify\] MMR result diversification \[size\] of -3 must be greater than zero/ + search: + index: test-result-diversification-index + body: + retriever: + diversify: + type: "mmr" + field: "textvector" + size: -3 + lambda: 0.7 + retriever: + knn: + field: "textvector" + query_vector: [ 0.5, 0.2, 0.4, 0.4 ] + k: 6 + num_candidates: 6 + + - match: { status: 400 } + - match: { error.type: action_request_validation_exception } + - do: catch: /\[diversify\] MMR result diversification must have a \[lambda\] between 0.0 and 1.0. The value provided was null/ search: diff --git a/server/src/main/java/org/elasticsearch/search/diversification/DiversifyRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/diversification/DiversifyRetrieverBuilder.java index 32a6a3fbe6bdd..03751d2767e64 100644 --- a/server/src/main/java/org/elasticsearch/search/diversification/DiversifyRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/diversification/DiversifyRetrieverBuilder.java @@ -189,6 +189,19 @@ public ActionRequestValidationException validate( } private ActionRequestValidationException validateMMRDiversification(ActionRequestValidationException validationException) { + if (this.size <= 0) { + validationException = addValidationError( + String.format( + Locale.ROOT, + "[%s] MMR result diversification [%s] of %d must be greater than zero", + getName(), + SIZE_FIELD.getPreferredName(), + this.size + ), + validationException + ); + } + if (this.size > this.rankWindowSize) { validationException = addValidationError( String.format( diff --git a/server/src/main/java/org/elasticsearch/search/diversification/mmr/MMRResultDiversification.java b/server/src/main/java/org/elasticsearch/search/diversification/mmr/MMRResultDiversification.java index c68c1cca488cc..861b7e9130a63 100644 --- a/server/src/main/java/org/elasticsearch/search/diversification/mmr/MMRResultDiversification.java +++ b/server/src/main/java/org/elasticsearch/search/diversification/mmr/MMRResultDiversification.java @@ -45,18 +45,16 @@ public RankDoc[] diversify(RankDoc[] docs) throws IOException { // our chosen DocIDs to keep List selectedDocRanks = new ArrayList<>(); - // always add the highest scoring doc to the list - RankDoc highestScoreDoc = Arrays.stream(docs).max(Comparator.comparingDouble(doc -> doc.score)).orElse(docs[0]); - int highestScoreDocRank = highestScoreDoc.rank; - selectedDocRanks.add(highestScoreDocRank); - // test the vector to see if we are using floats or bytes - VectorData firstVec = context.getFieldVector(highestScoreDocRank); + VectorData firstVec = context.getFieldVector(docs[0].rank); boolean useFloat = firstVec.isFloat(); // cache the similarity scores for the query vector vs. searchHits Map querySimilarity = getQuerySimilarityForDocs(docs, useFloat, context); + // always add the highest relevant doc to the list + selectedDocRanks.add(getHighestRelevantDocRank(docs, querySimilarity)); + Map> cachedSimilarities = new HashMap<>(); int topDocsSize = context.getSize(); @@ -113,6 +111,20 @@ public RankDoc[] diversify(RankDoc[] docs) throws IOException { return ret; } + private Integer getHighestRelevantDocRank(RankDoc[] docs, Map querySimilarity) { + Map.Entry highestRelevantDoc = querySimilarity.entrySet() + .stream() + .max(Comparator.comparingDouble(Map.Entry::getValue)) + .orElse(null); + + if (highestRelevantDoc != null) { + return highestRelevantDoc.getKey(); + } + + RankDoc highestScoreDoc = Arrays.stream(docs).max(Comparator.comparingDouble(doc -> doc.score)).orElse(docs[0]); + return highestScoreDoc.rank; + } + private float getHighestScoreForSelectedVectors( int docRank, MMRResultDiversificationContext context, diff --git a/server/src/test/java/org/elasticsearch/search/diversification/DiversifyRetrieverBuilderTests.java b/server/src/test/java/org/elasticsearch/search/diversification/DiversifyRetrieverBuilderTests.java index e34ab0c1c6b57..2b52eb0f3ce1b 100644 --- a/server/src/test/java/org/elasticsearch/search/diversification/DiversifyRetrieverBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/diversification/DiversifyRetrieverBuilderTests.java @@ -57,6 +57,38 @@ public class DiversifyRetrieverBuilderTests extends ESTestCase { public void testValidate() { SearchSourceBuilder source = new SearchSourceBuilder(); + var retrieverWithZeroSize = new DiversifyRetrieverBuilder( + getInnerRetriever(), + ResultDiversificationType.MMR, + "test_field", + 10, + 0, + getRandomQueryVector(), + 0.3f + ); + var validationZeroSize = retrieverWithZeroSize.validate(source, null, false, false); + assertEquals(1, validationZeroSize.validationErrors().size()); + assertEquals( + "[diversify] MMR result diversification [size] of 0 must be greater than zero", + validationZeroSize.validationErrors().getFirst() + ); + + var retrieverWithNegativeSize = new DiversifyRetrieverBuilder( + getInnerRetriever(), + ResultDiversificationType.MMR, + "test_field", + 10, + -1, + getRandomQueryVector(), + 0.3f + ); + var validationNegativeSize = retrieverWithNegativeSize.validate(source, null, false, false); + assertEquals(1, validationNegativeSize.validationErrors().size()); + assertEquals( + "[diversify] MMR result diversification [size] of -1 must be greater than zero", + validationNegativeSize.validationErrors().getFirst() + ); + var retrieverWithLargeSize = new DiversifyRetrieverBuilder( getInnerRetriever(), ResultDiversificationType.MMR, @@ -184,8 +216,8 @@ public void testMmrResultDiversification() { var result = retriever.combineInnerRetrieverResults(docs, false); assertEquals(3, result.length); - assertEquals(1, result[0].rank); - assertEquals(3, result[1].rank); + assertEquals(3, result[0].rank); + assertEquals(4, result[1].rank); assertEquals(6, result[2].rank); var retrieverWithoutRewrite = new DiversifyRetrieverBuilder( diff --git a/server/src/test/java/org/elasticsearch/search/diversification/mmr/MMRResultDiversificationTests.java b/server/src/test/java/org/elasticsearch/search/diversification/mmr/MMRResultDiversificationTests.java index 9dc58180a4b9e..7306045a28818 100644 --- a/server/src/test/java/org/elasticsearch/search/diversification/mmr/MMRResultDiversificationTests.java +++ b/server/src/test/java/org/elasticsearch/search/diversification/mmr/MMRResultDiversificationTests.java @@ -89,7 +89,7 @@ private MMRResultDiversificationContext getRandomFloatContext(List expe ) ); - expectedDocIds.addAll(List.of(1, 3, 6)); + expectedDocIds.addAll(List.of(3, 4, 6)); return diversificationContext; } @@ -124,7 +124,7 @@ private MMRResultDiversificationContext getRandomByteContext(List expec ) ); - expectedDocIds.addAll(List.of(1, 3, 6)); + expectedDocIds.addAll(List.of(2, 3, 6)); return diversificationContext; }