Skip to content

Commit 0196a7c

Browse files
committed
Support reranking based on max score of multiple snippets per document
1 parent f99c33e commit 0196a7c

File tree

5 files changed

+62
-18
lines changed

5 files changed

+62
-18
lines changed

server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureDoc.java

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ public class RankFeatureDoc extends RankDoc {
3030
// TODO: update to support more than 1 fields; and not restrict to string data
3131
public String featureData;
3232
public List<String> snippets;
33+
public List<Integer> docIndices;
3334

3435
public RankFeatureDoc(int doc, float score, int shardIndex) {
3536
super(doc, score, shardIndex);
@@ -40,6 +41,7 @@ public RankFeatureDoc(StreamInput in) throws IOException {
4041
featureData = in.readOptionalString();
4142
if (in.getTransportVersion().onOrAfter(TransportVersions.RERANK_SNIPPETS)) {
4243
snippets = in.readOptionalStringCollectionAsList();
44+
docIndices = in.readOptionalCollectionAsList(StreamInput::readVInt);
4345
}
4446
}
4547

@@ -56,23 +58,30 @@ public void snippets(List<String> snippets) {
5658
this.snippets = snippets;
5759
}
5860

61+
public void docIndices(List<Integer> docIndices) {
62+
this.docIndices = docIndices;
63+
}
64+
5965
@Override
6066
protected void doWriteTo(StreamOutput out) throws IOException {
6167
out.writeOptionalString(featureData);
6268
if (out.getTransportVersion().onOrAfter(TransportVersions.RERANK_SNIPPETS)) {
6369
out.writeOptionalStringCollection(snippets);
70+
out.writeOptionalCollection(docIndices, StreamOutput::writeVInt);
6471
}
6572
}
6673

6774
@Override
6875
protected boolean doEquals(RankDoc rd) {
6976
RankFeatureDoc other = (RankFeatureDoc) rd;
70-
return Objects.equals(this.featureData, other.featureData) && Objects.equals(this.snippets, other.snippets);
77+
return Objects.equals(this.featureData, other.featureData)
78+
&& Objects.equals(this.snippets, other.snippets)
79+
&& Objects.equals(this.docIndices, other.docIndices);
7180
}
7281

7382
@Override
7483
protected int doHashCode() {
75-
return Objects.hash(featureData, snippets);
84+
return Objects.hash(featureData, snippets, docIndices);
7685
}
7786

7887
@Override
@@ -84,5 +93,6 @@ public String getWriteableName() {
8493
protected void doToXContent(XContentBuilder builder, Params params) throws IOException {
8594
builder.field("featureData", featureData);
8695
builder.array("snippets", snippets);
96+
builder.array("docIndices", docIndices);
8797
}
8898
}

server/src/main/java/org/elasticsearch/search/rank/rerank/RerankingRankFeaturePhaseRankShardContext.java

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@
2121
import org.elasticsearch.search.rank.feature.RankFeatureDoc;
2222
import org.elasticsearch.search.rank.feature.RankFeatureShardResult;
2323

24+
import java.util.ArrayList;
2425
import java.util.Arrays;
2526
import java.util.List;
2627
import java.util.Map;
27-
import java.util.stream.Collectors;
2828

2929
/**
3030
* The {@code ReRankingRankFeaturePhaseRankShardContext} is handles the {@code SearchHits} generated from the {@code RankFeatureShardPhase}
@@ -43,6 +43,7 @@ public RerankingRankFeaturePhaseRankShardContext(String field) {
4343
public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) {
4444
try {
4545
RankFeatureDoc[] rankFeatureDocs = new RankFeatureDoc[hits.getHits().length];
46+
int docIndex = 0;
4647
for (int i = 0; i < hits.getHits().length; i++) {
4748
rankFeatureDocs[i] = new RankFeatureDoc(hits.getHits()[i].docId(), hits.getHits()[i].getScore(), shardId);
4849
SearchHit hit = hits.getHits()[i];
@@ -53,12 +54,16 @@ public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId)
5354
Map<String, HighlightField> highlightFields = hit.getHighlightFields();
5455
if (highlightFields != null) {
5556
if (highlightFields.containsKey(field)) {
56-
List<String> snippets = Arrays.stream(highlightFields.get(field).fragments())
57-
.map(Text::string)
58-
.collect(Collectors.toList());
57+
List<String> snippets = Arrays.stream(highlightFields.get(field).fragments()).map(Text::string).toList();
58+
List<Integer> docIndices = new ArrayList<>();
59+
for (String snippet : snippets) {
60+
docIndices.add(docIndex);
61+
}
5962
rankFeatureDocs[i].snippets(snippets);
63+
rankFeatureDocs[i].docIndices(docIndices);
6064
}
6165
}
66+
docIndex++;
6267
}
6368
return new RankFeatureShardResult(rankFeatureDocs);
6469
} catch (Exception ex) {

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ public Request(
107107
String query,
108108
Boolean returnDocuments,
109109
Integer topN,
110-
List<String> input, // I think we need to add some metadata to the strings here and return this with each response
110+
List<String> input,
111111
Map<String, Object> taskSettings,
112112
InputType inputType,
113113
TimeValue inferenceTimeout,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import static org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder.INFERENCE_ID_FIELD;
3737
import static org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder.INFERENCE_TEXT_FIELD;
3838
import static org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder.MIN_SCORE_FIELD;
39+
import static org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder.SNIPPETS_FIELD;
3940

4041
/**
4142
* A {@code RankBuilder} that enables ranking with text similarity model inference. Supports parameters for configuring the inference call.
@@ -133,7 +134,7 @@ public void doXContent(XContentBuilder builder, Params params) throws IOExceptio
133134
builder.field(FAILURES_ALLOWED_FIELD.getPreferredName(), true);
134135
}
135136
if (snippets != null) {
136-
137+
builder.field(SNIPPETS_FIELD.getPreferredName(), snippets);
137138
}
138139
}
139140

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankTaskSettings;
2323

2424
import java.util.ArrayList;
25+
import java.util.Arrays;
2526
import java.util.List;
2627
import java.util.Map;
2728

@@ -57,30 +58,32 @@ public TextSimilarityRankFeaturePhaseRankCoordinatorContext(
5758
@Override
5859
protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[]> scoreListener) {
5960

60-
// Reconcile the input strings with the documents that they belong to. Input size 6.
61-
// Let's say we have 6 snippets that we reranked from 2 documents (3 snippets each)
62-
6361
// Wrap the provided rankListener to an ActionListener that would handle the response from the inference service
6462
// and then pass the results
6563
final ActionListener<InferenceAction.Response> inferenceListener = scoreListener.delegateFailureAndWrap((l, r) -> {
6664
InferenceServiceResults results = r.getResults();
6765
assert results instanceof RankedDocsResults;
6866

69-
// Ensure we get exactly as many scores as the number of docs we passed, otherwise we may return incorrect results
7067
List<RankedDocsResults.RankedDoc> rankedDocs = ((RankedDocsResults) results).getRankedDocs();
68+
final float[] scores;
69+
if (featureDocs.length > 0 && featureDocs[0].snippets != null) {
70+
scores = extractScoresFromRankedSnippets(rankedDocs, featureDocs);
71+
} else {
72+
scores = extractScoresFromRankedDocs(rankedDocs);
73+
}
7174

72-
if (rankedDocs.size() != featureDocs.length) {
75+
// Ensure we get exactly as many final scores as the number of docs we passed, otherwise we may return incorrect results
76+
if (scores.length != featureDocs.length) {
7377
l.onFailure(
7478
new IllegalStateException(
7579
"Reranker input document count and returned score count mismatch: ["
7680
+ featureDocs.length
7781
+ "] vs ["
78-
+ rankedDocs.size()
82+
+ scores.length
7983
+ "]"
8084
)
8185
);
8286
} else {
83-
float[] scores = extractScoresFromRankedDocs(rankedDocs); // Return is size 2
8487
l.onResponse(scores);
8588
}
8689
});
@@ -119,8 +122,7 @@ protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[
119122
List<String> inferenceInputs = new ArrayList<>();
120123
for (RankFeatureDoc featureDoc : featureDocs) {
121124
if (featureDoc.snippets != null && featureDoc.snippets.isEmpty() == false) {
122-
// TODO support reranking multiple snippets
123-
inferenceInputs.add(featureDoc.snippets.get(0));
125+
inferenceInputs.addAll(featureDoc.snippets);
124126
} else {
125127
inferenceInputs.add(featureDoc.featureData);
126128
}
@@ -181,7 +183,33 @@ private float[] extractScoresFromRankedDocs(List<RankedDocsResults.RankedDoc> ra
181183
for (RankedDocsResults.RankedDoc rankedDoc : rankedDocs) {
182184
scores[rankedDoc.index()] = rankedDoc.relevanceScore();
183185
}
184-
return scores; // Return a float of size 2 (max score index per doc)
186+
return scores;
187+
}
188+
189+
private float[] extractScoresFromRankedSnippets(List<RankedDocsResults.RankedDoc> rankedDocs, RankFeatureDoc[] featureDocs) {
190+
int[] docMappings = Arrays.stream(featureDocs).flatMapToInt(f -> f.docIndices.stream().mapToInt(Integer::intValue)).toArray();
191+
192+
float[] scores = new float[featureDocs.length];
193+
boolean[] hasScore = new boolean[featureDocs.length];
194+
195+
for (int i = 0; i < rankedDocs.size(); i++) {
196+
int docId = docMappings[i];
197+
float score = rankedDocs.get(i).relevanceScore();
198+
199+
if (hasScore[docId] == false) {
200+
scores[docId] = score;
201+
hasScore[docId] = true;
202+
} else {
203+
scores[docId] = Math.max(scores[docId], score);
204+
}
205+
}
206+
207+
float[] result = new float[featureDocs.length];
208+
for (int i = 0; i < featureDocs.length; i++) {
209+
result[i] = hasScore[i] ? normalizeScore(scores[i]) : 0f;
210+
}
211+
212+
return result;
185213
}
186214

187215
private static float normalizeScore(float score) {

0 commit comments

Comments
 (0)