diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index d008b885a40a8..4cfd2d812e415 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -296,6 +296,7 @@ static TransportVersion def(int id) { public static final TransportVersion SEARCH_LOAD_PER_INDEX_STATS = def(9_095_0_00); public static final TransportVersion HEAP_USAGE_IN_CLUSTER_INFO = def(9_096_0_00); public static final TransportVersion NONE_CHUNKING_STRATEGY = def(9_097_0_00); + public static final TransportVersion RERANK_SNIPPETS = def(9_098_0_00); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java b/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java index 25238711c5c1c..ec3c99e1363e8 100644 --- a/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java @@ -172,7 +172,8 @@ public void onFailure(Exception e) { context.getOriginalIndices(queryResult.getShardIndex()), queryResult.getContextId(), queryResult.getShardSearchRequest(), - entry + entry, + rankFeaturePhaseRankCoordinatorContext.snippets() ), context.getTask(), listener diff --git a/server/src/main/java/org/elasticsearch/search/fetch/subphase/highlight/HighlightBuilder.java b/server/src/main/java/org/elasticsearch/search/fetch/subphase/highlight/HighlightBuilder.java index 1997601e73f6d..2a48e8338e377 100644 --- a/server/src/main/java/org/elasticsearch/search/fetch/subphase/highlight/HighlightBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/fetch/subphase/highlight/HighlightBuilder.java @@ -399,7 +399,7 @@ public Field(String name) { this.name = name; } - private Field(Field template, QueryBuilder builder) { + Field(Field template, QueryBuilder builder) { super(template, builder); name = template.name; fragmentOffset = template.fragmentOffset; diff --git a/server/src/main/java/org/elasticsearch/search/fetch/subphase/highlight/SearchHighlightContext.java b/server/src/main/java/org/elasticsearch/search/fetch/subphase/highlight/SearchHighlightContext.java index 631a75a355abf..a85ae92c24bcf 100644 --- a/server/src/main/java/org/elasticsearch/search/fetch/subphase/highlight/SearchHighlightContext.java +++ b/server/src/main/java/org/elasticsearch/search/fetch/subphase/highlight/SearchHighlightContext.java @@ -40,7 +40,7 @@ public static class Field { private final String field; private final FieldOptions fieldOptions; - Field(String field, FieldOptions fieldOptions) { + public Field(String field, FieldOptions fieldOptions) { assert field != null; assert fieldOptions != null; this.field = field; diff --git a/server/src/main/java/org/elasticsearch/search/rank/context/RankFeaturePhaseRankCoordinatorContext.java b/server/src/main/java/org/elasticsearch/search/rank/context/RankFeaturePhaseRankCoordinatorContext.java index 819d04e12eeeb..6afa4cc74e0a8 100644 --- a/server/src/main/java/org/elasticsearch/search/rank/context/RankFeaturePhaseRankCoordinatorContext.java +++ b/server/src/main/java/org/elasticsearch/search/rank/context/RankFeaturePhaseRankCoordinatorContext.java @@ -12,6 +12,7 @@ import org.apache.lucene.search.ScoreDoc; import org.elasticsearch.action.ActionListener; import org.elasticsearch.search.rank.feature.RankFeatureDoc; +import org.elasticsearch.search.rank.feature.RerankSnippetInput; import java.util.Arrays; import java.util.Comparator; @@ -30,18 +31,30 @@ public abstract class RankFeaturePhaseRankCoordinatorContext { protected final int from; protected final int rankWindowSize; protected final boolean failuresAllowed; + protected final RerankSnippetInput snippets; - public RankFeaturePhaseRankCoordinatorContext(int size, int from, int rankWindowSize, boolean failuresAllowed) { + public RankFeaturePhaseRankCoordinatorContext( + int size, + int from, + int rankWindowSize, + boolean failuresAllowed, + RerankSnippetInput snippets + ) { this.size = size < 0 ? DEFAULT_SIZE : size; this.from = from < 0 ? DEFAULT_FROM : from; this.rankWindowSize = rankWindowSize; this.failuresAllowed = failuresAllowed; + this.snippets = snippets; } public boolean failuresAllowed() { return failuresAllowed; } + public RerankSnippetInput snippets() { + return snippets; + } + /** * Computes the updated scores for a list of features (i.e. document-based data). We also pass along an ActionListener * that should be called with the new scores, and will continue execution to the next phase diff --git a/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureDoc.java b/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureDoc.java index cd8d9392aced8..6502ecac99b62 100644 --- a/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureDoc.java +++ b/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureDoc.java @@ -10,12 +10,14 @@ package org.elasticsearch.search.rank.feature; import org.apache.lucene.search.Explanation; +import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.search.rank.RankDoc; import org.elasticsearch.xcontent.XContentBuilder; import java.io.IOException; +import java.util.List; import java.util.Objects; /** @@ -27,6 +29,8 @@ public class RankFeatureDoc extends RankDoc { // TODO: update to support more than 1 fields; and not restrict to string data public String featureData; + public List snippets; + public List docIndices; public RankFeatureDoc(int doc, float score, int shardIndex) { super(doc, score, shardIndex); @@ -35,6 +39,10 @@ public RankFeatureDoc(int doc, float score, int shardIndex) { public RankFeatureDoc(StreamInput in) throws IOException { super(in); featureData = in.readOptionalString(); + if (in.getTransportVersion().onOrAfter(TransportVersions.RERANK_SNIPPETS)) { + snippets = in.readOptionalStringCollectionAsList(); + docIndices = in.readOptionalCollectionAsList(StreamInput::readVInt); + } } @Override @@ -46,20 +54,34 @@ public void featureData(String featureData) { this.featureData = featureData; } + public void snippets(List snippets) { + this.snippets = snippets; + } + + public void docIndices(List docIndices) { + this.docIndices = docIndices; + } + @Override protected void doWriteTo(StreamOutput out) throws IOException { out.writeOptionalString(featureData); + if (out.getTransportVersion().onOrAfter(TransportVersions.RERANK_SNIPPETS)) { + out.writeOptionalStringCollection(snippets); + out.writeOptionalCollection(docIndices, StreamOutput::writeVInt); + } } @Override protected boolean doEquals(RankDoc rd) { RankFeatureDoc other = (RankFeatureDoc) rd; - return Objects.equals(this.featureData, other.featureData); + return Objects.equals(this.featureData, other.featureData) + && Objects.equals(this.snippets, other.snippets) + && Objects.equals(this.docIndices, other.docIndices); } @Override protected int doHashCode() { - return Objects.hashCode(featureData); + return Objects.hash(featureData, snippets, docIndices); } @Override @@ -70,5 +92,7 @@ public String getWriteableName() { @Override protected void doToXContent(XContentBuilder builder, Params params) throws IOException { builder.field("featureData", featureData); + builder.array("snippets", snippets); + builder.array("docIndices", docIndices); } } diff --git a/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardPhase.java b/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardPhase.java index 4374c06da365d..9ab3bd366fbb3 100644 --- a/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardPhase.java +++ b/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardPhase.java @@ -17,10 +17,13 @@ import org.elasticsearch.search.fetch.StoredFieldsContext; import org.elasticsearch.search.fetch.subphase.FetchFieldsContext; import org.elasticsearch.search.fetch.subphase.FieldAndFormat; +import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder; +import org.elasticsearch.search.fetch.subphase.highlight.SearchHighlightContext; import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.search.rank.context.RankFeaturePhaseRankShardContext; import org.elasticsearch.tasks.TaskCancelledException; +import java.io.IOException; import java.util.Arrays; import java.util.Collections; @@ -48,10 +51,30 @@ public static void prepareForFetch(SearchContext searchContext, RankFeatureShard RankFeaturePhaseRankShardContext rankFeaturePhaseRankShardContext = shardContext(searchContext); if (rankFeaturePhaseRankShardContext != null) { - assert rankFeaturePhaseRankShardContext.getField() != null : "field must not be null"; + String field = rankFeaturePhaseRankShardContext.getField(); + assert field != null : "field must not be null"; searchContext.fetchFieldsContext( new FetchFieldsContext(Collections.singletonList(new FieldAndFormat(rankFeaturePhaseRankShardContext.getField(), null))) ); + try { + RerankSnippetInput snippets = request.snippets(); + if (snippets != null) { + // For POC purposes we're just stripping pre/post tags and deferring if/how we'd want to handle them for this use case. + HighlightBuilder highlightBuilder = new HighlightBuilder().field(field).preTags("").postTags(""); + // Force sorting by score to ensure that the first snippet is always the highest score + highlightBuilder.order(HighlightBuilder.Order.SCORE); + if (snippets.numFragments() != null) { + highlightBuilder.numOfFragments(snippets.numFragments()); + } + if (snippets.maxSize() != null) { + highlightBuilder.fragmentSize(snippets.maxSize()); + } + SearchHighlightContext searchHighlightContext = highlightBuilder.build(searchContext.getSearchExecutionContext()); + searchContext.highlight(searchHighlightContext); + } + } catch (IOException e) { + throw new RuntimeException("Failed to create highlight context", e); + } searchContext.storedFieldsContext(StoredFieldsContext.fromList(Collections.singletonList(StoredFieldsContext._NONE_))); searchContext.addFetchResult(); Arrays.sort(request.getDocIds()); diff --git a/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardRequest.java b/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardRequest.java index d6c10f15adf80..dc688cef11efa 100644 --- a/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardRequest.java +++ b/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardRequest.java @@ -9,12 +9,14 @@ package org.elasticsearch.search.rank.feature; +import org.elasticsearch.TransportVersions; import org.elasticsearch.action.IndicesRequest; import org.elasticsearch.action.OriginalIndices; import org.elasticsearch.action.search.SearchShardTask; import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; import org.elasticsearch.search.internal.ShardSearchContextId; import org.elasticsearch.search.internal.ShardSearchRequest; import org.elasticsearch.tasks.TaskId; @@ -38,16 +40,20 @@ public class RankFeatureShardRequest extends AbstractTransportRequest implements private final int[] docIds; + private final RerankSnippetInput snippets; + public RankFeatureShardRequest( OriginalIndices originalIndices, ShardSearchContextId contextId, ShardSearchRequest shardSearchRequest, - List docIds + List docIds, + @Nullable RerankSnippetInput snippets ) { this.originalIndices = originalIndices; this.shardSearchRequest = shardSearchRequest; this.docIds = docIds.stream().flatMapToInt(IntStream::of).toArray(); this.contextId = contextId; + this.snippets = snippets; } public RankFeatureShardRequest(StreamInput in) throws IOException { @@ -56,6 +62,11 @@ public RankFeatureShardRequest(StreamInput in) throws IOException { shardSearchRequest = in.readOptionalWriteable(ShardSearchRequest::new); docIds = in.readIntArray(); contextId = in.readOptionalWriteable(ShardSearchContextId::new); + if (in.getTransportVersion().onOrAfter(TransportVersions.RERANK_SNIPPETS)) { + snippets = in.readOptionalWriteable(RerankSnippetInput::new); + } else { + snippets = null; + } } @Override @@ -65,6 +76,9 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalWriteable(shardSearchRequest); out.writeIntArray(docIds); out.writeOptionalWriteable(contextId); + if (out.getTransportVersion().onOrAfter(TransportVersions.RERANK_SNIPPETS)) { + out.writeOptionalWriteable(snippets); + } } @Override @@ -95,6 +109,10 @@ public ShardSearchContextId contextId() { return contextId; } + public RerankSnippetInput snippets() { + return snippets; + } + @Override public SearchShardTask createTask(long id, String type, String action, TaskId parentTaskId, Map headers) { return new SearchShardTask(id, type, action, getDescription(), parentTaskId, headers); diff --git a/server/src/main/java/org/elasticsearch/search/rank/feature/RerankSnippetInput.java b/server/src/main/java/org/elasticsearch/search/rank/feature/RerankSnippetInput.java new file mode 100644 index 0000000000000..acdcf79764c7f --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/rank/feature/RerankSnippetInput.java @@ -0,0 +1,29 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.rank.feature; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; + +import java.io.IOException; + +public record RerankSnippetInput(Integer numFragments, Integer maxSize) implements Writeable { + + public RerankSnippetInput(StreamInput in) throws IOException { + this(in.readOptionalVInt(), in.readOptionalVInt()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalVInt(numFragments); + out.writeOptionalVInt(maxSize); + } +} diff --git a/server/src/main/java/org/elasticsearch/search/rank/rerank/RerankingRankFeaturePhaseRankShardContext.java b/server/src/main/java/org/elasticsearch/search/rank/rerank/RerankingRankFeaturePhaseRankShardContext.java index 96867fb1d190b..95eea541752c9 100644 --- a/server/src/main/java/org/elasticsearch/search/rank/rerank/RerankingRankFeaturePhaseRankShardContext.java +++ b/server/src/main/java/org/elasticsearch/search/rank/rerank/RerankingRankFeaturePhaseRankShardContext.java @@ -12,14 +12,19 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.common.document.DocumentField; +import org.elasticsearch.common.text.Text; import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHits; +import org.elasticsearch.search.fetch.subphase.highlight.HighlightField; import org.elasticsearch.search.rank.RankShardResult; import org.elasticsearch.search.rank.context.RankFeaturePhaseRankShardContext; import org.elasticsearch.search.rank.feature.RankFeatureDoc; import org.elasticsearch.search.rank.feature.RankFeatureShardResult; +import java.util.ArrayList; import java.util.Arrays; +import java.util.List; +import java.util.Map; /** * The {@code ReRankingRankFeaturePhaseRankShardContext} is handles the {@code SearchHits} generated from the {@code RankFeatureShardPhase} @@ -38,12 +43,27 @@ public RerankingRankFeaturePhaseRankShardContext(String field) { public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) { try { RankFeatureDoc[] rankFeatureDocs = new RankFeatureDoc[hits.getHits().length]; + int docIndex = 0; for (int i = 0; i < hits.getHits().length; i++) { rankFeatureDocs[i] = new RankFeatureDoc(hits.getHits()[i].docId(), hits.getHits()[i].getScore(), shardId); - DocumentField docField = hits.getHits()[i].field(field); + SearchHit hit = hits.getHits()[i]; + DocumentField docField = hit.field(field); if (docField != null) { rankFeatureDocs[i].featureData(docField.getValue().toString()); } + Map highlightFields = hit.getHighlightFields(); + if (highlightFields != null) { + if (highlightFields.containsKey(field)) { + List snippets = Arrays.stream(highlightFields.get(field).fragments()).map(Text::string).toList(); + List docIndices = new ArrayList<>(); + for (String snippet : snippets) { + docIndices.add(docIndex); + } + rankFeatureDocs[i].snippets(snippets); + rankFeatureDocs[i].docIndices(docIndices); + } + } + docIndex++; } return new RankFeatureShardResult(rankFeatureDocs); } catch (Exception ex) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/random/RandomRankFeaturePhaseRankCoordinatorContext.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/random/RandomRankFeaturePhaseRankCoordinatorContext.java index de593e0943f42..202b7c5134c48 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/random/RandomRankFeaturePhaseRankCoordinatorContext.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/random/RandomRankFeaturePhaseRankCoordinatorContext.java @@ -22,7 +22,7 @@ public class RandomRankFeaturePhaseRankCoordinatorContext extends RankFeaturePha private final Integer seed; public RandomRankFeaturePhaseRankCoordinatorContext(int size, int from, int rankWindowSize, Integer seed) { - super(size, from, rankWindowSize, false); + super(size, from, rankWindowSize, false, null); this.seed = seed; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java index 71f3465f8dfc8..04809ba7b3dca 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java @@ -23,6 +23,7 @@ import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext; import org.elasticsearch.search.rank.context.RankFeaturePhaseRankShardContext; import org.elasticsearch.search.rank.feature.RankFeatureDoc; +import org.elasticsearch.search.rank.feature.RerankSnippetInput; import org.elasticsearch.search.rank.rerank.RerankingRankFeaturePhaseRankShardContext; import org.elasticsearch.xcontent.XContentBuilder; @@ -35,6 +36,7 @@ import static org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder.INFERENCE_ID_FIELD; import static org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder.INFERENCE_TEXT_FIELD; import static org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder.MIN_SCORE_FIELD; +import static org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder.SNIPPETS_FIELD; /** * A {@code RankBuilder} that enables ranking with text similarity model inference. Supports parameters for configuring the inference call. @@ -54,6 +56,7 @@ public class TextSimilarityRankBuilder extends RankBuilder { private final String field; private final Float minScore; private final boolean failuresAllowed; + private final RerankSnippetInput snippets; public TextSimilarityRankBuilder( String field, @@ -61,7 +64,8 @@ public TextSimilarityRankBuilder( String inferenceText, int rankWindowSize, Float minScore, - boolean failuresAllowed + boolean failuresAllowed, + RerankSnippetInput snippets ) { super(rankWindowSize); this.inferenceId = inferenceId; @@ -69,6 +73,7 @@ public TextSimilarityRankBuilder( this.field = field; this.minScore = minScore; this.failuresAllowed = failuresAllowed; + this.snippets = snippets; } public TextSimilarityRankBuilder(StreamInput in) throws IOException { @@ -84,6 +89,11 @@ public TextSimilarityRankBuilder(StreamInput in) throws IOException { } else { this.failuresAllowed = false; } + if (in.getTransportVersion().onOrAfter(TransportVersions.RERANK_SNIPPETS)) { + this.snippets = in.readOptionalWriteable(RerankSnippetInput::new); + } else { + this.snippets = null; + } } @Override @@ -107,6 +117,9 @@ public void doWriteTo(StreamOutput out) throws IOException { || out.getTransportVersion().onOrAfter(TransportVersions.RERANKER_FAILURES_ALLOWED)) { out.writeBoolean(failuresAllowed); } + if (out.getTransportVersion().onOrAfter(TransportVersions.RERANK_SNIPPETS)) { + out.writeOptionalWriteable(snippets); + } } @Override @@ -122,6 +135,9 @@ public void doXContent(XContentBuilder builder, Params params) throws IOExceptio if (failuresAllowed) { builder.field(FAILURES_ALLOWED_FIELD.getPreferredName(), true); } + if (snippets != null) { + builder.field(SNIPPETS_FIELD.getPreferredName(), snippets); + } } @Override @@ -181,7 +197,8 @@ public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorCo inferenceId, inferenceText, minScore, - failuresAllowed + failuresAllowed, + snippets ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java index 27221dc1f5caf..f90d96766bd3c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java @@ -14,6 +14,7 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext; import org.elasticsearch.search.rank.feature.RankFeatureDoc; +import org.elasticsearch.search.rank.feature.RerankSnippetInput; import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; @@ -48,9 +49,10 @@ public TextSimilarityRankFeaturePhaseRankCoordinatorContext( String inferenceId, String inferenceText, Float minScore, - boolean failuresAllowed + boolean failuresAllowed, + RerankSnippetInput snippets ) { - super(size, from, rankWindowSize, failuresAllowed); + super(size, from, rankWindowSize, failuresAllowed, snippets); this.client = client; this.inferenceId = inferenceId; this.inferenceText = inferenceText; @@ -59,27 +61,33 @@ public TextSimilarityRankFeaturePhaseRankCoordinatorContext( @Override protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener scoreListener) { + // Wrap the provided rankListener to an ActionListener that would handle the response from the inference service // and then pass the results final ActionListener inferenceListener = scoreListener.delegateFailureAndWrap((l, r) -> { InferenceServiceResults results = r.getResults(); assert results instanceof RankedDocsResults; - // Ensure we get exactly as many scores as the number of docs we passed, otherwise we may return incorrect results List rankedDocs = ((RankedDocsResults) results).getRankedDocs(); + final float[] scores; + if (featureDocs.length > 0 && featureDocs[0].snippets != null) { + scores = extractScoresFromRankedSnippets(rankedDocs, featureDocs); + } else { + scores = extractScoresFromRankedDocs(rankedDocs); + } - if (rankedDocs.size() != featureDocs.length) { + // Ensure we get exactly as many final scores as the number of docs we passed, otherwise we may return incorrect results + if (scores.length != featureDocs.length) { l.onFailure( new IllegalStateException( "Reranker input document count and returned score count mismatch: [" + featureDocs.length + "] vs [" - + rankedDocs.size() + + scores.length + "]" ) ); } else { - float[] scores = extractScoresFromRankedDocs(rankedDocs); l.onResponse(scores); } }); @@ -118,8 +126,15 @@ protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener featureData = Arrays.stream(featureDocs).map(x -> x.featureData).toList(); - InferenceAction.Request inferenceRequest = generateRequest(featureData); + List inferenceInputs = new ArrayList<>(); + for (RankFeatureDoc featureDoc : featureDocs) { + if (featureDoc.snippets != null && featureDoc.snippets.isEmpty() == false) { + inferenceInputs.addAll(featureDoc.snippets); + } else { + inferenceInputs.add(featureDoc.featureData); + } + } + InferenceAction.Request inferenceRequest = generateRequest(inferenceInputs); try { executeAsyncWithOrigin(client, INFERENCE_ORIGIN, InferenceAction.INSTANCE, inferenceRequest, inferenceListener); } finally { @@ -178,6 +193,32 @@ private float[] extractScoresFromRankedDocs(List ra return scores; } + private float[] extractScoresFromRankedSnippets(List rankedDocs, RankFeatureDoc[] featureDocs) { + int[] docMappings = Arrays.stream(featureDocs).flatMapToInt(f -> f.docIndices.stream().mapToInt(Integer::intValue)).toArray(); + + float[] scores = new float[featureDocs.length]; + boolean[] hasScore = new boolean[featureDocs.length]; + + for (int i = 0; i < rankedDocs.size(); i++) { + int docId = docMappings[i]; + float score = rankedDocs.get(i).relevanceScore(); + + if (hasScore[docId] == false) { + scores[docId] = score; + hasScore[docId] = true; + } else { + scores[docId] = Math.max(scores[docId], score); + } + } + + float[] result = new float[featureDocs.length]; + for (int i = 0; i < featureDocs.length; i++) { + result[i] = hasScore[i] ? normalizeScore(scores[i]) : 0f; + } + + return result; + } + private static float normalizeScore(float score) { // As some models might produce negative scores, we want to ensure that all scores will be positive // so we will make use of the following normalization formula: diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java index 0a6ff009f367e..22ba3303eb0c9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java @@ -14,6 +14,7 @@ import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.rank.RankDoc; +import org.elasticsearch.search.rank.feature.RerankSnippetInput; import org.elasticsearch.search.retriever.CompoundRetrieverBuilder; import org.elasticsearch.search.retriever.RetrieverBuilder; import org.elasticsearch.search.retriever.RetrieverParserContext; @@ -47,6 +48,9 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder public static final ParseField INFERENCE_TEXT_FIELD = new ParseField("inference_text"); public static final ParseField FIELD_FIELD = new ParseField("field"); public static final ParseField FAILURES_ALLOWED_FIELD = new ParseField("allow_rerank_failures"); + public static final ParseField SNIPPETS_FIELD = new ParseField("snippets"); + public static final ParseField NUM_FRAGMENTS_FIELD = new ParseField("num_fragments"); + public static final ParseField MAX_SIZE_FIELD = new ParseField("max_size"); public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(TextSimilarityRankBuilder.NAME, args -> { @@ -56,6 +60,7 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder String field = (String) args[3]; int rankWindowSize = args[4] == null ? DEFAULT_RANK_WINDOW_SIZE : (int) args[4]; boolean failuresAllowed = args[5] != null && (Boolean) args[5]; + RerankSnippetInput snippets = (RerankSnippetInput) args[6]; return new TextSimilarityRankRetrieverBuilder( retrieverBuilder, @@ -63,10 +68,18 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder inferenceText, field, rankWindowSize, - failuresAllowed + failuresAllowed, + snippets ); }); + private static final ConstructingObjectParser SNIPPETS_PARSER = + new ConstructingObjectParser<>(SNIPPETS_FIELD.getPreferredName(), true, args -> { + Integer numFragments = (Integer) args[0]; + Integer maxSize = (Integer) args[1]; + return new RerankSnippetInput(numFragments, maxSize); + }); + static { PARSER.declareNamedObject(constructorArg(), (p, c, n) -> { RetrieverBuilder innerRetriever = p.namedObject(RetrieverBuilder.class, n, c); @@ -78,6 +91,9 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder PARSER.declareString(constructorArg(), FIELD_FIELD); PARSER.declareInt(optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD); PARSER.declareBoolean(optionalConstructorArg(), FAILURES_ALLOWED_FIELD); + PARSER.declareObject(optionalConstructorArg(), SNIPPETS_PARSER, SNIPPETS_FIELD); + SNIPPETS_PARSER.declareInt(optionalConstructorArg(), NUM_FRAGMENTS_FIELD); + SNIPPETS_PARSER.declareInt(optionalConstructorArg(), MAX_SIZE_FIELD); RetrieverBuilder.declareBaseParserFields(PARSER); } @@ -97,6 +113,7 @@ public static TextSimilarityRankRetrieverBuilder fromXContent( private final String inferenceText; private final String field; private final boolean failuresAllowed; + private final RerankSnippetInput snippets; public TextSimilarityRankRetrieverBuilder( RetrieverBuilder retrieverBuilder, @@ -104,13 +121,15 @@ public TextSimilarityRankRetrieverBuilder( String inferenceText, String field, int rankWindowSize, - boolean failuresAllowed + boolean failuresAllowed, + RerankSnippetInput snippets ) { super(List.of(new RetrieverSource(retrieverBuilder, null)), rankWindowSize); this.inferenceId = inferenceId; this.inferenceText = inferenceText; this.field = field; this.failuresAllowed = failuresAllowed; + this.snippets = snippets; } public TextSimilarityRankRetrieverBuilder( @@ -122,7 +141,8 @@ public TextSimilarityRankRetrieverBuilder( Float minScore, boolean failuresAllowed, String retrieverName, - List preFilterQueryBuilders + List preFilterQueryBuilders, + RerankSnippetInput snippets ) { super(retrieverSource, rankWindowSize); if (retrieverSource.size() != 1) { @@ -135,6 +155,7 @@ public TextSimilarityRankRetrieverBuilder( this.failuresAllowed = failuresAllowed; this.retrieverName = retrieverName; this.preFilterQueryBuilders = preFilterQueryBuilders; + this.snippets = snippets; } @Override @@ -151,7 +172,8 @@ protected TextSimilarityRankRetrieverBuilder clone( minScore, failuresAllowed, retrieverName, - newPreFilterQueryBuilders + newPreFilterQueryBuilders, + snippets ); } @@ -179,7 +201,7 @@ protected RankDoc[] combineInnerRetrieverResults(List rankResults, b @Override protected SearchSourceBuilder finalizeSourceBuilder(SearchSourceBuilder sourceBuilder) { sourceBuilder.rankBuilder( - new TextSimilarityRankBuilder(field, inferenceId, inferenceText, rankWindowSize, minScore, failuresAllowed) + new TextSimilarityRankBuilder(field, inferenceId, inferenceText, rankWindowSize, minScore, failuresAllowed, snippets) ); return sourceBuilder; } @@ -197,6 +219,10 @@ public boolean failuresAllowed() { return failuresAllowed; } + public RerankSnippetInput snippets() { + return snippets; + } + @Override protected void doToXContent(XContentBuilder builder, Params params) throws IOException { builder.field(RETRIEVER_FIELD.getPreferredName(), innerRetrievers.getFirst().retriever()); @@ -207,6 +233,16 @@ protected void doToXContent(XContentBuilder builder, Params params) throws IOExc if (failuresAllowed) { builder.field(FAILURES_ALLOWED_FIELD.getPreferredName(), failuresAllowed); } + if (snippets != null) { + builder.startObject(SNIPPETS_FIELD.getPreferredName()); + if (snippets.numFragments() != null) { + builder.field(NUM_FRAGMENTS_FIELD.getPreferredName(), snippets.numFragments()); + } + if (snippets.maxSize() != null) { + builder.field(MAX_SIZE_FIELD.getPreferredName(), snippets.maxSize()); + } + builder.endObject(); + } } @Override @@ -218,11 +254,12 @@ public boolean doEquals(Object other) { && Objects.equals(field, that.field) && rankWindowSize == that.rankWindowSize && Objects.equals(minScore, that.minScore) - && failuresAllowed == that.failuresAllowed; + && failuresAllowed == that.failuresAllowed + && Objects.equals(snippets, that.snippets); } @Override public int doHashCode() { - return Objects.hash(inferenceId, inferenceText, field, rankWindowSize, minScore, failuresAllowed); + return Objects.hash(inferenceId, inferenceText, field, rankWindowSize, minScore, failuresAllowed, snippets); } }