diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java index 16acfaa1af430..cc81d8215a74d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java @@ -95,7 +95,12 @@ public void parseRequestConfig( Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); - var chunkingSettings = extractChunkingSettings(config, taskType); + ChunkingSettings chunkingSettings = null; + if (TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap( + removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS) + ); + } CustomModel model = createModel( inferenceEntityId, @@ -146,7 +151,14 @@ private static RequestParameters createParameters(CustomModel model) { }; } - private static ChunkingSettings extractChunkingSettings(Map config, TaskType taskType) { + private static ChunkingSettings extractPersistentChunkingSettings(Map config, TaskType taskType) { + /* + * There's a sutle difference between how the chunking settings are parsed for the request context vs the persistent context. + * For persistent context, to support backwards compatibility, if the chunking settings are not present, removeFromMap will + * return null which results in the older word boundary chunking settings being used as the default. + * For request context, removeFromMapOrDefaultEmpty returns an empty map which results in the newer sentence boundary chunking + * settings being used as the default. + */ if (TaskType.TEXT_EMBEDDING.equals(taskType)) { return ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); } @@ -219,7 +231,7 @@ public CustomModel parsePersistedConfigWithSecrets( Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); Map secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS); - var chunkingSettings = extractChunkingSettings(config, taskType); + var chunkingSettings = extractPersistentChunkingSettings(config, taskType); return createModelWithoutLoggingDeprecations( inferenceEntityId, @@ -236,7 +248,7 @@ public CustomModel parsePersistedConfig(String inferenceEntityId, TaskType taskT Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); - var chunkingSettings = extractChunkingSettings(config, taskType); + var chunkingSettings = extractPersistentChunkingSettings(config, taskType); return createModelWithoutLoggingDeprecations( inferenceEntityId, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java index 6ddb4ff71eeb3..d33b58ca2f943 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java @@ -31,6 +31,7 @@ import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; @@ -53,7 +54,9 @@ import java.util.Map; import static org.elasticsearch.xpack.inference.Utils.TIMEOUT; +import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; import static org.elasticsearch.xpack.inference.Utils.getRequestConfigMap; +import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; @@ -312,6 +315,93 @@ private static CustomServiceSettings.TextEmbeddingSettings getDefaultTextEmbeddi : CustomServiceSettings.TextEmbeddingSettings.NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS; } + public void testParseRequestConfig_CreatesAnEmbeddingsModel_WhenChunkingSettingsProvided() throws Exception { + var chunkingSettingsMap = createRandomChunkingSettingsMap(); + + try (var service = createService(threadPool, clientManager)) { + var config = getRequestConfigMap( + createServiceSettingsMap(TaskType.TEXT_EMBEDDING), + createTaskSettingsMap(), + chunkingSettingsMap, + createSecretSettingsMap() + ); + + var listener = new PlainActionFuture(); + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, listener); + var model = listener.actionGet(TIMEOUT); + + assertModel(model, TaskType.TEXT_EMBEDDING); + + var expectedChunkingSettings = ChunkingSettingsBuilder.fromMap(chunkingSettingsMap); + assertThat(model.getConfigurations().getChunkingSettings(), is(expectedChunkingSettings)); + } + } + + public void testParseRequestConfig_CreatesAnEmbeddingsModel_WhenChunkingSettingsNotProvided() throws Exception { + try (var service = createService(threadPool, clientManager)) { + var config = getRequestConfigMap( + createServiceSettingsMap(TaskType.TEXT_EMBEDDING), + createTaskSettingsMap(), + createSecretSettingsMap() + ); + + var listener = new PlainActionFuture(); + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, listener); + var model = listener.actionGet(TIMEOUT); + + assertModel(model, TaskType.TEXT_EMBEDDING); + + var expectedChunkingSettings = ChunkingSettingsBuilder.fromMap(Map.of()); + assertThat(model.getConfigurations().getChunkingSettings(), is(expectedChunkingSettings)); + } + } + + public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModel_WhenChunkingSettingsProvided() throws Exception { + var chunkingSettingsMap = createRandomChunkingSettingsMap(); + + try (var service = createService(threadPool, clientManager)) { + var persistedConfigMap = getPersistedConfigMap( + createServiceSettingsMap(TaskType.TEXT_EMBEDDING), + createTaskSettingsMap(), + chunkingSettingsMap, + createSecretSettingsMap() + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfigMap.config(), + persistedConfigMap.secrets() + ); + + assertModel(model, TaskType.TEXT_EMBEDDING); + + var expectedChunkingSettings = ChunkingSettingsBuilder.fromMap(chunkingSettingsMap); + assertThat(model.getConfigurations().getChunkingSettings(), is(expectedChunkingSettings)); + } + } + + public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModel_WhenChunkingSettingsNotProvided() throws Exception { + try (var service = createService(threadPool, clientManager)) { + var persistedConfigMap = getPersistedConfigMap( + createServiceSettingsMap(TaskType.TEXT_EMBEDDING), + createTaskSettingsMap(), + createSecretSettingsMap() + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfigMap.config(), + persistedConfigMap.secrets() + ); + assertModel(model, TaskType.TEXT_EMBEDDING); + + var expectedChunkingSettings = ChunkingSettingsBuilder.fromMap(null); + assertThat(model.getConfigurations().getChunkingSettings(), is(expectedChunkingSettings)); + } + } + public void testInfer_ReturnsAnError_WithoutParsingTheResponseBody() throws IOException { try (var service = createService(threadPool, clientManager)) { String responseJson = "error";