-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Add query_vector_builder Support to Diversify Retriever
#139094
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
a9aef94
fad45ae
693a432
17c1081
9ccb78c
d69af11
94eb716
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,6 +10,7 @@ | |
| package org.elasticsearch.search.diversification; | ||
|
|
||
| import org.apache.lucene.search.ScoreDoc; | ||
| import org.apache.lucene.util.SetOnce; | ||
| import org.elasticsearch.ElasticsearchException; | ||
| import org.elasticsearch.ElasticsearchStatusException; | ||
| import org.elasticsearch.action.ActionRequestValidationException; | ||
|
|
@@ -26,6 +27,7 @@ | |
| import org.elasticsearch.search.retriever.CompoundRetrieverBuilder; | ||
| import org.elasticsearch.search.retriever.RetrieverBuilder; | ||
| import org.elasticsearch.search.retriever.RetrieverParserContext; | ||
| import org.elasticsearch.search.vectors.QueryVectorBuilder; | ||
| import org.elasticsearch.search.vectors.VectorData; | ||
| import org.elasticsearch.xcontent.ConstructingObjectParser; | ||
| import org.elasticsearch.xcontent.ObjectParser; | ||
|
|
@@ -40,15 +42,16 @@ | |
| import java.util.Locale; | ||
| import java.util.Map; | ||
| import java.util.Objects; | ||
| import java.util.function.Supplier; | ||
|
|
||
| import static org.elasticsearch.action.ValidateActions.addValidationError; | ||
| import static org.elasticsearch.common.Strings.format; | ||
| import static org.elasticsearch.search.rank.RankBuilder.DEFAULT_RANK_WINDOW_SIZE; | ||
| import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; | ||
| import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; | ||
|
|
||
| public final class DiversifyRetrieverBuilder extends CompoundRetrieverBuilder<DiversifyRetrieverBuilder> { | ||
|
|
||
| public static final Float DEFAULT_LAMBDA_VALUE = 0.7f; | ||
| public static final int DEFAULT_SIZE_VALUE = 10; | ||
|
|
||
| public static final NodeFeature RETRIEVER_RESULT_DIVERSIFICATION_MMR_FEATURE = new NodeFeature("retriever.result_diversification_mmr"); | ||
|
|
@@ -58,6 +61,7 @@ public final class DiversifyRetrieverBuilder extends CompoundRetrieverBuilder<Di | |
| public static final ParseField TYPE_FIELD = new ParseField("type"); | ||
| public static final ParseField FIELD_FIELD = new ParseField("field"); | ||
| public static final ParseField QUERY_VECTOR_FIELD = new ParseField("query_vector"); | ||
| public static final ParseField QUERY_VECTOR_BUILDER_FIELD = new ParseField("query_vector_builder"); | ||
| public static final ParseField LAMBDA_FIELD = new ParseField("lambda"); | ||
| public static final ParseField SIZE_FIELD = new ParseField("size"); | ||
|
|
||
|
|
@@ -83,8 +87,9 @@ public SearchHit hit() { | |
| int rankWindowSize = args[3] == null ? DEFAULT_RANK_WINDOW_SIZE : (int) args[3]; | ||
|
|
||
| VectorData queryVector = args[4] == null ? null : (VectorData) args[4]; | ||
| Float lambda = args[5] == null ? null : (Float) args[5]; | ||
| Integer size = args[6] == null ? null : (Integer) args[6]; | ||
| QueryVectorBuilder queryVectorBuilder = args[5] == null ? null : (QueryVectorBuilder) args[5]; | ||
| Float lambda = args[6] == null ? null : (Float) args[6]; | ||
| Integer size = args[7] == null ? null : (Integer) args[7]; | ||
|
|
||
| return new DiversifyRetrieverBuilder( | ||
| RetrieverSource.from((RetrieverBuilder) args[0]), | ||
|
|
@@ -93,6 +98,7 @@ public SearchHit hit() { | |
| rankWindowSize, | ||
| size, | ||
| queryVector, | ||
| queryVectorBuilder, | ||
| lambda | ||
| ); | ||
| } | ||
|
|
@@ -113,17 +119,22 @@ public SearchHit hit() { | |
| QUERY_VECTOR_FIELD, | ||
| ObjectParser.ValueType.OBJECT_ARRAY_STRING_OR_NUMBER | ||
| ); | ||
| PARSER.declareNamedObject( | ||
| optionalConstructorArg(), | ||
| (p, c, n) -> p.namedObject(QueryVectorBuilder.class, n, c), | ||
| QUERY_VECTOR_BUILDER_FIELD | ||
| ); | ||
| PARSER.declareFloat(optionalConstructorArg(), LAMBDA_FIELD); | ||
| PARSER.declareInt(optionalConstructorArg(), SIZE_FIELD); | ||
| RetrieverBuilder.declareBaseParserFields(PARSER); | ||
| } | ||
|
|
||
| private final ResultDiversificationType diversificationType; | ||
| private final String diversificationField; | ||
| private final VectorData queryVector; | ||
| private final Supplier<VectorData> queryVector; | ||
| private final QueryVectorBuilder queryVectorBuilder; | ||
| private final Float lambda; | ||
| private final Integer size; | ||
| private ResultDiversificationContext diversificationContext = null; | ||
|
|
||
| DiversifyRetrieverBuilder( | ||
| RetrieverSource innerRetriever, | ||
|
|
@@ -132,12 +143,14 @@ public SearchHit hit() { | |
| int rankWindowSize, | ||
| @Nullable Integer size, | ||
| @Nullable VectorData queryVector, | ||
| @Nullable QueryVectorBuilder queryVectorBuilder, | ||
| @Nullable Float lambda | ||
| ) { | ||
| super(List.of(innerRetriever), rankWindowSize); | ||
| this.diversificationType = diversificationType; | ||
| this.diversificationField = diversificationField; | ||
| this.queryVector = queryVector; | ||
| this.queryVector = queryVector != null ? () -> queryVector : null; | ||
| this.queryVectorBuilder = queryVectorBuilder; | ||
| this.lambda = lambda; | ||
| this.size = size == null ? Math.min(DEFAULT_SIZE_VALUE, rankWindowSize) : size; | ||
| } | ||
|
|
@@ -148,7 +161,8 @@ public SearchHit hit() { | |
| String diversificationField, | ||
| int rankWindowSize, | ||
| @Nullable Integer size, | ||
| @Nullable VectorData queryVector, | ||
| @Nullable Supplier<VectorData> queryVector, | ||
| @Nullable QueryVectorBuilder queryVectorBuilder, | ||
| @Nullable Float lambda | ||
| ) { | ||
| super(innerRetrievers, rankWindowSize); | ||
|
|
@@ -157,6 +171,7 @@ public SearchHit hit() { | |
| this.diversificationType = diversificationType; | ||
| this.diversificationField = diversificationField; | ||
| this.queryVector = queryVector; | ||
| this.queryVectorBuilder = queryVectorBuilder; | ||
| this.lambda = lambda; | ||
| this.size = size == null ? Math.min(DEFAULT_SIZE_VALUE, rankWindowSize) : size; | ||
| } | ||
|
|
@@ -170,6 +185,7 @@ protected DiversifyRetrieverBuilder clone(List<RetrieverSource> newChildRetrieve | |
| rankWindowSize, | ||
| size, | ||
| queryVector, | ||
| queryVectorBuilder, | ||
| lambda | ||
| ); | ||
| } | ||
|
|
@@ -181,6 +197,19 @@ public ActionRequestValidationException validate( | |
| boolean isScroll, | ||
| boolean allowPartialSearchResults | ||
| ) { | ||
| if (queryVector != null && queryVectorBuilder != null) { | ||
| validationException = addValidationError( | ||
| String.format( | ||
| Locale.ROOT, | ||
| "[%s] MMR result diversification can have one of [%s] or [%s], but not both", | ||
| getName(), | ||
| QUERY_VECTOR_FIELD.getPreferredName(), | ||
| QUERY_VECTOR_BUILDER_FIELD.getPreferredName() | ||
| ), | ||
| validationException | ||
| ); | ||
| } | ||
|
|
||
| if (diversificationType.equals(ResultDiversificationType.MMR)) { | ||
| validationException = validateMMRDiversification(validationException); | ||
| } | ||
|
|
@@ -235,17 +264,37 @@ private ActionRequestValidationException validateMMRDiversification(ActionReques | |
|
|
||
| @Override | ||
| protected RetrieverBuilder doRewrite(QueryRewriteContext ctx) { | ||
| if (diversificationType.equals(ResultDiversificationType.MMR)) { | ||
| // field vectors will be filled in during the combine | ||
| diversificationContext = new MMRResultDiversificationContext( | ||
| if (queryVectorBuilder != null) { | ||
| SetOnce<VectorData> toSet = new SetOnce<>(); | ||
| ctx.registerAsyncAction((c, l) -> { | ||
| queryVectorBuilder.buildVector(c, l.delegateFailureAndWrap((ll, v) -> { | ||
| toSet.set(v == null ? null : new VectorData(v)); | ||
| if (v == null) { | ||
| ll.onFailure( | ||
| new IllegalArgumentException( | ||
| format( | ||
| "[%s] with name [%s] returned null query_vector", | ||
| QUERY_VECTOR_BUILDER_FIELD.getPreferredName(), | ||
| queryVectorBuilder.getWriteableName() | ||
| ) | ||
| ) | ||
| ); | ||
| return; | ||
| } | ||
| ll.onResponse(null); | ||
| })); | ||
| }); | ||
|
|
||
| return new DiversifyRetrieverBuilder( | ||
| innerRetrievers, | ||
| diversificationType, | ||
| diversificationField, | ||
| lambda, | ||
| size == null ? DEFAULT_SIZE_VALUE : size, | ||
| queryVector | ||
| rankWindowSize, | ||
| size, | ||
| () -> toSet.get(), | ||
| null, | ||
| lambda | ||
| ); | ||
| } else { | ||
| // should not happen | ||
| throw new IllegalArgumentException("Unknown diversification type [" + diversificationType + "]"); | ||
| } | ||
|
|
||
| return this; | ||
|
|
@@ -281,13 +330,6 @@ protected Exception processInnerItemFailureException(Exception ex) { | |
|
|
||
| @Override | ||
| protected RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults, boolean explain) { | ||
| if (diversificationContext == null) { | ||
| throw new ElasticsearchStatusException( | ||
| "diversificationContext is not set. \"doRewrite\" should have been called beforehand.", | ||
| RestStatus.INTERNAL_SERVER_ERROR | ||
| ); | ||
| } | ||
|
|
||
| if (rankResults.isEmpty()) { | ||
| return new RankDoc[0]; | ||
| } | ||
|
|
@@ -302,6 +344,8 @@ protected RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults, b | |
| return new RankDoc[0]; | ||
| } | ||
|
|
||
| ResultDiversificationContext diversificationContext = getResultDiversificationContext(); | ||
|
|
||
| // gather and set the query vectors | ||
| // and create our intermediate results set | ||
| RankDoc[] results = new RankDoc[scoreDocs.length]; | ||
|
|
@@ -344,6 +388,15 @@ protected RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults, b | |
| } | ||
| } | ||
|
|
||
| private ResultDiversificationContext getResultDiversificationContext() { | ||
| if (diversificationType.equals(ResultDiversificationType.MMR)) { | ||
| return new MMRResultDiversificationContext(diversificationField, lambda, size == null ? DEFAULT_SIZE_VALUE : size, queryVector); | ||
| } | ||
|
|
||
| // should not happen | ||
| throw new IllegalArgumentException("Unknown diversification type [" + diversificationType + "]"); | ||
| } | ||
|
|
||
| private void extractFieldVectorData(int docId, Object fieldValue, Map<Integer, VectorData> fieldVectors) { | ||
| switch (fieldValue) { | ||
| case float[] floatArray -> { | ||
|
|
@@ -427,7 +480,11 @@ protected void doToXContent(XContentBuilder builder, Params params) throws IOExc | |
| builder.field(RANK_WINDOW_SIZE_FIELD.getPreferredName(), rankWindowSize); | ||
|
|
||
| if (queryVector != null) { | ||
| builder.field(QUERY_VECTOR_FIELD.getPreferredName(), queryVector); | ||
| builder.field(QUERY_VECTOR_FIELD.getPreferredName(), queryVector.get()); | ||
| } | ||
|
|
||
| if (queryVectorBuilder != null) { | ||
| builder.field(QUERY_VECTOR_BUILDER_FIELD.getPreferredName(), queryVectorBuilder); | ||
| } | ||
|
|
||
| if (lambda != null) { | ||
|
|
@@ -451,6 +508,8 @@ public boolean doEquals(Object o) { | |
| && this.diversificationType.equals(other.diversificationType) | ||
| && this.diversificationField.equals(other.diversificationField) | ||
| && Objects.equals(this.lambda, other.lambda) | ||
| && Objects.equals(this.queryVector, other.queryVector); | ||
| && ((queryVector == null && other.queryVector == null) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not necessarily in this case, as the |
||
| || (queryVector != null && other.queryVector != null && Objects.equals(queryVector.get(), other.queryVector.get()))) | ||
| && Objects.equals(this.queryVectorBuilder, other.queryVectorBuilder); | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please don't throw 🍅 at me 🙈
I think we should verify that we want the same
query_vector_builderAPI here. There's some legacy API crust, for example the API asks for amodel_idetc. and there's opportunity to modernize it withinference_id. This would be more work, though, so I'm not raising it as a blocker to the PR more as a discussion if that's something we want to bite off or if this is OK.If we decide to go forward with
query_vector_builder, perhaps not for this PR but the docs could use a bit of work here to define allowed parameters, and have something directly linkable.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with what you're saying - but I think it's a larger, and more discrete PR (as it also touches other users of the QueryVectorBuilder). The docs however are probably a good idea.