diff --git a/docs/changelog/137700.yaml b/docs/changelog/137700.yaml new file mode 100644 index 0000000000000..bf75c0b6ffdb1 --- /dev/null +++ b/docs/changelog/137700.yaml @@ -0,0 +1,5 @@ +pr: 137700 +summary: Add rerank chunking for JinaAI +area: Machine Learning +type: enhancement +issues: [] diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/chunking/ChunkingSettingsBuilder.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/chunking/ChunkingSettingsBuilder.java index 91016de86a55d..e6d921e6adbba 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/chunking/ChunkingSettingsBuilder.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/chunking/ChunkingSettingsBuilder.java @@ -67,4 +67,9 @@ public static ChunkingSettings buildChunkingSettingsForElasticRerank(int queryWo var maxChunkSizeWordCount = (int) (maxChunkSizeTokenCount * WORDS_PER_TOKEN); return new SentenceBoundaryChunkingSettings(maxChunkSizeWordCount, 1); } + + public static ChunkingSettings buildChunkingSettingsForRerank(int rerankerWindowSize, int queryWordCount) { + var chunkSizeWordCountWithFullQuery = rerankerWindowSize - queryWordCount; + return new SentenceBoundaryChunkingSettings(Math.max(chunkSizeWordCountWithFullQuery, rerankerWindowSize / 2), 1); + } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/chunking/RerankRequestChunker.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/chunking/RerankRequestChunker.java index 524a5650a3b8a..cb7fb71c6ae25 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/chunking/RerankRequestChunker.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/chunking/RerankRequestChunker.java @@ -28,6 +28,11 @@ public RerankRequestChunker(String query, List inputs, Integer maxChunks this.rerankChunks = chunk(inputs, buildChunkingSettingsForElasticRerank(query), maxChunksPerDoc); } + public RerankRequestChunker(String query, List inputs, int rerankerWindowSize, Integer maxChunksPerDoc) { + this.inputs = inputs; + this.rerankChunks = chunk(inputs, buildChunkingSettingsForRerank(rerankerWindowSize, query), maxChunksPerDoc); + } + private List chunk(List inputs, ChunkingSettings chunkingSettings, Integer maxChunksPerDoc) { var chunker = ChunkerBuilder.fromChunkingStrategy(chunkingSettings.getChunkingStrategy()); var chunks = new ArrayList(); @@ -53,6 +58,25 @@ public List getChunkedInputs() { return chunkedInputs; } + public ActionListener parseChunkedRerankResultsListener( + ActionListener listener, + boolean returnDocuments, + Integer topN + ) { + listener = parseChunkedRerankResultsListener(listener, returnDocuments); + if (topN != null) { + return listener.delegateFailureAndWrap((l, results) -> { + if (results instanceof RankedDocsResults rankedDocsResults) { + l.onResponse(new RankedDocsResults(rankedDocsResults.getRankedDocs().subList(0, topN))); + } else { + l.onFailure(new IllegalArgumentException("Expected RankedDocsResults but got: " + results.getClass().getName())); + } + }); + } else { + return listener; + } + } + public ActionListener parseChunkedRerankResultsListener( ActionListener listener, boolean returnDocuments @@ -101,4 +125,11 @@ private ChunkingSettings buildChunkingSettingsForElasticRerank(String query) { var queryWordCount = ChunkerUtils.countWords(0, query.length(), wordIterator); return ChunkingSettingsBuilder.buildChunkingSettingsForElasticRerank(queryWordCount); } + + private ChunkingSettings buildChunkingSettingsForRerank(int rerankerWindowSize, String query) { + var wordIterator = BreakIterator.getWordInstance(); + wordIterator.setText(query); + var queryWordCount = ChunkerUtils.countWords(0, query.length(), wordIterator); + return ChunkingSettingsBuilder.buildChunkingSettingsForRerank(rerankerWindowSize, queryWordCount); + } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/chunking/ChunkingSettingsBuilderTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/chunking/ChunkingSettingsBuilderTests.java index a7a6f78e26b7f..993e49d5887b0 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/chunking/ChunkingSettingsBuilderTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/chunking/ChunkingSettingsBuilderTests.java @@ -77,6 +77,31 @@ public void testBuildChunkingSettingsForElasticReranker_QueryTokenCountMoreThanH assertEquals(1, sentenceBoundaryChunkingSettings.sentenceOverlap()); } + public void testBuildChunkingSettingsForRerank_QueryWordCountLessThanHalfOfRerankerWindowSize() { + int rerankerWindowSize = randomIntBetween(20, 6000); + int queryWordCount = randomIntBetween(1, rerankerWindowSize / 2 - 1); + ChunkingSettings actualChunkingSettings = ChunkingSettingsBuilder.buildChunkingSettingsForRerank( + rerankerWindowSize, + queryWordCount + ); + SentenceBoundaryChunkingSettings expectedChunkingSettings = new SentenceBoundaryChunkingSettings( + rerankerWindowSize - queryWordCount, + 1 + ); + assertEquals(expectedChunkingSettings, actualChunkingSettings); + } + + public void testBuildChunkingSettingsForRerank_QueryWordCountMoreThanHalfOfRerankerWindowSize() { + int rerankerWindowSize = randomIntBetween(20, 6000); + int queryWordCount = randomIntBetween(rerankerWindowSize / 2, Integer.MAX_VALUE); + ChunkingSettings actualChunkingSettings = ChunkingSettingsBuilder.buildChunkingSettingsForRerank( + rerankerWindowSize, + queryWordCount + ); + SentenceBoundaryChunkingSettings expectedChunkingSettings = new SentenceBoundaryChunkingSettings(rerankerWindowSize / 2, 1); + assertEquals(expectedChunkingSettings, actualChunkingSettings); + } + private Map, ChunkingSettings> chunkingSettingsMapToChunkingSettings() { var maxChunkSizeWordBoundaryChunkingSettings = randomIntBetween(10, 300); var overlap = randomIntBetween(1, maxChunkSizeWordBoundaryChunkingSettings / 2); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/chunking/RerankRequestChunkerTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/chunking/RerankRequestChunkerTests.java index 698798cf0b6e2..cd2a8f5aca4d0 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/chunking/RerankRequestChunkerTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/chunking/RerankRequestChunkerTests.java @@ -274,6 +274,52 @@ public void testParseChunkedRerankResultsListener_MultipleInputsWithAllRequiring listener.onResponse(new RankedDocsResults(generateRankedDocs(relevanceScores, chunkedInputs))); } + public void testParseChunkedRerankResultsListener_TopNSetToNull() { + var inputs = List.of(generateTestText(10), generateTestText(100)); + var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs, null); + var returnDocuments = randomBoolean(); + var relevanceScores = generateRelevanceScores(4); + var listener = chunker.parseChunkedRerankResultsListener(ActionListener.wrap(results -> { + assertThat(results, instanceOf(RankedDocsResults.class)); + var rankedDocResults = (RankedDocsResults) results; + var expectedRankedDocs = new ArrayList(); + expectedRankedDocs.add(new RankedDocsResults.RankedDoc(0, relevanceScores.get(0), returnDocuments ? inputs.get(0) : null)); + expectedRankedDocs.add( + new RankedDocsResults.RankedDoc(1, Collections.max(relevanceScores.subList(1, 4)), returnDocuments ? inputs.get(1) : null) + ); + expectedRankedDocs.sort((r1, r2) -> Float.compare(r2.relevanceScore(), r1.relevanceScore())); + assertEquals(expectedRankedDocs, rankedDocResults.getRankedDocs()); + }, e -> fail("Expected successful parsing but got failure: " + e)), returnDocuments, null); + + var chunkedInputs = chunker.getChunkedInputs(); + assertEquals(4, chunkedInputs.size()); + listener.onResponse(new RankedDocsResults(generateRankedDocs(relevanceScores, chunkedInputs))); + } + + public void testParseChunkedRerankResultsListener_TopNProvided() { + var inputs = List.of(generateTestText(10), generateTestText(100)); + var chunker = new RerankRequestChunker(TEST_SENTENCE, inputs, null); + var returnDocuments = randomBoolean(); + var relevanceScores = generateRelevanceScores(4); + var topN = 1; + var listener = chunker.parseChunkedRerankResultsListener(ActionListener.wrap(results -> { + assertThat(results, instanceOf(RankedDocsResults.class)); + var rankedDocResults = (RankedDocsResults) results; + assertEquals(topN, rankedDocResults.getRankedDocs().size()); + var expectedRankedDocs = new ArrayList(); + expectedRankedDocs.add(new RankedDocsResults.RankedDoc(0, relevanceScores.get(0), returnDocuments ? inputs.get(0) : null)); + expectedRankedDocs.add( + new RankedDocsResults.RankedDoc(1, Collections.max(relevanceScores.subList(1, 4)), returnDocuments ? inputs.get(1) : null) + ); + expectedRankedDocs.sort((r1, r2) -> Float.compare(r2.relevanceScore(), r1.relevanceScore())); + assertEquals(expectedRankedDocs.subList(0, topN), rankedDocResults.getRankedDocs()); + }, e -> fail("Expected successful parsing but got failure: " + e)), returnDocuments, topN); + + var chunkedInputs = chunker.getChunkedInputs(); + assertEquals(4, chunkedInputs.size()); + listener.onResponse(new RankedDocsResults(generateRankedDocs(relevanceScores, chunkedInputs))); + } + private String generateTestText(int numSentences) { StringBuilder sb = new StringBuilder(); for (int i = 0; i < numSentences; i++) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/RerankInferenceProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/RerankInferenceProcessor.java new file mode 100644 index 0000000000000..4f640989cdfbb --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/RerankInferenceProcessor.java @@ -0,0 +1,72 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.RerankingInferenceService; +import org.elasticsearch.xpack.core.inference.chunking.RerankRequestChunker; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; +import org.elasticsearch.xpack.inference.services.settings.LongDocumentStrategy; +import org.elasticsearch.xpack.inference.services.settings.RerankServiceSettings; + +public class RerankInferenceProcessor { + public static void doInfer( + SenderService service, + Model model, + ExecutableAction action, + InferenceInputs inputs, + TimeValue timeout, + ActionListener listener + ) { + var serviceSettings = model.getServiceSettings(); + if (serviceSettings instanceof RerankServiceSettings == false) { + throw new IllegalArgumentException("RerankInferenceProcessor can only process models with RerankServiceSettings"); + } + + var rerankServiceSettings = (RerankServiceSettings) serviceSettings; + if (LongDocumentStrategy.CHUNK.equals(rerankServiceSettings.getLongDocumentStrategy())) { + if (inputs instanceof QueryAndDocsInputs == false) { + throw new IllegalArgumentException("RerankInferenceProcessor can only process QueryAndDocsInputs when chunking is enabled"); + } + var queryAndDocsInputs = (QueryAndDocsInputs) inputs; + + if (service instanceof RerankingInferenceService == false) { + throw new IllegalArgumentException( + "RerankInferenceProcessor can only process RerankingInferenceService when chunking is enabled" + ); + } + + var rerankRequestChunker = new RerankRequestChunker( + queryAndDocsInputs.getQuery(), + queryAndDocsInputs.getChunks(), + ((RerankingInferenceService) service).rerankerWindowSize(serviceSettings.modelId()), + rerankServiceSettings.getMaxChunksPerDoc() + ); + + inputs = new QueryAndDocsInputs( + queryAndDocsInputs.getQuery(), + rerankRequestChunker.getChunkedInputs(), + queryAndDocsInputs.getReturnDocuments(), // TODO: Check if we want this from inputs or from task settings + queryAndDocsInputs.getTopN(), // TODO: Check if we want this from inputs or from task settings + queryAndDocsInputs.stream() + ); + + listener = rerankRequestChunker.parseChunkedRerankResultsListener( + listener, + queryAndDocsInputs.getReturnDocuments() == null || queryAndDocsInputs.getReturnDocuments(), + queryAndDocsInputs.getTopN() + ); + } + action.execute(inputs, timeout, listener); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java index 655f70df89f32..74d0be6132bb0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java @@ -25,6 +25,7 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; @@ -150,13 +151,25 @@ public void chunkedInfer( }).addListener(listener); } - protected abstract void doInfer( + protected void doInfer( Model model, InferenceInputs inputs, Map taskSettings, TimeValue timeout, ActionListener listener - ); + ) { + var action = createAction(model, taskSettings); + + if (Objects.requireNonNull(model.getTaskType()) == TaskType.RERANK) { + RerankInferenceProcessor.doInfer(this, model, action, inputs, timeout, listener); + } else { + action.execute(inputs, timeout, listener); + } + } + + protected ExecutableAction createAction(Model model, Map taskSettings) { + return null; + } protected abstract void validateInputType(InputType inputType, Model model, ValidationException validationException); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java index 0cd3d5ad65782..d8d66ff93059e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java @@ -32,9 +32,9 @@ import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.core.inference.chunking.EmbeddingRequestChunker; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; -import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; @@ -235,23 +235,14 @@ protected void doUnifiedCompletionInfer( } @Override - public void doInfer( - Model model, - InferenceInputs inputs, - Map taskSettings, - TimeValue timeout, - ActionListener listener - ) { - if (model instanceof JinaAIModel == false) { - listener.onFailure(createInvalidModelException(model)); - return; - } - - JinaAIModel jinaaiModel = (JinaAIModel) model; + protected ExecutableAction createAction(Model model, Map taskSettings) { var actionCreator = new JinaAIActionCreator(getSender(), getServiceComponents()); - var action = jinaaiModel.accept(actionCreator, taskSettings); - action.execute(inputs, timeout, listener); + if (model instanceof JinaAIModel jinaaiModel) { + return jinaaiModel.accept(actionCreator, taskSettings); + } else { + throw createInvalidModelException(model); + } } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/rerank/JinaAIRerankServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/rerank/JinaAIRerankServiceSettings.java index ceb7281dd1156..c62d241c4efd1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/rerank/JinaAIRerankServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/rerank/JinaAIRerankServiceSettings.java @@ -19,14 +19,14 @@ import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.jinaai.JinaAIRateLimitServiceSettings; import org.elasticsearch.xpack.inference.services.jinaai.JinaAIServiceSettings; -import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.elasticsearch.xpack.inference.services.settings.RerankServiceSettings; import java.io.IOException; import java.util.Map; import java.util.Objects; -public class JinaAIRerankServiceSettings extends FilteredXContentObject implements ServiceSettings, JinaAIRateLimitServiceSettings { +public class JinaAIRerankServiceSettings extends RerankServiceSettings implements ServiceSettings, JinaAIRateLimitServiceSettings { public static final String NAME = "jinaai_rerank_service_settings"; private static final Logger logger = LogManager.getLogger(JinaAIRerankServiceSettings.class); @@ -38,18 +38,26 @@ public static JinaAIRerankServiceSettings fromMap(Map map, Confi throw validationException; } + var rerankServiceSettings = RerankServiceSettings.fromMap(map); var commonServiceSettings = JinaAIServiceSettings.fromMap(map, context); - return new JinaAIRerankServiceSettings(commonServiceSettings); + return new JinaAIRerankServiceSettings(rerankServiceSettings, commonServiceSettings); } private final JinaAIServiceSettings commonSettings; + public JinaAIRerankServiceSettings(RerankServiceSettings rerankSettings, JinaAIServiceSettings commonSettings) { + super(rerankSettings); + this.commonSettings = commonSettings; + } + public JinaAIRerankServiceSettings(JinaAIServiceSettings commonSettings) { + super(null, null); this.commonSettings = commonSettings; } public JinaAIRerankServiceSettings(StreamInput in) throws IOException { + super(in); this.commonSettings = new JinaAIServiceSettings(in); } @@ -76,6 +84,7 @@ public String getWriteableName() { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); + builder = super.toXContentFragmentOfExposedFields(builder, params); builder = commonSettings.toXContentFragment(builder, params); builder.endObject(); @@ -83,7 +92,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } @Override - protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { + public XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { + super.toXContentFragmentOfExposedFields(builder, params); commonSettings.toXContentFragmentOfExposedFields(builder, params); return builder; } @@ -95,6 +105,7 @@ public TransportVersion getMinimalSupportedVersion() { @Override public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); commonSettings.writeTo(out); } @@ -103,11 +114,12 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; JinaAIRerankServiceSettings that = (JinaAIRerankServiceSettings) o; - return Objects.equals(commonSettings, that.commonSettings); + return super.equals(o) && Objects.equals(commonSettings, that.commonSettings); } @Override public int hashCode() { return Objects.hash(commonSettings); + // TODO: include super.hashcode? } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/LongDocumentStrategy.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/LongDocumentStrategy.java new file mode 100644 index 0000000000000..6a88d98f4f2aa --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/LongDocumentStrategy.java @@ -0,0 +1,25 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.settings; + +import java.util.Locale; + +public enum LongDocumentStrategy { + CHUNK("chunk"), + TRUNCATE("truncate"); + + public final String strategyName; + + LongDocumentStrategy(String strategyName) { + this.strategyName = strategyName; + } + + public static LongDocumentStrategy fromString(String name) { + return valueOf(name.trim().toUpperCase(Locale.ROOT)); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/RerankServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/RerankServiceSettings.java new file mode 100644 index 0000000000000..1d848f848f588 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/RerankServiceSettings.java @@ -0,0 +1,132 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.settings; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.VersionedNamedWriteable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.EnumSet; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; + +public class RerankServiceSettings extends FilteredXContentObject implements VersionedNamedWriteable { + private final String NAME = "rerank_service_settings"; + + public static final String LONG_DOCUMENT_STRATEGY = "long_document_strategy"; + public static final String MAX_CHUNKS_PER_DOC = "max_chunks_per_doc"; + + private final LongDocumentStrategy longDocumentStrategy; + private final Integer maxChunksPerDoc; + + public static RerankServiceSettings fromMap(Map map) { + var validationException = new ValidationException(); + + LongDocumentStrategy longDocumentStrategy = extractOptionalEnum( + map, + LONG_DOCUMENT_STRATEGY, + ModelConfigurations.SERVICE_SETTINGS, + LongDocumentStrategy::fromString, + EnumSet.allOf(LongDocumentStrategy.class), + validationException + ); + + Integer maxChunksPerDoc = extractOptionalPositiveInteger( + map, + MAX_CHUNKS_PER_DOC, + ModelConfigurations.SERVICE_SETTINGS, + validationException + ); + + if ((longDocumentStrategy == null || LongDocumentStrategy.TRUNCATE.equals(longDocumentStrategy)) && maxChunksPerDoc != null) { + validationException.addValidationError( + "Setting [" + MAX_CHUNKS_PER_DOC + "] cannot be set without setting [" + LONG_DOCUMENT_STRATEGY + "]" + ); + } + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new RerankServiceSettings(longDocumentStrategy, maxChunksPerDoc); + } + + public RerankServiceSettings(LongDocumentStrategy longDocumentStrategy, Integer maxChunksPerDoc) { + this.longDocumentStrategy = longDocumentStrategy; + this.maxChunksPerDoc = maxChunksPerDoc; + } + + public RerankServiceSettings(RerankServiceSettings other) { + this.longDocumentStrategy = other.longDocumentStrategy; + this.maxChunksPerDoc = other.maxChunksPerDoc; + } + + public RerankServiceSettings(StreamInput in) throws IOException { + longDocumentStrategy = in.readOptionalEnum(LongDocumentStrategy.class); + maxChunksPerDoc = in.readOptionalInt(); + // TODO: Add transport version + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalEnum(longDocumentStrategy); + out.writeOptionalInt(maxChunksPerDoc); + // TODO: Add transport version + } + + public LongDocumentStrategy getLongDocumentStrategy() { + return longDocumentStrategy; + } + + public Integer getMaxChunksPerDoc() { + return maxChunksPerDoc; + } + + @Override + public XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, ToXContent.Params params) throws IOException { + if (longDocumentStrategy != null) { + builder.field(LONG_DOCUMENT_STRATEGY, longDocumentStrategy.strategyName); + } + if (maxChunksPerDoc != null) { + builder.field(MAX_CHUNKS_PER_DOC, maxChunksPerDoc); + } + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return null; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + RerankServiceSettings that = (RerankServiceSettings) o; + return longDocumentStrategy == that.longDocumentStrategy && java.util.Objects.equals(maxChunksPerDoc, that.maxChunksPerDoc); + } + + @Override + public int hashCode() { + return Objects.hash(longDocumentStrategy, maxChunksPerDoc); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/RerankInferenceProcessorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/RerankInferenceProcessorTests.java new file mode 100644 index 0000000000000..32fc125e1f3a6 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/RerankInferenceProcessorTests.java @@ -0,0 +1,231 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.RerankingInferenceService; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.inference.chunking.RerankRequestChunker; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; +import org.elasticsearch.xpack.inference.services.jinaai.JinaAIService; +import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankServiceSettings; +import org.elasticsearch.xpack.inference.services.settings.LongDocumentStrategy; +import org.elasticsearch.xpack.inference.services.settings.RerankServiceSettings; +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import java.util.List; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +public class RerankInferenceProcessorTests extends ESTestCase { + @Mock + private SenderService mockRerankingSenderService; + + @Mock + private Model mockModel; + + @Mock + private ServiceSettings mockServiceSettings; + + @Mock + private ExecutableAction mockExecutableAction; + + @Mock + private QueryAndDocsInputs mockQueryAndDocsInputs; + + @Mock + private TimeValue mockTimeValue; + + @Mock + private ActionListener mockListener; + + @SuppressWarnings("unchecked") + @Before + public void init() { + MockitoAnnotations.openMocks(this); + mockRerankingSenderService = mock(JinaAIService.class); + mockModel = mock(Model.class); + mockServiceSettings = mock(JinaAIRerankServiceSettings.class); + mockExecutableAction = mock(ExecutableAction.class); + mockQueryAndDocsInputs = mock(QueryAndDocsInputs.class); + mockTimeValue = mock(TimeValue.class); + mockListener = mock(ActionListener.class); + + when(mockModel.getServiceSettings()).thenReturn(mockServiceSettings); + } + + public void testDoInfer_ModelDoesNotHaveRerankServiceSettings_ThrowsIllegalArgumentException() { + when(mockModel.getServiceSettings()).thenReturn(mock(ServiceSettings.class)); + + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> { + RerankInferenceProcessor.doInfer( + mockRerankingSenderService, + mockModel, + mockExecutableAction, + mockQueryAndDocsInputs, + mockTimeValue, + mockListener + ); + }); + assertEquals("RerankInferenceProcessor can only process models with RerankServiceSettings", exception.getMessage()); + + verify(mockModel).getServiceSettings(); + verifyNoMoreInteractionsOnMocks(); + } + + public void testDoInfer_LongDocumentStrategyNull_ExecutesActionWithoutChunking() { + when(((RerankServiceSettings) mockServiceSettings).getLongDocumentStrategy()).thenReturn(null); + + RerankInferenceProcessor.doInfer( + mockRerankingSenderService, + mockModel, + mockExecutableAction, + mockQueryAndDocsInputs, + mockTimeValue, + mockListener + ); + verify(mockModel).getServiceSettings(); + verify(mockExecutableAction).execute(mockQueryAndDocsInputs, mockTimeValue, mockListener); + verifyNoMoreInteractionsOnMocks(); + } + + public void testDoInfer_LongDocumentStrategyTruncate_ExecutesActionWithoutChunk() { + when(((RerankServiceSettings) mockServiceSettings).getLongDocumentStrategy()).thenReturn(LongDocumentStrategy.TRUNCATE); + + RerankInferenceProcessor.doInfer( + mockRerankingSenderService, + mockModel, + mockExecutableAction, + mockQueryAndDocsInputs, + mockTimeValue, + mockListener + ); + verify(mockModel).getServiceSettings(); + verify(mockExecutableAction).execute(mockQueryAndDocsInputs, mockTimeValue, mockListener); + verifyNoMoreInteractionsOnMocks(); + } + + public void testDoInfer_LongDocumentStrategyChunkAndInputsAreNotQueryAndDocsInputs_ThrowIllegalArgumentException() { + when(((RerankServiceSettings) mockServiceSettings).getLongDocumentStrategy()).thenReturn(LongDocumentStrategy.CHUNK); + + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> { + RerankInferenceProcessor.doInfer( + mockRerankingSenderService, + mockModel, + mockExecutableAction, + mock(InferenceInputs.class), + mockTimeValue, + mockListener + ); + }); + assertEquals("RerankInferenceProcessor can only process QueryAndDocsInputs when chunking is enabled", exception.getMessage()); + + verify(mockModel).getServiceSettings(); + verifyNoMoreInteractionsOnMocks(); + } + + public void testDoInfer_LongDocumentStrategyChunkWithQueryAndDocsInputsAndNonRerankingInferenceService_ThrowIllegalArgumentException() { + when(((RerankServiceSettings) mockServiceSettings).getLongDocumentStrategy()).thenReturn(LongDocumentStrategy.CHUNK); + + var mockQueryAndDocsInputs = mock(QueryAndDocsInputs.class); + + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> { + RerankInferenceProcessor.doInfer( + mock(SenderService.class), + mockModel, + mockExecutableAction, + mockQueryAndDocsInputs, + mockTimeValue, + mockListener + ); + }); + assertEquals( + "RerankInferenceProcessor can only process RerankingInferenceService when chunking is enabled", + exception.getMessage() + ); + + verify(mockModel).getServiceSettings(); + verifyNoMoreInteractionsOnMocks(); + } + + @SuppressWarnings("unchecked") + public void testDoInfer_LongDocumentStrategyChunkWithQueryAndDocsInputsAndRerankingInferenceService_ExecutesActionWithChunking() { + when(((RerankServiceSettings) mockServiceSettings).getLongDocumentStrategy()).thenReturn(LongDocumentStrategy.CHUNK); + var maxChunksPerDoc = randomIntBetween(1, 5); + when(((RerankServiceSettings) mockServiceSettings).getMaxChunksPerDoc()).thenReturn(maxChunksPerDoc); + QueryAndDocsInputs queryAndDocsInputs = new QueryAndDocsInputs( + "This is a test query", + List.of("This is the first chunk", "This is the second chunk", "This is the third chunk"), + randomBoolean() ? randomBoolean() : null, + randomBoolean() ? randomIntBetween(1, 10) : null, + false + ); + + var rerankerWindowSize = randomIntBetween(512, 3000); + when(((RerankingInferenceService) mockRerankingSenderService).rerankerWindowSize(any())).thenReturn(rerankerWindowSize); + RerankInferenceProcessor.doInfer( + mockRerankingSenderService, + mockModel, + mockExecutableAction, + queryAndDocsInputs, + mockTimeValue, + mockListener + ); + + verify(mockModel).getServiceSettings(); + verify((RerankingInferenceService) mockRerankingSenderService).rerankerWindowSize(any()); + + RerankRequestChunker rerankRequestChunker = new RerankRequestChunker( + queryAndDocsInputs.getQuery(), + queryAndDocsInputs.getChunks(), + rerankerWindowSize, + maxChunksPerDoc + ); + + QueryAndDocsInputs expectedChunkedInputs = new QueryAndDocsInputs( + queryAndDocsInputs.getQuery(), + rerankRequestChunker.getChunkedInputs(), + queryAndDocsInputs.getReturnDocuments(), + queryAndDocsInputs.getTopN(), + queryAndDocsInputs.stream() + ); + ArgumentCaptor inputsCaptor = ArgumentCaptor.forClass(QueryAndDocsInputs.class); + verify(mockExecutableAction).execute(inputsCaptor.capture(), eq(mockTimeValue), any(ActionListener.class)); + QueryAndDocsInputs capturedInputs = inputsCaptor.getValue(); + assertEquals(expectedChunkedInputs.getQuery(), capturedInputs.getQuery()); + assertEquals(expectedChunkedInputs.getChunks(), capturedInputs.getChunks()); + assertEquals(expectedChunkedInputs.getReturnDocuments(), capturedInputs.getReturnDocuments()); + assertEquals(expectedChunkedInputs.getTopN(), capturedInputs.getTopN()); + verifyNoMoreInteractionsOnMocks(); + } + + private void verifyNoMoreInteractionsOnMocks() { + verifyNoMoreInteractions( + mockRerankingSenderService, + mockModel, + mockExecutableAction, + mockQueryAndDocsInputs, + mockTimeValue, + mockListener + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java index 552482c104f97..b4dc2c018577d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java @@ -24,6 +24,7 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.inference.InferencePlugin; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.http.sender.Sender; @@ -189,6 +190,83 @@ protected void doInfer( } } + public void testDoInfer_NonRerankTaskType_ExecutesAction() throws IOException { + var sender = createMockSender(); + var factory = mock(HttpRequestSender.Factory.class); + when(factory.createSender()).thenReturn(sender); + + var mockExecutableAction = mock(ExecutableAction.class); + var testService = new TestSenderService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty()) { + @Override + protected ExecutableAction createAction(Model model, Map taskSettings) { + return mockExecutableAction; + } + }; + + try (testService) { + var model = mock(Model.class); + when(model.getTaskType()).thenReturn(randomValueOtherThan(TaskType.RERANK, () -> randomFrom(TaskType.values()))); + var mockInferenceInputs = mock(InferenceInputs.class); + PlainActionFuture listener = new PlainActionFuture<>(); + + testService.doInfer(model, mockInferenceInputs, Map.of(), TIMEOUT, listener); + verify(mockExecutableAction).execute(mockInferenceInputs, TIMEOUT, listener); + } + + verify(sender, times(1)).close(); + verify(factory).createSender(); + verifyNoMoreInteractions(factory); + verifyNoMoreInteractions(sender); + verifyNoMoreInteractions(mockExecutableAction); + } + + public void testDoInfer_RerankTaskType_ExecutesRerankProcessor() throws IOException { + var sender = createMockSender(); + var factory = mock(HttpRequestSender.Factory.class); + when(factory.createSender()).thenReturn(sender); + + var mockExecutableAction = mock(ExecutableAction.class); + var testService = new TestSenderService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty()) { + @Override + protected ExecutableAction createAction(Model model, Map taskSettings) { + return mockExecutableAction; + } + }; + + try (testService) { + var model = mock(Model.class); + when(model.getTaskType()).thenReturn(TaskType.RERANK); + var mockInferenceInputs = mock(InferenceInputs.class); + PlainActionFuture listener = new PlainActionFuture<>(); + + IllegalArgumentException exception = assertThrows( + IllegalArgumentException.class, + () -> testService.doInfer(model, mockInferenceInputs, Map.of(), TIMEOUT, listener) + ); + assertEquals("RerankInferenceProcessor can only process models with RerankServiceSettings", exception.getMessage()); + } + + verify(sender, times(1)).close(); + verify(factory).createSender(); + verifyNoMoreInteractions(factory); + verifyNoMoreInteractions(sender); + verifyNoMoreInteractions(mockExecutableAction); + } + + public void testCreateAction_ReturnsNull() throws IOException { + var sender = createMockSender(); + var factory = mock(HttpRequestSender.Factory.class); + when(factory.createSender()).thenReturn(sender); + + try (var service = new TestSenderService(factory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + var action = service.createAction(mock(Model.class), Map.of()); + assertNull(action); + } + + verify(sender, times(1)).close(); + verify(factory, times(1)).createSender(); + } + public static Sender createMockSender() { var sender = mock(Sender.class); doAnswer(invocationOnMock -> { @@ -205,17 +283,6 @@ private static class TestSenderService extends SenderService { super(factory, serviceComponents, clusterService); } - @Override - protected void doInfer( - Model model, - InferenceInputs inputs, - Map taskSettings, - TimeValue timeout, - ActionListener listener - ) { - - } - @Override protected void validateInputType(InputType inputType, Model model, ValidationException validationException) {} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java index 8158bb0403222..3037a5ba7f7d6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java @@ -1603,6 +1603,89 @@ public void testInfer_Embedding_DoesNotSetInputType_WhenNotPresentInTaskSettings } } + // TODO: Do we still need the infer tests above? They should be covered by SenderService tests + public void testCreateAction_WithNonJinaAIModel_ThrowsElasticsearchStatusException() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + var mockModel = getInvalidModel("model_id", "service_name"); + + var thrownException = expectThrows(ElasticsearchStatusException.class, () -> service.createAction(mockModel, null)); + + MatcherAssert.assertThat( + thrownException.getMessage(), + is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.") + ); + } + } + + public void testCreateAction_CreatesActionForEmbeddingsModelWithNullOverrideTaskSettings() throws IOException { + var model = JinaAIEmbeddingsModelTests.createModel( + randomAlphaOfLength(10), + randomAlphaOfLength(10), + JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + randomNonNegativeInt(), + randomNonNegativeInt(), + randomAlphaOfLength(10), + null, + JinaAIEmbeddingType.FLOAT + ); + + testCreateAction(model, null); + } + + public void testCreateAction_CreatesActionForEmbeddingsModelWithNonNullOverrideTaskSettings() throws IOException { + var model = JinaAIEmbeddingsModelTests.createModel( + randomAlphaOfLength(10), + randomAlphaOfLength(10), + JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + randomNonNegativeInt(), + randomNonNegativeInt(), + randomAlphaOfLength(10), + null, + JinaAIEmbeddingType.FLOAT + ); + Map overrideTaskSettings = new HashMap<>(); + overrideTaskSettings.put("input_type", InputType.INGEST.toString()); + + testCreateAction(model, overrideTaskSettings); + } + + public void testCreateAction_CreatesActionForRerankModelWithNullOverrideTaskSettings() throws IOException { + var model = JinaAIRerankModelTests.createModel( + randomAlphaOfLength(10), + randomAlphaOfLength(10), + randomAlphaOfLength(10), + null, + false + ); + + testCreateAction(model, null); + } + + public void testCreateAction_CreatesActionForRerankModelWithNonNullOverrideTaskSettings() throws IOException { + var model = JinaAIRerankModelTests.createModel( + randomAlphaOfLength(10), + randomAlphaOfLength(10), + randomAlphaOfLength(10), + null, + false + ); + Map overrideTaskSettings = new HashMap<>(); + overrideTaskSettings.put("top_n", randomIntBetween(1, 10)); + + testCreateAction(model, overrideTaskSettings); + } + + private void testCreateAction(JinaAIModel model, Map taskSettings) throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + var action = service.createAction(model, taskSettings); + assertNotNull(action); + } + } + public void test_Embedding_ChunkedInfer_BatchesCallsChunkingSettingsSet() throws IOException { var model = JinaAIEmbeddingsModelTests.createModel( getUrl(webServer), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/rerank/JinaAIRerankServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/rerank/JinaAIRerankServiceSettingsTests.java index 47f67bd8cefb8..bd321f6514ee6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/rerank/JinaAIRerankServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/rerank/JinaAIRerankServiceSettingsTests.java @@ -9,25 +9,33 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.core.Nullable; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.jinaai.JinaAIServiceSettings; import org.elasticsearch.xpack.inference.services.jinaai.JinaAIServiceSettingsTests; +import org.elasticsearch.xpack.inference.services.settings.LongDocumentStrategy; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; +import org.elasticsearch.xpack.inference.services.settings.RerankServiceSettings; +import org.elasticsearch.xpack.inference.services.settings.RerankServiceSettingsTests; import java.io.IOException; import java.util.HashMap; import java.util.Map; import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhitespaceInJsonString; +import static org.elasticsearch.xpack.inference.services.jinaai.JinaAIServiceSettings.DEFAULT_RATE_LIMIT_SETTINGS; +import static org.hamcrest.Matchers.containsString; public class JinaAIRerankServiceSettingsTests extends AbstractBWCWireSerializationTestCase { public static JinaAIRerankServiceSettings createRandom() { return new JinaAIRerankServiceSettings( + RerankServiceSettingsTests.createRandom(), new JinaAIServiceSettings( randomFrom(new String[] { null, Strings.format("http://%s.com", randomAlphaOfLength(8)) }), randomAlphaOfLength(10), @@ -36,25 +44,109 @@ public static JinaAIRerankServiceSettings createRandom() { ); } + public void testFromMap_RerankServiceSettingsInvalid_ThrowsValidationException() { + var url = randomAlphaOfLength(10); + var model = randomAlphaOfLength(10); + + Map map = getServiceSettingsMap(null, randomNonNegativeInt(), url, model); + + ValidationException exception = expectThrows( + ValidationException.class, + () -> JinaAIRerankServiceSettings.fromMap(map, ConfigurationParseContext.REQUEST) + ); + assertThat( + exception.getMessage(), + containsString("Setting [max_chunks_per_doc] cannot be set without setting [long_document_strategy]") + ); + } + + public void testFromMap_JinaAiServiceSettingsInvalid_ThrowsValidationException() { + var rerankServiceSettings = RerankServiceSettingsTests.createRandom(); + var longDocumentStrategy = rerankServiceSettings.getLongDocumentStrategy(); + var maxChunksPerDoc = rerankServiceSettings.getMaxChunksPerDoc(); + var url = "http://www.abc.com"; + + Map map = getServiceSettingsMap(longDocumentStrategy, maxChunksPerDoc, url, null); + + ValidationException exception = expectThrows( + ValidationException.class, + () -> JinaAIRerankServiceSettings.fromMap(map, ConfigurationParseContext.REQUEST) + ); + assertThat(exception.getMessage(), containsString("[service_settings] does not contain the required setting [model_id]")); + } + + public void testFromMap_AllValuesProvided_CreatesJinaAIRerankServiceSettings() { + var longDocumentStrategy = LongDocumentStrategy.CHUNK; + var maxChunksPerDoc = randomNonNegativeInt(); + var url = "http://www.abc.com"; + var model = randomAlphaOfLength(10); + + Map map = getServiceSettingsMap(longDocumentStrategy, maxChunksPerDoc, url, model); + + var serviceSettings = JinaAIRerankServiceSettings.fromMap(map, ConfigurationParseContext.REQUEST); + assertExpectedServiceSettings(serviceSettings, longDocumentStrategy, maxChunksPerDoc, url, model); + } + + private Map getServiceSettingsMap( + LongDocumentStrategy longDocumentStrategy, + Integer maxChunksPerDoc, + String url, + String model + ) { + Map map = new HashMap<>(); + if (longDocumentStrategy != null) { + map.put(RerankServiceSettings.LONG_DOCUMENT_STRATEGY, longDocumentStrategy.strategyName); + } + + if (maxChunksPerDoc != null) { + map.put(RerankServiceSettings.MAX_CHUNKS_PER_DOC, maxChunksPerDoc); + } + + map.putAll(JinaAIServiceSettingsTests.getServiceSettingsMap(url, model)); + + return map; + } + + private void assertExpectedServiceSettings( + JinaAIRerankServiceSettings serviceSettings, + LongDocumentStrategy longDocumentStrategy, + Integer maxChunksPerDoc, + String url, + String model + ) { + assertEquals(longDocumentStrategy, serviceSettings.getLongDocumentStrategy()); + assertEquals(maxChunksPerDoc, serviceSettings.getMaxChunksPerDoc()); + JinaAIServiceSettings commonSettings = serviceSettings.getCommonSettings(); + assertEquals(url, commonSettings.uri().toString()); + assertEquals(model, commonSettings.modelId()); + } + public void testToXContent_WritesAllValues() throws IOException { + var longDocumentStrategy = randomFrom(LongDocumentStrategy.values()); + var maxChunksPerDoc = randomNonNegativeInt(); var url = "http://www.abc.com"; var model = "model"; - var serviceSettings = new JinaAIRerankServiceSettings(new JinaAIServiceSettings(url, model, null)); + var serviceSettings = new JinaAIRerankServiceSettings( + new RerankServiceSettings(longDocumentStrategy, maxChunksPerDoc), + new JinaAIServiceSettings(url, model, null) + ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); serviceSettings.toXContent(builder, null); String xContentResult = Strings.toString(builder); - assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(Strings.format(""" { - "url":"http://www.abc.com", - "model_id":"model", + "long_document_strategy":"%s", + "max_chunks_per_doc":%d, + "url":"%s", + "model_id":"%s", "rate_limit": { "requests_per_minute": 2000 } } - """)); + """, longDocumentStrategy.strategyName, maxChunksPerDoc, url, model, DEFAULT_RATE_LIMIT_SETTINGS.requestsPerTimeUnit()))); } @Override @@ -74,6 +166,7 @@ protected JinaAIRerankServiceSettings mutateInstance(JinaAIRerankServiceSettings @Override protected JinaAIRerankServiceSettings mutateInstanceForVersion(JinaAIRerankServiceSettings instance, TransportVersion version) { + // TODO: Update this once transport version is added to rerank service settings return instance; } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/settings/RerankServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/settings/RerankServiceSettingsTests.java new file mode 100644 index 0000000000000..d58152fa996de --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/settings/RerankServiceSettingsTests.java @@ -0,0 +1,127 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.settings; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.containsString; + +public class RerankServiceSettingsTests extends AbstractBWCWireSerializationTestCase { + + public static RerankServiceSettings createRandom() { + LongDocumentStrategy longDocumentStrategy = randomBoolean() ? null : randomFrom(LongDocumentStrategy.values()); + Integer maxChunksPerDoc = longDocumentStrategy == null ? null : randomBoolean() ? null : randomNonNegativeInt(); + return new RerankServiceSettings(longDocumentStrategy, maxChunksPerDoc); + } + + public void testFromMap_MapDoesNotIncludeValues_CreatesRerankServiceSettings() { + Map map = new HashMap<>(); + RerankServiceSettings settings = RerankServiceSettings.fromMap(map); + assertNull(settings.getLongDocumentStrategy()); + assertNull(settings.getMaxChunksPerDoc()); + } + + public void testFromMap_AllValuesSetToNull_CreatesRerankServiceSettings() { + Map map = new HashMap<>(); + map.put(RerankServiceSettings.LONG_DOCUMENT_STRATEGY, null); + map.put(RerankServiceSettings.MAX_CHUNKS_PER_DOC, null); + RerankServiceSettings settings = RerankServiceSettings.fromMap(map); + assertNull(settings.getLongDocumentStrategy()); + assertNull(settings.getMaxChunksPerDoc()); + } + + public void testFromMap_LongDocumentStrategyNullAndMaxChunksPerDocProvided_ThrowsValidationException() { + Map map = new HashMap<>(); + map.put(RerankServiceSettings.LONG_DOCUMENT_STRATEGY, null); + map.put(RerankServiceSettings.MAX_CHUNKS_PER_DOC, randomNonNegativeInt()); + ValidationException exception = expectThrows(ValidationException.class, () -> RerankServiceSettings.fromMap(map)); + assertThat( + exception.getMessage(), + containsString("Setting [max_chunks_per_doc] cannot be set without setting [long_document_strategy]") + ); + } + + public void testFromMap_LongDocumentStrategyTruncate_CreatesRerankServiceSettings() { + Map map = new HashMap<>(); + map.put(RerankServiceSettings.LONG_DOCUMENT_STRATEGY, LongDocumentStrategy.TRUNCATE.strategyName); + RerankServiceSettings settings = RerankServiceSettings.fromMap(map); + assertEquals(LongDocumentStrategy.TRUNCATE, settings.getLongDocumentStrategy()); + assertNull(settings.getMaxChunksPerDoc()); + } + + public void testFromMap_LongDocumentStrategyTruncateAndMaxChunksPerDocProvided_ThrowsValidationException() { + Map map = new HashMap<>(); + map.put(RerankServiceSettings.LONG_DOCUMENT_STRATEGY, LongDocumentStrategy.TRUNCATE.strategyName); + map.put(RerankServiceSettings.MAX_CHUNKS_PER_DOC, 5); + ValidationException exception = expectThrows(ValidationException.class, () -> RerankServiceSettings.fromMap(map)); + assertThat( + exception.getMessage(), + containsString("Setting [max_chunks_per_doc] cannot be set without setting [long_document_strategy]") + ); + } + + public void testFromMap_LongDocumentStrategyWithMaxChunksNotProvided_CreatesRerankServiceSettings() { + Map map = new HashMap<>(); + map.put(RerankServiceSettings.LONG_DOCUMENT_STRATEGY, LongDocumentStrategy.CHUNK.strategyName); + RerankServiceSettings settings = RerankServiceSettings.fromMap(map); + assertEquals(LongDocumentStrategy.CHUNK, settings.getLongDocumentStrategy()); + assertNull(settings.getMaxChunksPerDoc()); + } + + public void testFromMap_LongDocumentStrategyWithMaxChunksProvided_CreatesRerankServiceSettings() { + Map map = new HashMap<>(); + map.put(RerankServiceSettings.LONG_DOCUMENT_STRATEGY, LongDocumentStrategy.CHUNK.strategyName); + int maxChunksPerDoc = randomNonNegativeInt(); + map.put(RerankServiceSettings.MAX_CHUNKS_PER_DOC, maxChunksPerDoc); + RerankServiceSettings settings = RerankServiceSettings.fromMap(map); + assertEquals(LongDocumentStrategy.CHUNK, settings.getLongDocumentStrategy()); + assertEquals(maxChunksPerDoc, settings.getMaxChunksPerDoc().intValue()); + } + + @Override + protected RerankServiceSettings mutateInstanceForVersion(RerankServiceSettings instance, TransportVersion version) { + // TODO: Update this when transport version is added. + // No changes across versions yet + return instance; + } + + @Override + protected Writeable.Reader instanceReader() { + return RerankServiceSettings::new; + } + + @Override + protected RerankServiceSettings createTestInstance() { + return createRandom(); + } + + @Override + protected RerankServiceSettings mutateInstance(RerankServiceSettings instance) throws IOException { + LongDocumentStrategy longDocumentStrategy = instance.getLongDocumentStrategy(); + Integer maxChunksPerDoc = instance.getMaxChunksPerDoc(); + switch (randomInt(1)) { + case 0 -> { + longDocumentStrategy = randomValueOtherThan(longDocumentStrategy, () -> randomFrom(LongDocumentStrategy.values())); + } + case 1 -> { + maxChunksPerDoc = randomValueOtherThan(maxChunksPerDoc, ESTestCase::randomNonNegativeInt); + } + default -> throw new AssertionError("Illegal randomisation branch"); + } + + return new RerankServiceSettings(longDocumentStrategy, maxChunksPerDoc); + } +}