Skip to content

Commit 0ffaf94

Browse files
committed
Only rerank the first snippet
1 parent 3f52ac7 commit 0ffaf94

File tree

10 files changed

+49
-45
lines changed

10 files changed

+49
-45
lines changed

server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import org.elasticsearch.search.rank.feature.RankFeatureDoc;
2626
import org.elasticsearch.search.rank.feature.RankFeatureResult;
2727
import org.elasticsearch.search.rank.feature.RankFeatureShardRequest;
28-
import org.elasticsearch.search.rank.feature.Snippets;
2928
import org.elasticsearch.transport.Transport;
3029

3130
import java.util.Arrays;

server/src/main/java/org/elasticsearch/search/rank/context/RankFeaturePhaseRankCoordinatorContext.java

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import org.apache.lucene.search.ScoreDoc;
1313
import org.elasticsearch.action.ActionListener;
1414
import org.elasticsearch.search.rank.feature.RankFeatureDoc;
15-
import org.elasticsearch.search.rank.feature.Snippets;
15+
import org.elasticsearch.search.rank.feature.RerankSnippetInput;
1616

1717
import java.util.Arrays;
1818
import java.util.Comparator;
@@ -31,9 +31,15 @@ public abstract class RankFeaturePhaseRankCoordinatorContext {
3131
protected final int from;
3232
protected final int rankWindowSize;
3333
protected final boolean failuresAllowed;
34-
protected final Snippets snippets;
34+
protected final RerankSnippetInput snippets;
3535

36-
public RankFeaturePhaseRankCoordinatorContext(int size, int from, int rankWindowSize, boolean failuresAllowed, Snippets snippets) {
36+
public RankFeaturePhaseRankCoordinatorContext(
37+
int size,
38+
int from,
39+
int rankWindowSize,
40+
boolean failuresAllowed,
41+
RerankSnippetInput snippets
42+
) {
3743
this.size = size < 0 ? DEFAULT_SIZE : size;
3844
this.from = from < 0 ? DEFAULT_FROM : from;
3945
this.rankWindowSize = rankWindowSize;
@@ -45,7 +51,7 @@ public boolean failuresAllowed() {
4551
return failuresAllowed;
4652
}
4753

48-
public Snippets snippets() {
54+
public RerankSnippetInput snippets() {
4955
return snippets;
5056
}
5157

server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardPhase.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,12 @@ public static void prepareForFetch(SearchContext searchContext, RankFeatureShard
5757
new FetchFieldsContext(Collections.singletonList(new FieldAndFormat(rankFeaturePhaseRankShardContext.getField(), null)))
5858
);
5959
try {
60-
Snippets snippets = request.snippets();
60+
RerankSnippetInput snippets = request.snippets();
6161
if (snippets != null) {
62+
// For POC purposes we're just stripping pre/post tags and deferring if/how we'd want to handle them for this use case.
6263
HighlightBuilder highlightBuilder = new HighlightBuilder().field(field).preTags("").postTags("");
64+
// Force sorting by score to ensure that the first snippet is always the highest score
65+
highlightBuilder.order(HighlightBuilder.Order.SCORE);
6366
if (snippets.numFragments() != null) {
6467
highlightBuilder.numOfFragments(snippets.numFragments());
6568
}

server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardRequest.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,14 @@ public class RankFeatureShardRequest extends TransportRequest implements Indices
4040

4141
private final int[] docIds;
4242

43-
private final Snippets snippets;
43+
private final RerankSnippetInput snippets;
4444

4545
public RankFeatureShardRequest(
4646
OriginalIndices originalIndices,
4747
ShardSearchContextId contextId,
4848
ShardSearchRequest shardSearchRequest,
4949
List<Integer> docIds,
50-
@Nullable Snippets snippets
50+
@Nullable RerankSnippetInput snippets
5151
) {
5252
this.originalIndices = originalIndices;
5353
this.shardSearchRequest = shardSearchRequest;
@@ -63,7 +63,7 @@ public RankFeatureShardRequest(StreamInput in) throws IOException {
6363
docIds = in.readIntArray();
6464
contextId = in.readOptionalWriteable(ShardSearchContextId::new);
6565
if (in.getTransportVersion().onOrAfter(TransportVersions.RERANK_SNIPPETS)) {
66-
snippets = in.readOptionalWriteable(Snippets::new);
66+
snippets = in.readOptionalWriteable(RerankSnippetInput::new);
6767
} else {
6868
snippets = null;
6969
}
@@ -109,7 +109,7 @@ public ShardSearchContextId contextId() {
109109
return contextId;
110110
}
111111

112-
public Snippets snippets() {
112+
public RerankSnippetInput snippets() {
113113
return snippets;
114114
}
115115

server/src/main/java/org/elasticsearch/search/rank/feature/Snippets.java renamed to server/src/main/java/org/elasticsearch/search/rank/feature/RerankSnippetInput.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515

1616
import java.io.IOException;
1717

18-
public record Snippets(Integer numFragments, Integer maxSize) implements Writeable {
18+
public record RerankSnippetInput(Integer numFragments, Integer maxSize) implements Writeable {
1919

20-
public Snippets(StreamInput in) throws IOException {
20+
public RerankSnippetInput(StreamInput in) throws IOException {
2121
this(in.readOptionalVInt(), in.readOptionalVInt());
2222
}
2323

server/src/main/java/org/elasticsearch/search/rank/rerank/RerankingRankFeaturePhaseRankShardContext.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import org.elasticsearch.search.rank.feature.RankFeatureDoc;
2222
import org.elasticsearch.search.rank.feature.RankFeatureShardResult;
2323

24-
import java.util.ArrayList;
2524
import java.util.Arrays;
2625
import java.util.List;
2726
import java.util.Map;
@@ -53,9 +52,10 @@ public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId)
5352
}
5453
Map<String, HighlightField> highlightFields = hit.getHighlightFields();
5554
if (highlightFields != null) {
56-
HighlightField highlightField = highlightFields.get(field);
57-
if (highlightField != null) {
58-
List<String> snippets = new ArrayList<>(Arrays.stream(highlightField.fragments()).map(Text::toString).toList());
55+
if (highlightFields.containsKey(field)) {
56+
List<String> snippets = Arrays.stream(highlightFields.get(field).fragments())
57+
.map(Text::string)
58+
.collect(Collectors.toList());
5959
rankFeatureDocs[i].snippets(snippets);
6060
}
6161
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ public Request(
107107
String query,
108108
Boolean returnDocuments,
109109
Integer topN,
110-
List<String> input,
110+
List<String> input, // I think we need to add some metadata to the strings here and return this with each response
111111
Map<String, Object> taskSettings,
112112
InputType inputType,
113113
TimeValue inferenceTimeout,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext;
2424
import org.elasticsearch.search.rank.context.RankFeaturePhaseRankShardContext;
2525
import org.elasticsearch.search.rank.feature.RankFeatureDoc;
26-
import org.elasticsearch.search.rank.feature.Snippets;
26+
import org.elasticsearch.search.rank.feature.RerankSnippetInput;
2727
import org.elasticsearch.search.rank.rerank.RerankingRankFeaturePhaseRankShardContext;
2828
import org.elasticsearch.xcontent.XContentBuilder;
2929

@@ -55,7 +55,7 @@ public class TextSimilarityRankBuilder extends RankBuilder {
5555
private final String field;
5656
private final Float minScore;
5757
private final boolean failuresAllowed;
58-
private final Snippets snippets;
58+
private final RerankSnippetInput snippets;
5959

6060
public TextSimilarityRankBuilder(
6161
String field,
@@ -64,7 +64,7 @@ public TextSimilarityRankBuilder(
6464
int rankWindowSize,
6565
Float minScore,
6666
boolean failuresAllowed,
67-
Snippets snippets
67+
RerankSnippetInput snippets
6868
) {
6969
super(rankWindowSize);
7070
this.inferenceId = inferenceId;
@@ -88,7 +88,7 @@ public TextSimilarityRankBuilder(StreamInput in) throws IOException {
8888
this.failuresAllowed = false;
8989
}
9090
if (in.getTransportVersion().onOrAfter(TransportVersions.RERANK_SNIPPETS)) {
91-
this.snippets = in.readOptionalWriteable(Snippets::new);
91+
this.snippets = in.readOptionalWriteable(RerankSnippetInput::new);
9292
} else {
9393
this.snippets = null;
9494
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import org.elasticsearch.inference.TaskType;
1515
import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext;
1616
import org.elasticsearch.search.rank.feature.RankFeatureDoc;
17-
import org.elasticsearch.search.rank.feature.Snippets;
17+
import org.elasticsearch.search.rank.feature.RerankSnippetInput;
1818
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
1919
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
2020
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
@@ -45,7 +45,7 @@ public TextSimilarityRankFeaturePhaseRankCoordinatorContext(
4545
String inferenceText,
4646
Float minScore,
4747
boolean failuresAllowed,
48-
Snippets snippets
48+
RerankSnippetInput snippets
4949
) {
5050
super(size, from, rankWindowSize, failuresAllowed, snippets);
5151
this.client = client;
@@ -65,7 +65,7 @@ protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[
6565
// Ensure we get exactly as many scores as the number of docs we passed, otherwise we may return incorrect results
6666
List<RankedDocsResults.RankedDoc> rankedDocs = ((RankedDocsResults) results).getRankedDocs();
6767

68-
if (snippets == null && rankedDocs.size() != featureDocs.length) {
68+
if (rankedDocs.size() != featureDocs.length) {
6969
l.onFailure(
7070
new IllegalStateException(
7171
"Reranker input document count and returned score count mismatch: ["
@@ -112,17 +112,16 @@ protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[
112112
if (featureDocs.length == 0) {
113113
inferenceListener.onResponse(new InferenceAction.Response(new RankedDocsResults(List.of())));
114114
} else {
115-
List<String> featureData = new ArrayList<>();
116-
List<String> snippets = new ArrayList<>();
115+
List<String> inferenceInputs = new ArrayList<>();
117116
for (RankFeatureDoc featureDoc : featureDocs) {
118-
featureData.add(featureDoc.featureData);
119-
if (featureDoc.snippets != null) {
120-
snippets.addAll(featureDoc.snippets);
117+
if (featureDoc.snippets != null && featureDoc.snippets.isEmpty() == false) {
118+
// TODO support reranking multiple snippets
119+
inferenceInputs.add(featureDoc.snippets.get(0));
120+
} else {
121+
inferenceInputs.add(featureDoc.featureData);
121122
}
122123
}
123-
InferenceAction.Request inferenceRequest = snippets.isEmpty() == false
124-
? generateRequest(snippets)
125-
: generateRequest(featureData);
124+
InferenceAction.Request inferenceRequest = generateRequest(inferenceInputs);
126125
try {
127126
client.execute(InferenceAction.INSTANCE, inferenceRequest, inferenceListener);
128127
} finally {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import org.elasticsearch.license.XPackLicenseState;
1515
import org.elasticsearch.search.builder.SearchSourceBuilder;
1616
import org.elasticsearch.search.rank.RankDoc;
17-
import org.elasticsearch.search.rank.feature.Snippets;
17+
import org.elasticsearch.search.rank.feature.RerankSnippetInput;
1818
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
1919
import org.elasticsearch.search.retriever.RetrieverBuilder;
2020
import org.elasticsearch.search.retriever.RetrieverParserContext;
@@ -58,7 +58,7 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder
5858
String field = (String) args[3];
5959
int rankWindowSize = args[4] == null ? DEFAULT_RANK_WINDOW_SIZE : (int) args[4];
6060
boolean failuresAllowed = args[5] != null && (Boolean) args[5];
61-
Snippets snippets = (Snippets) args[6];
61+
RerankSnippetInput snippets = (RerankSnippetInput) args[6];
6262

6363
return new TextSimilarityRankRetrieverBuilder(
6464
retrieverBuilder,
@@ -71,15 +71,12 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder
7171
);
7272
});
7373

74-
private static final ConstructingObjectParser<Snippets, RetrieverParserContext> SNIPPETS_PARSER = new ConstructingObjectParser<>(
75-
SNIPPETS_FIELD.getPreferredName(),
76-
true,
77-
args -> {
74+
private static final ConstructingObjectParser<RerankSnippetInput, RetrieverParserContext> SNIPPETS_PARSER =
75+
new ConstructingObjectParser<>(SNIPPETS_FIELD.getPreferredName(), true, args -> {
7876
Integer numFragments = (Integer) args[0];
7977
Integer maxSize = (Integer) args[1];
80-
return new Snippets(numFragments, maxSize);
81-
}
82-
);
78+
return new RerankSnippetInput(numFragments, maxSize);
79+
});
8380

8481
static {
8582
PARSER.declareNamedObject(constructorArg(), (p, c, n) -> {
@@ -114,7 +111,7 @@ public static TextSimilarityRankRetrieverBuilder fromXContent(
114111
private final String inferenceText;
115112
private final String field;
116113
private final boolean failuresAllowed;
117-
private final Snippets snippets;
114+
private final RerankSnippetInput snippets;
118115

119116
public TextSimilarityRankRetrieverBuilder(
120117
RetrieverBuilder retrieverBuilder,
@@ -123,7 +120,7 @@ public TextSimilarityRankRetrieverBuilder(
123120
String field,
124121
int rankWindowSize,
125122
boolean failuresAllowed,
126-
Snippets snippets
123+
RerankSnippetInput snippets
127124
) {
128125
super(List.of(new RetrieverSource(retrieverBuilder, null)), rankWindowSize);
129126
this.inferenceId = inferenceId;
@@ -143,7 +140,7 @@ public TextSimilarityRankRetrieverBuilder(
143140
boolean failuresAllowed,
144141
String retrieverName,
145142
List<QueryBuilder> preFilterQueryBuilders,
146-
Snippets snippets
143+
RerankSnippetInput snippets
147144
) {
148145
super(retrieverSource, rankWindowSize);
149146
if (retrieverSource.size() != 1) {
@@ -226,7 +223,7 @@ public boolean failuresAllowed() {
226223
return failuresAllowed;
227224
}
228225

229-
public Snippets snippets() {
226+
public RerankSnippetInput snippets() {
230227
return snippets;
231228
}
232229

0 commit comments

Comments
 (0)