diff --git a/docs/changelog/130485.yaml b/docs/changelog/130485.yaml new file mode 100644 index 0000000000000..b01cf904647e3 --- /dev/null +++ b/docs/changelog/130485.yaml @@ -0,0 +1,5 @@ +pr: 130485 +summary: Add `RerankRequestChunker` +area: Machine Learning +type: enhancement +issues: [] diff --git a/server/src/main/resources/transport/definitions/referable/elastic_reranker_chunking_configuration.csv b/server/src/main/resources/transport/definitions/referable/elastic_reranker_chunking_configuration.csv new file mode 100644 index 0000000000000..9109265b1c299 --- /dev/null +++ b/server/src/main/resources/transport/definitions/referable/elastic_reranker_chunking_configuration.csv @@ -0,0 +1 @@ +9180000 diff --git a/server/src/main/resources/transport/upper_bounds/9.2.csv b/server/src/main/resources/transport/upper_bounds/9.2.csv index 2f27ba13c86cd..2afe25615ac23 100644 --- a/server/src/main/resources/transport/upper_bounds/9.2.csv +++ b/server/src/main/resources/transport/upper_bounds/9.2.csv @@ -1 +1 @@ -ml_inference_google_model_garden_added,9179000 +elastic_reranker_chunking_configuration,9180000 diff --git a/test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java b/test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java index 56fdac88dbfe1..644a5b46f4420 100644 --- a/test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java +++ b/test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java @@ -25,7 +25,8 @@ public enum FeatureFlag { "es.index_dimensions_tsid_optimization_feature_flag_enabled=true", Version.fromString("9.2.0"), null - ); + ), + ELASTIC_RERANKER_CHUNKING("es.elastic_reranker_chunking_long_documents=true", Version.fromString("9.2.0"), null); public final String systemProperty; public final Version from; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilder.java index 2f912d891ef60..12c41c7ea9470 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilder.java @@ -17,6 +17,9 @@ public class ChunkingSettingsBuilder { public static final SentenceBoundaryChunkingSettings DEFAULT_SETTINGS = new SentenceBoundaryChunkingSettings(250, 1); // Old settings used for backward compatibility for endpoints created before 8.16 when default was changed public static final WordBoundaryChunkingSettings OLD_DEFAULT_SETTINGS = new WordBoundaryChunkingSettings(250, 100); + public static final int ELASTIC_RERANKER_TOKEN_LIMIT = 512; + public static final int ELASTIC_RERANKER_EXTRA_TOKEN_COUNT = 3; + public static final float WORDS_PER_TOKEN = 0.75f; public static ChunkingSettings fromMap(Map settings) { return fromMap(settings, true); @@ -51,4 +54,17 @@ public static ChunkingSettings fromMap(Map settings, boolean ret case RECURSIVE -> RecursiveChunkingSettings.fromMap(new HashMap<>(settings)); }; } + + public static ChunkingSettings buildChunkingSettingsForElasticRerank(int queryWordCount) { + var queryTokenCount = Math.ceil(queryWordCount / WORDS_PER_TOKEN); + var chunkSizeTokenCountWithFullQuery = (ELASTIC_RERANKER_TOKEN_LIMIT - ELASTIC_RERANKER_EXTRA_TOKEN_COUNT - queryTokenCount); + + var maxChunkSizeTokenCount = Math.floor((float) ELASTIC_RERANKER_TOKEN_LIMIT / 2); + if (chunkSizeTokenCountWithFullQuery > maxChunkSizeTokenCount) { + maxChunkSizeTokenCount = chunkSizeTokenCountWithFullQuery; + } + + var maxChunkSizeWordCount = (int) (maxChunkSizeTokenCount * WORDS_PER_TOKEN); + return new SentenceBoundaryChunkingSettings(maxChunkSizeWordCount, 1); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RerankRequestChunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RerankRequestChunker.java new file mode 100644 index 0000000000000..87feb19986583 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RerankRequestChunker.java @@ -0,0 +1,101 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.chunking; + +import com.ibm.icu.text.BreakIterator; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +public class RerankRequestChunker { + private final List inputs; + private final List rerankChunks; + + public RerankRequestChunker(String query, List inputs, Integer maxChunksPerDoc) { + this.inputs = inputs; + this.rerankChunks = chunk(inputs, buildChunkingSettingsForElasticRerank(query), maxChunksPerDoc); + } + + private List chunk(List inputs, ChunkingSettings chunkingSettings, Integer maxChunksPerDoc) { + var chunker = ChunkerBuilder.fromChunkingStrategy(chunkingSettings.getChunkingStrategy()); + var chunks = new ArrayList(); + for (int i = 0; i < inputs.size(); i++) { + var chunksForInput = chunker.chunk(inputs.get(i), chunkingSettings); + if (maxChunksPerDoc != null && chunksForInput.size() > maxChunksPerDoc) { + chunksForInput = chunksForInput.subList(0, maxChunksPerDoc); + } + + for (var chunk : chunksForInput) { + chunks.add(new RerankChunks(i, inputs.get(i).substring(chunk.start(), chunk.end()))); + } + } + return chunks; + } + + public List getChunkedInputs() { + List chunkedInputs = new ArrayList<>(); + for (RerankChunks chunk : rerankChunks) { + chunkedInputs.add(chunk.chunkString()); + } + + return chunkedInputs; + } + + public ActionListener parseChunkedRerankResultsListener(ActionListener listener) { + return ActionListener.wrap(results -> { + if (results instanceof RankedDocsResults rankedDocsResults) { + listener.onResponse(parseRankedDocResultsForChunks(rankedDocsResults)); + + } else { + listener.onFailure(new IllegalArgumentException("Expected RankedDocsResults but got: " + results.getClass())); + } + + }, listener::onFailure); + } + + private RankedDocsResults parseRankedDocResultsForChunks(RankedDocsResults rankedDocsResults) { + List topRankedDocs = new ArrayList<>(); + Set docIndicesSeen = new HashSet<>(); + + List rankedDocs = new ArrayList<>(rankedDocsResults.getRankedDocs()); + rankedDocs.sort((r1, r2) -> Float.compare(r2.relevanceScore(), r1.relevanceScore())); + for (RankedDocsResults.RankedDoc rankedDoc : rankedDocs) { + int chunkIndex = rankedDoc.index(); + int docIndex = rerankChunks.get(chunkIndex).docIndex(); + + if (docIndicesSeen.contains(docIndex) == false) { + // Create a ranked doc with the full input string and the index for the document instead of the chunk + RankedDocsResults.RankedDoc updatedRankedDoc = new RankedDocsResults.RankedDoc( + docIndex, + rankedDoc.relevanceScore(), + inputs.get(docIndex) + ); + topRankedDocs.add(updatedRankedDoc); + docIndicesSeen.add(docIndex); + } + } + + return new RankedDocsResults(topRankedDocs); + } + + public record RerankChunks(int docIndex, String chunkString) {}; + + private ChunkingSettings buildChunkingSettingsForElasticRerank(String query) { + var wordIterator = BreakIterator.getWordInstance(); + wordIterator.setText(query); + var queryWordCount = ChunkerUtils.countWords(0, query.length(), wordIterator); + return ChunkingSettingsBuilder.buildChunkingSettingsForElasticRerank(queryWordCount); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticRerankerServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticRerankerServiceSettings.java index 2b7904e615682..dbf7c5132c996 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticRerankerServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticRerankerServiceSettings.java @@ -7,25 +7,51 @@ package org.elasticsearch.xpack.inference.services.elasticsearch; +import org.elasticsearch.TransportVersion; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; import java.io.IOException; +import java.util.EnumSet; +import java.util.Locale; import java.util.Map; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; +import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.ELASTIC_RERANKER_CHUNKING; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.RERANKER_ID; public class ElasticRerankerServiceSettings extends ElasticsearchInternalServiceSettings { public static final String NAME = "elastic_reranker_service_settings"; + public static final String LONG_DOCUMENT_STRATEGY = "long_document_strategy"; + public static final String MAX_CHUNKS_PER_DOC = "max_chunks_per_doc"; + + private static final TransportVersion ELASTIC_RERANKER_CHUNKING_CONFIGURATION = TransportVersion.fromName( + "elastic_reranker_chunking_configuration" + ); + + private final LongDocumentStrategy longDocumentStrategy; + private final Integer maxChunksPerDoc; + public static ElasticRerankerServiceSettings defaultEndpointSettings() { return new ElasticRerankerServiceSettings(null, 1, RERANKER_ID, new AdaptiveAllocationsSettings(Boolean.TRUE, 0, 32)); } - public ElasticRerankerServiceSettings(ElasticsearchInternalServiceSettings other) { + public ElasticRerankerServiceSettings( + ElasticsearchInternalServiceSettings other, + LongDocumentStrategy longDocumentStrategy, + Integer maxChunksPerDoc + ) { super(other); + this.longDocumentStrategy = longDocumentStrategy; + this.maxChunksPerDoc = maxChunksPerDoc; + } private ElasticRerankerServiceSettings( @@ -35,10 +61,32 @@ private ElasticRerankerServiceSettings( AdaptiveAllocationsSettings adaptiveAllocationsSettings ) { super(numAllocations, numThreads, modelId, adaptiveAllocationsSettings, null); + this.longDocumentStrategy = null; + this.maxChunksPerDoc = null; + } + + protected ElasticRerankerServiceSettings( + Integer numAllocations, + int numThreads, + String modelId, + AdaptiveAllocationsSettings adaptiveAllocationsSettings, + LongDocumentStrategy longDocumentStrategy, + Integer maxChunksPerDoc + ) { + super(numAllocations, numThreads, modelId, adaptiveAllocationsSettings, null); + this.longDocumentStrategy = longDocumentStrategy; + this.maxChunksPerDoc = maxChunksPerDoc; } public ElasticRerankerServiceSettings(StreamInput in) throws IOException { super(in); + if (in.getTransportVersion().supports(ELASTIC_RERANKER_CHUNKING_CONFIGURATION)) { + this.longDocumentStrategy = in.readOptionalEnum(LongDocumentStrategy.class); + this.maxChunksPerDoc = in.readOptionalInt(); + } else { + this.longDocumentStrategy = null; + this.maxChunksPerDoc = null; + } } /** @@ -48,21 +96,93 @@ public ElasticRerankerServiceSettings(StreamInput in) throws IOException { * {@link ValidationException} is thrown. * * @param map Source map containing the config - * @return The builder + * @return Parsed and validated service settings */ - public static Builder fromRequestMap(Map map) { + public static ElasticRerankerServiceSettings fromMap(Map map) { ValidationException validationException = new ValidationException(); var baseSettings = ElasticsearchInternalServiceSettings.fromMap(map, validationException); + LongDocumentStrategy longDocumentStrategy = null; + Integer maxChunksPerDoc = null; + if (ELASTIC_RERANKER_CHUNKING.isEnabled()) { + longDocumentStrategy = extractOptionalEnum( + map, + LONG_DOCUMENT_STRATEGY, + ModelConfigurations.SERVICE_SETTINGS, + LongDocumentStrategy::fromString, + EnumSet.allOf(LongDocumentStrategy.class), + validationException + ); + + maxChunksPerDoc = extractOptionalPositiveInteger( + map, + MAX_CHUNKS_PER_DOC, + ModelConfigurations.SERVICE_SETTINGS, + validationException + ); + + if (maxChunksPerDoc != null && (longDocumentStrategy == null || longDocumentStrategy == LongDocumentStrategy.TRUNCATE)) { + validationException.addValidationError( + "The [" + MAX_CHUNKS_PER_DOC + "] setting requires [" + LONG_DOCUMENT_STRATEGY + "] to be set to [chunk]" + ); + } + } + if (validationException.validationErrors().isEmpty() == false) { throw validationException; } - return baseSettings; + return new ElasticRerankerServiceSettings(baseSettings.build(), longDocumentStrategy, maxChunksPerDoc); + } + + public LongDocumentStrategy getLongDocumentStrategy() { + return longDocumentStrategy; + } + + public Integer getMaxChunksPerDoc() { + return maxChunksPerDoc; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + if (out.getTransportVersion().supports(ELASTIC_RERANKER_CHUNKING_CONFIGURATION)) { + out.writeOptionalEnum(longDocumentStrategy); + out.writeOptionalInt(maxChunksPerDoc); + } + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + addInternalSettingsToXContent(builder, params); + if (longDocumentStrategy != null) { + builder.field(LONG_DOCUMENT_STRATEGY, longDocumentStrategy.strategyName); + } + if (maxChunksPerDoc != null) { + builder.field(MAX_CHUNKS_PER_DOC, maxChunksPerDoc); + } + builder.endObject(); + return builder; } @Override public String getWriteableName() { return ElasticRerankerServiceSettings.NAME; } + + public enum LongDocumentStrategy { + CHUNK("chunk"), + TRUNCATE("truncate"); + + public final String strategyName; + + LongDocumentStrategy(String strategyName) { + this.strategyName = strategyName; + } + + public static LongDocumentStrategy fromString(String name) { + return valueOf(name.trim().toUpperCase(Locale.ROOT)); + } + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index 89258d5716e8e..8bf8043a1ec0d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -16,6 +16,7 @@ import org.elasticsearch.common.logging.DeprecationCategory; import org.elasticsearch.common.logging.DeprecationLogger; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.FeatureFlag; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Strings; @@ -58,6 +59,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextSimilarityConfigUpdate; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; +import org.elasticsearch.xpack.inference.chunking.RerankRequestChunker; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceUtils; @@ -114,6 +116,8 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi private static final Logger logger = LogManager.getLogger(ElasticsearchInternalService.class); private static final DeprecationLogger DEPRECATION_LOGGER = DeprecationLogger.getLogger(ElasticsearchInternalService.class); + public static final FeatureFlag ELASTIC_RERANKER_CHUNKING = new FeatureFlag("elastic_reranker_chunking_long_documents"); + /** * Fix for https://github.com/elastic/elasticsearch/issues/124675 * In 8.13.0 we transitioned from model_version to model_id. Any elser inference endpoints created prior to 8.13.0 will still use @@ -349,19 +353,13 @@ private void rerankerCase( ActionListener modelListener ) { - var esServiceSettingsBuilder = ElasticsearchInternalServiceSettings.fromRequestMap(serviceSettingsMap); + var serviceSettings = ElasticRerankerServiceSettings.fromMap(serviceSettingsMap); throwIfNotEmptyMap(config, name()); throwIfNotEmptyMap(serviceSettingsMap, name()); modelListener.onResponse( - new ElasticRerankerModel( - inferenceEntityId, - taskType, - NAME, - new ElasticRerankerServiceSettings(esServiceSettingsBuilder.build()), - RerankTaskSettings.fromMap(taskSettingsMap) - ) + new ElasticRerankerModel(inferenceEntityId, taskType, NAME, serviceSettings, RerankTaskSettings.fromMap(taskSettingsMap)) ); } @@ -535,7 +533,7 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M inferenceEntityId, taskType, NAME, - new ElasticRerankerServiceSettings(ElasticsearchInternalServiceSettings.fromPersistedMap(serviceSettingsMap)), + ElasticRerankerServiceSettings.fromMap(serviceSettingsMap), RerankTaskSettings.fromMap(taskSettingsMap) ); } else { @@ -688,6 +686,28 @@ public void inferRerank( Map requestTaskSettings, ActionListener listener ) { + ActionListener resultsListener = listener.delegateFailure((l, results) -> { + if (results instanceof RankedDocsResults rankedDocsResults) { + if (topN != null) { + l.onResponse(new RankedDocsResults(rankedDocsResults.getRankedDocs().subList(0, topN))); + } else { + l.onResponse(rankedDocsResults); + } + } else { + l.onFailure(new IllegalArgumentException("Expected RankedDocsResults but got: " + results.getClass().getName())); + } + }); + + if (model instanceof ElasticRerankerModel elasticRerankerModel && ELASTIC_RERANKER_CHUNKING.isEnabled()) { + var serviceSettings = elasticRerankerModel.getServiceSettings(); + var longDocumentStrategy = serviceSettings.getLongDocumentStrategy(); + if (longDocumentStrategy == ElasticRerankerServiceSettings.LongDocumentStrategy.CHUNK) { + var rerankChunker = new RerankRequestChunker(query, inputs, serviceSettings.getMaxChunksPerDoc()); + inputs = rerankChunker.getChunkedInputs(); + resultsListener = rerankChunker.parseChunkedRerankResultsListener(resultsListener); + } + + } var request = buildInferenceRequest(model.mlNodeDeploymentId(), new TextSimilarityConfigUpdate(query), inputs, inputType, timeout); var returnDocs = Boolean.TRUE; @@ -700,10 +720,8 @@ public void inferRerank( Function inputSupplier = returnDocs == Boolean.TRUE ? inputs::get : i -> null; - ActionListener mlResultsListener = listener.delegateFailureAndWrap( - (l, inferenceResult) -> l.onResponse( - textSimilarityResultsToRankedDocs(inferenceResult.getInferenceResults(), inputSupplier, topN) - ) + ActionListener mlResultsListener = resultsListener.delegateFailureAndWrap( + (l, inferenceResult) -> l.onResponse(textSimilarityResultsToRankedDocs(inferenceResult.getInferenceResults(), inputSupplier)) ); var maybeDeployListener = mlResultsListener.delegateResponse( @@ -812,8 +830,7 @@ public List aliases() { private RankedDocsResults textSimilarityResultsToRankedDocs( List results, - Function inputSupplier, - @Nullable Integer topN + Function inputSupplier ) { List rankings = new ArrayList<>(results.size()); for (int i = 0; i < results.size(); i++) { @@ -840,7 +857,7 @@ private RankedDocsResults textSimilarityResultsToRankedDocs( } Collections.sort(rankings); - return new RankedDocsResults(topN != null ? rankings.subList(0, topN) : rankings); + return new RankedDocsResults(rankings); } public List defaultConfigIds() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilderTests.java index 9e6dde60bc641..cc464a933481f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilderTests.java @@ -15,6 +15,10 @@ import java.util.HashMap; import java.util.Map; +import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder.ELASTIC_RERANKER_EXTRA_TOKEN_COUNT; +import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder.ELASTIC_RERANKER_TOKEN_LIMIT; +import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder.WORDS_PER_TOKEN; + public class ChunkingSettingsBuilderTests extends ESTestCase { public static final SentenceBoundaryChunkingSettings DEFAULT_SETTINGS = new SentenceBoundaryChunkingSettings(250, 1); @@ -47,6 +51,32 @@ public void testValidChunkingSettingsMap() { }); } + public void testBuildChunkingSettingsForElasticReranker_QueryTokenCountLessThanHalfOfTokenLimit() { + // Generate a word count for a non-empty query that takes up less than half the token limit + int maxQueryTokenCount = (ELASTIC_RERANKER_TOKEN_LIMIT - ELASTIC_RERANKER_EXTRA_TOKEN_COUNT) / 2; + int queryWordCount = randomIntBetween(1, (int) (maxQueryTokenCount * WORDS_PER_TOKEN)); + var queryTokenCount = Math.ceil(queryWordCount / WORDS_PER_TOKEN); + ChunkingSettings chunkingSettings = ChunkingSettingsBuilder.buildChunkingSettingsForElasticRerank(queryWordCount); + assertTrue(chunkingSettings instanceof SentenceBoundaryChunkingSettings); + SentenceBoundaryChunkingSettings sentenceBoundaryChunkingSettings = (SentenceBoundaryChunkingSettings) chunkingSettings; + int expectedMaxChunkSize = (int) ((ELASTIC_RERANKER_TOKEN_LIMIT - ELASTIC_RERANKER_EXTRA_TOKEN_COUNT - queryTokenCount) + * WORDS_PER_TOKEN); + assertEquals(expectedMaxChunkSize, sentenceBoundaryChunkingSettings.maxChunkSize); + assertEquals(1, sentenceBoundaryChunkingSettings.sentenceOverlap); + } + + public void testBuildChunkingSettingsForElasticReranker_QueryTokenCountMoreThanHalfOfTokenLimit() { + // Generate a word count for a non-empty query that takes up more than half the token limit + int maxQueryTokenCount = (ELASTIC_RERANKER_TOKEN_LIMIT - ELASTIC_RERANKER_EXTRA_TOKEN_COUNT) / 2; + int queryWordCount = randomIntBetween((int) (maxQueryTokenCount * WORDS_PER_TOKEN), Integer.MAX_VALUE); + ChunkingSettings chunkingSettings = ChunkingSettingsBuilder.buildChunkingSettingsForElasticRerank(queryWordCount); + assertTrue(chunkingSettings instanceof SentenceBoundaryChunkingSettings); + SentenceBoundaryChunkingSettings sentenceBoundaryChunkingSettings = (SentenceBoundaryChunkingSettings) chunkingSettings; + int expectedMaxChunkSize = (int) (Math.floor((float) ELASTIC_RERANKER_TOKEN_LIMIT / 2) * WORDS_PER_TOKEN); + assertEquals(expectedMaxChunkSize, sentenceBoundaryChunkingSettings.maxChunkSize); + assertEquals(1, sentenceBoundaryChunkingSettings.sentenceOverlap); + } + private Map, ChunkingSettings> chunkingSettingsMapToChunkingSettings() { var maxChunkSizeWordBoundaryChunkingSettings = randomIntBetween(10, 300); var overlap = randomIntBetween(1, maxChunkSizeWordBoundaryChunkingSettings / 2); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/RerankRequestChunkerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/RerankRequestChunkerTests.java new file mode 100644 index 0000000000000..5674fb3b73c98 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/RerankRequestChunkerTests.java @@ -0,0 +1,250 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.chunking; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; + +import java.util.ArrayList; +import java.util.List; + +import static java.lang.Math.max; +import static org.hamcrest.Matchers.instanceOf; + +public class RerankRequestChunkerTests extends ESTestCase { + private final String TEST_SENTENCE = "This is a test sentence that has ten total words. "; + + public void testGetChunkedInput_EmptyInput() { + var chunker = new RerankRequestChunker(TEST_SENTENCE, List.of(), null); + assertTrue(chunker.getChunkedInputs().isEmpty()); + } + + public void testGetChunkedInput_SingleInputWithoutChunkingRequired() { + var inputs = List.of(generateTestText(10)); + var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs, randomBoolean() ? null : randomIntBetween(1, 10)); + assertEquals(inputs, chunker.getChunkedInputs()); + } + + public void testGetChunkedInput_SingleInputWithChunkingRequired() { + var inputs = List.of(generateTestText(100)); + var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs, null); + var chunkedInputs = chunker.getChunkedInputs(); + assertEquals(3, chunkedInputs.size()); + } + + public void testGetChunkedInput_SingleInputWithChunkingRequiredWithMaxChunksPerDocLessThanTotalChunksGenerated() { + var inputs = List.of(generateTestText(100)); + var maxChunksPerDoc = randomIntBetween(1, 2); + var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs, maxChunksPerDoc); + var chunkedInputs = chunker.getChunkedInputs(); + assertEquals(maxChunksPerDoc, chunkedInputs.size()); + } + + public void testGetChunkedInput_SingleInputWithChunkingRequiredWithMaxChunksPerDocGreaterThanTotalChunksGenerated() { + var inputs = List.of(generateTestText(100)); + var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs, randomIntBetween(4, 10)); + var chunkedInputs = chunker.getChunkedInputs(); + assertEquals(3, chunkedInputs.size()); + } + + public void testGetChunkedInput_MultipleInputsWithoutChunkingRequired() { + var inputs = List.of(generateTestText(10), generateTestText(10)); + var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs, randomBoolean() ? null : randomIntBetween(1, 10)); + assertEquals(inputs, chunker.getChunkedInputs()); + } + + public void testGetChunkedInput_MultipleInputsWithSomeChunkingRequired() { + var inputs = List.of(generateTestText(10), generateTestText(100)); + var chunker = new RerankRequestChunker(randomAlphaOfLength(10), inputs, null); + var chunkedInputs = chunker.getChunkedInputs(); + assertEquals(4, chunkedInputs.size()); + } + + public void testGetChunkedInput_MultipleInputsWithSomeChunkingRequiredWithMaxChunksPerDocLessThanTotalChunksGenerated() { + var inputs = List.of(generateTestText(10), generateTestText(100)); + var maxChunksPerDoc = randomIntBetween(1, 2); + var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs, maxChunksPerDoc); + var chunkedInputs = chunker.getChunkedInputs(); + assertEquals(1 + maxChunksPerDoc, chunkedInputs.size()); + } + + public void testGetChunkedInput_MultipleInputsWithSomeChunkingRequiredWithMaxChunksPerDocGreaterThanTotalChunksGenerated() { + var inputs = List.of(generateTestText(10), generateTestText(100)); + var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs, randomIntBetween(3, 10)); + var chunkedInputs = chunker.getChunkedInputs(); + assertEquals(4, chunkedInputs.size()); + } + + public void testGetChunkedInput_MultipleInputsWithAllRequiringChunking() { + var inputs = List.of(generateTestText(100), generateTestText(100)); + var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs, null); + var chunkedInputs = chunker.getChunkedInputs(); + assertEquals(6, chunkedInputs.size()); + } + + public void testGetChunkedInput_MultipleInputsWithAllRequiringChunkingWithMaxChunksPerDocLessThanTotalChunksGenerated() { + var inputs = List.of(generateTestText(100), generateTestText(100)); + var maxChunksPerDoc = randomIntBetween(1, 2); + var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs, maxChunksPerDoc); + var chunkedInputs = chunker.getChunkedInputs(); + assertEquals(2 * maxChunksPerDoc, chunkedInputs.size()); + } + + public void testGetChunkedInput_MultipleInputsWithAllRequiringChunkingWithMaxChunksPerDocGreaterThanTotalChunksGenerated() { + var inputs = List.of(generateTestText(100), generateTestText(100)); + var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs, randomIntBetween(4, 10)); + var chunkedInputs = chunker.getChunkedInputs(); + assertEquals(6, chunkedInputs.size()); + } + + public void testParseChunkedRerankResultsListener_NonRankedDocsResults() { + var inputs = List.of(generateTestText(10), generateTestText(100)); + var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs, null); + var listener = chunker.parseChunkedRerankResultsListener( + ActionListener.wrap( + results -> fail("Expected failure but got: " + results.getClass()), + e -> assertTrue(e instanceof IllegalArgumentException && e.getMessage().contains("Expected RankedDocsResults")) + ) + ); + + listener.onResponse(new InferenceServiceResults() { + }); + } + + public void testParseChunkedRerankResultsListener_EmptyInput() { + var chunker = new RerankRequestChunker(TEST_SENTENCE, List.of(), null); + var listener = chunker.parseChunkedRerankResultsListener(ActionListener.wrap(results -> { + assertThat(results, instanceOf(RankedDocsResults.class)); + var rankedDocResults = (RankedDocsResults) results; + assertEquals(0, rankedDocResults.getRankedDocs().size()); + }, e -> fail("Expected successful parsing but got failure: " + e))); + listener.onResponse(new RankedDocsResults(List.of())); + } + + public void testParseChunkedRerankResultsListener_SingleInputWithoutChunking() { + var inputs = List.of(generateTestText(10)); + var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs, null); + var listener = chunker.parseChunkedRerankResultsListener(ActionListener.wrap(results -> { + assertThat(results, instanceOf(RankedDocsResults.class)); + var rankedDocResults = (RankedDocsResults) results; + assertEquals(1, rankedDocResults.getRankedDocs().size()); + }, e -> fail("Expected successful parsing but got failure: " + e))); + + var chunkedInputs = chunker.getChunkedInputs(); + assertEquals(1, chunkedInputs.size()); + listener.onResponse(new RankedDocsResults(List.of(new RankedDocsResults.RankedDoc(0, 1.0f, chunkedInputs.get(0))))); + } + + public void testParseChunkedRerankResultsListener_SingleInputWithChunking() { + var inputs = List.of(generateTestText(100)); + var relevanceScore1 = randomFloatBetween(0, 1, true); + var relevanceScore2 = randomFloatBetween(0, 1, true); + var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs, null); + var listener = chunker.parseChunkedRerankResultsListener(ActionListener.wrap(results -> { + assertThat(results, instanceOf(RankedDocsResults.class)); + var rankedDocResults = (RankedDocsResults) results; + assertEquals(1, rankedDocResults.getRankedDocs().size()); + var expectedRankedDocs = List.of(new RankedDocsResults.RankedDoc(0, max(relevanceScore1, relevanceScore2), inputs.get(0))); + assertEquals(expectedRankedDocs, rankedDocResults.getRankedDocs()); + }, e -> fail("Expected successful parsing but got failure: " + e))); + + var chunkedInputs = chunker.getChunkedInputs(); + assertEquals(3, chunkedInputs.size()); + var rankedDocsResults = List.of( + new RankedDocsResults.RankedDoc(0, relevanceScore1, chunkedInputs.get(0)), + new RankedDocsResults.RankedDoc(1, relevanceScore2, chunkedInputs.get(1)) + ); + // TODO: Sort this so that the assumption that the results are in order holds + listener.onResponse(new RankedDocsResults(rankedDocsResults)); + } + + public void testParseChunkedRerankResultsListener_MultipleInputsWithoutChunking() { + var inputs = List.of(generateTestText(10), generateTestText(10)); + var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs, null); + var listener = chunker.parseChunkedRerankResultsListener(ActionListener.wrap(results -> { + assertThat(results, instanceOf(RankedDocsResults.class)); + var rankedDocResults = (RankedDocsResults) results; + assertEquals(2, rankedDocResults.getRankedDocs().size()); + var sortedResults = new ArrayList<>(rankedDocResults.getRankedDocs()); + sortedResults.sort((r1, r2) -> Float.compare(r2.relevanceScore(), r1.relevanceScore())); + assertEquals(sortedResults, rankedDocResults.getRankedDocs()); + }, e -> fail("Expected successful parsing but got failure: " + e))); + + var chunkedInputs = chunker.getChunkedInputs(); + assertEquals(2, chunkedInputs.size()); + listener.onResponse( + new RankedDocsResults( + List.of( + new RankedDocsResults.RankedDoc(0, randomFloatBetween(0, 1, true), chunkedInputs.get(0)), + new RankedDocsResults.RankedDoc(1, randomFloatBetween(0, 1, true), chunkedInputs.get(1)) + ) + ) + ); + } + + public void testParseChunkedRerankResultsListener_MultipleInputsWithSomeChunking() { + var inputs = List.of(generateTestText(10), generateTestText(100)); + var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs, null); + var listener = chunker.parseChunkedRerankResultsListener(ActionListener.wrap(results -> { + assertThat(results, instanceOf(RankedDocsResults.class)); + var rankedDocResults = (RankedDocsResults) results; + assertEquals(2, rankedDocResults.getRankedDocs().size()); + var sortedResults = new ArrayList<>(rankedDocResults.getRankedDocs()); + sortedResults.sort((r1, r2) -> Float.compare(r2.relevanceScore(), r1.relevanceScore())); + assertEquals(sortedResults, rankedDocResults.getRankedDocs()); + }, e -> fail("Expected successful parsing but got failure: " + e))); + + var chunkedInputs = chunker.getChunkedInputs(); + assertEquals(4, chunkedInputs.size()); + listener.onResponse( + new RankedDocsResults( + List.of( + new RankedDocsResults.RankedDoc(0, randomFloatBetween(0, 1, true), chunkedInputs.get(0)), + new RankedDocsResults.RankedDoc(1, randomFloatBetween(0, 1, true), chunkedInputs.get(1)), + new RankedDocsResults.RankedDoc(2, randomFloatBetween(0, 1, true), chunkedInputs.get(2)) + ) + ) + ); + } + + public void testParseChunkedRerankResultsListener_MultipleInputsWithAllRequiringChunking() { + var inputs = List.of(generateTestText(100), generateTestText(100)); + var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs, null); + var listener = chunker.parseChunkedRerankResultsListener(ActionListener.wrap(results -> { + assertThat(results, instanceOf(RankedDocsResults.class)); + var rankedDocResults = (RankedDocsResults) results; + assertEquals(2, rankedDocResults.getRankedDocs().size()); + var sortedResults = new ArrayList<>(rankedDocResults.getRankedDocs()); + sortedResults.sort((r1, r2) -> Float.compare(r2.relevanceScore(), r1.relevanceScore())); + assertEquals(sortedResults, rankedDocResults.getRankedDocs()); + }, e -> fail("Expected successful parsing but got failure: " + e))); + + var chunkedInputs = chunker.getChunkedInputs(); + assertEquals(6, chunkedInputs.size()); + listener.onResponse( + new RankedDocsResults( + List.of( + new RankedDocsResults.RankedDoc(0, randomFloatBetween(0, 1, true), chunkedInputs.get(0)), + new RankedDocsResults.RankedDoc(1, randomFloatBetween(0, 1, true), chunkedInputs.get(1)), + new RankedDocsResults.RankedDoc(2, randomFloatBetween(0, 1, true), chunkedInputs.get(2)), + new RankedDocsResults.RankedDoc(3, randomFloatBetween(0, 1, true), chunkedInputs.get(3)) + ) + ) + ); + } + + private String generateTestText(int numSentences) { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < numSentences; i++) { + sb.append(TEST_SENTENCE); + } + return sb.toString(); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticRerankerServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticRerankerServiceSettingsTests.java new file mode 100644 index 0000000000000..c9ee6a0543140 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticRerankerServiceSettingsTests.java @@ -0,0 +1,402 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elasticsearch; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; +import org.junit.Assert; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; + +import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticRerankerServiceSettings.LONG_DOCUMENT_STRATEGY; +import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticRerankerServiceSettings.MAX_CHUNKS_PER_DOC; +import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.ELASTIC_RERANKER_CHUNKING; +import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings.ADAPTIVE_ALLOCATIONS; +import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings.MODEL_ID; +import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings.NUM_ALLOCATIONS; +import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings.NUM_THREADS; + +public class ElasticRerankerServiceSettingsTests extends AbstractWireSerializingTestCase { + public static ElasticRerankerServiceSettings createRandomWithoutChunkingConfiguration() { + return createRandom(null, null); + } + + public static ElasticRerankerServiceSettings createRandomWithChunkingConfiguration( + ElasticRerankerServiceSettings.LongDocumentStrategy longDocumentStrategy, + Integer maxChunksPerDoc + ) { + return createRandom(longDocumentStrategy, maxChunksPerDoc); + } + + public static ElasticRerankerServiceSettings createRandom() { + var longDocumentStrategy = ELASTIC_RERANKER_CHUNKING.isEnabled() + ? randomFrom(ElasticRerankerServiceSettings.LongDocumentStrategy.values()) + : null; + var maxChunksPerDoc = ELASTIC_RERANKER_CHUNKING.isEnabled() + && ElasticRerankerServiceSettings.LongDocumentStrategy.CHUNK.equals(longDocumentStrategy) + && randomBoolean() ? randomIntBetween(1, 10) : null; + return createRandom(longDocumentStrategy, maxChunksPerDoc); + } + + private static ElasticRerankerServiceSettings createRandom( + ElasticRerankerServiceSettings.LongDocumentStrategy longDocumentStrategy, + Integer maxChunksPerDoc + ) { + var withAdaptiveAllocations = randomBoolean(); + var numAllocations = withAdaptiveAllocations ? null : randomIntBetween(1, 10); + var numThreads = randomIntBetween(1, 10); + var modelId = randomAlphaOfLength(8); + var adaptiveAllocationsSettings = withAdaptiveAllocations + ? new AdaptiveAllocationsSettings(true, randomIntBetween(0, 2), randomIntBetween(2, 5)) + : null; + return new ElasticRerankerServiceSettings( + numAllocations, + numThreads, + modelId, + adaptiveAllocationsSettings, + longDocumentStrategy, + maxChunksPerDoc + ); + } + + public void testFromMap_NonAdaptiveAllocationsBaseSettings_CreatesSettingsCorrectly() { + var numAllocations = randomIntBetween(1, 10); + var numThreads = randomIntBetween(1, 10); + var modelId = randomAlphaOfLength(8); + + Map settingsMap = buildServiceSettingsMap( + Optional.of(numAllocations), + numThreads, + modelId, + Optional.empty(), + Optional.empty(), + Optional.empty() + ); + + ElasticRerankerServiceSettings settings = ElasticRerankerServiceSettings.fromMap(settingsMap); + assertExpectedSettings( + settings, + Optional.of(numAllocations), + numThreads, + modelId, + Optional.empty(), + Optional.empty(), + Optional.empty() + ); + } + + public void testFromMap_AdaptiveAllocationsBaseSettings_CreatesSettingsCorrectly() { + var numThreads = randomIntBetween(1, 10); + var modelId = randomAlphaOfLength(8); + var adaptiveAllocationsSettings = new AdaptiveAllocationsSettings(true, randomIntBetween(0, 2), randomIntBetween(2, 5)); + + Map settingsMap = buildServiceSettingsMap( + Optional.empty(), + numThreads, + modelId, + Optional.of(adaptiveAllocationsSettings), + Optional.empty(), + Optional.empty() + ); + + ElasticRerankerServiceSettings settings = ElasticRerankerServiceSettings.fromMap(settingsMap); + assertExpectedSettings( + settings, + Optional.empty(), + numThreads, + modelId, + Optional.of(adaptiveAllocationsSettings), + Optional.empty(), + Optional.empty() + ); + } + + public void testFromMap_NumAllocationsAndAdaptiveAllocationsNull_ThrowsValidationException() { + var numThreads = randomIntBetween(1, 10); + var modelId = randomAlphaOfLength(8); + + Map settingsMap = buildServiceSettingsMap( + Optional.empty(), + numThreads, + modelId, + Optional.empty(), + Optional.empty(), + Optional.empty() + ); + + ValidationException exception = Assert.assertThrows( + ValidationException.class, + () -> ElasticRerankerServiceSettings.fromMap(settingsMap) + ); + + assertTrue( + exception.getMessage() + .contains("[service_settings] does not contain one of the required settings [num_allocations, adaptive_allocations]") + ); + } + + public void testFromMap_ChunkingFeatureFlagDisabledAndLongDocumentStrategyProvided_CreatesSettingsIgnoringStrategy() { + assumeTrue( + "Only if 'elastic_reranker_chunking_long_documents' feature flag is disabled", + ELASTIC_RERANKER_CHUNKING.isEnabled() == false + ); + var withAdaptiveAllocations = randomBoolean(); + var numAllocations = withAdaptiveAllocations ? null : randomIntBetween(1, 10); + var numThreads = randomIntBetween(1, 10); + var modelId = randomAlphaOfLength(8); + var adaptiveAllocationsSettings = withAdaptiveAllocations + ? new AdaptiveAllocationsSettings(true, randomIntBetween(0, 2), randomIntBetween(2, 5)) + : null; + var longDocumentStrategy = ElasticRerankerServiceSettings.LongDocumentStrategy.TRUNCATE; + + Map settingsMap = buildServiceSettingsMap( + withAdaptiveAllocations ? Optional.empty() : Optional.of(numAllocations), + numThreads, + modelId, + withAdaptiveAllocations ? Optional.of(adaptiveAllocationsSettings) : Optional.empty(), + Optional.of(longDocumentStrategy), + Optional.empty() + ); + + ElasticRerankerServiceSettings settings = ElasticRerankerServiceSettings.fromMap(settingsMap); + assertExpectedSettings( + settings, + Optional.ofNullable(numAllocations), + numThreads, + modelId, + Optional.ofNullable(adaptiveAllocationsSettings), + Optional.empty(), + Optional.empty() + ); + } + + public void testFromMap_ChunkingFeatureFlagDisabledAndMaxChunksPerDocProvided_CreatesSettingsIgnoringMaxChunksPerDoc() { + assumeTrue( + "Only if 'elastic_reranker_chunking_long_documents' feature flag is disabled", + ELASTIC_RERANKER_CHUNKING.isEnabled() == false + ); + var withAdaptiveAllocations = randomBoolean(); + var numAllocations = withAdaptiveAllocations ? null : randomIntBetween(1, 10); + var numThreads = randomIntBetween(1, 10); + var modelId = randomAlphaOfLength(8); + var adaptiveAllocationsSettings = withAdaptiveAllocations + ? new AdaptiveAllocationsSettings(true, randomIntBetween(0, 2), randomIntBetween(2, 5)) + : null; + var maxChunksPerDoc = randomIntBetween(1, 10); + + Map settingsMap = buildServiceSettingsMap( + withAdaptiveAllocations ? Optional.empty() : Optional.of(numAllocations), + numThreads, + modelId, + withAdaptiveAllocations ? Optional.of(adaptiveAllocationsSettings) : Optional.empty(), + Optional.empty(), + Optional.of(maxChunksPerDoc) + ); + + ElasticRerankerServiceSettings settings = ElasticRerankerServiceSettings.fromMap(settingsMap); + assertExpectedSettings( + settings, + Optional.ofNullable(numAllocations), + numThreads, + modelId, + Optional.ofNullable(adaptiveAllocationsSettings), + Optional.empty(), + Optional.empty() + ); + } + + public void testFromMap_ChunkingFeatureFlagEnabledAndTruncateSelected_CreatesSettingsCorrectly() { + assumeTrue("Only if 'elastic_reranker_chunking_long_documents' feature flag is enabled", ELASTIC_RERANKER_CHUNKING.isEnabled()); + var withAdaptiveAllocations = randomBoolean(); + var numAllocations = withAdaptiveAllocations ? null : randomIntBetween(1, 10); + var numThreads = randomIntBetween(1, 10); + var modelId = randomAlphaOfLength(8); + var adaptiveAllocationsSettings = withAdaptiveAllocations + ? new AdaptiveAllocationsSettings(true, randomIntBetween(0, 2), randomIntBetween(2, 5)) + : null; + var longDocumentStrategy = ElasticRerankerServiceSettings.LongDocumentStrategy.TRUNCATE; + + Map settingsMap = buildServiceSettingsMap( + withAdaptiveAllocations ? Optional.empty() : Optional.of(numAllocations), + numThreads, + modelId, + withAdaptiveAllocations ? Optional.of(adaptiveAllocationsSettings) : Optional.empty(), + Optional.of(longDocumentStrategy), + Optional.empty() + ); + + ElasticRerankerServiceSettings settings = ElasticRerankerServiceSettings.fromMap(settingsMap); + assertExpectedSettings( + settings, + Optional.ofNullable(numAllocations), + numThreads, + modelId, + Optional.ofNullable(adaptiveAllocationsSettings), + Optional.of(longDocumentStrategy), + Optional.empty() + ); + } + + public void testFromMap_ChunkingFeatureFlagEnabledAndTruncateSelectedWithMaxChunksPerDoc_ThrowsValidationException() { + assumeTrue("Only if 'elastic_reranker_chunking_long_documents' feature flag is enabled", ELASTIC_RERANKER_CHUNKING.isEnabled()); + var withAdaptiveAllocations = randomBoolean(); + var numAllocations = withAdaptiveAllocations ? null : randomIntBetween(1, 10); + var numThreads = randomIntBetween(1, 10); + var modelId = randomAlphaOfLength(8); + var adaptiveAllocationsSettings = withAdaptiveAllocations + ? new AdaptiveAllocationsSettings(true, randomIntBetween(0, 2), randomIntBetween(2, 5)) + : null; + var longDocumentStrategy = ElasticRerankerServiceSettings.LongDocumentStrategy.TRUNCATE; + var maxChunksPerDoc = randomIntBetween(1, 10); + + Map settingsMap = buildServiceSettingsMap( + withAdaptiveAllocations ? Optional.empty() : Optional.of(numAllocations), + numThreads, + modelId, + withAdaptiveAllocations ? Optional.of(adaptiveAllocationsSettings) : Optional.empty(), + Optional.of(longDocumentStrategy), + Optional.of(maxChunksPerDoc) + ); + + ValidationException exception = Assert.assertThrows( + ValidationException.class, + () -> ElasticRerankerServiceSettings.fromMap(settingsMap) + ); + + assertTrue( + exception.getMessage().contains("The [max_chunks_per_doc] setting requires [long_document_strategy] to be set to [chunk]") + ); + } + + public void testFromMap_ChunkingFeatureFlagEnabledAndChunkSelected_CreatesSettingsCorrectly() { + assumeTrue("Only if 'elastic_reranker_chunking_long_documents' feature flag is enabled", ELASTIC_RERANKER_CHUNKING.isEnabled()); + var withAdaptiveAllocations = randomBoolean(); + var numAllocations = withAdaptiveAllocations ? null : randomIntBetween(1, 10); + var numThreads = randomIntBetween(1, 10); + var modelId = randomAlphaOfLength(8); + var adaptiveAllocationsSettings = withAdaptiveAllocations + ? new AdaptiveAllocationsSettings(true, randomIntBetween(0, 2), randomIntBetween(2, 5)) + : null; + var longDocumentStrategy = ElasticRerankerServiceSettings.LongDocumentStrategy.CHUNK; + var maxChunksPerDoc = randomIntBetween(1, 10); + + Map settingsMap = buildServiceSettingsMap( + withAdaptiveAllocations ? Optional.empty() : Optional.of(numAllocations), + numThreads, + modelId, + withAdaptiveAllocations ? Optional.of(adaptiveAllocationsSettings) : Optional.empty(), + Optional.of(longDocumentStrategy), + Optional.of(maxChunksPerDoc) + ); + + ElasticRerankerServiceSettings settings = ElasticRerankerServiceSettings.fromMap(settingsMap); + assertExpectedSettings( + settings, + Optional.ofNullable(numAllocations), + numThreads, + modelId, + Optional.ofNullable(adaptiveAllocationsSettings), + Optional.of(longDocumentStrategy), + Optional.of(maxChunksPerDoc) + ); + } + + public void testFromMap_ChunkingFeatureFlagEnabledAndChunkSelectedWithMaxChunksPerDoc_CreatesSettingsCorrectly() { + assumeTrue("Only if 'elastic_reranker_chunking_long_documents' feature flag is enabled", ELASTIC_RERANKER_CHUNKING.isEnabled()); + var withAdaptiveAllocations = randomBoolean(); + var numAllocations = withAdaptiveAllocations ? null : randomIntBetween(1, 10); + var numThreads = randomIntBetween(1, 10); + var modelId = randomAlphaOfLength(8); + var adaptiveAllocationsSettings = withAdaptiveAllocations + ? new AdaptiveAllocationsSettings(true, randomIntBetween(0, 2), randomIntBetween(2, 5)) + : null; + var longDocumentStrategy = ElasticRerankerServiceSettings.LongDocumentStrategy.CHUNK; + var maxChunksPerDoc = randomIntBetween(1, 10); + + Map settingsMap = buildServiceSettingsMap( + withAdaptiveAllocations ? Optional.empty() : Optional.of(numAllocations), + numThreads, + modelId, + withAdaptiveAllocations ? Optional.of(adaptiveAllocationsSettings) : Optional.empty(), + Optional.of(longDocumentStrategy), + Optional.of(maxChunksPerDoc) + ); + + ElasticRerankerServiceSettings settings = ElasticRerankerServiceSettings.fromMap(settingsMap); + assertExpectedSettings( + settings, + Optional.ofNullable(numAllocations), + numThreads, + modelId, + Optional.ofNullable(adaptiveAllocationsSettings), + Optional.of(longDocumentStrategy), + Optional.of(maxChunksPerDoc) + ); + } + + private Map buildServiceSettingsMap( + Optional numAllocations, + int numThreads, + String modelId, + Optional adaptiveAllocationsSettings, + Optional longDocumentStrategy, + Optional maxChunksPerDoc + ) { + var settingsMap = new HashMap(); + numAllocations.ifPresent(value -> settingsMap.put(NUM_ALLOCATIONS, value)); + settingsMap.put(NUM_THREADS, numThreads); + settingsMap.put(MODEL_ID, modelId); + adaptiveAllocationsSettings.ifPresent(settings -> { + var adaptiveMap = new HashMap(); + adaptiveMap.put(AdaptiveAllocationsSettings.ENABLED.getPreferredName(), settings.getEnabled()); + adaptiveMap.put(AdaptiveAllocationsSettings.MIN_NUMBER_OF_ALLOCATIONS.getPreferredName(), settings.getMinNumberOfAllocations()); + adaptiveMap.put(AdaptiveAllocationsSettings.MAX_NUMBER_OF_ALLOCATIONS.getPreferredName(), settings.getMaxNumberOfAllocations()); + settingsMap.put(ADAPTIVE_ALLOCATIONS, adaptiveMap); + }); + longDocumentStrategy.ifPresent(value -> settingsMap.put(LONG_DOCUMENT_STRATEGY, value.toString())); + maxChunksPerDoc.ifPresent(value -> settingsMap.put(MAX_CHUNKS_PER_DOC, value)); + return settingsMap; + } + + private void assertExpectedSettings( + ElasticRerankerServiceSettings settings, + Optional expectedNumAllocations, + int expectedNumThreads, + String expectedModelId, + Optional expectedAdaptiveAllocationsSettings, + Optional expectedLongDocumentStrategy, + Optional expectedMaxChunksPerDoc + ) { + assertEquals(expectedNumAllocations.orElse(null), settings.getNumAllocations()); + assertEquals(expectedNumThreads, settings.getNumThreads()); + assertEquals(expectedModelId, settings.modelId()); + assertEquals(expectedAdaptiveAllocationsSettings.orElse(null), settings.getAdaptiveAllocationsSettings()); + assertEquals(expectedLongDocumentStrategy.orElse(null), settings.getLongDocumentStrategy()); + assertEquals(expectedMaxChunksPerDoc.orElse(null), settings.getMaxChunksPerDoc()); + } + + @Override + protected Writeable.Reader instanceReader() { + return ElasticRerankerServiceSettings::new; + } + + @Override + protected ElasticRerankerServiceSettings createTestInstance() { + return createRandom(); + } + + @Override + protected ElasticRerankerServiceSettings mutateInstance(ElasticRerankerServiceSettings instance) throws IOException { + return randomValueOtherThan(instance, ElasticRerankerServiceSettingsTests::createRandom); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java index 3af19bf46c62e..e9f22f4848991 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java @@ -53,6 +53,7 @@ import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; +import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.core.ml.MachineLearningField; @@ -74,6 +75,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResultsTests; import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResultsTests; +import org.elasticsearch.xpack.core.ml.inference.results.TextSimilarityInferenceResultsTests; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfig; @@ -83,6 +85,7 @@ import org.elasticsearch.xpack.inference.InputTypeTests; import org.elasticsearch.xpack.inference.ModelConfigurationsTests; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests; +import org.elasticsearch.xpack.inference.chunking.RerankRequestChunker; import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings; import org.elasticsearch.xpack.inference.services.InferenceServiceTestCase; import org.elasticsearch.xpack.inference.services.ServiceFields; @@ -116,6 +119,8 @@ import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; import static org.elasticsearch.xpack.inference.Utils.mockClusterService; import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; +import static org.elasticsearch.xpack.inference.services.elasticsearch.BaseElasticsearchInternalService.notElasticsearchModelException; +import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.ELASTIC_RERANKER_CHUNKING; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.NAME; @@ -143,6 +148,8 @@ public class ElasticsearchInternalServiceTests extends InferenceServiceTestCase private static ThreadPool threadPool; + private final String TEST_SENTENCE = "This is a test sentence that has ten total words. "; + @Before public void setUp() throws Exception { super.setUp(); @@ -980,6 +987,163 @@ public void testUpdateModelWithEmbeddingDetails_ElasticsearchInternalModelNotMod verifyNoMoreInteractions(model); } + public void testInfer_UnsupportedModel() { + var service = createService(mock(Client.class)); + var model = new Model(ModelConfigurationsTests.createRandomInstance()); + + ActionListener listener = ActionListener.wrap( + results -> fail("Expected infer to fail for unsupported model type"), + e -> assertEquals(e.getMessage(), notElasticsearchModelException(model).getMessage()) + ); + + service.infer(model, null, null, null, List.of(), randomBoolean(), Map.of(), InputType.INGEST, null, listener); + } + + public void testInfer_ElasticRerankerSucceedsWithoutChunkingConfiguration() { + var model = new ElasticRerankerModel( + randomAlphaOfLength(10), + TaskType.RERANK, + NAME, + ElasticRerankerServiceSettingsTests.createRandomWithoutChunkingConfiguration(), + new RerankTaskSettings(randomBoolean()) + ); + + testInfer_ElasticReranker(model, generateTestDocs(randomIntBetween(2, 10), randomIntBetween(50, 100))); + } + + public void testInfer_ElasticRerankerFeatureFlagDisabledSucceedsWithTruncateConfiguration() { + assumeTrue( + "Only if 'elastic_reranker_chunking_long_documents' feature flag is disabled", + ELASTIC_RERANKER_CHUNKING.isEnabled() == false + ); + + var model = new ElasticRerankerModel( + randomAlphaOfLength(10), + TaskType.RERANK, + NAME, + ElasticRerankerServiceSettingsTests.createRandomWithChunkingConfiguration( + ElasticRerankerServiceSettings.LongDocumentStrategy.TRUNCATE, + null + ), + new RerankTaskSettings(randomBoolean()) + ); + + testInfer_ElasticReranker(model, generateTestDocs(randomIntBetween(2, 10), randomIntBetween(50, 100))); + } + + public void testInfer_ElasticRerankerFeatureFlagDisabledSucceedsIgnoringChunkConfiguration() { + assumeTrue( + "Only if 'elastic_reranker_chunking_long_documents' feature flag is disabled", + ELASTIC_RERANKER_CHUNKING.isEnabled() == false + ); + + var model = new ElasticRerankerModel( + randomAlphaOfLength(10), + TaskType.RERANK, + NAME, + ElasticRerankerServiceSettingsTests.createRandomWithChunkingConfiguration( + ElasticRerankerServiceSettings.LongDocumentStrategy.CHUNK, + randomBoolean() ? randomIntBetween(1, 10) : null + ), + new RerankTaskSettings(randomBoolean()) + ); + + testInfer_ElasticReranker(model, generateTestDocs(randomIntBetween(2, 10), randomIntBetween(50, 100))); + } + + public void testInfer_ElasticRerankerFeatureFlagEnabledAndSucceedsWithTruncateStrategy() { + assumeTrue("Only if 'elastic_reranker_chunking_long_documents' feature flag is enabled", ELASTIC_RERANKER_CHUNKING.isEnabled()); + + var model = new ElasticRerankerModel( + randomAlphaOfLength(10), + TaskType.RERANK, + NAME, + ElasticRerankerServiceSettingsTests.createRandomWithChunkingConfiguration( + ElasticRerankerServiceSettings.LongDocumentStrategy.TRUNCATE, + null + ), + new RerankTaskSettings(randomBoolean()) + ); + + testInfer_ElasticReranker(model, generateTestDocs(randomIntBetween(2, 10), randomIntBetween(50, 100))); + } + + public void testInfer_ElasticRerankerFeatureFlagEnabledAndSucceedsWithChunkStrategy() { + assumeTrue("Only if 'elastic_reranker_chunking_long_documents' feature flag is enabled", ELASTIC_RERANKER_CHUNKING.isEnabled()); + + var model = new ElasticRerankerModel( + randomAlphaOfLength(10), + TaskType.RERANK, + NAME, + ElasticRerankerServiceSettingsTests.createRandomWithChunkingConfiguration( + ElasticRerankerServiceSettings.LongDocumentStrategy.CHUNK, + randomBoolean() ? randomIntBetween(1, 10) : null + ), + new RerankTaskSettings(randomBoolean()) + ); + + testInfer_ElasticReranker(model, generateTestDocs(randomIntBetween(2, 10), randomIntBetween(50, 100))); + } + + @SuppressWarnings("unchecked") + private void testInfer_ElasticReranker(ElasticRerankerModel model, List inputs) { + var query = randomAlphaOfLength(10); + var mlTrainedModelResults = new ArrayList(); + var numResults = inputs.size(); + if (ELASTIC_RERANKER_CHUNKING.isEnabled() + && ElasticRerankerServiceSettings.LongDocumentStrategy.CHUNK.equals(model.getServiceSettings().getLongDocumentStrategy())) { + var rerankRequestChunker = new RerankRequestChunker(query, inputs, model.getServiceSettings().getMaxChunksPerDoc()); + numResults = rerankRequestChunker.getChunkedInputs().size(); + } + for (int i = 0; i < numResults; i++) { + mlTrainedModelResults.add(TextSimilarityInferenceResultsTests.createRandomResults()); + } + var response = new InferModelAction.Response(mlTrainedModelResults, "foo", true); + + Client client = mock(Client.class); + when(client.threadPool()).thenReturn(threadPool); + doAnswer(invocationOnMock -> { + var listener = (ActionListener) invocationOnMock.getArguments()[2]; + listener.onResponse(response); + return null; + }).when(client).execute(same(InferModelAction.INSTANCE), any(InferModelAction.Request.class), any(ActionListener.class)); + + var service = createService(client); + var topN = randomBoolean() ? null : randomIntBetween(1, inputs.size()); + + ActionListener listener = ActionListener.wrap(results -> { + assertThat(results, instanceOf(RankedDocsResults.class)); + var rankedDocsResults = (RankedDocsResults) results; + assertEquals(topN == null ? inputs.size() : topN, rankedDocsResults.getRankedDocs().size()); + + }, ESTestCase::fail); + + service.infer( + model, + randomAlphaOfLength(10), + randomBoolean() ? null : randomBoolean(), + topN, + inputs, + false, + Map.of(), + InputType.INGEST, + null, + listener + ); + } + + private List generateTestDocs(int numDocs, int numSentencesPerDoc) { + var docs = new ArrayList(); + for (int docIndex = 0; docIndex < numDocs; docIndex++) { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < numSentencesPerDoc; i++) { + sb.append(TEST_SENTENCE); + } + docs.add(sb.toString()); + } + return docs; + } + public void testChunkInfer_E5WithNullChunkingSettings() throws InterruptedException { testChunkInfer_e5(null); } diff --git a/x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java b/x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java index f39b3f2b01368..c87d7fb40f63b 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java +++ b/x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java @@ -13,6 +13,7 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.test.cluster.ElasticsearchCluster; +import org.elasticsearch.test.cluster.FeatureFlag; import org.elasticsearch.test.cluster.local.distribution.DistributionType; import org.elasticsearch.test.rest.yaml.ClientYamlTestCandidate; import org.elasticsearch.test.rest.yaml.ESClientYamlSuiteTestCase; @@ -31,6 +32,7 @@ public class InferenceRestIT extends ESClientYamlSuiteTestCase { .setting("xpack.security.enabled", "false") .setting("xpack.security.http.ssl.enabled", "false") .setting("xpack.license.self_generated.type", "trial") + .feature(FeatureFlag.ELASTIC_RERANKER_CHUNKING) .plugin("inference-service-test") .distribution(DistributionType.DEFAULT) .build();