diff --git a/test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java b/test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java index 644a5b46f4420..56fdac88dbfe1 100644 --- a/test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java +++ b/test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java @@ -25,8 +25,7 @@ public enum FeatureFlag { "es.index_dimensions_tsid_optimization_feature_flag_enabled=true", Version.fromString("9.2.0"), null - ), - ELASTIC_RERANKER_CHUNKING("es.elastic_reranker_chunking_long_documents=true", Version.fromString("9.2.0"), null); + ); public final String systemProperty; public final Version from; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RerankRequestChunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RerankRequestChunker.java index 87feb19986583..2b31796e0c640 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RerankRequestChunker.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RerankRequestChunker.java @@ -53,10 +53,13 @@ public List getChunkedInputs() { return chunkedInputs; } - public ActionListener parseChunkedRerankResultsListener(ActionListener listener) { + public ActionListener parseChunkedRerankResultsListener( + ActionListener listener, + boolean returnDocuments + ) { return ActionListener.wrap(results -> { if (results instanceof RankedDocsResults rankedDocsResults) { - listener.onResponse(parseRankedDocResultsForChunks(rankedDocsResults)); + listener.onResponse(parseRankedDocResultsForChunks(rankedDocsResults, returnDocuments)); } else { listener.onFailure(new IllegalArgumentException("Expected RankedDocsResults but got: " + results.getClass())); @@ -65,7 +68,7 @@ public ActionListener parseChunkedRerankResultsListener }, listener::onFailure); } - private RankedDocsResults parseRankedDocResultsForChunks(RankedDocsResults rankedDocsResults) { + private RankedDocsResults parseRankedDocResultsForChunks(RankedDocsResults rankedDocsResults, boolean returnDocuments) { List topRankedDocs = new ArrayList<>(); Set docIndicesSeen = new HashSet<>(); @@ -80,7 +83,7 @@ private RankedDocsResults parseRankedDocResultsForChunks(RankedDocsResults ranke RankedDocsResults.RankedDoc updatedRankedDoc = new RankedDocsResults.RankedDoc( docIndex, rankedDoc.relevanceScore(), - inputs.get(docIndex) + returnDocuments ? inputs.get(docIndex) : null ); topRankedDocs.add(updatedRankedDoc); docIndicesSeen.add(docIndex); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticRerankerServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticRerankerServiceSettings.java index dbf7c5132c996..89ceaf2493dc1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticRerankerServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticRerankerServiceSettings.java @@ -22,7 +22,6 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; -import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.ELASTIC_RERANKER_CHUNKING; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.RERANKER_ID; public class ElasticRerankerServiceSettings extends ElasticsearchInternalServiceSettings { @@ -102,30 +101,26 @@ public static ElasticRerankerServiceSettings fromMap(Map map) { ValidationException validationException = new ValidationException(); var baseSettings = ElasticsearchInternalServiceSettings.fromMap(map, validationException); - LongDocumentStrategy longDocumentStrategy = null; - Integer maxChunksPerDoc = null; - if (ELASTIC_RERANKER_CHUNKING.isEnabled()) { - longDocumentStrategy = extractOptionalEnum( - map, - LONG_DOCUMENT_STRATEGY, - ModelConfigurations.SERVICE_SETTINGS, - LongDocumentStrategy::fromString, - EnumSet.allOf(LongDocumentStrategy.class), - validationException + LongDocumentStrategy longDocumentStrategy = extractOptionalEnum( + map, + LONG_DOCUMENT_STRATEGY, + ModelConfigurations.SERVICE_SETTINGS, + LongDocumentStrategy::fromString, + EnumSet.allOf(LongDocumentStrategy.class), + validationException + ); + + Integer maxChunksPerDoc = extractOptionalPositiveInteger( + map, + MAX_CHUNKS_PER_DOC, + ModelConfigurations.SERVICE_SETTINGS, + validationException + ); + + if (maxChunksPerDoc != null && (longDocumentStrategy == null || longDocumentStrategy == LongDocumentStrategy.TRUNCATE)) { + validationException.addValidationError( + "The [" + MAX_CHUNKS_PER_DOC + "] setting requires [" + LONG_DOCUMENT_STRATEGY + "] to be set to [chunk]" ); - - maxChunksPerDoc = extractOptionalPositiveInteger( - map, - MAX_CHUNKS_PER_DOC, - ModelConfigurations.SERVICE_SETTINGS, - validationException - ); - - if (maxChunksPerDoc != null && (longDocumentStrategy == null || longDocumentStrategy == LongDocumentStrategy.TRUNCATE)) { - validationException.addValidationError( - "The [" + MAX_CHUNKS_PER_DOC + "] setting requires [" + LONG_DOCUMENT_STRATEGY + "] to be set to [chunk]" - ); - } } if (validationException.validationErrors().isEmpty() == false) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index 8bf8043a1ec0d..2c1ee96b519a3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -16,7 +16,6 @@ import org.elasticsearch.common.logging.DeprecationCategory; import org.elasticsearch.common.logging.DeprecationLogger; import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.common.util.FeatureFlag; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Strings; @@ -116,8 +115,6 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi private static final Logger logger = LogManager.getLogger(ElasticsearchInternalService.class); private static final DeprecationLogger DEPRECATION_LOGGER = DeprecationLogger.getLogger(ElasticsearchInternalService.class); - public static final FeatureFlag ELASTIC_RERANKER_CHUNKING = new FeatureFlag("elastic_reranker_chunking_long_documents"); - /** * Fix for https://github.com/elastic/elasticsearch/issues/124675 * In 8.13.0 we transitioned from model_version to model_id. Any elser inference endpoints created prior to 8.13.0 will still use @@ -698,26 +695,26 @@ public void inferRerank( } }); - if (model instanceof ElasticRerankerModel elasticRerankerModel && ELASTIC_RERANKER_CHUNKING.isEnabled()) { + var returnDocs = Boolean.TRUE; + if (returnDocuments != null) { + returnDocs = returnDocuments; + } else if (model.getTaskSettings() instanceof RerankTaskSettings modelSettings) { + var requestSettings = RerankTaskSettings.fromMap(requestTaskSettings); + returnDocs = RerankTaskSettings.of(modelSettings, requestSettings).returnDocuments(); + } + + if (model instanceof ElasticRerankerModel elasticRerankerModel) { var serviceSettings = elasticRerankerModel.getServiceSettings(); var longDocumentStrategy = serviceSettings.getLongDocumentStrategy(); if (longDocumentStrategy == ElasticRerankerServiceSettings.LongDocumentStrategy.CHUNK) { var rerankChunker = new RerankRequestChunker(query, inputs, serviceSettings.getMaxChunksPerDoc()); inputs = rerankChunker.getChunkedInputs(); - resultsListener = rerankChunker.parseChunkedRerankResultsListener(resultsListener); + resultsListener = rerankChunker.parseChunkedRerankResultsListener(resultsListener, returnDocs); } } var request = buildInferenceRequest(model.mlNodeDeploymentId(), new TextSimilarityConfigUpdate(query), inputs, inputType, timeout); - var returnDocs = Boolean.TRUE; - if (returnDocuments != null) { - returnDocs = returnDocuments; - } else if (model.getTaskSettings() instanceof RerankTaskSettings modelSettings) { - var requestSettings = RerankTaskSettings.fromMap(requestTaskSettings); - returnDocs = RerankTaskSettings.of(modelSettings, requestSettings).returnDocuments(); - } - Function inputSupplier = returnDocs == Boolean.TRUE ? inputs::get : i -> null; ActionListener mlResultsListener = resultsListener.delegateFailureAndWrap( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/RerankRequestChunkerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/RerankRequestChunkerTests.java index 5674fb3b73c98..31770f971dc8a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/RerankRequestChunkerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/RerankRequestChunkerTests.java @@ -13,6 +13,7 @@ import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import static java.lang.Math.max; @@ -111,7 +112,8 @@ public void testParseChunkedRerankResultsListener_NonRankedDocsResults() { ActionListener.wrap( results -> fail("Expected failure but got: " + results.getClass()), e -> assertTrue(e instanceof IllegalArgumentException && e.getMessage().contains("Expected RankedDocsResults")) - ) + ), + randomBoolean() ); listener.onResponse(new InferenceServiceResults() { @@ -124,120 +126,152 @@ public void testParseChunkedRerankResultsListener_EmptyInput() { assertThat(results, instanceOf(RankedDocsResults.class)); var rankedDocResults = (RankedDocsResults) results; assertEquals(0, rankedDocResults.getRankedDocs().size()); - }, e -> fail("Expected successful parsing but got failure: " + e))); + }, e -> fail("Expected successful parsing but got failure: " + e)), randomBoolean()); listener.onResponse(new RankedDocsResults(List.of())); } public void testParseChunkedRerankResultsListener_SingleInputWithoutChunking() { var inputs = List.of(generateTestText(10)); var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs, null); + var returnDocuments = randomBoolean(); + var relevanceScores = generateRelevanceScores(1); var listener = chunker.parseChunkedRerankResultsListener(ActionListener.wrap(results -> { assertThat(results, instanceOf(RankedDocsResults.class)); var rankedDocResults = (RankedDocsResults) results; - assertEquals(1, rankedDocResults.getRankedDocs().size()); - }, e -> fail("Expected successful parsing but got failure: " + e))); + var expectedRankedDocs = List.of( + new RankedDocsResults.RankedDoc(0, relevanceScores.get(0), returnDocuments ? inputs.get(0) : null) + ); + assertEquals(expectedRankedDocs, rankedDocResults.getRankedDocs()); + }, e -> fail("Expected successful parsing but got failure: " + e)), returnDocuments); var chunkedInputs = chunker.getChunkedInputs(); assertEquals(1, chunkedInputs.size()); - listener.onResponse(new RankedDocsResults(List.of(new RankedDocsResults.RankedDoc(0, 1.0f, chunkedInputs.get(0))))); + listener.onResponse( + new RankedDocsResults(List.of(new RankedDocsResults.RankedDoc(0, relevanceScores.get(0), chunkedInputs.get(0)))) + ); } - public void testParseChunkedRerankResultsListener_SingleInputWithChunking() { + public void testParseChunkedRerankResultsListener_SingleInputWithChunkingWithFirstChunkRelevanceScoreHighest() { var inputs = List.of(generateTestText(100)); - var relevanceScore1 = randomFloatBetween(0, 1, true); - var relevanceScore2 = randomFloatBetween(0, 1, true); + var relevanceScores = List.of(1f, 0.5f, 0.3f); var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs, null); + var returnDocuments = randomBoolean(); var listener = chunker.parseChunkedRerankResultsListener(ActionListener.wrap(results -> { assertThat(results, instanceOf(RankedDocsResults.class)); var rankedDocResults = (RankedDocsResults) results; - assertEquals(1, rankedDocResults.getRankedDocs().size()); - var expectedRankedDocs = List.of(new RankedDocsResults.RankedDoc(0, max(relevanceScore1, relevanceScore2), inputs.get(0))); + var expectedRankedDocs = List.of( + new RankedDocsResults.RankedDoc(0, relevanceScores.get(0), returnDocuments ? inputs.get(0) : null) + ); assertEquals(expectedRankedDocs, rankedDocResults.getRankedDocs()); - }, e -> fail("Expected successful parsing but got failure: " + e))); + }, e -> fail("Expected successful parsing but got failure: " + e)), returnDocuments); var chunkedInputs = chunker.getChunkedInputs(); assertEquals(3, chunkedInputs.size()); - var rankedDocsResults = List.of( - new RankedDocsResults.RankedDoc(0, relevanceScore1, chunkedInputs.get(0)), - new RankedDocsResults.RankedDoc(1, relevanceScore2, chunkedInputs.get(1)) - ); - // TODO: Sort this so that the assumption that the results are in order holds - listener.onResponse(new RankedDocsResults(rankedDocsResults)); + listener.onResponse(new RankedDocsResults(generateRankedDocs(relevanceScores, chunkedInputs))); + } + + public void testParseChunkedRerankResultsListener_SingleInputWithChunkingWithMiddleChunkRelevanceScoreHighest() { + var inputs = List.of(generateTestText(100)); + var relevanceScores = List.of(0.5f, 1f, 0.5f); + var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs, null); + var returnDocuments = randomBoolean(); + var listener = chunker.parseChunkedRerankResultsListener(ActionListener.wrap(results -> { + assertThat(results, instanceOf(RankedDocsResults.class)); + var rankedDocResults = (RankedDocsResults) results; + var expectedRankedDocs = List.of( + new RankedDocsResults.RankedDoc(0, relevanceScores.get(1), returnDocuments ? inputs.get(0) : null) + ); + assertEquals(expectedRankedDocs, rankedDocResults.getRankedDocs()); + }, e -> fail("Expected successful parsing but got failure: " + e)), returnDocuments); + + var chunkedInputs = chunker.getChunkedInputs(); + assertEquals(3, chunkedInputs.size()); + listener.onResponse(new RankedDocsResults(generateRankedDocs(relevanceScores, chunkedInputs))); + } + + public void testParseChunkedRerankResultsListener_SingleInputWithChunkingWithLastChunkRelevanceScoreHighest() { + var inputs = List.of(generateTestText(100)); + var relevanceScores = List.of(0.5f, 0.3f, 1f); + var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs, null); + var returnDocuments = randomBoolean(); + var listener = chunker.parseChunkedRerankResultsListener(ActionListener.wrap(results -> { + assertThat(results, instanceOf(RankedDocsResults.class)); + var rankedDocResults = (RankedDocsResults) results; + var expectedRankedDocs = List.of( + new RankedDocsResults.RankedDoc(0, relevanceScores.get(2), returnDocuments ? inputs.get(0) : null) + ); + assertEquals(expectedRankedDocs, rankedDocResults.getRankedDocs()); + }, e -> fail("Expected successful parsing but got failure: " + e)), returnDocuments); + + var chunkedInputs = chunker.getChunkedInputs(); + assertEquals(3, chunkedInputs.size()); + listener.onResponse(new RankedDocsResults(generateRankedDocs(relevanceScores, chunkedInputs))); } public void testParseChunkedRerankResultsListener_MultipleInputsWithoutChunking() { - var inputs = List.of(generateTestText(10), generateTestText(10)); + var inputs = List.of(generateTestText(10), generateTestText(20)); var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs, null); + var returnDocuments = randomBoolean(); + var relevanceScores = generateRelevanceScores(2); var listener = chunker.parseChunkedRerankResultsListener(ActionListener.wrap(results -> { assertThat(results, instanceOf(RankedDocsResults.class)); var rankedDocResults = (RankedDocsResults) results; - assertEquals(2, rankedDocResults.getRankedDocs().size()); - var sortedResults = new ArrayList<>(rankedDocResults.getRankedDocs()); - sortedResults.sort((r1, r2) -> Float.compare(r2.relevanceScore(), r1.relevanceScore())); - assertEquals(sortedResults, rankedDocResults.getRankedDocs()); - }, e -> fail("Expected successful parsing but got failure: " + e))); + var expectedRankedDocs = new ArrayList(); + expectedRankedDocs.add(new RankedDocsResults.RankedDoc(0, relevanceScores.get(0), returnDocuments ? inputs.get(0) : null)); + expectedRankedDocs.add(new RankedDocsResults.RankedDoc(1, relevanceScores.get(1), returnDocuments ? inputs.get(1) : null)); + expectedRankedDocs.sort((r1, r2) -> Float.compare(r2.relevanceScore(), r1.relevanceScore())); + assertEquals(expectedRankedDocs, rankedDocResults.getRankedDocs()); + }, e -> fail("Expected successful parsing but got failure: " + e)), returnDocuments); var chunkedInputs = chunker.getChunkedInputs(); assertEquals(2, chunkedInputs.size()); - listener.onResponse( - new RankedDocsResults( - List.of( - new RankedDocsResults.RankedDoc(0, randomFloatBetween(0, 1, true), chunkedInputs.get(0)), - new RankedDocsResults.RankedDoc(1, randomFloatBetween(0, 1, true), chunkedInputs.get(1)) - ) - ) - ); + listener.onResponse(new RankedDocsResults(generateRankedDocs(relevanceScores, chunkedInputs))); } public void testParseChunkedRerankResultsListener_MultipleInputsWithSomeChunking() { var inputs = List.of(generateTestText(10), generateTestText(100)); var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs, null); + var returnDocuments = randomBoolean(); + var relevanceScores = generateRelevanceScores(4); var listener = chunker.parseChunkedRerankResultsListener(ActionListener.wrap(results -> { assertThat(results, instanceOf(RankedDocsResults.class)); var rankedDocResults = (RankedDocsResults) results; - assertEquals(2, rankedDocResults.getRankedDocs().size()); - var sortedResults = new ArrayList<>(rankedDocResults.getRankedDocs()); - sortedResults.sort((r1, r2) -> Float.compare(r2.relevanceScore(), r1.relevanceScore())); - assertEquals(sortedResults, rankedDocResults.getRankedDocs()); - }, e -> fail("Expected successful parsing but got failure: " + e))); + var expectedRankedDocs = new ArrayList(); + expectedRankedDocs.add(new RankedDocsResults.RankedDoc(0, relevanceScores.get(0), returnDocuments ? inputs.get(0) : null)); + expectedRankedDocs.add( + new RankedDocsResults.RankedDoc(1, Collections.max(relevanceScores.subList(1, 4)), returnDocuments ? inputs.get(1) : null) + ); + expectedRankedDocs.sort((r1, r2) -> Float.compare(r2.relevanceScore(), r1.relevanceScore())); + assertEquals(expectedRankedDocs, rankedDocResults.getRankedDocs()); + }, e -> fail("Expected successful parsing but got failure: " + e)), returnDocuments); var chunkedInputs = chunker.getChunkedInputs(); assertEquals(4, chunkedInputs.size()); - listener.onResponse( - new RankedDocsResults( - List.of( - new RankedDocsResults.RankedDoc(0, randomFloatBetween(0, 1, true), chunkedInputs.get(0)), - new RankedDocsResults.RankedDoc(1, randomFloatBetween(0, 1, true), chunkedInputs.get(1)), - new RankedDocsResults.RankedDoc(2, randomFloatBetween(0, 1, true), chunkedInputs.get(2)) - ) - ) - ); + listener.onResponse(new RankedDocsResults(generateRankedDocs(relevanceScores, chunkedInputs))); } public void testParseChunkedRerankResultsListener_MultipleInputsWithAllRequiringChunking() { - var inputs = List.of(generateTestText(100), generateTestText(100)); + var inputs = List.of(generateTestText(100), generateTestText(105)); var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs, null); + var returnDocuments = randomBoolean(); + var relevanceScores = generateRelevanceScores(6); var listener = chunker.parseChunkedRerankResultsListener(ActionListener.wrap(results -> { assertThat(results, instanceOf(RankedDocsResults.class)); var rankedDocResults = (RankedDocsResults) results; - assertEquals(2, rankedDocResults.getRankedDocs().size()); - var sortedResults = new ArrayList<>(rankedDocResults.getRankedDocs()); - sortedResults.sort((r1, r2) -> Float.compare(r2.relevanceScore(), r1.relevanceScore())); - assertEquals(sortedResults, rankedDocResults.getRankedDocs()); - }, e -> fail("Expected successful parsing but got failure: " + e))); + var expectedRankedDocs = new ArrayList(); + expectedRankedDocs.add( + new RankedDocsResults.RankedDoc(0, Collections.max(relevanceScores.subList(0, 3)), returnDocuments ? inputs.get(0) : null) + ); + expectedRankedDocs.add( + new RankedDocsResults.RankedDoc(1, Collections.max(relevanceScores.subList(3, 6)), returnDocuments ? inputs.get(1) : null) + ); + expectedRankedDocs.sort((r1, r2) -> Float.compare(r2.relevanceScore(), r1.relevanceScore())); + assertEquals(expectedRankedDocs, rankedDocResults.getRankedDocs()); + }, e -> fail("Expected successful parsing but got failure: " + e)), returnDocuments); var chunkedInputs = chunker.getChunkedInputs(); assertEquals(6, chunkedInputs.size()); - listener.onResponse( - new RankedDocsResults( - List.of( - new RankedDocsResults.RankedDoc(0, randomFloatBetween(0, 1, true), chunkedInputs.get(0)), - new RankedDocsResults.RankedDoc(1, randomFloatBetween(0, 1, true), chunkedInputs.get(1)), - new RankedDocsResults.RankedDoc(2, randomFloatBetween(0, 1, true), chunkedInputs.get(2)), - new RankedDocsResults.RankedDoc(3, randomFloatBetween(0, 1, true), chunkedInputs.get(3)) - ) - ) - ); + listener.onResponse(new RankedDocsResults(generateRankedDocs(relevanceScores, chunkedInputs))); } private String generateTestText(int numSentences) { @@ -247,4 +281,20 @@ private String generateTestText(int numSentences) { } return sb.toString(); } + + private List generateRelevanceScores(int numScores) { + List scores = new ArrayList<>(); + for (int i = 0; i < numScores; i++) { + scores.add(randomValueOtherThanMany(scores::contains, () -> randomFloatBetween(0, 1, true))); + } + return scores; + } + + private List generateRankedDocs(List relevanceScores, List chunkedInputs) { + List rankedDocs = new ArrayList<>(); + for (int i = 0; i < max(relevanceScores.size(), chunkedInputs.size()); i++) { + rankedDocs.add(new RankedDocsResults.RankedDoc(i, relevanceScores.get(i), chunkedInputs.get(i))); + } + return rankedDocs; + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticRerankerServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticRerankerServiceSettingsTests.java index c9ee6a0543140..2a5816eff225f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticRerankerServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticRerankerServiceSettingsTests.java @@ -20,7 +20,6 @@ import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticRerankerServiceSettings.LONG_DOCUMENT_STRATEGY; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticRerankerServiceSettings.MAX_CHUNKS_PER_DOC; -import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.ELASTIC_RERANKER_CHUNKING; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings.ADAPTIVE_ALLOCATIONS; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings.MODEL_ID; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings.NUM_ALLOCATIONS; @@ -39,12 +38,10 @@ public static ElasticRerankerServiceSettings createRandomWithChunkingConfigurati } public static ElasticRerankerServiceSettings createRandom() { - var longDocumentStrategy = ELASTIC_RERANKER_CHUNKING.isEnabled() - ? randomFrom(ElasticRerankerServiceSettings.LongDocumentStrategy.values()) + var longDocumentStrategy = randomBoolean() ? randomFrom(ElasticRerankerServiceSettings.LongDocumentStrategy.values()) : null; + var maxChunksPerDoc = ElasticRerankerServiceSettings.LongDocumentStrategy.CHUNK.equals(longDocumentStrategy) && randomBoolean() + ? randomIntBetween(1, 10) : null; - var maxChunksPerDoc = ELASTIC_RERANKER_CHUNKING.isEnabled() - && ElasticRerankerServiceSettings.LongDocumentStrategy.CHUNK.equals(longDocumentStrategy) - && randomBoolean() ? randomIntBetween(1, 10) : null; return createRandom(longDocumentStrategy, maxChunksPerDoc); } @@ -145,78 +142,7 @@ public void testFromMap_NumAllocationsAndAdaptiveAllocationsNull_ThrowsValidatio ); } - public void testFromMap_ChunkingFeatureFlagDisabledAndLongDocumentStrategyProvided_CreatesSettingsIgnoringStrategy() { - assumeTrue( - "Only if 'elastic_reranker_chunking_long_documents' feature flag is disabled", - ELASTIC_RERANKER_CHUNKING.isEnabled() == false - ); - var withAdaptiveAllocations = randomBoolean(); - var numAllocations = withAdaptiveAllocations ? null : randomIntBetween(1, 10); - var numThreads = randomIntBetween(1, 10); - var modelId = randomAlphaOfLength(8); - var adaptiveAllocationsSettings = withAdaptiveAllocations - ? new AdaptiveAllocationsSettings(true, randomIntBetween(0, 2), randomIntBetween(2, 5)) - : null; - var longDocumentStrategy = ElasticRerankerServiceSettings.LongDocumentStrategy.TRUNCATE; - - Map settingsMap = buildServiceSettingsMap( - withAdaptiveAllocations ? Optional.empty() : Optional.of(numAllocations), - numThreads, - modelId, - withAdaptiveAllocations ? Optional.of(adaptiveAllocationsSettings) : Optional.empty(), - Optional.of(longDocumentStrategy), - Optional.empty() - ); - - ElasticRerankerServiceSettings settings = ElasticRerankerServiceSettings.fromMap(settingsMap); - assertExpectedSettings( - settings, - Optional.ofNullable(numAllocations), - numThreads, - modelId, - Optional.ofNullable(adaptiveAllocationsSettings), - Optional.empty(), - Optional.empty() - ); - } - - public void testFromMap_ChunkingFeatureFlagDisabledAndMaxChunksPerDocProvided_CreatesSettingsIgnoringMaxChunksPerDoc() { - assumeTrue( - "Only if 'elastic_reranker_chunking_long_documents' feature flag is disabled", - ELASTIC_RERANKER_CHUNKING.isEnabled() == false - ); - var withAdaptiveAllocations = randomBoolean(); - var numAllocations = withAdaptiveAllocations ? null : randomIntBetween(1, 10); - var numThreads = randomIntBetween(1, 10); - var modelId = randomAlphaOfLength(8); - var adaptiveAllocationsSettings = withAdaptiveAllocations - ? new AdaptiveAllocationsSettings(true, randomIntBetween(0, 2), randomIntBetween(2, 5)) - : null; - var maxChunksPerDoc = randomIntBetween(1, 10); - - Map settingsMap = buildServiceSettingsMap( - withAdaptiveAllocations ? Optional.empty() : Optional.of(numAllocations), - numThreads, - modelId, - withAdaptiveAllocations ? Optional.of(adaptiveAllocationsSettings) : Optional.empty(), - Optional.empty(), - Optional.of(maxChunksPerDoc) - ); - - ElasticRerankerServiceSettings settings = ElasticRerankerServiceSettings.fromMap(settingsMap); - assertExpectedSettings( - settings, - Optional.ofNullable(numAllocations), - numThreads, - modelId, - Optional.ofNullable(adaptiveAllocationsSettings), - Optional.empty(), - Optional.empty() - ); - } - - public void testFromMap_ChunkingFeatureFlagEnabledAndTruncateSelected_CreatesSettingsCorrectly() { - assumeTrue("Only if 'elastic_reranker_chunking_long_documents' feature flag is enabled", ELASTIC_RERANKER_CHUNKING.isEnabled()); + public void testFromMap_TruncateLongDocumentStrategySelected_CreatesSettingsCorrectly() { var withAdaptiveAllocations = randomBoolean(); var numAllocations = withAdaptiveAllocations ? null : randomIntBetween(1, 10); var numThreads = randomIntBetween(1, 10); @@ -247,8 +173,7 @@ public void testFromMap_ChunkingFeatureFlagEnabledAndTruncateSelected_CreatesSet ); } - public void testFromMap_ChunkingFeatureFlagEnabledAndTruncateSelectedWithMaxChunksPerDoc_ThrowsValidationException() { - assumeTrue("Only if 'elastic_reranker_chunking_long_documents' feature flag is enabled", ELASTIC_RERANKER_CHUNKING.isEnabled()); + public void testFromMap_TruncateLongDocumentStrategySelectedWithMaxChunksPerDoc_ThrowsValidationException() { var withAdaptiveAllocations = randomBoolean(); var numAllocations = withAdaptiveAllocations ? null : randomIntBetween(1, 10); var numThreads = randomIntBetween(1, 10); @@ -278,8 +203,7 @@ public void testFromMap_ChunkingFeatureFlagEnabledAndTruncateSelectedWithMaxChun ); } - public void testFromMap_ChunkingFeatureFlagEnabledAndChunkSelected_CreatesSettingsCorrectly() { - assumeTrue("Only if 'elastic_reranker_chunking_long_documents' feature flag is enabled", ELASTIC_RERANKER_CHUNKING.isEnabled()); + public void testFromMap_ChunkLongDocumentStrategySelected_CreatesSettingsCorrectly() { var withAdaptiveAllocations = randomBoolean(); var numAllocations = withAdaptiveAllocations ? null : randomIntBetween(1, 10); var numThreads = randomIntBetween(1, 10); @@ -311,8 +235,7 @@ public void testFromMap_ChunkingFeatureFlagEnabledAndChunkSelected_CreatesSettin ); } - public void testFromMap_ChunkingFeatureFlagEnabledAndChunkSelectedWithMaxChunksPerDoc_CreatesSettingsCorrectly() { - assumeTrue("Only if 'elastic_reranker_chunking_long_documents' feature flag is enabled", ELASTIC_RERANKER_CHUNKING.isEnabled()); + public void testFromMap_ChunkLongDocumentStrategySelectedWithMaxChunksPerDoc_CreatesSettingsCorrectly() { var withAdaptiveAllocations = randomBoolean(); var numAllocations = withAdaptiveAllocations ? null : randomIntBetween(1, 10); var numThreads = randomIntBetween(1, 10); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java index e9f22f4848991..2cc0323e6d913 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java @@ -120,7 +120,6 @@ import static org.elasticsearch.xpack.inference.Utils.mockClusterService; import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; import static org.elasticsearch.xpack.inference.services.elasticsearch.BaseElasticsearchInternalService.notElasticsearchModelException; -import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.ELASTIC_RERANKER_CHUNKING; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.NAME; @@ -1011,49 +1010,7 @@ public void testInfer_ElasticRerankerSucceedsWithoutChunkingConfiguration() { testInfer_ElasticReranker(model, generateTestDocs(randomIntBetween(2, 10), randomIntBetween(50, 100))); } - public void testInfer_ElasticRerankerFeatureFlagDisabledSucceedsWithTruncateConfiguration() { - assumeTrue( - "Only if 'elastic_reranker_chunking_long_documents' feature flag is disabled", - ELASTIC_RERANKER_CHUNKING.isEnabled() == false - ); - - var model = new ElasticRerankerModel( - randomAlphaOfLength(10), - TaskType.RERANK, - NAME, - ElasticRerankerServiceSettingsTests.createRandomWithChunkingConfiguration( - ElasticRerankerServiceSettings.LongDocumentStrategy.TRUNCATE, - null - ), - new RerankTaskSettings(randomBoolean()) - ); - - testInfer_ElasticReranker(model, generateTestDocs(randomIntBetween(2, 10), randomIntBetween(50, 100))); - } - - public void testInfer_ElasticRerankerFeatureFlagDisabledSucceedsIgnoringChunkConfiguration() { - assumeTrue( - "Only if 'elastic_reranker_chunking_long_documents' feature flag is disabled", - ELASTIC_RERANKER_CHUNKING.isEnabled() == false - ); - - var model = new ElasticRerankerModel( - randomAlphaOfLength(10), - TaskType.RERANK, - NAME, - ElasticRerankerServiceSettingsTests.createRandomWithChunkingConfiguration( - ElasticRerankerServiceSettings.LongDocumentStrategy.CHUNK, - randomBoolean() ? randomIntBetween(1, 10) : null - ), - new RerankTaskSettings(randomBoolean()) - ); - - testInfer_ElasticReranker(model, generateTestDocs(randomIntBetween(2, 10), randomIntBetween(50, 100))); - } - - public void testInfer_ElasticRerankerFeatureFlagEnabledAndSucceedsWithTruncateStrategy() { - assumeTrue("Only if 'elastic_reranker_chunking_long_documents' feature flag is enabled", ELASTIC_RERANKER_CHUNKING.isEnabled()); - + public void testInfer_SucceedsWithTruncateLongDocumentStrategy() { var model = new ElasticRerankerModel( randomAlphaOfLength(10), TaskType.RERANK, @@ -1068,9 +1025,7 @@ public void testInfer_ElasticRerankerFeatureFlagEnabledAndSucceedsWithTruncateSt testInfer_ElasticReranker(model, generateTestDocs(randomIntBetween(2, 10), randomIntBetween(50, 100))); } - public void testInfer_ElasticRerankerFeatureFlagEnabledAndSucceedsWithChunkStrategy() { - assumeTrue("Only if 'elastic_reranker_chunking_long_documents' feature flag is enabled", ELASTIC_RERANKER_CHUNKING.isEnabled()); - + public void testInfer_SucceedsWithChunkLongDocumentStrategy() { var model = new ElasticRerankerModel( randomAlphaOfLength(10), TaskType.RERANK, @@ -1090,8 +1045,7 @@ private void testInfer_ElasticReranker(ElasticRerankerModel model, List var query = randomAlphaOfLength(10); var mlTrainedModelResults = new ArrayList(); var numResults = inputs.size(); - if (ELASTIC_RERANKER_CHUNKING.isEnabled() - && ElasticRerankerServiceSettings.LongDocumentStrategy.CHUNK.equals(model.getServiceSettings().getLongDocumentStrategy())) { + if (ElasticRerankerServiceSettings.LongDocumentStrategy.CHUNK.equals(model.getServiceSettings().getLongDocumentStrategy())) { var rerankRequestChunker = new RerankRequestChunker(query, inputs, model.getServiceSettings().getMaxChunksPerDoc()); numResults = rerankRequestChunker.getChunkedInputs().size(); } diff --git a/x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java b/x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java index ef8831aa8c605..03ad7fe135b96 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java +++ b/x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java @@ -15,7 +15,6 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.test.RetryRule; import org.elasticsearch.test.cluster.ElasticsearchCluster; -import org.elasticsearch.test.cluster.FeatureFlag; import org.elasticsearch.test.cluster.local.distribution.DistributionType; import org.elasticsearch.test.rest.yaml.ClientYamlTestCandidate; import org.elasticsearch.test.rest.yaml.ESClientYamlSuiteTestCase; @@ -35,7 +34,6 @@ public class InferenceRestIT extends ESClientYamlSuiteTestCase { .setting("xpack.security.enabled", "false") .setting("xpack.security.http.ssl.enabled", "false") .setting("xpack.license.self_generated.type", "trial") - .feature(FeatureFlag.ELASTIC_RERANKER_CHUNKING) .plugin("inference-service-test") .distribution(DistributionType.DEFAULT) .build();