diff --git a/docs/reference/elasticsearch/rest-apis/retrievers/diversify-retriever.md b/docs/reference/elasticsearch/rest-apis/retrievers/diversify-retriever.md index 84dac1998b2df..eb35138864ca7 100644 --- a/docs/reference/elasticsearch/rest-apis/retrievers/diversify-retriever.md +++ b/docs/reference/elasticsearch/rest-apis/retrievers/diversify-retriever.md @@ -50,6 +50,14 @@ The ordering of results returned from the inner retriever is preserved. Query vector. Must have the same number of dimensions as the vector field you are searching against. Must be either an array of floats or a hex-encoded byte vector. + If you provide a `query_vector`, you cannot also provide a `query_vector_builder`. + +`query_vector_builder` +: (Optional, query vector builder object) + + Defines a [model](docs-content://solutions/search/vector/knn.md#knn-semantic-search) to build a query vector. + If you provide a `query_vector_builder`, you cannot also provide a `query_vector`. + `lambda` : (Required for `mmr`, float) diff --git a/server/src/main/java/org/elasticsearch/search/diversification/DiversifyRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/diversification/DiversifyRetrieverBuilder.java index 03751d2767e64..320c05ab559fc 100644 --- a/server/src/main/java/org/elasticsearch/search/diversification/DiversifyRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/diversification/DiversifyRetrieverBuilder.java @@ -10,6 +10,7 @@ package org.elasticsearch.search.diversification; import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.util.SetOnce; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionRequestValidationException; @@ -26,6 +27,7 @@ import org.elasticsearch.search.retriever.CompoundRetrieverBuilder; import org.elasticsearch.search.retriever.RetrieverBuilder; import org.elasticsearch.search.retriever.RetrieverParserContext; +import org.elasticsearch.search.vectors.QueryVectorBuilder; import org.elasticsearch.search.vectors.VectorData; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ObjectParser; @@ -40,15 +42,16 @@ import java.util.Locale; import java.util.Map; import java.util.Objects; +import java.util.function.Supplier; import static org.elasticsearch.action.ValidateActions.addValidationError; +import static org.elasticsearch.common.Strings.format; import static org.elasticsearch.search.rank.RankBuilder.DEFAULT_RANK_WINDOW_SIZE; import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; public final class DiversifyRetrieverBuilder extends CompoundRetrieverBuilder { - public static final Float DEFAULT_LAMBDA_VALUE = 0.7f; public static final int DEFAULT_SIZE_VALUE = 10; public static final NodeFeature RETRIEVER_RESULT_DIVERSIFICATION_MMR_FEATURE = new NodeFeature("retriever.result_diversification_mmr"); @@ -58,6 +61,7 @@ public final class DiversifyRetrieverBuilder extends CompoundRetrieverBuilder p.namedObject(QueryVectorBuilder.class, n, c), + QUERY_VECTOR_BUILDER_FIELD + ); PARSER.declareFloat(optionalConstructorArg(), LAMBDA_FIELD); PARSER.declareInt(optionalConstructorArg(), SIZE_FIELD); RetrieverBuilder.declareBaseParserFields(PARSER); @@ -120,10 +131,10 @@ public SearchHit hit() { private final ResultDiversificationType diversificationType; private final String diversificationField; - private final VectorData queryVector; + private final Supplier queryVector; + private final QueryVectorBuilder queryVectorBuilder; private final Float lambda; private final Integer size; - private ResultDiversificationContext diversificationContext = null; DiversifyRetrieverBuilder( RetrieverSource innerRetriever, @@ -132,12 +143,14 @@ public SearchHit hit() { int rankWindowSize, @Nullable Integer size, @Nullable VectorData queryVector, + @Nullable QueryVectorBuilder queryVectorBuilder, @Nullable Float lambda ) { super(List.of(innerRetriever), rankWindowSize); this.diversificationType = diversificationType; this.diversificationField = diversificationField; - this.queryVector = queryVector; + this.queryVector = queryVector != null ? () -> queryVector : null; + this.queryVectorBuilder = queryVectorBuilder; this.lambda = lambda; this.size = size == null ? Math.min(DEFAULT_SIZE_VALUE, rankWindowSize) : size; } @@ -148,7 +161,8 @@ public SearchHit hit() { String diversificationField, int rankWindowSize, @Nullable Integer size, - @Nullable VectorData queryVector, + @Nullable Supplier queryVector, + @Nullable QueryVectorBuilder queryVectorBuilder, @Nullable Float lambda ) { super(innerRetrievers, rankWindowSize); @@ -157,6 +171,7 @@ public SearchHit hit() { this.diversificationType = diversificationType; this.diversificationField = diversificationField; this.queryVector = queryVector; + this.queryVectorBuilder = queryVectorBuilder; this.lambda = lambda; this.size = size == null ? Math.min(DEFAULT_SIZE_VALUE, rankWindowSize) : size; } @@ -170,6 +185,7 @@ protected DiversifyRetrieverBuilder clone(List newChildRetrieve rankWindowSize, size, queryVector, + queryVectorBuilder, lambda ); } @@ -181,6 +197,19 @@ public ActionRequestValidationException validate( boolean isScroll, boolean allowPartialSearchResults ) { + if (queryVector != null && queryVectorBuilder != null) { + validationException = addValidationError( + String.format( + Locale.ROOT, + "[%s] MMR result diversification can have one of [%s] or [%s], but not both", + getName(), + QUERY_VECTOR_FIELD.getPreferredName(), + QUERY_VECTOR_BUILDER_FIELD.getPreferredName() + ), + validationException + ); + } + if (diversificationType.equals(ResultDiversificationType.MMR)) { validationException = validateMMRDiversification(validationException); } @@ -235,17 +264,37 @@ private ActionRequestValidationException validateMMRDiversification(ActionReques @Override protected RetrieverBuilder doRewrite(QueryRewriteContext ctx) { - if (diversificationType.equals(ResultDiversificationType.MMR)) { - // field vectors will be filled in during the combine - diversificationContext = new MMRResultDiversificationContext( + if (queryVectorBuilder != null) { + SetOnce toSet = new SetOnce<>(); + ctx.registerAsyncAction((c, l) -> { + queryVectorBuilder.buildVector(c, l.delegateFailureAndWrap((ll, v) -> { + toSet.set(v == null ? null : new VectorData(v)); + if (v == null) { + ll.onFailure( + new IllegalArgumentException( + format( + "[%s] with name [%s] returned null query_vector", + QUERY_VECTOR_BUILDER_FIELD.getPreferredName(), + queryVectorBuilder.getWriteableName() + ) + ) + ); + return; + } + ll.onResponse(null); + })); + }); + + return new DiversifyRetrieverBuilder( + innerRetrievers, + diversificationType, diversificationField, - lambda, - size == null ? DEFAULT_SIZE_VALUE : size, - queryVector + rankWindowSize, + size, + () -> toSet.get(), + null, + lambda ); - } else { - // should not happen - throw new IllegalArgumentException("Unknown diversification type [" + diversificationType + "]"); } return this; @@ -281,13 +330,6 @@ protected Exception processInnerItemFailureException(Exception ex) { @Override protected RankDoc[] combineInnerRetrieverResults(List rankResults, boolean explain) { - if (diversificationContext == null) { - throw new ElasticsearchStatusException( - "diversificationContext is not set. \"doRewrite\" should have been called beforehand.", - RestStatus.INTERNAL_SERVER_ERROR - ); - } - if (rankResults.isEmpty()) { return new RankDoc[0]; } @@ -302,6 +344,8 @@ protected RankDoc[] combineInnerRetrieverResults(List rankResults, b return new RankDoc[0]; } + ResultDiversificationContext diversificationContext = getResultDiversificationContext(); + // gather and set the query vectors // and create our intermediate results set RankDoc[] results = new RankDoc[scoreDocs.length]; @@ -344,6 +388,15 @@ protected RankDoc[] combineInnerRetrieverResults(List rankResults, b } } + private ResultDiversificationContext getResultDiversificationContext() { + if (diversificationType.equals(ResultDiversificationType.MMR)) { + return new MMRResultDiversificationContext(diversificationField, lambda, size == null ? DEFAULT_SIZE_VALUE : size, queryVector); + } + + // should not happen + throw new IllegalArgumentException("Unknown diversification type [" + diversificationType + "]"); + } + private void extractFieldVectorData(int docId, Object fieldValue, Map fieldVectors) { switch (fieldValue) { case float[] floatArray -> { @@ -427,7 +480,11 @@ protected void doToXContent(XContentBuilder builder, Params params) throws IOExc builder.field(RANK_WINDOW_SIZE_FIELD.getPreferredName(), rankWindowSize); if (queryVector != null) { - builder.field(QUERY_VECTOR_FIELD.getPreferredName(), queryVector); + builder.field(QUERY_VECTOR_FIELD.getPreferredName(), queryVector.get()); + } + + if (queryVectorBuilder != null) { + builder.field(QUERY_VECTOR_BUILDER_FIELD.getPreferredName(), queryVectorBuilder); } if (lambda != null) { @@ -451,6 +508,8 @@ public boolean doEquals(Object o) { && this.diversificationType.equals(other.diversificationType) && this.diversificationField.equals(other.diversificationField) && Objects.equals(this.lambda, other.lambda) - && Objects.equals(this.queryVector, other.queryVector); + && ((queryVector == null && other.queryVector == null) + || (queryVector != null && other.queryVector != null && Objects.equals(queryVector.get(), other.queryVector.get()))) + && Objects.equals(this.queryVectorBuilder, other.queryVectorBuilder); } } diff --git a/server/src/main/java/org/elasticsearch/search/diversification/ResultDiversificationContext.java b/server/src/main/java/org/elasticsearch/search/diversification/ResultDiversificationContext.java index b5b2e4cc812e4..09c1fd92e6335 100644 --- a/server/src/main/java/org/elasticsearch/search/diversification/ResultDiversificationContext.java +++ b/server/src/main/java/org/elasticsearch/search/diversification/ResultDiversificationContext.java @@ -14,14 +14,15 @@ import java.util.Map; import java.util.Set; +import java.util.function.Supplier; public abstract class ResultDiversificationContext { private final String field; private final int size; - private final VectorData queryVector; + private final Supplier queryVector; private Map fieldVectors = null; - protected ResultDiversificationContext(String field, int size, @Nullable VectorData queryVector) { + protected ResultDiversificationContext(String field, int size, @Nullable Supplier queryVector) { this.field = field; this.size = size; this.queryVector = queryVector; @@ -45,7 +46,7 @@ public void setFieldVectors(Map fieldVectors) { } public VectorData getQueryVector() { - return queryVector; + return queryVector == null ? null : queryVector.get(); } public VectorData getFieldVector(int rank) { diff --git a/server/src/main/java/org/elasticsearch/search/diversification/mmr/MMRResultDiversificationContext.java b/server/src/main/java/org/elasticsearch/search/diversification/mmr/MMRResultDiversificationContext.java index fba2909baf398..012ebeefd3564 100644 --- a/server/src/main/java/org/elasticsearch/search/diversification/mmr/MMRResultDiversificationContext.java +++ b/server/src/main/java/org/elasticsearch/search/diversification/mmr/MMRResultDiversificationContext.java @@ -13,11 +13,13 @@ import org.elasticsearch.search.diversification.ResultDiversificationContext; import org.elasticsearch.search.vectors.VectorData; +import java.util.function.Supplier; + public class MMRResultDiversificationContext extends ResultDiversificationContext { private final float lambda; - public MMRResultDiversificationContext(String field, float lambda, int size, @Nullable VectorData queryVector) { + public MMRResultDiversificationContext(String field, float lambda, int size, @Nullable Supplier queryVector) { super(field, size, queryVector); this.lambda = lambda; } diff --git a/server/src/test/java/org/elasticsearch/search/diversification/DiversifyRetrieverBuilderParsingTests.java b/server/src/test/java/org/elasticsearch/search/diversification/DiversifyRetrieverBuilderParsingTests.java index d01722d093d4d..d412ad122bf04 100644 --- a/server/src/test/java/org/elasticsearch/search/diversification/DiversifyRetrieverBuilderParsingTests.java +++ b/server/src/test/java/org/elasticsearch/search/diversification/DiversifyRetrieverBuilderParsingTests.java @@ -59,6 +59,7 @@ protected DiversifyRetrieverBuilder createTestInstance() { rankWindowSize, size, queryVector, + null, lambda ); } @@ -92,11 +93,7 @@ protected NamedXContentRegistry xContentRegistry() { private VectorData getRandomQueryVector() { if (randomBoolean()) { - float[] queryVector = new float[randomIntBetween(5, 256)]; - for (int i = 0; i < queryVector.length; i++) { - queryVector[i] = randomFloatBetween(0.0f, 1.0f, true); - } - return new VectorData(queryVector); + return new VectorData(getRandomFloatQueryVector()); } byte[] queryVector = new byte[randomIntBetween(5, 256)]; @@ -105,4 +102,12 @@ private VectorData getRandomQueryVector() { } return new VectorData(queryVector); } + + private float[] getRandomFloatQueryVector() { + float[] queryVector = new float[randomIntBetween(5, 256)]; + for (int i = 0; i < queryVector.length; i++) { + queryVector[i] = randomFloatBetween(0.0f, 1.0f, true); + } + return queryVector; + } } diff --git a/server/src/test/java/org/elasticsearch/search/diversification/DiversifyRetrieverBuilderTests.java b/server/src/test/java/org/elasticsearch/search/diversification/DiversifyRetrieverBuilderTests.java index 2b52eb0f3ce1b..b5b8b1fbd97ef 100644 --- a/server/src/test/java/org/elasticsearch/search/diversification/DiversifyRetrieverBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/diversification/DiversifyRetrieverBuilderTests.java @@ -12,6 +12,7 @@ import org.apache.lucene.search.ScoreDoc; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.TransportVersion; +import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.MockResolvedIndices; import org.elasticsearch.action.OriginalIndices; import org.elasticsearch.action.ResolvedIndices; @@ -38,6 +39,7 @@ import org.elasticsearch.search.retriever.CompoundRetrieverBuilder; import org.elasticsearch.search.retriever.RetrieverBuilder; import org.elasticsearch.search.retriever.TestRetrieverBuilder; +import org.elasticsearch.search.vectors.TestQueryVectorBuilderPlugin; import org.elasticsearch.search.vectors.VectorData; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.transport.RemoteClusterAware; @@ -63,7 +65,8 @@ public void testValidate() { "test_field", 10, 0, - getRandomQueryVector(), + new VectorData(getRandomFloatQueryVector()), + null, 0.3f ); var validationZeroSize = retrieverWithZeroSize.validate(source, null, false, false); @@ -79,7 +82,8 @@ public void testValidate() { "test_field", 10, -1, - getRandomQueryVector(), + new VectorData(getRandomFloatQueryVector()), + null, 0.3f ); var validationNegativeSize = retrieverWithNegativeSize.validate(source, null, false, false); @@ -95,7 +99,8 @@ public void testValidate() { "test_field", 10, 20, - getRandomQueryVector(), + new VectorData(getRandomFloatQueryVector()), + null, 0.3f ); var validationSize = retrieverWithLargeSize.validate(source, null, false, false); @@ -115,7 +120,8 @@ public void testValidate() { "test_field", rankWindowSize, size, - getRandomQueryVector(), + new VectorData(getRandomFloatQueryVector()), + null, 2.0f ); var validationLambda = retrieverHighLambda.validate(source, null, false, false); @@ -131,7 +137,8 @@ public void testValidate() { "test_field", rankWindowSize, size, - getRandomQueryVector(), + new VectorData(getRandomFloatQueryVector()), + null, -0.1f ); validationLambda = retrieverLowLambda.validate(source, null, false, false); @@ -147,7 +154,8 @@ public void testValidate() { "test_field", rankWindowSize, size, - getRandomQueryVector(), + new VectorData(getRandomFloatQueryVector()), + null, null ); validationLambda = retrieverNullLambda.validate(source, null, false, false); @@ -156,6 +164,23 @@ public void testValidate() { "[diversify] MMR result diversification must have a [lambda] between 0.0 and 1.0. The value provided was null", validationLambda.validationErrors().getFirst() ); + + var retrieverWithBothQueryVectorAndBuilder = new DiversifyRetrieverBuilder( + getInnerRetriever(), + ResultDiversificationType.MMR, + "test_field", + rankWindowSize, + size, + new VectorData(getRandomFloatQueryVector()), + new TestQueryVectorBuilderPlugin.TestQueryVectorBuilder(getRandomFloatQueryVector()), + 0.5f + ); + var validationQueryVectors = retrieverWithBothQueryVectorAndBuilder.validate(source, null, false, false); + assertEquals(1, validationQueryVectors.validationErrors().size()); + assertEquals( + "[diversify] MMR result diversification can have one of [query_vector] or [query_vector_builder], but not both", + validationQueryVectors.validationErrors().getFirst() + ); } public void testClone() { @@ -187,11 +212,47 @@ public void testDoRewrite() { assertSame(original, rewritten); assertCompoundRetriever(original, rewritten); - // will assert that the rewrite happened without assertion errors - List docs = new ArrayList<>(); - docs.add(new ScoreDoc[] {}); - var result = original.combineInnerRetrieverResults(docs, false); - assertEquals(0, result.length); + float[] queryVectorToUse = getRandomFloatQueryVector(256); + CompoundRetrieverBuilder.RetrieverSource innerRetriever = getInnerRetriever(); + + var withQueryVectorBuilder = new DiversifyRetrieverBuilder( + innerRetriever, + ResultDiversificationType.MMR, + "dense_vector_field", + 10, + 5, + null, + new TestQueryVectorBuilderPlugin.TestQueryVectorBuilder(queryVectorToUse), + 0.7f + ); + var builderRewritten = (DiversifyRetrieverBuilder) withQueryVectorBuilder.doRewrite(queryRewriteContext); + assertNotSame(withQueryVectorBuilder, builderRewritten); + // should not be equal as the query vector should now be set from the builder + assertNotEquals(withQueryVectorBuilder, builderRewritten); + assertCompoundRetriever(withQueryVectorBuilder, builderRewritten); + + queryRewriteContext.executeAsyncActions(new ActionListener() { + @Override + public void onResponse(Void unused) { + var withQueryVector = new DiversifyRetrieverBuilder( + innerRetriever, + ResultDiversificationType.MMR, + "dense_vector_field", + 10, + 5, + new VectorData(queryVectorToUse), + null, + 0.7f + ); + + assertEquals(withQueryVector, builderRewritten); + } + + @Override + public void onFailure(Exception e) { + fail(e); + } + }); } public void testMmrResultDiversification() { @@ -203,6 +264,7 @@ public void testMmrResultDiversification() { 10, 3, new VectorData(new float[] { 0.5f, 0.2f, 0.4f, 0.4f }), + null, 0.3f ); @@ -227,16 +289,10 @@ public void testMmrResultDiversification() { 10, 3, new VectorData(new float[] { 0.5f, 0.2f, 0.4f, 0.4f }), + null, 0.3f ); - ElasticsearchStatusException missingRewriteEx = assertThrows( - ElasticsearchStatusException.class, - () -> retrieverWithoutRewrite.combineInnerRetrieverResults(List.of(), false) - ); - assertEquals("diversificationContext is not set. \"doRewrite\" should have been called beforehand.", missingRewriteEx.getMessage()); - assertEquals(500, missingRewriteEx.status().getStatus()); - retrieverWithoutRewrite.doRewrite(queryRewriteContext); var emptyDocs = retrieverWithoutRewrite.combineInnerRetrieverResults(List.of(), false); @@ -262,6 +318,7 @@ public void testThrowsExceptionOnBadFieldTypes() { 10, 3, new VectorData(new float[] { 0.5f, 0.2f, 0.4f, 0.4f }), + null, 0.3f ); @@ -371,10 +428,9 @@ private static DiversifyRetrieverBuilder createRandomRetriever(@Nullable String int rankWindowSize = randomIntBetween(1, 20); Integer size = randomBoolean() ? null : randomIntBetween(1, 20); - // TODO - decide using float for byte here! VectorData queryVector = vectorDimensions == null - ? randomBoolean() ? getRandomQueryVector() : null - : getRandomQueryVector(vectorDimensions); + ? randomBoolean() ? new VectorData(getRandomFloatQueryVector()) : null + : new VectorData(getRandomFloatQueryVector(vectorDimensions)); Float lambda = randomFloatBetween(0.0f, 1.0f, true); CompoundRetrieverBuilder.RetrieverSource innerRetriever = getInnerRetriever(); return new DiversifyRetrieverBuilder( @@ -384,6 +440,7 @@ private static DiversifyRetrieverBuilder createRandomRetriever(@Nullable String rankWindowSize, size, queryVector, + null, lambda ); } @@ -392,17 +449,17 @@ private static CompoundRetrieverBuilder.RetrieverSource getInnerRetriever() { return new CompoundRetrieverBuilder.RetrieverSource(TestRetrieverBuilder.createRandomTestRetrieverBuilder(), null); } - private static VectorData getRandomQueryVector() { - return getRandomQueryVector(null); + private static float[] getRandomFloatQueryVector() { + return getRandomFloatQueryVector(null); } - private static VectorData getRandomQueryVector(@Nullable Integer vectorDimensions) { + private static float[] getRandomFloatQueryVector(@Nullable Integer vectorDimensions) { int vectorSize = vectorDimensions == null ? randomIntBetween(5, 256) : vectorDimensions; float[] queryVector = new float[vectorSize]; for (int i = 0; i < queryVector.length; i++) { queryVector[i] = randomFloatBetween(0.0f, 1.0f, true); } - return new VectorData(queryVector); + return queryVector; } private static ResolvedIndices createMockResolvedIndices(Map> localIndexDenseVectorFields) { diff --git a/server/src/test/java/org/elasticsearch/search/diversification/mmr/MMRResultDiversificationTests.java b/server/src/test/java/org/elasticsearch/search/diversification/mmr/MMRResultDiversificationTests.java index 7306045a28818..7f8eb3509613d 100644 --- a/server/src/test/java/org/elasticsearch/search/diversification/mmr/MMRResultDiversificationTests.java +++ b/server/src/test/java/org/elasticsearch/search/diversification/mmr/MMRResultDiversificationTests.java @@ -20,6 +20,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.function.Supplier; public class MMRResultDiversificationTests extends ESTestCase { @@ -70,7 +71,7 @@ private MMRResultDiversificationContext getRandomFloatContext(List expe DenseVectorFieldMapper.Builder builder = (DenseVectorFieldMapper.Builder) mapper.getMergeBuilder(); builder.elementType(DenseVectorFieldMapper.ElementType.FLOAT); - var queryVectorData = new VectorData(new float[] { 0.5f, 0.2f, 0.4f, 0.4f }); + Supplier queryVectorData = () -> new VectorData(new float[] { 0.5f, 0.2f, 0.4f, 0.4f }); var diversificationContext = new MMRResultDiversificationContext("dense_vector_field", 0.3f, 3, queryVectorData); diversificationContext.setFieldVectors( Map.of( @@ -105,7 +106,7 @@ private MMRResultDiversificationContext getRandomByteContext(List expec DenseVectorFieldMapper.Builder builder = (DenseVectorFieldMapper.Builder) mapper.getMergeBuilder(); builder.elementType(DenseVectorFieldMapper.ElementType.BYTE); - var queryVectorData = new VectorData(new byte[] { 0x50, 0x20, 0x40, 0x40 }); + Supplier queryVectorData = () -> new VectorData(new byte[] { 0x50, 0x20, 0x40, 0x40 }); var diversificationContext = new MMRResultDiversificationContext("dense_vector_field", 0.3f, 3, queryVectorData); diversificationContext.setFieldVectors( Map.of( @@ -141,7 +142,7 @@ public void testMMRDiversificationIfNoSearchHits() throws IOException { DenseVectorFieldMapper.Builder builder = (DenseVectorFieldMapper.Builder) mapper.getMergeBuilder(); builder.elementType(DenseVectorFieldMapper.ElementType.FLOAT); - var queryVectorData = new VectorData(new float[] { 0.5f, 0.2f, 0.4f, 0.4f }); + Supplier queryVectorData = () -> new VectorData(new float[] { 0.5f, 0.2f, 0.4f, 0.4f }); var diversificationContext = new MMRResultDiversificationContext("dense_vector_field", 0.6f, 10, queryVectorData); RankDoc[] emptyDocs = new RankDoc[0]; diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/70_semantic_text_diversifty_retriever.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/70_semantic_text_diversifty_retriever.yml new file mode 100644 index 0000000000000..b38dd31614d1b --- /dev/null +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/70_semantic_text_diversifty_retriever.yml @@ -0,0 +1,127 @@ +setup: + - requires: + cluster_features: [ "retriever.result_diversification_mmr" ] + reason: "Added retriever for result diversification using MMR" + + - requires: + test_runner_features: + - contains + - headers + + - do: + inference.put: + task_type: text_embedding + inference_id: dense-inference-id + body: > + { + "service": "text_embedding_test_service", + "service_settings": { + "model": "my_model", + "dimensions": 4, + "similarity": "cosine", + "api_key": "abc64" + }, + "task_settings": { + } + } + + - do: + indices.create: + index: test-result-diversification-index + body: + settings: + number_of_shards: 1 + mappings: + properties: + textbody: + type: text + keywordfield: + type: keyword + textvector: + type: dense_vector + dims: 4 + + - do: + headers: + Content-Type: application/json + bulk: + index: test-result-diversification-index + refresh: true + body: | + {"index":{}} + {"textbody": "first text", "textvector": [0.4, 0.2, 0.4, 0.4], "keywordfield": "test1"} + {"index":{}} + {"textbody": "second text", "textvector": [0.4, 0.2, 0.3, 0.3], "keywordfield": "test2"} + {"index":{}} + {"textbody": "second text duplicate", "textvector": [0.4, 0.2, 0.3, 0.3], "keywordfield": "test2_dup"} + {"index":{}} + {"textbody": "third text", "textvector": [0.4, 0.1, 0.3, 0.3], "keywordfield": "test3"} + {"index":{}} + {"textbody": "third text duplicate", "textvector": [0.4, 0.1, 0.3, 0.3], "keywordfield": "test3_dup"} + {"index":{}} + {"textbody": "fourth text", "textvector": [0.1, 0.9, 0.5, 0.9], "keywordfield": "test4"} + {"index":{}} + {"textbody": "fifth text", "textvector": [0.1, 0.9, 0.5, 0.9], "keywordfield": "test5"} + {"index":{}} + {"textbody": "sixth text", "textvector": [0.05, 0.05, 0.05, 0.05], "keywordfield": "test6"} + {"index":{}} + {"textbody": "seventh text", "textvector": [0.1, 0.9, 0.5, 0.9], "keywordfield": "test7"} + {"index":{}} + {"textbody": "eighth text", "textvector": [0.1, 0.9, 0.5, 0.9], "keywordfield": "test8"} + {"index":{}} + {"textbody": "ninth text", "textvector": [0.05, 0.05, 0.05, 0.05], "keywordfield": "test9"} + +--- +"MMR diversification using a query_vector_builder": + - do: + search: + index: test-result-diversification-index + body: + retriever: + diversify: + type: "mmr" + field: "textvector" + lambda: 0.3 + query_vector_builder: + text_embedding: + model_id: dense-inference-id + model_text: test + size: 3 + retriever: + knn: + field: "textvector" + query_vector: [ 0.5, 0.2, 0.4, 0.4 ] + k: 10 + num_candidates: 10 + + - match: { hits.total.value: 10 } + - length: { hits.hits: 3 } + +--- +"Validate query_vector or query_vector_builder but not both": + + - do: + catch: /\[diversify\] MMR result diversification can have one of \[query_vector\] or \[query_vector_builder\], but not both/ + search: + index: test-result-diversification-index + body: + retriever: + diversify: + type: "mmr" + field: "textvector" + lambda: 0.3 + query_vector: [ 0.4, 0.2, 0.3, 0.3 ] + query_vector_builder: + text_embedding: + model_id: dense-inference-id + model_text: test + size: 3 + retriever: + knn: + field: "textvector" + query_vector: [ 0.5, 0.2, 0.4, 0.4 ] + k: 10 + num_candidates: 10 + + - match: { status: 400 } + - match: { error.type: action_request_validation_exception }