From d2c22a67bf1fad24ec45fce52025630aedac0c3c Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Wed, 2 Apr 2025 14:03:02 -0400 Subject: [PATCH 1/7] Do highlighting in RankFeatureShardPhase --- .../subphase/highlight/HighlightBuilder.java | 2 +- .../highlight/SearchHighlightContext.java | 2 +- .../rank/feature/RankFeatureShardPhase.java | 16 +++++++++++++++- 3 files changed, 17 insertions(+), 3 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/search/fetch/subphase/highlight/HighlightBuilder.java b/server/src/main/java/org/elasticsearch/search/fetch/subphase/highlight/HighlightBuilder.java index 1997601e73f6d..2a48e8338e377 100644 --- a/server/src/main/java/org/elasticsearch/search/fetch/subphase/highlight/HighlightBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/fetch/subphase/highlight/HighlightBuilder.java @@ -399,7 +399,7 @@ public Field(String name) { this.name = name; } - private Field(Field template, QueryBuilder builder) { + Field(Field template, QueryBuilder builder) { super(template, builder); name = template.name; fragmentOffset = template.fragmentOffset; diff --git a/server/src/main/java/org/elasticsearch/search/fetch/subphase/highlight/SearchHighlightContext.java b/server/src/main/java/org/elasticsearch/search/fetch/subphase/highlight/SearchHighlightContext.java index 631a75a355abf..a85ae92c24bcf 100644 --- a/server/src/main/java/org/elasticsearch/search/fetch/subphase/highlight/SearchHighlightContext.java +++ b/server/src/main/java/org/elasticsearch/search/fetch/subphase/highlight/SearchHighlightContext.java @@ -40,7 +40,7 @@ public static class Field { private final String field; private final FieldOptions fieldOptions; - Field(String field, FieldOptions fieldOptions) { + public Field(String field, FieldOptions fieldOptions) { assert field != null; assert fieldOptions != null; this.field = field; diff --git a/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardPhase.java b/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardPhase.java index 4374c06da365d..251e5ef449d01 100644 --- a/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardPhase.java +++ b/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardPhase.java @@ -17,10 +17,13 @@ import org.elasticsearch.search.fetch.StoredFieldsContext; import org.elasticsearch.search.fetch.subphase.FetchFieldsContext; import org.elasticsearch.search.fetch.subphase.FieldAndFormat; +import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder; +import org.elasticsearch.search.fetch.subphase.highlight.SearchHighlightContext; import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.search.rank.context.RankFeaturePhaseRankShardContext; import org.elasticsearch.tasks.TaskCancelledException; +import java.io.IOException; import java.util.Arrays; import java.util.Collections; @@ -48,10 +51,21 @@ public static void prepareForFetch(SearchContext searchContext, RankFeatureShard RankFeaturePhaseRankShardContext rankFeaturePhaseRankShardContext = shardContext(searchContext); if (rankFeaturePhaseRankShardContext != null) { - assert rankFeaturePhaseRankShardContext.getField() != null : "field must not be null"; + String field = rankFeaturePhaseRankShardContext.getField(); + assert field != null : "field must not be null"; searchContext.fetchFieldsContext( new FetchFieldsContext(Collections.singletonList(new FieldAndFormat(rankFeaturePhaseRankShardContext.getField(), null))) ); + try { + SearchHighlightContext searchHighlightContext = new HighlightBuilder().field(field) + .numOfFragments(1) // TODO set num fragments + .preTags("") + .postTags("") + .build(searchContext.getSearchExecutionContext()); + searchContext.highlight(searchHighlightContext); + } catch (IOException e) { + throw new RuntimeException("Failed to create highlight context", e); + } searchContext.storedFieldsContext(StoredFieldsContext.fromList(Collections.singletonList(StoredFieldsContext._NONE_))); searchContext.addFetchResult(); Arrays.sort(request.getDocIds()); From 3f52ac7a4e50e0b51e00f646d12e77c1581f10a4 Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Wed, 2 Apr 2025 16:56:30 -0400 Subject: [PATCH 2/7] Propagate to --- .../org/elasticsearch/TransportVersions.java | 1 + .../action/search/RankFeaturePhase.java | 4 +- ...ankFeaturePhaseRankCoordinatorContext.java | 9 +++- .../search/rank/feature/RankFeatureDoc.java | 18 ++++++- .../rank/feature/RankFeatureShardPhase.java | 18 ++++--- .../rank/feature/RankFeatureShardRequest.java | 20 ++++++- .../search/rank/feature/Snippets.java | 29 ++++++++++ ...nkingRankFeaturePhaseRankShardContext.java | 17 +++++- ...ankFeaturePhaseRankCoordinatorContext.java | 2 +- .../TextSimilarityRankBuilder.java | 20 ++++++- ...ankFeaturePhaseRankCoordinatorContext.java | 22 +++++--- .../TextSimilarityRankRetrieverBuilder.java | 54 ++++++++++++++++--- 12 files changed, 186 insertions(+), 28 deletions(-) create mode 100644 server/src/main/java/org/elasticsearch/search/rank/feature/Snippets.java diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index eececd187f11e..aaa7b8ae311dc 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -209,6 +209,7 @@ static TransportVersion def(int id) { public static final TransportVersion INDEX_STATS_AND_METADATA_INCLUDE_PEAK_WRITE_LOAD = def(9_041_0_00); public static final TransportVersion REPOSITORIES_METADATA_AS_PROJECT_CUSTOM = def(9_042_0_00); public static final TransportVersion BATCHED_QUERY_PHASE_VERSION = def(9_043_0_00); + public static final TransportVersion RERANK_SNIPPETS = def(9_044_0_00); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java b/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java index 25238711c5c1c..df9e34bce9200 100644 --- a/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java @@ -25,6 +25,7 @@ import org.elasticsearch.search.rank.feature.RankFeatureDoc; import org.elasticsearch.search.rank.feature.RankFeatureResult; import org.elasticsearch.search.rank.feature.RankFeatureShardRequest; +import org.elasticsearch.search.rank.feature.Snippets; import org.elasticsearch.transport.Transport; import java.util.Arrays; @@ -172,7 +173,8 @@ public void onFailure(Exception e) { context.getOriginalIndices(queryResult.getShardIndex()), queryResult.getContextId(), queryResult.getShardSearchRequest(), - entry + entry, + rankFeaturePhaseRankCoordinatorContext.snippets() ), context.getTask(), listener diff --git a/server/src/main/java/org/elasticsearch/search/rank/context/RankFeaturePhaseRankCoordinatorContext.java b/server/src/main/java/org/elasticsearch/search/rank/context/RankFeaturePhaseRankCoordinatorContext.java index 819d04e12eeeb..1e6bf1fe8e310 100644 --- a/server/src/main/java/org/elasticsearch/search/rank/context/RankFeaturePhaseRankCoordinatorContext.java +++ b/server/src/main/java/org/elasticsearch/search/rank/context/RankFeaturePhaseRankCoordinatorContext.java @@ -12,6 +12,7 @@ import org.apache.lucene.search.ScoreDoc; import org.elasticsearch.action.ActionListener; import org.elasticsearch.search.rank.feature.RankFeatureDoc; +import org.elasticsearch.search.rank.feature.Snippets; import java.util.Arrays; import java.util.Comparator; @@ -30,18 +31,24 @@ public abstract class RankFeaturePhaseRankCoordinatorContext { protected final int from; protected final int rankWindowSize; protected final boolean failuresAllowed; + protected final Snippets snippets; - public RankFeaturePhaseRankCoordinatorContext(int size, int from, int rankWindowSize, boolean failuresAllowed) { + public RankFeaturePhaseRankCoordinatorContext(int size, int from, int rankWindowSize, boolean failuresAllowed, Snippets snippets) { this.size = size < 0 ? DEFAULT_SIZE : size; this.from = from < 0 ? DEFAULT_FROM : from; this.rankWindowSize = rankWindowSize; this.failuresAllowed = failuresAllowed; + this.snippets = snippets; } public boolean failuresAllowed() { return failuresAllowed; } + public Snippets snippets() { + return snippets; + } + /** * Computes the updated scores for a list of features (i.e. document-based data). We also pass along an ActionListener * that should be called with the new scores, and will continue execution to the next phase diff --git a/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureDoc.java b/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureDoc.java index cd8d9392aced8..9a699def4b2cd 100644 --- a/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureDoc.java +++ b/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureDoc.java @@ -10,12 +10,14 @@ package org.elasticsearch.search.rank.feature; import org.apache.lucene.search.Explanation; +import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.search.rank.RankDoc; import org.elasticsearch.xcontent.XContentBuilder; import java.io.IOException; +import java.util.List; import java.util.Objects; /** @@ -27,6 +29,7 @@ public class RankFeatureDoc extends RankDoc { // TODO: update to support more than 1 fields; and not restrict to string data public String featureData; + public List snippets; public RankFeatureDoc(int doc, float score, int shardIndex) { super(doc, score, shardIndex); @@ -35,6 +38,9 @@ public RankFeatureDoc(int doc, float score, int shardIndex) { public RankFeatureDoc(StreamInput in) throws IOException { super(in); featureData = in.readOptionalString(); + if (in.getTransportVersion().onOrAfter(TransportVersions.RERANK_SNIPPETS)) { + snippets = in.readOptionalStringCollectionAsList(); + } } @Override @@ -46,20 +52,27 @@ public void featureData(String featureData) { this.featureData = featureData; } + public void snippets(List snippets) { + this.snippets = snippets; + } + @Override protected void doWriteTo(StreamOutput out) throws IOException { out.writeOptionalString(featureData); + if (out.getTransportVersion().onOrAfter(TransportVersions.RERANK_SNIPPETS)) { + out.writeOptionalStringCollection(snippets); + } } @Override protected boolean doEquals(RankDoc rd) { RankFeatureDoc other = (RankFeatureDoc) rd; - return Objects.equals(this.featureData, other.featureData); + return Objects.equals(this.featureData, other.featureData) && Objects.equals(this.snippets, other.snippets); } @Override protected int doHashCode() { - return Objects.hashCode(featureData); + return Objects.hash(featureData, snippets); } @Override @@ -70,5 +83,6 @@ public String getWriteableName() { @Override protected void doToXContent(XContentBuilder builder, Params params) throws IOException { builder.field("featureData", featureData); + builder.array("snippets", snippets); } } diff --git a/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardPhase.java b/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardPhase.java index 251e5ef449d01..ea33ae7ea429c 100644 --- a/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardPhase.java +++ b/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardPhase.java @@ -57,12 +57,18 @@ public static void prepareForFetch(SearchContext searchContext, RankFeatureShard new FetchFieldsContext(Collections.singletonList(new FieldAndFormat(rankFeaturePhaseRankShardContext.getField(), null))) ); try { - SearchHighlightContext searchHighlightContext = new HighlightBuilder().field(field) - .numOfFragments(1) // TODO set num fragments - .preTags("") - .postTags("") - .build(searchContext.getSearchExecutionContext()); - searchContext.highlight(searchHighlightContext); + Snippets snippets = request.snippets(); + if (snippets != null) { + HighlightBuilder highlightBuilder = new HighlightBuilder().field(field).preTags("").postTags(""); + if (snippets.numFragments() != null) { + highlightBuilder.numOfFragments(snippets.numFragments()); + } + if (snippets.maxSize() != null) { + highlightBuilder.fragmentSize(snippets.maxSize()); + } + SearchHighlightContext searchHighlightContext = highlightBuilder.build(searchContext.getSearchExecutionContext()); + searchContext.highlight(searchHighlightContext); + } } catch (IOException e) { throw new RuntimeException("Failed to create highlight context", e); } diff --git a/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardRequest.java b/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardRequest.java index 41b271a6beb53..b538333f850a3 100644 --- a/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardRequest.java +++ b/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardRequest.java @@ -9,12 +9,14 @@ package org.elasticsearch.search.rank.feature; +import org.elasticsearch.TransportVersions; import org.elasticsearch.action.IndicesRequest; import org.elasticsearch.action.OriginalIndices; import org.elasticsearch.action.search.SearchShardTask; import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; import org.elasticsearch.search.internal.ShardSearchContextId; import org.elasticsearch.search.internal.ShardSearchRequest; import org.elasticsearch.tasks.TaskId; @@ -38,16 +40,20 @@ public class RankFeatureShardRequest extends TransportRequest implements Indices private final int[] docIds; + private final Snippets snippets; + public RankFeatureShardRequest( OriginalIndices originalIndices, ShardSearchContextId contextId, ShardSearchRequest shardSearchRequest, - List docIds + List docIds, + @Nullable Snippets snippets ) { this.originalIndices = originalIndices; this.shardSearchRequest = shardSearchRequest; this.docIds = docIds.stream().flatMapToInt(IntStream::of).toArray(); this.contextId = contextId; + this.snippets = snippets; } public RankFeatureShardRequest(StreamInput in) throws IOException { @@ -56,6 +62,11 @@ public RankFeatureShardRequest(StreamInput in) throws IOException { shardSearchRequest = in.readOptionalWriteable(ShardSearchRequest::new); docIds = in.readIntArray(); contextId = in.readOptionalWriteable(ShardSearchContextId::new); + if (in.getTransportVersion().onOrAfter(TransportVersions.RERANK_SNIPPETS)) { + snippets = in.readOptionalWriteable(Snippets::new); + } else { + snippets = null; + } } @Override @@ -65,6 +76,9 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalWriteable(shardSearchRequest); out.writeIntArray(docIds); out.writeOptionalWriteable(contextId); + if (out.getTransportVersion().onOrAfter(TransportVersions.RERANK_SNIPPETS)) { + out.writeOptionalWriteable(snippets); + } } @Override @@ -95,6 +109,10 @@ public ShardSearchContextId contextId() { return contextId; } + public Snippets snippets() { + return snippets; + } + @Override public SearchShardTask createTask(long id, String type, String action, TaskId parentTaskId, Map headers) { return new SearchShardTask(id, type, action, getDescription(), parentTaskId, headers); diff --git a/server/src/main/java/org/elasticsearch/search/rank/feature/Snippets.java b/server/src/main/java/org/elasticsearch/search/rank/feature/Snippets.java new file mode 100644 index 0000000000000..7e92b096341cc --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/rank/feature/Snippets.java @@ -0,0 +1,29 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.rank.feature; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; + +import java.io.IOException; + +public record Snippets(Integer numFragments, Integer maxSize) implements Writeable { + + public Snippets(StreamInput in) throws IOException { + this(in.readOptionalVInt(), in.readOptionalVInt()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalVInt(numFragments); + out.writeOptionalVInt(maxSize); + } +} diff --git a/server/src/main/java/org/elasticsearch/search/rank/rerank/RerankingRankFeaturePhaseRankShardContext.java b/server/src/main/java/org/elasticsearch/search/rank/rerank/RerankingRankFeaturePhaseRankShardContext.java index 96867fb1d190b..57f81e45f5c11 100644 --- a/server/src/main/java/org/elasticsearch/search/rank/rerank/RerankingRankFeaturePhaseRankShardContext.java +++ b/server/src/main/java/org/elasticsearch/search/rank/rerank/RerankingRankFeaturePhaseRankShardContext.java @@ -12,14 +12,20 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.common.document.DocumentField; +import org.elasticsearch.common.text.Text; import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHits; +import org.elasticsearch.search.fetch.subphase.highlight.HighlightField; import org.elasticsearch.search.rank.RankShardResult; import org.elasticsearch.search.rank.context.RankFeaturePhaseRankShardContext; import org.elasticsearch.search.rank.feature.RankFeatureDoc; import org.elasticsearch.search.rank.feature.RankFeatureShardResult; +import java.util.ArrayList; import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; /** * The {@code ReRankingRankFeaturePhaseRankShardContext} is handles the {@code SearchHits} generated from the {@code RankFeatureShardPhase} @@ -40,10 +46,19 @@ public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) RankFeatureDoc[] rankFeatureDocs = new RankFeatureDoc[hits.getHits().length]; for (int i = 0; i < hits.getHits().length; i++) { rankFeatureDocs[i] = new RankFeatureDoc(hits.getHits()[i].docId(), hits.getHits()[i].getScore(), shardId); - DocumentField docField = hits.getHits()[i].field(field); + SearchHit hit = hits.getHits()[i]; + DocumentField docField = hit.field(field); if (docField != null) { rankFeatureDocs[i].featureData(docField.getValue().toString()); } + Map highlightFields = hit.getHighlightFields(); + if (highlightFields != null) { + HighlightField highlightField = highlightFields.get(field); + if (highlightField != null) { + List snippets = new ArrayList<>(Arrays.stream(highlightField.fragments()).map(Text::toString).toList()); + rankFeatureDocs[i].snippets(snippets); + } + } } return new RankFeatureShardResult(rankFeatureDocs); } catch (Exception ex) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/random/RandomRankFeaturePhaseRankCoordinatorContext.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/random/RandomRankFeaturePhaseRankCoordinatorContext.java index de593e0943f42..202b7c5134c48 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/random/RandomRankFeaturePhaseRankCoordinatorContext.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/random/RandomRankFeaturePhaseRankCoordinatorContext.java @@ -22,7 +22,7 @@ public class RandomRankFeaturePhaseRankCoordinatorContext extends RankFeaturePha private final Integer seed; public RandomRankFeaturePhaseRankCoordinatorContext(int size, int from, int rankWindowSize, Integer seed) { - super(size, from, rankWindowSize, false); + super(size, from, rankWindowSize, false, null); this.seed = seed; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java index cc23e8e3a337b..d39f4f8966529 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java @@ -23,6 +23,7 @@ import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext; import org.elasticsearch.search.rank.context.RankFeaturePhaseRankShardContext; import org.elasticsearch.search.rank.feature.RankFeatureDoc; +import org.elasticsearch.search.rank.feature.Snippets; import org.elasticsearch.search.rank.rerank.RerankingRankFeaturePhaseRankShardContext; import org.elasticsearch.xcontent.XContentBuilder; @@ -54,6 +55,7 @@ public class TextSimilarityRankBuilder extends RankBuilder { private final String field; private final Float minScore; private final boolean failuresAllowed; + private final Snippets snippets; public TextSimilarityRankBuilder( String field, @@ -61,7 +63,8 @@ public TextSimilarityRankBuilder( String inferenceText, int rankWindowSize, Float minScore, - boolean failuresAllowed + boolean failuresAllowed, + Snippets snippets ) { super(rankWindowSize); this.inferenceId = inferenceId; @@ -69,6 +72,7 @@ public TextSimilarityRankBuilder( this.field = field; this.minScore = minScore; this.failuresAllowed = failuresAllowed; + this.snippets = snippets; } public TextSimilarityRankBuilder(StreamInput in) throws IOException { @@ -83,6 +87,11 @@ public TextSimilarityRankBuilder(StreamInput in) throws IOException { } else { this.failuresAllowed = false; } + if (in.getTransportVersion().onOrAfter(TransportVersions.RERANK_SNIPPETS)) { + this.snippets = in.readOptionalWriteable(Snippets::new); + } else { + this.snippets = null; + } } @Override @@ -105,6 +114,9 @@ public void doWriteTo(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(TransportVersions.RERANKER_FAILURES_ALLOWED)) { out.writeBoolean(failuresAllowed); } + if (out.getTransportVersion().onOrAfter(TransportVersions.RERANK_SNIPPETS)) { + out.writeOptionalWriteable(snippets); + } } @Override @@ -120,6 +132,9 @@ public void doXContent(XContentBuilder builder, Params params) throws IOExceptio if (failuresAllowed) { builder.field(FAILURES_ALLOWED_FIELD.getPreferredName(), true); } + if (snippets != null) { + + } } @Override @@ -179,7 +194,8 @@ public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorCo inferenceId, inferenceText, minScore, - failuresAllowed + failuresAllowed, + snippets ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java index 7f245ae854eac..5fc287f7a3724 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java @@ -14,6 +14,7 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext; import org.elasticsearch.search.rank.feature.RankFeatureDoc; +import org.elasticsearch.search.rank.feature.Snippets; import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; @@ -21,7 +22,6 @@ import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankTaskSettings; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; import java.util.Map; @@ -44,9 +44,10 @@ public TextSimilarityRankFeaturePhaseRankCoordinatorContext( String inferenceId, String inferenceText, Float minScore, - boolean failuresAllowed + boolean failuresAllowed, + Snippets snippets ) { - super(size, from, rankWindowSize, failuresAllowed); + super(size, from, rankWindowSize, failuresAllowed, snippets); this.client = client; this.inferenceId = inferenceId; this.inferenceText = inferenceText; @@ -64,7 +65,7 @@ protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener rankedDocs = ((RankedDocsResults) results).getRankedDocs(); - if (rankedDocs.size() != featureDocs.length) { + if (snippets == null && rankedDocs.size() != featureDocs.length) { l.onFailure( new IllegalStateException( "Reranker input document count and returned score count mismatch: [" @@ -111,8 +112,17 @@ protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener featureData = Arrays.stream(featureDocs).map(x -> x.featureData).toList(); - InferenceAction.Request inferenceRequest = generateRequest(featureData); + List featureData = new ArrayList<>(); + List snippets = new ArrayList<>(); + for (RankFeatureDoc featureDoc : featureDocs) { + featureData.add(featureDoc.featureData); + if (featureDoc.snippets != null) { + snippets.addAll(featureDoc.snippets); + } + } + InferenceAction.Request inferenceRequest = snippets.isEmpty() == false + ? generateRequest(snippets) + : generateRequest(featureData); try { client.execute(InferenceAction.INSTANCE, inferenceRequest, inferenceListener); } finally { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java index d6883d3743a1d..159fbea457994 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java @@ -14,6 +14,7 @@ import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.rank.RankDoc; +import org.elasticsearch.search.rank.feature.Snippets; import org.elasticsearch.search.retriever.CompoundRetrieverBuilder; import org.elasticsearch.search.retriever.RetrieverBuilder; import org.elasticsearch.search.retriever.RetrieverParserContext; @@ -45,6 +46,9 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder public static final ParseField INFERENCE_TEXT_FIELD = new ParseField("inference_text"); public static final ParseField FIELD_FIELD = new ParseField("field"); public static final ParseField FAILURES_ALLOWED_FIELD = new ParseField("allow_rerank_failures"); + public static final ParseField SNIPPETS_FIELD = new ParseField("snippets"); + public static final ParseField NUM_FRAGMENTS_FIELD = new ParseField("num_fragments"); + public static final ParseField MAX_SIZE_FIELD = new ParseField("max_size"); public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(TextSimilarityRankBuilder.NAME, args -> { @@ -54,6 +58,7 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder String field = (String) args[3]; int rankWindowSize = args[4] == null ? DEFAULT_RANK_WINDOW_SIZE : (int) args[4]; boolean failuresAllowed = args[5] != null && (Boolean) args[5]; + Snippets snippets = (Snippets) args[6]; return new TextSimilarityRankRetrieverBuilder( retrieverBuilder, @@ -61,10 +66,21 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder inferenceText, field, rankWindowSize, - failuresAllowed + failuresAllowed, + snippets ); }); + private static final ConstructingObjectParser SNIPPETS_PARSER = new ConstructingObjectParser<>( + SNIPPETS_FIELD.getPreferredName(), + true, + args -> { + Integer numFragments = (Integer) args[0]; + Integer maxSize = (Integer) args[1]; + return new Snippets(numFragments, maxSize); + } + ); + static { PARSER.declareNamedObject(constructorArg(), (p, c, n) -> { RetrieverBuilder innerRetriever = p.namedObject(RetrieverBuilder.class, n, c); @@ -76,6 +92,9 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder PARSER.declareString(constructorArg(), FIELD_FIELD); PARSER.declareInt(optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD); PARSER.declareBoolean(optionalConstructorArg(), FAILURES_ALLOWED_FIELD); + PARSER.declareObject(optionalConstructorArg(), SNIPPETS_PARSER, SNIPPETS_FIELD); + SNIPPETS_PARSER.declareInt(optionalConstructorArg(), NUM_FRAGMENTS_FIELD); + SNIPPETS_PARSER.declareInt(optionalConstructorArg(), MAX_SIZE_FIELD); RetrieverBuilder.declareBaseParserFields(PARSER); } @@ -95,6 +114,7 @@ public static TextSimilarityRankRetrieverBuilder fromXContent( private final String inferenceText; private final String field; private final boolean failuresAllowed; + private final Snippets snippets; public TextSimilarityRankRetrieverBuilder( RetrieverBuilder retrieverBuilder, @@ -102,13 +122,15 @@ public TextSimilarityRankRetrieverBuilder( String inferenceText, String field, int rankWindowSize, - boolean failuresAllowed + boolean failuresAllowed, + Snippets snippets ) { super(List.of(new RetrieverSource(retrieverBuilder, null)), rankWindowSize); this.inferenceId = inferenceId; this.inferenceText = inferenceText; this.field = field; this.failuresAllowed = failuresAllowed; + this.snippets = snippets; } public TextSimilarityRankRetrieverBuilder( @@ -120,7 +142,8 @@ public TextSimilarityRankRetrieverBuilder( Float minScore, boolean failuresAllowed, String retrieverName, - List preFilterQueryBuilders + List preFilterQueryBuilders, + Snippets snippets ) { super(retrieverSource, rankWindowSize); if (retrieverSource.size() != 1) { @@ -133,6 +156,7 @@ public TextSimilarityRankRetrieverBuilder( this.failuresAllowed = failuresAllowed; this.retrieverName = retrieverName; this.preFilterQueryBuilders = preFilterQueryBuilders; + this.snippets = snippets; } @Override @@ -149,7 +173,8 @@ protected TextSimilarityRankRetrieverBuilder clone( minScore, failuresAllowed, retrieverName, - newPreFilterQueryBuilders + newPreFilterQueryBuilders, + snippets ); } @@ -179,7 +204,7 @@ protected RankDoc[] combineInnerRetrieverResults(List rankResults, b @Override protected SearchSourceBuilder finalizeSourceBuilder(SearchSourceBuilder sourceBuilder) { sourceBuilder.rankBuilder( - new TextSimilarityRankBuilder(field, inferenceId, inferenceText, rankWindowSize, minScore, failuresAllowed) + new TextSimilarityRankBuilder(field, inferenceId, inferenceText, rankWindowSize, minScore, failuresAllowed, snippets) ); return sourceBuilder; } @@ -201,6 +226,10 @@ public boolean failuresAllowed() { return failuresAllowed; } + public Snippets snippets() { + return snippets; + } + @Override protected void doToXContent(XContentBuilder builder, Params params) throws IOException { builder.field(RETRIEVER_FIELD.getPreferredName(), innerRetrievers.getFirst().retriever()); @@ -211,6 +240,16 @@ protected void doToXContent(XContentBuilder builder, Params params) throws IOExc if (failuresAllowed) { builder.field(FAILURES_ALLOWED_FIELD.getPreferredName(), failuresAllowed); } + if (snippets != null) { + builder.startObject(SNIPPETS_FIELD.getPreferredName()); + if (snippets.numFragments() != null) { + builder.field(NUM_FRAGMENTS_FIELD.getPreferredName(), snippets.numFragments()); + } + if (snippets.maxSize() != null) { + builder.field(MAX_SIZE_FIELD.getPreferredName(), snippets.maxSize()); + } + builder.endObject(); + } } @Override @@ -222,11 +261,12 @@ public boolean doEquals(Object other) { && Objects.equals(field, that.field) && rankWindowSize == that.rankWindowSize && Objects.equals(minScore, that.minScore) - && failuresAllowed == that.failuresAllowed; + && failuresAllowed == that.failuresAllowed + && Objects.equals(snippets, that.snippets); } @Override public int doHashCode() { - return Objects.hash(inferenceId, inferenceText, field, rankWindowSize, minScore, failuresAllowed); + return Objects.hash(inferenceId, inferenceText, field, rankWindowSize, minScore, failuresAllowed, snippets); } } From 0ffaf94c4fe3532cefa2debc817602a519f416e1 Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Thu, 3 Apr 2025 15:16:37 -0400 Subject: [PATCH 3/7] Only rerank the first snippet --- .../action/search/RankFeaturePhase.java | 1 - ...ankFeaturePhaseRankCoordinatorContext.java | 14 +++++++---- .../rank/feature/RankFeatureShardPhase.java | 5 +++- .../rank/feature/RankFeatureShardRequest.java | 8 +++---- ...{Snippets.java => RerankSnippetInput.java} | 4 ++-- ...nkingRankFeaturePhaseRankShardContext.java | 8 +++---- .../inference/action/InferenceAction.java | 2 +- .../TextSimilarityRankBuilder.java | 8 +++---- ...ankFeaturePhaseRankCoordinatorContext.java | 21 ++++++++--------- .../TextSimilarityRankRetrieverBuilder.java | 23 ++++++++----------- 10 files changed, 49 insertions(+), 45 deletions(-) rename server/src/main/java/org/elasticsearch/search/rank/feature/{Snippets.java => RerankSnippetInput.java} (85%) diff --git a/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java b/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java index df9e34bce9200..ec3c99e1363e8 100644 --- a/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java @@ -25,7 +25,6 @@ import org.elasticsearch.search.rank.feature.RankFeatureDoc; import org.elasticsearch.search.rank.feature.RankFeatureResult; import org.elasticsearch.search.rank.feature.RankFeatureShardRequest; -import org.elasticsearch.search.rank.feature.Snippets; import org.elasticsearch.transport.Transport; import java.util.Arrays; diff --git a/server/src/main/java/org/elasticsearch/search/rank/context/RankFeaturePhaseRankCoordinatorContext.java b/server/src/main/java/org/elasticsearch/search/rank/context/RankFeaturePhaseRankCoordinatorContext.java index 1e6bf1fe8e310..6afa4cc74e0a8 100644 --- a/server/src/main/java/org/elasticsearch/search/rank/context/RankFeaturePhaseRankCoordinatorContext.java +++ b/server/src/main/java/org/elasticsearch/search/rank/context/RankFeaturePhaseRankCoordinatorContext.java @@ -12,7 +12,7 @@ import org.apache.lucene.search.ScoreDoc; import org.elasticsearch.action.ActionListener; import org.elasticsearch.search.rank.feature.RankFeatureDoc; -import org.elasticsearch.search.rank.feature.Snippets; +import org.elasticsearch.search.rank.feature.RerankSnippetInput; import java.util.Arrays; import java.util.Comparator; @@ -31,9 +31,15 @@ public abstract class RankFeaturePhaseRankCoordinatorContext { protected final int from; protected final int rankWindowSize; protected final boolean failuresAllowed; - protected final Snippets snippets; + protected final RerankSnippetInput snippets; - public RankFeaturePhaseRankCoordinatorContext(int size, int from, int rankWindowSize, boolean failuresAllowed, Snippets snippets) { + public RankFeaturePhaseRankCoordinatorContext( + int size, + int from, + int rankWindowSize, + boolean failuresAllowed, + RerankSnippetInput snippets + ) { this.size = size < 0 ? DEFAULT_SIZE : size; this.from = from < 0 ? DEFAULT_FROM : from; this.rankWindowSize = rankWindowSize; @@ -45,7 +51,7 @@ public boolean failuresAllowed() { return failuresAllowed; } - public Snippets snippets() { + public RerankSnippetInput snippets() { return snippets; } diff --git a/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardPhase.java b/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardPhase.java index ea33ae7ea429c..9ab3bd366fbb3 100644 --- a/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardPhase.java +++ b/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardPhase.java @@ -57,9 +57,12 @@ public static void prepareForFetch(SearchContext searchContext, RankFeatureShard new FetchFieldsContext(Collections.singletonList(new FieldAndFormat(rankFeaturePhaseRankShardContext.getField(), null))) ); try { - Snippets snippets = request.snippets(); + RerankSnippetInput snippets = request.snippets(); if (snippets != null) { + // For POC purposes we're just stripping pre/post tags and deferring if/how we'd want to handle them for this use case. HighlightBuilder highlightBuilder = new HighlightBuilder().field(field).preTags("").postTags(""); + // Force sorting by score to ensure that the first snippet is always the highest score + highlightBuilder.order(HighlightBuilder.Order.SCORE); if (snippets.numFragments() != null) { highlightBuilder.numOfFragments(snippets.numFragments()); } diff --git a/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardRequest.java b/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardRequest.java index b538333f850a3..cc053794925e2 100644 --- a/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardRequest.java +++ b/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardRequest.java @@ -40,14 +40,14 @@ public class RankFeatureShardRequest extends TransportRequest implements Indices private final int[] docIds; - private final Snippets snippets; + private final RerankSnippetInput snippets; public RankFeatureShardRequest( OriginalIndices originalIndices, ShardSearchContextId contextId, ShardSearchRequest shardSearchRequest, List docIds, - @Nullable Snippets snippets + @Nullable RerankSnippetInput snippets ) { this.originalIndices = originalIndices; this.shardSearchRequest = shardSearchRequest; @@ -63,7 +63,7 @@ public RankFeatureShardRequest(StreamInput in) throws IOException { docIds = in.readIntArray(); contextId = in.readOptionalWriteable(ShardSearchContextId::new); if (in.getTransportVersion().onOrAfter(TransportVersions.RERANK_SNIPPETS)) { - snippets = in.readOptionalWriteable(Snippets::new); + snippets = in.readOptionalWriteable(RerankSnippetInput::new); } else { snippets = null; } @@ -109,7 +109,7 @@ public ShardSearchContextId contextId() { return contextId; } - public Snippets snippets() { + public RerankSnippetInput snippets() { return snippets; } diff --git a/server/src/main/java/org/elasticsearch/search/rank/feature/Snippets.java b/server/src/main/java/org/elasticsearch/search/rank/feature/RerankSnippetInput.java similarity index 85% rename from server/src/main/java/org/elasticsearch/search/rank/feature/Snippets.java rename to server/src/main/java/org/elasticsearch/search/rank/feature/RerankSnippetInput.java index 7e92b096341cc..acdcf79764c7f 100644 --- a/server/src/main/java/org/elasticsearch/search/rank/feature/Snippets.java +++ b/server/src/main/java/org/elasticsearch/search/rank/feature/RerankSnippetInput.java @@ -15,9 +15,9 @@ import java.io.IOException; -public record Snippets(Integer numFragments, Integer maxSize) implements Writeable { +public record RerankSnippetInput(Integer numFragments, Integer maxSize) implements Writeable { - public Snippets(StreamInput in) throws IOException { + public RerankSnippetInput(StreamInput in) throws IOException { this(in.readOptionalVInt(), in.readOptionalVInt()); } diff --git a/server/src/main/java/org/elasticsearch/search/rank/rerank/RerankingRankFeaturePhaseRankShardContext.java b/server/src/main/java/org/elasticsearch/search/rank/rerank/RerankingRankFeaturePhaseRankShardContext.java index 57f81e45f5c11..1e3fb518924c7 100644 --- a/server/src/main/java/org/elasticsearch/search/rank/rerank/RerankingRankFeaturePhaseRankShardContext.java +++ b/server/src/main/java/org/elasticsearch/search/rank/rerank/RerankingRankFeaturePhaseRankShardContext.java @@ -21,7 +21,6 @@ import org.elasticsearch.search.rank.feature.RankFeatureDoc; import org.elasticsearch.search.rank.feature.RankFeatureShardResult; -import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Map; @@ -53,9 +52,10 @@ public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) } Map highlightFields = hit.getHighlightFields(); if (highlightFields != null) { - HighlightField highlightField = highlightFields.get(field); - if (highlightField != null) { - List snippets = new ArrayList<>(Arrays.stream(highlightField.fragments()).map(Text::toString).toList()); + if (highlightFields.containsKey(field)) { + List snippets = Arrays.stream(highlightFields.get(field).fragments()) + .map(Text::string) + .collect(Collectors.toList()); rankFeatureDocs[i].snippets(snippets); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java index 7a14185e7b5dc..528367d7a370d 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java @@ -107,7 +107,7 @@ public Request( String query, Boolean returnDocuments, Integer topN, - List input, + List input, // I think we need to add some metadata to the strings here and return this with each response Map taskSettings, InputType inputType, TimeValue inferenceTimeout, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java index d39f4f8966529..f8b78e7d7a990 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java @@ -23,7 +23,7 @@ import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext; import org.elasticsearch.search.rank.context.RankFeaturePhaseRankShardContext; import org.elasticsearch.search.rank.feature.RankFeatureDoc; -import org.elasticsearch.search.rank.feature.Snippets; +import org.elasticsearch.search.rank.feature.RerankSnippetInput; import org.elasticsearch.search.rank.rerank.RerankingRankFeaturePhaseRankShardContext; import org.elasticsearch.xcontent.XContentBuilder; @@ -55,7 +55,7 @@ public class TextSimilarityRankBuilder extends RankBuilder { private final String field; private final Float minScore; private final boolean failuresAllowed; - private final Snippets snippets; + private final RerankSnippetInput snippets; public TextSimilarityRankBuilder( String field, @@ -64,7 +64,7 @@ public TextSimilarityRankBuilder( int rankWindowSize, Float minScore, boolean failuresAllowed, - Snippets snippets + RerankSnippetInput snippets ) { super(rankWindowSize); this.inferenceId = inferenceId; @@ -88,7 +88,7 @@ public TextSimilarityRankBuilder(StreamInput in) throws IOException { this.failuresAllowed = false; } if (in.getTransportVersion().onOrAfter(TransportVersions.RERANK_SNIPPETS)) { - this.snippets = in.readOptionalWriteable(Snippets::new); + this.snippets = in.readOptionalWriteable(RerankSnippetInput::new); } else { this.snippets = null; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java index 5fc287f7a3724..23addd17ba33a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java @@ -14,7 +14,7 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext; import org.elasticsearch.search.rank.feature.RankFeatureDoc; -import org.elasticsearch.search.rank.feature.Snippets; +import org.elasticsearch.search.rank.feature.RerankSnippetInput; import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; @@ -45,7 +45,7 @@ public TextSimilarityRankFeaturePhaseRankCoordinatorContext( String inferenceText, Float minScore, boolean failuresAllowed, - Snippets snippets + RerankSnippetInput snippets ) { super(size, from, rankWindowSize, failuresAllowed, snippets); this.client = client; @@ -65,7 +65,7 @@ protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener rankedDocs = ((RankedDocsResults) results).getRankedDocs(); - if (snippets == null && rankedDocs.size() != featureDocs.length) { + if (rankedDocs.size() != featureDocs.length) { l.onFailure( new IllegalStateException( "Reranker input document count and returned score count mismatch: [" @@ -112,17 +112,16 @@ protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener featureData = new ArrayList<>(); - List snippets = new ArrayList<>(); + List inferenceInputs = new ArrayList<>(); for (RankFeatureDoc featureDoc : featureDocs) { - featureData.add(featureDoc.featureData); - if (featureDoc.snippets != null) { - snippets.addAll(featureDoc.snippets); + if (featureDoc.snippets != null && featureDoc.snippets.isEmpty() == false) { + // TODO support reranking multiple snippets + inferenceInputs.add(featureDoc.snippets.get(0)); + } else { + inferenceInputs.add(featureDoc.featureData); } } - InferenceAction.Request inferenceRequest = snippets.isEmpty() == false - ? generateRequest(snippets) - : generateRequest(featureData); + InferenceAction.Request inferenceRequest = generateRequest(inferenceInputs); try { client.execute(InferenceAction.INSTANCE, inferenceRequest, inferenceListener); } finally { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java index 159fbea457994..f7a2f48140e4a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java @@ -14,7 +14,7 @@ import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.rank.RankDoc; -import org.elasticsearch.search.rank.feature.Snippets; +import org.elasticsearch.search.rank.feature.RerankSnippetInput; import org.elasticsearch.search.retriever.CompoundRetrieverBuilder; import org.elasticsearch.search.retriever.RetrieverBuilder; import org.elasticsearch.search.retriever.RetrieverParserContext; @@ -58,7 +58,7 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder String field = (String) args[3]; int rankWindowSize = args[4] == null ? DEFAULT_RANK_WINDOW_SIZE : (int) args[4]; boolean failuresAllowed = args[5] != null && (Boolean) args[5]; - Snippets snippets = (Snippets) args[6]; + RerankSnippetInput snippets = (RerankSnippetInput) args[6]; return new TextSimilarityRankRetrieverBuilder( retrieverBuilder, @@ -71,15 +71,12 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder ); }); - private static final ConstructingObjectParser SNIPPETS_PARSER = new ConstructingObjectParser<>( - SNIPPETS_FIELD.getPreferredName(), - true, - args -> { + private static final ConstructingObjectParser SNIPPETS_PARSER = + new ConstructingObjectParser<>(SNIPPETS_FIELD.getPreferredName(), true, args -> { Integer numFragments = (Integer) args[0]; Integer maxSize = (Integer) args[1]; - return new Snippets(numFragments, maxSize); - } - ); + return new RerankSnippetInput(numFragments, maxSize); + }); static { PARSER.declareNamedObject(constructorArg(), (p, c, n) -> { @@ -114,7 +111,7 @@ public static TextSimilarityRankRetrieverBuilder fromXContent( private final String inferenceText; private final String field; private final boolean failuresAllowed; - private final Snippets snippets; + private final RerankSnippetInput snippets; public TextSimilarityRankRetrieverBuilder( RetrieverBuilder retrieverBuilder, @@ -123,7 +120,7 @@ public TextSimilarityRankRetrieverBuilder( String field, int rankWindowSize, boolean failuresAllowed, - Snippets snippets + RerankSnippetInput snippets ) { super(List.of(new RetrieverSource(retrieverBuilder, null)), rankWindowSize); this.inferenceId = inferenceId; @@ -143,7 +140,7 @@ public TextSimilarityRankRetrieverBuilder( boolean failuresAllowed, String retrieverName, List preFilterQueryBuilders, - Snippets snippets + RerankSnippetInput snippets ) { super(retrieverSource, rankWindowSize); if (retrieverSource.size() != 1) { @@ -226,7 +223,7 @@ public boolean failuresAllowed() { return failuresAllowed; } - public Snippets snippets() { + public RerankSnippetInput snippets() { return snippets; } From f99c33ef8fe3cf286a2cfa7e3448d29f6bc8984b Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Wed, 21 May 2025 10:49:36 -0400 Subject: [PATCH 4/7] Notes --- ...tSimilarityRankFeaturePhaseRankCoordinatorContext.java | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java index 23addd17ba33a..2f478fa561492 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java @@ -56,6 +56,10 @@ public TextSimilarityRankFeaturePhaseRankCoordinatorContext( @Override protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener scoreListener) { + + // Reconcile the input strings with the documents that they belong to. Input size 6. + // Let's say we have 6 snippets that we reranked from 2 documents (3 snippets each) + // Wrap the provided rankListener to an ActionListener that would handle the response from the inference service // and then pass the results final ActionListener inferenceListener = scoreListener.delegateFailureAndWrap((l, r) -> { @@ -76,7 +80,7 @@ protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener ra for (RankedDocsResults.RankedDoc rankedDoc : rankedDocs) { scores[rankedDoc.index()] = rankedDoc.relevanceScore(); } - return scores; + return scores; // Return a float of size 2 (max score index per doc) } private static float normalizeScore(float score) { From 0196a7c31b5f4dd50f045c857a59a48557131a4c Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Wed, 21 May 2025 16:00:04 -0400 Subject: [PATCH 5/7] Support reranking based on max score of multiple snippets per document --- .../search/rank/feature/RankFeatureDoc.java | 14 +++++- ...nkingRankFeaturePhaseRankShardContext.java | 13 +++-- .../inference/action/InferenceAction.java | 2 +- .../TextSimilarityRankBuilder.java | 3 +- ...ankFeaturePhaseRankCoordinatorContext.java | 48 +++++++++++++++---- 5 files changed, 62 insertions(+), 18 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureDoc.java b/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureDoc.java index 9a699def4b2cd..6502ecac99b62 100644 --- a/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureDoc.java +++ b/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureDoc.java @@ -30,6 +30,7 @@ public class RankFeatureDoc extends RankDoc { // TODO: update to support more than 1 fields; and not restrict to string data public String featureData; public List snippets; + public List docIndices; public RankFeatureDoc(int doc, float score, int shardIndex) { super(doc, score, shardIndex); @@ -40,6 +41,7 @@ public RankFeatureDoc(StreamInput in) throws IOException { featureData = in.readOptionalString(); if (in.getTransportVersion().onOrAfter(TransportVersions.RERANK_SNIPPETS)) { snippets = in.readOptionalStringCollectionAsList(); + docIndices = in.readOptionalCollectionAsList(StreamInput::readVInt); } } @@ -56,23 +58,30 @@ public void snippets(List snippets) { this.snippets = snippets; } + public void docIndices(List docIndices) { + this.docIndices = docIndices; + } + @Override protected void doWriteTo(StreamOutput out) throws IOException { out.writeOptionalString(featureData); if (out.getTransportVersion().onOrAfter(TransportVersions.RERANK_SNIPPETS)) { out.writeOptionalStringCollection(snippets); + out.writeOptionalCollection(docIndices, StreamOutput::writeVInt); } } @Override protected boolean doEquals(RankDoc rd) { RankFeatureDoc other = (RankFeatureDoc) rd; - return Objects.equals(this.featureData, other.featureData) && Objects.equals(this.snippets, other.snippets); + return Objects.equals(this.featureData, other.featureData) + && Objects.equals(this.snippets, other.snippets) + && Objects.equals(this.docIndices, other.docIndices); } @Override protected int doHashCode() { - return Objects.hash(featureData, snippets); + return Objects.hash(featureData, snippets, docIndices); } @Override @@ -84,5 +93,6 @@ public String getWriteableName() { protected void doToXContent(XContentBuilder builder, Params params) throws IOException { builder.field("featureData", featureData); builder.array("snippets", snippets); + builder.array("docIndices", docIndices); } } diff --git a/server/src/main/java/org/elasticsearch/search/rank/rerank/RerankingRankFeaturePhaseRankShardContext.java b/server/src/main/java/org/elasticsearch/search/rank/rerank/RerankingRankFeaturePhaseRankShardContext.java index 1e3fb518924c7..95eea541752c9 100644 --- a/server/src/main/java/org/elasticsearch/search/rank/rerank/RerankingRankFeaturePhaseRankShardContext.java +++ b/server/src/main/java/org/elasticsearch/search/rank/rerank/RerankingRankFeaturePhaseRankShardContext.java @@ -21,10 +21,10 @@ import org.elasticsearch.search.rank.feature.RankFeatureDoc; import org.elasticsearch.search.rank.feature.RankFeatureShardResult; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Map; -import java.util.stream.Collectors; /** * The {@code ReRankingRankFeaturePhaseRankShardContext} is handles the {@code SearchHits} generated from the {@code RankFeatureShardPhase} @@ -43,6 +43,7 @@ public RerankingRankFeaturePhaseRankShardContext(String field) { public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) { try { RankFeatureDoc[] rankFeatureDocs = new RankFeatureDoc[hits.getHits().length]; + int docIndex = 0; for (int i = 0; i < hits.getHits().length; i++) { rankFeatureDocs[i] = new RankFeatureDoc(hits.getHits()[i].docId(), hits.getHits()[i].getScore(), shardId); SearchHit hit = hits.getHits()[i]; @@ -53,12 +54,16 @@ public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) Map highlightFields = hit.getHighlightFields(); if (highlightFields != null) { if (highlightFields.containsKey(field)) { - List snippets = Arrays.stream(highlightFields.get(field).fragments()) - .map(Text::string) - .collect(Collectors.toList()); + List snippets = Arrays.stream(highlightFields.get(field).fragments()).map(Text::string).toList(); + List docIndices = new ArrayList<>(); + for (String snippet : snippets) { + docIndices.add(docIndex); + } rankFeatureDocs[i].snippets(snippets); + rankFeatureDocs[i].docIndices(docIndices); } } + docIndex++; } return new RankFeatureShardResult(rankFeatureDocs); } catch (Exception ex) { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java index 528367d7a370d..7a14185e7b5dc 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java @@ -107,7 +107,7 @@ public Request( String query, Boolean returnDocuments, Integer topN, - List input, // I think we need to add some metadata to the strings here and return this with each response + List input, Map taskSettings, InputType inputType, TimeValue inferenceTimeout, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java index f8b78e7d7a990..c1fe2778fcdbb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java @@ -36,6 +36,7 @@ import static org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder.INFERENCE_ID_FIELD; import static org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder.INFERENCE_TEXT_FIELD; import static org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder.MIN_SCORE_FIELD; +import static org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder.SNIPPETS_FIELD; /** * A {@code RankBuilder} that enables ranking with text similarity model inference. Supports parameters for configuring the inference call. @@ -133,7 +134,7 @@ public void doXContent(XContentBuilder builder, Params params) throws IOExceptio builder.field(FAILURES_ALLOWED_FIELD.getPreferredName(), true); } if (snippets != null) { - + builder.field(SNIPPETS_FIELD.getPreferredName(), snippets); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java index 2f478fa561492..28b1d8b6be072 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java @@ -22,6 +22,7 @@ import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankTaskSettings; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Map; @@ -57,30 +58,32 @@ public TextSimilarityRankFeaturePhaseRankCoordinatorContext( @Override protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener scoreListener) { - // Reconcile the input strings with the documents that they belong to. Input size 6. - // Let's say we have 6 snippets that we reranked from 2 documents (3 snippets each) - // Wrap the provided rankListener to an ActionListener that would handle the response from the inference service // and then pass the results final ActionListener inferenceListener = scoreListener.delegateFailureAndWrap((l, r) -> { InferenceServiceResults results = r.getResults(); assert results instanceof RankedDocsResults; - // Ensure we get exactly as many scores as the number of docs we passed, otherwise we may return incorrect results List rankedDocs = ((RankedDocsResults) results).getRankedDocs(); + final float[] scores; + if (featureDocs.length > 0 && featureDocs[0].snippets != null) { + scores = extractScoresFromRankedSnippets(rankedDocs, featureDocs); + } else { + scores = extractScoresFromRankedDocs(rankedDocs); + } - if (rankedDocs.size() != featureDocs.length) { + // Ensure we get exactly as many final scores as the number of docs we passed, otherwise we may return incorrect results + if (scores.length != featureDocs.length) { l.onFailure( new IllegalStateException( "Reranker input document count and returned score count mismatch: [" + featureDocs.length + "] vs [" - + rankedDocs.size() + + scores.length + "]" ) ); } else { - float[] scores = extractScoresFromRankedDocs(rankedDocs); // Return is size 2 l.onResponse(scores); } }); @@ -119,8 +122,7 @@ protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener inferenceInputs = new ArrayList<>(); for (RankFeatureDoc featureDoc : featureDocs) { if (featureDoc.snippets != null && featureDoc.snippets.isEmpty() == false) { - // TODO support reranking multiple snippets - inferenceInputs.add(featureDoc.snippets.get(0)); + inferenceInputs.addAll(featureDoc.snippets); } else { inferenceInputs.add(featureDoc.featureData); } @@ -181,7 +183,33 @@ private float[] extractScoresFromRankedDocs(List ra for (RankedDocsResults.RankedDoc rankedDoc : rankedDocs) { scores[rankedDoc.index()] = rankedDoc.relevanceScore(); } - return scores; // Return a float of size 2 (max score index per doc) + return scores; + } + + private float[] extractScoresFromRankedSnippets(List rankedDocs, RankFeatureDoc[] featureDocs) { + int[] docMappings = Arrays.stream(featureDocs).flatMapToInt(f -> f.docIndices.stream().mapToInt(Integer::intValue)).toArray(); + + float[] scores = new float[featureDocs.length]; + boolean[] hasScore = new boolean[featureDocs.length]; + + for (int i = 0; i < rankedDocs.size(); i++) { + int docId = docMappings[i]; + float score = rankedDocs.get(i).relevanceScore(); + + if (hasScore[docId] == false) { + scores[docId] = score; + hasScore[docId] = true; + } else { + scores[docId] = Math.max(scores[docId], score); + } + } + + float[] result = new float[featureDocs.length]; + for (int i = 0; i < featureDocs.length; i++) { + result[i] = hasScore[i] ? normalizeScore(scores[i]) : 0f; + } + + return result; } private static float normalizeScore(float score) { From 1dc80a1ac45b7322e260cbfb0974b7ffb9c77902 Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Thu, 22 May 2025 15:28:30 -0400 Subject: [PATCH 6/7] Fix compilation error --- .../rank/rerank/RerankingRankFeaturePhaseRankShardContext.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/src/main/java/org/elasticsearch/search/rank/rerank/RerankingRankFeaturePhaseRankShardContext.java b/server/src/main/java/org/elasticsearch/search/rank/rerank/RerankingRankFeaturePhaseRankShardContext.java index 95eea541752c9..3253b98b214fd 100644 --- a/server/src/main/java/org/elasticsearch/search/rank/rerank/RerankingRankFeaturePhaseRankShardContext.java +++ b/server/src/main/java/org/elasticsearch/search/rank/rerank/RerankingRankFeaturePhaseRankShardContext.java @@ -12,7 +12,6 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.common.document.DocumentField; -import org.elasticsearch.common.text.Text; import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHits; import org.elasticsearch.search.fetch.subphase.highlight.HighlightField; @@ -20,6 +19,7 @@ import org.elasticsearch.search.rank.context.RankFeaturePhaseRankShardContext; import org.elasticsearch.search.rank.feature.RankFeatureDoc; import org.elasticsearch.search.rank.feature.RankFeatureShardResult; +import org.elasticsearch.xcontent.Text; import java.util.ArrayList; import java.util.Arrays; From f68b720c5b5e3e868a8c5a1db11acec15a523287 Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Wed, 28 May 2025 16:39:04 -0400 Subject: [PATCH 7/7] Fix merge compile errors --- .../rank/rerank/RerankingRankFeaturePhaseRankShardContext.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/src/main/java/org/elasticsearch/search/rank/rerank/RerankingRankFeaturePhaseRankShardContext.java b/server/src/main/java/org/elasticsearch/search/rank/rerank/RerankingRankFeaturePhaseRankShardContext.java index 3253b98b214fd..95eea541752c9 100644 --- a/server/src/main/java/org/elasticsearch/search/rank/rerank/RerankingRankFeaturePhaseRankShardContext.java +++ b/server/src/main/java/org/elasticsearch/search/rank/rerank/RerankingRankFeaturePhaseRankShardContext.java @@ -12,6 +12,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.common.document.DocumentField; +import org.elasticsearch.common.text.Text; import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHits; import org.elasticsearch.search.fetch.subphase.highlight.HighlightField; @@ -19,7 +20,6 @@ import org.elasticsearch.search.rank.context.RankFeaturePhaseRankShardContext; import org.elasticsearch.search.rank.feature.RankFeatureDoc; import org.elasticsearch.search.rank.feature.RankFeatureShardResult; -import org.elasticsearch.xcontent.Text; import java.util.ArrayList; import java.util.Arrays;