Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/137700.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 137700
summary: Add rerank chunking for JinaAI
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,9 @@ public static ChunkingSettings buildChunkingSettingsForElasticRerank(int queryWo
var maxChunkSizeWordCount = (int) (maxChunkSizeTokenCount * WORDS_PER_TOKEN);
return new SentenceBoundaryChunkingSettings(maxChunkSizeWordCount, 1);
}

public static ChunkingSettings buildChunkingSettingsForRerank(int rerankerWindowSize, int queryWordCount) {
var chunkSizeWordCountWithFullQuery = rerankerWindowSize - queryWordCount;
return new SentenceBoundaryChunkingSettings(Math.max(chunkSizeWordCountWithFullQuery, rerankerWindowSize / 2), 1);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ public RerankRequestChunker(String query, List<String> inputs, Integer maxChunks
this.rerankChunks = chunk(inputs, buildChunkingSettingsForElasticRerank(query), maxChunksPerDoc);
}

public RerankRequestChunker(String query, List<String> inputs, int rerankerWindowSize, Integer maxChunksPerDoc) {
this.inputs = inputs;
this.rerankChunks = chunk(inputs, buildChunkingSettingsForRerank(rerankerWindowSize, query), maxChunksPerDoc);
}

private List<RerankChunks> chunk(List<String> inputs, ChunkingSettings chunkingSettings, Integer maxChunksPerDoc) {
var chunker = ChunkerBuilder.fromChunkingStrategy(chunkingSettings.getChunkingStrategy());
var chunks = new ArrayList<RerankChunks>();
Expand All @@ -53,6 +58,25 @@ public List<String> getChunkedInputs() {
return chunkedInputs;
}

public ActionListener<InferenceServiceResults> parseChunkedRerankResultsListener(
ActionListener<InferenceServiceResults> listener,
boolean returnDocuments,
Integer topN
) {
listener = parseChunkedRerankResultsListener(listener, returnDocuments);
if (topN != null) {
return listener.delegateFailureAndWrap((l, results) -> {
if (results instanceof RankedDocsResults rankedDocsResults) {
l.onResponse(new RankedDocsResults(rankedDocsResults.getRankedDocs().subList(0, topN)));
} else {
l.onFailure(new IllegalArgumentException("Expected RankedDocsResults but got: " + results.getClass().getName()));
}
});
} else {
return listener;
}
}

public ActionListener<InferenceServiceResults> parseChunkedRerankResultsListener(
ActionListener<InferenceServiceResults> listener,
boolean returnDocuments
Expand Down Expand Up @@ -101,4 +125,11 @@ private ChunkingSettings buildChunkingSettingsForElasticRerank(String query) {
var queryWordCount = ChunkerUtils.countWords(0, query.length(), wordIterator);
return ChunkingSettingsBuilder.buildChunkingSettingsForElasticRerank(queryWordCount);
}

private ChunkingSettings buildChunkingSettingsForRerank(int rerankerWindowSize, String query) {
var wordIterator = BreakIterator.getWordInstance();
wordIterator.setText(query);
var queryWordCount = ChunkerUtils.countWords(0, query.length(), wordIterator);
return ChunkingSettingsBuilder.buildChunkingSettingsForRerank(rerankerWindowSize, queryWordCount);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,31 @@ public void testBuildChunkingSettingsForElasticReranker_QueryTokenCountMoreThanH
assertEquals(1, sentenceBoundaryChunkingSettings.sentenceOverlap());
}

public void testBuildChunkingSettingsForRerank_QueryWordCountLessThanHalfOfRerankerWindowSize() {
int rerankerWindowSize = randomIntBetween(20, 6000);
int queryWordCount = randomIntBetween(1, rerankerWindowSize / 2 - 1);
ChunkingSettings actualChunkingSettings = ChunkingSettingsBuilder.buildChunkingSettingsForRerank(
rerankerWindowSize,
queryWordCount
);
SentenceBoundaryChunkingSettings expectedChunkingSettings = new SentenceBoundaryChunkingSettings(
rerankerWindowSize - queryWordCount,
1
);
assertEquals(expectedChunkingSettings, actualChunkingSettings);
}

public void testBuildChunkingSettingsForRerank_QueryWordCountMoreThanHalfOfRerankerWindowSize() {
int rerankerWindowSize = randomIntBetween(20, 6000);
int queryWordCount = randomIntBetween(rerankerWindowSize / 2, Integer.MAX_VALUE);
ChunkingSettings actualChunkingSettings = ChunkingSettingsBuilder.buildChunkingSettingsForRerank(
rerankerWindowSize,
queryWordCount
);
SentenceBoundaryChunkingSettings expectedChunkingSettings = new SentenceBoundaryChunkingSettings(rerankerWindowSize / 2, 1);
assertEquals(expectedChunkingSettings, actualChunkingSettings);
}

private Map<Map<String, Object>, ChunkingSettings> chunkingSettingsMapToChunkingSettings() {
var maxChunkSizeWordBoundaryChunkingSettings = randomIntBetween(10, 300);
var overlap = randomIntBetween(1, maxChunkSizeWordBoundaryChunkingSettings / 2);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,52 @@ public void testParseChunkedRerankResultsListener_MultipleInputsWithAllRequiring
listener.onResponse(new RankedDocsResults(generateRankedDocs(relevanceScores, chunkedInputs)));
}

public void testParseChunkedRerankResultsListener_TopNSetToNull() {
var inputs = List.of(generateTestText(10), generateTestText(100));
var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs, null);
var returnDocuments = randomBoolean();
var relevanceScores = generateRelevanceScores(4);
var listener = chunker.parseChunkedRerankResultsListener(ActionListener.wrap(results -> {
assertThat(results, instanceOf(RankedDocsResults.class));
var rankedDocResults = (RankedDocsResults) results;
var expectedRankedDocs = new ArrayList<RankedDocsResults.RankedDoc>();
expectedRankedDocs.add(new RankedDocsResults.RankedDoc(0, relevanceScores.get(0), returnDocuments ? inputs.get(0) : null));
expectedRankedDocs.add(
new RankedDocsResults.RankedDoc(1, Collections.max(relevanceScores.subList(1, 4)), returnDocuments ? inputs.get(1) : null)
);
expectedRankedDocs.sort((r1, r2) -> Float.compare(r2.relevanceScore(), r1.relevanceScore()));
assertEquals(expectedRankedDocs, rankedDocResults.getRankedDocs());
}, e -> fail("Expected successful parsing but got failure: " + e)), returnDocuments, null);

var chunkedInputs = chunker.getChunkedInputs();
assertEquals(4, chunkedInputs.size());
listener.onResponse(new RankedDocsResults(generateRankedDocs(relevanceScores, chunkedInputs)));
}

public void testParseChunkedRerankResultsListener_TopNProvided() {
var inputs = List.of(generateTestText(10), generateTestText(100));
var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs, null);
var returnDocuments = randomBoolean();
var relevanceScores = generateRelevanceScores(4);
var topN = 1;
var listener = chunker.parseChunkedRerankResultsListener(ActionListener.wrap(results -> {
assertThat(results, instanceOf(RankedDocsResults.class));
var rankedDocResults = (RankedDocsResults) results;
assertEquals(topN, rankedDocResults.getRankedDocs().size());
var expectedRankedDocs = new ArrayList<RankedDocsResults.RankedDoc>();
expectedRankedDocs.add(new RankedDocsResults.RankedDoc(0, relevanceScores.get(0), returnDocuments ? inputs.get(0) : null));
expectedRankedDocs.add(
new RankedDocsResults.RankedDoc(1, Collections.max(relevanceScores.subList(1, 4)), returnDocuments ? inputs.get(1) : null)
);
expectedRankedDocs.sort((r1, r2) -> Float.compare(r2.relevanceScore(), r1.relevanceScore()));
assertEquals(expectedRankedDocs.subList(0, topN), rankedDocResults.getRankedDocs());
}, e -> fail("Expected successful parsing but got failure: " + e)), returnDocuments, topN);

var chunkedInputs = chunker.getChunkedInputs();
assertEquals(4, chunkedInputs.size());
listener.onResponse(new RankedDocsResults(generateRankedDocs(relevanceScores, chunkedInputs)));
}

private String generateTestText(int numSentences) {
StringBuilder sb = new StringBuilder();
for (int i = 0; i < numSentences; i++) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services;

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.RerankingInferenceService;
import org.elasticsearch.xpack.core.inference.chunking.RerankRequestChunker;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs;
import org.elasticsearch.xpack.inference.services.settings.LongDocumentStrategy;
import org.elasticsearch.xpack.inference.services.settings.RerankServiceSettings;

public class RerankInferenceProcessor {
public static void doInfer(
SenderService service,
Model model,
ExecutableAction action,
InferenceInputs inputs,
TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
var serviceSettings = model.getServiceSettings();
if (serviceSettings instanceof RerankServiceSettings == false) {
throw new IllegalArgumentException("RerankInferenceProcessor can only process models with RerankServiceSettings");
}

var rerankServiceSettings = (RerankServiceSettings) serviceSettings;
if (LongDocumentStrategy.CHUNK.equals(rerankServiceSettings.getLongDocumentStrategy())) {
if (inputs instanceof QueryAndDocsInputs == false) {
throw new IllegalArgumentException("RerankInferenceProcessor can only process QueryAndDocsInputs when chunking is enabled");
}
var queryAndDocsInputs = (QueryAndDocsInputs) inputs;

if (service instanceof RerankingInferenceService == false) {
throw new IllegalArgumentException(
"RerankInferenceProcessor can only process RerankingInferenceService when chunking is enabled"
);
}

var rerankRequestChunker = new RerankRequestChunker(
queryAndDocsInputs.getQuery(),
queryAndDocsInputs.getChunks(),
((RerankingInferenceService) service).rerankerWindowSize(serviceSettings.modelId()),
rerankServiceSettings.getMaxChunksPerDoc()
);

inputs = new QueryAndDocsInputs(
queryAndDocsInputs.getQuery(),
rerankRequestChunker.getChunkedInputs(),
queryAndDocsInputs.getReturnDocuments(), // TODO: Check if we want this from inputs or from task settings
queryAndDocsInputs.getTopN(), // TODO: Check if we want this from inputs or from task settings
queryAndDocsInputs.stream()
);

listener = rerankRequestChunker.parseChunkedRerankResultsListener(
listener,
queryAndDocsInputs.getReturnDocuments() == null || queryAndDocsInputs.getReturnDocuments(),
queryAndDocsInputs.getTopN()
);
}
action.execute(inputs, timeout, listener);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnifiedCompletionRequest;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
Expand Down Expand Up @@ -150,13 +151,25 @@ public void chunkedInfer(
}).addListener(listener);
}

protected abstract void doInfer(
protected void doInfer(
Model model,
InferenceInputs inputs,
Map<String, Object> taskSettings,
TimeValue timeout,
ActionListener<InferenceServiceResults> listener
);
) {
var action = createAction(model, taskSettings);

if (Objects.requireNonNull(model.getTaskType()) == TaskType.RERANK) {
RerankInferenceProcessor.doInfer(this, model, action, inputs, timeout, listener);
} else {
action.execute(inputs, timeout, listener);
}
}

protected ExecutableAction createAction(Model model, Map<String, Object> taskSettings) {
return null;
}

protected abstract void validateInputType(InputType inputType, Model model, ValidationException validationException);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
import org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.core.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.SenderService;
Expand Down Expand Up @@ -235,23 +235,14 @@ protected void doUnifiedCompletionInfer(
}

@Override
public void doInfer(
Model model,
InferenceInputs inputs,
Map<String, Object> taskSettings,
TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
if (model instanceof JinaAIModel == false) {
listener.onFailure(createInvalidModelException(model));
return;
}

JinaAIModel jinaaiModel = (JinaAIModel) model;
protected ExecutableAction createAction(Model model, Map<String, Object> taskSettings) {
var actionCreator = new JinaAIActionCreator(getSender(), getServiceComponents());

var action = jinaaiModel.accept(actionCreator, taskSettings);
action.execute(inputs, timeout, listener);
if (model instanceof JinaAIModel jinaaiModel) {
return jinaaiModel.accept(actionCreator, taskSettings);
} else {
throw createInvalidModelException(model);
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.jinaai.JinaAIRateLimitServiceSettings;
import org.elasticsearch.xpack.inference.services.jinaai.JinaAIServiceSettings;
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import org.elasticsearch.xpack.inference.services.settings.RerankServiceSettings;

import java.io.IOException;
import java.util.Map;
import java.util.Objects;

public class JinaAIRerankServiceSettings extends FilteredXContentObject implements ServiceSettings, JinaAIRateLimitServiceSettings {
public class JinaAIRerankServiceSettings extends RerankServiceSettings implements ServiceSettings, JinaAIRateLimitServiceSettings {
public static final String NAME = "jinaai_rerank_service_settings";

private static final Logger logger = LogManager.getLogger(JinaAIRerankServiceSettings.class);
Expand All @@ -38,18 +38,26 @@ public static JinaAIRerankServiceSettings fromMap(Map<String, Object> map, Confi
throw validationException;
}

var rerankServiceSettings = RerankServiceSettings.fromMap(map);
var commonServiceSettings = JinaAIServiceSettings.fromMap(map, context);

return new JinaAIRerankServiceSettings(commonServiceSettings);
return new JinaAIRerankServiceSettings(rerankServiceSettings, commonServiceSettings);
}

private final JinaAIServiceSettings commonSettings;

public JinaAIRerankServiceSettings(RerankServiceSettings rerankSettings, JinaAIServiceSettings commonSettings) {
super(rerankSettings);
this.commonSettings = commonSettings;
}

public JinaAIRerankServiceSettings(JinaAIServiceSettings commonSettings) {
super(null, null);
this.commonSettings = commonSettings;
}

public JinaAIRerankServiceSettings(StreamInput in) throws IOException {
super(in);
this.commonSettings = new JinaAIServiceSettings(in);
}

Expand All @@ -76,14 +84,16 @@ public String getWriteableName() {
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();

builder = super.toXContentFragmentOfExposedFields(builder, params);
builder = commonSettings.toXContentFragment(builder, params);

builder.endObject();
return builder;
}

@Override
protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException {
public XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException {
super.toXContentFragmentOfExposedFields(builder, params);
commonSettings.toXContentFragmentOfExposedFields(builder, params);
return builder;
}
Expand All @@ -95,6 +105,7 @@ public TransportVersion getMinimalSupportedVersion() {

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
commonSettings.writeTo(out);
}

Expand All @@ -103,11 +114,12 @@ public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
JinaAIRerankServiceSettings that = (JinaAIRerankServiceSettings) o;
return Objects.equals(commonSettings, that.commonSettings);
return super.equals(o) && Objects.equals(commonSettings, that.commonSettings);
}

@Override
public int hashCode() {
return Objects.hash(commonSettings);
// TODO: include super.hashcode?
}
}
Loading