Skip to content

Commit e6208ec

Browse files
committed
Add hardcoded max token length
1 parent 1684fba commit e6208ec

File tree

6 files changed

+40
-3
lines changed

6 files changed

+40
-3
lines changed

server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,8 @@ public void onFailure(Exception e) {
173173
queryResult.getContextId(),
174174
queryResult.getShardSearchRequest(),
175175
entry,
176-
rankFeaturePhaseRankCoordinatorContext.snippets()
176+
rankFeaturePhaseRankCoordinatorContext.snippets(),
177+
rankFeaturePhaseRankCoordinatorContext.tokenSizeLimit()
177178
),
178179
context.getTask(),
179180
listener

server/src/main/java/org/elasticsearch/search/rank/context/RankFeaturePhaseRankCoordinatorContext.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ public RerankSnippetInput snippets() {
5555
return snippets;
5656
}
5757

58+
public abstract Integer tokenSizeLimit();
59+
5860
/**
5961
* Computes the updated scores for a list of features (i.e. document-based data). We also pass along an ActionListener
6062
* that should be called with the new scores, and will continue execution to the next phase

server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardPhase.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ public static void prepareForFetch(SearchContext searchContext, RankFeatureShard
6666
if (snippets.numSnippets() != null) {
6767
highlightBuilder.numOfFragments(snippets.numSnippets());
6868
}
69-
highlightBuilder.fragmentSize(20); // TODO use model limit
69+
highlightBuilder.fragmentSize(request.getTokenSizeLimit());
7070
SearchHighlightContext searchHighlightContext = highlightBuilder.build(searchContext.getSearchExecutionContext());
7171
searchContext.highlight(searchHighlightContext);
7272
}

server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardRequest.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,19 +41,22 @@ public class RankFeatureShardRequest extends AbstractTransportRequest implements
4141
private final int[] docIds;
4242

4343
private final RerankSnippetInput snippets;
44+
private final int tokenSizeLimit;
4445

4546
public RankFeatureShardRequest(
4647
OriginalIndices originalIndices,
4748
ShardSearchContextId contextId,
4849
ShardSearchRequest shardSearchRequest,
4950
List<Integer> docIds,
50-
@Nullable RerankSnippetInput snippets
51+
@Nullable RerankSnippetInput snippets,
52+
int tokenSizeLimit
5153
) {
5254
this.originalIndices = originalIndices;
5355
this.shardSearchRequest = shardSearchRequest;
5456
this.docIds = docIds.stream().flatMapToInt(IntStream::of).toArray();
5557
this.contextId = contextId;
5658
this.snippets = snippets;
59+
this.tokenSizeLimit = tokenSizeLimit;
5760
}
5861

5962
public RankFeatureShardRequest(StreamInput in) throws IOException {
@@ -64,8 +67,10 @@ public RankFeatureShardRequest(StreamInput in) throws IOException {
6467
contextId = in.readOptionalWriteable(ShardSearchContextId::new);
6568
if (in.getTransportVersion().onOrAfter(TransportVersions.RERANK_SNIPPETS)) {
6669
snippets = in.readOptionalWriteable(RerankSnippetInput::new);
70+
this.tokenSizeLimit = in.readVInt();
6771
} else {
6872
snippets = null;
73+
this.tokenSizeLimit = 0;
6974
}
7075
}
7176

@@ -78,6 +83,7 @@ public void writeTo(StreamOutput out) throws IOException {
7883
out.writeOptionalWriteable(contextId);
7984
if (out.getTransportVersion().onOrAfter(TransportVersions.RERANK_SNIPPETS)) {
8085
out.writeOptionalWriteable(snippets);
86+
out.writeVInt(tokenSizeLimit);
8187
}
8288
}
8389

@@ -113,6 +119,10 @@ public RerankSnippetInput snippets() {
113119
return snippets;
114120
}
115121

122+
public int getTokenSizeLimit() {
123+
return tokenSizeLimit;
124+
}
125+
116126
@Override
117127
public SearchShardTask createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
118128
return new SearchShardTask(id, type, action, getDescription(), parentTaskId, headers);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/random/RandomRankFeaturePhaseRankCoordinatorContext.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@ public RandomRankFeaturePhaseRankCoordinatorContext(int size, int from, int rank
2626
this.seed = seed;
2727
}
2828

29+
@Override
30+
public Integer tokenSizeLimit() {
31+
return 0;
32+
}
33+
2934
@Override
3035
protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[]> scoreListener) {
3136
// Generate random scores seeded by doc

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,18 @@
2828

2929
import static org.elasticsearch.xpack.core.ClientHelper.INFERENCE_ORIGIN;
3030
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
31+
import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.DEFAULT_RERANK_ID;
32+
import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.RERANKER_ID;
3133

3234
/**
3335
* A {@code RankFeaturePhaseRankCoordinatorContext} that performs a rerank inference call to determine relevance scores for documents within
3436
* the provided rank window.
3537
*/
3638
public class TextSimilarityRankFeaturePhaseRankCoordinatorContext extends RankFeaturePhaseRankCoordinatorContext {
3739

40+
private static final int RERANK_TOKEN_SIZE_LIMIT = 512;
41+
private static final int DEFAULT_TOKEN_SIZE_LIMIT = 4096;
42+
3843
protected final Client client;
3944
protected final String inferenceId;
4045
protected final String inferenceText;
@@ -58,6 +63,20 @@ public TextSimilarityRankFeaturePhaseRankCoordinatorContext(
5863
this.minScore = minScore;
5964
}
6065

66+
/**
67+
* @return The token size limit to apply to this rerank context.
68+
* This is not yet available so we are hardcoding it for now.
69+
* See: https://github.com/elastic/ml-team/issues/1622
70+
*/
71+
@Override
72+
public Integer tokenSizeLimit() {
73+
if (inferenceId.equals(DEFAULT_RERANK_ID) || inferenceId.equals(RERANKER_ID)) {
74+
return RERANK_TOKEN_SIZE_LIMIT;
75+
}
76+
77+
return DEFAULT_TOKEN_SIZE_LIMIT;
78+
}
79+
6180
@Override
6281
protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[]> scoreListener) {
6382

0 commit comments

Comments
 (0)