Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,12 @@ public void parseRequestConfig(
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);

var chunkingSettings = extractChunkingSettings(config, taskType);
ChunkingSettings chunkingSettings = null;
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
chunkingSettings = ChunkingSettingsBuilder.fromMap(
removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)
);
}

CustomModel model = createModel(
inferenceEntityId,
Expand Down Expand Up @@ -146,7 +151,14 @@ private static RequestParameters createParameters(CustomModel model) {
};
}

private static ChunkingSettings extractChunkingSettings(Map<String, Object> config, TaskType taskType) {
private static ChunkingSettings extractPersistentChunkingSettings(Map<String, Object> config, TaskType taskType) {
/*
* There's a sutle difference between how the chunking settings are parsed for the request context vs the persistent context.
* For persistent context, to support backwards compatibility, if the chunking settings are not present, removeFromMap will
* return null which results in the older word boundary chunking settings being used as the default.
* For request context, removeFromMapOrDefaultEmpty returns an empty map which results in the newer sentence boundary chunking
* settings being used as the default.
*/
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
return ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
}
Expand Down Expand Up @@ -219,7 +231,7 @@ public CustomModel parsePersistedConfigWithSecrets(
Map<String, Object> taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);
Map<String, Object> secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS);

var chunkingSettings = extractChunkingSettings(config, taskType);
var chunkingSettings = extractPersistentChunkingSettings(config, taskType);

return createModelWithoutLoggingDeprecations(
inferenceEntityId,
Expand All @@ -236,7 +248,7 @@ public CustomModel parsePersistedConfig(String inferenceEntityId, TaskType taskT
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map<String, Object> taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);

var chunkingSettings = extractChunkingSettings(config, taskType);
var chunkingSettings = extractPersistentChunkingSettings(config, taskType);

return createModelWithoutLoggingDeprecations(
inferenceEntityId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
Expand All @@ -53,7 +54,9 @@
import java.util.Map;

import static org.elasticsearch.xpack.inference.Utils.TIMEOUT;
import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap;
import static org.elasticsearch.xpack.inference.Utils.getRequestConfigMap;
import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
Expand Down Expand Up @@ -312,6 +315,93 @@ private static CustomServiceSettings.TextEmbeddingSettings getDefaultTextEmbeddi
: CustomServiceSettings.TextEmbeddingSettings.NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS;
}

public void testParseRequestConfig_CreatesAnEmbeddingsModel_WhenChunkingSettingsProvided() throws Exception {
var chunkingSettingsMap = createRandomChunkingSettingsMap();

try (var service = createService(threadPool, clientManager)) {
var config = getRequestConfigMap(
createServiceSettingsMap(TaskType.TEXT_EMBEDDING),
createTaskSettingsMap(),
chunkingSettingsMap,
createSecretSettingsMap()
);

var listener = new PlainActionFuture<Model>();
service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, listener);
var model = listener.actionGet(TIMEOUT);

assertModel(model, TaskType.TEXT_EMBEDDING);

var expectedChunkingSettings = ChunkingSettingsBuilder.fromMap(chunkingSettingsMap);
assertThat(model.getConfigurations().getChunkingSettings(), is(expectedChunkingSettings));
}
}

public void testParseRequestConfig_CreatesAnEmbeddingsModel_WhenChunkingSettingsNotProvided() throws Exception {
try (var service = createService(threadPool, clientManager)) {
var config = getRequestConfigMap(
createServiceSettingsMap(TaskType.TEXT_EMBEDDING),
createTaskSettingsMap(),
createSecretSettingsMap()
);

var listener = new PlainActionFuture<Model>();
service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, listener);
var model = listener.actionGet(TIMEOUT);

assertModel(model, TaskType.TEXT_EMBEDDING);

var expectedChunkingSettings = ChunkingSettingsBuilder.fromMap(Map.of());
assertThat(model.getConfigurations().getChunkingSettings(), is(expectedChunkingSettings));
}
}

public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModel_WhenChunkingSettingsProvided() throws Exception {
var chunkingSettingsMap = createRandomChunkingSettingsMap();

try (var service = createService(threadPool, clientManager)) {
var persistedConfigMap = getPersistedConfigMap(
createServiceSettingsMap(TaskType.TEXT_EMBEDDING),
createTaskSettingsMap(),
chunkingSettingsMap,
createSecretSettingsMap()
);

var model = service.parsePersistedConfigWithSecrets(
"id",
TaskType.TEXT_EMBEDDING,
persistedConfigMap.config(),
persistedConfigMap.secrets()
);

assertModel(model, TaskType.TEXT_EMBEDDING);

var expectedChunkingSettings = ChunkingSettingsBuilder.fromMap(chunkingSettingsMap);
assertThat(model.getConfigurations().getChunkingSettings(), is(expectedChunkingSettings));
}
}

public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModel_WhenChunkingSettingsNotProvided() throws Exception {
try (var service = createService(threadPool, clientManager)) {
var persistedConfigMap = getPersistedConfigMap(
createServiceSettingsMap(TaskType.TEXT_EMBEDDING),
createTaskSettingsMap(),
createSecretSettingsMap()
);

var model = service.parsePersistedConfigWithSecrets(
"id",
TaskType.TEXT_EMBEDDING,
persistedConfigMap.config(),
persistedConfigMap.secrets()
);
assertModel(model, TaskType.TEXT_EMBEDDING);

var expectedChunkingSettings = ChunkingSettingsBuilder.fromMap(null);
assertThat(model.getConfigurations().getChunkingSettings(), is(expectedChunkingSettings));
}
}

public void testInfer_ReturnsAnError_WithoutParsingTheResponseBody() throws IOException {
try (var service = createService(threadPool, clientManager)) {
String responseJson = "error";
Expand Down