Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@ The ordering of results returned from the inner retriever is preserved.

Query vector. Must have the same number of dimensions as the vector field you are searching against.
Must be either an array of floats or a hex-encoded byte vector.
If you provide a `query_vector`, you cannot also provide a `query_vector_builder`.

`query_vector_builder`
: (Optional, query vector builder object)

Defines a [model](docs-content://solutions/search/vector/knn.md#knn-semantic-search) to build a query vector.
If you provide a `query_vector_builder`, you cannot also provide a `query_vector`.


`lambda`
: (Required for `mmr`, float)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
package org.elasticsearch.search.diversification;

import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionRequestValidationException;
Expand All @@ -26,6 +27,7 @@
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverParserContext;
import org.elasticsearch.search.vectors.QueryVectorBuilder;
import org.elasticsearch.search.vectors.VectorData;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ObjectParser;
Expand All @@ -40,15 +42,16 @@
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.function.Supplier;

import static org.elasticsearch.action.ValidateActions.addValidationError;
import static org.elasticsearch.common.Strings.format;
import static org.elasticsearch.search.rank.RankBuilder.DEFAULT_RANK_WINDOW_SIZE;
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;

public final class DiversifyRetrieverBuilder extends CompoundRetrieverBuilder<DiversifyRetrieverBuilder> {

public static final Float DEFAULT_LAMBDA_VALUE = 0.7f;
public static final int DEFAULT_SIZE_VALUE = 10;

public static final NodeFeature RETRIEVER_RESULT_DIVERSIFICATION_MMR_FEATURE = new NodeFeature("retriever.result_diversification_mmr");
Expand All @@ -58,6 +61,7 @@ public final class DiversifyRetrieverBuilder extends CompoundRetrieverBuilder<Di
public static final ParseField TYPE_FIELD = new ParseField("type");
public static final ParseField FIELD_FIELD = new ParseField("field");
public static final ParseField QUERY_VECTOR_FIELD = new ParseField("query_vector");
public static final ParseField QUERY_VECTOR_BUILDER_FIELD = new ParseField("query_vector_builder");
public static final ParseField LAMBDA_FIELD = new ParseField("lambda");
public static final ParseField SIZE_FIELD = new ParseField("size");

Expand All @@ -83,8 +87,9 @@ public SearchHit hit() {
int rankWindowSize = args[3] == null ? DEFAULT_RANK_WINDOW_SIZE : (int) args[3];

VectorData queryVector = args[4] == null ? null : (VectorData) args[4];
Float lambda = args[5] == null ? null : (Float) args[5];
Integer size = args[6] == null ? null : (Integer) args[6];
QueryVectorBuilder queryVectorBuilder = args[5] == null ? null : (QueryVectorBuilder) args[5];
Float lambda = args[6] == null ? null : (Float) args[6];
Integer size = args[7] == null ? null : (Integer) args[7];

return new DiversifyRetrieverBuilder(
RetrieverSource.from((RetrieverBuilder) args[0]),
Expand All @@ -93,6 +98,7 @@ public SearchHit hit() {
rankWindowSize,
size,
queryVector,
queryVectorBuilder,
lambda
);
}
Expand All @@ -113,17 +119,22 @@ public SearchHit hit() {
QUERY_VECTOR_FIELD,
ObjectParser.ValueType.OBJECT_ARRAY_STRING_OR_NUMBER
);
PARSER.declareNamedObject(
optionalConstructorArg(),
(p, c, n) -> p.namedObject(QueryVectorBuilder.class, n, c),
QUERY_VECTOR_BUILDER_FIELD
);
PARSER.declareFloat(optionalConstructorArg(), LAMBDA_FIELD);
PARSER.declareInt(optionalConstructorArg(), SIZE_FIELD);
RetrieverBuilder.declareBaseParserFields(PARSER);
}

private final ResultDiversificationType diversificationType;
private final String diversificationField;
private final VectorData queryVector;
private final Supplier<VectorData> queryVector;
private final QueryVectorBuilder queryVectorBuilder;
private final Float lambda;
private final Integer size;
private ResultDiversificationContext diversificationContext = null;

DiversifyRetrieverBuilder(
RetrieverSource innerRetriever,
Expand All @@ -132,12 +143,14 @@ public SearchHit hit() {
int rankWindowSize,
@Nullable Integer size,
@Nullable VectorData queryVector,
@Nullable QueryVectorBuilder queryVectorBuilder,
@Nullable Float lambda
) {
super(List.of(innerRetriever), rankWindowSize);
this.diversificationType = diversificationType;
this.diversificationField = diversificationField;
this.queryVector = queryVector;
this.queryVector = queryVector != null ? () -> queryVector : null;
this.queryVectorBuilder = queryVectorBuilder;
this.lambda = lambda;
this.size = size == null ? Math.min(DEFAULT_SIZE_VALUE, rankWindowSize) : size;
}
Expand All @@ -148,7 +161,8 @@ public SearchHit hit() {
String diversificationField,
int rankWindowSize,
@Nullable Integer size,
@Nullable VectorData queryVector,
@Nullable Supplier<VectorData> queryVector,
@Nullable QueryVectorBuilder queryVectorBuilder,
@Nullable Float lambda
) {
super(innerRetrievers, rankWindowSize);
Expand All @@ -157,6 +171,7 @@ public SearchHit hit() {
this.diversificationType = diversificationType;
this.diversificationField = diversificationField;
this.queryVector = queryVector;
this.queryVectorBuilder = queryVectorBuilder;
this.lambda = lambda;
this.size = size == null ? Math.min(DEFAULT_SIZE_VALUE, rankWindowSize) : size;
}
Expand All @@ -170,6 +185,7 @@ protected DiversifyRetrieverBuilder clone(List<RetrieverSource> newChildRetrieve
rankWindowSize,
size,
queryVector,
queryVectorBuilder,
lambda
);
}
Expand All @@ -181,6 +197,19 @@ public ActionRequestValidationException validate(
boolean isScroll,
boolean allowPartialSearchResults
) {
if (queryVector != null && queryVectorBuilder != null) {
validationException = addValidationError(
String.format(
Locale.ROOT,
"[%s] MMR result diversification can have one of [%s] or [%s], but not both",
getName(),
QUERY_VECTOR_FIELD.getPreferredName(),
QUERY_VECTOR_BUILDER_FIELD.getPreferredName()
),
validationException
);
}

if (diversificationType.equals(ResultDiversificationType.MMR)) {
validationException = validateMMRDiversification(validationException);
}
Expand Down Expand Up @@ -235,17 +264,37 @@ private ActionRequestValidationException validateMMRDiversification(ActionReques

@Override
protected RetrieverBuilder doRewrite(QueryRewriteContext ctx) {
if (diversificationType.equals(ResultDiversificationType.MMR)) {
// field vectors will be filled in during the combine
diversificationContext = new MMRResultDiversificationContext(
if (queryVectorBuilder != null) {
SetOnce<VectorData> toSet = new SetOnce<>();
ctx.registerAsyncAction((c, l) -> {
queryVectorBuilder.buildVector(c, l.delegateFailureAndWrap((ll, v) -> {
toSet.set(v == null ? null : new VectorData(v));
if (v == null) {
ll.onFailure(
new IllegalArgumentException(
format(
"[%s] with name [%s] returned null query_vector",
QUERY_VECTOR_BUILDER_FIELD.getPreferredName(),
queryVectorBuilder.getWriteableName()
)
)
);
return;
}
ll.onResponse(null);
}));
});

return new DiversifyRetrieverBuilder(
innerRetrievers,
diversificationType,
diversificationField,
lambda,
size == null ? DEFAULT_SIZE_VALUE : size,
queryVector
rankWindowSize,
size,
() -> toSet.get(),
null,
lambda
);
} else {
// should not happen
throw new IllegalArgumentException("Unknown diversification type [" + diversificationType + "]");
}

return this;
Expand Down Expand Up @@ -281,13 +330,6 @@ protected Exception processInnerItemFailureException(Exception ex) {

@Override
protected RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults, boolean explain) {
if (diversificationContext == null) {
throw new ElasticsearchStatusException(
"diversificationContext is not set. \"doRewrite\" should have been called beforehand.",
RestStatus.INTERNAL_SERVER_ERROR
);
}

if (rankResults.isEmpty()) {
return new RankDoc[0];
}
Expand All @@ -302,6 +344,8 @@ protected RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults, b
return new RankDoc[0];
}

ResultDiversificationContext diversificationContext = getResultDiversificationContext();

// gather and set the query vectors
// and create our intermediate results set
RankDoc[] results = new RankDoc[scoreDocs.length];
Expand Down Expand Up @@ -344,6 +388,15 @@ protected RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults, b
}
}

private ResultDiversificationContext getResultDiversificationContext() {
if (diversificationType.equals(ResultDiversificationType.MMR)) {
return new MMRResultDiversificationContext(diversificationField, lambda, size == null ? DEFAULT_SIZE_VALUE : size, queryVector);
}

// should not happen
throw new IllegalArgumentException("Unknown diversification type [" + diversificationType + "]");
}

private void extractFieldVectorData(int docId, Object fieldValue, Map<Integer, VectorData> fieldVectors) {
switch (fieldValue) {
case float[] floatArray -> {
Expand Down Expand Up @@ -427,7 +480,11 @@ protected void doToXContent(XContentBuilder builder, Params params) throws IOExc
builder.field(RANK_WINDOW_SIZE_FIELD.getPreferredName(), rankWindowSize);

if (queryVector != null) {
builder.field(QUERY_VECTOR_FIELD.getPreferredName(), queryVector);
builder.field(QUERY_VECTOR_FIELD.getPreferredName(), queryVector.get());
}

if (queryVectorBuilder != null) {
builder.field(QUERY_VECTOR_BUILDER_FIELD.getPreferredName(), queryVectorBuilder);
}

if (lambda != null) {
Expand All @@ -451,6 +508,8 @@ public boolean doEquals(Object o) {
&& this.diversificationType.equals(other.diversificationType)
&& this.diversificationField.equals(other.diversificationField)
&& Objects.equals(this.lambda, other.lambda)
&& Objects.equals(this.queryVector, other.queryVector);
&& ((queryVector == null && other.queryVector == null)
|| (queryVector != null && other.queryVector != null && Objects.equals(queryVector.get(), other.queryVector.get())))
&& Objects.equals(this.queryVectorBuilder, other.queryVectorBuilder);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@

import java.util.Map;
import java.util.Set;
import java.util.function.Supplier;

public abstract class ResultDiversificationContext {
private final String field;
private final int size;
private final VectorData queryVector;
private final Supplier<VectorData> queryVector;
private Map<Integer, VectorData> fieldVectors = null;

protected ResultDiversificationContext(String field, int size, @Nullable VectorData queryVector) {
protected ResultDiversificationContext(String field, int size, @Nullable Supplier<VectorData> queryVector) {
this.field = field;
this.size = size;
this.queryVector = queryVector;
Expand All @@ -45,7 +46,7 @@ public void setFieldVectors(Map<Integer, VectorData> fieldVectors) {
}

public VectorData getQueryVector() {
return queryVector;
return queryVector.get();
}

public VectorData getFieldVector(int rank) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
import org.elasticsearch.search.diversification.ResultDiversificationContext;
import org.elasticsearch.search.vectors.VectorData;

import java.util.function.Supplier;

public class MMRResultDiversificationContext extends ResultDiversificationContext {

private final float lambda;

public MMRResultDiversificationContext(String field, float lambda, int size, @Nullable VectorData queryVector) {
public MMRResultDiversificationContext(String field, float lambda, int size, @Nullable Supplier<VectorData> queryVector) {
super(field, size, queryVector);
this.lambda = lambda;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ protected DiversifyRetrieverBuilder createTestInstance() {
rankWindowSize,
size,
queryVector,
null,
lambda
);
}
Expand Down Expand Up @@ -92,11 +93,7 @@ protected NamedXContentRegistry xContentRegistry() {

private VectorData getRandomQueryVector() {
if (randomBoolean()) {
float[] queryVector = new float[randomIntBetween(5, 256)];
for (int i = 0; i < queryVector.length; i++) {
queryVector[i] = randomFloatBetween(0.0f, 1.0f, true);
}
return new VectorData(queryVector);
return new VectorData(getRandomFloatQueryVector());
}

byte[] queryVector = new byte[randomIntBetween(5, 256)];
Expand All @@ -105,4 +102,12 @@ private VectorData getRandomQueryVector() {
}
return new VectorData(queryVector);
}

private float[] getRandomFloatQueryVector() {
float[] queryVector = new float[randomIntBetween(5, 256)];
for (int i = 0; i < queryVector.length; i++) {
queryVector[i] = randomFloatBetween(0.0f, 1.0f, true);
}
return queryVector;
}
}
Loading
Loading