Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,13 @@ public Map<String, InferenceService> getServices() {
}

public Optional<InferenceService> getService(String serviceName) {
return Optional.ofNullable(services.get(serviceName));

if ("elser".equals(serviceName)) { // ElserService.NAME before removal
// here we are aliasing the elser service to use the elasticsearch service instead
return Optional.ofNullable(services.get("elasticsearch")); // ElasticsearchInternalService.NAME
} else {
return Optional.ofNullable(services.get(serviceName));
}
}

public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,10 @@
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.InferencePlugin;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
import org.elasticsearch.xpack.inference.services.elser.ElserInternalModel;
import org.elasticsearch.xpack.inference.services.elser.ElserInternalService;
import org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettingsTests;
import org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceTests;
import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeTaskSettingsTests;
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalModel;
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService;
import org.elasticsearch.xpack.inference.services.elasticsearch.ElserInternalServiceSettingsTests;
import org.elasticsearch.xpack.inference.services.elasticsearch.ElserMlNodeTaskSettingsTests;
import org.junit.Before;

import java.io.IOException;
Expand Down Expand Up @@ -118,10 +117,10 @@ public void testGetModel() throws Exception {

assertEquals(model.getConfigurations().getService(), modelHolder.get().service());

var elserService = new ElserInternalService(
var elserService = new ElasticsearchInternalService(
new InferenceServiceExtension.InferenceServiceFactoryContext(mock(Client.class), mock(ThreadPool.class))
);
ElserInternalModel roundTripModel = elserService.parsePersistedConfigWithSecrets(
ElasticsearchInternalModel roundTripModel = (ElasticsearchInternalModel) elserService.parsePersistedConfigWithSecrets(
modelHolder.get().inferenceEntityId(),
modelHolder.get().taskType(),
modelHolder.get().settings(),
Expand Down Expand Up @@ -277,7 +276,17 @@ public void testGetModelWithSecrets() throws InterruptedException {
}

private Model buildElserModelConfig(String inferenceEntityId, TaskType taskType) {
return ElserInternalServiceTests.randomModelConfig(inferenceEntityId, taskType);
return switch (taskType) {
case SPARSE_EMBEDDING -> new org.elasticsearch.xpack.inference.services.elasticsearch.ElserInternalModel(
inferenceEntityId,
taskType,
ElasticsearchInternalService.NAME,
ElserInternalServiceSettingsTests.createRandom(),
ElserMlNodeTaskSettingsTests.createRandom()
);
default -> throw new IllegalArgumentException("task type " + taskType + " is not supported");
};

}

protected <T> void blockingCall(Consumer<ActionListener<T>> function, AtomicReference<T> response, AtomicReference<Exception> error)
Expand All @@ -300,7 +309,7 @@ private static Model buildModelWithUnknownField(String inferenceEntityId) {
new ModelWithUnknownField(
inferenceEntityId,
TaskType.SPARSE_EMBEDDING,
ElserInternalService.NAME,
ElasticsearchInternalService.NAME,
ElserInternalServiceSettingsTests.createRandom(),
ElserMlNodeTaskSettingsTests.createRandom()
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalTextEmbeddingServiceSettings;
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings;
import org.elasticsearch.xpack.inference.services.elasticsearch.ElserInternalServiceSettings;
import org.elasticsearch.xpack.inference.services.elasticsearch.ElserMlNodeTaskSettings;
import org.elasticsearch.xpack.inference.services.elasticsearch.MultilingualE5SmallInternalServiceSettings;
import org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings;
import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeTaskSettings;
import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.googleaistudio.embeddings.GoogleAiStudioEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiSecretSettings;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceFeature;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings;
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService;
import org.elasticsearch.xpack.inference.services.elser.ElserInternalService;
import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioService;
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiService;
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceService;
Expand Down Expand Up @@ -229,7 +228,6 @@ public void loadExtensions(ExtensionLoader loader) {

public List<InferenceServiceExtension.Factory> getInferenceServiceFactories() {
return List.of(
ElserInternalService::new,
context -> new HuggingFaceElserService(httpFactory.get(), serviceComponents.get()),
context -> new HuggingFaceService(httpFactory.get(), serviceComponents.get()),
context -> new OpenAiService(httpFactory.get(), serviceComponents.get()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.common.logging.DeprecationLogger;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.xcontent.ChunkedToXContent;
import org.elasticsearch.inference.InferenceService;
Expand Down Expand Up @@ -42,6 +43,7 @@ public class TransportInferenceAction extends HandledTransportAction<InferenceAc
private final InferenceServiceRegistry serviceRegistry;
private final InferenceStats inferenceStats;
private final StreamingTaskManager streamingTaskManager;
private static final DeprecationLogger DEPRECATION_LOGGER = DeprecationLogger.getLogger(TransportInferenceAction.class);

@Inject
public TransportInferenceAction(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,14 @@
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.inference.InferencePlugin;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService;

import java.io.IOException;
import java.util.List;
import java.util.Map;

import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.OLD_ELSER_SERVICE_NAME;

public class TransportPutInferenceModelAction extends TransportMasterNodeAction<
PutInferenceModelAction.Request,
Expand Down Expand Up @@ -110,6 +113,10 @@ protected void masterOperation(
return;
}

if (List.of(OLD_ELSER_SERVICE_NAME, ElasticsearchInternalService.NAME).contains(serviceName)) {
// required for BWC of elser service in elasticsearch service TODO remove when elser service deprecated
requestAsMap.put(ModelConfigurations.SERVICE, serviceName);
}
var service = serviceRegistry.getService(serviceName);
if (service.isEmpty()) {
listener.onFailure(new ElasticsearchStatusException("Unknown service [{}]", RestStatus.BAD_REQUEST, serviceName));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.elastic.ElasticInferenceServiceActionVisitor;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.elser.ElserModels;
import org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels;

import java.net.URI;
import java.net.URISyntaxException;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.elser.ElserModels;
import org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels;
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
import org.elasticsearch.xpack.core.ml.utils.MlPlatformArchitecturesUtil;
import org.elasticsearch.xpack.inference.InferencePlugin;
import org.elasticsearch.xpack.inference.services.elser.ElserInternalModel;

import java.io.IOException;
import java.util.EnumSet;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@

package org.elasticsearch.xpack.inference.services.elasticsearch;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.logging.DeprecationCategory;
import org.elasticsearch.common.logging.DeprecationLogger;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkedInferenceServiceResults;
Expand Down Expand Up @@ -55,10 +58,13 @@
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
import static org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels.ELSER_V2_MODEL;
import static org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels.ELSER_V2_MODEL_LINUX_X86;

public class ElasticsearchInternalService extends BaseElasticsearchInternalService {

public static final String NAME = "elasticsearch";
public static final String OLD_ELSER_SERVICE_NAME = "elser";

static final String MULTILINGUAL_E5_SMALL_MODEL_ID = ".multilingual-e5-small";
static final String MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86 = ".multilingual-e5-small_linux-x86_64";
Expand All @@ -67,6 +73,9 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86
);

private static final Logger logger = LogManager.getLogger(ElasticsearchInternalService.class);
private static final DeprecationLogger DEPRECATION_LOGGER = DeprecationLogger.getLogger(ElasticsearchInternalService.class);

public ElasticsearchInternalService(InferenceServiceExtension.InferenceServiceFactoryContext context) {
super(context);
}
Expand Down Expand Up @@ -94,19 +103,41 @@ public void parseRequestConfig(
try {
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map<String, Object> taskSettingsMap = removeFromMap(config, ModelConfigurations.TASK_SETTINGS);
String serviceName = (String) config.remove(ModelConfigurations.SERVICE); // required for elser service in elasticsearch service

throwIfNotEmptyMap(config, name());

String modelId = (String) serviceSettingsMap.get(ElasticsearchInternalServiceSettings.MODEL_ID);
if (modelId == null) {
throw new ValidationException().addValidationError("Error parsing request config, model id is missing");
}
if (MULTILINGUAL_E5_SMALL_VALID_IDS.contains(modelId)) {
if (OLD_ELSER_SERVICE_NAME.equals(serviceName)) {
// TODO complete deprecation of null model ID
// throw new ValidationException().addValidationError("Error parsing request config, model id is missing");
DEPRECATION_LOGGER.critical(
DeprecationCategory.API,
"inference_api_null_model_id_in_elasticsearch_service",
"Putting elasticsearch service inference endpoints (including elser service) without a model_id field is"
+ " deprecated and will be removed in a future release. Please specify a model_id field."
);
platformArch.accept(
modelListener.delegateFailureAndWrap(
(delegate, arch) -> elserCase(inferenceEntityId, taskType, config, arch, serviceSettingsMap, modelListener)
)
);
} else {
throw new IllegalArgumentException("Error parsing service settings, model_id must be provided");
}
} else if (MULTILINGUAL_E5_SMALL_VALID_IDS.contains(modelId)) {
platformArch.accept(
modelListener.delegateFailureAndWrap(
(delegate, arch) -> e5Case(inferenceEntityId, taskType, config, arch, serviceSettingsMap, modelListener)
)
);
} else if (ElserModels.isValidModel(modelId)) {
platformArch.accept(
modelListener.delegateFailureAndWrap(
(delegate, arch) -> elserCase(inferenceEntityId, taskType, config, arch, serviceSettingsMap, modelListener)
)
);
} else {
customElandCase(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, modelListener);
}
Expand Down Expand Up @@ -239,7 +270,86 @@ static boolean modelVariantValidForArchitecture(Set<String> platformArchitecture
// platform agnostic model is always compatible
return true;
}
return modelId.equals(
selectDefaultModelVariantBasedOnClusterArchitecture(
platformArchitectures,
MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86,
MULTILINGUAL_E5_SMALL_MODEL_ID
)
);
}

private void elserCase(
String inferenceEntityId,
TaskType taskType,
Map<String, Object> config,
Set<String> platformArchitectures,
Map<String, Object> serviceSettingsMap,
ActionListener<Model> modelListener
) {
var esServiceSettingsBuilder = ElasticsearchInternalServiceSettings.fromRequestMap(serviceSettingsMap);
final String defaultModelId = selectDefaultModelVariantBasedOnClusterArchitecture(
platformArchitectures,
ELSER_V2_MODEL_LINUX_X86,
ELSER_V2_MODEL
);
if (false == defaultModelId.equals(esServiceSettingsBuilder.getModelId())) {

if (esServiceSettingsBuilder.getModelId() == null) {
// TODO remove this case once we remove the option to not pass model ID
esServiceSettingsBuilder.setModelId(defaultModelId);
} else if (esServiceSettingsBuilder.getModelId().equals(ELSER_V2_MODEL)) {
logger.warn(
"The platform agnostic model [{}] was requested on Linux x86_64. "
+ "It is recommended to use the optimized model instead [{}]",
ELSER_V2_MODEL,
ELSER_V2_MODEL_LINUX_X86
);
} else {
throw new IllegalArgumentException(
"Error parsing request config, model id does not match any models available on this platform. Was ["
+ esServiceSettingsBuilder.getModelId()
+ "]. You may need to use a platform agnostic model."
);
}
}

DEPRECATION_LOGGER.warn(
DeprecationCategory.API,
"inference_api_elser_service",
"The [{}] service is deprecated and will be removed in a future release. Use the [{}] service instead, with"
+ " [model_id] set to [{}] in the [service_settings]",
OLD_ELSER_SERVICE_NAME,
ElasticsearchInternalService.NAME,
defaultModelId
);

if (modelVariantDoesNotMatchArchitecturesAndIsNotPlatformAgnostic(platformArchitectures, esServiceSettingsBuilder.getModelId())) {
throw new IllegalArgumentException(
"Error parsing request config, model id does not match any models available on this platform. Was ["
+ esServiceSettingsBuilder.getModelId()
+ "]"
);
}

throwIfNotEmptyMap(config, name());
throwIfNotEmptyMap(serviceSettingsMap, name());

modelListener.onResponse(
new ElserInternalModel(
inferenceEntityId,
taskType,
NAME,
new ElserInternalServiceSettings(esServiceSettingsBuilder.build()),
ElserMlNodeTaskSettings.DEFAULT
)
);
}

private static boolean modelVariantDoesNotMatchArchitecturesAndIsNotPlatformAgnostic(
Set<String> platformArchitectures,
String modelId
) {
return modelId.equals(
selectDefaultModelVariantBasedOnClusterArchitecture(
platformArchitectures,
Expand Down Expand Up @@ -276,6 +386,14 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M
NAME,
new MultilingualE5SmallInternalServiceSettings(ElasticsearchInternalServiceSettings.fromPersistedMap(serviceSettingsMap))
);
} else if (ElserModels.isValidModel(modelId)) {
return new ElserInternalModel(
inferenceEntityId,
taskType,
NAME,
new ElserInternalServiceSettings(ElasticsearchInternalServiceSettings.fromPersistedMap(serviceSettingsMap)),
ElserMlNodeTaskSettings.DEFAULT
);
} else {
return createCustomElandModel(
inferenceEntityId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ protected static ElasticsearchInternalServiceSettings.Builder fromMap(
validationException
);

// model id is optional as the ELSER and E5 service will default it
// model id is optional as the ELSER service will default it. TODO make this a required field once the elser service is removed
String modelId = extractOptionalString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);

if (numAllocations == null && adaptiveAllocationsSettings == null) {
Expand Down
Loading
Loading