Skip to content

Commit d56726c

Browse files
committed
Consolidate snippet rank input
1 parent 6e8521a commit d56726c

File tree

8 files changed

+83
-135
lines changed

8 files changed

+83
-135
lines changed

server/src/main/java/org/elasticsearch/search/rank/feature/RerankSnippetConfig.java renamed to server/src/main/java/org/elasticsearch/search/rank/feature/RerankSnippetInput.java

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,37 +17,57 @@
1717
import java.io.IOException;
1818
import java.util.Objects;
1919

20-
public class RerankSnippetConfig implements Writeable {
20+
public class RerankSnippetInput implements Writeable {
2121

2222
public final Integer numSnippets;
23+
private final String inferenceText;
24+
private final Integer tokenSizeLimit;
2325
public final QueryBuilder snippetQueryBuilder;
2426

2527
public static final int DEFAULT_NUM_SNIPPETS = 1;
2628

27-
public RerankSnippetConfig(StreamInput in) throws IOException {
29+
public RerankSnippetInput(StreamInput in) throws IOException {
2830
this.numSnippets = in.readOptionalVInt();
31+
this.inferenceText = in.readString();
32+
this.tokenSizeLimit = in.readOptionalVInt();
2933
this.snippetQueryBuilder = in.readOptionalNamedWriteable(QueryBuilder.class);
3034
}
3135

32-
public RerankSnippetConfig(Integer numSnippets) {
33-
this(numSnippets, null);
36+
public RerankSnippetInput(Integer numSnippets) {
37+
this(numSnippets, null, null);
3438
}
3539

36-
public RerankSnippetConfig(Integer numSnippets, QueryBuilder snippetQueryBuilder) {
40+
public RerankSnippetInput(Integer numSnippets, String inferenceText, Integer tokenSizeLimit) {
41+
this(numSnippets, inferenceText, tokenSizeLimit, null);
42+
}
43+
44+
public RerankSnippetInput(Integer numSnippets, String inferenceText, Integer tokenSizeLimit, QueryBuilder snippetQueryBuilder) {
3745
this.numSnippets = numSnippets;
46+
this.inferenceText = inferenceText;
47+
this.tokenSizeLimit = tokenSizeLimit;
3848
this.snippetQueryBuilder = snippetQueryBuilder;
3949
}
4050

4151
@Override
4252
public void writeTo(StreamOutput out) throws IOException {
4353
out.writeOptionalVInt(numSnippets);
54+
out.writeString(inferenceText);
55+
out.writeOptionalVInt(tokenSizeLimit);
4456
out.writeOptionalNamedWriteable(snippetQueryBuilder);
4557
}
4658

4759
public Integer numSnippets() {
4860
return numSnippets;
4961
}
5062

63+
public String inferenceText() {
64+
return inferenceText;
65+
}
66+
67+
public Integer tokenSizeLimit() {
68+
return tokenSizeLimit;
69+
}
70+
5171
public QueryBuilder snippetQueryBuilder() {
5272
return snippetQueryBuilder;
5373
}
@@ -56,12 +76,15 @@ public QueryBuilder snippetQueryBuilder() {
5676
public boolean equals(Object o) {
5777
if (this == o) return true;
5878
if (o == null || getClass() != o.getClass()) return false;
59-
RerankSnippetConfig that = (RerankSnippetConfig) o;
60-
return Objects.equals(numSnippets, that.numSnippets) && Objects.equals(snippetQueryBuilder, that.snippetQueryBuilder);
79+
RerankSnippetInput that = (RerankSnippetInput) o;
80+
return Objects.equals(numSnippets, that.numSnippets)
81+
&& Objects.equals(inferenceText, that.inferenceText)
82+
&& Objects.equals(tokenSizeLimit, that.tokenSizeLimit)
83+
&& Objects.equals(snippetQueryBuilder, that.snippetQueryBuilder);
6184
}
6285

6386
@Override
6487
public int hashCode() {
65-
return Objects.hash(numSnippets, snippetQueryBuilder);
88+
return Objects.hash(numSnippets, inferenceText, tokenSizeLimit, snippetQueryBuilder);
6689
}
6790
}

server/src/main/java/org/elasticsearch/search/rank/feature/SnippetRankInput.java

Lines changed: 0 additions & 73 deletions
This file was deleted.

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@
2727
import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext;
2828
import org.elasticsearch.search.rank.context.RankFeaturePhaseRankShardContext;
2929
import org.elasticsearch.search.rank.feature.RankFeatureDoc;
30-
import org.elasticsearch.search.rank.feature.RerankSnippetConfig;
31-
import org.elasticsearch.search.rank.feature.SnippetRankInput;
30+
import org.elasticsearch.search.rank.feature.RerankSnippetInput;
3231
import org.elasticsearch.xcontent.XContentBuilder;
3332

3433
import java.io.IOException;
@@ -73,7 +72,7 @@ public class TextSimilarityRankBuilder extends RankBuilder {
7372
private final String field;
7473
private final Float minScore;
7574
private final boolean failuresAllowed;
76-
private final SnippetRankInput snippetRankInput;
75+
private final RerankSnippetInput rerankSnippetInput;
7776

7877
public TextSimilarityRankBuilder(
7978
String field,
@@ -82,15 +81,15 @@ public TextSimilarityRankBuilder(
8281
int rankWindowSize,
8382
Float minScore,
8483
boolean failuresAllowed,
85-
SnippetRankInput snippetRankInput
84+
RerankSnippetInput rerankSnippetInput
8685
) {
8786
super(rankWindowSize);
8887
this.inferenceId = inferenceId;
8988
this.inferenceText = inferenceText;
9089
this.field = field;
9190
this.minScore = minScore;
9291
this.failuresAllowed = failuresAllowed;
93-
this.snippetRankInput = snippetRankInput;
92+
this.rerankSnippetInput = rerankSnippetInput;
9493
}
9594

9695
public TextSimilarityRankBuilder(StreamInput in) throws IOException {
@@ -107,9 +106,9 @@ public TextSimilarityRankBuilder(StreamInput in) throws IOException {
107106
this.failuresAllowed = false;
108107
}
109108
if (in.getTransportVersion().onOrAfter(TransportVersions.RERANK_SNIPPETS)) {
110-
this.snippetRankInput = in.readOptionalWriteable(SnippetRankInput::new);
109+
this.rerankSnippetInput = in.readOptionalWriteable(RerankSnippetInput::new);
111110
} else {
112-
this.snippetRankInput = null;
111+
this.rerankSnippetInput = null;
113112
}
114113
}
115114

@@ -135,7 +134,7 @@ public void doWriteTo(StreamOutput out) throws IOException {
135134
out.writeBoolean(failuresAllowed);
136135
}
137136
if (out.getTransportVersion().onOrAfter(TransportVersions.RERANK_SNIPPETS)) {
138-
out.writeOptionalWriteable(snippetRankInput);
137+
out.writeOptionalWriteable(rerankSnippetInput);
139138
}
140139
}
141140

@@ -152,17 +151,16 @@ public void doXContent(XContentBuilder builder, Params params) throws IOExceptio
152151
if (failuresAllowed) {
153152
builder.field(FAILURES_ALLOWED_FIELD.getPreferredName(), true);
154153
}
155-
if (snippetRankInput != null) {
156-
builder.field(SNIPPETS_FIELD.getPreferredName(), snippetRankInput);
154+
if (rerankSnippetInput != null) {
155+
builder.field(SNIPPETS_FIELD.getPreferredName(), rerankSnippetInput);
157156
}
158157
}
159158

160159
@Override
161160
public RankBuilder rewrite(QueryRewriteContext queryRewriteContext) throws IOException {
162161
TextSimilarityRankBuilder rewritten = this;
163-
RerankSnippetConfig snippets = snippetRankInput != null ? snippetRankInput.snippets() : null;
164-
if (snippets != null) {
165-
QueryBuilder snippetQueryBuilder = snippets.snippetQueryBuilder();
162+
if (rerankSnippetInput != null) {
163+
QueryBuilder snippetQueryBuilder = rerankSnippetInput.snippetQueryBuilder();
166164
if (snippetQueryBuilder == null) {
167165
rewritten = new TextSimilarityRankBuilder(
168166
field,
@@ -171,10 +169,11 @@ public RankBuilder rewrite(QueryRewriteContext queryRewriteContext) throws IOExc
171169
rankWindowSize(),
172170
minScore,
173171
failuresAllowed,
174-
new SnippetRankInput(
175-
new RerankSnippetConfig(snippets.numSnippets(), new MatchQueryBuilder(field, inferenceText)),
176-
snippetRankInput.inferenceText(),
177-
snippetRankInput.tokenSizeLimit()
172+
new RerankSnippetInput(
173+
rerankSnippetInput.numSnippets(),
174+
rerankSnippetInput.inferenceText(),
175+
rerankSnippetInput.tokenSizeLimit(),
176+
new MatchQueryBuilder(field, inferenceText)
178177
)
179178
);
180179
} else {
@@ -187,10 +186,11 @@ public RankBuilder rewrite(QueryRewriteContext queryRewriteContext) throws IOExc
187186
rankWindowSize(),
188187
minScore,
189188
failuresAllowed,
190-
new SnippetRankInput(
191-
new RerankSnippetConfig(snippets.numSnippets(), rewrittenSnippetQueryBuilder),
192-
snippetRankInput.inferenceText(),
193-
snippetRankInput.tokenSizeLimit()
189+
new RerankSnippetInput(
190+
rerankSnippetInput.numSnippets(),
191+
rerankSnippetInput.inferenceText(),
192+
rerankSnippetInput.tokenSizeLimit(),
193+
rewrittenSnippetQueryBuilder
194194
)
195195
);
196196
}
@@ -244,7 +244,7 @@ public QueryPhaseRankCoordinatorContext buildQueryPhaseCoordinatorContext(int si
244244

245245
@Override
246246
public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() {
247-
return new TextSimilarityRerankingRankFeaturePhaseRankShardContext(field, snippetRankInput);
247+
return new TextSimilarityRerankingRankFeaturePhaseRankShardContext(field, rerankSnippetInput);
248248
}
249249

250250
@Override
@@ -258,7 +258,9 @@ public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorCo
258258
inferenceText,
259259
minScore,
260260
failuresAllowed,
261-
snippetRankInput != null ? new SnippetRankInput(snippetRankInput.snippets(), inferenceText, tokenSizeLimit(inferenceId)) : null
261+
rerankSnippetInput != null
262+
? new RerankSnippetInput(rerankSnippetInput.numSnippets, inferenceText, tokenSizeLimit(inferenceId))
263+
: null
262264
);
263265
}
264266

@@ -301,12 +303,12 @@ protected boolean doEquals(RankBuilder other) {
301303
&& Objects.equals(field, that.field)
302304
&& Objects.equals(minScore, that.minScore)
303305
&& failuresAllowed == that.failuresAllowed
304-
&& Objects.equals(snippetRankInput, that.snippetRankInput);
306+
&& Objects.equals(rerankSnippetInput, that.rerankSnippetInput);
305307
}
306308

307309
@Override
308310
protected int doHashCode() {
309-
return Objects.hash(inferenceId, inferenceText, field, minScore, failuresAllowed, snippetRankInput);
311+
return Objects.hash(inferenceId, inferenceText, field, minScore, failuresAllowed, rerankSnippetInput);
310312
}
311313

312314
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import org.elasticsearch.inference.TaskType;
1616
import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext;
1717
import org.elasticsearch.search.rank.feature.RankFeatureDoc;
18-
import org.elasticsearch.search.rank.feature.SnippetRankInput;
18+
import org.elasticsearch.search.rank.feature.RerankSnippetInput;
1919
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
2020
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
2121
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
@@ -41,7 +41,7 @@ public class TextSimilarityRankFeaturePhaseRankCoordinatorContext extends RankFe
4141
protected final String inferenceId;
4242
protected final String inferenceText;
4343
protected final Float minScore;
44-
protected final SnippetRankInput snippetRankInput;
44+
protected final RerankSnippetInput rerankSnippetInput;
4545

4646
public TextSimilarityRankFeaturePhaseRankCoordinatorContext(
4747
int size,
@@ -52,14 +52,14 @@ public TextSimilarityRankFeaturePhaseRankCoordinatorContext(
5252
String inferenceText,
5353
Float minScore,
5454
boolean failuresAllowed,
55-
@Nullable SnippetRankInput snippetRankInput
55+
@Nullable RerankSnippetInput rerankSnippetInput
5656
) {
5757
super(size, from, rankWindowSize, failuresAllowed);
5858
this.client = client;
5959
this.inferenceId = inferenceId;
6060
this.inferenceText = inferenceText;
6161
this.minScore = minScore;
62-
this.snippetRankInput = snippetRankInput;
62+
this.rerankSnippetInput = rerankSnippetInput;
6363
}
6464

6565
@Override
@@ -81,7 +81,7 @@ protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[
8181
l.onResponse(originalScores);
8282
} else {
8383
final float[] scores;
84-
if (this.snippetRankInput != null) {
84+
if (this.rerankSnippetInput != null) {
8585
scores = extractScoresFromRankedSnippets(rankedDocs, featureDocs);
8686
} else {
8787
scores = extractScoresFromRankedDocs(rankedDocs);

0 commit comments

Comments
 (0)