Skip to content

Commit d8dbaab

Browse files
committed
Remove docIndices from RankFeatureDoc
1 parent ce8364b commit d8dbaab

File tree

3 files changed

+8
-21
lines changed

3 files changed

+8
-21
lines changed

server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureDoc.java

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ public class RankFeatureDoc extends RankDoc {
2929

3030
// TODO: update to support more than 1 fields; and not restrict to string data
3131
public List<String> featureData;
32-
public List<Integer> docIndices;
3332

3433
public RankFeatureDoc(int doc, float score, int shardIndex) {
3534
super(doc, score, shardIndex);
@@ -39,7 +38,6 @@ public RankFeatureDoc(StreamInput in) throws IOException {
3938
super(in);
4039
if (in.getTransportVersion().onOrAfter(TransportVersions.RERANK_SNIPPETS)) {
4140
featureData = in.readOptionalStringCollectionAsList();
42-
docIndices = in.readOptionalCollectionAsList(StreamInput::readVInt);
4341
} else {
4442
String featureDataString = in.readOptionalString();
4543
featureData = featureDataString == null ? null : List.of(featureDataString);
@@ -55,15 +53,14 @@ public void featureData(List<String> featureData) {
5553
this.featureData = featureData;
5654
}
5755

58-
public void docIndices(List<Integer> docIndices) {
59-
this.docIndices = docIndices;
56+
public void featureData(String featureData) {
57+
this.featureData = List.of(featureData);
6058
}
6159

6260
@Override
6361
protected void doWriteTo(StreamOutput out) throws IOException {
6462
if (out.getTransportVersion().onOrAfter(TransportVersions.RERANK_SNIPPETS)) {
6563
out.writeOptionalStringCollection(featureData);
66-
out.writeOptionalCollection(docIndices, StreamOutput::writeVInt);
6764
} else {
6865
out.writeOptionalString(featureData.get(0));
6966
}
@@ -72,12 +69,12 @@ protected void doWriteTo(StreamOutput out) throws IOException {
7269
@Override
7370
protected boolean doEquals(RankDoc rd) {
7471
RankFeatureDoc other = (RankFeatureDoc) rd;
75-
return Objects.equals(this.featureData, other.featureData) && Objects.equals(this.docIndices, other.docIndices);
72+
return Objects.equals(this.featureData, other.featureData);
7673
}
7774

7875
@Override
7976
protected int doHashCode() {
80-
return Objects.hash(featureData, docIndices);
77+
return Objects.hash(featureData);
8178
}
8279

8380
@Override
@@ -88,6 +85,5 @@ public String getWriteableName() {
8885
@Override
8986
protected void doToXContent(XContentBuilder builder, Params params) throws IOException {
9087
builder.array("featureData", featureData);
91-
builder.array("docIndices", docIndices);
9288
}
9389
}

server/src/main/java/org/elasticsearch/search/rank/rerank/RerankingRankFeaturePhaseRankShardContext.java

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import org.elasticsearch.search.rank.feature.RerankSnippetInput;
2323
import org.elasticsearch.xcontent.Text;
2424

25-
import java.util.ArrayList;
2625
import java.util.Arrays;
2726
import java.util.List;
2827
import java.util.Map;
@@ -50,27 +49,20 @@ public RerankingRankFeaturePhaseRankShardContext(String field, RerankSnippetInpu
5049
public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) {
5150
try {
5251
RankFeatureDoc[] rankFeatureDocs = new RankFeatureDoc[hits.getHits().length];
53-
int docIndex = 0;
5452
for (int i = 0; i < hits.getHits().length; i++) {
5553
rankFeatureDocs[i] = new RankFeatureDoc(hits.getHits()[i].docId(), hits.getHits()[i].getScore(), shardId);
5654
SearchHit hit = hits.getHits()[i];
5755
DocumentField docField = hit.field(field);
5856
if (docField != null && snippets == null) {
59-
rankFeatureDocs[i].featureData(List.of(docField.getValue().toString()));
57+
rankFeatureDocs[i].featureData(docField.getValue().toString());
6058
}
6159
Map<String, HighlightField> highlightFields = hit.getHighlightFields();
6260
if (highlightFields != null) {
6361
if (highlightFields.containsKey(field)) {
6462
List<String> snippets = Arrays.stream(highlightFields.get(field).fragments()).map(Text::string).toList();
65-
List<Integer> docIndices = new ArrayList<>();
66-
for (String s : snippets) {
67-
docIndices.add(docIndex);
68-
}
6963
rankFeatureDocs[i].featureData(snippets);
70-
rankFeatureDocs[i].docIndices(docIndices);
7164
}
7265
}
73-
docIndex++;
7466
}
7567
return new RankFeatureShardResult(rankFeatureDocs);
7668
} catch (Exception ex) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankTaskSettings;
2424

2525
import java.util.ArrayList;
26-
import java.util.Arrays;
2726
import java.util.List;
2827
import java.util.Map;
2928

@@ -192,13 +191,13 @@ private float[] extractScoresFromRankedDocs(List<RankedDocsResults.RankedDoc> ra
192191
}
193192

194193
private float[] extractScoresFromRankedSnippets(List<RankedDocsResults.RankedDoc> rankedDocs, RankFeatureDoc[] featureDocs) {
195-
int[] docMappings = Arrays.stream(featureDocs).flatMapToInt(f -> f.docIndices.stream().mapToInt(Integer::intValue)).toArray();
196-
194+
int numSnippets = rankedDocs.size() / featureDocs.length;
197195
float[] scores = new float[featureDocs.length];
198196
boolean[] hasScore = new boolean[featureDocs.length];
199197

200198
for (int i = 0; i < rankedDocs.size(); i++) {
201-
int docId = docMappings[i];
199+
// TODO this naively assumes that we always get the requested number of snippets per ranked document
200+
int docId = i / numSnippets;
202201
float score = rankedDocs.get(i).relevanceScore();
203202

204203
if (hasScore[docId] == false) {

0 commit comments

Comments
 (0)