diff --git a/docs/changelog/121041.yaml b/docs/changelog/121041.yaml new file mode 100644 index 0000000000000..44a51a966c0a1 --- /dev/null +++ b/docs/changelog/121041.yaml @@ -0,0 +1,5 @@ +pr: 121041 +summary: Support configurable chunking in `semantic_text` fields +area: Relevance +type: enhancement +issues: [] diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index bd12813ee4b7c..5720b6f2d9938 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -202,6 +202,7 @@ static TransportVersion def(int id) { public static final TransportVersion RERANK_COMMON_OPTIONS_ADDED_8_19 = def(8_841_0_15); public static final TransportVersion REMOTE_EXCEPTION_8_19 = def(8_841_0_16); public static final TransportVersion AMAZON_BEDROCK_TASK_SETTINGS_8_19 = def(8_841_0_17); + public static final TransportVersion SEMANTIC_TEXT_CHUNKING_CONFIG_8_19 = def(8_841_0_18); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java index 8917d5a9cbbb5..ac4a3330823f9 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java @@ -12,6 +12,7 @@ import org.elasticsearch.TransportVersions; import org.elasticsearch.cluster.Diff; import org.elasticsearch.cluster.SimpleDiffable; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xcontent.ToXContentFragment; @@ -22,8 +23,11 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.Map; import java.util.Objects; +import static org.elasticsearch.TransportVersions.SEMANTIC_TEXT_CHUNKING_CONFIG_8_19; + /** * Contains inference field data for fields. * As inference is done in the coordinator node to avoid re-doing it at shard / replica level, the coordinator needs to check for the need @@ -35,21 +39,30 @@ public final class InferenceFieldMetadata implements SimpleDiffable chunkingSettings; - public InferenceFieldMetadata(String name, String inferenceId, String[] sourceFields) { - this(name, inferenceId, inferenceId, sourceFields); + public InferenceFieldMetadata(String name, String inferenceId, String[] sourceFields, Map chunkingSettings) { + this(name, inferenceId, inferenceId, sourceFields, chunkingSettings); } - public InferenceFieldMetadata(String name, String inferenceId, String searchInferenceId, String[] sourceFields) { + public InferenceFieldMetadata( + String name, + String inferenceId, + String searchInferenceId, + String[] sourceFields, + Map chunkingSettings + ) { this.name = Objects.requireNonNull(name); this.inferenceId = Objects.requireNonNull(inferenceId); this.searchInferenceId = Objects.requireNonNull(searchInferenceId); this.sourceFields = Objects.requireNonNull(sourceFields); + this.chunkingSettings = chunkingSettings != null ? Map.copyOf(chunkingSettings) : null; } public InferenceFieldMetadata(StreamInput input) throws IOException { @@ -61,6 +74,11 @@ public InferenceFieldMetadata(StreamInput input) throws IOException { this.searchInferenceId = this.inferenceId; } this.sourceFields = input.readStringArray(); + if (input.getTransportVersion().onOrAfter(SEMANTIC_TEXT_CHUNKING_CONFIG_8_19)) { + this.chunkingSettings = input.readGenericMap(); + } else { + this.chunkingSettings = null; + } } @Override @@ -71,6 +89,9 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(searchInferenceId); } out.writeStringArray(sourceFields); + if (out.getTransportVersion().onOrAfter(SEMANTIC_TEXT_CHUNKING_CONFIG_8_19)) { + out.writeGenericMap(chunkingSettings); + } } @Override @@ -81,16 +102,22 @@ public boolean equals(Object o) { return Objects.equals(name, that.name) && Objects.equals(inferenceId, that.inferenceId) && Objects.equals(searchInferenceId, that.searchInferenceId) - && Arrays.equals(sourceFields, that.sourceFields); + && Arrays.equals(sourceFields, that.sourceFields) + && Objects.equals(chunkingSettings, that.chunkingSettings); } @Override public int hashCode() { - int result = Objects.hash(name, inferenceId, searchInferenceId); + int result = Objects.hash(name, inferenceId, searchInferenceId, chunkingSettings); result = 31 * result + Arrays.hashCode(sourceFields); return result; } + @Override + public String toString() { + return Strings.toString(this); + } + public String getName() { return name; } @@ -107,6 +134,10 @@ public String[] getSourceFields() { return sourceFields; } + public Map getChunkingSettings() { + return chunkingSettings; + } + public static Diff readDiffFrom(StreamInput in) throws IOException { return SimpleDiffable.readDiffFrom(InferenceFieldMetadata::new, in); } @@ -119,6 +150,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(SEARCH_INFERENCE_ID_FIELD, searchInferenceId); } builder.array(SOURCE_FIELDS_FIELD, sourceFields); + if (chunkingSettings != null) { + builder.startObject(CHUNKING_SETTINGS_FIELD); + builder.mapContents(chunkingSettings); + builder.endObject(); + } return builder.endObject(); } @@ -131,6 +167,7 @@ public static InferenceFieldMetadata fromXContent(XContentParser parser) throws String currentFieldName = null; String inferenceId = null; String searchInferenceId = null; + Map chunkingSettings = null; List inputFields = new ArrayList<>(); while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { if (token == XContentParser.Token.FIELD_NAME) { @@ -151,6 +188,8 @@ public static InferenceFieldMetadata fromXContent(XContentParser parser) throws } } } + } else if (CHUNKING_SETTINGS_FIELD.equals(currentFieldName)) { + chunkingSettings = parser.map(); } else { parser.skipChildren(); } @@ -159,7 +198,8 @@ public static InferenceFieldMetadata fromXContent(XContentParser parser) throws name, inferenceId, searchInferenceId == null ? inferenceId : searchInferenceId, - inputFields.toArray(String[]::new) + inputFields.toArray(String[]::new), + chunkingSettings ); } } diff --git a/server/src/main/java/org/elasticsearch/inference/ChunkInferenceInput.java b/server/src/main/java/org/elasticsearch/inference/ChunkInferenceInput.java new file mode 100644 index 0000000000000..8e25e0e55f08c --- /dev/null +++ b/server/src/main/java/org/elasticsearch/inference/ChunkInferenceInput.java @@ -0,0 +1,25 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.inference; + +import org.elasticsearch.core.Nullable; + +import java.util.List; + +public record ChunkInferenceInput(String input, @Nullable ChunkingSettings chunkingSettings) { + + public ChunkInferenceInput(String input) { + this(input, null); + } + + public static List inputs(List chunkInferenceInputs) { + return chunkInferenceInputs.stream().map(ChunkInferenceInput::input).toList(); + } +} diff --git a/server/src/main/java/org/elasticsearch/inference/ChunkingSettings.java b/server/src/main/java/org/elasticsearch/inference/ChunkingSettings.java index 2e9072626b0a8..34b3e5a6d58ee 100644 --- a/server/src/main/java/org/elasticsearch/inference/ChunkingSettings.java +++ b/server/src/main/java/org/elasticsearch/inference/ChunkingSettings.java @@ -12,6 +12,10 @@ import org.elasticsearch.common.io.stream.VersionedNamedWriteable; import org.elasticsearch.xcontent.ToXContentObject; +import java.util.Map; + public interface ChunkingSettings extends ToXContentObject, VersionedNamedWriteable { ChunkingStrategy getChunkingStrategy(); + + Map asMap(); } diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceService.java b/server/src/main/java/org/elasticsearch/inference/InferenceService.java index 72a7c22e4b39a..d85acb021506a 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceService.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceService.java @@ -133,18 +133,18 @@ void unifiedCompletionInfer( /** * Chunk long text. * - * @param model The model - * @param query Inference query, mainly for re-ranking - * @param input Inference input - * @param taskSettings Settings in the request to override the model's defaults - * @param inputType For search, ingest etc - * @param timeout The timeout for the request - * @param listener Chunked Inference result listener + * @param model The model + * @param query Inference query, mainly for re-ranking + * @param input Inference input + * @param taskSettings Settings in the request to override the model's defaults + * @param inputType For search, ingest etc + * @param timeout The timeout for the request + * @param listener Chunked Inference result listener */ void chunkedInfer( Model model, @Nullable String query, - List input, + List input, Map taskSettings, InputType inputType, TimeValue timeout, diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java index 4abd0c4a9d469..a976e37ee2cf1 100644 --- a/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java @@ -694,7 +694,8 @@ private static InferenceFieldMetadata randomInferenceFieldMetadata(String name) name, randomIdentifier(), randomIdentifier(), - randomSet(1, 5, ESTestCase::randomIdentifier).toArray(String[]::new) + randomSet(1, 5, ESTestCase::randomIdentifier).toArray(String[]::new), + InferenceFieldMetadataTests.generateRandomChunkingSettings() ); } diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java index 2d5805696320d..f0c61b68226e1 100644 --- a/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java @@ -15,8 +15,10 @@ import org.elasticsearch.xcontent.XContentParser; import java.io.IOException; +import java.util.Map; import java.util.function.Predicate; +import static org.elasticsearch.cluster.metadata.InferenceFieldMetadata.CHUNKING_SETTINGS_FIELD; import static org.hamcrest.Matchers.equalTo; public class InferenceFieldMetadataTests extends AbstractXContentTestCase { @@ -37,11 +39,6 @@ protected InferenceFieldMetadata createTestInstance() { return createTestItem(); } - @Override - protected Predicate getRandomFieldsExcludeFilter() { - return p -> p.equals(""); // do not add elements at the top-level as any element at this level is parsed as a new inference field - } - @Override protected InferenceFieldMetadata doParseInstance(XContentParser parser) throws IOException { if (parser.nextToken() == XContentParser.Token.START_OBJECT) { @@ -58,18 +55,57 @@ protected boolean supportsUnknownFields() { return true; } + @Override + protected Predicate getRandomFieldsExcludeFilter() { + // do not add elements at the top-level as any element at this level is parsed as a new inference field, + // and do not add additional elements to chunking maps as they will fail parsing with extra data + return field -> field.equals("") || field.contains(CHUNKING_SETTINGS_FIELD); + } + private static InferenceFieldMetadata createTestItem() { String name = randomAlphaOfLengthBetween(3, 10); String inferenceId = randomIdentifier(); String searchInferenceId = randomIdentifier(); String[] inputFields = generateRandomStringArray(5, 10, false, false); - return new InferenceFieldMetadata(name, inferenceId, searchInferenceId, inputFields); + Map chunkingSettings = generateRandomChunkingSettings(); + return new InferenceFieldMetadata(name, inferenceId, searchInferenceId, inputFields, chunkingSettings); + } + + public static Map generateRandomChunkingSettings() { + if (randomBoolean()) { + return null; // Defaults to model chunking settings + } + return randomBoolean() ? generateRandomWordBoundaryChunkingSettings() : generateRandomSentenceBoundaryChunkingSettings(); + } + + private static Map generateRandomWordBoundaryChunkingSettings() { + return Map.of("strategy", "word_boundary", "max_chunk_size", randomIntBetween(20, 100), "overlap", randomIntBetween(1, 50)); + } + + private static Map generateRandomSentenceBoundaryChunkingSettings() { + return Map.of( + "strategy", + "sentence_boundary", + "max_chunk_size", + randomIntBetween(20, 100), + "sentence_overlap", + randomIntBetween(0, 1) + ); } public void testNullCtorArgsThrowException() { - assertThrows(NullPointerException.class, () -> new InferenceFieldMetadata(null, "inferenceId", "searchInferenceId", new String[0])); - assertThrows(NullPointerException.class, () -> new InferenceFieldMetadata("name", null, "searchInferenceId", new String[0])); - assertThrows(NullPointerException.class, () -> new InferenceFieldMetadata("name", "inferenceId", null, new String[0])); - assertThrows(NullPointerException.class, () -> new InferenceFieldMetadata("name", "inferenceId", "searchInferenceId", null)); + assertThrows( + NullPointerException.class, + () -> new InferenceFieldMetadata(null, "inferenceId", "searchInferenceId", new String[0], Map.of()) + ); + assertThrows( + NullPointerException.class, + () -> new InferenceFieldMetadata("name", null, "searchInferenceId", new String[0], Map.of()) + ); + assertThrows(NullPointerException.class, () -> new InferenceFieldMetadata("name", "inferenceId", null, new String[0], Map.of())); + assertThrows( + NullPointerException.class, + () -> new InferenceFieldMetadata("name", "inferenceId", "searchInferenceId", null, Map.of()) + ); } } diff --git a/server/src/test/java/org/elasticsearch/index/mapper/MappingLookupInferenceFieldMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/MappingLookupInferenceFieldMapperTests.java index 755b83e8eb7ad..93ac31c9ba582 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/MappingLookupInferenceFieldMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/MappingLookupInferenceFieldMapperTests.java @@ -11,6 +11,7 @@ import org.apache.lucene.search.Query; import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; +import org.elasticsearch.cluster.metadata.InferenceFieldMetadataTests; import org.elasticsearch.index.query.SearchExecutionContext; import org.elasticsearch.plugins.MapperPlugin; import org.elasticsearch.plugins.Plugin; @@ -102,7 +103,13 @@ private static class TestInferenceFieldMapper extends FieldMapper implements Inf @Override public InferenceFieldMetadata getMetadata(Set sourcePaths) { - return new InferenceFieldMetadata(fullPath(), INFERENCE_ID, SEARCH_INFERENCE_ID, sourcePaths.toArray(new String[0])); + return new InferenceFieldMetadata( + fullPath(), + INFERENCE_ID, + SEARCH_INFERENCE_ID, + sourcePaths.toArray(new String[0]), + InferenceFieldMetadataTests.generateRandomChunkingSettings() + ); } @Override diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java index 3c29cef47d628..7d4a120668a8b 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java @@ -13,6 +13,9 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; +import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.ChunkingStrategy; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; @@ -22,14 +25,20 @@ import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunker; +import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings; import java.io.IOException; +import java.util.ArrayList; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Random; public abstract class AbstractTestInferenceService implements InferenceService { + protected record ChunkedInput(String input, int startOffset, int endOffset) {} + protected static final Random random = new Random( System.getProperty("tests.seed") == null ? System.currentTimeMillis() @@ -105,6 +114,34 @@ public void start(Model model, TimeValue timeout, ActionListener listen @Override public void close() throws IOException {} + protected List chunkInputs(ChunkInferenceInput input) { + ChunkingSettings chunkingSettings = input.chunkingSettings(); + String inputText = input.input(); + if (chunkingSettings == null) { + return List.of(new ChunkedInput(inputText, 0, inputText.length())); + } + + List chunkedInputs = new ArrayList<>(); + if (chunkingSettings.getChunkingStrategy() == ChunkingStrategy.WORD) { + WordBoundaryChunker chunker = new WordBoundaryChunker(); + WordBoundaryChunkingSettings wordBoundaryChunkingSettings = (WordBoundaryChunkingSettings) chunkingSettings; + List offsets = chunker.chunk( + inputText, + wordBoundaryChunkingSettings.maxChunkSize(), + wordBoundaryChunkingSettings.overlap() + ); + for (WordBoundaryChunker.ChunkOffset offset : offsets) { + chunkedInputs.add(new ChunkedInput(inputText.substring(offset.start(), offset.end()), offset.start(), offset.end())); + } + + } else { + // Won't implement till we need it + throw new UnsupportedOperationException("Test inference service only supports word chunking strategies"); + } + + return chunkedInputs; + } + public static class TestServiceModel extends Model { public TestServiceModel( diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java index ad6f1b88de328..044af0ab1d37d 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java @@ -17,6 +17,7 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceExtension; @@ -35,7 +36,6 @@ import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; -import org.elasticsearch.xpack.core.inference.results.EmbeddingResults; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import java.io.IOException; @@ -147,7 +147,7 @@ public void unifiedCompletionInfer( public void chunkedInfer( Model model, @Nullable String query, - List input, + List input, Map taskSettings, InputType inputType, TimeValue timeout, @@ -176,21 +176,20 @@ private TextEmbeddingFloatResults makeResults(List input, ServiceSetting return new TextEmbeddingFloatResults(embeddings); } - private List makeChunkedResults(List input, ServiceSettings serviceSettings) { - TextEmbeddingFloatResults nonChunkedResults = makeResults(input, serviceSettings); - + private List makeChunkedResults(List inputs, ServiceSettings serviceSettings) { var results = new ArrayList(); - for (int i = 0; i < input.size(); i++) { - results.add( - new ChunkedInferenceEmbedding( - List.of( - new EmbeddingResults.Chunk( - nonChunkedResults.embeddings().get(i), - new ChunkedInference.TextOffset(0, input.get(i).length()) - ) + for (ChunkInferenceInput input : inputs) { + List chunkedInput = chunkInputs(input); + List chunks = chunkedInput.stream() + .map( + c -> new TextEmbeddingFloatResults.Chunk( + makeResults(List.of(c.input()), serviceSettings).embeddings().get(0), + new ChunkedInference.TextOffset(c.startOffset(), c.endOffset()) ) ) - ); + .toList(); + ChunkedInferenceEmbedding chunkedInferenceEmbedding = new ChunkedInferenceEmbedding(chunks); + results.add(chunkedInferenceEmbedding); } return results; } diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java index d4e3642affddb..989726443ecf4 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java @@ -16,6 +16,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceExtension; @@ -136,7 +137,7 @@ public void unifiedCompletionInfer( public void chunkedInfer( Model model, @Nullable String query, - List input, + List input, Map taskSettings, InputType inputType, TimeValue timeout, diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java index e103761ab8863..3df62ff142bb5 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java @@ -16,6 +16,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceExtension; @@ -33,9 +34,7 @@ import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; -import org.elasticsearch.xpack.core.inference.results.EmbeddingResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.core.ml.search.WeightedToken; import java.io.IOException; @@ -46,6 +45,7 @@ import java.util.Map; public class TestSparseInferenceServiceExtension implements InferenceServiceExtension { + @Override public List getInferenceServiceFactories() { return List.of(TestInferenceService::new); @@ -114,8 +114,7 @@ public void infer( ActionListener listener ) { switch (model.getConfigurations().getTaskType()) { - case ANY, SPARSE_EMBEDDING -> listener.onResponse(makeSparseEmbeddingResults(input)); - case TEXT_EMBEDDING -> listener.onResponse(makeTextEmbeddingResults(input)); + case ANY, SPARSE_EMBEDDING -> listener.onResponse(makeResults(input)); default -> listener.onFailure( new ElasticsearchStatusException( TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()), @@ -139,7 +138,7 @@ public void unifiedCompletionInfer( public void chunkedInfer( Model model, @Nullable String query, - List input, + List input, Map taskSettings, InputType inputType, TimeValue timeout, @@ -156,7 +155,7 @@ public void chunkedInfer( } } - private SparseEmbeddingResults makeSparseEmbeddingResults(List input) { + private SparseEmbeddingResults makeResults(List input) { var embeddings = new ArrayList(); for (int i = 0; i < input.size(); i++) { var tokens = new ArrayList(); @@ -168,35 +167,20 @@ private SparseEmbeddingResults makeSparseEmbeddingResults(List input) { return new SparseEmbeddingResults(embeddings); } - private TextEmbeddingFloatResults makeTextEmbeddingResults(List input) { - var embeddings = new ArrayList(); - for (int i = 0; i < input.size(); i++) { - var values = new float[5]; - for (int j = 0; j < 5; j++) { - values[j] = random.nextFloat(); - } - embeddings.add(new TextEmbeddingFloatResults.Embedding(values)); - } - return new TextEmbeddingFloatResults(embeddings); - } - - private List makeChunkedResults(List input) { + private List makeChunkedResults(List inputs) { List results = new ArrayList<>(); - for (int i = 0; i < input.size(); i++) { - var tokens = new ArrayList(); - for (int j = 0; j < 5; j++) { - tokens.add(new WeightedToken("feature_" + j, generateEmbedding(input.get(i), j))); - } - results.add( - new ChunkedInferenceEmbedding( - List.of( - new EmbeddingResults.Chunk( - new SparseEmbeddingResults.Embedding(tokens, false), - new ChunkedInference.TextOffset(0, input.get(i).length()) - ) - ) - ) - ); + for (ChunkInferenceInput chunkInferenceInput : inputs) { + List chunkedInput = chunkInputs(chunkInferenceInput); + List chunks = chunkedInput.stream().map(c -> { + var tokens = new ArrayList(); + for (int i = 0; i < 5; i++) { + tokens.add(new WeightedToken("feature_" + i, generateEmbedding(c.input(), i))); + } + var embeddings = new SparseEmbeddingResults.Embedding(tokens, false); + return new SparseEmbeddingResults.Chunk(embeddings, new ChunkedInference.TextOffset(c.startOffset(), c.endOffset())); + }).toList(); + ChunkedInferenceEmbedding chunkedInferenceEmbedding = new ChunkedInferenceEmbedding(chunks); + results.add(chunkedInferenceEmbedding); } return results; } diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java index 33dabd95ce41a..ecfd20461554a 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java @@ -17,6 +17,7 @@ import org.elasticsearch.common.xcontent.ChunkedToXContentHelper; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceExtension; @@ -292,7 +293,7 @@ public Iterator toXContentChunked(ToXContent.Params params public void chunkedInfer( Model model, String query, - List input, + List input, Map taskSettings, InputType inputType, TimeValue timeout, diff --git a/x-pack/plugin/inference/src/main/java/module-info.java b/x-pack/plugin/inference/src/main/java/module-info.java index 1c2240e8c5217..ce4bbe92774fa 100644 --- a/x-pack/plugin/inference/src/main/java/module-info.java +++ b/x-pack/plugin/inference/src/main/java/module-info.java @@ -42,6 +42,7 @@ exports org.elasticsearch.xpack.inference.services; exports org.elasticsearch.xpack.inference; exports org.elasticsearch.xpack.inference.action.task; + exports org.elasticsearch.xpack.inference.chunking; provides org.elasticsearch.features.FeatureSpecification with org.elasticsearch.xpack.inference.InferenceFeatures; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java index 5695e239380f3..33778c98ccf7f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java @@ -16,6 +16,7 @@ import java.util.Set; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper.SEMANTIC_TEXT_SUPPORT_CHUNKING_CONFIG; import static org.elasticsearch.xpack.inference.queries.SemanticKnnVectorQueryRewriteInterceptor.SEMANTIC_KNN_FILTER_FIX; import static org.elasticsearch.xpack.inference.queries.SemanticKnnVectorQueryRewriteInterceptor.SEMANTIC_KNN_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED; import static org.elasticsearch.xpack.inference.queries.SemanticMatchQueryRewriteInterceptor.SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED; @@ -63,7 +64,8 @@ public Set getTestFeatures() { SEMANTIC_KNN_FILTER_FIX, TEST_RERANKING_SERVICE_PARSE_TEXT_AS_SCORE, SemanticTextFieldMapper.SEMANTIC_TEXT_BIT_VECTOR_SUPPORT, - SemanticTextFieldMapper.SEMANTIC_TEXT_HANDLE_EMPTY_INPUT + SemanticTextFieldMapper.SEMANTIC_TEXT_HANDLE_EMPTY_INPUT, + SEMANTIC_TEXT_SUPPORT_CHUNKING_CONFIG ); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index c2ad057d0a256..ed5342b2ccbf1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -34,7 +34,9 @@ import org.elasticsearch.core.Releasable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.InputType; @@ -53,6 +55,7 @@ import org.elasticsearch.xpack.core.XPackField; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; import org.elasticsearch.xpack.inference.InferenceException; +import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.mapper.SemanticTextField; import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import org.elasticsearch.xpack.inference.mapper.SemanticTextUtils; @@ -170,6 +173,7 @@ private record InferenceProvider(InferenceService service, Model model) {} * @param input The input to run inference on. * @param inputOrder The original order of the input. * @param offsetAdjustment The adjustment to apply to the chunk text offsets. + * @param chunkingSettings Additional explicitly specified chunking settings, or null to use model defaults */ private record FieldInferenceRequest( int bulkItemIndex, @@ -177,7 +181,8 @@ private record FieldInferenceRequest( String sourceField, String input, int inputOrder, - int offsetAdjustment + int offsetAdjustment, + ChunkingSettings chunkingSettings ) {} /** @@ -353,7 +358,10 @@ public void onFailure(Exception exc) { modelRegistry.getModelWithSecrets(inferenceId, modelLoadingListener); return; } - final List inputs = requests.stream().map(FieldInferenceRequest::input).collect(Collectors.toList()); + final List inputs = requests.stream() + .map(r -> new ChunkInferenceInput(r.input, r.chunkingSettings)) + .collect(Collectors.toList()); + ActionListener> completionListener = new ActionListener<>() { @Override public void onResponse(List results) { @@ -448,6 +456,7 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map< for (var entry : fieldInferenceMap.values()) { String field = entry.getName(); String inferenceId = entry.getInferenceId(); + ChunkingSettings chunkingSettings = ChunkingSettingsBuilder.fromMap(entry.getChunkingSettings(), false); if (useLegacyFormat) { var originalFieldValue = XContentMapValues.extractValue(field, docMap); @@ -522,7 +531,9 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map< new FieldInferenceResponse(field, sourceField, v, order++, 0, null, EMPTY_CHUNKED_INFERENCE) ); } else { - requests.add(new FieldInferenceRequest(itemIndex, field, sourceField, v, order++, offsetAdjustment)); + requests.add( + new FieldInferenceRequest(itemIndex, field, sourceField, v, order++, offsetAdjustment, chunkingSettings) + ); } // When using the inference metadata fields format, all the input values are concatenated so that the @@ -603,6 +614,7 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons new SemanticTextField.InferenceResult( inferenceFieldMetadata.getInferenceId(), model != null ? new MinimalServiceSettings(model) : null, + ChunkingSettingsBuilder.fromMap(inferenceFieldMetadata.getChunkingSettings(), false), chunkMap ), indexRequest.getContentType() diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilder.java index 2ede1684e315b..25553a4c760f0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilder.java @@ -10,6 +10,7 @@ import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.ChunkingStrategy; +import java.util.HashMap; import java.util.Map; public class ChunkingSettingsBuilder { @@ -18,13 +19,24 @@ public class ChunkingSettingsBuilder { public static final WordBoundaryChunkingSettings OLD_DEFAULT_SETTINGS = new WordBoundaryChunkingSettings(250, 100); public static ChunkingSettings fromMap(Map settings) { - if (settings == null) { - return OLD_DEFAULT_SETTINGS; - } + return fromMap(settings, true); + } - if (settings.isEmpty()) { - return DEFAULT_SETTINGS; + public static ChunkingSettings fromMap(Map settings, boolean returnDefaultValues) { + + if (returnDefaultValues) { + if (settings == null) { + return OLD_DEFAULT_SETTINGS; + } + if (settings.isEmpty()) { + return DEFAULT_SETTINGS; + } + } else { + if (settings == null || settings.isEmpty()) { + return null; + } } + if (settings.containsKey(ChunkingSettingsOptions.STRATEGY.toString()) == false) { throw new IllegalArgumentException("Can't generate Chunker without ChunkingStrategy provided"); } @@ -33,8 +45,8 @@ public static ChunkingSettings fromMap(Map settings) { settings.get(ChunkingSettingsOptions.STRATEGY.toString()).toString() ); return switch (chunkingStrategy) { - case WORD -> WordBoundaryChunkingSettings.fromMap(settings); - case SENTENCE -> SentenceBoundaryChunkingSettings.fromMap(settings); + case WORD -> WordBoundaryChunkingSettings.fromMap(new HashMap<>(settings)); + case SENTENCE -> SentenceBoundaryChunkingSettings.fromMap(new HashMap<>(settings)); }; } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java index 70dc845b40628..b19129350282f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java @@ -10,8 +10,10 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.util.concurrent.AtomicArray; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.ChunkingStrategy; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; @@ -21,6 +23,8 @@ import java.util.ArrayList; import java.util.List; +import java.util.Map; +import java.util.Objects; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReferenceArray; import java.util.function.Supplier; @@ -40,9 +44,9 @@ public class EmbeddingRequestChunker> { // Visible for testing - record Request(int inputIndex, int chunkIndex, ChunkOffset chunk, List inputs) { - String chunkText() { - return inputs.get(inputIndex).substring(chunk.start(), chunk.end()); + record Request(int inputIndex, int chunkIndex, ChunkOffset chunk, List inputs) { + public String chunkText() { + return inputs.get(inputIndex).input().substring(chunk.start(), chunk.end()); } } @@ -54,8 +58,7 @@ public Supplier> inputs() { public record BatchRequestAndListener(BatchRequest batch, ActionListener listener) {} - private static final int DEFAULT_WORDS_PER_CHUNK = 250; - private static final int DEFAULT_CHUNK_OVERLAP = 100; + private static final ChunkingSettings DEFAULT_CHUNKING_SETTINGS = new WordBoundaryChunkingSettings(250, 100); // The maximum number of chunks that is stored for any input text. // If the configured chunker chunks the text into more chunks, each @@ -72,28 +75,44 @@ public record BatchRequestAndListener(BatchRequest batch, ActionListener resultsErrors; private ActionListener> finalListener; - public EmbeddingRequestChunker(List inputs, int maxNumberOfInputsPerBatch) { + public EmbeddingRequestChunker(List inputs, int maxNumberOfInputsPerBatch) { this(inputs, maxNumberOfInputsPerBatch, null); } - public EmbeddingRequestChunker(List inputs, int maxNumberOfInputsPerBatch, int wordsPerChunk, int chunkOverlap) { + public EmbeddingRequestChunker(List inputs, int maxNumberOfInputsPerBatch, int wordsPerChunk, int chunkOverlap) { this(inputs, maxNumberOfInputsPerBatch, new WordBoundaryChunkingSettings(wordsPerChunk, chunkOverlap)); } - public EmbeddingRequestChunker(List inputs, int maxNumberOfInputsPerBatch, ChunkingSettings chunkingSettings) { + public EmbeddingRequestChunker( + List inputs, + int maxNumberOfInputsPerBatch, + ChunkingSettings defaultChunkingSettings + ) { this.resultEmbeddings = new ArrayList<>(inputs.size()); this.resultOffsetStarts = new ArrayList<>(inputs.size()); this.resultOffsetEnds = new ArrayList<>(inputs.size()); this.resultsErrors = new AtomicArray<>(inputs.size()); - if (chunkingSettings == null) { - chunkingSettings = new WordBoundaryChunkingSettings(DEFAULT_WORDS_PER_CHUNK, DEFAULT_CHUNK_OVERLAP); + if (defaultChunkingSettings == null) { + defaultChunkingSettings = DEFAULT_CHUNKING_SETTINGS; } - Chunker chunker = ChunkerBuilder.fromChunkingStrategy(chunkingSettings.getChunkingStrategy()); + + Map chunkers = inputs.stream() + .map(ChunkInferenceInput::chunkingSettings) + .filter(Objects::nonNull) + .map(ChunkingSettings::getChunkingStrategy) + .distinct() + .collect(Collectors.toMap(chunkingStrategy -> chunkingStrategy, ChunkerBuilder::fromChunkingStrategy)); + Chunker defaultChunker = ChunkerBuilder.fromChunkingStrategy(defaultChunkingSettings.getChunkingStrategy()); List allRequests = new ArrayList<>(); for (int inputIndex = 0; inputIndex < inputs.size(); inputIndex++) { - List chunks = chunker.chunk(inputs.get(inputIndex), chunkingSettings); + ChunkingSettings chunkingSettings = inputs.get(inputIndex).chunkingSettings(); + if (chunkingSettings == null) { + chunkingSettings = defaultChunkingSettings; + } + Chunker chunker = chunkers.getOrDefault(chunkingSettings.getChunkingStrategy(), defaultChunker); + List chunks = chunker.chunk(inputs.get(inputIndex).input(), chunkingSettings); int resultCount = Math.min(chunks.size(), MAX_CHUNKS); resultEmbeddings.add(new AtomicReferenceArray<>(resultCount)); resultOffsetStarts.add(new ArrayList<>(resultCount)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java index 0033cc9ee2bef..b87e164089d31 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java @@ -22,6 +22,7 @@ import java.io.IOException; import java.util.Arrays; +import java.util.Locale; import java.util.Map; import java.util.Objects; import java.util.Set; @@ -54,6 +55,18 @@ public SentenceBoundaryChunkingSettings(StreamInput in) throws IOException { } } + @Override + public Map asMap() { + return Map.of( + ChunkingSettingsOptions.STRATEGY.toString(), + STRATEGY.toString().toLowerCase(Locale.ROOT), + ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(), + maxChunkSize, + ChunkingSettingsOptions.SENTENCE_OVERLAP.toString(), + sentenceOverlap + ); + } + public static SentenceBoundaryChunkingSettings fromMap(Map map) { ValidationException validationException = new ValidationException(); @@ -141,4 +154,9 @@ public boolean equals(Object o) { public int hashCode() { return Objects.hash(maxChunkSize, sentenceOverlap); } + + @Override + public String toString() { + return Strings.toString(this); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java index 7e0378d5b0cd1..97f8aa49ef4d1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java @@ -21,6 +21,7 @@ import java.io.IOException; import java.util.Arrays; +import java.util.Locale; import java.util.Map; import java.util.Objects; import java.util.Set; @@ -48,6 +49,26 @@ public WordBoundaryChunkingSettings(StreamInput in) throws IOException { overlap = in.readInt(); } + @Override + public Map asMap() { + return Map.of( + ChunkingSettingsOptions.STRATEGY.toString(), + STRATEGY.toString().toLowerCase(Locale.ROOT), + ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(), + maxChunkSize, + ChunkingSettingsOptions.OVERLAP.toString(), + overlap + ); + } + + public int maxChunkSize() { + return maxChunkSize; + } + + public int overlap() { + return overlap; + } + public static WordBoundaryChunkingSettings fromMap(Map map) { ValidationException validationException = new ValidationException(); @@ -130,4 +151,9 @@ public boolean equals(Object o) { public int hashCode() { return Objects.hash(maxChunkSize, overlap); } + + @Override + public String toString() { + return Strings.toString(this); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInput.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInput.java index 3f59386082d73..1e188d0f7bf5b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInput.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInput.java @@ -8,11 +8,14 @@ package org.elasticsearch.xpack.inference.external.http.sender; import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ChunkInferenceInput; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InputType; import java.util.List; import java.util.Objects; import java.util.function.Supplier; +import java.util.stream.Collectors; public class EmbeddingsInput extends InferenceInputs { @@ -24,30 +27,42 @@ public static EmbeddingsInput of(InferenceInputs inferenceInputs) { return (EmbeddingsInput) inferenceInputs; } - private final Supplier> listSupplier; + private final Supplier> listSupplier; private final InputType inputType; - public EmbeddingsInput(List input, @Nullable InputType inputType) { + public EmbeddingsInput(List input, @Nullable InputType inputType) { this(input, inputType, false); } - public EmbeddingsInput(Supplier> inputSupplier, @Nullable InputType inputType) { + public EmbeddingsInput(Supplier> inputSupplier, @Nullable InputType inputType) { super(false); this.listSupplier = Objects.requireNonNull(inputSupplier); this.inputType = inputType; } - public EmbeddingsInput(List input, @Nullable InputType inputType, boolean stream) { + public EmbeddingsInput(List input, @Nullable ChunkingSettings chunkingSettings, @Nullable InputType inputType) { + this(input.stream().map(i -> new ChunkInferenceInput(i, chunkingSettings)).collect(Collectors.toList()), inputType, false); + } + + public EmbeddingsInput(List input, @Nullable InputType inputType, boolean stream) { super(stream); Objects.requireNonNull(input); this.listSupplier = () -> input; this.inputType = inputType; } - public List getInputs() { + public List getInputs() { return this.listSupplier.get(); } + public static EmbeddingsInput fromStrings(List input, @Nullable InputType inputType) { + return new EmbeddingsInput(input, null, inputType); + } + + public List getStringInputs() { + return getInputs().stream().map(ChunkInferenceInput::input).collect(Collectors.toList()); + } + public InputType getInputType() { return this.inputType; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/TruncatingRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/TruncatingRequestManager.java index 4a485f87858aa..c39387d647f77 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/TruncatingRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/TruncatingRequestManager.java @@ -52,7 +52,7 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - var docsInput = inferenceInputs.castTo(EmbeddingsInput.class).getInputs(); + var docsInput = inferenceInputs.castTo(EmbeddingsInput.class).getStringInputs(); var truncatedInput = truncate(docsInput, maxInputTokens); var request = requestCreator.apply(truncatedInput); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java index be7588abbbc8d..312b4e9cb9cde 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java @@ -15,6 +15,7 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.index.IndexVersions; import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.MinimalServiceSettings; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.DeprecationHandler; @@ -27,6 +28,7 @@ import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xcontent.support.MapXContentParser; +import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import java.io.IOException; import java.util.ArrayList; @@ -69,8 +71,14 @@ public record SemanticTextField( static final String CHUNKED_START_OFFSET_FIELD = "start_offset"; static final String CHUNKED_END_OFFSET_FIELD = "end_offset"; static final String MODEL_SETTINGS_FIELD = "model_settings"; + static final String CHUNKING_SETTINGS_FIELD = "chunking_settings"; - public record InferenceResult(String inferenceId, MinimalServiceSettings modelSettings, Map> chunks) {} + public record InferenceResult( + String inferenceId, + MinimalServiceSettings modelSettings, + ChunkingSettings chunkingSettings, + Map> chunks + ) {} public record Chunk(@Nullable String text, int startOffset, int endOffset, BytesReference rawEmbeddings) {} @@ -120,6 +128,18 @@ static MinimalServiceSettings parseModelSettingsFromMap(Object node) { } } + static ChunkingSettings parseChunkingSettingsFromMap(Object node) { + if (node == null) { + return null; + } + try { + Map map = XContentMapValues.nodeMapValue(node, CHUNKING_SETTINGS_FIELD); + return ChunkingSettingsBuilder.fromMap(map, false); + } catch (Exception exc) { + throw new ElasticsearchException(exc); + } + } + @Override public List originalValues() { return originalValues != null ? originalValues : Collections.emptyList(); @@ -135,6 +155,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.startObject(INFERENCE_FIELD); builder.field(INFERENCE_ID_FIELD, inference.inferenceId); builder.field(MODEL_SETTINGS_FIELD, inference.modelSettings); + if (inference.chunkingSettings != null) { + builder.field(CHUNKING_SETTINGS_FIELD, inference.chunkingSettings); + } + if (useLegacyFormat) { builder.startArray(CHUNKS_FIELD); } else { @@ -178,6 +202,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws private static final ConstructingObjectParser SEMANTIC_TEXT_FIELD_PARSER = new ConstructingObjectParser<>(SemanticTextFieldMapper.CONTENT_TYPE, true, (args, context) -> { List originalValues = (List) args[0]; + InferenceResult inferenceResult = (InferenceResult) args[1]; if (context.useLegacyFormat() == false) { if (originalValues != null && originalValues.isEmpty() == false) { throw new IllegalArgumentException("Unknown field [" + TEXT_FIELD + "]"); @@ -188,7 +213,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws context.useLegacyFormat(), context.fieldName(), originalValues, - (InferenceResult) args[1], + inferenceResult, context.xContentType() ); }); @@ -197,7 +222,13 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws private static final ConstructingObjectParser INFERENCE_RESULT_PARSER = new ConstructingObjectParser<>( INFERENCE_FIELD, true, - args -> new InferenceResult((String) args[0], (MinimalServiceSettings) args[1], (Map>) args[2]) + args -> { + String inferenceId = (String) args[0]; + MinimalServiceSettings modelSettings = (MinimalServiceSettings) args[1]; + Map chunkingSettings = (Map) args[2]; + Map> chunks = (Map>) args[3]; + return new InferenceResult(inferenceId, modelSettings, ChunkingSettingsBuilder.fromMap(chunkingSettings, false), chunks); + } ); private static final ConstructingObjectParser CHUNKS_PARSER = new ConstructingObjectParser<>( @@ -218,11 +249,17 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws INFERENCE_RESULT_PARSER.declareString(constructorArg(), new ParseField(INFERENCE_ID_FIELD)); INFERENCE_RESULT_PARSER.declareObjectOrNull( - constructorArg(), + optionalConstructorArg(), (p, c) -> MinimalServiceSettings.parse(p), null, new ParseField(MODEL_SETTINGS_FIELD) ); + INFERENCE_RESULT_PARSER.declareObjectOrNull( + optionalConstructorArg(), + (p, c) -> p.map(), + null, + new ParseField(CHUNKING_SETTINGS_FIELD) + ); INFERENCE_RESULT_PARSER.declareField(constructorArg(), (p, c) -> { if (c.useLegacyFormat()) { return Map.of(c.fieldName, parseChunksArrayLegacy(p, c)); @@ -297,7 +334,7 @@ public static List toSemanticTextFieldChunksLegacy(String input, ChunkedI return chunks; } - public static Chunk toSemanticTextFieldChunkLegacy(String input, ChunkedInference.Chunk chunk) { + public static Chunk toSemanticTextFieldChunkLegacy(String input, org.elasticsearch.inference.ChunkedInference.Chunk chunk) { var text = input.substring(chunk.textOffset().start(), chunk.textOffset().end()); return new Chunk(text, -1, -1, chunk.bytesReference()); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 15bdbdd27694a..f78bcc9106979 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -58,6 +58,7 @@ import org.elasticsearch.index.query.NestedQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.inference.MinimalServiceSettings; import org.elasticsearch.inference.SimilarityMeasure; @@ -94,6 +95,7 @@ import static org.elasticsearch.search.SearchService.DEFAULT_SIZE; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKED_EMBEDDINGS_FIELD; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKED_OFFSET_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKING_SETTINGS_FIELD; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKS_FIELD; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.INFERENCE_FIELD; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.INFERENCE_ID_FIELD; @@ -122,6 +124,7 @@ public class SemanticTextFieldMapper extends FieldMapper implements InferenceFie public static final NodeFeature SEMANTIC_TEXT_HANDLE_EMPTY_INPUT = new NodeFeature("semantic_text.handle_empty_input"); public static final NodeFeature SEMANTIC_TEXT_SKIP_INFERENCE_FIELDS = new NodeFeature("semantic_text.skip_inference_fields"); public static final NodeFeature SEMANTIC_TEXT_BIT_VECTOR_SUPPORT = new NodeFeature("semantic_text.bit_vector_support"); + public static final NodeFeature SEMANTIC_TEXT_SUPPORT_CHUNKING_CONFIG = new NodeFeature("semantic_text.support_chunking_config"); public static final String CONTENT_TYPE = "semantic_text"; public static final String DEFAULT_ELSER_2_INFERENCE_ID = DEFAULT_ELSER_ID; @@ -179,6 +182,17 @@ public static class Builder extends FieldMapper.Builder { Objects::toString ).acceptsNull().setMergeValidator(SemanticTextFieldMapper::canMergeModelSettings); + @SuppressWarnings("unchecked") + private final Parameter chunkingSettings = new Parameter<>( + CHUNKING_SETTINGS_FIELD, + true, + () -> null, + (n, c, o) -> SemanticTextField.parseChunkingSettingsFromMap(o), + mapper -> ((SemanticTextFieldType) mapper.fieldType()).chunkingSettings, + XContentBuilder::field, + Objects::toString + ).acceptsNull(); + private final Parameter> meta = Parameter.metaParam(); private Function inferenceFieldBuilder; @@ -221,9 +235,14 @@ public Builder setModelSettings(MinimalServiceSettings value) { return this; } + public Builder setChunkingSettings(ChunkingSettings value) { + this.chunkingSettings.setValue(value); + return this; + } + @Override protected Parameter[] getParameters() { - return new Parameter[] { inferenceId, searchInferenceId, modelSettings, meta }; + return new Parameter[] { inferenceId, searchInferenceId, modelSettings, chunkingSettings, meta }; } @Override @@ -265,6 +284,7 @@ public SemanticTextFieldMapper build(MapperBuilderContext context) { inferenceId.getValue(), searchInferenceId.getValue(), modelSettings.getValue(), + chunkingSettings.getValue(), inferenceField, useLegacyFormat, meta.getValue() @@ -523,7 +543,10 @@ public InferenceFieldMetadata getMetadata(Set sourcePaths) { String[] copyFields = sourcePaths.toArray(String[]::new); // ensure consistent order Arrays.sort(copyFields); - return new InferenceFieldMetadata(fullPath(), fieldType().getInferenceId(), fieldType().getSearchInferenceId(), copyFields); + ChunkingSettings fieldTypeChunkingSettings = fieldType().getChunkingSettings(); + Map asMap = fieldTypeChunkingSettings != null ? fieldTypeChunkingSettings.asMap() : null; + + return new InferenceFieldMetadata(fullPath(), fieldType().getInferenceId(), fieldType().getSearchInferenceId(), copyFields, asMap); } @Override @@ -552,6 +575,7 @@ public static class SemanticTextFieldType extends SimpleMappedFieldType { private final String inferenceId; private final String searchInferenceId; private final MinimalServiceSettings modelSettings; + private final ChunkingSettings chunkingSettings; private final ObjectMapper inferenceField; private final boolean useLegacyFormat; @@ -560,6 +584,7 @@ public SemanticTextFieldType( String inferenceId, String searchInferenceId, MinimalServiceSettings modelSettings, + ChunkingSettings chunkingSettings, ObjectMapper inferenceField, boolean useLegacyFormat, Map meta @@ -568,6 +593,7 @@ public SemanticTextFieldType( this.inferenceId = inferenceId; this.searchInferenceId = searchInferenceId; this.modelSettings = modelSettings; + this.chunkingSettings = chunkingSettings; this.inferenceField = inferenceField; this.useLegacyFormat = useLegacyFormat; } @@ -603,6 +629,10 @@ public MinimalServiceSettings getModelSettings() { return modelSettings; } + public ChunkingSettings getChunkingSettings() { + return chunkingSettings; + } + public ObjectMapper getInferenceField() { return inferenceField; } @@ -875,7 +905,7 @@ public List fetchValues(Source source, int doc, List ignoredValu useLegacyFormat, name(), null, - new SemanticTextField.InferenceResult(inferenceId, modelSettings, chunkMap), + new SemanticTextField.InferenceResult(inferenceId, modelSettings, chunkingSettings, chunkMap), source.sourceContentType() ) ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java index a35f08f51a81c..27ae9235aa720 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java @@ -14,6 +14,7 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Strings; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceResults; @@ -70,29 +71,31 @@ public void infer( ActionListener listener ) { init(); - var inferenceInput = createInput(this, model, input, inputType, query, returnDocuments, topN, stream); + var chunkInferenceInput = input.stream().map(i -> new ChunkInferenceInput(i, null)).toList(); + var inferenceInput = createInput(this, model, chunkInferenceInput, inputType, query, returnDocuments, topN, stream); doInfer(model, inferenceInput, taskSettings, timeout, listener); } private static InferenceInputs createInput( SenderService service, Model model, - List input, + List input, InputType inputType, @Nullable String query, @Nullable Boolean returnDocuments, @Nullable Integer topN, boolean stream ) { + List textInput = ChunkInferenceInput.inputs(input); return switch (model.getTaskType()) { - case COMPLETION, CHAT_COMPLETION -> new ChatCompletionInput(input, stream); + case COMPLETION, CHAT_COMPLETION -> new ChatCompletionInput(textInput, stream); case RERANK -> { ValidationException validationException = new ValidationException(); service.validateRerankParameters(returnDocuments, topN, validationException); if (validationException.validationErrors().isEmpty() == false) { throw validationException; } - yield new QueryAndDocsInputs(query, input, returnDocuments, topN, stream); + yield new QueryAndDocsInputs(query, textInput, returnDocuments, topN, stream); } case TEXT_EMBEDDING, SPARSE_EMBEDDING -> { ValidationException validationException = new ValidationException(); @@ -122,7 +125,7 @@ public void unifiedCompletionInfer( public void chunkedInfer( Model model, - List input, + List input, Map taskSettings, InputType inputType, TimeValue timeout, @@ -136,7 +139,7 @@ public void chunkedInfer( public void chunkedInfer( Model model, @Nullable String query, - List input, + List input, Map taskSettings, InputType inputType, TimeValue timeout, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchEmbeddingsRequestManager.java index bb9a9728f1212..f11cd41b25aa0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchEmbeddingsRequestManager.java @@ -72,7 +72,7 @@ public void execute( ActionListener listener ) { EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = input.getInputs(); + List docsInput = input.getStringInputs(); InputType inputType = input.getInputType(); AlibabaCloudSearchEmbeddingsRequest request = new AlibabaCloudSearchEmbeddingsRequest(account, docsInput, inputType, model); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java index 1ccee3c442ca2..7897317736c72 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java @@ -343,7 +343,7 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = alibabaCloudSearchModel.accept(actionCreator, taskSettings); - action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchSparseRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchSparseRequestManager.java index 082bbbbbbe323..acce3b9a1d6ea 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchSparseRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchSparseRequestManager.java @@ -72,7 +72,7 @@ public void execute( ActionListener listener ) { EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = input.getInputs(); + List docsInput = input.getStringInputs(); InputType inputType = input.getInputType(); AlibabaCloudSearchSparseRequest request = new AlibabaCloudSearchSparseRequest(account, docsInput, inputType, model); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockEmbeddingsRequestManager.java index c07e5a8462e12..06910611e0a96 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockEmbeddingsRequestManager.java @@ -57,7 +57,7 @@ public void execute( ActionListener listener ) { EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = input.getInputs(); + List docsInput = input.getStringInputs(); InputType inputType = input.getInputType(); var serviceSettings = embeddingsModel.getServiceSettings(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java index 520b5c6b91549..b0b4b7eed1a72 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java @@ -169,7 +169,7 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = baseAmazonBedrockModel.accept(actionCreator, taskSettings); - action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); } } else { listener.onFailure(createInvalidModelException(model)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioEmbeddingsRequestManager.java index e206d9b0ca8de..abb9b26a80b0c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioEmbeddingsRequestManager.java @@ -51,7 +51,7 @@ public void execute( ActionListener listener ) { EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = input.getInputs(); + List docsInput = input.getStringInputs(); InputType inputType = input.getInputType(); var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java index 33b66253c5973..04883f23b947f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java @@ -140,7 +140,7 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = baseAzureAiStudioModel.accept(actionCreator, taskSettings); - action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); } } else { listener.onFailure(createInvalidModelException(model)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiEmbeddingsRequestManager.java index 10a8fbae3474d..e98bf731210d7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiEmbeddingsRequestManager.java @@ -64,7 +64,7 @@ public void execute( ActionListener listener ) { EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = input.getInputs(); + List docsInput = input.getStringInputs(); InputType inputType = input.getInputType(); var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java index e9c5a310322ed..e9ff97c1ba725 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java @@ -293,7 +293,7 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = azureOpenAiModel.accept(actionCreator, taskSettings); - action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereEmbeddingsRequestManager.java index 8e53678c936bf..e721c3e46cecf 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereEmbeddingsRequestManager.java @@ -53,7 +53,7 @@ public void execute( ActionListener listener ) { EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = input.getInputs(); + List docsInput = input.getStringInputs(); InputType inputType = input.getInputType(); CohereEmbeddingsRequest request = new CohereEmbeddingsRequest(docsInput, inputType, model); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java index 550c6ad756ac6..bf6a0bd03122b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java @@ -306,7 +306,7 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = cohereModel.accept(actionCreator, taskSettings); - action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 04ec7596f8e47..2910048cbd0a4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -425,7 +425,7 @@ private ElasticInferenceServiceModel createModelFromPersistent( private static List translateToChunkedResults(InferenceInputs inputs, InferenceServiceResults inferenceResults) { if (inferenceResults instanceof SparseEmbeddingResults sparseEmbeddingResults) { - var inputsAsList = EmbeddingsInput.of(inputs).getInputs(); + var inputsAsList = EmbeddingsInput.of(inputs).getStringInputs(); return ChunkedInferenceEmbedding.listOf(inputsAsList, sparseEmbeddingResults); } else if (inferenceResults instanceof ErrorInferenceResults error) { return List.of(new ChunkedInferenceError(error.getException())); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsRequestManager.java index a9ad10b3ba198..ecb452d5f4d78 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsRequestManager.java @@ -68,7 +68,7 @@ public void execute( ActionListener listener ) { EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = input.getInputs(); + List docsInput = input.getStringInputs(); InputType inputType = input.getInputType(); var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index 9be2f6a0b11ed..8232240b2c9ba 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -19,6 +19,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceResults; @@ -685,22 +686,11 @@ public void inferRerank( client.execute(InferModelAction.INSTANCE, request, maybeDeployListener); } - public void chunkedInfer( - Model model, - List input, - Map taskSettings, - InputType inputType, - TimeValue timeout, - ActionListener> listener - ) { - chunkedInfer(model, null, input, taskSettings, inputType, timeout, listener); - } - @Override public void chunkedInfer( Model model, @Nullable String query, - List input, + List input, Map taskSettings, InputType inputType, TimeValue timeout, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioEmbeddingsRequestManager.java index e285f2235f27c..13e54b9e3e17b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioEmbeddingsRequestManager.java @@ -57,7 +57,7 @@ public void execute( ActionListener listener ) { EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = input.getInputs(); + List docsInput = input.getStringInputs(); InputType inputType = input.getInputType(); var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java index d334d4569275f..9841ea64370c3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java @@ -355,7 +355,13 @@ protected void doChunkedInfer( ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { - doInfer(model, new EmbeddingsInput(request.batch().inputs(), inputType), taskSettings, timeout, request.listener()); + doInfer( + model, + EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), + taskSettings, + timeout, + request.listener() + ); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiEmbeddingsRequestManager.java index be9262302b99a..2dc60ef114459 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiEmbeddingsRequestManager.java @@ -65,7 +65,7 @@ public void execute( ActionListener listener ) { EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = input.getInputs(); + List docsInput = input.getStringInputs(); InputType inputType = input.getInputType(); var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java index c4ce885717d39..e966ebc8d9e9b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java @@ -243,7 +243,7 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = googleVertexAiModel.accept(actionCreator, taskSettings); - action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceRequestManager.java index c5f8736c27641..b09cf8d98b7f3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceRequestManager.java @@ -62,7 +62,7 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - List docsInput = EmbeddingsInput.of(inferenceInputs).getInputs(); + List docsInput = EmbeddingsInput.of(inferenceInputs).getStringInputs(); var truncatedInput = truncate(docsInput, model.getTokenLimit()); var request = new HuggingFaceInferenceRequest(truncator, truncatedInput, model); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java index 4a2e395009b06..f2a53520e18e6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java @@ -128,7 +128,7 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = huggingFaceModel.accept(actionCreator); - action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java index 60a7f0bfd44b1..8116eaf86e74a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java @@ -15,6 +15,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -112,7 +113,7 @@ protected void doChunkedInfer( private static List translateToChunkedResults(EmbeddingsInput inputs, InferenceServiceResults inferenceResults) { if (inferenceResults instanceof TextEmbeddingFloatResults textEmbeddingResults) { - validateInputSizeAgainstEmbeddings(inputs.getInputs(), textEmbeddingResults.embeddings().size()); + validateInputSizeAgainstEmbeddings(ChunkInferenceInput.inputs(inputs.getInputs()), textEmbeddingResults.embeddings().size()); var results = new ArrayList(inputs.getInputs().size()); @@ -122,7 +123,7 @@ private static List translateToChunkedResults(EmbeddingsInput List.of( new EmbeddingResults.Chunk( textEmbeddingResults.embeddings().get(i), - new ChunkedInference.TextOffset(0, inputs.getInputs().get(i).length()) + new ChunkedInference.TextOffset(0, inputs.getInputs().get(i).input().length()) ) ) ) @@ -130,7 +131,7 @@ private static List translateToChunkedResults(EmbeddingsInput } return results; } else if (inferenceResults instanceof SparseEmbeddingResults sparseEmbeddingResults) { - var inputsAsList = EmbeddingsInput.of(inputs).getInputs(); + var inputsAsList = ChunkInferenceInput.inputs(EmbeddingsInput.of(inputs).getInputs()); return ChunkedInferenceEmbedding.listOf(inputsAsList, sparseEmbeddingResults); } else if (inferenceResults instanceof ErrorInferenceResults error) { return List.of(new ChunkedInferenceError(error.getException())); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxEmbeddingsRequestManager.java index b02efc60133e6..b7c679d3cda54 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxEmbeddingsRequestManager.java @@ -55,7 +55,7 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - List docsInput = EmbeddingsInput.of(inferenceInputs).getInputs(); + List docsInput = EmbeddingsInput.of(inferenceInputs).getStringInputs(); var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); execute( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java index ac674e65cc757..7dfb0002bb062 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java @@ -308,11 +308,11 @@ protected void doChunkedInfer( var batchedRequests = new EmbeddingRequestChunker<>( input.getInputs(), EMBEDDING_MAX_BATCH_SIZE, - model.getConfigurations().getChunkingSettings() + ibmWatsonxModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { var action = ibmWatsonxModel.accept(getActionCreator(getSender(), getServiceComponents()), taskSettings); - action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIEmbeddingsRequestManager.java index d3a00f2de57ce..083690a894c00 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIEmbeddingsRequestManager.java @@ -53,7 +53,7 @@ public void execute( ActionListener listener ) { EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = input.getInputs(); + List docsInput = input.getStringInputs(); InputType inputType = input.getInputType(); JinaAIEmbeddingsRequest request = new JinaAIEmbeddingsRequest(docsInput, inputType, model); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java index ec7014e511765..c2e88cb6cdc7c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java @@ -287,7 +287,7 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = jinaaiModel.accept(actionCreator, taskSettings); - action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralEmbeddingsRequestManager.java index 52e388d594886..73342edaaff15 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralEmbeddingsRequestManager.java @@ -61,7 +61,7 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - List docsInput = EmbeddingsInput.of(inferenceInputs).getInputs(); + List docsInput = EmbeddingsInput.of(inferenceInputs).getStringInputs(); var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); MistralEmbeddingsRequest request = new MistralEmbeddingsRequest(truncator, truncatedInput, model); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java index e293fbd59a918..558b7e255f2b4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java @@ -122,7 +122,7 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = mistralEmbeddingsModel.accept(actionCreator, taskSettings); - action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); } } else { listener.onFailure(createInvalidModelException(model)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java index 358e9ca94e9f4..4bb717e0bf7e2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java @@ -346,7 +346,7 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = openAiModel.accept(actionCreator, taskSettings); - action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java index bb7cb2da5aa2b..3a048d170372b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java @@ -307,7 +307,7 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = voyageaiModel.accept(actionCreator, taskSettings); - action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreator.java index 5bf9bd66def2f..03753835177cb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreator.java @@ -57,7 +57,7 @@ public ExecutableAction create(VoyageAIEmbeddingsModel model, Map new VoyageAIEmbeddingsRequest( - embeddingsInput.getInputs(), + embeddingsInput.getStringInputs(), embeddingsInput.getInputType(), overriddenModel ), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticKnnVectorQueryRewriteInterceptorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticKnnVectorQueryRewriteInterceptorTests.java index 073bf8f5afb9a..270cdba6d3469 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticKnnVectorQueryRewriteInterceptorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticKnnVectorQueryRewriteInterceptorTests.java @@ -56,7 +56,7 @@ public void cleanup() { public void testKnnQueryWithVectorBuilderIsInterceptedAndRewritten() throws IOException { Map inferenceFields = Map.of( FIELD_NAME, - new InferenceFieldMetadata(index.getName(), INFERENCE_ID, new String[] { FIELD_NAME }) + new InferenceFieldMetadata(index.getName(), INFERENCE_ID, new String[] { FIELD_NAME }, null) ); QueryRewriteContext context = createQueryRewriteContext(inferenceFields); QueryVectorBuilder queryVectorBuilder = new TextEmbeddingQueryVectorBuilder(INFERENCE_ID, QUERY); @@ -67,7 +67,7 @@ public void testKnnQueryWithVectorBuilderIsInterceptedAndRewritten() throws IOEx public void testKnnWithQueryBuilderWithoutInferenceIdIsInterceptedAndRewritten() throws IOException { Map inferenceFields = Map.of( FIELD_NAME, - new InferenceFieldMetadata(index.getName(), INFERENCE_ID, new String[] { FIELD_NAME }) + new InferenceFieldMetadata(index.getName(), INFERENCE_ID, new String[] { FIELD_NAME }, null) ); QueryRewriteContext context = createQueryRewriteContext(inferenceFields); QueryVectorBuilder queryVectorBuilder = new TextEmbeddingQueryVectorBuilder(null, QUERY); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticMatchQueryRewriteInterceptorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticMatchQueryRewriteInterceptorTests.java index 47705c14d5941..6987ef33ed63d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticMatchQueryRewriteInterceptorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticMatchQueryRewriteInterceptorTests.java @@ -52,7 +52,7 @@ public void cleanup() { public void testMatchQueryOnInferenceFieldIsInterceptedAndRewrittenToSemanticQuery() throws IOException { Map inferenceFields = Map.of( FIELD_NAME, - new InferenceFieldMetadata(index.getName(), "inferenceId", new String[] { FIELD_NAME }) + new InferenceFieldMetadata(index.getName(), "inferenceId", new String[] { FIELD_NAME }, null) ); QueryRewriteContext context = createQueryRewriteContext(inferenceFields); QueryBuilder original = createTestQueryBuilder(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticSparseVectorQueryRewriteInterceptorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticSparseVectorQueryRewriteInterceptorTests.java index 1adad1df7b29b..075955766a0a9 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticSparseVectorQueryRewriteInterceptorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticSparseVectorQueryRewriteInterceptorTests.java @@ -54,7 +54,7 @@ public void cleanup() { public void testSparseVectorQueryOnInferenceFieldIsInterceptedAndRewritten() throws IOException { Map inferenceFields = Map.of( FIELD_NAME, - new InferenceFieldMetadata(index.getName(), "inferenceId", new String[] { FIELD_NAME }) + new InferenceFieldMetadata(index.getName(), "inferenceId", new String[] { FIELD_NAME }, null) ); QueryRewriteContext context = createQueryRewriteContext(inferenceFields); QueryBuilder original = new SparseVectorQueryBuilder(FIELD_NAME, INFERENCE_ID, QUERY); @@ -78,7 +78,7 @@ public void testSparseVectorQueryOnInferenceFieldIsInterceptedAndRewritten() thr public void testSparseVectorQueryOnInferenceFieldWithoutInferenceIdIsInterceptedAndRewritten() throws IOException { Map inferenceFields = Map.of( FIELD_NAME, - new InferenceFieldMetadata(index.getName(), "inferenceId", new String[] { FIELD_NAME }) + new InferenceFieldMetadata(index.getName(), "inferenceId", new String[] { FIELD_NAME }, null) ); QueryRewriteContext context = createQueryRewriteContext(inferenceFields); QueryBuilder original = new SparseVectorQueryBuilder(FIELD_NAME, null, QUERY); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index 4ddefc52abbd8..99a79b81cabac 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -35,6 +35,7 @@ import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper; import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceRegistry; @@ -135,7 +136,7 @@ public void testFilterNoop() throws Exception { new BulkItemRequest[0] ); request.setInferenceFieldMap( - Map.of("foo", new InferenceFieldMetadata("foo", "bar", "baz", generateRandomStringArray(5, 10, false, false))) + Map.of("foo", new InferenceFieldMetadata("foo", "bar", "baz", generateRandomStringArray(5, 10, false, false), null)) ); filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain); awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); @@ -168,7 +169,7 @@ public void testLicenseInvalidForInference() throws InterruptedException { Map inferenceFieldMap = Map.of( "obj.field1", - new InferenceFieldMetadata("obj.field1", model.getInferenceEntityId(), new String[] { "obj.field1" }) + new InferenceFieldMetadata("obj.field1", model.getInferenceEntityId(), new String[] { "obj.field1" }, null) ); BulkItemRequest[] items = new BulkItemRequest[1]; items[0] = new BulkItemRequest(0, new IndexRequest("test").source("obj.field1", "Test")); @@ -208,11 +209,11 @@ public void testInferenceNotFound() throws Exception { Map inferenceFieldMap = Map.of( "field1", - new InferenceFieldMetadata("field1", model.getInferenceEntityId(), new String[] { "field1" }), + new InferenceFieldMetadata("field1", model.getInferenceEntityId(), new String[] { "field1" }, null), "field2", - new InferenceFieldMetadata("field2", "inference_0", new String[] { "field2" }), + new InferenceFieldMetadata("field2", "inference_0", new String[] { "field2" }, null), "field3", - new InferenceFieldMetadata("field3", "inference_0", new String[] { "field3" }) + new InferenceFieldMetadata("field3", "inference_0", new String[] { "field3" }, null) ); BulkItemRequest[] items = new BulkItemRequest[10]; for (int i = 0; i < items.length; i++) { @@ -280,7 +281,7 @@ public void testItemFailures() throws Exception { Map inferenceFieldMap = Map.of( "field1", - new InferenceFieldMetadata("field1", model.getInferenceEntityId(), new String[] { "field1" }) + new InferenceFieldMetadata("field1", model.getInferenceEntityId(), new String[] { "field1" }, null) ); BulkItemRequest[] items = new BulkItemRequest[3]; items[0] = new BulkItemRequest(0, new IndexRequest("index").source("field1", "I am a failure")); @@ -348,7 +349,7 @@ public void testExplicitNull() throws Exception { Map inferenceFieldMap = Map.of( "obj.field1", - new InferenceFieldMetadata("obj.field1", model.getInferenceEntityId(), new String[] { "obj.field1" }) + new InferenceFieldMetadata("obj.field1", model.getInferenceEntityId(), new String[] { "obj.field1" }, null) ); Map sourceWithNull = new HashMap<>(); sourceWithNull.put("field1", null); @@ -404,7 +405,7 @@ public void testHandleEmptyInput() throws Exception { Task task = mock(Task.class); Map inferenceFieldMap = Map.of( "semantic_text_field", - new InferenceFieldMetadata("semantic_text_field", model.getInferenceEntityId(), new String[] { "semantic_text_field" }) + new InferenceFieldMetadata("semantic_text_field", model.getInferenceEntityId(), new String[] { "semantic_text_field" }, null) ); BulkItemRequest[] items = new BulkItemRequest[3]; @@ -431,7 +432,7 @@ public void testManyRandomDocs() throws Exception { for (int i = 0; i < numInferenceFields; i++) { String field = randomAlphaOfLengthBetween(5, 10); String inferenceId = randomFrom(inferenceModelMap.keySet()); - inferenceFieldMap.put(field, new InferenceFieldMetadata(field, inferenceId, new String[] { field })); + inferenceFieldMap.put(field, new InferenceFieldMetadata(field, inferenceId, new String[] { field }, null)); } int numRequests = atLeast(100); @@ -505,12 +506,12 @@ private static ShardBulkInferenceActionFilter createFilter( InferenceService inferenceService = mock(InferenceService.class); Answer chunkedInferAnswer = invocationOnMock -> { StaticModel model = (StaticModel) invocationOnMock.getArguments()[0]; - List inputs = (List) invocationOnMock.getArguments()[2]; + List inputs = (List) invocationOnMock.getArguments()[2]; ActionListener> listener = (ActionListener>) invocationOnMock.getArguments()[6]; Runnable runnable = () -> { List results = new ArrayList<>(); - for (String input : inputs) { - results.add(model.getResults(input)); + for (ChunkInferenceInput input : inputs) { + results.add(model.getResults(input.input())); } listener.onResponse(results); }; @@ -601,13 +602,14 @@ private static BulkItemRequest[] randomBulkItemRequest( useLegacyFormat, field, model, + null, List.of(inputText), results, requestContentType ); } else { Map> inputTextMap = Map.of(field, List.of(inputText)); - semanticTextField = randomSemanticText(useLegacyFormat, field, model, List.of(inputText), requestContentType); + semanticTextField = randomSemanticText(useLegacyFormat, field, model, null, List.of(inputText), requestContentType); model.putResult(inputText, toChunkedResult(useLegacyFormat, inputTextMap, semanticTextField)); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilderTests.java index 4a284e0a84ff5..9e6dde60bc641 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilderTests.java @@ -21,14 +21,18 @@ public class ChunkingSettingsBuilderTests extends ESTestCase { public void testNullChunkingSettingsMap() { ChunkingSettings chunkingSettings = ChunkingSettingsBuilder.fromMap(null); - assertEquals(ChunkingSettingsBuilder.OLD_DEFAULT_SETTINGS, chunkingSettings); + + ChunkingSettings chunkingSettingsOrNull = ChunkingSettingsBuilder.fromMap(null, false); + assertNull(chunkingSettingsOrNull); } public void testEmptyChunkingSettingsMap() { ChunkingSettings chunkingSettings = ChunkingSettingsBuilder.fromMap(Collections.emptyMap()); - assertEquals(DEFAULT_SETTINGS, chunkingSettings); + + ChunkingSettings chunkingSettingsOrNull = ChunkingSettingsBuilder.fromMap(Map.of(), false); + assertNull(chunkingSettingsOrNull); } public void testChunkingStrategyNotProvided() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java index 01890e0d0a356..f864e1757a62c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.inference.chunking; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; @@ -46,22 +47,27 @@ public void testEmptyInput_SentenceChunker() { } public void testWhitespaceInput_SentenceChunker() { - var batches = new EmbeddingRequestChunker<>(List.of(" "), 10, new SentenceBoundaryChunkingSettings(250, 1)) - .batchRequestsWithListeners(testListener()); + var batches = new EmbeddingRequestChunker<>( + List.of(new ChunkInferenceInput(" ")), + 10, + new SentenceBoundaryChunkingSettings(250, 1) + ).batchRequestsWithListeners(testListener()); assertThat(batches, hasSize(1)); assertThat(batches.get(0).batch().inputs().get(), hasSize(1)); assertThat(batches.get(0).batch().inputs().get().get(0), Matchers.is(" ")); } public void testBlankInput_WordChunker() { - var batches = new EmbeddingRequestChunker<>(List.of(""), 100, 100, 10).batchRequestsWithListeners(testListener()); + var batches = new EmbeddingRequestChunker<>(List.of(new ChunkInferenceInput("")), 100, 100, 10).batchRequestsWithListeners( + testListener() + ); assertThat(batches, hasSize(1)); assertThat(batches.get(0).batch().inputs().get(), hasSize(1)); assertThat(batches.get(0).batch().inputs().get().get(0), Matchers.is("")); } public void testBlankInput_SentenceChunker() { - var batches = new EmbeddingRequestChunker<>(List.of(""), 10, new SentenceBoundaryChunkingSettings(250, 1)) + var batches = new EmbeddingRequestChunker<>(List.of(new ChunkInferenceInput("")), 10, new SentenceBoundaryChunkingSettings(250, 1)) .batchRequestsWithListeners(testListener()); assertThat(batches, hasSize(1)); assertThat(batches.get(0).batch().inputs().get(), hasSize(1)); @@ -69,36 +75,45 @@ public void testBlankInput_SentenceChunker() { } public void testInputThatDoesNotChunk_WordChunker() { - var batches = new EmbeddingRequestChunker<>(List.of("ABBAABBA"), 100, 100, 10).batchRequestsWithListeners(testListener()); + var batches = new EmbeddingRequestChunker<>(List.of(new ChunkInferenceInput("ABBAABBA")), 100, 100, 10).batchRequestsWithListeners( + testListener() + ); assertThat(batches, hasSize(1)); assertThat(batches.get(0).batch().inputs().get(), hasSize(1)); assertThat(batches.get(0).batch().inputs().get().get(0), Matchers.is("ABBAABBA")); } public void testInputThatDoesNotChunk_SentenceChunker() { - var batches = new EmbeddingRequestChunker<>(List.of("ABBAABBA"), 10, new SentenceBoundaryChunkingSettings(250, 1)) - .batchRequestsWithListeners(testListener()); + var batches = new EmbeddingRequestChunker<>( + List.of(new ChunkInferenceInput("ABBAABBA")), + 10, + new SentenceBoundaryChunkingSettings(250, 1) + ).batchRequestsWithListeners(testListener()); assertThat(batches, hasSize(1)); assertThat(batches.get(0).batch().inputs().get(), hasSize(1)); assertThat(batches.get(0).batch().inputs().get().get(0), Matchers.is("ABBAABBA")); } public void testShortInputsAreSingleBatch() { - String input = "one chunk"; + ChunkInferenceInput input = new ChunkInferenceInput("one chunk"); var batches = new EmbeddingRequestChunker<>(List.of(input), 100, 100, 10).batchRequestsWithListeners(testListener()); assertThat(batches, hasSize(1)); - assertThat(batches.get(0).batch().inputs().get(), contains(input)); + assertThat(batches.get(0).batch().inputs().get(), contains(input.input())); } public void testMultipleShortInputsAreSingleBatch() { - List inputs = List.of("1st small", "2nd small", "3rd small"); + List inputs = List.of( + new ChunkInferenceInput("1st small"), + new ChunkInferenceInput("2nd small"), + new ChunkInferenceInput("3rd small") + ); var batches = new EmbeddingRequestChunker<>(inputs, 100, 100, 10).batchRequestsWithListeners(testListener()); assertThat(batches, hasSize(1)); EmbeddingRequestChunker.BatchRequest batch = batches.get(0).batch(); - assertEquals(batch.inputs().get(), inputs); + assertEquals(batch.inputs().get(), ChunkInferenceInput.inputs(inputs)); for (int i = 0; i < inputs.size(); i++) { var request = batch.requests().get(i); - assertThat(request.chunkText(), equalTo(inputs.get(i))); + assertThat(request.chunkText(), equalTo(inputs.get(i).input())); assertEquals(i, request.inputIndex()); assertEquals(0, request.chunkIndex()); } @@ -107,10 +122,10 @@ public void testMultipleShortInputsAreSingleBatch() { public void testManyInputsMakeManyBatches() { int maxNumInputsPerBatch = 10; int numInputs = maxNumInputsPerBatch * 3 + 1; // requires 4 batches - var inputs = new ArrayList(); + var inputs = new ArrayList(); for (int i = 0; i < numInputs; i++) { - inputs.add("input " + i); + inputs.add(new ChunkInferenceInput("input " + i)); } var batches = new EmbeddingRequestChunker<>(inputs, maxNumInputsPerBatch, 100, 10).batchRequestsWithListeners(testListener()); @@ -133,7 +148,7 @@ public void testManyInputsMakeManyBatches() { List requests = batches.get(0).batch().requests(); for (int i = 0; i < requests.size(); i++) { EmbeddingRequestChunker.Request request = requests.get(i); - assertThat(request.chunkText(), equalTo(inputs.get(i))); + assertThat(request.chunkText(), equalTo(inputs.get(i).input())); assertThat(request.inputIndex(), equalTo(i)); assertThat(request.chunkIndex(), equalTo(0)); } @@ -142,10 +157,10 @@ public void testManyInputsMakeManyBatches() { public void testChunkingSettingsProvided() { int maxNumInputsPerBatch = 10; int numInputs = maxNumInputsPerBatch * 3 + 1; // requires 4 batches - var inputs = new ArrayList(); + var inputs = new ArrayList(); for (int i = 0; i < numInputs; i++) { - inputs.add("input " + i); + inputs.add(new ChunkInferenceInput("input " + i)); } var batches = new EmbeddingRequestChunker<>(inputs, maxNumInputsPerBatch, ChunkingSettingsTests.createRandomChunkingSettings()) @@ -169,7 +184,7 @@ public void testChunkingSettingsProvided() { List requests = batches.get(0).batch().requests(); for (int i = 0; i < requests.size(); i++) { EmbeddingRequestChunker.Request request = requests.get(i); - assertThat(request.chunkText(), equalTo(inputs.get(i))); + assertThat(request.chunkText(), equalTo(inputs.get(i).input())); assertThat(request.inputIndex(), equalTo(i)); assertThat(request.chunkIndex(), equalTo(0)); } @@ -188,7 +203,12 @@ public void testLongInputChunkedOverMultipleBatches() { passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace } - List inputs = List.of("1st small", passageBuilder.toString(), "2nd small", "3rd small"); + List inputs = List.of( + new ChunkInferenceInput("1st small"), + new ChunkInferenceInput(passageBuilder.toString()), + new ChunkInferenceInput("2nd small"), + new ChunkInferenceInput("3rd small") + ); var batches = new EmbeddingRequestChunker<>(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(testListener()); @@ -244,7 +264,11 @@ public void testVeryLongInput_Sparse() { passageBuilder.append("word").append(i).append(" "); // chunk on whitespace } - List inputs = List.of("1st small", passageBuilder.toString(), "2nd small"); + List inputs = List.of( + new ChunkInferenceInput("1st small"), + new ChunkInferenceInput(passageBuilder.toString()), + new ChunkInferenceInput("2nd small") + ); var finalListener = testListener(); List batches = new EmbeddingRequestChunker<>(inputs, batchSize, chunkSize, 0) @@ -278,7 +302,7 @@ public void testVeryLongInput_Sparse() { assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class)); ChunkedInferenceEmbedding chunkedEmbedding = (ChunkedInferenceEmbedding) inference; assertThat(chunkedEmbedding.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(0), chunkedEmbedding.chunks().get(0).offset()), equalTo("1st small")); + assertThat(getMatchedText(inputs.get(0).input(), chunkedEmbedding.chunks().get(0).offset()), equalTo("1st small")); assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(SparseEmbeddingResults.Embedding.class)); SparseEmbeddingResults.Embedding embedding = (SparseEmbeddingResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); assertThat(embedding.tokens(), contains(new WeightedToken("word", 1 / 16384f))); @@ -294,16 +318,19 @@ public void testVeryLongInput_Sparse() { // The first merged chunk consists of 20 small chunks (so 400 words) and the max // weight is the weight of the 20th small chunk (so 21/16384). - assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(0).offset()), startsWith("word0 word1 ")); - assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(0).offset()), endsWith(" word398 word399")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(0).offset()), startsWith("word0 word1 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(0).offset()), endsWith(" word398 word399")); assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(SparseEmbeddingResults.Embedding.class)); embedding = (SparseEmbeddingResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); assertThat(embedding.tokens(), contains(new WeightedToken("word", 21 / 16384f))); // The last merged chunk consists of 19 small chunks (so 380 words) and the max // weight is the weight of the 10000th small chunk (so 10001/16384). - assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(511).offset()), startsWith(" word199620 word199621 ")); - assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(511).offset()), endsWith(" word199998 word199999")); + assertThat( + getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(511).offset()), + startsWith(" word199620 word199621 ") + ); + assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(511).offset()), endsWith(" word199998 word199999")); assertThat(chunkedEmbedding.chunks().get(511).embedding(), instanceOf(SparseEmbeddingResults.Embedding.class)); embedding = (SparseEmbeddingResults.Embedding) chunkedEmbedding.chunks().get(511).embedding(); assertThat(embedding.tokens(), contains(new WeightedToken("word", 10001 / 16384f))); @@ -313,7 +340,7 @@ public void testVeryLongInput_Sparse() { assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class)); chunkedEmbedding = (ChunkedInferenceEmbedding) inference; assertThat(chunkedEmbedding.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(2), chunkedEmbedding.chunks().get(0).offset()), equalTo("2nd small")); + assertThat(getMatchedText(inputs.get(2).input(), chunkedEmbedding.chunks().get(0).offset()), equalTo("2nd small")); assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(SparseEmbeddingResults.Embedding.class)); embedding = (SparseEmbeddingResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); assertThat(embedding.tokens(), contains(new WeightedToken("word", 10002 / 16384f))); @@ -329,7 +356,11 @@ public void testVeryLongInput_Float() { passageBuilder.append("word").append(i).append(" "); // chunk on whitespace } - List inputs = List.of("1st small", passageBuilder.toString(), "2nd small"); + List inputs = List.of( + new ChunkInferenceInput("1st small"), + new ChunkInferenceInput(passageBuilder.toString()), + new ChunkInferenceInput("2nd small") + ); var finalListener = testListener(); List batches = new EmbeddingRequestChunker<>(inputs, batchSize, chunkSize, 0) @@ -362,7 +393,7 @@ public void testVeryLongInput_Float() { assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class)); ChunkedInferenceEmbedding chunkedEmbedding = (ChunkedInferenceEmbedding) inference; assertThat(chunkedEmbedding.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(0), chunkedEmbedding.chunks().get(0).offset()), equalTo("1st small")); + assertThat(getMatchedText(inputs.get(0).input(), chunkedEmbedding.chunks().get(0).offset()), equalTo("1st small")); assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class)); TextEmbeddingFloatResults.Embedding embedding = (TextEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); assertThat(embedding.values(), equalTo(new float[] { 1 / 16384f })); @@ -378,16 +409,19 @@ public void testVeryLongInput_Float() { // The first merged chunk consists of 20 small chunks (so 400 words) and the weight // is the average of the weights 2/16384 ... 21/16384. - assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(0).offset()), startsWith("word0 word1 ")); - assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(0).offset()), endsWith(" word398 word399")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(0).offset()), startsWith("word0 word1 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(0).offset()), endsWith(" word398 word399")); assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class)); embedding = (TextEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); assertThat(embedding.values(), equalTo(new float[] { (2 + 21) / (2 * 16384f) })); // The last merged chunk consists of 19 small chunks (so 380 words) and the weight // is the average of the weights 9983/16384 ... 10001/16384. - assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(511).offset()), startsWith(" word199620 word199621 ")); - assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(511).offset()), endsWith(" word199998 word199999")); + assertThat( + getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(511).offset()), + startsWith(" word199620 word199621 ") + ); + assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(511).offset()), endsWith(" word199998 word199999")); assertThat(chunkedEmbedding.chunks().get(511).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class)); embedding = (TextEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks().get(511).embedding(); assertThat(embedding.values(), equalTo(new float[] { (9983 + 10001) / (2 * 16384f) })); @@ -397,7 +431,7 @@ public void testVeryLongInput_Float() { assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class)); chunkedEmbedding = (ChunkedInferenceEmbedding) inference; assertThat(chunkedEmbedding.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(2), chunkedEmbedding.chunks().get(0).offset()), equalTo("2nd small")); + assertThat(getMatchedText(inputs.get(2).input(), chunkedEmbedding.chunks().get(0).offset()), equalTo("2nd small")); assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class)); embedding = (TextEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); assertThat(embedding.values(), equalTo(new float[] { 10002 / 16384f })); @@ -413,7 +447,11 @@ public void testVeryLongInput_Byte() { passageBuilder.append("word").append(i).append(" "); // chunk on whitespace } - List inputs = List.of("1st small", passageBuilder.toString(), "2nd small"); + List inputs = List.of( + new ChunkInferenceInput("1st small"), + new ChunkInferenceInput(passageBuilder.toString()), + new ChunkInferenceInput("2nd small") + ); var finalListener = testListener(); List batches = new EmbeddingRequestChunker<>(inputs, batchSize, chunkSize, 0) @@ -446,7 +484,7 @@ public void testVeryLongInput_Byte() { assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class)); ChunkedInferenceEmbedding chunkedEmbedding = (ChunkedInferenceEmbedding) inference; assertThat(chunkedEmbedding.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(0), chunkedEmbedding.chunks().get(0).offset()), equalTo("1st small")); + assertThat(getMatchedText(inputs.get(0).input(), chunkedEmbedding.chunks().get(0).offset()), equalTo("1st small")); assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingByteResults.Embedding.class)); TextEmbeddingByteResults.Embedding embedding = (TextEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); assertThat(embedding.values(), equalTo(new byte[] { 1 })); @@ -462,8 +500,8 @@ public void testVeryLongInput_Byte() { // The first merged chunk consists of 20 small chunks (so 400 words) and the weight // is the average of the weights 2 ... 21, so 11.5, which is rounded to 12. - assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(0).offset()), startsWith("word0 word1 ")); - assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(0).offset()), endsWith(" word398 word399")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(0).offset()), startsWith("word0 word1 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(0).offset()), endsWith(" word398 word399")); assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingByteResults.Embedding.class)); embedding = (TextEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); assertThat(embedding.values(), equalTo(new byte[] { 12 })); @@ -471,8 +509,11 @@ public void testVeryLongInput_Byte() { // The last merged chunk consists of 19 small chunks (so 380 words) and the weight // is the average of the weights 9983 ... 10001 modulo 256 (bytes overflowing), so // the average of -1, 0, 1, ... , 17, so 8. - assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(511).offset()), startsWith(" word199620 word199621 ")); - assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(511).offset()), endsWith(" word199998 word199999")); + assertThat( + getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(511).offset()), + startsWith(" word199620 word199621 ") + ); + assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(511).offset()), endsWith(" word199998 word199999")); assertThat(chunkedEmbedding.chunks().get(511).embedding(), instanceOf(TextEmbeddingByteResults.Embedding.class)); embedding = (TextEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(511).embedding(); assertThat(embedding.values(), equalTo(new byte[] { 8 })); @@ -482,7 +523,7 @@ public void testVeryLongInput_Byte() { assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class)); chunkedEmbedding = (ChunkedInferenceEmbedding) inference; assertThat(chunkedEmbedding.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(2), chunkedEmbedding.chunks().get(0).offset()), equalTo("2nd small")); + assertThat(getMatchedText(inputs.get(2).input(), chunkedEmbedding.chunks().get(0).offset()), equalTo("2nd small")); assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingByteResults.Embedding.class)); embedding = (TextEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); assertThat(embedding.values(), equalTo(new byte[] { 18 })); @@ -500,7 +541,12 @@ public void testMergingListener_Float() { for (int i = 0; i < numberOfWordsInPassage; i++) { passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace } - List inputs = List.of("1st small", passageBuilder.toString(), "2nd small", "3rd small"); + List inputs = List.of( + new ChunkInferenceInput("1st small"), + new ChunkInferenceInput(passageBuilder.toString()), + new ChunkInferenceInput("2nd small"), + new ChunkInferenceInput("3rd small") + ); var finalListener = testListener(); var batches = new EmbeddingRequestChunker<>(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(finalListener); @@ -529,7 +575,7 @@ public void testMergingListener_Float() { assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedFloatResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedFloatResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(0), chunkedFloatResult.chunks().get(0).offset()), equalTo("1st small")); + assertThat(getMatchedText(inputs.get(0).input(), chunkedFloatResult.chunks().get(0).offset()), equalTo("1st small")); } { // this is the large input split in multiple chunks @@ -537,26 +583,29 @@ public void testMergingListener_Float() { assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedFloatResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedFloatResult.chunks(), hasSize(6)); - assertThat(getMatchedText(inputs.get(1), chunkedFloatResult.chunks().get(0).offset()), startsWith("passage_input0 ")); - assertThat(getMatchedText(inputs.get(1), chunkedFloatResult.chunks().get(1).offset()), startsWith(" passage_input20 ")); - assertThat(getMatchedText(inputs.get(1), chunkedFloatResult.chunks().get(2).offset()), startsWith(" passage_input40 ")); - assertThat(getMatchedText(inputs.get(1), chunkedFloatResult.chunks().get(3).offset()), startsWith(" passage_input60 ")); - assertThat(getMatchedText(inputs.get(1), chunkedFloatResult.chunks().get(4).offset()), startsWith(" passage_input80 ")); - assertThat(getMatchedText(inputs.get(1), chunkedFloatResult.chunks().get(5).offset()), startsWith(" passage_input100 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedFloatResult.chunks().get(0).offset()), startsWith("passage_input0 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedFloatResult.chunks().get(1).offset()), startsWith(" passage_input20 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedFloatResult.chunks().get(2).offset()), startsWith(" passage_input40 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedFloatResult.chunks().get(3).offset()), startsWith(" passage_input60 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedFloatResult.chunks().get(4).offset()), startsWith(" passage_input80 ")); + assertThat( + getMatchedText(inputs.get(1).input(), chunkedFloatResult.chunks().get(5).offset()), + startsWith(" passage_input100 ") + ); } { var chunkedResult = finalListener.results.get(2); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedFloatResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedFloatResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(2), chunkedFloatResult.chunks().get(0).offset()), equalTo("2nd small")); + assertThat(getMatchedText(inputs.get(2).input(), chunkedFloatResult.chunks().get(0).offset()), equalTo("2nd small")); } { var chunkedResult = finalListener.results.get(3); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedFloatResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedFloatResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(3), chunkedFloatResult.chunks().get(0).offset()), equalTo("3rd small")); + assertThat(getMatchedText(inputs.get(3).input(), chunkedFloatResult.chunks().get(0).offset()), equalTo("3rd small")); } } @@ -572,7 +621,12 @@ public void testMergingListener_Byte() { for (int i = 0; i < numberOfWordsInPassage; i++) { passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace } - List inputs = List.of("1st small", passageBuilder.toString(), "2nd small", "3rd small"); + List inputs = List.of( + new ChunkInferenceInput("1st small"), + new ChunkInferenceInput(passageBuilder.toString()), + new ChunkInferenceInput("2nd small"), + new ChunkInferenceInput("3rd small") + ); var finalListener = testListener(); var batches = new EmbeddingRequestChunker<>(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(finalListener); @@ -601,7 +655,7 @@ public void testMergingListener_Byte() { assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(0), chunkedByteResult.chunks().get(0).offset()), equalTo("1st small")); + assertThat(getMatchedText(inputs.get(0).input(), chunkedByteResult.chunks().get(0).offset()), equalTo("1st small")); } { // this is the large input split in multiple chunks @@ -609,26 +663,26 @@ public void testMergingListener_Byte() { assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(6)); - assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(0).offset()), startsWith("passage_input0 ")); - assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(1).offset()), startsWith(" passage_input20 ")); - assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(2).offset()), startsWith(" passage_input40 ")); - assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(3).offset()), startsWith(" passage_input60 ")); - assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(4).offset()), startsWith(" passage_input80 ")); - assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(5).offset()), startsWith(" passage_input100 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(0).offset()), startsWith("passage_input0 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(1).offset()), startsWith(" passage_input20 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(2).offset()), startsWith(" passage_input40 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(3).offset()), startsWith(" passage_input60 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(4).offset()), startsWith(" passage_input80 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(5).offset()), startsWith(" passage_input100 ")); } { var chunkedResult = finalListener.results.get(2); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(2), chunkedByteResult.chunks().get(0).offset()), equalTo("2nd small")); + assertThat(getMatchedText(inputs.get(2).input(), chunkedByteResult.chunks().get(0).offset()), equalTo("2nd small")); } { var chunkedResult = finalListener.results.get(3); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(3), chunkedByteResult.chunks().get(0).offset()), equalTo("3rd small")); + assertThat(getMatchedText(inputs.get(3).input(), chunkedByteResult.chunks().get(0).offset()), equalTo("3rd small")); } } @@ -644,7 +698,12 @@ public void testMergingListener_Bit() { for (int i = 0; i < numberOfWordsInPassage; i++) { passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace } - List inputs = List.of("1st small", passageBuilder.toString(), "2nd small", "3rd small"); + List inputs = List.of( + new ChunkInferenceInput("1st small"), + new ChunkInferenceInput(passageBuilder.toString()), + new ChunkInferenceInput("2nd small"), + new ChunkInferenceInput("3rd small") + ); var finalListener = testListener(); var batches = new EmbeddingRequestChunker<>(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(finalListener); @@ -673,7 +732,7 @@ public void testMergingListener_Bit() { assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(0), chunkedByteResult.chunks().get(0).offset()), equalTo("1st small")); + assertThat(getMatchedText(inputs.get(0).input(), chunkedByteResult.chunks().get(0).offset()), equalTo("1st small")); } { // this is the large input split in multiple chunks @@ -681,26 +740,26 @@ public void testMergingListener_Bit() { assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(6)); - assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(0).offset()), startsWith("passage_input0 ")); - assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(1).offset()), startsWith(" passage_input20 ")); - assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(2).offset()), startsWith(" passage_input40 ")); - assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(3).offset()), startsWith(" passage_input60 ")); - assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(4).offset()), startsWith(" passage_input80 ")); - assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(5).offset()), startsWith(" passage_input100 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(0).offset()), startsWith("passage_input0 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(1).offset()), startsWith(" passage_input20 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(2).offset()), startsWith(" passage_input40 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(3).offset()), startsWith(" passage_input60 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(4).offset()), startsWith(" passage_input80 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(5).offset()), startsWith(" passage_input100 ")); } { var chunkedResult = finalListener.results.get(2); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(2), chunkedByteResult.chunks().get(0).offset()), equalTo("2nd small")); + assertThat(getMatchedText(inputs.get(2).input(), chunkedByteResult.chunks().get(0).offset()), equalTo("2nd small")); } { var chunkedResult = finalListener.results.get(3); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(3), chunkedByteResult.chunks().get(0).offset()), equalTo("3rd small")); + assertThat(getMatchedText(inputs.get(3).input(), chunkedByteResult.chunks().get(0).offset()), equalTo("3rd small")); } } @@ -716,7 +775,12 @@ public void testMergingListener_Sparse() { for (int i = 0; i < numberOfWordsInPassage; i++) { passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace } - List inputs = List.of("1st small", "2nd small", "3rd small", passageBuilder.toString()); + List inputs = List.of( + new ChunkInferenceInput("1st small"), + new ChunkInferenceInput("2nd small"), + new ChunkInferenceInput("3rd small"), + new ChunkInferenceInput(passageBuilder.toString()) + ); var finalListener = testListener(); var batches = new EmbeddingRequestChunker<>(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(finalListener); @@ -752,21 +816,21 @@ public void testMergingListener_Sparse() { assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedSparseResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedSparseResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(0), chunkedSparseResult.chunks().get(0).offset()), equalTo("1st small")); + assertThat(getMatchedText(inputs.get(0).input(), chunkedSparseResult.chunks().get(0).offset()), equalTo("1st small")); } { var chunkedResult = finalListener.results.get(1); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedSparseResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedSparseResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(1), chunkedSparseResult.chunks().get(0).offset()), equalTo("2nd small")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedSparseResult.chunks().get(0).offset()), equalTo("2nd small")); } { var chunkedResult = finalListener.results.get(2); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedSparseResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedSparseResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(2), chunkedSparseResult.chunks().get(0).offset()), equalTo("3rd small")); + assertThat(getMatchedText(inputs.get(2).input(), chunkedSparseResult.chunks().get(0).offset()), equalTo("3rd small")); } { // this is the large input split in multiple chunks @@ -774,14 +838,24 @@ public void testMergingListener_Sparse() { assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedSparseResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedSparseResult.chunks(), hasSize(9)); // passage is split into 9 chunks, 10 words each - assertThat(getMatchedText(inputs.get(3), chunkedSparseResult.chunks().get(0).offset()), startsWith("passage_input0 ")); - assertThat(getMatchedText(inputs.get(3), chunkedSparseResult.chunks().get(1).offset()), startsWith(" passage_input10 ")); - assertThat(getMatchedText(inputs.get(3), chunkedSparseResult.chunks().get(8).offset()), startsWith(" passage_input80 ")); + assertThat(getMatchedText(inputs.get(3).input(), chunkedSparseResult.chunks().get(0).offset()), startsWith("passage_input0 ")); + assertThat( + getMatchedText(inputs.get(3).input(), chunkedSparseResult.chunks().get(1).offset()), + startsWith(" passage_input10 ") + ); + assertThat( + getMatchedText(inputs.get(3).input(), chunkedSparseResult.chunks().get(8).offset()), + startsWith(" passage_input80 ") + ); } } public void testListenerErrorsWithWrongNumberOfResponses() { - List inputs = List.of("1st small", "2nd small", "3rd small"); + List inputs = List.of( + new ChunkInferenceInput("1st small"), + new ChunkInferenceInput("2nd small"), + new ChunkInferenceInput("3rd small") + ); var failureMessage = new AtomicReference(); var listener = new ActionListener>() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableActionTests.java index 6f335ab32f01c..0bfa640b0cded 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableActionTests.java @@ -10,6 +10,7 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.ESTestCase; @@ -64,7 +65,7 @@ public void testOneInputIsValid() { public void testMoreThanOneInput() { var badInput = mock(EmbeddingsInput.class); - var input = List.of("one", "two"); + var input = List.of(new ChunkInferenceInput("one"), new ChunkInferenceInput("two")); when(badInput.getInputs()).thenReturn(input); when(badInput.isSingleInput()).thenReturn(false); var actualException = new AtomicReference(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java index 9fd4176ffa02b..2c78b7358e9ba 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; @@ -119,7 +120,7 @@ public void testCreateSender_SendsRequestAndReceivesResponse() throws Exception PlainActionFuture listener = new PlainActionFuture<>(); sender.send( OpenAiEmbeddingsRequestManagerTests.makeCreator(getUrl(webServer), null, "key", "model", null, threadPool), - new EmbeddingsInput(List.of("abc"), null), + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), null), null, listener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTaskTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTaskTests.java index 2aa9658eba99f..78028f93eebdd 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTaskTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTaskTests.java @@ -11,6 +11,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.Scheduler; @@ -61,7 +62,7 @@ public void testExecuting_DoesNotCallOnFailureForTimeout_AfterIllegalArgumentExc var requestTask = new RequestTask( OpenAiEmbeddingsRequestManagerTests.makeCreator("url", null, "key", "model", null, "id", threadPool), - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()), TimeValue.timeValueMillis(1), mockThreadPool, listener @@ -81,7 +82,7 @@ public void testRequest_ReturnsTimeoutException() { PlainActionFuture listener = new PlainActionFuture<>(); var requestTask = new RequestTask( OpenAiEmbeddingsRequestManagerTests.makeCreator("url", null, "key", "model", null, "id", threadPool), - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()), TimeValue.timeValueMillis(1), threadPool, listener @@ -104,7 +105,7 @@ public void testRequest_DoesNotCallOnFailureTwiceWhenTimingOut() throws Exceptio var requestTask = new RequestTask( OpenAiEmbeddingsRequestManagerTests.makeCreator("url", null, "key", "model", null, "id", threadPool), - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()), TimeValue.timeValueMillis(1), threadPool, listener @@ -133,7 +134,7 @@ public void testRequest_DoesNotCallOnResponseAfterTimingOut() throws Exception { var requestTask = new RequestTask( OpenAiEmbeddingsRequestManagerTests.makeCreator("url", null, "key", "model", null, "id", threadPool), - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()), TimeValue.timeValueMillis(1), threadPool, listener @@ -160,7 +161,7 @@ public void testRequest_DoesNotCallOnFailureForTimeout_AfterAlreadyCallingOnResp var requestTask = new RequestTask( OpenAiEmbeddingsRequestManagerTests.makeCreator("url", null, "key", "model", null, "id", threadPool), - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()), TimeValue.timeValueMillis(1), mockThreadPool, listener diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticInferenceMetadataFieldsRecoveryTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticInferenceMetadataFieldsRecoveryTests.java index d30a4362e19d9..add130da2d368 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticInferenceMetadataFieldsRecoveryTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticInferenceMetadataFieldsRecoveryTests.java @@ -27,6 +27,7 @@ import org.elasticsearch.index.mapper.SourceToParse; import org.elasticsearch.index.translog.Translog; import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; @@ -42,6 +43,7 @@ import java.util.List; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.generateRandomChunkingSettings; import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomChunkedInferenceEmbeddingByte; import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomChunkedInferenceEmbeddingFloat; import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomChunkedInferenceEmbeddingSparse; @@ -51,12 +53,14 @@ public class SemanticInferenceMetadataFieldsRecoveryTests extends EngineTestCase { private final Model model1; private final Model model2; + private final ChunkingSettings chunkingSettings; private final boolean useSynthetic; private final boolean useIncludesExcludes; public SemanticInferenceMetadataFieldsRecoveryTests(boolean useSynthetic, boolean useIncludesExcludes) { this.model1 = TestModel.createRandomInstance(TaskType.TEXT_EMBEDDING, List.of(SimilarityMeasure.DOT_PRODUCT)); this.model2 = TestModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); + this.chunkingSettings = generateRandomChunkingSettings(); this.useSynthetic = useSynthetic; this.useIncludesExcludes = useIncludesExcludes; } @@ -105,6 +109,10 @@ protected String defaultMapping() { builder.field("element_type", model1.getServiceSettings().elementType().name()); builder.field("service", model1.getConfigurations().getService()); builder.endObject(); + if (chunkingSettings != null) { + builder.field("chunking_settings"); + chunkingSettings.toXContent(builder, null); + } builder.endObject(); builder.startObject("semantic_2"); @@ -114,6 +122,10 @@ protected String defaultMapping() { builder.field("task_type", model2.getTaskType().name()); builder.field("service", model2.getConfigurations().getService()); builder.endObject(); + if (chunkingSettings != null) { + builder.field("chunking_settings"); + chunkingSettings.toXContent(builder, null); + } builder.endObject(); builder.endObject(); @@ -229,8 +241,8 @@ private BytesReference randomSource() throws IOException { false, builder, List.of( - randomSemanticText(false, "semantic_2", model2, randomInputs(), XContentType.JSON), - randomSemanticText(false, "semantic_1", model1, randomInputs(), XContentType.JSON) + randomSemanticText(false, "semantic_2", model2, chunkingSettings, randomInputs(), XContentType.JSON), + randomSemanticText(false, "semantic_1", model1, chunkingSettings, randomInputs(), XContentType.JSON) ) ); builder.endObject(); @@ -241,6 +253,7 @@ private static SemanticTextField randomSemanticText( boolean useLegacyFormat, String fieldName, Model model, + ChunkingSettings chunkingSettings, List inputs, XContentType contentType ) throws IOException { @@ -252,7 +265,15 @@ private static SemanticTextField randomSemanticText( case SPARSE_EMBEDDING -> randomChunkedInferenceEmbeddingSparse(inputs, false); default -> throw new AssertionError("invalid task type: " + model.getTaskType().name()); }; - return semanticTextFieldFromChunkedInferenceResults(useLegacyFormat, fieldName, model, inputs, results, contentType); + return semanticTextFieldFromChunkedInferenceResults( + useLegacyFormat, + fieldName, + model, + chunkingSettings, + inputs, + results, + contentType + ); } private static List randomInputs() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java index f6037c5e02ccb..f8c24c7e4bb2c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java @@ -53,6 +53,7 @@ import org.elasticsearch.index.mapper.vectors.XFeatureField; import org.elasticsearch.index.query.SearchExecutionContext; import org.elasticsearch.index.search.ESToParentBlockJoinQuery; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.MinimalServiceSettings; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.SimilarityMeasure; @@ -89,6 +90,8 @@ import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getChunksFieldName; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getEmbeddingsFieldName; import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper.DEFAULT_ELSER_2_INFERENCE_ID; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.generateRandomChunkingSettings; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.generateRandomChunkingSettingsOtherThan; import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticText; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; @@ -173,7 +176,7 @@ protected IngestScriptSupport ingestScriptSupport() { @Override public MappedFieldType getMappedFieldType() { - return new SemanticTextFieldMapper.SemanticTextFieldType("field", "fake-inference-id", null, null, null, false, Map.of()); + return new SemanticTextFieldMapper.SemanticTextFieldType("field", "fake-inference-id", null, null, null, null, false, Map.of()); } @Override @@ -561,6 +564,15 @@ public void testUpdateSearchInferenceId() throws IOException { } private static void assertSemanticTextField(MapperService mapperService, String fieldName, boolean expectedModelSettings) { + assertSemanticTextField(mapperService, fieldName, expectedModelSettings, null); + } + + private static void assertSemanticTextField( + MapperService mapperService, + String fieldName, + boolean expectedModelSettings, + ChunkingSettings expectedChunkingSettings + ) { Mapper mapper = mapperService.mappingLookup().getMapper(fieldName); assertNotNull(mapper); assertThat(mapper, instanceOf(SemanticTextFieldMapper.class)); @@ -612,6 +624,13 @@ private static void assertSemanticTextField(MapperService mapperService, String } else { assertNull(semanticFieldMapper.fieldType().getModelSettings()); } + + if (expectedChunkingSettings != null) { + assertNotNull(semanticFieldMapper.fieldType().getChunkingSettings()); + assertEquals(expectedChunkingSettings, semanticFieldMapper.fieldType().getChunkingSettings()); + } else { + assertNull(semanticFieldMapper.fieldType().getChunkingSettings()); + } } private static void assertInferenceEndpoints( @@ -637,9 +656,22 @@ public void testSuccessfulParse() throws IOException { Model model1 = TestModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); Model model2 = TestModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); + ChunkingSettings chunkingSettings = null; // Some chunking settings configs can produce different Lucene docs counts XContentBuilder mapping = mapping(b -> { - addSemanticTextMapping(b, fieldName1, model1.getInferenceEntityId(), setSearchInferenceId ? searchInferenceId : null); - addSemanticTextMapping(b, fieldName2, model2.getInferenceEntityId(), setSearchInferenceId ? searchInferenceId : null); + addSemanticTextMapping( + b, + fieldName1, + model1.getInferenceEntityId(), + setSearchInferenceId ? searchInferenceId : null, + chunkingSettings + ); + addSemanticTextMapping( + b, + fieldName2, + model2.getInferenceEntityId(), + setSearchInferenceId ? searchInferenceId : null, + chunkingSettings + ); }); MapperService mapperService = createMapperService(mapping, useLegacyFormat); @@ -665,8 +697,15 @@ public void testSuccessfulParse() throws IOException { useLegacyFormat, b, List.of( - randomSemanticText(useLegacyFormat, fieldName1, model1, List.of("a b", "c"), XContentType.JSON), - randomSemanticText(useLegacyFormat, fieldName2, model2, List.of("d e f"), XContentType.JSON) + randomSemanticText( + useLegacyFormat, + fieldName1, + model1, + chunkingSettings, + List.of("a b", "c"), + XContentType.JSON + ), + randomSemanticText(useLegacyFormat, fieldName2, model2, chunkingSettings, List.of("d e f"), XContentType.JSON) ) ) ) @@ -747,7 +786,7 @@ public void testSuccessfulParse() throws IOException { public void testMissingInferenceId() throws IOException { final MapperService mapperService = createMapperService( - mapping(b -> addSemanticTextMapping(b, "field", "my_id", null)), + mapping(b -> addSemanticTextMapping(b, "field", "my_id", null, null)), useLegacyFormat ); @@ -773,8 +812,11 @@ public void testMissingInferenceId() throws IOException { assertThat(ex.getCause().getMessage(), containsString("Required [inference_id]")); } - public void testMissingModelSettings() throws IOException { - MapperService mapperService = createMapperService(mapping(b -> addSemanticTextMapping(b, "field", "my_id", null)), useLegacyFormat); + public void testMissingModelSettingsAndChunks() throws IOException { + MapperService mapperService = createMapperService( + mapping(b -> addSemanticTextMapping(b, "field", "my_id", null, null)), + useLegacyFormat + ); IllegalArgumentException ex = expectThrows( DocumentParsingException.class, IllegalArgumentException.class, @@ -786,11 +828,15 @@ public void testMissingModelSettings() throws IOException { ) ) ); - assertThat(ex.getCause().getMessage(), containsString("Required [model_settings, chunks]")); + // Model settings may be null here so we only error on chunks + assertThat(ex.getCause().getMessage(), containsString("Required [chunks]")); } public void testMissingTaskType() throws IOException { - MapperService mapperService = createMapperService(mapping(b -> addSemanticTextMapping(b, "field", "my_id", null)), useLegacyFormat); + MapperService mapperService = createMapperService( + mapping(b -> addSemanticTextMapping(b, "field", "my_id", null, null)), + useLegacyFormat + ); IllegalArgumentException ex = expectThrows( DocumentParsingException.class, IllegalArgumentException.class, @@ -849,10 +895,43 @@ public void testDenseVectorElementType() throws IOException { assertMapperService.accept(byteMapperService, DenseVectorFieldMapper.ElementType.BYTE); } + public void testSettingAndUpdatingChunkingSettings() throws IOException { + Model model = TestModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); + final ChunkingSettings chunkingSettings = generateRandomChunkingSettings(false); + String fieldName = "field"; + + SemanticTextField randomSemanticText = randomSemanticText( + useLegacyFormat, + fieldName, + model, + chunkingSettings, + List.of("a"), + XContentType.JSON + ); + + MapperService mapperService = createMapperService( + mapping(b -> addSemanticTextMapping(b, fieldName, model.getInferenceEntityId(), null, chunkingSettings)), + useLegacyFormat + ); + assertSemanticTextField(mapperService, fieldName, false, chunkingSettings); + + ChunkingSettings newChunkingSettings = generateRandomChunkingSettingsOtherThan(chunkingSettings); + merge(mapperService, mapping(b -> addSemanticTextMapping(b, fieldName, model.getInferenceEntityId(), null, newChunkingSettings))); + assertSemanticTextField(mapperService, fieldName, false, newChunkingSettings); + } + public void testModelSettingsRequiredWithChunks() throws IOException { // Create inference results where model settings are set to null and chunks are provided Model model = TestModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); - SemanticTextField randomSemanticText = randomSemanticText(useLegacyFormat, "field", model, List.of("a"), XContentType.JSON); + ChunkingSettings chunkingSettings = generateRandomChunkingSettings(false); + SemanticTextField randomSemanticText = randomSemanticText( + useLegacyFormat, + "field", + model, + chunkingSettings, + List.of("a"), + XContentType.JSON + ); SemanticTextField inferenceResults = new SemanticTextField( randomSemanticText.useLegacyFormat(), randomSemanticText.fieldName(), @@ -860,13 +939,14 @@ public void testModelSettingsRequiredWithChunks() throws IOException { new SemanticTextField.InferenceResult( randomSemanticText.inference().inferenceId(), null, + randomSemanticText.inference().chunkingSettings(), randomSemanticText.inference().chunks() ), randomSemanticText.contentType() ); MapperService mapperService = createMapperService( - mapping(b -> addSemanticTextMapping(b, "field", model.getInferenceEntityId(), null)), + mapping(b -> addSemanticTextMapping(b, "field", model.getInferenceEntityId(), null, chunkingSettings)), useLegacyFormat ); SourceToParse source = source(b -> addSemanticTextInferenceResults(useLegacyFormat, b, List.of(inferenceResults))); @@ -905,7 +985,7 @@ private MapperService mapperServiceForFieldWithModelSettings( useLegacyFormat, fieldName, List.of(), - new SemanticTextField.InferenceResult(inferenceId, modelSettings, Map.of()), + new SemanticTextField.InferenceResult(inferenceId, modelSettings, generateRandomChunkingSettings(), Map.of()), XContentType.JSON ); XContentBuilder builder = JsonXContent.contentBuilder().startObject(); @@ -977,7 +1057,8 @@ private static void addSemanticTextMapping( XContentBuilder mappingBuilder, String fieldName, String inferenceId, - String searchInferenceId + String searchInferenceId, + ChunkingSettings chunkingSettings ) throws IOException { mappingBuilder.startObject(fieldName); mappingBuilder.field("type", SemanticTextFieldMapper.CONTENT_TYPE); @@ -985,6 +1066,11 @@ private static void addSemanticTextMapping( if (searchInferenceId != null) { mappingBuilder.field("search_inference_id", searchInferenceId); } + if (chunkingSettings != null) { + mappingBuilder.startObject("chunking_settings"); + mappingBuilder.mapContents(chunkingSettings.asMap()); + mappingBuilder.endObject(); + } mappingBuilder.endObject(); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java index 0bde2b275c82d..b4ac5c475d425 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapperTestUtils; import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.MinimalServiceSettings; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.SimilarityMeasure; @@ -29,6 +30,8 @@ import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.core.ml.search.WeightedToken; import org.elasticsearch.xpack.core.utils.FloatConversionUtils; +import org.elasticsearch.xpack.inference.chunking.SentenceBoundaryChunkingSettings; +import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings; import org.elasticsearch.xpack.inference.model.TestModel; import java.io.IOException; @@ -71,7 +74,7 @@ protected void assertEqualInstances(SemanticTextField expectedInstance, Semantic assertThat(newInstance.originalValues(), equalTo(expectedInstance.originalValues())); assertThat(newInstance.inference().modelSettings(), equalTo(expectedInstance.inference().modelSettings())); assertThat(newInstance.inference().chunks().size(), equalTo(expectedInstance.inference().chunks().size())); - + assertThat(newInstance.inference().chunkingSettings(), equalTo(expectedInstance.inference().chunkingSettings())); MinimalServiceSettings modelSettings = newInstance.inference().modelSettings(); for (var entry : newInstance.inference().chunks().entrySet()) { var expectedChunks = expectedInstance.inference().chunks().get(entry.getKey()); @@ -119,6 +122,7 @@ protected SemanticTextField createTestInstance() { useLegacyFormat, NAME, TestModel.createRandomInstance(), + generateRandomChunkingSettings(), rawValues, randomFrom(XContentType.values()) ); @@ -248,6 +252,7 @@ public static SemanticTextField randomSemanticText( boolean useLegacyFormat, String fieldName, Model model, + ChunkingSettings chunkingSettings, List inputs, XContentType contentType ) throws IOException { @@ -259,13 +264,22 @@ public static SemanticTextField randomSemanticText( case SPARSE_EMBEDDING -> randomChunkedInferenceEmbeddingSparse(inputs); default -> throw new AssertionError("invalid task type: " + model.getTaskType().name()); }; - return semanticTextFieldFromChunkedInferenceResults(useLegacyFormat, fieldName, model, inputs, results, contentType); + return semanticTextFieldFromChunkedInferenceResults( + useLegacyFormat, + fieldName, + model, + chunkingSettings, + inputs, + results, + contentType + ); } public static SemanticTextField semanticTextFieldFromChunkedInferenceResults( boolean useLegacyFormat, String fieldName, Model model, + ChunkingSettings chunkingSettings, List inputs, ChunkedInference results, XContentType contentType @@ -300,12 +314,30 @@ public static SemanticTextField semanticTextFieldFromChunkedInferenceResults( new SemanticTextField.InferenceResult( model.getInferenceEntityId(), new MinimalServiceSettings(model), + chunkingSettings, Map.of(fieldName, chunks) ), contentType ); } + public static ChunkingSettings generateRandomChunkingSettings() { + return generateRandomChunkingSettings(true); + } + + public static ChunkingSettings generateRandomChunkingSettings(boolean allowNull) { + if (allowNull && randomBoolean()) { + return null; // Use model defaults + } + return randomBoolean() + ? new WordBoundaryChunkingSettings(randomIntBetween(20, 100), randomIntBetween(1, 10)) + : new SentenceBoundaryChunkingSettings(randomIntBetween(20, 100), randomIntBetween(0, 1)); + } + + public static ChunkingSettings generateRandomChunkingSettingsOtherThan(ChunkingSettings chunkingSettings) { + return randomValueOtherThan(chunkingSettings, () -> generateRandomChunkingSettings(false)); + } + /** * Returns a randomly generated object for Semantic Text tests purpose. */ diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java index ded2f5775e60e..ce84046d7cc5f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java @@ -369,7 +369,7 @@ private static SourceToParse buildSemanticTextFieldWithInferenceResults( useLegacyFormat, SEMANTIC_TEXT_FIELD, null, - new SemanticTextField.InferenceResult(INFERENCE_ID, modelSettings, Map.of(SEMANTIC_TEXT_FIELD, List.of())), + new SemanticTextField.InferenceResult(INFERENCE_ID, modelSettings, null, Map.of(SEMANTIC_TEXT_FIELD, List.of())), XContentType.JSON ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java index 47345b5eb56ff..bab9d5d8595be 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java @@ -16,6 +16,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -447,7 +448,7 @@ public void testChunkedInfer_SparseEmbeddingChunkingSettingsNotSet() throws IOEx } private void testChunkedInfer(TaskType taskType, ChunkingSettings chunkingSettings) throws IOException { - var input = List.of("foo", "bar"); + var input = List.of(new ChunkInferenceInput("foo"), new ChunkInferenceInput("bar")); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/action/AlibabaCloudSearchCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/action/AlibabaCloudSearchCompletionActionTests.java index 3ac706df819b4..97bac8582c1fc 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/action/AlibabaCloudSearchCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/action/AlibabaCloudSearchCompletionActionTests.java @@ -122,7 +122,7 @@ public void testExecute_ThrowsIllegalArgumentException_WhenInputIsNotChatComplet PlainActionFuture listener = new PlainActionFuture<>(); assertThrows(IllegalArgumentException.class, () -> { action.execute( - new EmbeddingsInput(List.of(randomAlphaOfLength(10)), InputType.INGEST), + new EmbeddingsInput(List.of(randomAlphaOfLength(10)), null, InputType.INGEST), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java index 0d37859cb3690..3ce36f9f6acfa 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java @@ -20,6 +20,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -1305,7 +1306,7 @@ private void testChunkedInfer(AmazonBedrockEmbeddingsModel model) throws IOExcep PlainActionFuture> listener = new PlainActionFuture<>(); service.chunkedInfer( model, - List.of("a", "bb"), + List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")), new HashMap<>(), InputType.INTERNAL_INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/action/AmazonBedrockActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/action/AmazonBedrockActionCreatorTests.java index 145a2e6078360..8b55d5b78f397 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/action/AmazonBedrockActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/action/AmazonBedrockActionCreatorTests.java @@ -74,7 +74,7 @@ public void testEmbeddingsRequestAction_Titan() throws IOException { var action = creator.create(model, Map.of()); PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -112,7 +112,7 @@ public void testEmbeddingsRequestAction_Cohere() throws IOException { var action = creator.create(model, Map.of()); PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomWithNull(); - action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.0123F, -0.0123F })))); @@ -145,7 +145,7 @@ public void testEmbeddingsRequestAction_HandlesException() throws IOException { ); var action = creator.create(model, Map.of()); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); assertThat(sender.sendCount(), is(1)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockMockRequestSender.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockMockRequestSender.java index 232ed2d23c367..797d50878a0b7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockMockRequestSender.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockMockRequestSender.java @@ -12,6 +12,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; @@ -83,7 +84,7 @@ public void send( ) { sendCounter++; if (inferenceInputs instanceof EmbeddingsInput docsInput) { - inputs.add(docsInput.getInputs()); + inputs.add(ChunkInferenceInput.inputs(docsInput.getInputs())); if (docsInput.getInputType() != null) { inputTypes.add(docsInput.getInputType()); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockRequestSenderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockRequestSenderTests.java index b676951e7fd29..34b28b642df79 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockRequestSenderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockRequestSenderTests.java @@ -10,6 +10,7 @@ import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.ThreadPool; @@ -83,7 +84,7 @@ public void testCreateSender_SendsEmbeddingsRequestAndReceivesResponse() throws threadPool, new TimeValue(30, TimeUnit.SECONDS) ); - sender.send(requestManager, new EmbeddingsInput(List.of("abc"), null), null, listener); + sender.send(requestManager, new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), null), null, listener); var result = listener.actionGet(TIMEOUT); assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.123F, 0.456F, 0.678F, 0.789F })))); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java index e8016b3ef2d04..21f686836588c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java @@ -19,6 +19,7 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -1098,7 +1099,7 @@ private void testChunkedInfer(AzureAiStudioEmbeddingsModel model) throws IOExcep PlainActionFuture> listener = new PlainActionFuture<>(); service.chunkedInfer( model, - List.of("a", "bb"), + List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")), new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionAndCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionAndCreatorTests.java index 29658db5a3d9d..9896286f503f3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionAndCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionAndCreatorTests.java @@ -115,7 +115,7 @@ public void testEmbeddingsRequestAction() throws IOException { var action = creator.create(model, Map.of()); PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomSearchAndIngestWithNull(); - action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java index 918fa738295f4..82504a1153d01 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java @@ -19,6 +19,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -996,7 +997,7 @@ private void testChunkedInfer(AzureOpenAiEmbeddingsModel model) throws IOExcepti PlainActionFuture> listener = new PlainActionFuture<>(); service.chunkedInfer( model, - List.of("a", "bb"), + List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")), new HashMap<>(), InputType.INTERNAL_INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiActionCreatorTests.java index 5287907a2ce76..5f3bbd5af0a16 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiActionCreatorTests.java @@ -119,7 +119,7 @@ public void testCreate_AzureOpenAiEmbeddingsModel() throws IOException { var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -170,7 +170,7 @@ public void testCreate_AzureOpenAiEmbeddingsModel_WithoutUser() throws IOExcepti var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -222,7 +222,7 @@ public void testCreate_AzureOpenAiEmbeddingsModel_FailsFromInvalidResponseFormat var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var failureCauseMessage = "Required [data]"; var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); @@ -296,7 +296,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusC var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abcd"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abcd"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -373,7 +373,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusC var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abcd"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abcd"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -433,7 +433,11 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException { var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("super long input"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute( + new EmbeddingsInput(List.of("super long input"), null, inputType), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); var result = listener.actionGet(TIMEOUT); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiEmbeddingsActionTests.java index f9c85af00d4da..7d59ea225ab22 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiEmbeddingsActionTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.test.ESTestCase; @@ -116,7 +117,11 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomWithNull(); - action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute( + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); var result = listener.actionGet(TIMEOUT); @@ -144,7 +149,11 @@ public void testExecute_ThrowsElasticsearchException() { PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomWithNull(); - action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute( + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -165,7 +174,11 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomWithNull(); - action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute( + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -186,7 +199,11 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomWithNull(); - action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute( + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -201,7 +218,11 @@ public void testExecute_ThrowsException() { PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomWithNull(); - action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute( + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java index b0cd517683464..2c5ec28a9a32e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java @@ -20,6 +20,7 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -1271,7 +1272,7 @@ private void testChunkedInfer(CohereEmbeddingsModel model) throws IOException { // 2 inputs service.chunkedInfer( model, - List.of("a", "bb"), + List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")), new HashMap<>(), InputType.UNSPECIFIED, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -1369,7 +1370,7 @@ public void testChunkedInfer_BatchesCalls_Bytes() throws IOException { // 2 inputs service.chunkedInfer( model, - List.of("a", "bb"), + List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")), new HashMap<>(), InputType.UNSPECIFIED, InferenceAction.Request.DEFAULT_TIMEOUT, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java index da9e2e872981d..b56a19c0af0f1 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java @@ -119,7 +119,7 @@ public void testCreate_CohereEmbeddingsModel() throws IOException { var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java index 22b23ac0679b9..13d5191577d4c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java @@ -127,7 +127,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomWithNull(); - action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -216,7 +216,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForInt8ResponseType() throws I ); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -271,7 +271,7 @@ public void testExecute_ThrowsElasticsearchException() { var action = createAction(getUrl(webServer), "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -291,7 +291,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled var action = createAction(getUrl(webServer), "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -311,7 +311,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled var action = createAction(null, "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -325,7 +325,7 @@ public void testExecute_ThrowsException() { var action = createAction(getUrl(webServer), "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -339,7 +339,7 @@ public void testExecute_ThrowsExceptionWithNullUrl() { var action = createAction(null, "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index 4b7239482b69f..f33d9f5c09ada 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -16,6 +16,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.EmptySecretSettings; import org.elasticsearch.inference.EmptyTaskSettings; @@ -605,7 +606,7 @@ public void testChunkedInfer_PropagatesProductUseCaseHeader() throws IOException service.chunkedInfer( model, null, - List.of("input text"), + List.of(new ChunkInferenceInput("input text")), new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -732,7 +733,7 @@ public void testChunkedInfer_PassesThrough() throws IOException { PlainActionFuture> listener = new PlainActionFuture<>(); service.chunkedInfer( model, - List.of("input text"), + List.of(new ChunkInferenceInput("input text")), new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java index a2966488742ef..fadf4a899e45d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java @@ -96,7 +96,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForElserAction() throws IOExce PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("hello world"), InputType.UNSPECIFIED), + new EmbeddingsInput(List.of("hello world"), null, InputType.UNSPECIFIED), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -157,7 +157,7 @@ public void testSend_FailsFromInvalidResponseFormat_ForElserAction() throws IOEx PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("hello world"), InputType.UNSPECIFIED), + new EmbeddingsInput(List.of("hello world"), null, InputType.UNSPECIFIED), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -214,7 +214,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating() throws IOExc PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("hello world"), InputType.UNSPECIFIED), + new EmbeddingsInput(List.of("hello world"), null, InputType.UNSPECIFIED), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -278,7 +278,7 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("hello world"), InputType.UNSPECIFIED), + new EmbeddingsInput(List.of("hello world"), null, InputType.UNSPECIFIED), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java index 9263025755fc5..9b76d5d8b02b6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java @@ -25,6 +25,7 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptyTaskSettings; @@ -994,7 +995,7 @@ private void testChunkInfer_e5(ChunkingSettings chunkingSettings) throws Interru service.chunkedInfer( model, null, - List.of("a", "bb"), + List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")), Map.of(), InputType.SEARCH, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -1066,7 +1067,7 @@ private void testChunkInfer_Sparse(ChunkingSettings chunkingSettings) throws Int service.chunkedInfer( model, null, - List.of("a", "bb"), + List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")), Map.of(), InputType.SEARCH, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -1138,7 +1139,7 @@ private void testChunkInfer_Elser(ChunkingSettings chunkingSettings) throws Inte service.chunkedInfer( model, null, - List.of("a", "bb"), + List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")), Map.of(), InputType.SEARCH, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -1183,7 +1184,8 @@ public void testChunkInferSetsTokenization() { expectedWindowSize.set(null); service.chunkedInfer( model, - List.of("foo", "bar"), + null, + List.of(new ChunkInferenceInput("foo"), new ChunkInferenceInput("bar")), Map.of(), InputType.SEARCH, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -1194,7 +1196,8 @@ public void testChunkInferSetsTokenization() { expectedWindowSize.set(256); service.chunkedInfer( model, - List.of("foo", "bar"), + null, + List.of(new ChunkInferenceInput("foo"), new ChunkInferenceInput("bar")), Map.of(), InputType.SEARCH, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -1246,7 +1249,7 @@ public void testChunkInfer_FailsBatch() throws InterruptedException { service.chunkedInfer( model, null, - List.of("foo", "bar", "baz"), + List.of(new ChunkInferenceInput("foo"), new ChunkInferenceInput("bar"), new ChunkInferenceInput("baz")), Map.of(), InputType.SEARCH, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -1272,7 +1275,7 @@ public void testChunkingLargeDocument() throws InterruptedException { // build a doc with enough words to make numChunks of chunks int wordsPerChunk = 10; int numWords = numChunks * wordsPerChunk; - var input = "word ".repeat(numWords); + var input = new ChunkInferenceInput("word ".repeat(numWords), null); Client client = mock(Client.class); when(client.threadPool()).thenReturn(threadPool); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java index 5c3b0ce1b27eb..48ddc7dd98b7b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java @@ -19,6 +19,7 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.Strings; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptyTaskSettings; @@ -892,7 +893,7 @@ public void testChunkedInfer_ChunkingSettingsSet() throws IOException { private void testChunkedInfer(String modelId, String apiKey, GoogleAiStudioEmbeddingsModel model) throws IOException { - var input = List.of("a", "bb"); + var input = List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); @@ -937,7 +938,7 @@ private void testChunkedInfer(String modelId, String apiKey, GoogleAiStudioEmbed assertThat(results.get(0), instanceOf(ChunkedInferenceEmbedding.class)); var floatResult = (ChunkedInferenceEmbedding) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); - assertEquals(new ChunkedInference.TextOffset(0, input.get(0).length()), floatResult.chunks().get(0).offset()); + assertEquals(new ChunkedInference.TextOffset(0, input.get(0).input().length()), floatResult.chunks().get(0).offset()); assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); assertTrue( Arrays.equals( @@ -952,7 +953,7 @@ private void testChunkedInfer(String modelId, String apiKey, GoogleAiStudioEmbed assertThat(results.get(1), instanceOf(ChunkedInferenceEmbedding.class)); var floatResult = (ChunkedInferenceEmbedding) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); - assertEquals(new ChunkedInference.TextOffset(0, input.get(1).length()), floatResult.chunks().get(0).offset()); + assertEquals(new ChunkedInference.TextOffset(0, input.get(1).input().length()), floatResult.chunks().get(0).offset()); assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); assertTrue( Arrays.equals( @@ -976,7 +977,7 @@ private void testChunkedInfer(String modelId, String apiKey, GoogleAiStudioEmbed "model", Strings.format("%s/%s", "models", modelId), "content", - Map.of("parts", List.of(Map.of("text", input.get(0)))), + Map.of("parts", List.of(Map.of("text", input.get(0).input()))), "taskType", "RETRIEVAL_DOCUMENT" ), @@ -984,7 +985,7 @@ private void testChunkedInfer(String modelId, String apiKey, GoogleAiStudioEmbed "model", Strings.format("%s/%s", "models", modelId), "content", - Map.of("parts", List.of(Map.of("text", input.get(1)))), + Map.of("parts", List.of(Map.of("text", input.get(1).input()))), "taskType", "RETRIEVAL_DOCUMENT" ) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/action/GoogleAiStudioEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/action/GoogleAiStudioEmbeddingsActionTests.java index f2077998d1797..2668f7f8f7c27 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/action/GoogleAiStudioEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/action/GoogleAiStudioEmbeddingsActionTests.java @@ -108,7 +108,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomWithNull(); - action.execute(new EmbeddingsInput(List.of(input), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of(input), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -163,7 +163,7 @@ public void testExecute_ThrowsElasticsearchException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -187,7 +187,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -205,7 +205,7 @@ public void testExecute_ThrowsException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiEmbeddingsActionTests.java index 8c6e5d31f59c4..5acd78930637b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiEmbeddingsActionTests.java @@ -75,7 +75,7 @@ public void testExecute_ThrowsElasticsearchException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -99,7 +99,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -117,7 +117,7 @@ public void testExecute_ThrowsException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiRerankActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiRerankActionTests.java index 0e575ed045711..760858b5a1261 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiRerankActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiRerankActionTests.java @@ -71,7 +71,7 @@ public void testExecute_ThrowsElasticsearchException() { var action = createAction(getUrl(webServer), "projectId", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -91,7 +91,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled var action = createAction(getUrl(webServer), "projectId", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -105,7 +105,7 @@ public void testExecute_ThrowsException() { var action = createAction(getUrl(webServer), "projectId", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java index ec73fb841b2f7..ab4c0777940ef 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InputType; @@ -95,7 +96,7 @@ public void testChunkedInfer_CallsInfer_Elser_ConvertsFloatResponse() throws IOE PlainActionFuture> listener = new PlainActionFuture<>(); service.chunkedInfer( model, - List.of("abc"), + List.of(new ChunkInferenceInput("abc")), new HashMap<>(), InputType.INTERNAL_SEARCH, InferenceAction.Request.DEFAULT_TIMEOUT, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java index d9d7bba058989..8111da40934df 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java @@ -19,6 +19,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -725,7 +726,7 @@ public void testChunkedInfer_CallsInfer_TextEmbedding_ConvertsFloatResponse() th PlainActionFuture> listener = new PlainActionFuture<>(); service.chunkedInfer( model, - List.of("abc"), + List.of(new ChunkInferenceInput("abc")), new HashMap<>(), InputType.INTERNAL_INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -776,7 +777,7 @@ public void testChunkedInfer() throws IOException { PlainActionFuture> listener = new PlainActionFuture<>(); service.chunkedInfer( model, - List.of("abc"), + List.of(new ChunkInferenceInput("abc")), new HashMap<>(), InputType.INTERNAL_INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java index eabeeff9b7b1a..09619cd076ff5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java @@ -95,7 +95,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForElserAction() throws IOExce PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -166,7 +166,7 @@ public void testSend_FailsFromInvalidResponseFormat_ForElserAction() throws IOEx PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -218,7 +218,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction() throws I PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -280,7 +280,7 @@ public void testSend_FailsFromInvalidResponseFormat_ForEmbeddingsAction() throws PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -339,7 +339,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating() throws IOExc PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abcd"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abcd"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -401,7 +401,7 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("123456"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("123456"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionTests.java index f3bedf04e056f..cfa0f0bb2198b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionTests.java @@ -63,7 +63,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderThrows() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -87,7 +87,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -108,7 +108,7 @@ public void testExecute_ThrowsException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java index 2178073dd2f1f..4013009a086a7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java @@ -19,6 +19,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptyTaskSettings; @@ -730,7 +731,7 @@ public void testChunkedInfer_ChunkingSettingsSet() throws IOException { } private void testChunkedInfer_Batches(ChunkingSettings chunkingSettings) throws IOException { - var input = List.of("a", "bb"); + var input = List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); @@ -785,7 +786,7 @@ private void testChunkedInfer_Batches(ChunkingSettings chunkingSettings) throws assertThat(results.get(0), instanceOf(ChunkedInferenceEmbedding.class)); var floatResult = (ChunkedInferenceEmbedding) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); - assertEquals(new ChunkedInference.TextOffset(0, input.get(0).length()), floatResult.chunks().get(0).offset()); + assertEquals(new ChunkedInference.TextOffset(0, input.get(0).input().length()), floatResult.chunks().get(0).offset()); assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); assertTrue( Arrays.equals( @@ -800,7 +801,7 @@ private void testChunkedInfer_Batches(ChunkingSettings chunkingSettings) throws assertThat(results.get(1), instanceOf(ChunkedInferenceEmbedding.class)); var floatResult = (ChunkedInferenceEmbedding) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); - assertEquals(new ChunkedInference.TextOffset(0, input.get(1).length()), floatResult.chunks().get(0).offset()); + assertEquals(new ChunkedInference.TextOffset(0, input.get(1).input().length()), floatResult.chunks().get(0).offset()); assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); assertTrue( Arrays.equals( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxEmbeddingsActionTests.java index d9f3ed0c394db..9376e4da76261 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxEmbeddingsActionTests.java @@ -114,7 +114,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of(input), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of(input), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -144,7 +144,7 @@ public void testExecute_ThrowsElasticsearchException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -173,7 +173,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -197,7 +197,7 @@ public void testExecute_ThrowsException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java index 67af588f548e7..b4a39be58b245 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java @@ -19,6 +19,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -1674,7 +1675,7 @@ private void test_Embedding_ChunkedInfer_BatchesCalls(JinaAIEmbeddingsModel mode service.chunkedInfer( model, null, - List.of("a", "bb"), + List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")), new HashMap<>(), InputType.UNSPECIFIED, InferenceAction.Request.DEFAULT_TIMEOUT, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java index 04c93eb5d081e..fffb48de52198 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java @@ -19,6 +19,7 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -683,7 +684,7 @@ public void testChunkedInfer(MistralEmbeddingsModel model) throws IOException { PlainActionFuture> listener = new PlainActionFuture<>(); service.chunkedInfer( model, - List.of("abc", "def"), + List.of(new ChunkInferenceInput("abc"), new ChunkInferenceInput("def")), new HashMap<>(), InputType.INTERNAL_INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index c3b19596ff913..3569389347c23 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -21,6 +21,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -1519,7 +1520,7 @@ private void testChunkedInfer(OpenAiEmbeddingsModel model) throws IOException { PlainActionFuture> listener = new PlainActionFuture<>(); service.chunkedInfer( model, - List.of("a", "bb"), + List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")), new HashMap<>(), InputType.INTERNAL_INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiActionCreatorTests.java index 102cdbec77d74..d7c72cf98e267 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiActionCreatorTests.java @@ -109,7 +109,7 @@ public void testCreate_OpenAiEmbeddingsModel() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -166,7 +166,7 @@ public void testCreate_OpenAiEmbeddingsModel_WithoutUser() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -222,7 +222,7 @@ public void testCreate_OpenAiEmbeddingsModel_WithoutOrganization() throws IOExce PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -285,7 +285,7 @@ public void testCreate_OpenAiEmbeddingsModel_FailsFromInvalidResponseFormat() th PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -625,7 +625,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusC PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abcd"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abcd"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -712,7 +712,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusC PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abcd"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abcd"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -784,7 +784,7 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("super long input"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("super long input"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiEmbeddingsActionTests.java index 4a1609c7a27df..17c08dee34e5c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiEmbeddingsActionTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.http.MockResponse; @@ -114,7 +115,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -154,7 +155,7 @@ public void testExecute_ThrowsElasticsearchException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -178,7 +179,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -202,7 +203,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -220,7 +221,7 @@ public void testExecute_ThrowsException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -238,7 +239,7 @@ public void testExecute_ThrowsExceptionWithNullUrl() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java index aa76cee1b3936..238300f2ccdca 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java @@ -18,6 +18,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -1635,7 +1636,7 @@ private void test_Embedding_ChunkedInfer_BatchesCalls(VoyageAIEmbeddingsModel mo service.chunkedInfer( model, null, - List.of("a", "bb"), + List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")), new HashMap<>(), InputType.UNSPECIFIED, InferenceAction.Request.DEFAULT_TIMEOUT, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreatorTests.java index 5ccf9951d5bc7..8e5e14e21575d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreatorTests.java @@ -11,6 +11,7 @@ import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.test.ESTestCase; @@ -110,7 +111,11 @@ public void testCreate_VoyageAIEmbeddingsModel() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomSearchAndIngestWithNull(); - action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute( + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); var result = listener.actionGet(TIMEOUT); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIEmbeddingsActionTests.java index a268db7f336f7..58953a82eb10c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIEmbeddingsActionTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.test.ESTestCase; @@ -120,7 +121,11 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomSearchAndIngestWithNull(); - action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute( + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); var result = listener.actionGet(TIMEOUT); @@ -217,7 +222,11 @@ public void testExecute_ReturnsSuccessfulResponse_ForInt8ResponseType() throws I PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomSearchAndIngestWithNull(); - action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute( + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); var result = listener.actionGet(TIMEOUT); @@ -314,7 +323,11 @@ public void testExecute_ReturnsSuccessfulResponse_ForBinaryResponseType() throws PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomSearchAndIngestWithNull(); - action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute( + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); var result = listener.actionGet(TIMEOUT); @@ -383,7 +396,7 @@ public void testExecute_ThrowsElasticsearchException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomSearchAndIngestWithNull()), + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomSearchAndIngestWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -407,7 +420,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomSearchAndIngestWithNull()), + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomSearchAndIngestWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -425,7 +438,7 @@ public void testExecute_ThrowsException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomSearchAndIngestWithNull()), + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomSearchAndIngestWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -448,7 +461,7 @@ private ExecutableAction createAction( threadPool, model, EMBEDDINGS_HANDLER, - (embeddingsInput) -> new VoyageAIEmbeddingsRequest(embeddingsInput.getInputs(), embeddingsInput.getInputType(), model), + (embeddingsInput) -> new VoyageAIEmbeddingsRequest(embeddingsInput.getStringInputs(), embeddingsInput.getInputType(), model), EmbeddingsInput.class ); diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml new file mode 100644 index 0000000000000..a6ff307f0ef4a --- /dev/null +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml @@ -0,0 +1,523 @@ +setup: + - requires: + cluster_features: "semantic_text.support_chunking_config" + reason: semantic_text chunking configuration added in 8.19 + + - do: + inference.put: + task_type: text_embedding + inference_id: dense-inference-id + body: > + { + "service": "text_embedding_test_service", + "service_settings": { + "model": "my_model", + "dimensions": 10, + "similarity": "cosine", + "api_key": "abc64" + }, + "task_settings": { + } + } + + - do: + inference.put: + task_type: sparse_embedding + inference_id: sparse-inference-id + body: > + { + "service": "test_service", + "service_settings": { + "model": "my_model", + "api_key": "abc64" + }, + "task_settings": { + } + } + + - do: + indices.create: + index: default-chunking-sparse + body: + mappings: + properties: + keyword_field: + type: keyword + inference_field: + type: semantic_text + inference_id: sparse-inference-id + + - do: + indices.create: + index: default-chunking-dense + body: + mappings: + properties: + keyword_field: + type: keyword + inference_field: + type: semantic_text + inference_id: dense-inference-id + + - do: + indices.create: + index: custom-chunking-sparse + body: + mappings: + properties: + keyword_field: + type: keyword + inference_field: + type: semantic_text + inference_id: sparse-inference-id + chunking_settings: + strategy: word + max_chunk_size: 10 + overlap: 1 + + - do: + indices.create: + index: custom-chunking-dense + body: + mappings: + properties: + keyword_field: + type: keyword + inference_field: + type: semantic_text + inference_id: dense-inference-id + chunking_settings: + strategy: word + max_chunk_size: 10 + overlap: 1 + + - do: + index: + index: default-chunking-sparse + id: doc_1 + body: + keyword_field: "default sentence chunking" + inference_field: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." + refresh: true + + - do: + index: + index: custom-chunking-sparse + id: doc_2 + body: + keyword_field: "custom word chunking" + inference_field: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." + refresh: true + + - do: + index: + index: default-chunking-dense + id: doc_3 + body: + keyword_field: "default sentence chunking" + inference_field: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." + refresh: true + + - do: + index: + index: custom-chunking-dense + id: doc_4 + body: + keyword_field: "custom word chunking" + inference_field: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." + refresh: true + +--- +"We return chunking configurations with mappings": + + - do: + indices.get_mapping: + index: default-chunking-sparse + + - not_exists: default-chunking-sparse.mappings.properties.inference_field.chunking_settings + + - do: + indices.get_mapping: + index: custom-chunking-sparse + + - match: { "custom-chunking-sparse.mappings.properties.inference_field.chunking_settings.strategy": "word" } + - match: { "custom-chunking-sparse.mappings.properties.inference_field.chunking_settings.max_chunk_size": 10 } + - match: { "custom-chunking-sparse.mappings.properties.inference_field.chunking_settings.overlap": 1 } + + - do: + indices.get_mapping: + index: default-chunking-dense + + - not_exists: default-chunking-dense.mappings.properties.inference_field.chunking_settings + + - do: + indices.get_mapping: + index: custom-chunking-dense + + - match: { "custom-chunking-dense.mappings.properties.inference_field.chunking_settings.strategy": "word" } + - match: { "custom-chunking-dense.mappings.properties.inference_field.chunking_settings.max_chunk_size": 10 } + - match: { "custom-chunking-dense.mappings.properties.inference_field.chunking_settings.overlap": 1 } + + +--- +"We do not set custom chunking settings for null or empty specified chunking settings": + + - do: + indices.create: + index: null-chunking + body: + mappings: + properties: + inference_field: + type: semantic_text + inference_id: dense-inference-id + chunking_settings: null + + - do: + indices.get_mapping: + index: null-chunking + + - not_exists: null-chunking.mappings.properties.inference_field.chunking_settings + + + - do: + indices.create: + index: empty-chunking + body: + mappings: + properties: + inference_field: + type: semantic_text + inference_id: sparse-inference-id + chunking_settings: { } + + - do: + indices.get_mapping: + index: empty-chunking + + - not_exists: empty-chunking.mappings.properties.inference_field.chunking_settings + +--- +"We return different chunks based on configured chunking overrides or model defaults for sparse embeddings": + + - do: + search: + index: default-chunking-sparse + body: + query: + semantic: + field: "inference_field" + query: "What is Elasticsearch?" + highlight: + fields: + inference_field: + type: "semantic" + number_of_fragments: 3 + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - length: { hits.hits.0.highlight.inference_field: 1 } + - match: { hits.hits.0.highlight.inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + + - do: + search: + index: custom-chunking-sparse + body: + query: + semantic: + field: "inference_field" + query: "What is Elasticsearch?" + highlight: + fields: + inference_field: + type: "semantic" + number_of_fragments: 3 + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_2" } + - length: { hits.hits.0.highlight.inference_field: 3 } + - match: { hits.hits.0.highlight.inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which" } + - match: { hits.hits.0.highlight.inference_field.1: " which is built on top of Lucene internally and enjoys" } + - match: { hits.hits.0.highlight.inference_field.2: " enjoys all the features it provides." } + +--- +"We return different chunks based on configured chunking overrides or model defaults for dense embeddings": + + - do: + search: + index: default-chunking-dense + body: + query: + semantic: + field: "inference_field" + query: "What is Elasticsearch?" + highlight: + fields: + inference_field: + type: "semantic" + number_of_fragments: 2 + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_3" } + - length: { hits.hits.0.highlight.inference_field: 1 } + - match: { hits.hits.0.highlight.inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + + - do: + search: + index: custom-chunking-dense + body: + query: + semantic: + field: "inference_field" + query: "What is Elasticsearch?" + highlight: + fields: + inference_field: + type: "semantic" + number_of_fragments: 3 + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_4" } + - length: { hits.hits.0.highlight.inference_field: 3 } + - match: { hits.hits.0.highlight.inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which" } + - match: { hits.hits.0.highlight.inference_field.1: " which is built on top of Lucene internally and enjoys" } + - match: { hits.hits.0.highlight.inference_field.2: " enjoys all the features it provides." } + +--- +"We respect multiple semantic_text fields with different chunking configurations": + + - do: + indices.create: + index: mixed-chunking + body: + mappings: + properties: + keyword_field: + type: keyword + default_chunked_inference_field: + type: semantic_text + inference_id: sparse-inference-id + customized_chunked_inference_field: + type: semantic_text + inference_id: sparse-inference-id + chunking_settings: + strategy: word + max_chunk_size: 10 + overlap: 1 + + - do: + index: + index: mixed-chunking + id: doc_1 + body: + default_chunked_inference_field: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." + customized_chunked_inference_field: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." + refresh: true + + - do: + search: + index: mixed-chunking + body: + query: + bool: + should: + - match: + default_chunked_inference_field: "What is Elasticsearch?" + - match: + customized_chunked_inference_field: "What is Elasticsearch?" + highlight: + fields: + default_chunked_inference_field: + type: "semantic" + number_of_fragments: 3 + customized_chunked_inference_field: + type: "semantic" + number_of_fragments: 3 + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - length: { hits.hits.0.highlight.default_chunked_inference_field: 1 } + - match: { hits.hits.0.highlight.default_chunked_inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - length: { hits.hits.0.highlight.customized_chunked_inference_field: 3 } + - match: { hits.hits.0.highlight.customized_chunked_inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which" } + - match: { hits.hits.0.highlight.customized_chunked_inference_field.1: " which is built on top of Lucene internally and enjoys" } + - match: { hits.hits.0.highlight.customized_chunked_inference_field.2: " enjoys all the features it provides." } + +--- +"Bulk requests are handled appropriately": + + - do: + indices.create: + index: index1 + body: + mappings: + properties: + keyword_field: + type: keyword + inference_field: + type: semantic_text + inference_id: sparse-inference-id + chunking_settings: + strategy: word + max_chunk_size: 10 + overlap: 1 + + - do: + indices.create: + index: index2 + body: + mappings: + properties: + keyword_field: + type: keyword + inference_field: + type: semantic_text + inference_id: sparse-inference-id + + - do: + bulk: + refresh: true + body: | + { "index": { "_index": "index1", "_id": "doc_1" }} + { "inference_field": "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + { "index": { "_index": "index2", "_id": "doc_2" }} + { "inference_field": "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + { "index": { "_index": "index1", "_id": "doc_3" }} + { "inference_field": "Elasticsearch is a free, open-source search engine and analytics tool that stores and indexes data." } + + - do: + search: + index: index1,index2 + body: + query: + semantic: + field: "inference_field" + query: "What is Elasticsearch?" + highlight: + fields: + inference_field: + type: "semantic" + number_of_fragments: 3 + + - match: { hits.total.value: 3 } + + - match: { hits.hits.0._id: "doc_3" } + - length: { hits.hits.0.highlight.inference_field: 2 } + - match: { hits.hits.0.highlight.inference_field.0: "Elasticsearch is a free, open-source search engine and analytics" } + - match: { hits.hits.0.highlight.inference_field.1: " analytics tool that stores and indexes data." } + + - match: { hits.hits.1._id: "doc_1" } + - length: { hits.hits.1.highlight.inference_field: 3 } + - match: { hits.hits.1.highlight.inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which" } + - match: { hits.hits.1.highlight.inference_field.1: " which is built on top of Lucene internally and enjoys" } + - match: { hits.hits.1.highlight.inference_field.2: " enjoys all the features it provides." } + + - match: { hits.hits.2._id: "doc_2" } + - length: { hits.hits.2.highlight.inference_field: 1 } + - match: { hits.hits.2.highlight.inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + + +--- +"Invalid chunking settings will result in an error": + + - do: + catch: /chunking settings can not have the following settings/ + indices.create: + index: invalid-chunking-extra-stuff + body: + mappings: + properties: + keyword_field: + type: keyword + inference_field: + type: semantic_text + inference_id: sparse-inference-id + chunking_settings: + strategy: word + max_chunk_size: 10 + overlap: 1 + extra: stuff + + - do: + catch: /\[chunking_settings\] does not contain the required setting \[max_chunk_size\]/ + indices.create: + index: invalid-chunking-missing-required-settings + body: + mappings: + properties: + keyword_field: + type: keyword + inference_field: + type: semantic_text + inference_id: sparse-inference-id + chunking_settings: + strategy: word + + - do: + catch: /Invalid chunkingStrategy/ + indices.create: + index: invalid-chunking-invalid-strategy + body: + mappings: + properties: + inference_field: + type: semantic_text + inference_id: sparse-inference-id + chunking_settings: + strategy: invalid + +--- +"We can update chunking settings": + + - do: + indices.create: + index: chunking-update + body: + mappings: + properties: + inference_field: + type: semantic_text + inference_id: sparse-inference-id + + - do: + indices.get_mapping: + index: chunking-update + + - not_exists: chunking-update.mappings.properties.inference_field.chunking_settings + + - do: + indices.put_mapping: + index: chunking-update + body: + properties: + inference_field: + type: semantic_text + inference_id: sparse-inference-id + chunking_settings: + strategy: word + max_chunk_size: 10 + overlap: 1 + + - do: + indices.get_mapping: + index: chunking-update + + - match: { "chunking-update.mappings.properties.inference_field.chunking_settings.strategy": "word" } + - match: { "chunking-update.mappings.properties.inference_field.chunking_settings.max_chunk_size": 10 } + - match: { "chunking-update.mappings.properties.inference_field.chunking_settings.overlap": 1 } + + - do: + indices.put_mapping: + index: chunking-update + body: + properties: + inference_field: + type: semantic_text + inference_id: sparse-inference-id + + - do: + indices.get_mapping: + index: chunking-update + + - not_exists: chunking-update.mappings.properties.inference_field.chunking_settings + diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking_bwc.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking_bwc.yml new file mode 100644 index 0000000000000..f189d5535bb77 --- /dev/null +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking_bwc.yml @@ -0,0 +1,550 @@ +setup: + - requires: + cluster_features: "semantic_text.support_chunking_config" + reason: semantic_text chunking configuration added in 8.19 + + - do: + inference.put: + task_type: text_embedding + inference_id: dense-inference-id + body: > + { + "service": "text_embedding_test_service", + "service_settings": { + "model": "my_model", + "dimensions": 10, + "similarity": "cosine", + "api_key": "abc64" + }, + "task_settings": { + } + } + + - do: + inference.put: + task_type: sparse_embedding + inference_id: sparse-inference-id + body: > + { + "service": "test_service", + "service_settings": { + "model": "my_model", + "api_key": "abc64" + }, + "task_settings": { + } + } + + - do: + indices.create: + index: default-chunking-sparse + body: + settings: + index.mapping.semantic_text.use_legacy_format: true + mappings: + properties: + keyword_field: + type: keyword + inference_field: + type: semantic_text + inference_id: sparse-inference-id + + - do: + indices.create: + index: default-chunking-dense + body: + settings: + index.mapping.semantic_text.use_legacy_format: true + mappings: + properties: + keyword_field: + type: keyword + inference_field: + type: semantic_text + inference_id: dense-inference-id + + - do: + indices.create: + index: custom-chunking-sparse + body: + settings: + index.mapping.semantic_text.use_legacy_format: true + mappings: + properties: + keyword_field: + type: keyword + inference_field: + type: semantic_text + inference_id: sparse-inference-id + chunking_settings: + strategy: word + max_chunk_size: 10 + overlap: 1 + + - do: + indices.create: + index: custom-chunking-dense + body: + settings: + index.mapping.semantic_text.use_legacy_format: true + mappings: + properties: + keyword_field: + type: keyword + inference_field: + type: semantic_text + inference_id: dense-inference-id + chunking_settings: + strategy: word + max_chunk_size: 10 + overlap: 1 + + - do: + index: + index: default-chunking-sparse + id: doc_1 + body: + keyword_field: "default sentence chunking" + inference_field: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." + refresh: true + + - do: + index: + index: custom-chunking-sparse + id: doc_2 + body: + keyword_field: "custom word chunking" + inference_field: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." + refresh: true + + - do: + index: + index: default-chunking-dense + id: doc_3 + body: + keyword_field: "default sentence chunking" + inference_field: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." + refresh: true + + - do: + index: + index: custom-chunking-dense + id: doc_4 + body: + keyword_field: "custom word chunking" + inference_field: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." + refresh: true + +--- +"We return chunking configurations with mappings": + + - do: + indices.get_mapping: + index: default-chunking-sparse + + - not_exists: default-chunking-sparse.mappings.properties.inference_field.chunking_settings + + - do: + indices.get_mapping: + index: custom-chunking-sparse + + - match: { "custom-chunking-sparse.mappings.properties.inference_field.chunking_settings.strategy": "word" } + - match: { "custom-chunking-sparse.mappings.properties.inference_field.chunking_settings.max_chunk_size": 10 } + - match: { "custom-chunking-sparse.mappings.properties.inference_field.chunking_settings.overlap": 1 } + + - do: + indices.get_mapping: + index: default-chunking-dense + + - not_exists: default-chunking-dense.mappings.properties.inference_field.chunking_settings + + - do: + indices.get_mapping: + index: custom-chunking-dense + + - match: { "custom-chunking-dense.mappings.properties.inference_field.chunking_settings.strategy": "word" } + - match: { "custom-chunking-dense.mappings.properties.inference_field.chunking_settings.max_chunk_size": 10 } + - match: { "custom-chunking-dense.mappings.properties.inference_field.chunking_settings.overlap": 1 } + + +--- +"We do not set custom chunking settings for null or empty specified chunking settings": + + - do: + indices.create: + index: null-chunking + body: + settings: + index.mapping.semantic_text.use_legacy_format: true + mappings: + properties: + inference_field: + type: semantic_text + inference_id: dense-inference-id + chunking_settings: null + + - do: + indices.get_mapping: + index: null-chunking + + - not_exists: null-chunking.mappings.properties.inference_field.chunking_settings + + + - do: + indices.create: + index: empty-chunking + body: + settings: + index.mapping.semantic_text.use_legacy_format: true + mappings: + properties: + inference_field: + type: semantic_text + inference_id: sparse-inference-id + chunking_settings: { } + + - do: + indices.get_mapping: + index: empty-chunking + + - not_exists: empty-chunking.mappings.properties.inference_field.chunking_settings + +--- +"We return different chunks based on configured chunking overrides or model defaults for sparse embeddings": + + - do: + search: + index: default-chunking-sparse + body: + query: + semantic: + field: "inference_field" + query: "What is Elasticsearch?" + highlight: + fields: + inference_field: + type: "semantic" + number_of_fragments: 3 + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - length: { hits.hits.0.highlight.inference_field: 1 } + - match: { hits.hits.0.highlight.inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + + - do: + search: + index: custom-chunking-sparse + body: + query: + semantic: + field: "inference_field" + query: "What is Elasticsearch?" + highlight: + fields: + inference_field: + type: "semantic" + number_of_fragments: 3 + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_2" } + - length: { hits.hits.0.highlight.inference_field: 3 } + - match: { hits.hits.0.highlight.inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which" } + - match: { hits.hits.0.highlight.inference_field.1: " which is built on top of Lucene internally and enjoys" } + - match: { hits.hits.0.highlight.inference_field.2: " enjoys all the features it provides." } + +--- +"We return different chunks based on configured chunking overrides or model defaults for dense embeddings": + + - do: + search: + index: default-chunking-dense + body: + query: + semantic: + field: "inference_field" + query: "What is Elasticsearch?" + highlight: + fields: + inference_field: + type: "semantic" + number_of_fragments: 2 + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_3" } + - length: { hits.hits.0.highlight.inference_field: 1 } + - match: { hits.hits.0.highlight.inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + + - do: + search: + index: custom-chunking-dense + body: + query: + semantic: + field: "inference_field" + query: "What is Elasticsearch?" + highlight: + fields: + inference_field: + type: "semantic" + number_of_fragments: 3 + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_4" } + - length: { hits.hits.0.highlight.inference_field: 3 } + - match: { hits.hits.0.highlight.inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which" } + - match: { hits.hits.0.highlight.inference_field.1: " which is built on top of Lucene internally and enjoys" } + - match: { hits.hits.0.highlight.inference_field.2: " enjoys all the features it provides." } + +--- +"We respect multiple semantic_text fields with different chunking configurations": + + - do: + indices.create: + index: mixed-chunking + body: + settings: + index.mapping.semantic_text.use_legacy_format: true + mappings: + properties: + keyword_field: + type: keyword + default_chunked_inference_field: + type: semantic_text + inference_id: sparse-inference-id + customized_chunked_inference_field: + type: semantic_text + inference_id: sparse-inference-id + chunking_settings: + strategy: word + max_chunk_size: 10 + overlap: 1 + + - do: + index: + index: mixed-chunking + id: doc_1 + body: + default_chunked_inference_field: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." + customized_chunked_inference_field: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." + refresh: true + + - do: + search: + index: mixed-chunking + body: + query: + bool: + should: + - match: + default_chunked_inference_field: "What is Elasticsearch?" + - match: + customized_chunked_inference_field: "What is Elasticsearch?" + highlight: + fields: + default_chunked_inference_field: + type: "semantic" + number_of_fragments: 3 + customized_chunked_inference_field: + type: "semantic" + number_of_fragments: 3 + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - length: { hits.hits.0.highlight.default_chunked_inference_field: 1 } + - match: { hits.hits.0.highlight.default_chunked_inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - length: { hits.hits.0.highlight.customized_chunked_inference_field: 3 } + - match: { hits.hits.0.highlight.customized_chunked_inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which" } + - match: { hits.hits.0.highlight.customized_chunked_inference_field.1: " which is built on top of Lucene internally and enjoys" } + - match: { hits.hits.0.highlight.customized_chunked_inference_field.2: " enjoys all the features it provides." } + + +--- +"Bulk requests are handled appropriately": + + - do: + indices.create: + index: index1 + body: + settings: + index.mapping.semantic_text.use_legacy_format: true + mappings: + properties: + keyword_field: + type: keyword + inference_field: + type: semantic_text + inference_id: sparse-inference-id + chunking_settings: + strategy: word + max_chunk_size: 10 + overlap: 1 + + - do: + indices.create: + index: index2 + body: + settings: + index.mapping.semantic_text.use_legacy_format: true + mappings: + properties: + keyword_field: + type: keyword + inference_field: + type: semantic_text + inference_id: sparse-inference-id + + - do: + bulk: + refresh: true + body: | + { "index": { "_index": "index1", "_id": "doc_1" }} + { "inference_field": "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + { "index": { "_index": "index2", "_id": "doc_2" }} + { "inference_field": "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + { "index": { "_index": "index1", "_id": "doc_3" }} + { "inference_field": "Elasticsearch is a free, open-source search engine and analytics tool that stores and indexes data." } + + - do: + search: + index: index1,index2 + body: + query: + semantic: + field: "inference_field" + query: "What is Elasticsearch?" + highlight: + fields: + inference_field: + type: "semantic" + number_of_fragments: 3 + + - match: { hits.total.value: 3 } + + - match: { hits.hits.0._id: "doc_3" } + - length: { hits.hits.0.highlight.inference_field: 2 } + - match: { hits.hits.0.highlight.inference_field.0: "Elasticsearch is a free, open-source search engine and analytics" } + - match: { hits.hits.0.highlight.inference_field.1: " analytics tool that stores and indexes data." } + + - match: { hits.hits.1._id: "doc_1" } + - length: { hits.hits.1.highlight.inference_field: 3 } + - match: { hits.hits.1.highlight.inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which" } + - match: { hits.hits.1.highlight.inference_field.1: " which is built on top of Lucene internally and enjoys" } + - match: { hits.hits.1.highlight.inference_field.2: " enjoys all the features it provides." } + + - match: { hits.hits.2._id: "doc_2" } + - length: { hits.hits.2.highlight.inference_field: 1 } + - match: { hits.hits.2.highlight.inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + + +--- +"Invalid chunking settings will result in an error": + + - do: + catch: /chunking settings can not have the following settings/ + indices.create: + index: invalid-chunking-extra-stuff + body: + settings: + index.mapping.semantic_text.use_legacy_format: true + mappings: + properties: + keyword_field: + type: keyword + inference_field: + type: semantic_text + inference_id: sparse-inference-id + chunking_settings: + strategy: word + max_chunk_size: 10 + overlap: 1 + extra: stuff + + - do: + catch: /\[chunking_settings\] does not contain the required setting \[max_chunk_size\]/ + indices.create: + index: invalid-chunking-missing-required-settings + body: + settings: + index.mapping.semantic_text.use_legacy_format: true + mappings: + properties: + keyword_field: + type: keyword + inference_field: + type: semantic_text + inference_id: sparse-inference-id + chunking_settings: + strategy: word + + - do: + catch: /Invalid chunkingStrategy/ + indices.create: + index: invalid-chunking-invalid-strategy + body: + settings: + index.mapping.semantic_text.use_legacy_format: true + mappings: + properties: + inference_field: + type: semantic_text + inference_id: sparse-inference-id + chunking_settings: + strategy: invalid + +--- +"We can update chunking settings": + + - do: + indices.create: + index: chunking-update + body: + settings: + index.mapping.semantic_text.use_legacy_format: true + mappings: + properties: + inference_field: + type: semantic_text + inference_id: sparse-inference-id + + - do: + indices.get_mapping: + index: chunking-update + + - not_exists: chunking-update.mappings.properties.inference_field.chunking_settings + + - do: + indices.put_mapping: + index: chunking-update + body: + properties: + inference_field: + type: semantic_text + inference_id: sparse-inference-id + chunking_settings: + strategy: word + max_chunk_size: 10 + overlap: 1 + + - do: + indices.get_mapping: + index: chunking-update + + - match: { "chunking-update.mappings.properties.inference_field.chunking_settings.strategy": "word" } + - match: { "chunking-update.mappings.properties.inference_field.chunking_settings.max_chunk_size": 10 } + - match: { "chunking-update.mappings.properties.inference_field.chunking_settings.overlap": 1 } + + - do: + indices.put_mapping: + index: chunking-update + body: + properties: + inference_field: + type: semantic_text + inference_id: sparse-inference-id + + - do: + indices.get_mapping: + index: chunking-update + + - not_exists: chunking-update.mappings.properties.inference_field.chunking_settings + diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/60_semantic_text_inference_update.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/60_semantic_text_inference_update.yml index 27c405f6c23bf..35e472e72b06d 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/60_semantic_text_inference_update.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/60_semantic_text_inference_update.yml @@ -79,7 +79,7 @@ setup: body: fields: [ _inference_fields ] query: - match_all: {} + match_all: { } - match: { hits.total.value: 1 } - match: { hits.total.relation: eq } @@ -104,7 +104,7 @@ setup: body: fields: [ _inference_fields ] query: - match_all: {} + match_all: { } - match: { hits.total.value: 1 } - match: { hits.total.relation: eq } @@ -129,7 +129,7 @@ setup: body: fields: [ _inference_fields ] query: - match_all: {} + match_all: { } - match: { hits.total.value: 1 } - match: { hits.total.relation: eq } @@ -152,7 +152,7 @@ setup: body: fields: [ _inference_fields ] query: - match_all: {} + match_all: { } # We can't directly check that the embeddings are different since there isn't a "does not match" assertion in the # YAML test framework. Check that the start and end offsets change as expected as a proxy. @@ -179,7 +179,7 @@ setup: body: fields: [ _inference_fields ] query: - match_all: {} + match_all: { } - match: { hits.total.value: 1 } - match: { hits.total.relation: eq } @@ -202,7 +202,7 @@ setup: body: fields: [ _inference_fields ] query: - match_all: {} + match_all: { } - match: { hits.total.value: 1 } - match: { hits.total.relation: eq } @@ -254,7 +254,7 @@ setup: body: fields: [ _inference_fields ] query: - match_all: {} + match_all: { } - match: { hits.total.value: 1 } - match: { hits.total.relation: eq } @@ -283,7 +283,7 @@ setup: body: fields: [ _inference_fields ] query: - match_all: {} + match_all: { } - match: { hits.total.value: 1 } - match: { hits.total.relation: eq } @@ -320,7 +320,7 @@ setup: body: fields: [ _inference_fields ] query: - match_all: {} + match_all: { } - match: { hits.total.value: 1 } - match: { hits.total.relation: eq } @@ -367,7 +367,7 @@ setup: index: test-index id: doc_1 body: - doc: { "sparse_field": [{"key": "value"}], "dense_field": [{"key": "value"}] } + doc: { "sparse_field": [ { "key": "value" } ], "dense_field": [ { "key": "value" } ] } - match: { error.type: "status_exception" } - match: { error.reason: "/Invalid\\ format\\ for\\ field\\ \\[(dense|sparse)_field\\].+/" } @@ -415,7 +415,7 @@ setup: body: fields: [ _inference_fields ] query: - match_all: {} + match_all: { } - match: { hits.total.value: 1 } - match: { hits.total.relation: eq } @@ -448,7 +448,7 @@ setup: body: fields: [ _inference_fields ] query: - match_all: {} + match_all: { } - match: { hits.total.value: 1 } - match: { hits.total.relation: eq } @@ -509,7 +509,7 @@ setup: body: fields: [ _inference_fields ] query: - match_all: {} + match_all: { } - match: { hits.total.value: 1 } - match: { hits.total.relation: eq } @@ -540,7 +540,7 @@ setup: body: fields: [ _inference_fields ] query: - match_all: {} + match_all: { } - match: { hits.total.value: 1 } - match: { hits.total.relation: eq }