diff --git a/docs/changelog/130092.yaml b/docs/changelog/130092.yaml new file mode 100644 index 0000000000000..0e54e5f013d23 --- /dev/null +++ b/docs/changelog/130092.yaml @@ -0,0 +1,5 @@ +pr: 130092 +summary: "Added Llama provider support to the Inference Plugin" +area: Machine Learning +type: enhancement +issues: [] diff --git a/libs/x-content/src/main/java/org/elasticsearch/xcontent/ConstructingObjectParser.java b/libs/x-content/src/main/java/org/elasticsearch/xcontent/ConstructingObjectParser.java index be13207702627..c3b322db0e3a5 100644 --- a/libs/x-content/src/main/java/org/elasticsearch/xcontent/ConstructingObjectParser.java +++ b/libs/x-content/src/main/java/org/elasticsearch/xcontent/ConstructingObjectParser.java @@ -220,6 +220,27 @@ public void declareField(BiConsumer consumer, ContextParser void declareObjectArrayOrNull( + BiConsumer> consumer, + ContextParser objectParser, + ParseField field + ) { + declareField( + consumer, + (p, c) -> p.currentToken() == XContentParser.Token.VALUE_NULL ? null : parseArray(p, c, objectParser), + field, + ValueType.OBJECT_ARRAY_OR_NULL + ); + } + @Override public void declareNamedObject( BiConsumer consumer, diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 82dfcd9ca56e2..ea2824624e3f7 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -343,6 +343,7 @@ static TransportVersion def(int id) { public static final TransportVersion ESQL_CATEGORIZE_OPTIONS = def(9_122_0_00); public static final TransportVersion ML_INFERENCE_AZURE_AI_STUDIO_RERANK_ADDED = def(9_123_0_00); public static final TransportVersion PROJECT_STATE_REGISTRY_ENTRY = def(9_124_0_00); + public static final TransportVersion ML_INFERENCE_LLAMA_ADDED = def(9_125_0_00); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java b/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java index db31aafc8c190..b6f724e69d40f 100644 --- a/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java +++ b/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java @@ -121,7 +121,7 @@ public static Params withMaxTokensAndSkipStreamOptionsField(String modelId, Para * - Key: {@link #MODEL_FIELD}, Value: modelId * - Key: {@link #MAX_COMPLETION_TOKENS_FIELD}, Value: {@link #maxCompletionTokens()} */ - public static Params withMaxCompletionTokensTokens(String modelId, Params params) { + public static Params withMaxCompletionTokens(String modelId, Params params) { return new DelegatingMapParams( Map.ofEntries(Map.entry(MODEL_ID_PARAM, modelId), Map.entry(MAX_TOKENS_PARAM, MAX_COMPLETION_TOKENS_FIELD)), params diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequestTests.java index 56dc2a6d0212a..fd5632606867e 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequestTests.java @@ -119,7 +119,7 @@ public void testParseAllFields() throws IOException { assertThat(request, is(expected)); assertThat( - Strings.toString(request, UnifiedCompletionRequest.withMaxCompletionTokensTokens("gpt-4o", ToXContent.EMPTY_PARAMS)), + Strings.toString(request, UnifiedCompletionRequest.withMaxCompletionTokens("gpt-4o", ToXContent.EMPTY_PARAMS)), is(XContentHelper.stripWhitespace(requestJson)) ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index 5028cc6873cbb..6fd07cd4c2831 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -106,6 +106,8 @@ import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings; import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankServiceSettings; import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankTaskSettings; +import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionServiceSettings; +import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionServiceSettings; @@ -175,6 +177,7 @@ public static List getNamedWriteables() { addJinaAINamedWriteables(namedWriteables); addVoyageAINamedWriteables(namedWriteables); addCustomNamedWriteables(namedWriteables); + addLlamaNamedWriteables(namedWriteables); addUnifiedNamedWriteables(namedWriteables); @@ -274,8 +277,25 @@ private static void addMistralNamedWriteables(List MistralChatCompletionServiceSettings::new ) ); + // no task settings for Mistral + } - // note - no task settings for Mistral embeddings... + private static void addLlamaNamedWriteables(List namedWriteables) { + namedWriteables.add( + new NamedWriteableRegistry.Entry( + ServiceSettings.class, + LlamaEmbeddingsServiceSettings.NAME, + LlamaEmbeddingsServiceSettings::new + ) + ); + namedWriteables.add( + new NamedWriteableRegistry.Entry( + ServiceSettings.class, + LlamaChatCompletionServiceSettings.NAME, + LlamaChatCompletionServiceSettings::new + ) + ); + // no task settings for Llama } private static void addAzureAiStudioNamedWriteables(List namedWriteables) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index a35d64ab84c7f..bbb1bd1a2fec2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -133,6 +133,7 @@ import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserService; import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxService; import org.elasticsearch.xpack.inference.services.jinaai.JinaAIService; +import org.elasticsearch.xpack.inference.services.llama.LlamaService; import org.elasticsearch.xpack.inference.services.mistral.MistralService; import org.elasticsearch.xpack.inference.services.openai.OpenAiService; import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerClient; @@ -402,6 +403,7 @@ public List getInferenceServiceFactories() { context -> new JinaAIService(httpFactory.get(), serviceComponents.get(), context), context -> new VoyageAIService(httpFactory.get(), serviceComponents.get(), context), context -> new DeepSeekService(httpFactory.get(), serviceComponents.get(), context), + context -> new LlamaService(httpFactory.get(), serviceComponents.get(), context), ElasticsearchInternalService::new, context -> new CustomService(httpFactory.get(), serviceComponents.get(), context) ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index adbec49328804..f5f1074bfbb86 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -17,6 +17,7 @@ import org.elasticsearch.core.Tuple; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; @@ -304,6 +305,12 @@ public static String invalidSettingError(String settingName, String scope) { return Strings.format("[%s] does not allow the setting [%s]", scope, settingName); } + public static URI extractUri(Map map, String fieldName, ValidationException validationException) { + String parsedUrl = extractRequiredString(map, fieldName, ModelConfigurations.SERVICE_SETTINGS, validationException); + + return convertToUri(parsedUrl, fieldName, ModelConfigurations.SERVICE_SETTINGS, validationException); + } + public static URI convertToUri(@Nullable String url, String settingName, String settingScope, ValidationException validationException) { try { return createOptionalUri(url); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceUnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceUnifiedChatCompletionRequestEntity.java index b45d4449251f4..007dc820c629f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceUnifiedChatCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceUnifiedChatCompletionRequestEntity.java @@ -28,7 +28,7 @@ public ElasticInferenceServiceUnifiedChatCompletionRequestEntity(UnifiedChatInpu @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - unifiedRequestEntity.toXContent(builder, UnifiedCompletionRequest.withMaxCompletionTokensTokens(modelId, params)); + unifiedRequestEntity.toXContent(builder, UnifiedCompletionRequest.withMaxCompletionTokens(modelId, params)); builder.endObject(); return builder; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettings.java index 7429153835ee3..91735d39f3973 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettings.java @@ -31,11 +31,10 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.convertToUri; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractUri; public class HuggingFaceServiceSettings extends FilteredXContentObject implements ServiceSettings, HuggingFaceRateLimitServiceSettings { public static final String NAME = "hugging_face_service_settings"; @@ -70,12 +69,6 @@ public static HuggingFaceServiceSettings fromMap(Map map, Config return new HuggingFaceServiceSettings(uri, similarityMeasure, dims, maxInputTokens, rateLimitSettings); } - public static URI extractUri(Map map, String fieldName, ValidationException validationException) { - String parsedUrl = extractRequiredString(map, fieldName, ModelConfigurations.SERVICE_SETTINGS, validationException); - - return convertToUri(parsedUrl, fieldName, ModelConfigurations.SERVICE_SETTINGS, validationException); - } - private final URI uri; private final SimilarityMeasure similarity; private final Integer dimensions; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettings.java index cdc2529428bed..64da6e32bc1f2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettings.java @@ -31,7 +31,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; -import static org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings.extractUri; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractUri; /** * Settings for the Hugging Face chat completion service. diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserServiceSettings.java index b1d3297fc6328..ad771e72b6b35 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserServiceSettings.java @@ -28,7 +28,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri; -import static org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings.extractUri; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractUri; public class HuggingFaceElserServiceSettings extends FilteredXContentObject implements diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankServiceSettings.java index b0b21b26395af..57c103bbbf3b9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankServiceSettings.java @@ -27,7 +27,7 @@ import java.util.Objects; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri; -import static org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings.extractUri; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractUri; public class HuggingFaceRerankServiceSettings extends FilteredXContentObject implements diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaModel.java new file mode 100644 index 0000000000000..3e24d058d8540 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaModel.java @@ -0,0 +1,89 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama; + +import org.elasticsearch.inference.EmptySecretSettings; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SecretSettings; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel; +import org.elasticsearch.xpack.inference.services.llama.action.LlamaActionVisitor; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.net.URI; +import java.net.URISyntaxException; +import java.util.Map; +import java.util.Objects; + +/** + * Abstract class representing a Llama model for inference. + * This class extends RateLimitGroupingModel and provides common functionality for Llama models. + */ +public abstract class LlamaModel extends RateLimitGroupingModel { + protected URI uri; + protected RateLimitSettings rateLimitSettings; + + /** + * Constructor for creating a LlamaModel with specified configurations and secrets. + * + * @param configurations the model configurations + * @param secrets the secret settings for the model + */ + protected LlamaModel(ModelConfigurations configurations, ModelSecrets secrets) { + super(configurations, secrets); + } + + /** + * Constructor for creating a LlamaModel with specified model, service settings, and secret settings. + * @param model the model configurations + * @param serviceSettings the settings for the inference service + */ + protected LlamaModel(RateLimitGroupingModel model, ServiceSettings serviceSettings) { + super(model, serviceSettings); + } + + public URI uri() { + return this.uri; + } + + @Override + public RateLimitSettings rateLimitSettings() { + return this.rateLimitSettings; + } + + @Override + public int rateLimitGroupingHash() { + return Objects.hash(getServiceSettings().modelId(), uri, getSecretSettings()); + } + + // Needed for testing only + public void setURI(String newUri) { + try { + this.uri = new URI(newUri); + } catch (URISyntaxException e) { + // swallow any error + } + } + + /** + * Retrieves the secret settings from the provided map of secrets. + * If the map is null or empty, it returns an instance of EmptySecretSettings. + * Caused by the fact that Llama model doesn't have out of the box security settings and can be used witout authentication. + * + * @param secrets the map containing secret settings + * @return an instance of SecretSettings + */ + protected static SecretSettings retrieveSecretSettings(Map secrets) { + return (secrets != null && secrets.isEmpty()) ? EmptySecretSettings.INSTANCE : DefaultSecretSettings.fromMap(secrets); + } + + protected abstract ExecutableAction accept(LlamaActionVisitor creator); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java new file mode 100644 index 0000000000000..bd6b3c91fc9e9 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java @@ -0,0 +1,423 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.util.LazyInitializable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SettingsConfiguration; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; +import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; +import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; +import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.SenderService; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.llama.action.LlamaActionCreator; +import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModel; +import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionResponseHandler; +import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.llama.request.completion.LlamaChatCompletionRequest; +import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.util.EnumSet; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.elasticsearch.inference.TaskType.CHAT_COMPLETION; +import static org.elasticsearch.inference.TaskType.COMPLETION; +import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; +import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; + +/** + * LlamaService is an inference service for Llama models, supporting text embedding and chat completion tasks. + * It extends SenderService to handle HTTP requests and responses for Llama models. + */ +public class LlamaService extends SenderService { + public static final String NAME = "llama"; + private static final String SERVICE_NAME = "Llama"; + /** + * The optimal batch size depends on the hardware the model is deployed on. + * For Llama use a conservatively small max batch size as it is + * unknown how the model is deployed + */ + static final int EMBEDDING_MAX_BATCH_SIZE = 20; + private static final EnumSet SUPPORTED_TASK_TYPES = EnumSet.of(TEXT_EMBEDDING, COMPLETION, CHAT_COMPLETION); + private static final ResponseHandler UNIFIED_CHAT_COMPLETION_HANDLER = new LlamaChatCompletionResponseHandler( + "llama chat completion", + OpenAiChatCompletionResponseEntity::fromResponse + ); + + /** + * Constructor for creating a LlamaService with specified HTTP request sender factory and service components. + * + * @param factory the factory to create HTTP request senders + * @param serviceComponents the components required for the inference service + * @param context the context for the inference service factory + */ + public LlamaService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public LlamaService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); + } + + @Override + protected void doInfer( + Model model, + InferenceInputs inputs, + Map taskSettings, + TimeValue timeout, + ActionListener listener + ) { + var actionCreator = new LlamaActionCreator(getSender(), getServiceComponents()); + if (model instanceof LlamaModel llamaModel) { + llamaModel.accept(actionCreator).execute(inputs, timeout, listener); + } else { + listener.onFailure(createInvalidModelException(model)); + } + } + + @Override + protected void validateInputType(InputType inputType, Model model, ValidationException validationException) { + ServiceUtils.validateInputTypeIsUnspecifiedOrInternal(inputType, validationException); + } + + /** + * Creates a LlamaModel based on the provided parameters. + * + * @param inferenceId the unique identifier for the inference entity + * @param taskType the type of task this model is designed for + * @param serviceSettings the settings for the inference service + * @param chunkingSettings the settings for chunking, if applicable + * @param secretSettings the secret settings for the model, such as API keys or tokens + * @param failureMessage the message to use in case of failure + * @param context the context for parsing configuration settings + * @return a new instance of LlamaModel based on the provided parameters + */ + protected LlamaModel createModel( + String inferenceId, + TaskType taskType, + Map serviceSettings, + ChunkingSettings chunkingSettings, + Map secretSettings, + String failureMessage, + ConfigurationParseContext context + ) { + switch (taskType) { + case TEXT_EMBEDDING: + return new LlamaEmbeddingsModel(inferenceId, taskType, NAME, serviceSettings, chunkingSettings, secretSettings, context); + case CHAT_COMPLETION, COMPLETION: + return new LlamaChatCompletionModel(inferenceId, taskType, NAME, serviceSettings, secretSettings, context); + default: + throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + } + } + + @Override + public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { + if (model instanceof LlamaEmbeddingsModel embeddingsModel) { + var serviceSettings = embeddingsModel.getServiceSettings(); + var similarityFromModel = serviceSettings.similarity(); + var similarityToUse = similarityFromModel == null ? SimilarityMeasure.DOT_PRODUCT : similarityFromModel; + + var updatedServiceSettings = new LlamaEmbeddingsServiceSettings( + serviceSettings.modelId(), + serviceSettings.uri(), + embeddingSize, + similarityToUse, + serviceSettings.maxInputTokens(), + serviceSettings.rateLimitSettings() + ); + + return new LlamaEmbeddingsModel(embeddingsModel, updatedServiceSettings); + } else { + throw ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails(model.getClass()); + } + } + + @Override + protected void doChunkedInfer( + Model model, + EmbeddingsInput inputs, + Map taskSettings, + InputType inputType, + TimeValue timeout, + ActionListener> listener + ) { + if (model instanceof LlamaEmbeddingsModel == false) { + listener.onFailure(createInvalidModelException(model)); + return; + } + + var llamaModel = (LlamaEmbeddingsModel) model; + var actionCreator = new LlamaActionCreator(getSender(), getServiceComponents()); + + List batchedRequests = new EmbeddingRequestChunker<>( + inputs.getInputs(), + EMBEDDING_MAX_BATCH_SIZE, + llamaModel.getConfigurations().getChunkingSettings() + ).batchRequestsWithListeners(listener); + + for (var request : batchedRequests) { + var action = llamaModel.accept(actionCreator); + action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); + } + } + + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + if (model instanceof LlamaChatCompletionModel == false) { + listener.onFailure(createInvalidModelException(model)); + return; + } + + var llamaChatCompletionModel = (LlamaChatCompletionModel) model; + var overriddenModel = LlamaChatCompletionModel.of(llamaChatCompletionModel, inputs.getRequest()); + var manager = new GenericRequestManager<>( + getServiceComponents().threadPool(), + overriddenModel, + UNIFIED_CHAT_COMPLETION_HANDLER, + unifiedChatInput -> new LlamaChatCompletionRequest(unifiedChatInput, overriddenModel), + UnifiedChatInput.class + ); + var errorMessage = LlamaActionCreator.buildErrorMessage(CHAT_COMPLETION, model.getInferenceEntityId()); + var action = new SenderExecutableAction(getSender(), manager, errorMessage); + + action.execute(inputs, timeout, listener); + } + + @Override + public Set supportedStreamingTasks() { + return EnumSet.of(COMPLETION, CHAT_COMPLETION); + } + + @Override + public InferenceServiceConfiguration getConfiguration() { + return Configuration.get(); + } + + @Override + public EnumSet supportedTaskTypes() { + return SUPPORTED_TASK_TYPES; + } + + @Override + public String name() { + return NAME; + } + + @Override + public void parseRequestConfig( + String modelId, + TaskType taskType, + Map config, + ActionListener parsedModelListener + ) { + try { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + + ChunkingSettings chunkingSettings = null; + if (TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap( + removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS) + ); + } + + LlamaModel model = createModel( + modelId, + taskType, + serviceSettingsMap, + chunkingSettings, + serviceSettingsMap, + TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), + ConfigurationParseContext.REQUEST + ); + + throwIfNotEmptyMap(config, NAME); + throwIfNotEmptyMap(serviceSettingsMap, NAME); + throwIfNotEmptyMap(taskSettingsMap, NAME); + + parsedModelListener.onResponse(model); + } catch (Exception e) { + parsedModelListener.onFailure(e); + } + } + + @Override + public Model parsePersistedConfigWithSecrets( + String modelId, + TaskType taskType, + Map config, + Map secrets + ) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + Map secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS); + + ChunkingSettings chunkingSettings = null; + if (TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); + } + + return createModelFromPersistent( + modelId, + taskType, + serviceSettingsMap, + chunkingSettings, + secretSettingsMap, + parsePersistedConfigErrorMsg(modelId, NAME) + ); + } + + private LlamaModel createModelFromPersistent( + String inferenceEntityId, + TaskType taskType, + Map serviceSettings, + ChunkingSettings chunkingSettings, + Map secretSettings, + String failureMessage + ) { + return createModel( + inferenceEntityId, + taskType, + serviceSettings, + chunkingSettings, + secretSettings, + failureMessage, + ConfigurationParseContext.PERSISTENT + ); + } + + @Override + public Model parsePersistedConfig(String modelId, TaskType taskType, Map config) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + + ChunkingSettings chunkingSettings = null; + if (TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); + } + + return createModelFromPersistent( + modelId, + taskType, + serviceSettingsMap, + chunkingSettings, + null, + parsePersistedConfigErrorMsg(modelId, NAME) + ); + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_LLAMA_ADDED; + } + + @Override + public boolean hideFromConfigurationApi() { + // The Llama service is very configurable so we're going to hide it from being exposed in the service API. + return true; + } + + /** + * Configuration class for the Llama inference service. + * It provides the settings and configurations required for the service. + */ + public static class Configuration { + public static InferenceServiceConfiguration get() { + return CONFIGURATION.getOrCompute(); + } + + private Configuration() {} + + private static final LazyInitializable CONFIGURATION = new LazyInitializable<>( + () -> { + var configurationMap = new HashMap(); + + configurationMap.put( + URL, + new SettingsConfiguration.Builder(SUPPORTED_TASK_TYPES).setDescription("The URL endpoint to use for the requests.") + .setLabel("URL") + .setRequired(true) + .setSensitive(false) + .setUpdatable(false) + .setType(SettingsConfigurationFieldType.STRING) + .build() + ); + configurationMap.put( + MODEL_ID, + new SettingsConfiguration.Builder(SUPPORTED_TASK_TYPES).setDescription( + "Refer to the Llama models documentation for the list of available models." + ) + .setLabel("Model") + .setRequired(true) + .setSensitive(false) + .setUpdatable(false) + .setType(SettingsConfigurationFieldType.STRING) + .build() + ); + configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration(SUPPORTED_TASK_TYPES)); + configurationMap.putAll(RateLimitSettings.toSettingsConfiguration(SUPPORTED_TASK_TYPES)); + + return new InferenceServiceConfiguration.Builder().setService(NAME) + .setName(SERVICE_NAME) + .setTaskTypes(SUPPORTED_TASK_TYPES) + .setConfigurations(configurationMap) + .build(); + } + ); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreator.java new file mode 100644 index 0000000000000..52e284ba7ccca --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreator.java @@ -0,0 +1,110 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.action; + +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; +import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; +import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; +import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceEmbeddingsResponseEntity; +import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModel; +import org.elasticsearch.xpack.inference.services.llama.completion.LlamaCompletionResponseHandler; +import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsResponseHandler; +import org.elasticsearch.xpack.inference.services.llama.request.completion.LlamaChatCompletionRequest; +import org.elasticsearch.xpack.inference.services.llama.request.embeddings.LlamaEmbeddingsRequest; +import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity; + +import java.util.Objects; + +import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.inference.common.Truncator.truncate; + +/** + * Creates actions for Llama inference requests, handling both embeddings and completions. + * This class implements the {@link LlamaActionVisitor} interface to provide specific action creation methods. + */ +public class LlamaActionCreator implements LlamaActionVisitor { + + private static final String FAILED_TO_SEND_REQUEST_ERROR_MESSAGE = "Failed to send Llama %s request from inference entity id [%s]"; + private static final String COMPLETION_ERROR_PREFIX = "Llama completions"; + private static final String USER_ROLE = "user"; + + private static final ResponseHandler EMBEDDINGS_HANDLER = new LlamaEmbeddingsResponseHandler( + "llama text embedding", + HuggingFaceEmbeddingsResponseEntity::fromResponse + ); + private static final ResponseHandler COMPLETION_HANDLER = new LlamaCompletionResponseHandler( + "llama completion", + OpenAiChatCompletionResponseEntity::fromResponse + ); + + private final Sender sender; + private final ServiceComponents serviceComponents; + + /** + * Constructs a new LlamaActionCreator with the specified sender and service components. + * + * @param sender the sender to use for executing actions + * @param serviceComponents the service components providing necessary services + */ + public LlamaActionCreator(Sender sender, ServiceComponents serviceComponents) { + this.sender = Objects.requireNonNull(sender); + this.serviceComponents = Objects.requireNonNull(serviceComponents); + } + + @Override + public ExecutableAction create(LlamaEmbeddingsModel model) { + var manager = new GenericRequestManager<>( + serviceComponents.threadPool(), + model, + EMBEDDINGS_HANDLER, + embeddingsInput -> new LlamaEmbeddingsRequest( + serviceComponents.truncator(), + truncate(embeddingsInput.getStringInputs(), model.getServiceSettings().maxInputTokens()), + model + ), + EmbeddingsInput.class + ); + + var errorMessage = buildErrorMessage(TaskType.TEXT_EMBEDDING, model.getInferenceEntityId()); + return new SenderExecutableAction(sender, manager, errorMessage); + } + + @Override + public ExecutableAction create(LlamaChatCompletionModel model) { + var manager = new GenericRequestManager<>( + serviceComponents.threadPool(), + model, + COMPLETION_HANDLER, + inputs -> new LlamaChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), model), + ChatCompletionInput.class + ); + + var errorMessage = buildErrorMessage(TaskType.COMPLETION, model.getInferenceEntityId()); + return new SingleInputSenderExecutableAction(sender, manager, errorMessage, COMPLETION_ERROR_PREFIX); + } + + /** + * Builds an error message for failed requests. + * + * @param requestType the type of request that failed + * @param inferenceId the inference entity ID associated with the request + * @return a formatted error message + */ + public static String buildErrorMessage(TaskType requestType, String inferenceId) { + return format(FAILED_TO_SEND_REQUEST_ERROR_MESSAGE, requestType.toString(), inferenceId); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionVisitor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionVisitor.java new file mode 100644 index 0000000000000..1521b83b668c7 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionVisitor.java @@ -0,0 +1,34 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.action; + +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModel; +import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsModel; + +/** + * Visitor interface for creating executable actions for Llama inference models. + * This interface defines methods to create actions for both embeddings and chat completion models. + */ +public interface LlamaActionVisitor { + /** + * Creates an executable action for the given Llama embeddings model. + * + * @param model the Llama embeddings model + * @return an executable action for the embeddings model + */ + ExecutableAction create(LlamaEmbeddingsModel model); + + /** + * Creates an executable action for the given Llama chat completion model. + * + * @param model the Llama chat completion model + * @return an executable action for the chat completion model + */ + ExecutableAction create(LlamaChatCompletionModel model); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionModel.java new file mode 100644 index 0000000000000..a1a38f1eae326 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionModel.java @@ -0,0 +1,132 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.completion; + +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SecretSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.llama.LlamaModel; +import org.elasticsearch.xpack.inference.services.llama.action.LlamaActionVisitor; + +import java.util.Map; + +/** + * Represents a Llama chat completion model for inference. + * This class extends the LlamaModel and provides specific configurations and settings for chat completion tasks. + */ +public class LlamaChatCompletionModel extends LlamaModel { + + /** + * Constructor for creating a LlamaChatCompletionModel with specified parameters. + * @param inferenceEntityId the unique identifier for the inference entity + * @param taskType the type of task this model is designed for + * @param service the name of the inference service + * @param serviceSettings the settings for the inference service, specific to chat completion + * @param secrets the secret settings for the model, such as API keys or tokens + * @param context the context for parsing configuration settings + */ + public LlamaChatCompletionModel( + String inferenceEntityId, + TaskType taskType, + String service, + Map serviceSettings, + Map secrets, + ConfigurationParseContext context + ) { + this( + inferenceEntityId, + taskType, + service, + LlamaChatCompletionServiceSettings.fromMap(serviceSettings, context), + retrieveSecretSettings(secrets) + ); + } + + /** + * Constructor for creating a LlamaChatCompletionModel with specified parameters. + * @param inferenceEntityId the unique identifier for the inference entity + * @param taskType the type of task this model is designed for + * @param service the name of the inference service + * @param serviceSettings the settings for the inference service, specific to chat completion + * @param secrets the secret settings for the model, such as API keys or tokens + */ + public LlamaChatCompletionModel( + String inferenceEntityId, + TaskType taskType, + String service, + LlamaChatCompletionServiceSettings serviceSettings, + SecretSettings secrets + ) { + super( + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, EmptyTaskSettings.INSTANCE), + new ModelSecrets(secrets) + ); + setPropertiesFromServiceSettings(serviceSettings); + } + + /** + * Factory method to create a LlamaChatCompletionModel with overridden model settings based on the request. + * If the request does not specify a model, the original model is returned. + * + * @param model the original LlamaChatCompletionModel + * @param request the UnifiedCompletionRequest containing potential overrides + * @return a new LlamaChatCompletionModel with overridden settings or the original model if no overrides are specified + */ + public static LlamaChatCompletionModel of(LlamaChatCompletionModel model, UnifiedCompletionRequest request) { + if (request.model() == null) { + // If no model id is specified in the request, return the original model + return model; + } + + var originalModelServiceSettings = model.getServiceSettings(); + var overriddenServiceSettings = new LlamaChatCompletionServiceSettings( + request.model(), + originalModelServiceSettings.uri(), + originalModelServiceSettings.rateLimitSettings() + ); + + return new LlamaChatCompletionModel( + model.getInferenceEntityId(), + model.getTaskType(), + model.getConfigurations().getService(), + overriddenServiceSettings, + model.getSecretSettings() + ); + } + + private void setPropertiesFromServiceSettings(LlamaChatCompletionServiceSettings serviceSettings) { + this.uri = serviceSettings.uri(); + this.rateLimitSettings = serviceSettings.rateLimitSettings(); + } + + /** + * Returns the service settings specific to Llama chat completion. + * + * @return the LlamaChatCompletionServiceSettings associated with this model + */ + @Override + public LlamaChatCompletionServiceSettings getServiceSettings() { + return (LlamaChatCompletionServiceSettings) super.getServiceSettings(); + } + + /** + * Accepts a visitor that creates an executable action for this Llama chat completion model. + * + * @param creator the visitor that creates the executable action + * @return an ExecutableAction representing this model + */ + @Override + public ExecutableAction accept(LlamaActionVisitor creator) { + return creator.create(this); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionResponseHandler.java new file mode 100644 index 0000000000000..85d60308d77d3 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionResponseHandler.java @@ -0,0 +1,180 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.completion; + +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.llama.response.LlamaErrorResponse; +import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler; + +import java.util.Locale; +import java.util.Optional; + +import static org.elasticsearch.core.Strings.format; + +/** + * Handles streaming chat completion responses and error parsing for Llama inference endpoints. + * This handler is designed to work with the unified Llama chat completion API. + */ +public class LlamaChatCompletionResponseHandler extends OpenAiUnifiedChatCompletionResponseHandler { + + private static final String LLAMA_ERROR = "llama_error"; + private static final String STREAM_ERROR = "stream_error"; + + /** + * Constructor for creating a LlamaChatCompletionResponseHandler with specified request type and response parser. + * + * @param requestType the type of request this handler will process + * @param parseFunction the function to parse the response + */ + public LlamaChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) { + super(requestType, parseFunction, LlamaErrorResponse::fromResponse); + } + + /** + * Constructor for creating a LlamaChatCompletionResponseHandler with specified request type, + * @param message the error message to include in the exception + * @param request the request that caused the error + * @param result the HTTP result containing the response + * @param errorResponse the error response parsed from the HTTP result + * @return an exception representing the error, specific to Llama chat completion + */ + @Override + protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) { + assert request.isStreaming() : "Only streaming requests support this format"; + var responseStatusCode = result.response().getStatusLine().getStatusCode(); + if (request.isStreaming()) { + var errorMessage = constructErrorMessage(message, request, errorResponse, responseStatusCode); + var restStatus = toRestStatus(responseStatusCode); + return errorResponse instanceof LlamaErrorResponse + ? new UnifiedChatCompletionException(restStatus, errorMessage, LLAMA_ERROR, restStatus.name().toLowerCase(Locale.ROOT)) + : new UnifiedChatCompletionException( + restStatus, + errorMessage, + createErrorType(errorResponse), + restStatus.name().toLowerCase(Locale.ROOT) + ); + } else { + return super.buildError(message, request, result, errorResponse); + } + } + + /** + * Builds an exception for mid-stream errors encountered during Llama chat completion requests. + * + * @param request the request that caused the error + * @param message the error message + * @param e the exception that occurred, if any + * @return a UnifiedChatCompletionException representing the error + */ + @Override + protected Exception buildMidStreamError(Request request, String message, Exception e) { + var errorResponse = StreamingLlamaErrorResponseEntity.fromString(message); + if (errorResponse instanceof StreamingLlamaErrorResponseEntity) { + return new UnifiedChatCompletionException( + RestStatus.INTERNAL_SERVER_ERROR, + format( + "%s for request from inference entity id [%s]. Error message: [%s]", + SERVER_ERROR_OBJECT, + request.getInferenceEntityId(), + errorResponse.getErrorMessage() + ), + LLAMA_ERROR, + STREAM_ERROR + ); + } else if (e != null) { + return UnifiedChatCompletionException.fromThrowable(e); + } else { + return new UnifiedChatCompletionException( + RestStatus.INTERNAL_SERVER_ERROR, + format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, request.getInferenceEntityId()), + createErrorType(errorResponse), + STREAM_ERROR + ); + } + } + + /** + * StreamingLlamaErrorResponseEntity allows creation of {@link ErrorResponse} from a JSON string. + * This entity is used to parse error responses from streaming Llama requests. + * For non-streaming requests {@link LlamaErrorResponse} should be used. + * Example error response for Bad Request error would look like: + *

+     *  {
+     *      "error": {
+     *          "message": "400: Invalid value: Model 'llama3.12:3b' not found"
+     *      }
+     *  }
+     * 
+ */ + private static class StreamingLlamaErrorResponseEntity extends ErrorResponse { + private static final ConstructingObjectParser, Void> ERROR_PARSER = new ConstructingObjectParser<>( + LLAMA_ERROR, + true, + args -> Optional.ofNullable((LlamaChatCompletionResponseHandler.StreamingLlamaErrorResponseEntity) args[0]) + ); + private static final ConstructingObjectParser< + LlamaChatCompletionResponseHandler.StreamingLlamaErrorResponseEntity, + Void> ERROR_BODY_PARSER = new ConstructingObjectParser<>( + LLAMA_ERROR, + true, + args -> new LlamaChatCompletionResponseHandler.StreamingLlamaErrorResponseEntity( + args[0] != null ? (String) args[0] : "unknown" + ) + ); + + static { + ERROR_BODY_PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField("message")); + + ERROR_PARSER.declareObjectOrNull( + ConstructingObjectParser.optionalConstructorArg(), + ERROR_BODY_PARSER, + null, + new ParseField("error") + ); + } + + /** + * Parses a streaming Llama error response from a JSON string. + * + * @param response the raw JSON string representing an error + * @return a parsed {@link ErrorResponse} or {@link ErrorResponse#UNDEFINED_ERROR} if parsing fails + */ + private static ErrorResponse fromString(String response) { + try ( + XContentParser parser = XContentFactory.xContent(XContentType.JSON) + .createParser(XContentParserConfiguration.EMPTY, response) + ) { + return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR); + } catch (Exception e) { + // swallow the error + } + + return ErrorResponse.UNDEFINED_ERROR; + } + + /** + * Constructs a StreamingLlamaErrorResponseEntity with the specified error message. + * + * @param errorMessage the error message to include in the response entity + */ + StreamingLlamaErrorResponseEntity(String errorMessage) { + super(errorMessage); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionServiceSettings.java new file mode 100644 index 0000000000000..7917a8cba5b48 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionServiceSettings.java @@ -0,0 +1,183 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.completion; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.llama.LlamaService; +import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.net.URI; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; +import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractUri; + +/** + * Represents the settings for a Llama chat completion service. + * This class encapsulates the model ID, URI, and rate limit settings for the Llama chat completion service. + */ +public class LlamaChatCompletionServiceSettings extends FilteredXContentObject implements ServiceSettings { + public static final String NAME = "llama_completion_service_settings"; + // There is no default rate limit for Llama, so we set a reasonable default of 3000 requests per minute + protected static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(3000); + + private final String modelId; + private final URI uri; + private final RateLimitSettings rateLimitSettings; + + /** + * Creates a new instance of LlamaChatCompletionServiceSettings from a map of settings. + * + * @param map the map containing the service settings + * @param context the context for parsing configuration settings + * @return a new instance of LlamaChatCompletionServiceSettings + * @throws ValidationException if required fields are missing or invalid + */ + public static LlamaChatCompletionServiceSettings fromMap(Map map, ConfigurationParseContext context) { + ValidationException validationException = new ValidationException(); + + var model = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + var uri = extractUri(map, URL, validationException); + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + LlamaService.NAME, + context + ); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new LlamaChatCompletionServiceSettings(model, uri, rateLimitSettings); + } + + /** + * Constructs a new LlamaChatCompletionServiceSettings from a StreamInput. + * + * @param in the StreamInput to read from + * @throws IOException if an I/O error occurs during reading + */ + public LlamaChatCompletionServiceSettings(StreamInput in) throws IOException { + this.modelId = in.readString(); + this.uri = createUri(in.readString()); + this.rateLimitSettings = new RateLimitSettings(in); + } + + /** + * Constructs a new LlamaChatCompletionServiceSettings with the specified model ID, URI, and rate limit settings. + * + * @param modelId the ID of the model + * @param uri the URI of the service + * @param rateLimitSettings the rate limit settings for the service + */ + public LlamaChatCompletionServiceSettings(String modelId, URI uri, @Nullable RateLimitSettings rateLimitSettings) { + this.modelId = modelId; + this.uri = uri; + this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); + } + + /** + * Constructs a new LlamaChatCompletionServiceSettings with the specified model ID and URL. + * The rate limit settings will be set to the default value. + * + * @param modelId the ID of the model + * @param url the URL of the service + */ + public LlamaChatCompletionServiceSettings(String modelId, String url, @Nullable RateLimitSettings rateLimitSettings) { + this(modelId, createUri(url), rateLimitSettings); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_LLAMA_ADDED; + } + + @Override + public String modelId() { + return this.modelId; + } + + /** + * Returns the URI of the Llama chat completion service. + * + * @return the URI of the service + */ + public URI uri() { + return this.uri; + } + + /** + * Returns the rate limit settings for the Llama chat completion service. + * + * @return the rate limit settings + */ + public RateLimitSettings rateLimitSettings() { + return this.rateLimitSettings; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(modelId); + out.writeString(uri.toString()); + rateLimitSettings.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + toXContentFragmentOfExposedFields(builder, params); + builder.endObject(); + return builder; + } + + @Override + protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { + builder.field(MODEL_ID, modelId); + builder.field(URL, uri.toString()); + rateLimitSettings.toXContent(builder, params); + + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + LlamaChatCompletionServiceSettings that = (LlamaChatCompletionServiceSettings) o; + return Objects.equals(modelId, that.modelId) + && Objects.equals(uri, that.uri) + && Objects.equals(rateLimitSettings, that.rateLimitSettings); + } + + @Override + public int hashCode() { + return Objects.hash(modelId, uri, rateLimitSettings); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaCompletionResponseHandler.java new file mode 100644 index 0000000000000..8e3b5b10df900 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaCompletionResponseHandler.java @@ -0,0 +1,29 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.completion; + +import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; +import org.elasticsearch.xpack.inference.services.llama.response.LlamaErrorResponse; +import org.elasticsearch.xpack.inference.services.openai.OpenAiChatCompletionResponseHandler; + +/** + * Handles non-streaming completion responses for Llama models, extending the OpenAI completion response handler. + * This class is specifically designed to handle Llama's error response format. + */ +public class LlamaCompletionResponseHandler extends OpenAiChatCompletionResponseHandler { + + /** + * Constructs a LlamaCompletionResponseHandler with the specified request type and response parser. + * + * @param requestType The type of request being handled (e.g., "llama completions"). + * @param parseFunction The function to parse the response. + */ + public LlamaCompletionResponseHandler(String requestType, ResponseParser parseFunction) { + super(requestType, parseFunction, LlamaErrorResponse::fromResponse); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsModel.java new file mode 100644 index 0000000000000..ebf0b7e8132c1 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsModel.java @@ -0,0 +1,119 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.embeddings; + +import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SecretSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.llama.LlamaModel; +import org.elasticsearch.xpack.inference.services.llama.action.LlamaActionVisitor; + +import java.util.Map; + +/** + * Represents a Llama embeddings model for inference. + * This class extends the LlamaModel and provides specific configurations and settings for embeddings tasks. + */ +public class LlamaEmbeddingsModel extends LlamaModel { + + /** + * Constructor for creating a LlamaEmbeddingsModel with specified parameters. + * + * @param inferenceEntityId the unique identifier for the inference entity + * @param taskType the type of task this model is designed for + * @param service the name of the inference service + * @param serviceSettings the settings for the inference service, specific to embeddings + * @param secrets the secret settings for the model, such as API keys or tokens + * @param context the context for parsing configuration settings + */ + public LlamaEmbeddingsModel( + String inferenceEntityId, + TaskType taskType, + String service, + Map serviceSettings, + ChunkingSettings chunkingSettings, + Map secrets, + ConfigurationParseContext context + ) { + this( + inferenceEntityId, + taskType, + service, + LlamaEmbeddingsServiceSettings.fromMap(serviceSettings, context), + chunkingSettings, + retrieveSecretSettings(secrets) + ); + } + + /** + * Constructor for creating a LlamaEmbeddingsModel with specified parameters. + * + * @param model the base LlamaEmbeddingsModel to copy properties from + * @param serviceSettings the settings for the inference service, specific to embeddings + */ + public LlamaEmbeddingsModel(LlamaEmbeddingsModel model, LlamaEmbeddingsServiceSettings serviceSettings) { + super(model, serviceSettings); + setPropertiesFromServiceSettings(serviceSettings); + } + + /** + * Sets properties from the provided LlamaEmbeddingsServiceSettings. + * + * @param serviceSettings the service settings to extract properties from + */ + private void setPropertiesFromServiceSettings(LlamaEmbeddingsServiceSettings serviceSettings) { + this.uri = serviceSettings.uri(); + this.rateLimitSettings = serviceSettings.rateLimitSettings(); + } + + /** + * Constructor for creating a LlamaEmbeddingsModel with specified parameters. + * + * @param inferenceEntityId the unique identifier for the inference entity + * @param taskType the type of task this model is designed for + * @param service the name of the inference service + * @param serviceSettings the settings for the inference service, specific to embeddings + * @param chunkingSettings the chunking settings for processing input data + * @param secrets the secret settings for the model, such as API keys or tokens + */ + public LlamaEmbeddingsModel( + String inferenceEntityId, + TaskType taskType, + String service, + LlamaEmbeddingsServiceSettings serviceSettings, + ChunkingSettings chunkingSettings, + SecretSettings secrets + ) { + super( + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, EmptyTaskSettings.INSTANCE, chunkingSettings), + new ModelSecrets(secrets) + ); + setPropertiesFromServiceSettings(serviceSettings); + } + + @Override + public LlamaEmbeddingsServiceSettings getServiceSettings() { + return (LlamaEmbeddingsServiceSettings) super.getServiceSettings(); + } + + /** + * Accepts a visitor to create an executable action for this Llama embeddings model. + * + * @param creator the visitor that creates the executable action + * @return an ExecutableAction representing the Llama embeddings model + */ + @Override + public ExecutableAction accept(LlamaActionVisitor creator) { + return creator.create(this); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsResponseHandler.java new file mode 100644 index 0000000000000..240ccf46c7482 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsResponseHandler.java @@ -0,0 +1,29 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.embeddings; + +import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; +import org.elasticsearch.xpack.inference.services.llama.response.LlamaErrorResponse; +import org.elasticsearch.xpack.inference.services.openai.OpenAiResponseHandler; + +/** + * Handles responses for Llama embeddings requests, parsing the response and handling errors. + * This class extends OpenAiResponseHandler to provide specific functionality for Llama embeddings. + */ +public class LlamaEmbeddingsResponseHandler extends OpenAiResponseHandler { + + /** + * Constructs a new LlamaEmbeddingsResponseHandler with the specified request type and response parser. + * + * @param requestType the type of request this handler will process + * @param parseFunction the function to parse the response + */ + public LlamaEmbeddingsResponseHandler(String requestType, ResponseParser parseFunction) { + super(requestType, parseFunction, LlamaErrorResponse::fromResponse, false); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsServiceSettings.java new file mode 100644 index 0000000000000..a14146070247a --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsServiceSettings.java @@ -0,0 +1,257 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.embeddings; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.llama.LlamaService; +import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.net.URI; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; +import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; +import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractUri; + +/** + * Settings for the Llama embeddings service. + * This class encapsulates the configuration settings required to use Llama for generating embeddings. + */ +public class LlamaEmbeddingsServiceSettings extends FilteredXContentObject implements ServiceSettings { + public static final String NAME = "llama_embeddings_service_settings"; + // There is no default rate limit for Llama, so we set a reasonable default of 3000 requests per minute + protected static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(3000); + + private final String modelId; + private final URI uri; + private final Integer dimensions; + private final SimilarityMeasure similarity; + private final Integer maxInputTokens; + private final RateLimitSettings rateLimitSettings; + + /** + * Creates a new instance of LlamaEmbeddingsServiceSettings from a map of settings. + * + * @param map the map containing the settings + * @param context the context for parsing configuration settings + * @return a new instance of LlamaEmbeddingsServiceSettings + * @throws ValidationException if any required fields are missing or invalid + */ + public static LlamaEmbeddingsServiceSettings fromMap(Map map, ConfigurationParseContext context) { + ValidationException validationException = new ValidationException(); + + var model = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + var uri = extractUri(map, URL, validationException); + var dimensions = extractOptionalPositiveInteger(map, DIMENSIONS, ModelConfigurations.SERVICE_SETTINGS, validationException); + var similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException); + var maxInputTokens = extractOptionalPositiveInteger( + map, + MAX_INPUT_TOKENS, + ModelConfigurations.SERVICE_SETTINGS, + validationException + ); + var rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException, LlamaService.NAME, context); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new LlamaEmbeddingsServiceSettings(model, uri, dimensions, similarity, maxInputTokens, rateLimitSettings); + } + + /** + * Constructs a new LlamaEmbeddingsServiceSettings from a StreamInput. + * + * @param in the StreamInput to read from + * @throws IOException if an I/O error occurs during reading + */ + public LlamaEmbeddingsServiceSettings(StreamInput in) throws IOException { + this.modelId = in.readString(); + this.uri = createUri(in.readString()); + this.dimensions = in.readOptionalVInt(); + this.similarity = in.readOptionalEnum(SimilarityMeasure.class); + this.maxInputTokens = in.readOptionalVInt(); + this.rateLimitSettings = new RateLimitSettings(in); + } + + /** + * Constructs a new LlamaEmbeddingsServiceSettings with the specified parameters. + * + * @param modelId the identifier for the model + * @param uri the URI of the Llama service + * @param dimensions the number of dimensions for the embeddings, can be null + * @param similarity the similarity measure to use, can be null + * @param maxInputTokens the maximum number of input tokens, can be null + * @param rateLimitSettings the rate limit settings for the service, can be null + */ + public LlamaEmbeddingsServiceSettings( + String modelId, + URI uri, + @Nullable Integer dimensions, + @Nullable SimilarityMeasure similarity, + @Nullable Integer maxInputTokens, + @Nullable RateLimitSettings rateLimitSettings + ) { + this.modelId = modelId; + this.uri = uri; + this.dimensions = dimensions; + this.similarity = similarity; + this.maxInputTokens = maxInputTokens; + this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); + } + + /** + * Constructs a new LlamaEmbeddingsServiceSettings with the specified parameters. + * + * @param modelId the identifier for the model + * @param url the URL of the Llama service + * @param dimensions the number of dimensions for the embeddings, can be null + * @param similarity the similarity measure to use, can be null + * @param maxInputTokens the maximum number of input tokens, can be null + * @param rateLimitSettings the rate limit settings for the service, can be null + */ + public LlamaEmbeddingsServiceSettings( + String modelId, + String url, + @Nullable Integer dimensions, + @Nullable SimilarityMeasure similarity, + @Nullable Integer maxInputTokens, + @Nullable RateLimitSettings rateLimitSettings + ) { + this(modelId, createUri(url), dimensions, similarity, maxInputTokens, rateLimitSettings); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_LLAMA_ADDED; + } + + @Override + public String modelId() { + return this.modelId; + } + + public URI uri() { + return this.uri; + } + + @Override + public Integer dimensions() { + return this.dimensions; + } + + @Override + public SimilarityMeasure similarity() { + return this.similarity; + } + + @Override + public DenseVectorFieldMapper.ElementType elementType() { + return DenseVectorFieldMapper.ElementType.FLOAT; + } + + /** + * Returns the maximum number of input tokens allowed for this service. + * + * @return the maximum input tokens, or null if not specified + */ + public Integer maxInputTokens() { + return this.maxInputTokens; + } + + /** + * Returns the rate limit settings for this service. + * + * @return the rate limit settings, never null + */ + public RateLimitSettings rateLimitSettings() { + return this.rateLimitSettings; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(modelId); + out.writeString(uri.toString()); + out.writeOptionalVInt(dimensions); + out.writeOptionalEnum(SimilarityMeasure.translateSimilarity(similarity, out.getTransportVersion())); + out.writeOptionalVInt(maxInputTokens); + rateLimitSettings.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + toXContentFragmentOfExposedFields(builder, params); + builder.endObject(); + return builder; + } + + @Override + protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { + builder.field(MODEL_ID, modelId); + builder.field(URL, uri.toString()); + + if (dimensions != null) { + builder.field(DIMENSIONS, dimensions); + } + if (similarity != null) { + builder.field(SIMILARITY, similarity); + } + if (maxInputTokens != null) { + builder.field(MAX_INPUT_TOKENS, maxInputTokens); + } + rateLimitSettings.toXContent(builder, params); + + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + LlamaEmbeddingsServiceSettings that = (LlamaEmbeddingsServiceSettings) o; + return Objects.equals(modelId, that.modelId) + && Objects.equals(uri, that.uri) + && Objects.equals(dimensions, that.dimensions) + && Objects.equals(maxInputTokens, that.maxInputTokens) + && Objects.equals(similarity, that.similarity) + && Objects.equals(rateLimitSettings, that.rateLimitSettings); + } + + @Override + public int hashCode() { + return Objects.hash(modelId, uri, dimensions, maxInputTokens, similarity, rateLimitSettings); + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequest.java new file mode 100644 index 0000000000000..3bb01f215087e --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequest.java @@ -0,0 +1,96 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.request.completion; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.entity.ByteArrayEntity; +import org.elasticsearch.common.Strings; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModel; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; + +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; + +/** + * Llama Chat Completion Request + * This class is responsible for creating a request to the Llama chat completion model. + * It constructs an HTTP POST request with the necessary headers and body content. + */ +public class LlamaChatCompletionRequest implements Request { + + private final LlamaChatCompletionModel model; + private final UnifiedChatInput chatInput; + + /** + * Constructs a new LlamaChatCompletionRequest with the specified chat input and model. + * + * @param chatInput the chat input containing the messages and parameters for the completion request + * @param model the Llama chat completion model to be used for the request + */ + public LlamaChatCompletionRequest(UnifiedChatInput chatInput, LlamaChatCompletionModel model) { + this.chatInput = Objects.requireNonNull(chatInput); + this.model = Objects.requireNonNull(model); + } + + /** + * Returns the chat input for this request. + * + * @return the chat input containing the messages and parameters + */ + @Override + public HttpRequest createHttpRequest() { + HttpPost httpPost = new HttpPost(model.uri()); + + ByteArrayEntity byteEntity = new ByteArrayEntity( + Strings.toString(new LlamaChatCompletionRequestEntity(chatInput, model)).getBytes(StandardCharsets.UTF_8) + ); + httpPost.setEntity(byteEntity); + + httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaTypeWithoutParameters()); + if (model.getSecretSettings() instanceof DefaultSecretSettings secretSettings) { + httpPost.setHeader(createAuthBearerHeader(secretSettings.apiKey())); + } + + return new HttpRequest(httpPost, getInferenceEntityId()); + } + + @Override + public URI getURI() { + return model.uri(); + } + + @Override + public Request truncate() { + // No truncation for Llama chat completions + return this; + } + + @Override + public boolean[] getTruncationInfo() { + // No truncation for Llama chat completions + return null; + } + + @Override + public String getInferenceEntityId() { + return model.getInferenceEntityId(); + } + + @Override + public boolean isStreaming() { + return chatInput.stream(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequestEntity.java new file mode 100644 index 0000000000000..fc80dab09f6f5 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequestEntity.java @@ -0,0 +1,47 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.request.completion; + +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.external.unified.UnifiedChatCompletionRequestEntity; +import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModel; + +import java.io.IOException; +import java.util.Objects; + +/** + * LlamaChatCompletionRequestEntity is responsible for creating the request entity for Llama chat completion. + * It implements ToXContentObject to allow serialization to XContent format. + */ +public class LlamaChatCompletionRequestEntity implements ToXContentObject { + + private final LlamaChatCompletionModel model; + private final UnifiedChatCompletionRequestEntity unifiedRequestEntity; + + /** + * Constructs a LlamaChatCompletionRequestEntity with the specified unified chat input and model. + * + * @param unifiedChatInput the unified chat input containing messages and parameters for the completion request + * @param model the Llama chat completion model to be used for the request + */ + public LlamaChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, LlamaChatCompletionModel model) { + this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(unifiedChatInput); + this.model = Objects.requireNonNull(model); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + unifiedRequestEntity.toXContent(builder, UnifiedCompletionRequest.withMaxTokens(model.getServiceSettings().modelId(), params)); + builder.endObject(); + return builder; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/request/embeddings/LlamaEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/request/embeddings/LlamaEmbeddingsRequest.java new file mode 100644 index 0000000000000..5883880dbb812 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/request/embeddings/LlamaEmbeddingsRequest.java @@ -0,0 +1,96 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.request.embeddings; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.entity.ByteArrayEntity; +import org.elasticsearch.common.Strings; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.common.Truncator; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; + +import java.net.URI; +import java.nio.charset.StandardCharsets; + +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; + +/** + * Llama Embeddings Request + * This class is responsible for creating a request to the Llama embeddings model. + * It constructs an HTTP POST request with the necessary headers and body content. + */ +public class LlamaEmbeddingsRequest implements Request { + private final URI uri; + private final LlamaEmbeddingsModel model; + private final String inferenceEntityId; + private final Truncator.TruncationResult truncationResult; + private final Truncator truncator; + + /** + * Constructs a new LlamaEmbeddingsRequest with the specified truncator, input, and model. + * + * @param truncator the truncator to handle input truncation + * @param input the input to be truncated + * @param model the Llama embeddings model to be used for the request + */ + public LlamaEmbeddingsRequest(Truncator truncator, Truncator.TruncationResult input, LlamaEmbeddingsModel model) { + this.uri = model.uri(); + this.model = model; + this.inferenceEntityId = model.getInferenceEntityId(); + this.truncator = truncator; + this.truncationResult = input; + } + + /** + * Returns the URI for this request. + * + * @return the URI of the Llama embeddings model + */ + @Override + public HttpRequest createHttpRequest() { + HttpPost httpPost = new HttpPost(this.uri); + + ByteArrayEntity byteEntity = new ByteArrayEntity( + Strings.toString(new LlamaEmbeddingsRequestEntity(model.getServiceSettings().modelId(), truncationResult.input())) + .getBytes(StandardCharsets.UTF_8) + ); + httpPost.setEntity(byteEntity); + + httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaTypeWithoutParameters()); + if (model.getSecretSettings() instanceof DefaultSecretSettings secretSettings) { + httpPost.setHeader(createAuthBearerHeader(secretSettings.apiKey())); + } + + return new HttpRequest(httpPost, getInferenceEntityId()); + } + + @Override + public URI getURI() { + return uri; + } + + @Override + public Request truncate() { + var truncatedInput = truncator.truncate(truncationResult.input()); + return new LlamaEmbeddingsRequest(truncator, truncatedInput, model); + } + + @Override + public boolean[] getTruncationInfo() { + return truncationResult.truncated().clone(); + } + + @Override + public String getInferenceEntityId() { + return inferenceEntityId; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/request/embeddings/LlamaEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/request/embeddings/LlamaEmbeddingsRequestEntity.java new file mode 100644 index 0000000000000..3f734bacec87d --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/request/embeddings/LlamaEmbeddingsRequestEntity.java @@ -0,0 +1,51 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.request.embeddings; + +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +/** + * LlamaEmbeddingsRequestEntity is responsible for creating the request entity for Llama embeddings. + * It implements ToXContentObject to allow serialization to XContent format. + */ +public record LlamaEmbeddingsRequestEntity(String modelId, List contents) implements ToXContentObject { + + public static final String CONTENTS_FIELD = "contents"; + public static final String MODEL_ID_FIELD = "model_id"; + + /** + * Constructs a LlamaEmbeddingsRequestEntity with the specified model ID and contents. + * + * @param modelId the ID of the model to use for embeddings + * @param contents the list of contents to generate embeddings for + */ + public LlamaEmbeddingsRequestEntity { + Objects.requireNonNull(modelId); + Objects.requireNonNull(contents); + } + + /** + * Constructs a LlamaEmbeddingsRequestEntity with the specified model ID and a single content string. + */ + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + builder.field(MODEL_ID_FIELD, modelId); + builder.field(CONTENTS_FIELD, contents); + + builder.endObject(); + + return builder; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/response/LlamaErrorResponse.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/response/LlamaErrorResponse.java new file mode 100644 index 0000000000000..727231209fdf1 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/response/LlamaErrorResponse.java @@ -0,0 +1,59 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.response; + +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; + +import java.nio.charset.StandardCharsets; + +/** + * LlamaErrorResponse is responsible for handling error responses from Llama inference services. + * It extends ErrorResponse to provide specific functionality for Llama errors. + * An example error response for Not Found error would look like: + *

+ *  {
+ *      "detail": "Not Found"
+ *  }
+ * 
+ * An example error response for Bad Request error would look like: + *

+ *  {
+ *     "error": {
+ *         "detail": {
+ *             "errors": [
+ *                 {
+ *                     "loc": [
+ *                         "body",
+ *                         "model"
+ *                     ],
+ *                     "msg": "Field required",
+ *                     "type": "missing"
+ *                 }
+ *             ]
+ *         }
+ *     }
+ *  }
+ * 
+ */ +public class LlamaErrorResponse extends ErrorResponse { + + public LlamaErrorResponse(String message) { + super(message); + } + + public static ErrorResponse fromResponse(HttpResult response) { + try { + String errorMessage = new String(response.body(), StandardCharsets.UTF_8); + return new LlamaErrorResponse(errorMessage); + } catch (Exception e) { + // swallow the error + } + return ErrorResponse.UNDEFINED_ERROR; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralModel.java index 57219a03b3bdb..55a5b4fe71047 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralModel.java @@ -10,19 +10,21 @@ import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel; +import org.elasticsearch.xpack.inference.services.mistral.action.MistralActionVisitor; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import java.net.URI; import java.net.URISyntaxException; +import java.util.Objects; /** * Represents a Mistral model that can be used for inference tasks. * This class extends RateLimitGroupingModel to handle rate limiting based on model and API key. */ public abstract class MistralModel extends RateLimitGroupingModel { - protected String model; protected URI uri; protected RateLimitSettings rateLimitSettings; @@ -34,10 +36,6 @@ protected MistralModel(RateLimitGroupingModel model, ServiceSettings serviceSett super(model, serviceSettings); } - public String model() { - return this.model; - } - public URI uri() { return this.uri; } @@ -49,7 +47,7 @@ public RateLimitSettings rateLimitSettings() { @Override public int rateLimitGroupingHash() { - return 0; + return Objects.hash(getServiceSettings().modelId(), getSecretSettings().apiKey()); } // Needed for testing only @@ -65,4 +63,6 @@ public void setURI(String newUri) { public DefaultSecretSettings getSecretSettings() { return (DefaultSecretSettings) super.getSecretSettings(); } + + public abstract ExecutableAction accept(MistralActionVisitor creator); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java index 3048847ea90d7..c1eee5eb27338 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java @@ -108,16 +108,10 @@ protected void doInfer( ) { var actionCreator = new MistralActionCreator(getSender(), getServiceComponents()); - switch (model) { - case MistralEmbeddingsModel mistralEmbeddingsModel: - mistralEmbeddingsModel.accept(actionCreator, taskSettings).execute(inputs, timeout, listener); - break; - case MistralChatCompletionModel mistralChatCompletionModel: - mistralChatCompletionModel.accept(actionCreator).execute(inputs, timeout, listener); - break; - default: - listener.onFailure(createInvalidModelException(model)); - break; + if (model instanceof MistralModel mistralModel) { + mistralModel.accept(actionCreator).execute(inputs, timeout, listener); + } else { + listener.onFailure(createInvalidModelException(model)); } } @@ -172,7 +166,7 @@ protected void doChunkedInfer( ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { - var action = mistralEmbeddingsModel.accept(actionCreator, taskSettings); + var action = mistralEmbeddingsModel.accept(actionCreator); action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); } } else { @@ -217,7 +211,6 @@ public void parseRequestConfig( modelId, taskType, serviceSettingsMap, - taskSettingsMap, chunkingSettings, serviceSettingsMap, TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), @@ -242,7 +235,7 @@ public MistralModel parsePersistedConfigWithSecrets( Map secrets ) { Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); Map secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS); ChunkingSettings chunkingSettings = null; @@ -254,7 +247,6 @@ public MistralModel parsePersistedConfigWithSecrets( modelId, taskType, serviceSettingsMap, - taskSettingsMap, chunkingSettings, secretSettingsMap, parsePersistedConfigErrorMsg(modelId, NAME) @@ -264,7 +256,7 @@ public MistralModel parsePersistedConfigWithSecrets( @Override public MistralModel parsePersistedConfig(String modelId, TaskType taskType, Map config) { Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); ChunkingSettings chunkingSettings = null; if (TaskType.TEXT_EMBEDDING.equals(taskType)) { @@ -275,7 +267,6 @@ public MistralModel parsePersistedConfig(String modelId, TaskType taskType, Map< modelId, taskType, serviceSettingsMap, - taskSettingsMap, chunkingSettings, null, parsePersistedConfigErrorMsg(modelId, NAME) @@ -296,7 +287,6 @@ private static MistralModel createModel( String modelId, TaskType taskType, Map serviceSettings, - Map taskSettings, ChunkingSettings chunkingSettings, @Nullable Map secretSettings, String failureMessage, @@ -304,16 +294,7 @@ private static MistralModel createModel( ) { switch (taskType) { case TEXT_EMBEDDING: - return new MistralEmbeddingsModel( - modelId, - taskType, - NAME, - serviceSettings, - taskSettings, - chunkingSettings, - secretSettings, - context - ); + return new MistralEmbeddingsModel(modelId, taskType, NAME, serviceSettings, chunkingSettings, secretSettings, context); case CHAT_COMPLETION, COMPLETION: return new MistralChatCompletionModel(modelId, taskType, NAME, serviceSettings, secretSettings, context); default: @@ -325,7 +306,6 @@ private MistralModel createModelFromPersistent( String inferenceEntityId, TaskType taskType, Map serviceSettings, - Map taskSettings, ChunkingSettings chunkingSettings, Map secretSettings, String failureMessage @@ -334,7 +314,6 @@ private MistralModel createModelFromPersistent( inferenceEntityId, taskType, serviceSettings, - taskSettings, chunkingSettings, secretSettings, failureMessage, @@ -369,10 +348,10 @@ public MistralEmbeddingsModel updateModelWithEmbeddingDetails(Model model, int e */ public static class Configuration { public static InferenceServiceConfiguration get() { - return configuration.getOrCompute(); + return CONFIGURATION.getOrCompute(); } - private static final LazyInitializable configuration = new LazyInitializable<>( + private static final LazyInitializable CONFIGURATION = new LazyInitializable<>( () -> { var configurationMap = new HashMap(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralActionCreator.java index fbf842f4fb789..ba7377c3209e2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralActionCreator.java @@ -24,7 +24,6 @@ import org.elasticsearch.xpack.inference.services.mistral.request.completion.MistralChatCompletionRequest; import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity; -import java.util.Map; import java.util.Objects; import static org.elasticsearch.core.Strings.format; @@ -51,7 +50,7 @@ public MistralActionCreator(Sender sender, ServiceComponents serviceComponents) } @Override - public ExecutableAction create(MistralEmbeddingsModel embeddingsModel, Map taskSettings) { + public ExecutableAction create(MistralEmbeddingsModel embeddingsModel) { var requestManager = new MistralEmbeddingsRequestManager( embeddingsModel, serviceComponents.truncator(), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralActionVisitor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralActionVisitor.java index 5f494e4d65477..e1c4b12883c56 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralActionVisitor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralActionVisitor.java @@ -11,8 +11,6 @@ import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionModel; import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsModel; -import java.util.Map; - /** * Interface for creating {@link ExecutableAction} instances for Mistral models. *

@@ -25,10 +23,9 @@ public interface MistralActionVisitor { * Creates an {@link ExecutableAction} for the given {@link MistralEmbeddingsModel}. * * @param embeddingsModel The model to create the action for. - * @param taskSettings The task settings to use. * @return An {@link ExecutableAction} for the given model. */ - ExecutableAction create(MistralEmbeddingsModel embeddingsModel, Map taskSettings); + ExecutableAction create(MistralEmbeddingsModel embeddingsModel); /** * Creates an {@link ExecutableAction} for the given {@link MistralChatCompletionModel}. diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/completion/MistralChatCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/completion/MistralChatCompletionModel.java index 03fe502a82807..876c46edcb70d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/completion/MistralChatCompletionModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/completion/MistralChatCompletionModel.java @@ -22,7 +22,6 @@ import java.net.URI; import java.net.URISyntaxException; import java.util.Map; -import java.util.Objects; import static org.elasticsearch.xpack.inference.services.mistral.MistralConstants.API_COMPLETIONS_PATH; @@ -95,23 +94,17 @@ public MistralChatCompletionModel( DefaultSecretSettings secrets ) { super( - new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, new EmptyTaskSettings()), + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, EmptyTaskSettings.INSTANCE), new ModelSecrets(secrets) ); setPropertiesFromServiceSettings(serviceSettings); } private void setPropertiesFromServiceSettings(MistralChatCompletionServiceSettings serviceSettings) { - this.model = serviceSettings.modelId(); this.rateLimitSettings = serviceSettings.rateLimitSettings(); setEndpointUrl(); } - @Override - public int rateLimitGroupingHash() { - return Objects.hash(model, getSecretSettings().apiKey()); - } - private void setEndpointUrl() { try { this.uri = new URI(API_COMPLETIONS_PATH); @@ -131,6 +124,7 @@ public MistralChatCompletionServiceSettings getServiceSettings() { * @param creator The visitor that creates the executable action. * @return An ExecutableAction that can be executed. */ + @Override public ExecutableAction accept(MistralActionVisitor creator) { return creator.create(this); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsModel.java index 48d2fecc5ce13..8ac186ac9d642 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsModel.java @@ -12,7 +12,6 @@ import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; -import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; @@ -37,7 +36,6 @@ public MistralEmbeddingsModel( TaskType taskType, String service, Map serviceSettings, - Map taskSettings, ChunkingSettings chunkingSettings, @Nullable Map secrets, ConfigurationParseContext context @@ -47,7 +45,6 @@ public MistralEmbeddingsModel( taskType, service, MistralEmbeddingsServiceSettings.fromMap(serviceSettings, context), - EmptyTaskSettings.INSTANCE, // no task settings for Mistral embeddings chunkingSettings, DefaultSecretSettings.fromMap(secrets) ); @@ -59,7 +56,6 @@ public MistralEmbeddingsModel(MistralEmbeddingsModel model, MistralEmbeddingsSer } private void setPropertiesFromServiceSettings(MistralEmbeddingsServiceSettings serviceSettings) { - this.model = serviceSettings.modelId(); this.rateLimitSettings = serviceSettings.rateLimitSettings(); setEndpointUrl(); } @@ -77,12 +73,11 @@ public MistralEmbeddingsModel( TaskType taskType, String service, MistralEmbeddingsServiceSettings serviceSettings, - TaskSettings taskSettings, ChunkingSettings chunkingSettings, DefaultSecretSettings secrets ) { super( - new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, new EmptyTaskSettings(), chunkingSettings), + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, EmptyTaskSettings.INSTANCE, chunkingSettings), new ModelSecrets(secrets) ); setPropertiesFromServiceSettings(serviceSettings); @@ -93,7 +88,8 @@ public MistralEmbeddingsServiceSettings getServiceSettings() { return (MistralEmbeddingsServiceSettings) super.getServiceSettings(); } - public ExecutableAction accept(MistralActionVisitor creator, Map taskSettings) { - return creator.create(this, taskSettings); + @Override + public ExecutableAction accept(MistralActionVisitor creator) { + return creator.create(this); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsServiceSettings.java index 6b1c7d36a9fe6..4cf1fef3c92c4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsServiceSettings.java @@ -178,12 +178,13 @@ public boolean equals(Object o) { return Objects.equals(model, that.model) && Objects.equals(dimensions, that.dimensions) && Objects.equals(maxInputTokens, that.maxInputTokens) - && Objects.equals(similarity, that.similarity); + && Objects.equals(similarity, that.similarity) + && Objects.equals(rateLimitSettings, that.rateLimitSettings); } @Override public int hashCode() { - return Objects.hash(model, dimensions, maxInputTokens, similarity); + return Objects.hash(model, dimensions, maxInputTokens, similarity, rateLimitSettings); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/request/embeddings/MistralEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/request/embeddings/MistralEmbeddingsRequest.java index 8b772d4b8f2ed..b7d3866bcebfd 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/request/embeddings/MistralEmbeddingsRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/request/embeddings/MistralEmbeddingsRequest.java @@ -42,7 +42,7 @@ public HttpRequest createHttpRequest() { HttpPost httpPost = new HttpPost(this.uri); ByteArrayEntity byteEntity = new ByteArrayEntity( - Strings.toString(new MistralEmbeddingsRequestEntity(embeddingsModel.model(), truncationResult.input())) + Strings.toString(new MistralEmbeddingsRequestEntity(embeddingsModel.getServiceSettings().modelId(), truncationResult.input())) .getBytes(StandardCharsets.UTF_8) ); httpPost.setEntity(byteEntity); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedStreamingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedStreamingProcessor.java index 3120f1ff92e48..957203b5ee802 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedStreamingProcessor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedStreamingProcessor.java @@ -198,7 +198,7 @@ private static class DeltaParser { PARSER.declareStringOrNull(ConstructingObjectParser.optionalConstructorArg(), new ParseField(CONTENT_FIELD)); PARSER.declareStringOrNull(ConstructingObjectParser.optionalConstructorArg(), new ParseField(REFUSAL_FIELD)); PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(ROLE_FIELD)); - PARSER.declareObjectArray( + PARSER.declareObjectArrayOrNull( ConstructingObjectParser.optionalConstructorArg(), (p, c) -> ChatCompletionChunkParser.ToolCallParser.parse(p), new ParseField(TOOL_CALLS_FIELD) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/request/OpenAiUnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/request/OpenAiUnifiedChatCompletionRequestEntity.java index 928ed3ff444e6..2ae70cb52b565 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/request/OpenAiUnifiedChatCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/request/OpenAiUnifiedChatCompletionRequestEntity.java @@ -34,7 +34,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.startObject(); unifiedRequestEntity.toXContent( builder, - UnifiedCompletionRequest.withMaxCompletionTokensTokens(model.getServiceSettings().modelId(), params) + UnifiedCompletionRequest.withMaxCompletionTokens(model.getServiceSettings().modelId(), params) ); if (Strings.isNullOrEmpty(model.getTaskSettings().user()) == false) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java index aeb09af03ebab..4a4c59f091abf 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java @@ -23,7 +23,6 @@ import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; -import org.elasticsearch.xpack.inference.services.custom.CustomModel; import org.junit.After; import org.junit.Assume; import org.junit.Before; @@ -141,7 +140,7 @@ public boolean isEnabled() { return true; } - protected abstract CustomModel createEmbeddingModel(@Nullable SimilarityMeasure similarityMeasure); + protected abstract Model createEmbeddingModel(@Nullable SimilarityMeasure similarityMeasure); } private static final UpdateModelConfiguration DISABLED_UPDATE_MODEL_TESTS = new UpdateModelConfiguration() { @@ -151,7 +150,7 @@ public boolean isEnabled() { } @Override - protected CustomModel createEmbeddingModel(SimilarityMeasure similarityMeasure) { + protected Model createEmbeddingModel(SimilarityMeasure similarityMeasure) { throw new UnsupportedOperationException("Update model tests are disabled"); } }; @@ -351,11 +350,17 @@ public void testParsePersistedConfigWithSecrets_ThrowsUnsupportedModelType() thr assertThat( exception.getMessage(), - containsString(Strings.format("service does not support task type [%s]", parseConfigTestConfig.unsupportedTaskType)) + containsString( + Strings.format(fetchPersistedConfigTaskTypeParsingErrorMessageFormat(), parseConfigTestConfig.unsupportedTaskType) + ) ); } } + protected String fetchPersistedConfigTaskTypeParsingErrorMessageFormat() { + return "service does not support task type [%s]"; + } + public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { var parseConfigTestConfig = testConfiguration.commonConfig; @@ -374,7 +379,7 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists persistedConfigMap.secrets() ); - parseConfigTestConfig.assertModel(model, TaskType.TEXT_EMBEDDING); + parseConfigTestConfig.assertModel(model, parseConfigTestConfig.taskType); } } @@ -396,7 +401,7 @@ public void testParsePersistedConfigWithSecrets_ThrowsWhenAnExtraKeyExistsInServ persistedConfigMap.secrets() ); - parseConfigTestConfig.assertModel(model, TaskType.TEXT_EMBEDDING); + parseConfigTestConfig.assertModel(model, parseConfigTestConfig.taskType); } } @@ -413,7 +418,7 @@ public void testParsePersistedConfigWithSecrets_ThrowsWhenAnExtraKeyExistsInTask var model = service.parsePersistedConfigWithSecrets("id", parseConfigTestConfig.taskType, config.config(), config.secrets()); - parseConfigTestConfig.assertModel(model, TaskType.TEXT_EMBEDDING); + parseConfigTestConfig.assertModel(model, parseConfigTestConfig.taskType); } } @@ -430,7 +435,7 @@ public void testParsePersistedConfigWithSecrets_ThrowsWhenAnExtraKeyExistsInSecr var model = service.parsePersistedConfigWithSecrets("id", parseConfigTestConfig.taskType, config.config(), config.secrets()); - parseConfigTestConfig.assertModel(model, TaskType.TEXT_EMBEDDING); + parseConfigTestConfig.assertModel(model, parseConfigTestConfig.taskType); } } @@ -468,7 +473,7 @@ public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IO () -> service.updateModelWithEmbeddingDetails(getInvalidModel("id", "service"), randomNonNegativeInt()) ); - assertThat(exception.getMessage(), containsString("Can't update embedding details for model of type:")); + assertThat(exception.getMessage(), containsString("Can't update embedding details for model")); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java new file mode 100644 index 0000000000000..dd68c43f5e62d --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java @@ -0,0 +1,840 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama; + +import org.apache.http.HttpHeaders; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionTestUtils; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; +import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; +import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.AbstractInferenceServiceTests; +import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion; +import org.elasticsearch.xpack.inference.services.SenderService; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModel; +import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModelTests; +import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsModelTests; +import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.hamcrest.CoreMatchers; +import org.hamcrest.Matchers; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.Arrays; +import java.util.EnumSet; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.ExceptionsHelper.unwrapCause; +import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; +import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS; +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModelTests.createChatCompletionModel; +import static org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionServiceSettingsTests.getServiceSettingsMap; +import static org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsServiceSettingsTests.buildServiceSettingsMap; +import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.isA; +import static org.mockito.Mockito.mock; + +public class LlamaServiceTests extends AbstractInferenceServiceTests { + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private final MockWebServer webServer = new MockWebServer(); + private ThreadPool threadPool; + private HttpClientManager clientManager; + + public LlamaServiceTests() { + super(createTestConfiguration()); + } + + private static TestConfiguration createTestConfiguration() { + return new TestConfiguration.Builder(new CommonConfig(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING) { + + @Override + protected SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) { + return LlamaServiceTests.createService(threadPool, clientManager); + } + + @Override + protected Map createServiceSettingsMap(TaskType taskType) { + return LlamaServiceTests.createServiceSettingsMap(taskType); + } + + @Override + protected Map createTaskSettingsMap() { + return new HashMap<>(); + } + + @Override + protected Map createSecretSettingsMap() { + return LlamaServiceTests.createSecretSettingsMap(); + } + + @Override + protected void assertModel(Model model, TaskType taskType) { + LlamaServiceTests.assertModel(model, taskType); + } + + @Override + protected EnumSet supportedStreamingTasks() { + return EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.COMPLETION); + } + }).enableUpdateModelTests(new UpdateModelConfiguration() { + @Override + protected LlamaEmbeddingsModel createEmbeddingModel(SimilarityMeasure similarityMeasure) { + return createInternalEmbeddingModel(similarityMeasure); + } + }).build(); + } + + private static void assertModel(Model model, TaskType taskType) { + switch (taskType) { + case TEXT_EMBEDDING -> assertTextEmbeddingModel(model); + case COMPLETION -> assertCompletionModel(model); + case CHAT_COMPLETION -> assertChatCompletionModel(model); + default -> fail("unexpected task type [" + taskType + "]"); + } + } + + private static void assertTextEmbeddingModel(Model model) { + var llamaModel = assertCommonModelFields(model); + + assertThat(llamaModel.getTaskType(), Matchers.is(TaskType.TEXT_EMBEDDING)); + } + + private static LlamaModel assertCommonModelFields(Model model) { + assertThat(model, instanceOf(LlamaModel.class)); + + var llamaModel = (LlamaModel) model; + assertThat(llamaModel.getServiceSettings().modelId(), is("model_id")); + assertThat(llamaModel.uri.toString(), Matchers.is("http://www.abc.com")); + assertThat(llamaModel.getTaskSettings(), Matchers.is(EmptyTaskSettings.INSTANCE)); + assertThat( + ((DefaultSecretSettings) llamaModel.getSecretSettings()).apiKey(), + Matchers.is(new SecureString("secret".toCharArray())) + ); + + return llamaModel; + } + + private static void assertCompletionModel(Model model) { + var llamaModel = assertCommonModelFields(model); + assertThat(llamaModel.getTaskType(), Matchers.is(TaskType.COMPLETION)); + } + + private static void assertChatCompletionModel(Model model) { + var llamaModel = assertCommonModelFields(model); + assertThat(llamaModel.getTaskType(), Matchers.is(TaskType.CHAT_COMPLETION)); + } + + public static SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + return new LlamaService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty()); + } + + private static Map createServiceSettingsMap(TaskType taskType) { + Map settingsMap = new HashMap<>( + Map.of(ServiceFields.URL, "http://www.abc.com", ServiceFields.MODEL_ID, "model_id") + ); + + if (taskType == TaskType.TEXT_EMBEDDING) { + settingsMap.putAll( + Map.of( + ServiceFields.SIMILARITY, + SimilarityMeasure.COSINE.toString(), + ServiceFields.DIMENSIONS, + 1536, + ServiceFields.MAX_INPUT_TOKENS, + 512 + ) + ); + } + + return settingsMap; + } + + private static Map createSecretSettingsMap() { + return new HashMap<>(Map.of("api_key", "secret")); + } + + private static LlamaEmbeddingsModel createInternalEmbeddingModel(@Nullable SimilarityMeasure similarityMeasure) { + var inferenceId = "inference_id"; + + return new LlamaEmbeddingsModel( + inferenceId, + TaskType.TEXT_EMBEDDING, + LlamaService.NAME, + new LlamaEmbeddingsServiceSettings( + "model_id", + "http://www.abc.com", + 1536, + similarityMeasure, + 512, + new RateLimitSettings(10_000) + ), + ChunkingSettingsTests.createRandomChunkingSettings(), + new DefaultSecretSettings(new SecureString("secret".toCharArray())) + ); + } + + protected String fetchPersistedConfigTaskTypeParsingErrorMessageFormat() { + return "Failed to parse stored model [id] for [llama] service, please delete and add the service again"; + } + + @Before + public void init() throws Exception { + webServer.start(); + threadPool = createThreadPool(inferenceUtilityPool()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @After + public void shutdown() throws IOException { + clientManager.close(); + terminate(threadPool); + webServer.close(); + } + + public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { + try (var service = createService()) { + ActionListener modelVerificationActionListener = ActionListener.wrap(model -> { + assertThat(model, instanceOf(LlamaEmbeddingsModel.class)); + + var embeddingsModel = (LlamaEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(((DefaultSecretSettings) (embeddingsModel.getSecretSettings())).apiKey().toString(), is("secret")); + }, e -> fail("parse request should not fail " + e.getMessage())); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + getServiceSettingsMap("model", "url"), + createRandomChunkingSettingsMap(), + getSecretSettingsMap("secret") + ), + modelVerificationActionListener + ); + } + } + + public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { + try (var service = createService()) { + ActionListener modelVerificationActionListener = ActionListener.wrap(model -> { + assertThat(model, instanceOf(LlamaEmbeddingsModel.class)); + + var embeddingsModel = (LlamaEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(((DefaultSecretSettings) (embeddingsModel.getSecretSettings())).apiKey().toString(), is("secret")); + }, e -> fail("parse request should not fail " + e.getMessage())); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + getServiceSettingsMap("model", "url"), + createRandomChunkingSettingsMap(), + getSecretSettingsMap("secret") + ), + modelVerificationActionListener + ); + } + } + + public void testParseRequestConfig_ThrowsException_WithoutModelId() throws IOException { + var url = "url"; + var secret = "secret"; + + try (var service = createService()) { + ActionListener modelVerificationListener = ActionListener.wrap(m -> { + assertThat(m, instanceOf(LlamaChatCompletionModel.class)); + + var chatCompletionModel = (LlamaChatCompletionModel) m; + + assertThat(chatCompletionModel.getServiceSettings().uri().toString(), is(url)); + assertNull(chatCompletionModel.getServiceSettings().modelId()); + assertThat(((DefaultSecretSettings) (chatCompletionModel.getSecretSettings())).apiKey().toString(), is("secret")); + + }, exception -> { + assertThat(exception, instanceOf(ValidationException.class)); + assertThat( + exception.getMessage(), + is("Validation Failed: 1: [service_settings] does not contain the required setting [model_id];") + ); + }); + + service.parseRequestConfig( + "id", + TaskType.CHAT_COMPLETION, + getRequestConfigMap(getServiceSettingsMap(null, url), getSecretSettingsMap(secret)), + modelVerificationListener + ); + } + } + + public void testParseRequestConfig_ThrowsException_WithoutUrl() throws IOException { + var model = "model"; + var secret = "secret"; + + try (var service = createService()) { + ActionListener modelVerificationListener = ActionListener.wrap(m -> { + assertThat(m, instanceOf(LlamaChatCompletionModel.class)); + + var chatCompletionModel = (LlamaChatCompletionModel) m; + + assertThat(chatCompletionModel.getServiceSettings().modelId(), is(model)); + assertNull(chatCompletionModel.getServiceSettings().modelId()); + assertThat(((DefaultSecretSettings) (chatCompletionModel.getSecretSettings())).apiKey().toString(), is("secret")); + + }, exception -> { + assertThat(exception, instanceOf(ValidationException.class)); + assertThat( + exception.getMessage(), + is("Validation Failed: 1: [service_settings] does not contain the required setting [url];") + ); + }); + + service.parseRequestConfig( + "id", + TaskType.CHAT_COMPLETION, + getRequestConfigMap(getServiceSettingsMap(model, null), getSecretSettingsMap(secret)), + modelVerificationListener + ); + } + } + + public void testUnifiedCompletionInfer() throws Exception { + // The escapes are because the streaming response must be on a single line + String responseJson = """ + data: {\ + "id": "chatcmpl-8425dd3d-78f3-4143-93cb-dd576ab8ae26",\ + "choices": [{\ + "delta": {\ + "content": "Deep",\ + "function_call": null,\ + "refusal": null,\ + "role": "assistant",\ + "tool_calls": null\ + },\ + "finish_reason": null,\ + "index": 0,\ + "logprobs": null\ + }\ + ],\ + "created": 1750158492,\ + "model": "llama3.2:3b",\ + "object": "chat.completion.chunk",\ + "service_tier": null,\ + "system_fingerprint": "fp_ollama",\ + "usage": null\ + } + + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + try (var service = new LlamaService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + var model = createChatCompletionModel("model", getUrl(webServer), "secret"); + PlainActionFuture listener = new PlainActionFuture<>(); + service.unifiedCompletionInfer( + model, + UnifiedCompletionRequest.of( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null)) + ), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoErrors().hasEvent(XContentHelper.stripWhitespace(""" + { + "id": "chatcmpl-8425dd3d-78f3-4143-93cb-dd576ab8ae26", + "choices": [{ + "delta": { + "content": "Deep", + "role": "assistant" + }, + "index": 0 + } + ], + "model": "llama3.2:3b", + "object": "chat.completion.chunk" + } + """)); + } + } + + public void testUnifiedCompletionNonStreamingNotFoundError() throws Exception { + String responseJson = """ + { + "detail": "Not Found" + } + """; + webServer.enqueue(new MockResponse().setResponseCode(404).setBody(responseJson)); + + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + try (var service = new LlamaService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + var model = LlamaChatCompletionModelTests.createChatCompletionModel("model", getUrl(webServer), "secret"); + var latch = new CountDownLatch(1); + service.unifiedCompletionInfer( + model, + UnifiedCompletionRequest.of( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null)) + ), + InferenceAction.Request.DEFAULT_TIMEOUT, + ActionListener.runAfter(ActionTestUtils.assertNoSuccessListener(e -> { + try (var builder = XContentFactory.jsonBuilder()) { + var t = unwrapCause(e); + assertThat(t, isA(UnifiedChatCompletionException.class)); + ((UnifiedChatCompletionException) t).toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> { + try { + xContent.toXContent(builder, EMPTY_PARAMS); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + }); + var json = XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType()); + assertThat(json, is(String.format(Locale.ROOT, XContentHelper.stripWhitespace(""" + { + "error" : { + "code" : "not_found", + "message" : "Resource not found at [%s] for request from inference entity id [id] status \ + [404]. Error message: [{\\n \\"detail\\": \\"Not Found\\"\\n}\\n]", + "type" : "llama_error" + } + }"""), getUrl(webServer)))); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + }), latch::countDown) + ); + assertTrue(latch.await(30, TimeUnit.SECONDS)); + } + } + + public void testMidStreamUnifiedCompletionError() throws Exception { + String responseJson = """ + data: {"error": {"message": "400: Invalid value: Model 'llama3.12:3b' not found"}} + + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + testStreamError(XContentHelper.stripWhitespace(""" + { + "error": { + "code": "stream_error", + "message": "Received an error response for request from inference entity id [id].\ + Error message: [400: Invalid value: Model 'llama3.12:3b' not found]", + "type": "llama_error" + } + } + """)); + } + + public void testInfer_StreamRequest() throws Exception { + String responseJson = """ + data: {\ + "id": "chatcmpl-8425dd3d-78f3-4143-93cb-dd576ab8ae26",\ + "choices": [{\ + "delta": {\ + "content": "Deep",\ + "function_call": null,\ + "refusal": null,\ + "role": "assistant",\ + "tool_calls": null\ + },\ + "finish_reason": null,\ + "index": 0,\ + "logprobs": null\ + }\ + ],\ + "created": 1750158492,\ + "model": "llama3.2:3b",\ + "object": "chat.completion.chunk",\ + "service_tier": null,\ + "system_fingerprint": "fp_ollama",\ + "usage": null\ + } + + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + streamCompletion().hasNoErrors().hasEvent(""" + {"completion":[{"delta":"Deep"}]}"""); + } + + private void testStreamError(String expectedResponse) throws Exception { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + try (var service = new LlamaService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + var model = LlamaChatCompletionModelTests.createChatCompletionModel("model", getUrl(webServer), "secret"); + PlainActionFuture listener = new PlainActionFuture<>(); + service.unifiedCompletionInfer( + model, + UnifiedCompletionRequest.of( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null)) + ), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + + InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoEvents().hasErrorMatching(e -> { + e = unwrapCause(e); + assertThat(e, isA(UnifiedChatCompletionException.class)); + try (var builder = XContentFactory.jsonBuilder()) { + ((UnifiedChatCompletionException) e).toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> { + try { + xContent.toXContent(builder, EMPTY_PARAMS); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + }); + var json = XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType()); + + assertThat(json, is(expectedResponse)); + } + }); + } + } + + public void testInfer_StreamRequest_ErrorResponse() { + String responseJson = """ + { + "detail": "Not Found" + }"""; + webServer.enqueue(new MockResponse().setResponseCode(404).setBody(responseJson)); + + var e = assertThrows(ElasticsearchStatusException.class, this::streamCompletion); + assertThat(e.status(), equalTo(RestStatus.NOT_FOUND)); + assertThat(e.getMessage(), equalTo(String.format(Locale.ROOT, """ + Resource not found at [%s] for request from inference entity id [id] status [404]. Error message: [{ + "detail": "Not Found" + }]""", getUrl(webServer)))); + } + + public void testInfer_StreamRequestRetry() throws Exception { + webServer.enqueue(new MockResponse().setResponseCode(503).setBody(""" + { + "error": { + "message": "server busy" + } + }""")); + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(""" + data: {\ + "id": "chatcmpl-8425dd3d-78f3-4143-93cb-dd576ab8ae26",\ + "choices": [{\ + "delta": {\ + "content": "Deep",\ + "function_call": null,\ + "refusal": null,\ + "role": "assistant",\ + "tool_calls": null\ + },\ + "finish_reason": null,\ + "index": 0,\ + "logprobs": null\ + }\ + ],\ + "created": 1750158492,\ + "model": "llama3.2:3b",\ + "object": "chat.completion.chunk",\ + "service_tier": null,\ + "system_fingerprint": "fp_ollama",\ + "usage": null\ + } + + """)); + + streamCompletion().hasNoErrors().hasEvent(""" + {"completion":[{"delta":"Deep"}]}"""); + } + + public void testSupportsStreaming() throws IOException { + try (var service = new LlamaService(mock(), createWithEmptySettings(mock()), mockClusterServiceEmpty())) { + assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION))); + assertFalse(service.canStream(TaskType.ANY)); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInEmbeddingSecretSettingsMap() throws IOException { + try (var service = createService()) { + var secretSettings = getSecretSettingsMap("secret"); + secretSettings.put("extra_key", "value"); + + var config = getRequestConfigMap(getEmbeddingsServiceSettingsMap(), secretSettings); + + ActionListener modelVerificationListener = ActionListener.wrap( + model -> fail("Expected exception, but got model: " + model), + exception -> { + assertThat(exception, instanceOf(ElasticsearchStatusException.class)); + assertThat( + exception.getMessage(), + is("Configuration contains settings [{extra_key=value}] unknown to the [llama] service") + ); + } + ); + + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener); + } + } + + public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException { + var model = LlamaEmbeddingsModelTests.createEmbeddingsModel("id", "url", "api_key"); + model.setURI(getUrl(webServer)); + + testChunkedInfer(model); + } + + public void testChunkedInfer_ChunkingSettingsSet() throws IOException { + var model = LlamaEmbeddingsModelTests.createEmbeddingsModelWithChunkingSettings("id", "url", "api_key"); + model.setURI(getUrl(webServer)); + + testChunkedInfer(model); + } + + public void testChunkedInfer(LlamaEmbeddingsModel model) throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new LlamaService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + + String responseJson = """ + { + "embeddings": [ + [ + 0.010060793, + -0.0017529363 + ], + [ + 0.110060793, + -0.1017529363 + ] + ] + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + PlainActionFuture> listener = new PlainActionFuture<>(); + service.chunkedInfer( + model, + null, + List.of(new ChunkInferenceInput("abc"), new ChunkInferenceInput("def")), + new HashMap<>(), + InputType.INTERNAL_INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var results = listener.actionGet(TIMEOUT); + + assertThat(results, hasSize(2)); + { + assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class)); + var floatResult = (ChunkedInferenceEmbedding) results.get(0); + assertThat(floatResult.chunks(), hasSize(1)); + assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertTrue( + Arrays.equals( + new float[] { 0.010060793f, -0.0017529363f }, + ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() + ) + ); + } + { + assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class)); + var floatResult = (ChunkedInferenceEmbedding) results.get(1); + assertThat(floatResult.chunks(), hasSize(1)); + assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); + assertTrue( + Arrays.equals( + new float[] { 0.110060793f, -0.1017529363f }, + ((TextEmbeddingFloatResults.Embedding) floatResult.chunks().get(0).embedding()).values() + ) + ); + } + + assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + assertThat( + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaTypeWithoutParameters()) + ); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer api_key")); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat(requestMap.size(), Matchers.is(2)); + assertThat(requestMap.get("contents"), Matchers.is(List.of("abc", "def"))); + assertThat(requestMap.get("model_id"), Matchers.is("id")); + } + } + + public void testGetConfiguration() throws Exception { + try (var service = createService()) { + String content = XContentHelper.stripWhitespace(""" + { + "service": "llama", + "name": "Llama", + "task_types": ["text_embedding", "completion", "chat_completion"], + "configurations": { + "api_key": { + "description": "API Key for the provider you're connecting to.", + "label": "API Key", + "required": true, + "sensitive": true, + "updatable": true, + "type": "str", + "supported_task_types": ["text_embedding", "completion", "chat_completion"] + }, + "model_id": { + "description": "Refer to the Llama models documentation for the list of available models.", + "label": "Model", + "required": true, + "sensitive": false, + "updatable": false, + "type": "str", + "supported_task_types": ["text_embedding", "completion", "chat_completion"] + }, + "rate_limit.requests_per_minute": { + "description": "Minimize the number of rate limit errors.", + "label": "Rate Limit", + "required": false, + "sensitive": false, + "updatable": false, + "type": "int", + "supported_task_types": ["text_embedding", "completion", "chat_completion"] + }, + "url": { + "description": "The URL endpoint to use for the requests.", + "label": "URL", + "required": true, + "sensitive": false, + "updatable": false, + "type": "str", + "supported_task_types": ["text_embedding", "completion", "chat_completion"] + } + } + } + """); + InferenceServiceConfiguration configuration = InferenceServiceConfiguration.fromXContentBytes( + new BytesArray(content), + XContentType.JSON + ); + boolean humanReadable = true; + BytesReference originalBytes = toShuffledXContent(configuration, XContentType.JSON, ToXContent.EMPTY_PARAMS, humanReadable); + InferenceServiceConfiguration serviceConfiguration = service.getConfiguration(); + assertToXContentEquivalent( + originalBytes, + toXContent(serviceConfiguration, XContentType.JSON, humanReadable), + XContentType.JSON + ); + } + } + + private InferenceEventsAssertion streamCompletion() throws Exception { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + try (var service = new LlamaService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + var model = LlamaChatCompletionModelTests.createCompletionModel("model", getUrl(webServer), "secret"); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + null, + null, + null, + List.of("abc"), + true, + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + return InferenceEventsAssertion.assertThat(listener.actionGet(TIMEOUT)).hasFinishedStream(); + } + } + + private LlamaService createService() { + return new LlamaService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), mockClusterServiceEmpty()); + } + + private Map getRequestConfigMap( + Map serviceSettings, + Map chunkingSettings, + Map secretSettings + ) { + var requestConfigMap = getRequestConfigMap(serviceSettings, secretSettings); + requestConfigMap.put(ModelConfigurations.CHUNKING_SETTINGS, chunkingSettings); + + return requestConfigMap; + } + + private Map getRequestConfigMap(Map serviceSettings, Map secretSettings) { + var builtServiceSettings = new HashMap<>(); + builtServiceSettings.putAll(serviceSettings); + builtServiceSettings.putAll(secretSettings); + + return new HashMap<>(Map.of(ModelConfigurations.SERVICE_SETTINGS, builtServiceSettings)); + } + + private static Map getEmbeddingsServiceSettingsMap() { + return buildServiceSettingsMap("id", "url", SimilarityMeasure.COSINE.toString(), null, null, null); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreatorTests.java new file mode 100644 index 0000000000000..366e0926f0daa --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreatorTests.java @@ -0,0 +1,283 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.action; + +import org.apache.http.HttpHeaders; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests; +import org.elasticsearch.xpack.inference.InputTypeTests; +import org.elasticsearch.xpack.inference.common.TruncatorTests; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; +import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModelTests; +import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsModelTests; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields; +import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; +import static org.elasticsearch.xpack.inference.logging.ThrottlerManagerTests.mockThrottlerManager; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; + +public class LlamaActionCreatorTests extends ESTestCase { + + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private final MockWebServer webServer = new MockWebServer(); + private ThreadPool threadPool; + private HttpClientManager clientManager; + + @Before + public void init() throws Exception { + webServer.start(); + threadPool = createThreadPool(inferenceUtilityPool()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @After + public void shutdown() throws IOException { + clientManager.close(); + terminate(threadPool); + webServer.close(); + } + + public void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = createSender(senderFactory)) { + sender.start(); + + String responseJson = """ + { + "embeddings": [ + [ + -0.0123, + 0.123 + ] + ] + { + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + PlainActionFuture listener = createEmbeddingsFuture(sender, createWithEmptySettings(threadPool)); + + var result = listener.actionGet(TIMEOUT); + + assertThat(result.asMap(), is(TextEmbeddingFloatResultsTests.buildExpectationFloat(List.of(new float[] { -0.0123F, 0.123F })))); + + assertEmbeddingsRequest(); + } + } + + public void testExecute_FailsFromInvalidResponseFormat_ForEmbeddingsAction() throws IOException { + var settings = buildSettingsWithRetryFields( + TimeValue.timeValueMillis(1), + TimeValue.timeValueMinutes(1), + TimeValue.timeValueSeconds(0) + ); + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings); + + try (var sender = createSender(senderFactory)) { + sender.start(); + + String responseJson = """ + [ + { + "embeddings": [ + [ + -0.0123, + 0.123 + ] + ] + { + ] + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + PlainActionFuture listener = createEmbeddingsFuture(sender, createWithEmptySettings(threadPool)); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + assertThat( + thrownException.getMessage(), + is("Failed to parse object: expecting token of type [START_ARRAY] but found [START_OBJECT]") + ); + + assertEmbeddingsRequest(); + } + } + + public void testExecute_ReturnsSuccessfulResponse_ForCompletionAction() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = createSender(senderFactory)) { + sender.start(); + + String responseJson = """ + { + "id": "chatcmpl-03e70a75-efb6-447d-b661-e5ed0bd59ce9", + "choices": [ + { + "finish_reason": "length", + "index": 0, + "logprobs": null, + "message": { + "content": "Hello there, how may I assist you today?", + "refusal": null, + "role": "assistant", + "annotations": null, + "audio": null, + "function_call": null, + "tool_calls": null + } + } + ], + "created": 1750157476, + "model": "llama3.2:3b", + "object": "chat.completion", + "service_tier": null, + "system_fingerprint": "fp_ollama", + "usage": { + "completion_tokens": 10, + "prompt_tokens": 30, + "total_tokens": 40, + "completion_tokens_details": null, + "prompt_tokens_details": null + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + PlainActionFuture listener = createCompletionFuture(sender, createWithEmptySettings(threadPool)); + + var result = listener.actionGet(TIMEOUT); + + assertThat(result.asMap(), is(buildExpectationCompletion(List.of("Hello there, how may I assist you today?")))); + + assertCompletionRequest(); + } + } + + public void testExecute_FailsFromInvalidResponseFormat_ForCompletionAction() throws IOException { + var settings = buildSettingsWithRetryFields( + TimeValue.timeValueMillis(1), + TimeValue.timeValueMinutes(1), + TimeValue.timeValueSeconds(0) + ); + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings); + + try (var sender = createSender(senderFactory)) { + sender.start(); + + String responseJson = """ + { + "invalid_field": "unexpected" + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + PlainActionFuture listener = createCompletionFuture( + sender, + new ServiceComponents(threadPool, mockThrottlerManager(), settings, TruncatorTests.createTruncator()) + ); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + assertThat( + thrownException.getMessage(), + is("Failed to send Llama completion request from inference entity id [id]. Cause: Required [choices]") + ); + + assertCompletionRequest(); + } + } + + private PlainActionFuture createEmbeddingsFuture(Sender sender, ServiceComponents threadPool) { + var model = LlamaEmbeddingsModelTests.createEmbeddingsModel("model", getUrl(webServer), "secret"); + var actionCreator = new LlamaActionCreator(sender, threadPool); + var action = actionCreator.create(model); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute( + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + return listener; + } + + private PlainActionFuture createCompletionFuture(Sender sender, ServiceComponents threadPool) { + var model = LlamaChatCompletionModelTests.createCompletionModel("model", getUrl(webServer), "secret"); + var actionCreator = new LlamaActionCreator(sender, threadPool); + var action = actionCreator.create(model); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new ChatCompletionInput(List.of("Hello"), false), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + return listener; + } + + private void assertCompletionRequest() throws IOException { + assertCommonRequestProperties(); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat(requestMap.size(), is(4)); + assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "Hello")))); + assertThat(requestMap.get("model"), is("model")); + assertThat(requestMap.get("n"), is(1)); + assertThat(requestMap.get("stream"), is(false)); + } + + @SuppressWarnings("unchecked") + private void assertEmbeddingsRequest() throws IOException { + assertCommonRequestProperties(); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat(requestMap.size(), is(2)); + assertThat(requestMap.get("contents"), instanceOf(List.class)); + var inputList = (List) requestMap.get("contents"); + assertThat(inputList, contains("abc")); + } + + private void assertCommonRequestProperties() { + assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + assertThat( + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaTypeWithoutParameters()) + ); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionModelTests.java new file mode 100644 index 0000000000000..844d17addac6d --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionModelTests.java @@ -0,0 +1,142 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.completion; + +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.inference.EmptySecretSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; + +import java.util.List; + +import static org.hamcrest.Matchers.is; + +public class LlamaChatCompletionModelTests extends ESTestCase { + + public static LlamaChatCompletionModel createCompletionModel(String modelId, String url, String apiKey) { + return new LlamaChatCompletionModel( + "id", + TaskType.COMPLETION, + "llama", + new LlamaChatCompletionServiceSettings(modelId, url, null), + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } + + public static LlamaChatCompletionModel createChatCompletionModel(String modelId, String url, String apiKey) { + return new LlamaChatCompletionModel( + "id", + TaskType.CHAT_COMPLETION, + "llama", + new LlamaChatCompletionServiceSettings(modelId, url, null), + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } + + public static LlamaChatCompletionModel createChatCompletionModelNoAuth(String modelId, String url) { + return new LlamaChatCompletionModel( + "id", + TaskType.CHAT_COMPLETION, + "llama", + new LlamaChatCompletionServiceSettings(modelId, url, null), + EmptySecretSettings.INSTANCE + ); + } + + public void testOverrideWith_UnifiedCompletionRequest_KeepsSameModelId() { + var model = createCompletionModel("model_name", "url", "api_key"); + var request = new UnifiedCompletionRequest( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)), + "model_name", + null, + null, + null, + null, + null, + null + ); + + var overriddenModel = LlamaChatCompletionModel.of(model, request); + + assertThat(overriddenModel, is(model)); + } + + public void testOverrideWith_UnifiedCompletionRequest_OverridesExistingModelId() { + var model = createCompletionModel("model_name", "url", "api_key"); + var request = new UnifiedCompletionRequest( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)), + "different_model", + null, + null, + null, + null, + null, + null + ); + + var overriddenModel = LlamaChatCompletionModel.of(model, request); + + assertThat(overriddenModel.getServiceSettings().modelId(), is("different_model")); + } + + public void testOverrideWith_UnifiedCompletionRequest_OverridesNullModelId() { + var model = createCompletionModel(null, "url", "api_key"); + var request = new UnifiedCompletionRequest( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)), + "different_model", + null, + null, + null, + null, + null, + null + ); + + var overriddenModel = LlamaChatCompletionModel.of(model, request); + + assertThat(overriddenModel.getServiceSettings().modelId(), is("different_model")); + } + + public void testOverrideWith_UnifiedCompletionRequest_KeepsNullIfNoModelIdProvided() { + var model = createCompletionModel(null, "url", "api_key"); + var request = new UnifiedCompletionRequest( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)), + null, + null, + null, + null, + null, + null, + null + ); + + var overriddenModel = LlamaChatCompletionModel.of(model, request); + + assertNull(overriddenModel.getServiceSettings().modelId()); + } + + public void testOverrideWith_UnifiedCompletionRequest_UsesModelFields_WhenRequestDoesNotOverride() { + var model = createCompletionModel("model_name", "url", "api_key"); + var request = new UnifiedCompletionRequest( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)), + null, // not overriding model + null, + null, + null, + null, + null, + null + ); + + var overriddenModel = LlamaChatCompletionModel.of(model, request); + + assertThat(overriddenModel.getServiceSettings().modelId(), is("model_name")); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionResponseHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionResponseHandlerTests.java new file mode 100644 index 0000000000000..c9b6069d383ed --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionResponseHandlerTests.java @@ -0,0 +1,162 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.completion; + +import org.apache.http.HttpResponse; +import org.apache.http.StatusLine; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.RetryException; +import org.elasticsearch.xpack.inference.external.request.Request; + +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.charset.StandardCharsets; + +import static org.elasticsearch.ExceptionsHelper.unwrapCause; +import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.isA; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class LlamaChatCompletionResponseHandlerTests extends ESTestCase { + private final LlamaChatCompletionResponseHandler responseHandler = new LlamaChatCompletionResponseHandler( + "chat completions", + (a, b) -> mock() + ); + + public void testFailNotFound() throws IOException { + var responseJson = XContentHelper.stripWhitespace(""" + { + "detail": "Not Found" + } + """); + + var errorJson = invalidResponseJson(responseJson, 404); + + assertThat(errorJson, is(XContentHelper.stripWhitespace(""" + { + "error" : { + "code" : "not_found", + "message" : "Resource not found at [https://api.llama.ai/v1/chat/completions] for request from inference entity id [id] \ + status [404]. Error message: [{\\"detail\\":\\"Not Found\\"}]", + "type" : "llama_error" + } + }"""))); + } + + public void testFailBadRequest() throws IOException { + var responseJson = XContentHelper.stripWhitespace(""" + { + "error": { + "detail": { + "errors": [{ + "loc": [ + "body", + "messages" + ], + "msg": "Field required", + "type": "missing" + } + ] + } + } + } + """); + + var errorJson = invalidResponseJson(responseJson, 400); + + assertThat(errorJson, is(XContentHelper.stripWhitespace(""" + { + "error": { + "code": "bad_request", + "message": "Received a bad request status code for request from inference entity id [id] status [400].\ + Error message: [{\\"error\\":{\\"detail\\":{\\"errors\\":[{\\"loc\\":[\\"body\\",\\"messages\\"],\\"msg\\":\\"Field\ + required\\",\\"type\\":\\"missing\\"}]}}}]", + "type": "llama_error" + } + } + """))); + } + + public void testFailValidationWithInvalidJson() throws IOException { + var responseJson = """ + what? this isn't a json + """; + + var errorJson = invalidResponseJson(responseJson, 500); + + assertThat(errorJson, is(XContentHelper.stripWhitespace(""" + { + "error": { + "code": "bad_request", + "message": "Received a server error status code for request from inference entity id [id] status [500]. Error message: \ + [what? this isn't a json\\n]", + "type": "llama_error" + } + } + """))); + } + + private String invalidResponseJson(String responseJson, int statusCode) throws IOException { + var exception = invalidResponse(responseJson, statusCode); + assertThat(exception, isA(RetryException.class)); + assertThat(unwrapCause(exception), isA(UnifiedChatCompletionException.class)); + return toJson((UnifiedChatCompletionException) unwrapCause(exception)); + } + + private Exception invalidResponse(String responseJson, int statusCode) { + return expectThrows( + RetryException.class, + () -> responseHandler.validateResponse( + mock(), + mock(), + mockRequest(), + new HttpResult(mockErrorResponse(statusCode), responseJson.getBytes(StandardCharsets.UTF_8)), + true + ) + ); + } + + private static Request mockRequest() throws URISyntaxException { + var request = mock(Request.class); + when(request.getInferenceEntityId()).thenReturn("id"); + when(request.isStreaming()).thenReturn(true); + when(request.getURI()).thenReturn(new URI("https://api.llama.ai/v1/chat/completions")); + return request; + } + + private static HttpResponse mockErrorResponse(int statusCode) { + var statusLine = mock(StatusLine.class); + when(statusLine.getStatusCode()).thenReturn(statusCode); + + var response = mock(HttpResponse.class); + when(response.getStatusLine()).thenReturn(statusLine); + + return response; + } + + private String toJson(UnifiedChatCompletionException e) throws IOException { + try (var builder = XContentFactory.jsonBuilder()) { + e.toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> { + try { + xContent.toXContent(builder, EMPTY_PARAMS); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + }); + return XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType()); + } + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionServiceSettingsTests.java new file mode 100644 index 0000000000000..21b42453d9c39 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionServiceSettingsTests.java @@ -0,0 +1,198 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.completion; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; + +public class LlamaChatCompletionServiceSettingsTests extends AbstractBWCWireSerializationTestCase { + + public static final String MODEL_ID = "some model"; + public static final String CORRECT_URL = "https://www.elastic.co"; + public static final int RATE_LIMIT = 2; + + public void testFromMap_AllFields_Success() { + var serviceSettings = LlamaChatCompletionServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + MODEL_ID, + ServiceFields.URL, + CORRECT_URL, + RateLimitSettings.FIELD_NAME, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ) + ), + ConfigurationParseContext.PERSISTENT + ); + + assertThat(serviceSettings, is(new LlamaChatCompletionServiceSettings(MODEL_ID, CORRECT_URL, new RateLimitSettings(RATE_LIMIT)))); + } + + public void testFromMap_MissingModelId_ThrowsException() { + var thrownException = expectThrows( + ValidationException.class, + () -> LlamaChatCompletionServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.URL, + CORRECT_URL, + RateLimitSettings.FIELD_NAME, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ) + ), + ConfigurationParseContext.PERSISTENT + ) + ); + + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] does not contain the required setting [model_id];") + ); + } + + public void testFromMap_MissingUrl_ThrowsException() { + var thrownException = expectThrows( + ValidationException.class, + () -> LlamaChatCompletionServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + MODEL_ID, + RateLimitSettings.FIELD_NAME, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ) + ), + ConfigurationParseContext.PERSISTENT + ) + ); + + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] does not contain the required setting [url];") + ); + } + + public void testFromMap_MissingRateLimit_Success() { + var serviceSettings = LlamaChatCompletionServiceSettings.fromMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, MODEL_ID, ServiceFields.URL, CORRECT_URL)), + ConfigurationParseContext.PERSISTENT + ); + + assertThat(serviceSettings, is(new LlamaChatCompletionServiceSettings(MODEL_ID, CORRECT_URL, null))); + } + + public void testToXContent_WritesAllValues() throws IOException { + var serviceSettings = LlamaChatCompletionServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + MODEL_ID, + ServiceFields.URL, + CORRECT_URL, + RateLimitSettings.FIELD_NAME, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ) + ), + ConfigurationParseContext.PERSISTENT + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + serviceSettings.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + var expected = XContentHelper.stripWhitespace(""" + { + "model_id": "some model", + "url": "https://www.elastic.co", + "rate_limit": { + "requests_per_minute": 2 + } + } + """); + + assertThat(xContentResult, is(expected)); + } + + public void testToXContent_DoesNotWriteOptionalValues_DefaultRateLimit() throws IOException { + var serviceSettings = LlamaChatCompletionServiceSettings.fromMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, MODEL_ID, ServiceFields.URL, CORRECT_URL)), + ConfigurationParseContext.PERSISTENT + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + serviceSettings.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + var expected = XContentHelper.stripWhitespace(""" + { + "model_id": "some model", + "url": "https://www.elastic.co", + "rate_limit": { + "requests_per_minute": 3000 + } + } + """); + assertThat(xContentResult, is(expected)); + } + + @Override + protected Writeable.Reader instanceReader() { + return LlamaChatCompletionServiceSettings::new; + } + + @Override + protected LlamaChatCompletionServiceSettings createTestInstance() { + return createRandom(); + } + + @Override + protected LlamaChatCompletionServiceSettings mutateInstance(LlamaChatCompletionServiceSettings instance) throws IOException { + return randomValueOtherThan(instance, LlamaChatCompletionServiceSettingsTests::createRandom); + } + + @Override + protected LlamaChatCompletionServiceSettings mutateInstanceForVersion( + LlamaChatCompletionServiceSettings instance, + TransportVersion version + ) { + return instance; + } + + private static LlamaChatCompletionServiceSettings createRandom() { + var modelId = randomAlphaOfLength(8); + var url = randomAlphaOfLength(15); + return new LlamaChatCompletionServiceSettings(modelId, ServiceUtils.createUri(url), RateLimitSettingsTests.createRandom()); + } + + public static Map getServiceSettingsMap(String model, String url) { + var map = new HashMap(); + + map.put(ServiceFields.MODEL_ID, model); + map.put(ServiceFields.URL, url); + + return map; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsModelTests.java new file mode 100644 index 0000000000000..4e75cab196a6d --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsModelTests.java @@ -0,0 +1,51 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.embeddings; + +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.inference.EmptySecretSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; + +import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettings; + +public class LlamaEmbeddingsModelTests extends ESTestCase { + public static LlamaEmbeddingsModel createEmbeddingsModel(String modelId, String url, String apiKey) { + return new LlamaEmbeddingsModel( + "id", + TaskType.TEXT_EMBEDDING, + "llama", + new LlamaEmbeddingsServiceSettings(modelId, url, null, null, null, null), + null, + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } + + public static LlamaEmbeddingsModel createEmbeddingsModelWithChunkingSettings(String modelId, String url, String apiKey) { + return new LlamaEmbeddingsModel( + "id", + TaskType.TEXT_EMBEDDING, + "llama", + new LlamaEmbeddingsServiceSettings(modelId, url, null, null, null, null), + createRandomChunkingSettings(), + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } + + public static LlamaEmbeddingsModel createEmbeddingsModelNoAuth(String modelId, String url) { + return new LlamaEmbeddingsModel( + "id", + TaskType.TEXT_EMBEDDING, + "llama", + new LlamaEmbeddingsServiceSettings(modelId, url, null, null, null, null), + null, + EmptySecretSettings.INSTANCE + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsServiceSettingsTests.java new file mode 100644 index 0000000000000..5fd3ce704540c --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/embeddings/LlamaEmbeddingsServiceSettingsTests.java @@ -0,0 +1,479 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.embeddings; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.ByteArrayStreamInput; +import org.elasticsearch.common.io.stream.BytesStreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; +import org.hamcrest.CoreMatchers; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; + +public class LlamaEmbeddingsServiceSettingsTests extends AbstractWireSerializingTestCase { + private static final String MODEL_ID = "some model"; + private static final String CORRECT_URL = "https://www.elastic.co"; + private static final int DIMENSIONS = 384; + private static final SimilarityMeasure SIMILARITY_MEASURE = SimilarityMeasure.DOT_PRODUCT; + private static final int MAX_INPUT_TOKENS = 128; + private static final int RATE_LIMIT = 2; + + public void testFromMap_AllFields_Success() { + var serviceSettings = LlamaEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_ID, + CORRECT_URL, + SIMILARITY_MEASURE.toString(), + DIMENSIONS, + MAX_INPUT_TOKENS, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ), + ConfigurationParseContext.PERSISTENT + ); + + assertThat( + serviceSettings, + is( + new LlamaEmbeddingsServiceSettings( + MODEL_ID, + CORRECT_URL, + DIMENSIONS, + SIMILARITY_MEASURE, + MAX_INPUT_TOKENS, + new RateLimitSettings(RATE_LIMIT) + ) + ) + ); + } + + public void testFromMap_NoModelId_Failure() { + var thrownException = expectThrows( + ValidationException.class, + () -> LlamaEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + null, + CORRECT_URL, + SIMILARITY_MEASURE.toString(), + DIMENSIONS, + MAX_INPUT_TOKENS, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ), + ConfigurationParseContext.PERSISTENT + ) + ); + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] does not contain the required setting [model_id];") + ); + } + + public void testFromMap_NoUrl_Failure() { + var thrownException = expectThrows( + ValidationException.class, + () -> LlamaEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_ID, + null, + SIMILARITY_MEASURE.toString(), + DIMENSIONS, + MAX_INPUT_TOKENS, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ), + ConfigurationParseContext.PERSISTENT + ) + ); + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] does not contain the required setting [url];") + ); + } + + public void testFromMap_EmptyUrl_Failure() { + var thrownException = expectThrows( + ValidationException.class, + () -> LlamaEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_ID, + "", + SIMILARITY_MEASURE.toString(), + DIMENSIONS, + MAX_INPUT_TOKENS, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ), + ConfigurationParseContext.PERSISTENT + ) + ); + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] Invalid value empty string. [url] must be a non-empty string;") + ); + } + + public void testFromMap_InvalidUrl_Failure() { + var thrownException = expectThrows( + ValidationException.class, + () -> LlamaEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_ID, + "^^^", + SIMILARITY_MEASURE.toString(), + DIMENSIONS, + MAX_INPUT_TOKENS, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ), + ConfigurationParseContext.PERSISTENT + ) + ); + assertThat( + thrownException.getMessage(), + containsString( + "Validation Failed: 1: [service_settings] Invalid url [^^^] received for field [url]. " + + "Error: unable to parse url [^^^]. Reason: Illegal character in path;" + ) + ); + } + + public void testFromMap_NoSimilarity_Success() { + var serviceSettings = LlamaEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_ID, + CORRECT_URL, + null, + DIMENSIONS, + MAX_INPUT_TOKENS, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ), + ConfigurationParseContext.PERSISTENT + ); + + assertThat( + serviceSettings, + is( + new LlamaEmbeddingsServiceSettings( + MODEL_ID, + CORRECT_URL, + DIMENSIONS, + null, + MAX_INPUT_TOKENS, + new RateLimitSettings(RATE_LIMIT) + ) + ) + ); + } + + public void testFromMap_InvalidSimilarity_Failure() { + var thrownException = expectThrows( + ValidationException.class, + () -> LlamaEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_ID, + CORRECT_URL, + "by_size", + DIMENSIONS, + MAX_INPUT_TOKENS, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ), + ConfigurationParseContext.PERSISTENT + ) + ); + assertThat( + thrownException.getMessage(), + containsString( + "Validation Failed: 1: [service_settings] Invalid value [by_size] received. " + + "[similarity] must be one of [cosine, dot_product, l2_norm];" + ) + ); + } + + public void testFromMap_NoDimensions_Success() { + var serviceSettings = LlamaEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_ID, + CORRECT_URL, + SIMILARITY_MEASURE.toString(), + null, + MAX_INPUT_TOKENS, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ), + ConfigurationParseContext.PERSISTENT + ); + + assertThat( + serviceSettings, + is( + new LlamaEmbeddingsServiceSettings( + MODEL_ID, + CORRECT_URL, + null, + SIMILARITY_MEASURE, + MAX_INPUT_TOKENS, + new RateLimitSettings(RATE_LIMIT) + ) + ) + ); + } + + public void testFromMap_ZeroDimensions_Failure() { + var thrownException = expectThrows( + ValidationException.class, + () -> LlamaEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_ID, + CORRECT_URL, + SIMILARITY_MEASURE.toString(), + 0, + MAX_INPUT_TOKENS, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ), + ConfigurationParseContext.PERSISTENT + ) + ); + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] Invalid value [0]. [dimensions] must be a positive integer;") + ); + } + + public void testFromMap_NegativeDimensions_Failure() { + var thrownException = expectThrows( + ValidationException.class, + () -> LlamaEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_ID, + CORRECT_URL, + SIMILARITY_MEASURE.toString(), + -10, + MAX_INPUT_TOKENS, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ), + ConfigurationParseContext.PERSISTENT + ) + ); + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] Invalid value [-10]. [dimensions] must be a positive integer;") + ); + } + + public void testFromMap_NoInputTokens_Success() { + var serviceSettings = LlamaEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_ID, + CORRECT_URL, + SIMILARITY_MEASURE.toString(), + DIMENSIONS, + null, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ), + ConfigurationParseContext.PERSISTENT + ); + + assertThat( + serviceSettings, + is( + new LlamaEmbeddingsServiceSettings( + MODEL_ID, + CORRECT_URL, + DIMENSIONS, + SIMILARITY_MEASURE, + null, + new RateLimitSettings(RATE_LIMIT) + ) + ) + ); + } + + public void testFromMap_ZeroInputTokens_Failure() { + var thrownException = expectThrows( + ValidationException.class, + () -> LlamaEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_ID, + CORRECT_URL, + SIMILARITY_MEASURE.toString(), + DIMENSIONS, + 0, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ), + ConfigurationParseContext.PERSISTENT + ) + ); + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] Invalid value [0]. [max_input_tokens] must be a positive integer;") + ); + } + + public void testFromMap_NegativeInputTokens_Failure() { + var thrownException = expectThrows( + ValidationException.class, + () -> LlamaEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap( + MODEL_ID, + CORRECT_URL, + SIMILARITY_MEASURE.toString(), + DIMENSIONS, + -10, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ), + ConfigurationParseContext.PERSISTENT + ) + ); + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] Invalid value [-10]. [max_input_tokens] must be a positive integer;") + ); + } + + public void testFromMap_NoRateLimit_Success() { + var serviceSettings = LlamaEmbeddingsServiceSettings.fromMap( + buildServiceSettingsMap(MODEL_ID, CORRECT_URL, SIMILARITY_MEASURE.toString(), DIMENSIONS, MAX_INPUT_TOKENS, null), + ConfigurationParseContext.PERSISTENT + ); + + assertThat( + serviceSettings, + is( + new LlamaEmbeddingsServiceSettings( + MODEL_ID, + CORRECT_URL, + DIMENSIONS, + SIMILARITY_MEASURE, + MAX_INPUT_TOKENS, + new RateLimitSettings(3000) + ) + ) + ); + } + + public void testToXContent_WritesAllValues() throws IOException { + var entity = new LlamaEmbeddingsServiceSettings( + MODEL_ID, + CORRECT_URL, + DIMENSIONS, + SIMILARITY_MEASURE, + MAX_INPUT_TOKENS, + new RateLimitSettings(3) + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, CoreMatchers.is(XContentHelper.stripWhitespace(""" + { + "model_id": "some model", + "url": "https://www.elastic.co", + "dimensions": 384, + "similarity": "dot_product", + "max_input_tokens": 128, + "rate_limit": { + "requests_per_minute": 3 + } + } + """))); + } + + public void testStreamInputAndOutput_WritesValuesCorrectly() throws IOException { + var outputBuffer = new BytesStreamOutput(); + var settings = new LlamaEmbeddingsServiceSettings( + MODEL_ID, + CORRECT_URL, + DIMENSIONS, + SIMILARITY_MEASURE, + MAX_INPUT_TOKENS, + new RateLimitSettings(3) + ); + settings.writeTo(outputBuffer); + + var outputBufferRef = outputBuffer.bytes(); + var inputBuffer = new ByteArrayStreamInput(outputBufferRef.array()); + + var settingsFromBuffer = new LlamaEmbeddingsServiceSettings(inputBuffer); + + assertEquals(settings, settingsFromBuffer); + } + + @Override + protected Writeable.Reader instanceReader() { + return LlamaEmbeddingsServiceSettings::new; + } + + @Override + protected LlamaEmbeddingsServiceSettings createTestInstance() { + return createRandom(); + } + + @Override + protected LlamaEmbeddingsServiceSettings mutateInstance(LlamaEmbeddingsServiceSettings instance) throws IOException { + return randomValueOtherThan(instance, LlamaEmbeddingsServiceSettingsTests::createRandom); + } + + private static LlamaEmbeddingsServiceSettings createRandom() { + var modelId = randomAlphaOfLength(8); + var url = randomAlphaOfLength(15); + var similarityMeasure = randomFrom(SimilarityMeasure.values()); + var dimensions = randomIntBetween(32, 256); + var maxInputTokens = randomIntBetween(128, 256); + return new LlamaEmbeddingsServiceSettings( + modelId, + url, + dimensions, + similarityMeasure, + maxInputTokens, + RateLimitSettingsTests.createRandom() + ); + } + + public static HashMap buildServiceSettingsMap( + @Nullable String modelId, + @Nullable String url, + @Nullable String similarity, + @Nullable Integer dimensions, + @Nullable Integer maxInputTokens, + @Nullable HashMap rateLimitSettings + ) { + HashMap result = new HashMap<>(); + if (modelId != null) { + result.put(ServiceFields.MODEL_ID, modelId); + } + if (url != null) { + result.put(ServiceFields.URL, url); + } + if (similarity != null) { + result.put(ServiceFields.SIMILARITY, similarity); + } + if (dimensions != null) { + result.put(ServiceFields.DIMENSIONS, dimensions); + } + if (maxInputTokens != null) { + result.put(ServiceFields.MAX_INPUT_TOKENS, maxInputTokens); + } + if (rateLimitSettings != null) { + result.put(RateLimitSettings.FIELD_NAME, rateLimitSettings); + } + return result; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequestEntityTests.java new file mode 100644 index 0000000000000..dd8b3d7dfa38c --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequestEntityTests.java @@ -0,0 +1,64 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.request.completion; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.json.JsonXContent; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModel; +import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModelTests; + +import java.io.IOException; +import java.util.ArrayList; + +public class LlamaChatCompletionRequestEntityTests extends ESTestCase { + private static final String ROLE = "user"; + + public void testModelUserFieldsSerialization() throws IOException { + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Hello, world!"), + ROLE, + null, + null + ); + var messageList = new ArrayList(); + messageList.add(message); + + var unifiedRequest = UnifiedCompletionRequest.of(messageList); + + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + LlamaChatCompletionModel model = LlamaChatCompletionModelTests.createChatCompletionModel("model", "url", "api-key"); + + LlamaChatCompletionRequestEntity entity = new LlamaChatCompletionRequestEntity(unifiedChatInput, model); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + String expectedJson = """ + { + "messages": [{ + "content": "Hello, world!", + "role": "user" + } + ], + "model": "model", + "n": 1, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """; + assertEquals(XContentHelper.stripWhitespace(expectedJson), Strings.toString(builder)); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequestTests.java new file mode 100644 index 0000000000000..6f0701a810fb1 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/request/completion/LlamaChatCompletionRequestTests.java @@ -0,0 +1,82 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.request.completion; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModelTests; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; + +public class LlamaChatCompletionRequestTests extends ESTestCase { + + public void testCreateRequest_WithStreaming() throws IOException { + String input = randomAlphaOfLength(15); + var request = createRequest("model", "url", "secret", input, true); + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(request.getURI().toString(), is("url")); + assertThat(requestMap.get("stream"), is(true)); + assertThat(requestMap.get("model"), is("model")); + assertThat(requestMap.get("n"), is(1)); + assertThat(requestMap.get("stream_options"), is(Map.of("include_usage", true))); + assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", input)))); + assertThat(httpPost.getFirstHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); + } + + public void testCreateRequest_NoStreaming_NoAuthorization() throws IOException { + String input = randomAlphaOfLength(15); + var request = createRequestWithNoAuth("model", "url", input, false); + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(request.getURI().toString(), is("url")); + assertThat(requestMap.get("stream"), is(false)); + assertThat(requestMap.get("model"), is("model")); + assertThat(requestMap.get("n"), is(1)); + assertNull(requestMap.get("stream_options")); + assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", input)))); + assertNull(httpPost.getFirstHeader("Authorization")); + } + + public void testTruncate_DoesNotReduceInputTextSize() { + String input = randomAlphaOfLength(5); + var request = createRequest("model", "url", "secret", input, true); + assertThat(request.truncate(), is(request)); + } + + public void testTruncationInfo_ReturnsNull() { + var request = createRequest("model", "url", "secret", randomAlphaOfLength(5), true); + assertNull(request.getTruncationInfo()); + } + + public static LlamaChatCompletionRequest createRequest(String modelId, String url, String apiKey, String input, boolean stream) { + var chatCompletionModel = LlamaChatCompletionModelTests.createChatCompletionModel(modelId, url, apiKey); + return new LlamaChatCompletionRequest(new UnifiedChatInput(List.of(input), "user", stream), chatCompletionModel); + } + + public static LlamaChatCompletionRequest createRequestWithNoAuth(String modelId, String url, String input, boolean stream) { + var chatCompletionModel = LlamaChatCompletionModelTests.createChatCompletionModelNoAuth(modelId, url); + return new LlamaChatCompletionRequest(new UnifiedChatInput(List.of(input), "user", stream), chatCompletionModel); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/request/embeddings/LlamaEmbeddingsRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/request/embeddings/LlamaEmbeddingsRequestEntityTests.java new file mode 100644 index 0000000000000..a055a0870e30d --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/request/embeddings/LlamaEmbeddingsRequestEntityTests.java @@ -0,0 +1,38 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.request.embeddings; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; + +import java.io.IOException; +import java.util.List; + +import static org.hamcrest.CoreMatchers.is; + +public class LlamaEmbeddingsRequestEntityTests extends ESTestCase { + + public void testXContent_Success() throws IOException { + var entity = new LlamaEmbeddingsRequestEntity("llama-embed", List.of("ABDC")); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is(XContentHelper.stripWhitespace(""" + { + "model_id": "llama-embed", + "contents": ["ABDC"] + } + """))); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/request/embeddings/LlamaEmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/request/embeddings/LlamaEmbeddingsRequestTests.java new file mode 100644 index 0000000000000..ab24fa9a0bc56 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/request/embeddings/LlamaEmbeddingsRequestTests.java @@ -0,0 +1,101 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.request.embeddings; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.common.Truncator; +import org.elasticsearch.xpack.inference.common.TruncatorTests; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsModelTests; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; + +public class LlamaEmbeddingsRequestTests extends ESTestCase { + + public void testCreateRequest_WithAuth_Success() throws IOException { + var request = createRequest(); + var httpRequest = request.createHttpRequest(); + var httpPost = validateRequestUrlAndContentType(httpRequest); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap, aMapWithSize(2)); + assertThat(requestMap.get("contents"), is(List.of("ABCD"))); + assertThat(requestMap.get("model_id"), is("llama-embed")); + assertThat(httpPost.getFirstHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer apikey")); + } + + public void testCreateRequest_NoAuth_Success() throws IOException { + var request = createRequestNoAuth(); + var httpRequest = request.createHttpRequest(); + var httpPost = validateRequestUrlAndContentType(httpRequest); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap, aMapWithSize(2)); + assertThat(requestMap.get("contents"), is(List.of("ABCD"))); + assertThat(requestMap.get("model_id"), is("llama-embed")); + assertNull(httpPost.getFirstHeader("Authorization")); + } + + public void testTruncate_ReducesInputTextSizeByHalf() throws IOException { + var request = createRequest(); + var truncatedRequest = request.truncate(); + + var httpRequest = truncatedRequest.createHttpRequest(); + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap, aMapWithSize(2)); + assertThat(requestMap.get("contents"), is(List.of("AB"))); + assertThat(requestMap.get("model_id"), is("llama-embed")); + } + + public void testIsTruncated_ReturnsTrue() { + var request = createRequest(); + assertFalse(request.getTruncationInfo()[0]); + + var truncatedRequest = request.truncate(); + assertTrue(truncatedRequest.getTruncationInfo()[0]); + } + + private HttpPost validateRequestUrlAndContentType(HttpRequest request) { + assertThat(request.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) request.httpRequestBase(); + assertThat(httpPost.getURI().toString(), is("url")); + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaTypeWithoutParameters())); + return httpPost; + } + + private static LlamaEmbeddingsRequest createRequest() { + var embeddingsModel = LlamaEmbeddingsModelTests.createEmbeddingsModel("llama-embed", "url", "apikey"); + return new LlamaEmbeddingsRequest( + TruncatorTests.createTruncator(), + new Truncator.TruncationResult(List.of("ABCD"), new boolean[] { false }), + embeddingsModel + ); + } + + private static LlamaEmbeddingsRequest createRequestNoAuth() { + var embeddingsModel = LlamaEmbeddingsModelTests.createEmbeddingsModelNoAuth("llama-embed", "url"); + return new LlamaEmbeddingsRequest( + TruncatorTests.createTruncator(), + new Truncator.TruncationResult(List.of("ABCD"), new boolean[] { false }), + embeddingsModel + ); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/response/LlamaErrorResponseTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/response/LlamaErrorResponseTests.java new file mode 100644 index 0000000000000..aa3c6f6c20b6e --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/response/LlamaErrorResponseTests.java @@ -0,0 +1,34 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.response; + +import org.apache.http.HttpResponse; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.http.HttpResult; + +import java.nio.charset.StandardCharsets; + +import static org.mockito.Mockito.mock; + +public class LlamaErrorResponseTests extends ESTestCase { + + public static final String ERROR_RESPONSE_JSON = """ + { + "error": "A valid user token is required" + } + """; + + public void testFromResponse() { + var errorResponse = LlamaErrorResponse.fromResponse( + new HttpResult(mock(HttpResponse.class), ERROR_RESPONSE_JSON.getBytes(StandardCharsets.UTF_8)) + ); + assertNotNull(errorResponse); + assertEquals(ERROR_RESPONSE_JSON, errorResponse.getErrorMessage()); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingModelTests.java index 6f8b40fd7f19c..9aa076e224efe 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingModelTests.java @@ -10,7 +10,6 @@ import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.ChunkingSettings; -import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; @@ -37,7 +36,6 @@ public static MistralEmbeddingsModel createModel( TaskType.TEXT_EMBEDDING, "mistral", new MistralEmbeddingsServiceSettings(model, dimensions, maxTokens, similarity, rateLimitSettings), - EmptyTaskSettings.INSTANCE, chunkingSettings, new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) ); @@ -57,7 +55,6 @@ public static MistralEmbeddingsModel createModel( TaskType.TEXT_EMBEDDING, "mistral", new MistralEmbeddingsServiceSettings(model, dimensions, maxTokens, similarity, rateLimitSettings), - EmptyTaskSettings.INSTANCE, null, new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/request/completion/MistralChatCompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/request/completion/MistralChatCompletionRequestTests.java index 4a70861932d28..2c8fb4fd48698 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/request/completion/MistralChatCompletionRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/request/completion/MistralChatCompletionRequestTests.java @@ -49,7 +49,7 @@ public void testTruncate_DoesNotReduceInputTextSize() throws IOException { var requestMap = entityAsMap(httpPost.getEntity().getContent()); assertThat(requestMap, aMapWithSize(4)); - // We do not truncate for Hugging Face chat completions + // We do not truncate for Mistral chat completions assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", input)))); assertThat(requestMap.get("model"), is("model")); assertThat(requestMap.get("n"), is(1));