From 849147e733b291e9f12a376aa0eb96fb6fc7779c Mon Sep 17 00:00:00 2001 From: dan-rubinstein Date: Tue, 10 Jun 2025 10:26:05 -0400 Subject: [PATCH 1/8] Add RerankRequestChunker --- .../action/TransportInferenceAction.java | 2 + .../chunking/RerankRequestChunker.java | 99 +++++++++++++++++++ ...ankFeaturePhaseRankCoordinatorContext.java | 13 ++- 3 files changed, 112 insertions(+), 2 deletions(-) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RerankRequestChunker.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java index 7d24b7766baa3..c28bd485f0e48 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java @@ -74,6 +74,8 @@ protected void doInference( InferenceService service, ActionListener listener ) { + // var rerankChunker = new RerankRequestChunker(request.getInput()); + service.infer( model, request.getQuery(), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RerankRequestChunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RerankRequestChunker.java new file mode 100644 index 0000000000000..322294034b8b9 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RerankRequestChunker.java @@ -0,0 +1,99 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.chunking; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class RerankRequestChunker { + + private final ChunkingSettings chunkingSettings; + private final List inputs; + private final Map rerankChunks; + + public RerankRequestChunker(List inputs) { + // TODO: Make chunking settings dependent on the model being used. + // There may be a way to do this dynamically knowing the max token size for the model/service and query size + // instead of hardcoding it ona model/service basis. + this.chunkingSettings = new WordBoundaryChunkingSettings(100, 10); + this.inputs = inputs; + this.rerankChunks = chunk(inputs, chunkingSettings); + } + + private Map chunk(List inputs, ChunkingSettings chunkingSettings) { + var chunker = ChunkerBuilder.fromChunkingStrategy(chunkingSettings.getChunkingStrategy()); + var chunks = new HashMap(); + var chunkIndex = 0; + for (int i = 0; i < inputs.size(); i++) { + var chunksForInput = chunker.chunk(inputs.get(i), chunkingSettings); + for (var chunk : chunksForInput) { + chunks.put(chunkIndex, new RerankChunks(i, inputs.get(i).substring(chunk.start(), chunk.end()))); + chunkIndex++; + } + } + return chunks; + } + + public List getChunkedInputs() { + List chunkedInputs = new ArrayList<>(); + for (RerankChunks chunk : rerankChunks.values()) { + chunkedInputs.add(chunk.chunkString()); + } + // TODO: Score the inputs here and only return the top N chunks for each document + return chunkedInputs; + } + + public ActionListener parseChunkedRerankResultsListener(ActionListener listener) { + return ActionListener.wrap(results -> { + if (results.getResults() instanceof RankedDocsResults rankedDocsResults) { + listener.onResponse(new InferenceAction.Response(parseRankedDocResultsForChunks(rankedDocsResults))); + // TODO: Figure out if the above correctly creates the response or if it loses any info + + } else { + listener.onFailure(new IllegalArgumentException("Expected RankedDocsResults but got: " + results.getClass())); + } + + }, listener::onFailure); + } + + private RankedDocsResults parseRankedDocResultsForChunks(RankedDocsResults rankedDocsResults) { + Map bestRankedDocResultPerDoc = new HashMap<>(); + for (var rankedDoc : rankedDocsResults.getRankedDocs()) { + int chunkIndex = rankedDoc.index(); + int docIndex = rerankChunks.get(chunkIndex).docIndex(); + if (bestRankedDocResultPerDoc.containsKey(docIndex)) { + RankedDocsResults.RankedDoc existingDoc = bestRankedDocResultPerDoc.get(docIndex); + if (rankedDoc.relevanceScore() > existingDoc.relevanceScore()) { + bestRankedDocResultPerDoc.put( + docIndex, + new RankedDocsResults.RankedDoc(docIndex, rankedDoc.relevanceScore(), inputs.get(docIndex)) + ); + } + } else { + bestRankedDocResultPerDoc.put( + docIndex, + new RankedDocsResults.RankedDoc(docIndex, rankedDoc.relevanceScore(), inputs.get(docIndex)) + ); + } + } + var bestRankedDocResultPerDocList = new ArrayList<>(bestRankedDocResultPerDoc.values()); + bestRankedDocResultPerDocList.sort( + (RankedDocsResults.RankedDoc d1, RankedDocsResults.RankedDoc d2) -> Float.compare(d2.relevanceScore(), d1.relevanceScore()) + ); + return new RankedDocsResults(bestRankedDocResultPerDocList); + } + + public record RerankChunks(int docIndex, String chunkString) {}; +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java index 27221dc1f5caf..c231f48ceadac 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java @@ -17,6 +17,7 @@ import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; +import org.elasticsearch.xpack.inference.chunking.RerankRequestChunker; import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings; import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankTaskSettings; import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankTaskSettings; @@ -119,9 +120,16 @@ protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener featureData = Arrays.stream(featureDocs).map(x -> x.featureData).toList(); - InferenceAction.Request inferenceRequest = generateRequest(featureData); + RerankRequestChunker chunker = new RerankRequestChunker(featureData); + InferenceAction.Request inferenceRequest = generateRequest(chunker.getChunkedInputs()); try { - executeAsyncWithOrigin(client, INFERENCE_ORIGIN, InferenceAction.INSTANCE, inferenceRequest, inferenceListener); + executeAsyncWithOrigin( + client, + INFERENCE_ORIGIN, + InferenceAction.INSTANCE, + inferenceRequest, + chunker.parseChunkedRerankResultsListener(inferenceListener) + ); } finally { inferenceRequest.decRef(); } @@ -156,6 +164,7 @@ protected RankFeatureDoc[] preprocess(RankFeatureDoc[] originalDocs, boolean rer } protected InferenceAction.Request generateRequest(List docFeatures) { + // TODO: Try running the RerankRequestChunker here. return new InferenceAction.Request( TaskType.RERANK, inferenceId, From da4c939c713032d5afb3d3930e68b1b11e1ef3f7 Mon Sep 17 00:00:00 2001 From: dan-rubinstein Date: Fri, 4 Jul 2025 13:49:03 -0400 Subject: [PATCH 2/8] Add chunking strategy generation --- .../action/TransportInferenceAction.java | 1 - .../chunking/ChunkingSettingsBuilder.java | 16 ++++ .../chunking/RerankRequestChunker.java | 78 +++++++++---------- ...ankFeaturePhaseRankCoordinatorContext.java | 12 +-- .../ElasticsearchInternalService.java | 24 ++++-- .../ChunkingSettingsBuilderTests.java | 30 +++++++ 6 files changed, 103 insertions(+), 58 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java index c28bd485f0e48..8a0072287c022 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java @@ -74,7 +74,6 @@ protected void doInference( InferenceService service, ActionListener listener ) { - // var rerankChunker = new RerankRequestChunker(request.getInput()); service.infer( model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilder.java index 2f912d891ef60..8c32b0fe8dee9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilder.java @@ -17,6 +17,9 @@ public class ChunkingSettingsBuilder { public static final SentenceBoundaryChunkingSettings DEFAULT_SETTINGS = new SentenceBoundaryChunkingSettings(250, 1); // Old settings used for backward compatibility for endpoints created before 8.16 when default was changed public static final WordBoundaryChunkingSettings OLD_DEFAULT_SETTINGS = new WordBoundaryChunkingSettings(250, 100); + public static final int ELASTIC_RERANKER_TOKEN_LIMIT = 512; + public static final int ELASTIC_RERANKER_EXTRA_TOKEN_COUNT = 3; + public static final float TOKENS_PER_WORD = 0.75f; public static ChunkingSettings fromMap(Map settings) { return fromMap(settings, true); @@ -51,4 +54,17 @@ public static ChunkingSettings fromMap(Map settings, boolean ret case RECURSIVE -> RecursiveChunkingSettings.fromMap(new HashMap<>(settings)); }; } + + public static ChunkingSettings buildChunkingSettingsForElasticRerank(int queryWordCount) { + var queryTokenCount = Math.ceil(queryWordCount * TOKENS_PER_WORD); + var chunkSizeTokenCountWithFullQuery = (ELASTIC_RERANKER_TOKEN_LIMIT - ELASTIC_RERANKER_EXTRA_TOKEN_COUNT - queryTokenCount); + + var maxChunkSizeTokenCount = Math.floor((float) ELASTIC_RERANKER_TOKEN_LIMIT / 2); + if (chunkSizeTokenCountWithFullQuery > maxChunkSizeTokenCount) { + maxChunkSizeTokenCount = chunkSizeTokenCountWithFullQuery; + } + + var maxChunkSizeWordCount = (int) (maxChunkSizeTokenCount / TOKENS_PER_WORD); + return new SentenceBoundaryChunkingSettings(maxChunkSizeWordCount, 1); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RerankRequestChunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RerankRequestChunker.java index 322294034b8b9..eebfa0622842a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RerankRequestChunker.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RerankRequestChunker.java @@ -7,40 +7,34 @@ package org.elasticsearch.xpack.inference.chunking; +import com.ibm.icu.text.BreakIterator; + import org.elasticsearch.action.ActionListener; import org.elasticsearch.inference.ChunkingSettings; -import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; import java.util.ArrayList; -import java.util.HashMap; +import java.util.HashSet; import java.util.List; -import java.util.Map; +import java.util.Set; public class RerankRequestChunker { - - private final ChunkingSettings chunkingSettings; private final List inputs; - private final Map rerankChunks; + private final List rerankChunks; - public RerankRequestChunker(List inputs) { - // TODO: Make chunking settings dependent on the model being used. - // There may be a way to do this dynamically knowing the max token size for the model/service and query size - // instead of hardcoding it ona model/service basis. - this.chunkingSettings = new WordBoundaryChunkingSettings(100, 10); + public RerankRequestChunker(String query, List inputs) { this.inputs = inputs; - this.rerankChunks = chunk(inputs, chunkingSettings); + this.rerankChunks = chunk(inputs, buildChunkingSettingsForElasticRerank(query)); } - private Map chunk(List inputs, ChunkingSettings chunkingSettings) { + private List chunk(List inputs, ChunkingSettings chunkingSettings) { var chunker = ChunkerBuilder.fromChunkingStrategy(chunkingSettings.getChunkingStrategy()); - var chunks = new HashMap(); - var chunkIndex = 0; + var chunks = new ArrayList(); for (int i = 0; i < inputs.size(); i++) { var chunksForInput = chunker.chunk(inputs.get(i), chunkingSettings); for (var chunk : chunksForInput) { - chunks.put(chunkIndex, new RerankChunks(i, inputs.get(i).substring(chunk.start(), chunk.end()))); - chunkIndex++; + chunks.add(new RerankChunks(i, inputs.get(i).substring(chunk.start(), chunk.end()))); } } return chunks; @@ -48,18 +42,18 @@ private Map chunk(List inputs, ChunkingSettings c public List getChunkedInputs() { List chunkedInputs = new ArrayList<>(); - for (RerankChunks chunk : rerankChunks.values()) { + for (RerankChunks chunk : rerankChunks) { chunkedInputs.add(chunk.chunkString()); } + // TODO: Score the inputs here and only return the top N chunks for each document return chunkedInputs; } - public ActionListener parseChunkedRerankResultsListener(ActionListener listener) { + public ActionListener parseChunkedRerankResultsListener(ActionListener listener) { return ActionListener.wrap(results -> { - if (results.getResults() instanceof RankedDocsResults rankedDocsResults) { - listener.onResponse(new InferenceAction.Response(parseRankedDocResultsForChunks(rankedDocsResults))); - // TODO: Figure out if the above correctly creates the response or if it loses any info + if (results instanceof RankedDocsResults rankedDocsResults) { + listener.onResponse(parseRankedDocResultsForChunks(rankedDocsResults)); } else { listener.onFailure(new IllegalArgumentException("Expected RankedDocsResults but got: " + results.getClass())); @@ -68,32 +62,36 @@ public ActionListener parseChunkedRerankResultsListene }, listener::onFailure); } + // TODO: Can we assume the rankeddocsresults are always sorted by relevance score? + // TODO: Should we short circuit if no chunking was done? private RankedDocsResults parseRankedDocResultsForChunks(RankedDocsResults rankedDocsResults) { - Map bestRankedDocResultPerDoc = new HashMap<>(); - for (var rankedDoc : rankedDocsResults.getRankedDocs()) { + List updatedRankedDocs = new ArrayList<>(); + Set docIndicesSeen = new HashSet<>(); + for (RankedDocsResults.RankedDoc rankedDoc : rankedDocsResults.getRankedDocs()) { int chunkIndex = rankedDoc.index(); int docIndex = rerankChunks.get(chunkIndex).docIndex(); - if (bestRankedDocResultPerDoc.containsKey(docIndex)) { - RankedDocsResults.RankedDoc existingDoc = bestRankedDocResultPerDoc.get(docIndex); - if (rankedDoc.relevanceScore() > existingDoc.relevanceScore()) { - bestRankedDocResultPerDoc.put( - docIndex, - new RankedDocsResults.RankedDoc(docIndex, rankedDoc.relevanceScore(), inputs.get(docIndex)) - ); - } - } else { - bestRankedDocResultPerDoc.put( + + if (docIndicesSeen.contains(docIndex) == false) { + // Create a ranked doc with the full input string and the index for the document instead of the chunk + RankedDocsResults.RankedDoc updatedRankedDoc = new RankedDocsResults.RankedDoc( docIndex, - new RankedDocsResults.RankedDoc(docIndex, rankedDoc.relevanceScore(), inputs.get(docIndex)) + rankedDoc.relevanceScore(), + inputs.get(docIndex) ); + updatedRankedDocs.add(updatedRankedDoc); + docIndicesSeen.add(docIndex); } } - var bestRankedDocResultPerDocList = new ArrayList<>(bestRankedDocResultPerDoc.values()); - bestRankedDocResultPerDocList.sort( - (RankedDocsResults.RankedDoc d1, RankedDocsResults.RankedDoc d2) -> Float.compare(d2.relevanceScore(), d1.relevanceScore()) - ); - return new RankedDocsResults(bestRankedDocResultPerDocList); + + return new RankedDocsResults(updatedRankedDocs); } public record RerankChunks(int docIndex, String chunkString) {}; + + private ChunkingSettings buildChunkingSettingsForElasticRerank(String query) { + var wordIterator = BreakIterator.getWordInstance(); + wordIterator.setText(query); + var queryWordCount = ChunkerUtils.countWords(0, query.length(), wordIterator); + return ChunkingSettingsBuilder.buildChunkingSettingsForElasticRerank(queryWordCount); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java index c231f48ceadac..2b90577bd5db4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java @@ -17,7 +17,6 @@ import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; -import org.elasticsearch.xpack.inference.chunking.RerankRequestChunker; import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings; import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankTaskSettings; import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankTaskSettings; @@ -120,16 +119,9 @@ protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener featureData = Arrays.stream(featureDocs).map(x -> x.featureData).toList(); - RerankRequestChunker chunker = new RerankRequestChunker(featureData); - InferenceAction.Request inferenceRequest = generateRequest(chunker.getChunkedInputs()); + InferenceAction.Request inferenceRequest = generateRequest(featureData); try { - executeAsyncWithOrigin( - client, - INFERENCE_ORIGIN, - InferenceAction.INSTANCE, - inferenceRequest, - chunker.parseChunkedRerankResultsListener(inferenceListener) - ); + executeAsyncWithOrigin(client, INFERENCE_ORIGIN, InferenceAction.INSTANCE, inferenceRequest, inferenceListener); } finally { inferenceRequest.decRef(); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index 4f2674179be67..084ef7ae61188 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -57,6 +57,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextSimilarityConfigUpdate; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; +import org.elasticsearch.xpack.inference.chunking.RerankRequestChunker; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceUtils; @@ -686,7 +687,15 @@ public void inferRerank( Map requestTaskSettings, ActionListener listener ) { - var request = buildInferenceRequest(model.mlNodeDeploymentId(), new TextSimilarityConfigUpdate(query), inputs, inputType, timeout); + var rerankChunker = new RerankRequestChunker(query, inputs); + var chunkedInputs = rerankChunker.getChunkedInputs(); + var request = buildInferenceRequest( + model.mlNodeDeploymentId(), + new TextSimilarityConfigUpdate(query), + chunkedInputs, + inputType, + timeout + ); var returnDocs = Boolean.TRUE; if (returnDocuments != null) { @@ -696,13 +705,14 @@ public void inferRerank( returnDocs = RerankTaskSettings.of(modelSettings, requestSettings).returnDocuments(); } - Function inputSupplier = returnDocs == Boolean.TRUE ? inputs::get : i -> null; + Function inputSupplier = returnDocs == Boolean.TRUE ? chunkedInputs::get : i -> null; - ActionListener mlResultsListener = listener.delegateFailureAndWrap( - (l, inferenceResult) -> l.onResponse( - textSimilarityResultsToRankedDocs(inferenceResult.getInferenceResults(), inputSupplier, topN) - ) - ); + ActionListener mlResultsListener = rerankChunker.parseChunkedRerankResultsListener(listener) + .delegateFailureAndWrap( + (l, inferenceResult) -> l.onResponse( + textSimilarityResultsToRankedDocs(inferenceResult.getInferenceResults(), inputSupplier, topN) + ) + ); var maybeDeployListener = mlResultsListener.delegateResponse( (l, exception) -> maybeStartDeployment(model, exception, request, mlResultsListener) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilderTests.java index 9e6dde60bc641..69ebfbf6caa0c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilderTests.java @@ -15,6 +15,10 @@ import java.util.HashMap; import java.util.Map; +import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder.ELASTIC_RERANKER_EXTRA_TOKEN_COUNT; +import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder.ELASTIC_RERANKER_TOKEN_LIMIT; +import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder.TOKENS_PER_WORD; + public class ChunkingSettingsBuilderTests extends ESTestCase { public static final SentenceBoundaryChunkingSettings DEFAULT_SETTINGS = new SentenceBoundaryChunkingSettings(250, 1); @@ -47,6 +51,32 @@ public void testValidChunkingSettingsMap() { }); } + public void testBuildChunkingSettingsForElasticReranker_QueryTokenCountLessThanHalfOfTokenLimit() { + // Generate a word count for a non-empty query that takes up less than half the token limit + int maxQueryTokenCount = (ELASTIC_RERANKER_TOKEN_LIMIT - ELASTIC_RERANKER_EXTRA_TOKEN_COUNT) / 2; + int queryWordCount = randomIntBetween(1, (int) (maxQueryTokenCount / TOKENS_PER_WORD)); + var queryTokenCount = Math.ceil(queryWordCount * TOKENS_PER_WORD); + ChunkingSettings chunkingSettings = ChunkingSettingsBuilder.buildChunkingSettingsForElasticRerank(queryWordCount); + assertTrue(chunkingSettings instanceof SentenceBoundaryChunkingSettings); + SentenceBoundaryChunkingSettings sentenceBoundaryChunkingSettings = (SentenceBoundaryChunkingSettings) chunkingSettings; + int expectedMaxChunkSize = (int) ((ELASTIC_RERANKER_TOKEN_LIMIT - ELASTIC_RERANKER_EXTRA_TOKEN_COUNT - queryTokenCount) + / TOKENS_PER_WORD); + assertEquals(expectedMaxChunkSize, sentenceBoundaryChunkingSettings.maxChunkSize); + assertEquals(1, sentenceBoundaryChunkingSettings.sentenceOverlap); + } + + public void testBuildChunkingSettingsForElasticReranker_QueryTokenCountMoreThanHalfOfTokenLimit() { + // Generate a word count for a non-empty query that takes up more than half the token limit + int maxQueryTokenCount = (ELASTIC_RERANKER_TOKEN_LIMIT - ELASTIC_RERANKER_EXTRA_TOKEN_COUNT) / 2; + int queryWordCount = randomIntBetween((int) (maxQueryTokenCount / TOKENS_PER_WORD), Integer.MAX_VALUE); + ChunkingSettings chunkingSettings = ChunkingSettingsBuilder.buildChunkingSettingsForElasticRerank(queryWordCount); + assertTrue(chunkingSettings instanceof SentenceBoundaryChunkingSettings); + SentenceBoundaryChunkingSettings sentenceBoundaryChunkingSettings = (SentenceBoundaryChunkingSettings) chunkingSettings; + int expectedMaxChunkSize = (int) (Math.floor((float) ELASTIC_RERANKER_TOKEN_LIMIT / 2) / TOKENS_PER_WORD); + assertEquals(expectedMaxChunkSize, sentenceBoundaryChunkingSettings.maxChunkSize); + assertEquals(1, sentenceBoundaryChunkingSettings.sentenceOverlap); + } + private Map, ChunkingSettings> chunkingSettingsMapToChunkingSettings() { var maxChunkSizeWordBoundaryChunkingSettings = randomIntBetween(10, 300); var overlap = randomIntBetween(1, maxChunkSizeWordBoundaryChunkingSettings / 2); From 4ff8eb027ec9a508cb0a3fa353c53b5362f0a2fc Mon Sep 17 00:00:00 2001 From: dan-rubinstein Date: Wed, 23 Jul 2025 13:49:19 -0400 Subject: [PATCH 3/8] Adding unit tests and fixing token/word ratio --- .../chunking/ChunkingSettingsBuilder.java | 6 +- .../ChunkingSettingsBuilderTests.java | 12 +- .../chunking/RerankRequestChunkerTests.java | 194 ++++++++++++++++++ 3 files changed, 203 insertions(+), 9 deletions(-) create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/RerankRequestChunkerTests.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilder.java index 8c32b0fe8dee9..12c41c7ea9470 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilder.java @@ -19,7 +19,7 @@ public class ChunkingSettingsBuilder { public static final WordBoundaryChunkingSettings OLD_DEFAULT_SETTINGS = new WordBoundaryChunkingSettings(250, 100); public static final int ELASTIC_RERANKER_TOKEN_LIMIT = 512; public static final int ELASTIC_RERANKER_EXTRA_TOKEN_COUNT = 3; - public static final float TOKENS_PER_WORD = 0.75f; + public static final float WORDS_PER_TOKEN = 0.75f; public static ChunkingSettings fromMap(Map settings) { return fromMap(settings, true); @@ -56,7 +56,7 @@ public static ChunkingSettings fromMap(Map settings, boolean ret } public static ChunkingSettings buildChunkingSettingsForElasticRerank(int queryWordCount) { - var queryTokenCount = Math.ceil(queryWordCount * TOKENS_PER_WORD); + var queryTokenCount = Math.ceil(queryWordCount / WORDS_PER_TOKEN); var chunkSizeTokenCountWithFullQuery = (ELASTIC_RERANKER_TOKEN_LIMIT - ELASTIC_RERANKER_EXTRA_TOKEN_COUNT - queryTokenCount); var maxChunkSizeTokenCount = Math.floor((float) ELASTIC_RERANKER_TOKEN_LIMIT / 2); @@ -64,7 +64,7 @@ public static ChunkingSettings buildChunkingSettingsForElasticRerank(int queryWo maxChunkSizeTokenCount = chunkSizeTokenCountWithFullQuery; } - var maxChunkSizeWordCount = (int) (maxChunkSizeTokenCount / TOKENS_PER_WORD); + var maxChunkSizeWordCount = (int) (maxChunkSizeTokenCount * WORDS_PER_TOKEN); return new SentenceBoundaryChunkingSettings(maxChunkSizeWordCount, 1); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilderTests.java index 69ebfbf6caa0c..cc464a933481f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilderTests.java @@ -17,7 +17,7 @@ import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder.ELASTIC_RERANKER_EXTRA_TOKEN_COUNT; import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder.ELASTIC_RERANKER_TOKEN_LIMIT; -import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder.TOKENS_PER_WORD; +import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder.WORDS_PER_TOKEN; public class ChunkingSettingsBuilderTests extends ESTestCase { @@ -54,13 +54,13 @@ public void testValidChunkingSettingsMap() { public void testBuildChunkingSettingsForElasticReranker_QueryTokenCountLessThanHalfOfTokenLimit() { // Generate a word count for a non-empty query that takes up less than half the token limit int maxQueryTokenCount = (ELASTIC_RERANKER_TOKEN_LIMIT - ELASTIC_RERANKER_EXTRA_TOKEN_COUNT) / 2; - int queryWordCount = randomIntBetween(1, (int) (maxQueryTokenCount / TOKENS_PER_WORD)); - var queryTokenCount = Math.ceil(queryWordCount * TOKENS_PER_WORD); + int queryWordCount = randomIntBetween(1, (int) (maxQueryTokenCount * WORDS_PER_TOKEN)); + var queryTokenCount = Math.ceil(queryWordCount / WORDS_PER_TOKEN); ChunkingSettings chunkingSettings = ChunkingSettingsBuilder.buildChunkingSettingsForElasticRerank(queryWordCount); assertTrue(chunkingSettings instanceof SentenceBoundaryChunkingSettings); SentenceBoundaryChunkingSettings sentenceBoundaryChunkingSettings = (SentenceBoundaryChunkingSettings) chunkingSettings; int expectedMaxChunkSize = (int) ((ELASTIC_RERANKER_TOKEN_LIMIT - ELASTIC_RERANKER_EXTRA_TOKEN_COUNT - queryTokenCount) - / TOKENS_PER_WORD); + * WORDS_PER_TOKEN); assertEquals(expectedMaxChunkSize, sentenceBoundaryChunkingSettings.maxChunkSize); assertEquals(1, sentenceBoundaryChunkingSettings.sentenceOverlap); } @@ -68,11 +68,11 @@ public void testBuildChunkingSettingsForElasticReranker_QueryTokenCountLessThanH public void testBuildChunkingSettingsForElasticReranker_QueryTokenCountMoreThanHalfOfTokenLimit() { // Generate a word count for a non-empty query that takes up more than half the token limit int maxQueryTokenCount = (ELASTIC_RERANKER_TOKEN_LIMIT - ELASTIC_RERANKER_EXTRA_TOKEN_COUNT) / 2; - int queryWordCount = randomIntBetween((int) (maxQueryTokenCount / TOKENS_PER_WORD), Integer.MAX_VALUE); + int queryWordCount = randomIntBetween((int) (maxQueryTokenCount * WORDS_PER_TOKEN), Integer.MAX_VALUE); ChunkingSettings chunkingSettings = ChunkingSettingsBuilder.buildChunkingSettingsForElasticRerank(queryWordCount); assertTrue(chunkingSettings instanceof SentenceBoundaryChunkingSettings); SentenceBoundaryChunkingSettings sentenceBoundaryChunkingSettings = (SentenceBoundaryChunkingSettings) chunkingSettings; - int expectedMaxChunkSize = (int) (Math.floor((float) ELASTIC_RERANKER_TOKEN_LIMIT / 2) / TOKENS_PER_WORD); + int expectedMaxChunkSize = (int) (Math.floor((float) ELASTIC_RERANKER_TOKEN_LIMIT / 2) * WORDS_PER_TOKEN); assertEquals(expectedMaxChunkSize, sentenceBoundaryChunkingSettings.maxChunkSize); assertEquals(1, sentenceBoundaryChunkingSettings.sentenceOverlap); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/RerankRequestChunkerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/RerankRequestChunkerTests.java new file mode 100644 index 0000000000000..f882853e041b3 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/RerankRequestChunkerTests.java @@ -0,0 +1,194 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.chunking; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; + +import java.util.List; + +import static org.hamcrest.Matchers.instanceOf; + +public class RerankRequestChunkerTests extends ESTestCase { + private static final String TEST_SENTENCE = "This is a test sentence that has ten total words. "; + + public void testGetChunkedInput_EmptyInput() { + var chunker = new RerankRequestChunker(TEST_SENTENCE, List.of()); + assertTrue(chunker.getChunkedInputs().isEmpty()); + } + + public void testGetChunkedInput_SingleInputWithoutChunkingRequired() { + var inputs = List.of(generateTestText(10)); + var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs); + assertEquals(inputs, chunker.getChunkedInputs()); + } + + public void testGetChunkedInput_SingleInputWithChunkingRequired() { + var inputs = List.of(generateTestText(100)); + var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs); + var chunkedInputs = chunker.getChunkedInputs(); + assertEquals(3, chunkedInputs.size()); + } + + public void testGetChunkedInput_MultipleInputsWithoutChunkingRequired() { + var inputs = List.of(generateTestText(10), generateTestText(10)); + var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs); + assertEquals(inputs, chunker.getChunkedInputs()); + } + + public void testGetChunkedInput_MultipleInputsWithSomeChunkingRequired() { + var inputs = List.of(generateTestText(10), generateTestText(100)); + var chunker = new RerankRequestChunker(randomAlphaOfLength(10), inputs); + var chunkedInputs = chunker.getChunkedInputs(); + assertEquals(4, chunkedInputs.size()); + } + + public void testGetChunkedInput_MultipleInputsWithAllRequiringChunking() { + var inputs = List.of(generateTestText(100), generateTestText(100)); + var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs); + var chunkedInputs = chunker.getChunkedInputs(); + assertEquals(6, chunkedInputs.size()); + } + + public void testParseChunkedRerankResultsListener_NonRankedDocsResults() { + var inputs = List.of(generateTestText(10), generateTestText(100)); + var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs); + var listener = chunker.parseChunkedRerankResultsListener( + ActionListener.wrap( + results -> fail("Expected failure but got: " + results.getClass()), + e -> assertTrue(e instanceof IllegalArgumentException && e.getMessage().contains("Expected RankedDocsResults")) + ) + ); + + listener.onResponse(new InferenceServiceResults() { + }); + } + + public void testParseChunkedRerankResultsListener_EmptyInput() { + var chunker = new RerankRequestChunker(TEST_SENTENCE, List.of()); + var listener = chunker.parseChunkedRerankResultsListener(ActionListener.wrap(results -> { + assertThat(results, instanceOf(RankedDocsResults.class)); + var rankedDocResults = (RankedDocsResults) results; + assertEquals(0, rankedDocResults.getRankedDocs().size()); + }, e -> fail("Expected successful parsing but got failure: " + e))); + listener.onResponse(new RankedDocsResults(List.of())); + } + + public void testParseChunkedRerankResultsListener_SingleInputWithoutChunking() { + var inputs = List.of(generateTestText(10)); + var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs); + var listener = chunker.parseChunkedRerankResultsListener(ActionListener.wrap(results -> { + assertThat(results, instanceOf(RankedDocsResults.class)); + var rankedDocResults = (RankedDocsResults) results; + assertEquals(1, rankedDocResults.getRankedDocs().size()); + }, e -> fail("Expected successful parsing but got failure: " + e))); + + var chunkedInputs = chunker.getChunkedInputs(); + assertEquals(1, chunkedInputs.size()); + listener.onResponse(new RankedDocsResults(List.of(new RankedDocsResults.RankedDoc(0, 1.0f, chunkedInputs.get(0))))); + } + + public void testParseChunkedRerankResultsListener_SingleInputWithChunking() { + var inputs = List.of(generateTestText(100)); + var relevanceScore1 = randomFloatBetween(0, 1, true); + var relevanceScore2 = randomFloatBetween(0, 1, true); + var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs); + var listener = chunker.parseChunkedRerankResultsListener(ActionListener.wrap(results -> { + assertThat(results, instanceOf(RankedDocsResults.class)); + var rankedDocResults = (RankedDocsResults) results; + assertEquals(1, rankedDocResults.getRankedDocs().size()); + var expectedRankedDocs = List.of(new RankedDocsResults.RankedDoc(0, relevanceScore1, inputs.get(0))); + assertEquals(expectedRankedDocs, rankedDocResults.getRankedDocs()); + }, e -> fail("Expected successful parsing but got failure: " + e))); + + var chunkedInputs = chunker.getChunkedInputs(); + assertEquals(3, chunkedInputs.size()); + var rankedDocsResults = List.of( + new RankedDocsResults.RankedDoc(0, relevanceScore1, chunkedInputs.get(0)), + new RankedDocsResults.RankedDoc(1, relevanceScore2, chunkedInputs.get(1)) + ); + // TODO: Sort this so that the assumption that the results are in order holds + listener.onResponse(new RankedDocsResults(rankedDocsResults)); + } + + public void testParseChunkedRerankResultsListener_MultipleInputsWithoutChunking() { + var inputs = List.of(generateTestText(10), generateTestText(10)); + var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs); + var listener = chunker.parseChunkedRerankResultsListener(ActionListener.wrap(results -> { + assertThat(results, instanceOf(RankedDocsResults.class)); + var rankedDocResults = (RankedDocsResults) results; + assertEquals(2, rankedDocResults.getRankedDocs().size()); + }, e -> fail("Expected successful parsing but got failure: " + e))); + + var chunkedInputs = chunker.getChunkedInputs(); + assertEquals(2, chunkedInputs.size()); + listener.onResponse( + new RankedDocsResults( + List.of( + new RankedDocsResults.RankedDoc(0, 1.0f, chunkedInputs.get(0)), + new RankedDocsResults.RankedDoc(1, 1.0f, chunkedInputs.get(1)) + ) + ) + ); + } + + public void testParseChunkedRerankResultsListener_MultipleInputsWithSomeChunking() { + var inputs = List.of(generateTestText(10), generateTestText(100)); + var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs); + var listener = chunker.parseChunkedRerankResultsListener(ActionListener.wrap(results -> { + assertThat(results, instanceOf(RankedDocsResults.class)); + var rankedDocResults = (RankedDocsResults) results; + assertEquals(2, rankedDocResults.getRankedDocs().size()); + }, e -> fail("Expected successful parsing but got failure: " + e))); + + var chunkedInputs = chunker.getChunkedInputs(); + assertEquals(4, chunkedInputs.size()); + listener.onResponse( + new RankedDocsResults( + List.of( + new RankedDocsResults.RankedDoc(0, 1.0f, chunkedInputs.get(0)), + new RankedDocsResults.RankedDoc(1, 1.0f, chunkedInputs.get(1)), + new RankedDocsResults.RankedDoc(2, 1.0f, chunkedInputs.get(2)) + ) + ) + ); + } + + public void testParseChunkedRerankResultsListener_MultipleInputsWithAllRequiringChunking() { + var inputs = List.of(generateTestText(100), generateTestText(100)); + var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs); + var listener = chunker.parseChunkedRerankResultsListener(ActionListener.wrap(results -> { + assertThat(results, instanceOf(RankedDocsResults.class)); + var rankedDocResults = (RankedDocsResults) results; + assertEquals(2, rankedDocResults.getRankedDocs().size()); + }, e -> fail("Expected successful parsing but got failure: " + e))); + + var chunkedInputs = chunker.getChunkedInputs(); + assertEquals(6, chunkedInputs.size()); + listener.onResponse( + new RankedDocsResults( + List.of( + new RankedDocsResults.RankedDoc(0, 1.0f, chunkedInputs.get(0)), + new RankedDocsResults.RankedDoc(1, 1.0f, chunkedInputs.get(1)), + new RankedDocsResults.RankedDoc(2, 1.0f, chunkedInputs.get(2)), + new RankedDocsResults.RankedDoc(3, 1.0f, chunkedInputs.get(3)) + ) + ) + ); + } + + private String generateTestText(int numSentences) { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < numSentences; i++) { + sb.append(TEST_SENTENCE); + } + return sb.toString(); + } +} From 9ef89177dfbd718c07f18b1973b7d8b4c1198d38 Mon Sep 17 00:00:00 2001 From: dan-rubinstein Date: Mon, 8 Sep 2025 15:09:13 -0400 Subject: [PATCH 4/8] Add configurable values for long document handling strategy and maximum chunks per document --- .../chunking/ChunkingSettingsBuilder.java | 2 +- .../chunking/RerankRequestChunker.java | 13 ++- .../ElasticRerankerServiceSettings.java | 105 +++++++++++++++++- .../ElasticsearchInternalService.java | 37 +++--- .../ChunkingSettingsBuilderTests.java | 4 +- .../chunking/RerankRequestChunkerTests.java | 26 ++--- 6 files changed, 148 insertions(+), 39 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilder.java index 12c41c7ea9470..8857f01d26f4c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilder.java @@ -65,6 +65,6 @@ public static ChunkingSettings buildChunkingSettingsForElasticRerank(int queryWo } var maxChunkSizeWordCount = (int) (maxChunkSizeTokenCount * WORDS_PER_TOKEN); - return new SentenceBoundaryChunkingSettings(maxChunkSizeWordCount, 1); + return new SentenceBoundaryChunkingSettings(maxChunkSizeWordCount, 0); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RerankRequestChunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RerankRequestChunker.java index eebfa0622842a..2044d8ebef0b9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RerankRequestChunker.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RerankRequestChunker.java @@ -23,16 +23,23 @@ public class RerankRequestChunker { private final List inputs; private final List rerankChunks; - public RerankRequestChunker(String query, List inputs) { + public RerankRequestChunker(String query, List inputs, Integer maxChunksPerDoc) { this.inputs = inputs; - this.rerankChunks = chunk(inputs, buildChunkingSettingsForElasticRerank(query)); + this.rerankChunks = chunk(inputs, buildChunkingSettingsForElasticRerank(query), maxChunksPerDoc); } - private List chunk(List inputs, ChunkingSettings chunkingSettings) { + private List chunk(List inputs, ChunkingSettings chunkingSettings, Integer maxChunksPerDoc) { var chunker = ChunkerBuilder.fromChunkingStrategy(chunkingSettings.getChunkingStrategy()); var chunks = new ArrayList(); for (int i = 0; i < inputs.size(); i++) { var chunksForInput = chunker.chunk(inputs.get(i), chunkingSettings); + if (maxChunksPerDoc != null && chunksForInput.size() > maxChunksPerDoc) { + var limitedChunks = chunksForInput.subList(0, maxChunksPerDoc - 1); + var lastChunk = limitedChunks.getLast(); + limitedChunks.add(new Chunker.ChunkOffset(lastChunk.end(), inputs.get(i).length())); + chunksForInput = limitedChunks; + } + for (var chunk : chunksForInput) { chunks.add(new RerankChunks(i, inputs.get(i).substring(chunk.start(), chunk.end()))); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticRerankerServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticRerankerServiceSettings.java index 2b7904e615682..537c0fc0c307f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticRerankerServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticRerankerServiceSettings.java @@ -9,23 +9,49 @@ import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; import java.io.IOException; +import java.util.EnumSet; +import java.util.Locale; import java.util.Map; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.RERANKER_ID; public class ElasticRerankerServiceSettings extends ElasticsearchInternalServiceSettings { public static final String NAME = "elastic_reranker_service_settings"; + private static final String LONG_DOCUMENT_HANDLING_STRATEGY = "long_document_handling_strategy"; + private static final String MAX_CHUNKS_PER_DOC = "max_chunks_per_doc"; + + private final LongDocumentHandlingStrategy longDocumentHandlingStrategy; + private final Integer maxChunksPerDoc; + public static ElasticRerankerServiceSettings defaultEndpointSettings() { return new ElasticRerankerServiceSettings(null, 1, RERANKER_ID, new AdaptiveAllocationsSettings(Boolean.TRUE, 0, 32)); } public ElasticRerankerServiceSettings(ElasticsearchInternalServiceSettings other) { super(other); + this.longDocumentHandlingStrategy = null; + this.maxChunksPerDoc = null; + } + + public ElasticRerankerServiceSettings( + ElasticsearchInternalServiceSettings other, + LongDocumentHandlingStrategy longDocumentHandlingStrategy, + Integer maxChunksPerDoc + ) { + super(other); + this.longDocumentHandlingStrategy = longDocumentHandlingStrategy; + this.maxChunksPerDoc = maxChunksPerDoc; + } private ElasticRerankerServiceSettings( @@ -35,10 +61,15 @@ private ElasticRerankerServiceSettings( AdaptiveAllocationsSettings adaptiveAllocationsSettings ) { super(numAllocations, numThreads, modelId, adaptiveAllocationsSettings, null); + this.longDocumentHandlingStrategy = null; + this.maxChunksPerDoc = null; } public ElasticRerankerServiceSettings(StreamInput in) throws IOException { super(in); + // TODO: Add transport version here + this.longDocumentHandlingStrategy = in.readOptionalEnum(LongDocumentHandlingStrategy.class); + this.maxChunksPerDoc = in.readOptionalInt(); } /** @@ -48,21 +79,89 @@ public ElasticRerankerServiceSettings(StreamInput in) throws IOException { * {@link ValidationException} is thrown. * * @param map Source map containing the config - * @return The builder + * @return Parsed and validated service settings */ - public static Builder fromRequestMap(Map map) { + public static ElasticRerankerServiceSettings fromMap(Map map) { ValidationException validationException = new ValidationException(); var baseSettings = ElasticsearchInternalServiceSettings.fromMap(map, validationException); + LongDocumentHandlingStrategy longDocumentHandlingStrategy = extractOptionalEnum( + map, + LONG_DOCUMENT_HANDLING_STRATEGY, + ModelConfigurations.SERVICE_SETTINGS, + LongDocumentHandlingStrategy::fromString, + EnumSet.allOf(LongDocumentHandlingStrategy.class), + validationException + ); + + Integer maxChunksPerDoc = extractOptionalPositiveInteger( + map, + MAX_CHUNKS_PER_DOC, + ModelConfigurations.SERVICE_SETTINGS, + validationException + ); + + if (maxChunksPerDoc != null + && (longDocumentHandlingStrategy == null || longDocumentHandlingStrategy == LongDocumentHandlingStrategy.TRUNCATE)) { + validationException.addValidationError( + "The [" + MAX_CHUNKS_PER_DOC + "] setting requires [" + LONG_DOCUMENT_HANDLING_STRATEGY + "] to be set to [chunk]" + ); + } + if (validationException.validationErrors().isEmpty() == false) { throw validationException; } - return baseSettings; + return new ElasticRerankerServiceSettings(baseSettings.build(), longDocumentHandlingStrategy, maxChunksPerDoc); + } + + public LongDocumentHandlingStrategy getLongDocumentHandlingStrategy() { + return longDocumentHandlingStrategy; + } + + public Integer getMaxChunksPerDoc() { + return maxChunksPerDoc; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + // TODO: Add transport version here + out.writeOptionalEnum(longDocumentHandlingStrategy); + out.writeOptionalInt(maxChunksPerDoc); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + addInternalSettingsToXContent(builder, params); + if (longDocumentHandlingStrategy != null) { + builder.field(LONG_DOCUMENT_HANDLING_STRATEGY, longDocumentHandlingStrategy.strategyName); + } + if (maxChunksPerDoc != null) { + builder.field(MAX_CHUNKS_PER_DOC, maxChunksPerDoc); + } + builder.endObject(); + return builder; } @Override public String getWriteableName() { return ElasticRerankerServiceSettings.NAME; } + + public enum LongDocumentHandlingStrategy { + CHUNK("chunk"), + TRUNCATE("truncate"); + + public final String strategyName; + + LongDocumentHandlingStrategy(String strategyName) { + this.strategyName = strategyName; + } + + public static LongDocumentHandlingStrategy fromString(String name) { + return valueOf(name.trim().toUpperCase(Locale.ROOT)); + } + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index 2eb9cbcd255b6..55def97679edb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -349,19 +349,13 @@ private void rerankerCase( ActionListener modelListener ) { - var esServiceSettingsBuilder = ElasticsearchInternalServiceSettings.fromRequestMap(serviceSettingsMap); + var serviceSettings = ElasticRerankerServiceSettings.fromMap(serviceSettingsMap); throwIfNotEmptyMap(config, name()); throwIfNotEmptyMap(serviceSettingsMap, name()); modelListener.onResponse( - new ElasticRerankerModel( - inferenceEntityId, - taskType, - NAME, - new ElasticRerankerServiceSettings(esServiceSettingsBuilder.build()), - RerankTaskSettings.fromMap(taskSettingsMap) - ) + new ElasticRerankerModel(inferenceEntityId, taskType, NAME, serviceSettings, RerankTaskSettings.fromMap(taskSettingsMap)) ); } @@ -535,7 +529,7 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M inferenceEntityId, taskType, NAME, - new ElasticRerankerServiceSettings(ElasticsearchInternalServiceSettings.fromPersistedMap(serviceSettingsMap)), + ElasticRerankerServiceSettings.fromMap(serviceSettingsMap), RerankTaskSettings.fromMap(taskSettingsMap) ); } else { @@ -688,8 +682,18 @@ public void inferRerank( Map requestTaskSettings, ActionListener listener ) { - var rerankChunker = new RerankRequestChunker(query, inputs); - var chunkedInputs = rerankChunker.getChunkedInputs(); + var chunkedInputs = inputs; + var resultsListener = listener; + if (model instanceof ElasticRerankerModel elasticRerankerModel) { + var serviceSettings = elasticRerankerModel.getServiceSettings(); + var longDocumentHandlingStrategy = serviceSettings.getLongDocumentHandlingStrategy(); + if (longDocumentHandlingStrategy == ElasticRerankerServiceSettings.LongDocumentHandlingStrategy.CHUNK) { + var rerankChunker = new RerankRequestChunker(query, inputs, serviceSettings.getMaxChunksPerDoc()); + chunkedInputs = rerankChunker.getChunkedInputs(); + resultsListener = rerankChunker.parseChunkedRerankResultsListener(listener); + } + + } var request = buildInferenceRequest( model.mlNodeDeploymentId(), new TextSimilarityConfigUpdate(query), @@ -708,12 +712,11 @@ public void inferRerank( Function inputSupplier = returnDocs == Boolean.TRUE ? chunkedInputs::get : i -> null; - ActionListener mlResultsListener = rerankChunker.parseChunkedRerankResultsListener(listener) - .delegateFailureAndWrap( - (l, inferenceResult) -> l.onResponse( - textSimilarityResultsToRankedDocs(inferenceResult.getInferenceResults(), inputSupplier, topN) - ) - ); + ActionListener mlResultsListener = resultsListener.delegateFailureAndWrap( + (l, inferenceResult) -> l.onResponse( + textSimilarityResultsToRankedDocs(inferenceResult.getInferenceResults(), inputSupplier, topN) + ) + ); var maybeDeployListener = mlResultsListener.delegateResponse( (l, exception) -> maybeStartDeployment(model, exception, request, mlResultsListener) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilderTests.java index cc464a933481f..00452e6beb313 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilderTests.java @@ -62,7 +62,7 @@ public void testBuildChunkingSettingsForElasticReranker_QueryTokenCountLessThanH int expectedMaxChunkSize = (int) ((ELASTIC_RERANKER_TOKEN_LIMIT - ELASTIC_RERANKER_EXTRA_TOKEN_COUNT - queryTokenCount) * WORDS_PER_TOKEN); assertEquals(expectedMaxChunkSize, sentenceBoundaryChunkingSettings.maxChunkSize); - assertEquals(1, sentenceBoundaryChunkingSettings.sentenceOverlap); + assertEquals(0, sentenceBoundaryChunkingSettings.sentenceOverlap); } public void testBuildChunkingSettingsForElasticReranker_QueryTokenCountMoreThanHalfOfTokenLimit() { @@ -74,7 +74,7 @@ public void testBuildChunkingSettingsForElasticReranker_QueryTokenCountMoreThanH SentenceBoundaryChunkingSettings sentenceBoundaryChunkingSettings = (SentenceBoundaryChunkingSettings) chunkingSettings; int expectedMaxChunkSize = (int) (Math.floor((float) ELASTIC_RERANKER_TOKEN_LIMIT / 2) * WORDS_PER_TOKEN); assertEquals(expectedMaxChunkSize, sentenceBoundaryChunkingSettings.maxChunkSize); - assertEquals(1, sentenceBoundaryChunkingSettings.sentenceOverlap); + assertEquals(0, sentenceBoundaryChunkingSettings.sentenceOverlap); } private Map, ChunkingSettings> chunkingSettingsMapToChunkingSettings() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/RerankRequestChunkerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/RerankRequestChunkerTests.java index f882853e041b3..5c4f4280f270d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/RerankRequestChunkerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/RerankRequestChunkerTests.java @@ -20,46 +20,46 @@ public class RerankRequestChunkerTests extends ESTestCase { private static final String TEST_SENTENCE = "This is a test sentence that has ten total words. "; public void testGetChunkedInput_EmptyInput() { - var chunker = new RerankRequestChunker(TEST_SENTENCE, List.of()); + var chunker = new RerankRequestChunker(TEST_SENTENCE, List.of(), null); assertTrue(chunker.getChunkedInputs().isEmpty()); } public void testGetChunkedInput_SingleInputWithoutChunkingRequired() { var inputs = List.of(generateTestText(10)); - var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs); + var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs, null); assertEquals(inputs, chunker.getChunkedInputs()); } public void testGetChunkedInput_SingleInputWithChunkingRequired() { var inputs = List.of(generateTestText(100)); - var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs); + var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs, null); var chunkedInputs = chunker.getChunkedInputs(); assertEquals(3, chunkedInputs.size()); } public void testGetChunkedInput_MultipleInputsWithoutChunkingRequired() { var inputs = List.of(generateTestText(10), generateTestText(10)); - var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs); + var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs, null); assertEquals(inputs, chunker.getChunkedInputs()); } public void testGetChunkedInput_MultipleInputsWithSomeChunkingRequired() { var inputs = List.of(generateTestText(10), generateTestText(100)); - var chunker = new RerankRequestChunker(randomAlphaOfLength(10), inputs); + var chunker = new RerankRequestChunker(randomAlphaOfLength(10), inputs, null); var chunkedInputs = chunker.getChunkedInputs(); assertEquals(4, chunkedInputs.size()); } public void testGetChunkedInput_MultipleInputsWithAllRequiringChunking() { var inputs = List.of(generateTestText(100), generateTestText(100)); - var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs); + var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs, null); var chunkedInputs = chunker.getChunkedInputs(); assertEquals(6, chunkedInputs.size()); } public void testParseChunkedRerankResultsListener_NonRankedDocsResults() { var inputs = List.of(generateTestText(10), generateTestText(100)); - var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs); + var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs, null); var listener = chunker.parseChunkedRerankResultsListener( ActionListener.wrap( results -> fail("Expected failure but got: " + results.getClass()), @@ -72,7 +72,7 @@ public void testParseChunkedRerankResultsListener_NonRankedDocsResults() { } public void testParseChunkedRerankResultsListener_EmptyInput() { - var chunker = new RerankRequestChunker(TEST_SENTENCE, List.of()); + var chunker = new RerankRequestChunker(TEST_SENTENCE, List.of(), null); var listener = chunker.parseChunkedRerankResultsListener(ActionListener.wrap(results -> { assertThat(results, instanceOf(RankedDocsResults.class)); var rankedDocResults = (RankedDocsResults) results; @@ -83,7 +83,7 @@ public void testParseChunkedRerankResultsListener_EmptyInput() { public void testParseChunkedRerankResultsListener_SingleInputWithoutChunking() { var inputs = List.of(generateTestText(10)); - var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs); + var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs, null); var listener = chunker.parseChunkedRerankResultsListener(ActionListener.wrap(results -> { assertThat(results, instanceOf(RankedDocsResults.class)); var rankedDocResults = (RankedDocsResults) results; @@ -99,7 +99,7 @@ public void testParseChunkedRerankResultsListener_SingleInputWithChunking() { var inputs = List.of(generateTestText(100)); var relevanceScore1 = randomFloatBetween(0, 1, true); var relevanceScore2 = randomFloatBetween(0, 1, true); - var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs); + var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs, null); var listener = chunker.parseChunkedRerankResultsListener(ActionListener.wrap(results -> { assertThat(results, instanceOf(RankedDocsResults.class)); var rankedDocResults = (RankedDocsResults) results; @@ -120,7 +120,7 @@ public void testParseChunkedRerankResultsListener_SingleInputWithChunking() { public void testParseChunkedRerankResultsListener_MultipleInputsWithoutChunking() { var inputs = List.of(generateTestText(10), generateTestText(10)); - var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs); + var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs, null); var listener = chunker.parseChunkedRerankResultsListener(ActionListener.wrap(results -> { assertThat(results, instanceOf(RankedDocsResults.class)); var rankedDocResults = (RankedDocsResults) results; @@ -141,7 +141,7 @@ public void testParseChunkedRerankResultsListener_MultipleInputsWithoutChunking( public void testParseChunkedRerankResultsListener_MultipleInputsWithSomeChunking() { var inputs = List.of(generateTestText(10), generateTestText(100)); - var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs); + var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs, null); var listener = chunker.parseChunkedRerankResultsListener(ActionListener.wrap(results -> { assertThat(results, instanceOf(RankedDocsResults.class)); var rankedDocResults = (RankedDocsResults) results; @@ -163,7 +163,7 @@ public void testParseChunkedRerankResultsListener_MultipleInputsWithSomeChunking public void testParseChunkedRerankResultsListener_MultipleInputsWithAllRequiringChunking() { var inputs = List.of(generateTestText(100), generateTestText(100)); - var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs); + var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs, null); var listener = chunker.parseChunkedRerankResultsListener(ActionListener.wrap(results -> { assertThat(results, instanceOf(RankedDocsResults.class)); var rankedDocResults = (RankedDocsResults) results; From 24497aec7b506b091413b509c0daa79c9c3cf051 Mon Sep 17 00:00:00 2001 From: dan-rubinstein Date: Thu, 11 Sep 2025 14:41:52 -0400 Subject: [PATCH 5/8] Adding back sentence overlap for rerank chunking strategy --- .../xpack/inference/chunking/ChunkingSettingsBuilder.java | 2 +- .../inference/chunking/ChunkingSettingsBuilderTests.java | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilder.java index 8857f01d26f4c..12c41c7ea9470 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilder.java @@ -65,6 +65,6 @@ public static ChunkingSettings buildChunkingSettingsForElasticRerank(int queryWo } var maxChunkSizeWordCount = (int) (maxChunkSizeTokenCount * WORDS_PER_TOKEN); - return new SentenceBoundaryChunkingSettings(maxChunkSizeWordCount, 0); + return new SentenceBoundaryChunkingSettings(maxChunkSizeWordCount, 1); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilderTests.java index 00452e6beb313..cc464a933481f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilderTests.java @@ -62,7 +62,7 @@ public void testBuildChunkingSettingsForElasticReranker_QueryTokenCountLessThanH int expectedMaxChunkSize = (int) ((ELASTIC_RERANKER_TOKEN_LIMIT - ELASTIC_RERANKER_EXTRA_TOKEN_COUNT - queryTokenCount) * WORDS_PER_TOKEN); assertEquals(expectedMaxChunkSize, sentenceBoundaryChunkingSettings.maxChunkSize); - assertEquals(0, sentenceBoundaryChunkingSettings.sentenceOverlap); + assertEquals(1, sentenceBoundaryChunkingSettings.sentenceOverlap); } public void testBuildChunkingSettingsForElasticReranker_QueryTokenCountMoreThanHalfOfTokenLimit() { @@ -74,7 +74,7 @@ public void testBuildChunkingSettingsForElasticReranker_QueryTokenCountMoreThanH SentenceBoundaryChunkingSettings sentenceBoundaryChunkingSettings = (SentenceBoundaryChunkingSettings) chunkingSettings; int expectedMaxChunkSize = (int) (Math.floor((float) ELASTIC_RERANKER_TOKEN_LIMIT / 2) * WORDS_PER_TOKEN); assertEquals(expectedMaxChunkSize, sentenceBoundaryChunkingSettings.maxChunkSize); - assertEquals(0, sentenceBoundaryChunkingSettings.sentenceOverlap); + assertEquals(1, sentenceBoundaryChunkingSettings.sentenceOverlap); } private Map, ChunkingSettings> chunkingSettingsMapToChunkingSettings() { From 8b977115686bf9150c47e4750064adf0cd15830f Mon Sep 17 00:00:00 2001 From: dan-rubinstein Date: Thu, 18 Sep 2025 15:25:52 -0400 Subject: [PATCH 6/8] Adding unit tests, transport version, and feature flag --- ...lastic_reranker_chunking_configuration.csv | 1 + .../resources/transport/upper_bounds/9.2.csv | 2 +- .../test/cluster/FeatureFlag.java | 3 +- .../action/TransportInferenceAction.java | 1 - .../chunking/RerankRequestChunker.java | 9 +- .../ElasticRerankerServiceSettings.java | 113 +++-- .../ElasticsearchInternalService.java | 33 +- .../chunking/RerankRequestChunkerTests.java | 82 +++- .../ElasticRerankerServiceSettingsTests.java | 402 ++++++++++++++++++ .../ElasticsearchInternalServiceTests.java | 164 +++++++ 10 files changed, 731 insertions(+), 79 deletions(-) create mode 100644 server/src/main/resources/transport/definitions/referable/elastic_reranker_chunking_configuration.csv create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticRerankerServiceSettingsTests.java diff --git a/server/src/main/resources/transport/definitions/referable/elastic_reranker_chunking_configuration.csv b/server/src/main/resources/transport/definitions/referable/elastic_reranker_chunking_configuration.csv new file mode 100644 index 0000000000000..40081d05c7097 --- /dev/null +++ b/server/src/main/resources/transport/definitions/referable/elastic_reranker_chunking_configuration.csv @@ -0,0 +1 @@ +9168000 diff --git a/server/src/main/resources/transport/upper_bounds/9.2.csv b/server/src/main/resources/transport/upper_bounds/9.2.csv index bf1a90e5be4e9..13cfff23d5f33 100644 --- a/server/src/main/resources/transport/upper_bounds/9.2.csv +++ b/server/src/main/resources/transport/upper_bounds/9.2.csv @@ -1 +1 @@ -index_request_include_tsid,9167000 +elastic_reranker_chunking_configuration,9168000 diff --git a/test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java b/test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java index 6aa61e2ed38e1..a2e47cf2d0425 100644 --- a/test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java +++ b/test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java @@ -28,7 +28,8 @@ public enum FeatureFlag { "es.index_dimensions_tsid_optimization_feature_flag_enabled=true", Version.fromString("9.2.0"), null - ); + ), + ELASTIC_RERANKER_CHUNKING("es.elastic_reranker_chunking_long_documents=true", Version.fromString("9.2.0"), null); public final String systemProperty; public final Version from; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java index e7f71d7d3beec..8d7a37ca52ea7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java @@ -70,7 +70,6 @@ protected void doInference( InferenceService service, ActionListener listener ) { - service.infer( model, request.getQuery(), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RerankRequestChunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RerankRequestChunker.java index 2044d8ebef0b9..05bd5bbde8741 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RerankRequestChunker.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RerankRequestChunker.java @@ -34,10 +34,7 @@ private List chunk(List inputs, ChunkingSettings chunkingS for (int i = 0; i < inputs.size(); i++) { var chunksForInput = chunker.chunk(inputs.get(i), chunkingSettings); if (maxChunksPerDoc != null && chunksForInput.size() > maxChunksPerDoc) { - var limitedChunks = chunksForInput.subList(0, maxChunksPerDoc - 1); - var lastChunk = limitedChunks.getLast(); - limitedChunks.add(new Chunker.ChunkOffset(lastChunk.end(), inputs.get(i).length())); - chunksForInput = limitedChunks; + chunksForInput = chunksForInput.subList(0, maxChunksPerDoc); } for (var chunk : chunksForInput) { @@ -53,7 +50,6 @@ public List getChunkedInputs() { chunkedInputs.add(chunk.chunkString()); } - // TODO: Score the inputs here and only return the top N chunks for each document return chunkedInputs; } @@ -69,8 +65,6 @@ public ActionListener parseChunkedRerankResultsListener }, listener::onFailure); } - // TODO: Can we assume the rankeddocsresults are always sorted by relevance score? - // TODO: Should we short circuit if no chunking was done? private RankedDocsResults parseRankedDocResultsForChunks(RankedDocsResults rankedDocsResults) { List updatedRankedDocs = new ArrayList<>(); Set docIndicesSeen = new HashSet<>(); @@ -89,6 +83,7 @@ private RankedDocsResults parseRankedDocResultsForChunks(RankedDocsResults ranke docIndicesSeen.add(docIndex); } } + updatedRankedDocs.sort((r1, r2) -> Float.compare(r2.relevanceScore(), r1.relevanceScore())); return new RankedDocsResults(updatedRankedDocs); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticRerankerServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticRerankerServiceSettings.java index 537c0fc0c307f..dbf7c5132c996 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticRerankerServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticRerankerServiceSettings.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.inference.services.elasticsearch; +import org.elasticsearch.TransportVersion; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -21,35 +22,34 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; +import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.ELASTIC_RERANKER_CHUNKING; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.RERANKER_ID; public class ElasticRerankerServiceSettings extends ElasticsearchInternalServiceSettings { public static final String NAME = "elastic_reranker_service_settings"; - private static final String LONG_DOCUMENT_HANDLING_STRATEGY = "long_document_handling_strategy"; - private static final String MAX_CHUNKS_PER_DOC = "max_chunks_per_doc"; + public static final String LONG_DOCUMENT_STRATEGY = "long_document_strategy"; + public static final String MAX_CHUNKS_PER_DOC = "max_chunks_per_doc"; - private final LongDocumentHandlingStrategy longDocumentHandlingStrategy; + private static final TransportVersion ELASTIC_RERANKER_CHUNKING_CONFIGURATION = TransportVersion.fromName( + "elastic_reranker_chunking_configuration" + ); + + private final LongDocumentStrategy longDocumentStrategy; private final Integer maxChunksPerDoc; public static ElasticRerankerServiceSettings defaultEndpointSettings() { return new ElasticRerankerServiceSettings(null, 1, RERANKER_ID, new AdaptiveAllocationsSettings(Boolean.TRUE, 0, 32)); } - public ElasticRerankerServiceSettings(ElasticsearchInternalServiceSettings other) { - super(other); - this.longDocumentHandlingStrategy = null; - this.maxChunksPerDoc = null; - } - public ElasticRerankerServiceSettings( ElasticsearchInternalServiceSettings other, - LongDocumentHandlingStrategy longDocumentHandlingStrategy, + LongDocumentStrategy longDocumentStrategy, Integer maxChunksPerDoc ) { super(other); - this.longDocumentHandlingStrategy = longDocumentHandlingStrategy; + this.longDocumentStrategy = longDocumentStrategy; this.maxChunksPerDoc = maxChunksPerDoc; } @@ -61,15 +61,32 @@ private ElasticRerankerServiceSettings( AdaptiveAllocationsSettings adaptiveAllocationsSettings ) { super(numAllocations, numThreads, modelId, adaptiveAllocationsSettings, null); - this.longDocumentHandlingStrategy = null; + this.longDocumentStrategy = null; this.maxChunksPerDoc = null; } + protected ElasticRerankerServiceSettings( + Integer numAllocations, + int numThreads, + String modelId, + AdaptiveAllocationsSettings adaptiveAllocationsSettings, + LongDocumentStrategy longDocumentStrategy, + Integer maxChunksPerDoc + ) { + super(numAllocations, numThreads, modelId, adaptiveAllocationsSettings, null); + this.longDocumentStrategy = longDocumentStrategy; + this.maxChunksPerDoc = maxChunksPerDoc; + } + public ElasticRerankerServiceSettings(StreamInput in) throws IOException { super(in); - // TODO: Add transport version here - this.longDocumentHandlingStrategy = in.readOptionalEnum(LongDocumentHandlingStrategy.class); - this.maxChunksPerDoc = in.readOptionalInt(); + if (in.getTransportVersion().supports(ELASTIC_RERANKER_CHUNKING_CONFIGURATION)) { + this.longDocumentStrategy = in.readOptionalEnum(LongDocumentStrategy.class); + this.maxChunksPerDoc = in.readOptionalInt(); + } else { + this.longDocumentStrategy = null; + this.maxChunksPerDoc = null; + } } /** @@ -85,38 +102,41 @@ public static ElasticRerankerServiceSettings fromMap(Map map) { ValidationException validationException = new ValidationException(); var baseSettings = ElasticsearchInternalServiceSettings.fromMap(map, validationException); - LongDocumentHandlingStrategy longDocumentHandlingStrategy = extractOptionalEnum( - map, - LONG_DOCUMENT_HANDLING_STRATEGY, - ModelConfigurations.SERVICE_SETTINGS, - LongDocumentHandlingStrategy::fromString, - EnumSet.allOf(LongDocumentHandlingStrategy.class), - validationException - ); - - Integer maxChunksPerDoc = extractOptionalPositiveInteger( - map, - MAX_CHUNKS_PER_DOC, - ModelConfigurations.SERVICE_SETTINGS, - validationException - ); - - if (maxChunksPerDoc != null - && (longDocumentHandlingStrategy == null || longDocumentHandlingStrategy == LongDocumentHandlingStrategy.TRUNCATE)) { - validationException.addValidationError( - "The [" + MAX_CHUNKS_PER_DOC + "] setting requires [" + LONG_DOCUMENT_HANDLING_STRATEGY + "] to be set to [chunk]" + LongDocumentStrategy longDocumentStrategy = null; + Integer maxChunksPerDoc = null; + if (ELASTIC_RERANKER_CHUNKING.isEnabled()) { + longDocumentStrategy = extractOptionalEnum( + map, + LONG_DOCUMENT_STRATEGY, + ModelConfigurations.SERVICE_SETTINGS, + LongDocumentStrategy::fromString, + EnumSet.allOf(LongDocumentStrategy.class), + validationException ); + + maxChunksPerDoc = extractOptionalPositiveInteger( + map, + MAX_CHUNKS_PER_DOC, + ModelConfigurations.SERVICE_SETTINGS, + validationException + ); + + if (maxChunksPerDoc != null && (longDocumentStrategy == null || longDocumentStrategy == LongDocumentStrategy.TRUNCATE)) { + validationException.addValidationError( + "The [" + MAX_CHUNKS_PER_DOC + "] setting requires [" + LONG_DOCUMENT_STRATEGY + "] to be set to [chunk]" + ); + } } if (validationException.validationErrors().isEmpty() == false) { throw validationException; } - return new ElasticRerankerServiceSettings(baseSettings.build(), longDocumentHandlingStrategy, maxChunksPerDoc); + return new ElasticRerankerServiceSettings(baseSettings.build(), longDocumentStrategy, maxChunksPerDoc); } - public LongDocumentHandlingStrategy getLongDocumentHandlingStrategy() { - return longDocumentHandlingStrategy; + public LongDocumentStrategy getLongDocumentStrategy() { + return longDocumentStrategy; } public Integer getMaxChunksPerDoc() { @@ -126,17 +146,18 @@ public Integer getMaxChunksPerDoc() { @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); - // TODO: Add transport version here - out.writeOptionalEnum(longDocumentHandlingStrategy); - out.writeOptionalInt(maxChunksPerDoc); + if (out.getTransportVersion().supports(ELASTIC_RERANKER_CHUNKING_CONFIGURATION)) { + out.writeOptionalEnum(longDocumentStrategy); + out.writeOptionalInt(maxChunksPerDoc); + } } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); addInternalSettingsToXContent(builder, params); - if (longDocumentHandlingStrategy != null) { - builder.field(LONG_DOCUMENT_HANDLING_STRATEGY, longDocumentHandlingStrategy.strategyName); + if (longDocumentStrategy != null) { + builder.field(LONG_DOCUMENT_STRATEGY, longDocumentStrategy.strategyName); } if (maxChunksPerDoc != null) { builder.field(MAX_CHUNKS_PER_DOC, maxChunksPerDoc); @@ -150,17 +171,17 @@ public String getWriteableName() { return ElasticRerankerServiceSettings.NAME; } - public enum LongDocumentHandlingStrategy { + public enum LongDocumentStrategy { CHUNK("chunk"), TRUNCATE("truncate"); public final String strategyName; - LongDocumentHandlingStrategy(String strategyName) { + LongDocumentStrategy(String strategyName) { this.strategyName = strategyName; } - public static LongDocumentHandlingStrategy fromString(String name) { + public static LongDocumentStrategy fromString(String name) { return valueOf(name.trim().toUpperCase(Locale.ROOT)); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index 432a71e8b9b79..042b056defb2f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -16,6 +16,7 @@ import org.elasticsearch.common.logging.DeprecationCategory; import org.elasticsearch.common.logging.DeprecationLogger; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.FeatureFlag; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Strings; @@ -115,6 +116,8 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi private static final Logger logger = LogManager.getLogger(ElasticsearchInternalService.class); private static final DeprecationLogger DEPRECATION_LOGGER = DeprecationLogger.getLogger(ElasticsearchInternalService.class); + public static final FeatureFlag ELASTIC_RERANKER_CHUNKING = new FeatureFlag("elastic_reranker_chunking_long_documents"); + /** * Fix for https://github.com/elastic/elasticsearch/issues/124675 * In 8.13.0 we transitioned from model_version to model_id. Any elser inference endpoints created prior to 8.13.0 will still use @@ -684,14 +687,25 @@ public void inferRerank( ActionListener listener ) { var chunkedInputs = inputs; - var resultsListener = listener; - if (model instanceof ElasticRerankerModel elasticRerankerModel) { + ActionListener resultsListener = listener.delegateFailure((l, results) -> { + if (results instanceof RankedDocsResults rankedDocsResults) { + if (topN != null) { + l.onResponse(new RankedDocsResults(rankedDocsResults.getRankedDocs().subList(0, topN))); + } else { + l.onResponse(rankedDocsResults); + } + } else { + l.onFailure(new IllegalArgumentException("Expected RankedDocsResults but got: " + results.getClass().getName())); + } + }); + + if (model instanceof ElasticRerankerModel elasticRerankerModel && ELASTIC_RERANKER_CHUNKING.isEnabled()) { var serviceSettings = elasticRerankerModel.getServiceSettings(); - var longDocumentHandlingStrategy = serviceSettings.getLongDocumentHandlingStrategy(); - if (longDocumentHandlingStrategy == ElasticRerankerServiceSettings.LongDocumentHandlingStrategy.CHUNK) { + var longDocumentStrategy = serviceSettings.getLongDocumentStrategy(); + if (longDocumentStrategy == ElasticRerankerServiceSettings.LongDocumentStrategy.CHUNK) { var rerankChunker = new RerankRequestChunker(query, inputs, serviceSettings.getMaxChunksPerDoc()); chunkedInputs = rerankChunker.getChunkedInputs(); - resultsListener = rerankChunker.parseChunkedRerankResultsListener(listener); + resultsListener = rerankChunker.parseChunkedRerankResultsListener(resultsListener); } } @@ -714,9 +728,7 @@ public void inferRerank( Function inputSupplier = returnDocs == Boolean.TRUE ? chunkedInputs::get : i -> null; ActionListener mlResultsListener = resultsListener.delegateFailureAndWrap( - (l, inferenceResult) -> l.onResponse( - textSimilarityResultsToRankedDocs(inferenceResult.getInferenceResults(), inputSupplier, topN) - ) + (l, inferenceResult) -> l.onResponse(textSimilarityResultsToRankedDocs(inferenceResult.getInferenceResults(), inputSupplier)) ); var maybeDeployListener = mlResultsListener.delegateResponse( @@ -825,8 +837,7 @@ public List aliases() { private RankedDocsResults textSimilarityResultsToRankedDocs( List results, - Function inputSupplier, - @Nullable Integer topN + Function inputSupplier ) { List rankings = new ArrayList<>(results.size()); for (int i = 0; i < results.size(); i++) { @@ -853,7 +864,7 @@ private RankedDocsResults textSimilarityResultsToRankedDocs( } Collections.sort(rankings); - return new RankedDocsResults(topN != null ? rankings.subList(0, topN) : rankings); + return new RankedDocsResults(rankings); } public List defaultConfigIds() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/RerankRequestChunkerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/RerankRequestChunkerTests.java index 5c4f4280f270d..0fbee5a16109d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/RerankRequestChunkerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/RerankRequestChunkerTests.java @@ -14,10 +14,11 @@ import java.util.List; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.instanceOf; public class RerankRequestChunkerTests extends ESTestCase { - private static final String TEST_SENTENCE = "This is a test sentence that has ten total words. "; + private final String TEST_SENTENCE = "This is a test sentence that has ten total words. "; public void testGetChunkedInput_EmptyInput() { var chunker = new RerankRequestChunker(TEST_SENTENCE, List.of(), null); @@ -26,7 +27,7 @@ public void testGetChunkedInput_EmptyInput() { public void testGetChunkedInput_SingleInputWithoutChunkingRequired() { var inputs = List.of(generateTestText(10)); - var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs, null); + var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs, randomBoolean() ? null : randomIntBetween(1, 10)); assertEquals(inputs, chunker.getChunkedInputs()); } @@ -37,9 +38,24 @@ public void testGetChunkedInput_SingleInputWithChunkingRequired() { assertEquals(3, chunkedInputs.size()); } + public void testGetChunkedInput_SingleInputWithChunkingRequiredWithMaxChunksPerDocLessThanTotalChunksGenerated() { + var inputs = List.of(generateTestText(100)); + var maxChunksPerDoc = randomIntBetween(1, 2); + var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs, maxChunksPerDoc); + var chunkedInputs = chunker.getChunkedInputs(); + assertEquals(maxChunksPerDoc, chunkedInputs.size()); + } + + public void testGetChunkedInput_SingleInputWithChunkingRequiredWithMaxChunksPerDocGreaterThanTotalChunksGenerated() { + var inputs = List.of(generateTestText(100)); + var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs, randomIntBetween(4, 10)); + var chunkedInputs = chunker.getChunkedInputs(); + assertEquals(3, chunkedInputs.size()); + } + public void testGetChunkedInput_MultipleInputsWithoutChunkingRequired() { var inputs = List.of(generateTestText(10), generateTestText(10)); - var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs, null); + var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs, randomBoolean() ? null : randomIntBetween(1, 10)); assertEquals(inputs, chunker.getChunkedInputs()); } @@ -50,6 +66,21 @@ public void testGetChunkedInput_MultipleInputsWithSomeChunkingRequired() { assertEquals(4, chunkedInputs.size()); } + public void testGetChunkedInput_MultipleInputsWithSomeChunkingRequiredWithMaxChunksPerDocLessThanTotalChunksGenerated() { + var inputs = List.of(generateTestText(10), generateTestText(100)); + var maxChunksPerDoc = randomIntBetween(1, 2); + var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs, maxChunksPerDoc); + var chunkedInputs = chunker.getChunkedInputs(); + assertEquals(1 + maxChunksPerDoc, chunkedInputs.size()); + } + + public void testGetChunkedInput_MultipleInputsWithSomeChunkingRequiredWithMaxChunksPerDocGreaterThanTotalChunksGenerated() { + var inputs = List.of(generateTestText(10), generateTestText(100)); + var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs, randomIntBetween(3, 10)); + var chunkedInputs = chunker.getChunkedInputs(); + assertEquals(4, chunkedInputs.size()); + } + public void testGetChunkedInput_MultipleInputsWithAllRequiringChunking() { var inputs = List.of(generateTestText(100), generateTestText(100)); var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs, null); @@ -57,6 +88,21 @@ public void testGetChunkedInput_MultipleInputsWithAllRequiringChunking() { assertEquals(6, chunkedInputs.size()); } + public void testGetChunkedInput_MultipleInputsWithAllRequiringChunkingWithMaxChunksPerDocLessThanTotalChunksGenerated() { + var inputs = List.of(generateTestText(100), generateTestText(100)); + var maxChunksPerDoc = randomIntBetween(1, 2); + var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs, maxChunksPerDoc); + var chunkedInputs = chunker.getChunkedInputs(); + assertEquals(2 * maxChunksPerDoc, chunkedInputs.size()); + } + + public void testGetChunkedInput_MultipleInputsWithAllRequiringChunkingWithMaxChunksPerDocGreaterThanTotalChunksGenerated() { + var inputs = List.of(generateTestText(100), generateTestText(100)); + var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs, randomIntBetween(4, 10)); + var chunkedInputs = chunker.getChunkedInputs(); + assertEquals(6, chunkedInputs.size()); + } + public void testParseChunkedRerankResultsListener_NonRankedDocsResults() { var inputs = List.of(generateTestText(10), generateTestText(100)); var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs, null); @@ -125,6 +171,10 @@ public void testParseChunkedRerankResultsListener_MultipleInputsWithoutChunking( assertThat(results, instanceOf(RankedDocsResults.class)); var rankedDocResults = (RankedDocsResults) results; assertEquals(2, rankedDocResults.getRankedDocs().size()); + assertThat( + rankedDocResults.getRankedDocs().get(0).relevanceScore(), + greaterThanOrEqualTo(rankedDocResults.getRankedDocs().get(1).relevanceScore()) + ); }, e -> fail("Expected successful parsing but got failure: " + e))); var chunkedInputs = chunker.getChunkedInputs(); @@ -132,8 +182,8 @@ public void testParseChunkedRerankResultsListener_MultipleInputsWithoutChunking( listener.onResponse( new RankedDocsResults( List.of( - new RankedDocsResults.RankedDoc(0, 1.0f, chunkedInputs.get(0)), - new RankedDocsResults.RankedDoc(1, 1.0f, chunkedInputs.get(1)) + new RankedDocsResults.RankedDoc(0, randomFloatBetween(0, 1, true), chunkedInputs.get(0)), + new RankedDocsResults.RankedDoc(1, randomFloatBetween(0, 1, true), chunkedInputs.get(1)) ) ) ); @@ -146,6 +196,10 @@ public void testParseChunkedRerankResultsListener_MultipleInputsWithSomeChunking assertThat(results, instanceOf(RankedDocsResults.class)); var rankedDocResults = (RankedDocsResults) results; assertEquals(2, rankedDocResults.getRankedDocs().size()); + assertThat( + rankedDocResults.getRankedDocs().get(0).relevanceScore(), + greaterThanOrEqualTo(rankedDocResults.getRankedDocs().get(1).relevanceScore()) + ); }, e -> fail("Expected successful parsing but got failure: " + e))); var chunkedInputs = chunker.getChunkedInputs(); @@ -153,9 +207,9 @@ public void testParseChunkedRerankResultsListener_MultipleInputsWithSomeChunking listener.onResponse( new RankedDocsResults( List.of( - new RankedDocsResults.RankedDoc(0, 1.0f, chunkedInputs.get(0)), - new RankedDocsResults.RankedDoc(1, 1.0f, chunkedInputs.get(1)), - new RankedDocsResults.RankedDoc(2, 1.0f, chunkedInputs.get(2)) + new RankedDocsResults.RankedDoc(0, randomFloatBetween(0, 1, true), chunkedInputs.get(0)), + new RankedDocsResults.RankedDoc(1, randomFloatBetween(0, 1, true), chunkedInputs.get(1)), + new RankedDocsResults.RankedDoc(2, randomFloatBetween(0, 1, true), chunkedInputs.get(2)) ) ) ); @@ -168,6 +222,10 @@ public void testParseChunkedRerankResultsListener_MultipleInputsWithAllRequiring assertThat(results, instanceOf(RankedDocsResults.class)); var rankedDocResults = (RankedDocsResults) results; assertEquals(2, rankedDocResults.getRankedDocs().size()); + assertThat( + rankedDocResults.getRankedDocs().get(0).relevanceScore(), + greaterThanOrEqualTo(rankedDocResults.getRankedDocs().get(1).relevanceScore()) + ); }, e -> fail("Expected successful parsing but got failure: " + e))); var chunkedInputs = chunker.getChunkedInputs(); @@ -175,10 +233,10 @@ public void testParseChunkedRerankResultsListener_MultipleInputsWithAllRequiring listener.onResponse( new RankedDocsResults( List.of( - new RankedDocsResults.RankedDoc(0, 1.0f, chunkedInputs.get(0)), - new RankedDocsResults.RankedDoc(1, 1.0f, chunkedInputs.get(1)), - new RankedDocsResults.RankedDoc(2, 1.0f, chunkedInputs.get(2)), - new RankedDocsResults.RankedDoc(3, 1.0f, chunkedInputs.get(3)) + new RankedDocsResults.RankedDoc(0, randomFloatBetween(0, 1, true), chunkedInputs.get(0)), + new RankedDocsResults.RankedDoc(1, randomFloatBetween(0, 1, true), chunkedInputs.get(1)), + new RankedDocsResults.RankedDoc(2, randomFloatBetween(0, 1, true), chunkedInputs.get(2)), + new RankedDocsResults.RankedDoc(3, randomFloatBetween(0, 1, true), chunkedInputs.get(3)) ) ) ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticRerankerServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticRerankerServiceSettingsTests.java new file mode 100644 index 0000000000000..c9ee6a0543140 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticRerankerServiceSettingsTests.java @@ -0,0 +1,402 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elasticsearch; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; +import org.junit.Assert; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; + +import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticRerankerServiceSettings.LONG_DOCUMENT_STRATEGY; +import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticRerankerServiceSettings.MAX_CHUNKS_PER_DOC; +import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.ELASTIC_RERANKER_CHUNKING; +import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings.ADAPTIVE_ALLOCATIONS; +import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings.MODEL_ID; +import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings.NUM_ALLOCATIONS; +import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings.NUM_THREADS; + +public class ElasticRerankerServiceSettingsTests extends AbstractWireSerializingTestCase { + public static ElasticRerankerServiceSettings createRandomWithoutChunkingConfiguration() { + return createRandom(null, null); + } + + public static ElasticRerankerServiceSettings createRandomWithChunkingConfiguration( + ElasticRerankerServiceSettings.LongDocumentStrategy longDocumentStrategy, + Integer maxChunksPerDoc + ) { + return createRandom(longDocumentStrategy, maxChunksPerDoc); + } + + public static ElasticRerankerServiceSettings createRandom() { + var longDocumentStrategy = ELASTIC_RERANKER_CHUNKING.isEnabled() + ? randomFrom(ElasticRerankerServiceSettings.LongDocumentStrategy.values()) + : null; + var maxChunksPerDoc = ELASTIC_RERANKER_CHUNKING.isEnabled() + && ElasticRerankerServiceSettings.LongDocumentStrategy.CHUNK.equals(longDocumentStrategy) + && randomBoolean() ? randomIntBetween(1, 10) : null; + return createRandom(longDocumentStrategy, maxChunksPerDoc); + } + + private static ElasticRerankerServiceSettings createRandom( + ElasticRerankerServiceSettings.LongDocumentStrategy longDocumentStrategy, + Integer maxChunksPerDoc + ) { + var withAdaptiveAllocations = randomBoolean(); + var numAllocations = withAdaptiveAllocations ? null : randomIntBetween(1, 10); + var numThreads = randomIntBetween(1, 10); + var modelId = randomAlphaOfLength(8); + var adaptiveAllocationsSettings = withAdaptiveAllocations + ? new AdaptiveAllocationsSettings(true, randomIntBetween(0, 2), randomIntBetween(2, 5)) + : null; + return new ElasticRerankerServiceSettings( + numAllocations, + numThreads, + modelId, + adaptiveAllocationsSettings, + longDocumentStrategy, + maxChunksPerDoc + ); + } + + public void testFromMap_NonAdaptiveAllocationsBaseSettings_CreatesSettingsCorrectly() { + var numAllocations = randomIntBetween(1, 10); + var numThreads = randomIntBetween(1, 10); + var modelId = randomAlphaOfLength(8); + + Map settingsMap = buildServiceSettingsMap( + Optional.of(numAllocations), + numThreads, + modelId, + Optional.empty(), + Optional.empty(), + Optional.empty() + ); + + ElasticRerankerServiceSettings settings = ElasticRerankerServiceSettings.fromMap(settingsMap); + assertExpectedSettings( + settings, + Optional.of(numAllocations), + numThreads, + modelId, + Optional.empty(), + Optional.empty(), + Optional.empty() + ); + } + + public void testFromMap_AdaptiveAllocationsBaseSettings_CreatesSettingsCorrectly() { + var numThreads = randomIntBetween(1, 10); + var modelId = randomAlphaOfLength(8); + var adaptiveAllocationsSettings = new AdaptiveAllocationsSettings(true, randomIntBetween(0, 2), randomIntBetween(2, 5)); + + Map settingsMap = buildServiceSettingsMap( + Optional.empty(), + numThreads, + modelId, + Optional.of(adaptiveAllocationsSettings), + Optional.empty(), + Optional.empty() + ); + + ElasticRerankerServiceSettings settings = ElasticRerankerServiceSettings.fromMap(settingsMap); + assertExpectedSettings( + settings, + Optional.empty(), + numThreads, + modelId, + Optional.of(adaptiveAllocationsSettings), + Optional.empty(), + Optional.empty() + ); + } + + public void testFromMap_NumAllocationsAndAdaptiveAllocationsNull_ThrowsValidationException() { + var numThreads = randomIntBetween(1, 10); + var modelId = randomAlphaOfLength(8); + + Map settingsMap = buildServiceSettingsMap( + Optional.empty(), + numThreads, + modelId, + Optional.empty(), + Optional.empty(), + Optional.empty() + ); + + ValidationException exception = Assert.assertThrows( + ValidationException.class, + () -> ElasticRerankerServiceSettings.fromMap(settingsMap) + ); + + assertTrue( + exception.getMessage() + .contains("[service_settings] does not contain one of the required settings [num_allocations, adaptive_allocations]") + ); + } + + public void testFromMap_ChunkingFeatureFlagDisabledAndLongDocumentStrategyProvided_CreatesSettingsIgnoringStrategy() { + assumeTrue( + "Only if 'elastic_reranker_chunking_long_documents' feature flag is disabled", + ELASTIC_RERANKER_CHUNKING.isEnabled() == false + ); + var withAdaptiveAllocations = randomBoolean(); + var numAllocations = withAdaptiveAllocations ? null : randomIntBetween(1, 10); + var numThreads = randomIntBetween(1, 10); + var modelId = randomAlphaOfLength(8); + var adaptiveAllocationsSettings = withAdaptiveAllocations + ? new AdaptiveAllocationsSettings(true, randomIntBetween(0, 2), randomIntBetween(2, 5)) + : null; + var longDocumentStrategy = ElasticRerankerServiceSettings.LongDocumentStrategy.TRUNCATE; + + Map settingsMap = buildServiceSettingsMap( + withAdaptiveAllocations ? Optional.empty() : Optional.of(numAllocations), + numThreads, + modelId, + withAdaptiveAllocations ? Optional.of(adaptiveAllocationsSettings) : Optional.empty(), + Optional.of(longDocumentStrategy), + Optional.empty() + ); + + ElasticRerankerServiceSettings settings = ElasticRerankerServiceSettings.fromMap(settingsMap); + assertExpectedSettings( + settings, + Optional.ofNullable(numAllocations), + numThreads, + modelId, + Optional.ofNullable(adaptiveAllocationsSettings), + Optional.empty(), + Optional.empty() + ); + } + + public void testFromMap_ChunkingFeatureFlagDisabledAndMaxChunksPerDocProvided_CreatesSettingsIgnoringMaxChunksPerDoc() { + assumeTrue( + "Only if 'elastic_reranker_chunking_long_documents' feature flag is disabled", + ELASTIC_RERANKER_CHUNKING.isEnabled() == false + ); + var withAdaptiveAllocations = randomBoolean(); + var numAllocations = withAdaptiveAllocations ? null : randomIntBetween(1, 10); + var numThreads = randomIntBetween(1, 10); + var modelId = randomAlphaOfLength(8); + var adaptiveAllocationsSettings = withAdaptiveAllocations + ? new AdaptiveAllocationsSettings(true, randomIntBetween(0, 2), randomIntBetween(2, 5)) + : null; + var maxChunksPerDoc = randomIntBetween(1, 10); + + Map settingsMap = buildServiceSettingsMap( + withAdaptiveAllocations ? Optional.empty() : Optional.of(numAllocations), + numThreads, + modelId, + withAdaptiveAllocations ? Optional.of(adaptiveAllocationsSettings) : Optional.empty(), + Optional.empty(), + Optional.of(maxChunksPerDoc) + ); + + ElasticRerankerServiceSettings settings = ElasticRerankerServiceSettings.fromMap(settingsMap); + assertExpectedSettings( + settings, + Optional.ofNullable(numAllocations), + numThreads, + modelId, + Optional.ofNullable(adaptiveAllocationsSettings), + Optional.empty(), + Optional.empty() + ); + } + + public void testFromMap_ChunkingFeatureFlagEnabledAndTruncateSelected_CreatesSettingsCorrectly() { + assumeTrue("Only if 'elastic_reranker_chunking_long_documents' feature flag is enabled", ELASTIC_RERANKER_CHUNKING.isEnabled()); + var withAdaptiveAllocations = randomBoolean(); + var numAllocations = withAdaptiveAllocations ? null : randomIntBetween(1, 10); + var numThreads = randomIntBetween(1, 10); + var modelId = randomAlphaOfLength(8); + var adaptiveAllocationsSettings = withAdaptiveAllocations + ? new AdaptiveAllocationsSettings(true, randomIntBetween(0, 2), randomIntBetween(2, 5)) + : null; + var longDocumentStrategy = ElasticRerankerServiceSettings.LongDocumentStrategy.TRUNCATE; + + Map settingsMap = buildServiceSettingsMap( + withAdaptiveAllocations ? Optional.empty() : Optional.of(numAllocations), + numThreads, + modelId, + withAdaptiveAllocations ? Optional.of(adaptiveAllocationsSettings) : Optional.empty(), + Optional.of(longDocumentStrategy), + Optional.empty() + ); + + ElasticRerankerServiceSettings settings = ElasticRerankerServiceSettings.fromMap(settingsMap); + assertExpectedSettings( + settings, + Optional.ofNullable(numAllocations), + numThreads, + modelId, + Optional.ofNullable(adaptiveAllocationsSettings), + Optional.of(longDocumentStrategy), + Optional.empty() + ); + } + + public void testFromMap_ChunkingFeatureFlagEnabledAndTruncateSelectedWithMaxChunksPerDoc_ThrowsValidationException() { + assumeTrue("Only if 'elastic_reranker_chunking_long_documents' feature flag is enabled", ELASTIC_RERANKER_CHUNKING.isEnabled()); + var withAdaptiveAllocations = randomBoolean(); + var numAllocations = withAdaptiveAllocations ? null : randomIntBetween(1, 10); + var numThreads = randomIntBetween(1, 10); + var modelId = randomAlphaOfLength(8); + var adaptiveAllocationsSettings = withAdaptiveAllocations + ? new AdaptiveAllocationsSettings(true, randomIntBetween(0, 2), randomIntBetween(2, 5)) + : null; + var longDocumentStrategy = ElasticRerankerServiceSettings.LongDocumentStrategy.TRUNCATE; + var maxChunksPerDoc = randomIntBetween(1, 10); + + Map settingsMap = buildServiceSettingsMap( + withAdaptiveAllocations ? Optional.empty() : Optional.of(numAllocations), + numThreads, + modelId, + withAdaptiveAllocations ? Optional.of(adaptiveAllocationsSettings) : Optional.empty(), + Optional.of(longDocumentStrategy), + Optional.of(maxChunksPerDoc) + ); + + ValidationException exception = Assert.assertThrows( + ValidationException.class, + () -> ElasticRerankerServiceSettings.fromMap(settingsMap) + ); + + assertTrue( + exception.getMessage().contains("The [max_chunks_per_doc] setting requires [long_document_strategy] to be set to [chunk]") + ); + } + + public void testFromMap_ChunkingFeatureFlagEnabledAndChunkSelected_CreatesSettingsCorrectly() { + assumeTrue("Only if 'elastic_reranker_chunking_long_documents' feature flag is enabled", ELASTIC_RERANKER_CHUNKING.isEnabled()); + var withAdaptiveAllocations = randomBoolean(); + var numAllocations = withAdaptiveAllocations ? null : randomIntBetween(1, 10); + var numThreads = randomIntBetween(1, 10); + var modelId = randomAlphaOfLength(8); + var adaptiveAllocationsSettings = withAdaptiveAllocations + ? new AdaptiveAllocationsSettings(true, randomIntBetween(0, 2), randomIntBetween(2, 5)) + : null; + var longDocumentStrategy = ElasticRerankerServiceSettings.LongDocumentStrategy.CHUNK; + var maxChunksPerDoc = randomIntBetween(1, 10); + + Map settingsMap = buildServiceSettingsMap( + withAdaptiveAllocations ? Optional.empty() : Optional.of(numAllocations), + numThreads, + modelId, + withAdaptiveAllocations ? Optional.of(adaptiveAllocationsSettings) : Optional.empty(), + Optional.of(longDocumentStrategy), + Optional.of(maxChunksPerDoc) + ); + + ElasticRerankerServiceSettings settings = ElasticRerankerServiceSettings.fromMap(settingsMap); + assertExpectedSettings( + settings, + Optional.ofNullable(numAllocations), + numThreads, + modelId, + Optional.ofNullable(adaptiveAllocationsSettings), + Optional.of(longDocumentStrategy), + Optional.of(maxChunksPerDoc) + ); + } + + public void testFromMap_ChunkingFeatureFlagEnabledAndChunkSelectedWithMaxChunksPerDoc_CreatesSettingsCorrectly() { + assumeTrue("Only if 'elastic_reranker_chunking_long_documents' feature flag is enabled", ELASTIC_RERANKER_CHUNKING.isEnabled()); + var withAdaptiveAllocations = randomBoolean(); + var numAllocations = withAdaptiveAllocations ? null : randomIntBetween(1, 10); + var numThreads = randomIntBetween(1, 10); + var modelId = randomAlphaOfLength(8); + var adaptiveAllocationsSettings = withAdaptiveAllocations + ? new AdaptiveAllocationsSettings(true, randomIntBetween(0, 2), randomIntBetween(2, 5)) + : null; + var longDocumentStrategy = ElasticRerankerServiceSettings.LongDocumentStrategy.CHUNK; + var maxChunksPerDoc = randomIntBetween(1, 10); + + Map settingsMap = buildServiceSettingsMap( + withAdaptiveAllocations ? Optional.empty() : Optional.of(numAllocations), + numThreads, + modelId, + withAdaptiveAllocations ? Optional.of(adaptiveAllocationsSettings) : Optional.empty(), + Optional.of(longDocumentStrategy), + Optional.of(maxChunksPerDoc) + ); + + ElasticRerankerServiceSettings settings = ElasticRerankerServiceSettings.fromMap(settingsMap); + assertExpectedSettings( + settings, + Optional.ofNullable(numAllocations), + numThreads, + modelId, + Optional.ofNullable(adaptiveAllocationsSettings), + Optional.of(longDocumentStrategy), + Optional.of(maxChunksPerDoc) + ); + } + + private Map buildServiceSettingsMap( + Optional numAllocations, + int numThreads, + String modelId, + Optional adaptiveAllocationsSettings, + Optional longDocumentStrategy, + Optional maxChunksPerDoc + ) { + var settingsMap = new HashMap(); + numAllocations.ifPresent(value -> settingsMap.put(NUM_ALLOCATIONS, value)); + settingsMap.put(NUM_THREADS, numThreads); + settingsMap.put(MODEL_ID, modelId); + adaptiveAllocationsSettings.ifPresent(settings -> { + var adaptiveMap = new HashMap(); + adaptiveMap.put(AdaptiveAllocationsSettings.ENABLED.getPreferredName(), settings.getEnabled()); + adaptiveMap.put(AdaptiveAllocationsSettings.MIN_NUMBER_OF_ALLOCATIONS.getPreferredName(), settings.getMinNumberOfAllocations()); + adaptiveMap.put(AdaptiveAllocationsSettings.MAX_NUMBER_OF_ALLOCATIONS.getPreferredName(), settings.getMaxNumberOfAllocations()); + settingsMap.put(ADAPTIVE_ALLOCATIONS, adaptiveMap); + }); + longDocumentStrategy.ifPresent(value -> settingsMap.put(LONG_DOCUMENT_STRATEGY, value.toString())); + maxChunksPerDoc.ifPresent(value -> settingsMap.put(MAX_CHUNKS_PER_DOC, value)); + return settingsMap; + } + + private void assertExpectedSettings( + ElasticRerankerServiceSettings settings, + Optional expectedNumAllocations, + int expectedNumThreads, + String expectedModelId, + Optional expectedAdaptiveAllocationsSettings, + Optional expectedLongDocumentStrategy, + Optional expectedMaxChunksPerDoc + ) { + assertEquals(expectedNumAllocations.orElse(null), settings.getNumAllocations()); + assertEquals(expectedNumThreads, settings.getNumThreads()); + assertEquals(expectedModelId, settings.modelId()); + assertEquals(expectedAdaptiveAllocationsSettings.orElse(null), settings.getAdaptiveAllocationsSettings()); + assertEquals(expectedLongDocumentStrategy.orElse(null), settings.getLongDocumentStrategy()); + assertEquals(expectedMaxChunksPerDoc.orElse(null), settings.getMaxChunksPerDoc()); + } + + @Override + protected Writeable.Reader instanceReader() { + return ElasticRerankerServiceSettings::new; + } + + @Override + protected ElasticRerankerServiceSettings createTestInstance() { + return createRandom(); + } + + @Override + protected ElasticRerankerServiceSettings mutateInstance(ElasticRerankerServiceSettings instance) throws IOException { + return randomValueOtherThan(instance, ElasticRerankerServiceSettingsTests::createRandom); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java index 3af19bf46c62e..e9f22f4848991 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java @@ -53,6 +53,7 @@ import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; +import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.core.ml.MachineLearningField; @@ -74,6 +75,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResultsTests; import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResultsTests; +import org.elasticsearch.xpack.core.ml.inference.results.TextSimilarityInferenceResultsTests; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfig; @@ -83,6 +85,7 @@ import org.elasticsearch.xpack.inference.InputTypeTests; import org.elasticsearch.xpack.inference.ModelConfigurationsTests; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests; +import org.elasticsearch.xpack.inference.chunking.RerankRequestChunker; import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings; import org.elasticsearch.xpack.inference.services.InferenceServiceTestCase; import org.elasticsearch.xpack.inference.services.ServiceFields; @@ -116,6 +119,8 @@ import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; import static org.elasticsearch.xpack.inference.Utils.mockClusterService; import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; +import static org.elasticsearch.xpack.inference.services.elasticsearch.BaseElasticsearchInternalService.notElasticsearchModelException; +import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.ELASTIC_RERANKER_CHUNKING; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.NAME; @@ -143,6 +148,8 @@ public class ElasticsearchInternalServiceTests extends InferenceServiceTestCase private static ThreadPool threadPool; + private final String TEST_SENTENCE = "This is a test sentence that has ten total words. "; + @Before public void setUp() throws Exception { super.setUp(); @@ -980,6 +987,163 @@ public void testUpdateModelWithEmbeddingDetails_ElasticsearchInternalModelNotMod verifyNoMoreInteractions(model); } + public void testInfer_UnsupportedModel() { + var service = createService(mock(Client.class)); + var model = new Model(ModelConfigurationsTests.createRandomInstance()); + + ActionListener listener = ActionListener.wrap( + results -> fail("Expected infer to fail for unsupported model type"), + e -> assertEquals(e.getMessage(), notElasticsearchModelException(model).getMessage()) + ); + + service.infer(model, null, null, null, List.of(), randomBoolean(), Map.of(), InputType.INGEST, null, listener); + } + + public void testInfer_ElasticRerankerSucceedsWithoutChunkingConfiguration() { + var model = new ElasticRerankerModel( + randomAlphaOfLength(10), + TaskType.RERANK, + NAME, + ElasticRerankerServiceSettingsTests.createRandomWithoutChunkingConfiguration(), + new RerankTaskSettings(randomBoolean()) + ); + + testInfer_ElasticReranker(model, generateTestDocs(randomIntBetween(2, 10), randomIntBetween(50, 100))); + } + + public void testInfer_ElasticRerankerFeatureFlagDisabledSucceedsWithTruncateConfiguration() { + assumeTrue( + "Only if 'elastic_reranker_chunking_long_documents' feature flag is disabled", + ELASTIC_RERANKER_CHUNKING.isEnabled() == false + ); + + var model = new ElasticRerankerModel( + randomAlphaOfLength(10), + TaskType.RERANK, + NAME, + ElasticRerankerServiceSettingsTests.createRandomWithChunkingConfiguration( + ElasticRerankerServiceSettings.LongDocumentStrategy.TRUNCATE, + null + ), + new RerankTaskSettings(randomBoolean()) + ); + + testInfer_ElasticReranker(model, generateTestDocs(randomIntBetween(2, 10), randomIntBetween(50, 100))); + } + + public void testInfer_ElasticRerankerFeatureFlagDisabledSucceedsIgnoringChunkConfiguration() { + assumeTrue( + "Only if 'elastic_reranker_chunking_long_documents' feature flag is disabled", + ELASTIC_RERANKER_CHUNKING.isEnabled() == false + ); + + var model = new ElasticRerankerModel( + randomAlphaOfLength(10), + TaskType.RERANK, + NAME, + ElasticRerankerServiceSettingsTests.createRandomWithChunkingConfiguration( + ElasticRerankerServiceSettings.LongDocumentStrategy.CHUNK, + randomBoolean() ? randomIntBetween(1, 10) : null + ), + new RerankTaskSettings(randomBoolean()) + ); + + testInfer_ElasticReranker(model, generateTestDocs(randomIntBetween(2, 10), randomIntBetween(50, 100))); + } + + public void testInfer_ElasticRerankerFeatureFlagEnabledAndSucceedsWithTruncateStrategy() { + assumeTrue("Only if 'elastic_reranker_chunking_long_documents' feature flag is enabled", ELASTIC_RERANKER_CHUNKING.isEnabled()); + + var model = new ElasticRerankerModel( + randomAlphaOfLength(10), + TaskType.RERANK, + NAME, + ElasticRerankerServiceSettingsTests.createRandomWithChunkingConfiguration( + ElasticRerankerServiceSettings.LongDocumentStrategy.TRUNCATE, + null + ), + new RerankTaskSettings(randomBoolean()) + ); + + testInfer_ElasticReranker(model, generateTestDocs(randomIntBetween(2, 10), randomIntBetween(50, 100))); + } + + public void testInfer_ElasticRerankerFeatureFlagEnabledAndSucceedsWithChunkStrategy() { + assumeTrue("Only if 'elastic_reranker_chunking_long_documents' feature flag is enabled", ELASTIC_RERANKER_CHUNKING.isEnabled()); + + var model = new ElasticRerankerModel( + randomAlphaOfLength(10), + TaskType.RERANK, + NAME, + ElasticRerankerServiceSettingsTests.createRandomWithChunkingConfiguration( + ElasticRerankerServiceSettings.LongDocumentStrategy.CHUNK, + randomBoolean() ? randomIntBetween(1, 10) : null + ), + new RerankTaskSettings(randomBoolean()) + ); + + testInfer_ElasticReranker(model, generateTestDocs(randomIntBetween(2, 10), randomIntBetween(50, 100))); + } + + @SuppressWarnings("unchecked") + private void testInfer_ElasticReranker(ElasticRerankerModel model, List inputs) { + var query = randomAlphaOfLength(10); + var mlTrainedModelResults = new ArrayList(); + var numResults = inputs.size(); + if (ELASTIC_RERANKER_CHUNKING.isEnabled() + && ElasticRerankerServiceSettings.LongDocumentStrategy.CHUNK.equals(model.getServiceSettings().getLongDocumentStrategy())) { + var rerankRequestChunker = new RerankRequestChunker(query, inputs, model.getServiceSettings().getMaxChunksPerDoc()); + numResults = rerankRequestChunker.getChunkedInputs().size(); + } + for (int i = 0; i < numResults; i++) { + mlTrainedModelResults.add(TextSimilarityInferenceResultsTests.createRandomResults()); + } + var response = new InferModelAction.Response(mlTrainedModelResults, "foo", true); + + Client client = mock(Client.class); + when(client.threadPool()).thenReturn(threadPool); + doAnswer(invocationOnMock -> { + var listener = (ActionListener) invocationOnMock.getArguments()[2]; + listener.onResponse(response); + return null; + }).when(client).execute(same(InferModelAction.INSTANCE), any(InferModelAction.Request.class), any(ActionListener.class)); + + var service = createService(client); + var topN = randomBoolean() ? null : randomIntBetween(1, inputs.size()); + + ActionListener listener = ActionListener.wrap(results -> { + assertThat(results, instanceOf(RankedDocsResults.class)); + var rankedDocsResults = (RankedDocsResults) results; + assertEquals(topN == null ? inputs.size() : topN, rankedDocsResults.getRankedDocs().size()); + + }, ESTestCase::fail); + + service.infer( + model, + randomAlphaOfLength(10), + randomBoolean() ? null : randomBoolean(), + topN, + inputs, + false, + Map.of(), + InputType.INGEST, + null, + listener + ); + } + + private List generateTestDocs(int numDocs, int numSentencesPerDoc) { + var docs = new ArrayList(); + for (int docIndex = 0; docIndex < numDocs; docIndex++) { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < numSentencesPerDoc; i++) { + sb.append(TEST_SENTENCE); + } + docs.add(sb.toString()); + } + return docs; + } + public void testChunkInfer_E5WithNullChunkingSettings() throws InterruptedException { testChunkInfer_e5(null); } From 833ef027ef846911593c48f2da57862afce3651a Mon Sep 17 00:00:00 2001 From: Dan Rubinstein Date: Mon, 22 Sep 2025 13:18:37 -0400 Subject: [PATCH 7/8] Update docs/changelog/130485.yaml --- docs/changelog/130485.yaml | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 docs/changelog/130485.yaml diff --git a/docs/changelog/130485.yaml b/docs/changelog/130485.yaml new file mode 100644 index 0000000000000..b01cf904647e3 --- /dev/null +++ b/docs/changelog/130485.yaml @@ -0,0 +1,5 @@ +pr: 130485 +summary: Add `RerankRequestChunker` +area: Machine Learning +type: enhancement +issues: [] From 344e121534a0ea550fd872aa127d8148de0b8466 Mon Sep 17 00:00:00 2001 From: dan-rubinstein Date: Thu, 25 Sep 2025 12:34:01 -0400 Subject: [PATCH 8/8] Adding unit tests and refactoring code with clearer naming conventions --- .../chunking/RerankRequestChunker.java | 12 +++++---- ...ankFeaturePhaseRankCoordinatorContext.java | 1 - .../ElasticsearchInternalService.java | 13 +++------- .../chunking/RerankRequestChunkerTests.java | 26 +++++++++---------- .../xpack/inference/InferenceRestIT.java | 2 ++ 5 files changed, 24 insertions(+), 30 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RerankRequestChunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RerankRequestChunker.java index 05bd5bbde8741..87feb19986583 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RerankRequestChunker.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RerankRequestChunker.java @@ -66,9 +66,12 @@ public ActionListener parseChunkedRerankResultsListener } private RankedDocsResults parseRankedDocResultsForChunks(RankedDocsResults rankedDocsResults) { - List updatedRankedDocs = new ArrayList<>(); + List topRankedDocs = new ArrayList<>(); Set docIndicesSeen = new HashSet<>(); - for (RankedDocsResults.RankedDoc rankedDoc : rankedDocsResults.getRankedDocs()) { + + List rankedDocs = new ArrayList<>(rankedDocsResults.getRankedDocs()); + rankedDocs.sort((r1, r2) -> Float.compare(r2.relevanceScore(), r1.relevanceScore())); + for (RankedDocsResults.RankedDoc rankedDoc : rankedDocs) { int chunkIndex = rankedDoc.index(); int docIndex = rerankChunks.get(chunkIndex).docIndex(); @@ -79,13 +82,12 @@ private RankedDocsResults parseRankedDocResultsForChunks(RankedDocsResults ranke rankedDoc.relevanceScore(), inputs.get(docIndex) ); - updatedRankedDocs.add(updatedRankedDoc); + topRankedDocs.add(updatedRankedDoc); docIndicesSeen.add(docIndex); } } - updatedRankedDocs.sort((r1, r2) -> Float.compare(r2.relevanceScore(), r1.relevanceScore())); - return new RankedDocsResults(updatedRankedDocs); + return new RankedDocsResults(topRankedDocs); } public record RerankChunks(int docIndex, String chunkString) {}; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java index bc363d16c81cd..725443bc01e0d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java @@ -178,7 +178,6 @@ protected RankFeatureDoc[] preprocess(RankFeatureDoc[] originalDocs, boolean rer } protected InferenceAction.Request generateRequest(List docFeatures) { - // TODO: Try running the RerankRequestChunker here. return new InferenceAction.Request( TaskType.RERANK, inferenceId, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index 042b056defb2f..8bf8043a1ec0d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -686,7 +686,6 @@ public void inferRerank( Map requestTaskSettings, ActionListener listener ) { - var chunkedInputs = inputs; ActionListener resultsListener = listener.delegateFailure((l, results) -> { if (results instanceof RankedDocsResults rankedDocsResults) { if (topN != null) { @@ -704,18 +703,12 @@ public void inferRerank( var longDocumentStrategy = serviceSettings.getLongDocumentStrategy(); if (longDocumentStrategy == ElasticRerankerServiceSettings.LongDocumentStrategy.CHUNK) { var rerankChunker = new RerankRequestChunker(query, inputs, serviceSettings.getMaxChunksPerDoc()); - chunkedInputs = rerankChunker.getChunkedInputs(); + inputs = rerankChunker.getChunkedInputs(); resultsListener = rerankChunker.parseChunkedRerankResultsListener(resultsListener); } } - var request = buildInferenceRequest( - model.mlNodeDeploymentId(), - new TextSimilarityConfigUpdate(query), - chunkedInputs, - inputType, - timeout - ); + var request = buildInferenceRequest(model.mlNodeDeploymentId(), new TextSimilarityConfigUpdate(query), inputs, inputType, timeout); var returnDocs = Boolean.TRUE; if (returnDocuments != null) { @@ -725,7 +718,7 @@ public void inferRerank( returnDocs = RerankTaskSettings.of(modelSettings, requestSettings).returnDocuments(); } - Function inputSupplier = returnDocs == Boolean.TRUE ? chunkedInputs::get : i -> null; + Function inputSupplier = returnDocs == Boolean.TRUE ? inputs::get : i -> null; ActionListener mlResultsListener = resultsListener.delegateFailureAndWrap( (l, inferenceResult) -> l.onResponse(textSimilarityResultsToRankedDocs(inferenceResult.getInferenceResults(), inputSupplier)) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/RerankRequestChunkerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/RerankRequestChunkerTests.java index 0fbee5a16109d..5674fb3b73c98 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/RerankRequestChunkerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/RerankRequestChunkerTests.java @@ -12,9 +12,10 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; +import java.util.ArrayList; import java.util.List; -import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static java.lang.Math.max; import static org.hamcrest.Matchers.instanceOf; public class RerankRequestChunkerTests extends ESTestCase { @@ -150,7 +151,7 @@ public void testParseChunkedRerankResultsListener_SingleInputWithChunking() { assertThat(results, instanceOf(RankedDocsResults.class)); var rankedDocResults = (RankedDocsResults) results; assertEquals(1, rankedDocResults.getRankedDocs().size()); - var expectedRankedDocs = List.of(new RankedDocsResults.RankedDoc(0, relevanceScore1, inputs.get(0))); + var expectedRankedDocs = List.of(new RankedDocsResults.RankedDoc(0, max(relevanceScore1, relevanceScore2), inputs.get(0))); assertEquals(expectedRankedDocs, rankedDocResults.getRankedDocs()); }, e -> fail("Expected successful parsing but got failure: " + e))); @@ -171,10 +172,9 @@ public void testParseChunkedRerankResultsListener_MultipleInputsWithoutChunking( assertThat(results, instanceOf(RankedDocsResults.class)); var rankedDocResults = (RankedDocsResults) results; assertEquals(2, rankedDocResults.getRankedDocs().size()); - assertThat( - rankedDocResults.getRankedDocs().get(0).relevanceScore(), - greaterThanOrEqualTo(rankedDocResults.getRankedDocs().get(1).relevanceScore()) - ); + var sortedResults = new ArrayList<>(rankedDocResults.getRankedDocs()); + sortedResults.sort((r1, r2) -> Float.compare(r2.relevanceScore(), r1.relevanceScore())); + assertEquals(sortedResults, rankedDocResults.getRankedDocs()); }, e -> fail("Expected successful parsing but got failure: " + e))); var chunkedInputs = chunker.getChunkedInputs(); @@ -196,10 +196,9 @@ public void testParseChunkedRerankResultsListener_MultipleInputsWithSomeChunking assertThat(results, instanceOf(RankedDocsResults.class)); var rankedDocResults = (RankedDocsResults) results; assertEquals(2, rankedDocResults.getRankedDocs().size()); - assertThat( - rankedDocResults.getRankedDocs().get(0).relevanceScore(), - greaterThanOrEqualTo(rankedDocResults.getRankedDocs().get(1).relevanceScore()) - ); + var sortedResults = new ArrayList<>(rankedDocResults.getRankedDocs()); + sortedResults.sort((r1, r2) -> Float.compare(r2.relevanceScore(), r1.relevanceScore())); + assertEquals(sortedResults, rankedDocResults.getRankedDocs()); }, e -> fail("Expected successful parsing but got failure: " + e))); var chunkedInputs = chunker.getChunkedInputs(); @@ -222,10 +221,9 @@ public void testParseChunkedRerankResultsListener_MultipleInputsWithAllRequiring assertThat(results, instanceOf(RankedDocsResults.class)); var rankedDocResults = (RankedDocsResults) results; assertEquals(2, rankedDocResults.getRankedDocs().size()); - assertThat( - rankedDocResults.getRankedDocs().get(0).relevanceScore(), - greaterThanOrEqualTo(rankedDocResults.getRankedDocs().get(1).relevanceScore()) - ); + var sortedResults = new ArrayList<>(rankedDocResults.getRankedDocs()); + sortedResults.sort((r1, r2) -> Float.compare(r2.relevanceScore(), r1.relevanceScore())); + assertEquals(sortedResults, rankedDocResults.getRankedDocs()); }, e -> fail("Expected successful parsing but got failure: " + e))); var chunkedInputs = chunker.getChunkedInputs(); diff --git a/x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java b/x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java index f39b3f2b01368..c87d7fb40f63b 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java +++ b/x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java @@ -13,6 +13,7 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.test.cluster.ElasticsearchCluster; +import org.elasticsearch.test.cluster.FeatureFlag; import org.elasticsearch.test.cluster.local.distribution.DistributionType; import org.elasticsearch.test.rest.yaml.ClientYamlTestCandidate; import org.elasticsearch.test.rest.yaml.ESClientYamlSuiteTestCase; @@ -31,6 +32,7 @@ public class InferenceRestIT extends ESClientYamlSuiteTestCase { .setting("xpack.security.enabled", "false") .setting("xpack.security.http.ssl.enabled", "false") .setting("xpack.license.self_generated.type", "trial") + .feature(FeatureFlag.ELASTIC_RERANKER_CHUNKING) .plugin("inference-service-test") .distribution(DistributionType.DEFAULT) .build();