From ec94f530a66d6b101e822cb205f1651eb1b73f47 Mon Sep 17 00:00:00 2001 From: Pat Whelan Date: Fri, 31 Jan 2025 15:49:38 -0500 Subject: [PATCH 01/11] [ML] Integrate with DeepSeek API Integrating for Chat Completion and Completion task types, both calling the chat completion API for DeepSeek. --- .../org/elasticsearch/TransportVersions.java | 1 + .../InferenceNamedWriteablesProvider.java | 2 + .../xpack/inference/InferencePlugin.java | 2 + .../http/sender/DeepSeekRequestManager.java | 84 +++++ .../DeepSeekChatCompletionRequest.java | 93 +++++ .../deepseek/DeepSeekChatCompletionModel.java | 198 ++++++++++ .../services/deepseek/DeepSeekService.java | 233 ++++++++++++ .../deepseek/DeepSeekServiceTests.java | 347 ++++++++++++++++++ 8 files changed, 960 insertions(+) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/DeepSeekRequestManager.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/deepseek/DeepSeekChatCompletionRequest.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekChatCompletionModel.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 37473c565189b..fadf67ebcfbd4 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -185,6 +185,7 @@ static TransportVersion def(int id) { public static final TransportVersion ESQL_DRIVER_TASK_DESCRIPTION = def(9_005_0_00); public static final TransportVersion ESQL_RETRY_ON_SHARD_LEVEL_FAILURE = def(9_006_0_00); public static final TransportVersion ESQL_PROFILE_ASYNC_NANOS = def(9_007_00_0); + public static final TransportVersion ML_INFERENCE_DEEPSEEK = def(9_008_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index e8dc763116707..94042932c9127 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -56,6 +56,7 @@ import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings; import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankServiceSettings; import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings; +import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalServiceSettings; import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalTextEmbeddingServiceSettings; @@ -144,6 +145,7 @@ public static List getNamedWriteables() { addUnifiedNamedWriteables(namedWriteables); namedWriteables.addAll(StreamingTaskManager.namedWriteables()); + namedWriteables.addAll(DeepSeekChatCompletionModel.namedWriteables()); return namedWriteables; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index e3604351c1937..694dd244f128f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -114,6 +114,7 @@ import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioService; import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiService; import org.elasticsearch.xpack.inference.services.cohere.CohereService; +import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekService; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings; @@ -357,6 +358,7 @@ public List getInferenceServiceFactories() { context -> new AlibabaCloudSearchService(httpFactory.get(), serviceComponents.get()), context -> new IbmWatsonxService(httpFactory.get(), serviceComponents.get()), context -> new JinaAIService(httpFactory.get(), serviceComponents.get()), + context -> new DeepSeekService(httpFactory.get(), serviceComponents.get()), ElasticsearchInternalService::new ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/DeepSeekRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/DeepSeekRequestManager.java new file mode 100644 index 0000000000000..5f925e3286959 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/DeepSeekRequestManager.java @@ -0,0 +1,84 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.openai.OpenAiChatCompletionResponseHandler; +import org.elasticsearch.xpack.inference.external.openai.OpenAiUnifiedChatCompletionResponseHandler; +import org.elasticsearch.xpack.inference.external.request.deepseek.DeepSeekChatCompletionRequest; +import org.elasticsearch.xpack.inference.external.response.openai.OpenAiChatCompletionResponseEntity; +import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel; + +import java.util.Objects; +import java.util.function.Supplier; + +import static org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs.createUnsupportedTypeException; + +public class DeepSeekRequestManager extends BaseRequestManager { + + private static final Logger logger = LogManager.getLogger(DeepSeekRequestManager.class); + + private static final ResponseHandler CHAT_COMPLETION = createChatCompletionHandler(); + private static final ResponseHandler COMPLETION = createCompletionHandler(); + + private final DeepSeekChatCompletionModel model; + + public DeepSeekRequestManager(DeepSeekChatCompletionModel model, ThreadPool threadPool) { + super(threadPool, model.getInferenceEntityId(), model.rateLimitGroup(), model.rateLimitSettings()); + this.model = Objects.requireNonNull(model); + } + + @Override + public void execute( + InferenceInputs inferenceInputs, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + ActionListener listener + ) { + switch (inferenceInputs) { + case UnifiedChatInput uci -> execute(uci, requestSender, hasRequestCompletedFunction, listener); + case ChatCompletionInput cci -> execute(cci, requestSender, hasRequestCompletedFunction, listener); + default -> throw createUnsupportedTypeException(inferenceInputs, UnifiedChatInput.class); + } + } + + private void execute( + UnifiedChatInput inferenceInputs, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + ActionListener listener + ) { + var request = new DeepSeekChatCompletionRequest(inferenceInputs, model); + execute(new ExecutableInferenceRequest(requestSender, logger, request, CHAT_COMPLETION, hasRequestCompletedFunction, listener)); + } + + private void execute( + ChatCompletionInput inferenceInputs, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + ActionListener listener + ) { + var unifiedInputs = new UnifiedChatInput(inferenceInputs.getInputs(), "user", inferenceInputs.stream()); + var request = new DeepSeekChatCompletionRequest(unifiedInputs, model); + execute(new ExecutableInferenceRequest(requestSender, logger, request, COMPLETION, hasRequestCompletedFunction, listener)); + } + + private static ResponseHandler createChatCompletionHandler() { + return new OpenAiUnifiedChatCompletionResponseHandler("deepseek chat completion", OpenAiChatCompletionResponseEntity::fromResponse); + } + + private static ResponseHandler createCompletionHandler() { + return new OpenAiChatCompletionResponseHandler("deepseek completion", OpenAiChatCompletionResponseEntity::fromResponse); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/deepseek/DeepSeekChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/deepseek/DeepSeekChatCompletionRequest.java new file mode 100644 index 0000000000000..e0db6d2021375 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/deepseek/DeepSeekChatCompletionRequest.java @@ -0,0 +1,93 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.deepseek; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.entity.ByteArrayEntity; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.common.Strings; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xcontent.json.JsonXContent; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.external.unified.UnifiedChatCompletionRequestEntity; +import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel; + +import java.io.IOException; +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.openai.OpenAiUnifiedStreamingProcessor.MODEL_FIELD; +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; + +public class DeepSeekChatCompletionRequest implements Request { + + private final DeepSeekChatCompletionModel model; + private final UnifiedChatInput unifiedChatInput; + + public DeepSeekChatCompletionRequest(UnifiedChatInput unifiedChatInput, DeepSeekChatCompletionModel model) { + this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput); + this.model = Objects.requireNonNull(model); + } + + @Override + public HttpRequest createHttpRequest() { + HttpPost httpPost = new HttpPost(model.uri()); + + httpPost.setEntity(createEntity()); + + httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()); + httpPost.setHeader(createAuthBearerHeader(model.apiKey())); + + return new HttpRequest(httpPost, getInferenceEntityId()); + } + + private ByteArrayEntity createEntity() { + var modelId = Objects.requireNonNullElseGet(unifiedChatInput.getRequest().model(), model::model); + try (var builder = JsonXContent.contentBuilder()) { + builder.startObject(); + new UnifiedChatCompletionRequestEntity(unifiedChatInput).toXContent(builder, ToXContent.EMPTY_PARAMS); + builder.field(MODEL_FIELD, modelId); + builder.endObject(); + return new ByteArrayEntity(Strings.toString(builder).getBytes(StandardCharsets.UTF_8)); + } catch (IOException e) { + throw new ElasticsearchException("Failed to serialize request payload.", e); + } + } + + @Override + public URI getURI() { + return model.uri(); + } + + @Override + public Request truncate() { + // No truncation for OpenAI chat completions + return this; + } + + @Override + public boolean[] getTruncationInfo() { + // No truncation for OpenAI chat completions + return null; + } + + @Override + public String getInferenceEntityId() { + return model.getInferenceEntityId(); + } + + @Override + public boolean isStreaming() { + return unifiedChatInput.stream(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekChatCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekChatCompletionModel.java new file mode 100644 index 0000000000000..2391eb49ff04b --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekChatCompletionModel.java @@ -0,0 +1,198 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.deepseek; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.net.URI; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; +import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createOptionalUri; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredSecureString; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; + +/** + * Design notes: + * This provider tries to match the OpenAI, so we'll design around that as well. + * + * Task Type: + * - Chat Completion + * + * Service Settings: + * - api_key + * - model + * - url + * + * Task Settings: + * - nothing? + * + * Rate Limiting: + * - The website claims to want unlimited, so we're setting it as MAX_INT per minute? + */ +public class DeepSeekChatCompletionModel extends Model { + private static final Object RATE_LIMIT_GROUP = new Object(); + private static final RateLimitSettings RATE_LIMIT_SETTINGS = new RateLimitSettings(Integer.MAX_VALUE); + private static final URI DEFAULT_URI = URI.create("https://api.deepseek.com/chat/completions"); + private final DeepSeekServiceSettings serviceSettings; + private final DefaultSecretSettings secretSettings; + + public static List namedWriteables() { + return List.of(new NamedWriteableRegistry.Entry(ServiceSettings.class, DeepSeekServiceSettings.NAME, DeepSeekServiceSettings::new)); + } + + public static DeepSeekChatCompletionModel createFromNewInput( + String inferenceEntityId, + TaskType taskType, + String service, + Map serviceSettingsMap + ) { + var validationException = new ValidationException(); + + var model = extractRequiredString(serviceSettingsMap, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + var uri = createOptionalUri( + extractOptionalString(serviceSettingsMap, URL, ModelConfigurations.SERVICE_SETTINGS, validationException) + ); + var secureApiToken = extractRequiredSecureString( + serviceSettingsMap, + "api_key", + ModelConfigurations.SERVICE_SETTINGS, + validationException + ); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + var serviceSettings = new DeepSeekServiceSettings(model, uri); + var taskSettings = new EmptyTaskSettings(); + var secretSettings = new DefaultSecretSettings(secureApiToken); + var modelConfigurations = new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings); + return new DeepSeekChatCompletionModel(serviceSettings, secretSettings, modelConfigurations, new ModelSecrets(secretSettings)); + } + + public static DeepSeekChatCompletionModel readFromStorage( + String inferenceEntityId, + TaskType taskType, + String service, + Map serviceSettingsMap, + Map secrets + ) { + var validationException = new ValidationException(); + + var model = extractRequiredString(serviceSettingsMap, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + var uri = createOptionalUri( + extractOptionalString(serviceSettingsMap, "url", ModelConfigurations.SERVICE_SETTINGS, validationException) + ); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + var serviceSettings = new DeepSeekServiceSettings(model, uri); + var taskSettings = new EmptyTaskSettings(); + var secretSettings = DefaultSecretSettings.fromMap(secrets); + var modelConfigurations = new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings); + return new DeepSeekChatCompletionModel(serviceSettings, secretSettings, modelConfigurations, new ModelSecrets(secretSettings)); + } + + private DeepSeekChatCompletionModel( + DeepSeekServiceSettings serviceSettings, + DefaultSecretSettings secretSettings, + ModelConfigurations configurations, + ModelSecrets secrets + ) { + super(configurations, secrets); + this.serviceSettings = serviceSettings; + this.secretSettings = secretSettings; + } + + public SecureString apiKey() { + return secretSettings.apiKey(); + } + + public String model() { + return serviceSettings.modelId(); + } + + public URI uri() { + return serviceSettings.uri() != null ? serviceSettings.uri() : DEFAULT_URI; + } + + public Object rateLimitGroup() { + return RATE_LIMIT_GROUP; + } + + public RateLimitSettings rateLimitSettings() { + return RATE_LIMIT_SETTINGS; + } + + private record DeepSeekServiceSettings(String modelId, URI uri) implements ServiceSettings { + private static final String NAME = "deep_seek_service_settings"; + + DeepSeekServiceSettings { + Objects.requireNonNull(modelId); + } + + DeepSeekServiceSettings(StreamInput in) throws IOException { + this(in.readString(), in.readOptional(url -> URI.create(url.readString()))); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_DEEPSEEK; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(modelId); + out.writeOptionalString(uri != null ? uri.toString() : null); + } + + @Override + public ToXContentObject getFilteredXContentObject() { + return this; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(MODEL_ID, modelId); + if (uri != null) { + builder.field(URL, uri.toString()); + } + return builder.endObject(); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java new file mode 100644 index 0000000000000..c0478763a8803 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java @@ -0,0 +1,233 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.deepseek; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.util.LazyInitializable; +import org.elasticsearch.core.Strings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SettingsConfiguration; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; +import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; +import org.elasticsearch.xpack.inference.external.http.sender.DeepSeekRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.SenderService; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder; + +import java.util.EnumSet; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; +import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; + +public class DeepSeekService extends SenderService { + private static final String NAME = "deepseek"; + private static final String CHAT_COMPLETION_ERROR_PREFIX = "deepseek chat completions"; + private static final String COMPLETION_ERROR_PREFIX = "deepseek completions"; + private static final String SERVICE_NAME = "DeepSeek"; + // The task types exposed via the _inference/_services API + private static final EnumSet SUPPORTED_TASK_TYPES_FOR_SERVICES_API = EnumSet.of( + TaskType.COMPLETION, + TaskType.CHAT_COMPLETION + ); + + public DeepSeekService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { + super(factory, serviceComponents); + } + + @Override + protected void doInfer( + Model model, + InferenceInputs inputs, + Map taskSettings, + InputType inputType, + TimeValue timeout, + ActionListener listener + ) { + doInfer(model, inputs, timeout, COMPLETION_ERROR_PREFIX, listener); + } + + private void doInfer( + Model model, + InferenceInputs inputs, + TimeValue timeout, + String errorPrefix, + ActionListener listener + ) { + if (model instanceof DeepSeekChatCompletionModel deepSeekModel) { + var requestCreator = new DeepSeekRequestManager(deepSeekModel, getServiceComponents().threadPool()); + var errorMessage = constructFailedToSendRequestMessage(deepSeekModel.uri(), errorPrefix); + var action = new SenderExecutableAction(getSender(), requestCreator, errorMessage); + action.execute(inputs, timeout, listener); + } else { + listener.onFailure(createInvalidModelException(model)); + } + } + + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + doInfer(model, inputs, timeout, CHAT_COMPLETION_ERROR_PREFIX, listener); + } + + @Override + protected void doChunkedInfer( + Model model, + DocumentsOnlyInput inputs, + Map taskSettings, + InputType inputType, + TimeValue timeout, + ActionListener> listener + ) { + listener.onFailure(new UnsupportedOperationException(Strings.format("The %s service only supports unified completion", NAME))); + } + + @Override + public String name() { + return NAME; + } + + @Override + public void parseRequestConfig( + String modelId, + TaskType taskType, + Map config, + ActionListener parsedModelListener + ) { + ActionListener.completeWith(parsedModelListener, () -> { + var serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + try { + return DeepSeekChatCompletionModel.createFromNewInput(modelId, taskType, NAME, serviceSettingsMap); + } finally { + throwIfNotEmptyMap(serviceSettingsMap, NAME); + } + }); + } + + @Override + public Model parsePersistedConfigWithSecrets( + String modelId, + TaskType taskType, + Map config, + Map secrets + ) { + var serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + var secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS); + return DeepSeekChatCompletionModel.readFromStorage(modelId, taskType, NAME, serviceSettingsMap, secretSettingsMap); + } + + @Override + public Model parsePersistedConfig(String modelId, TaskType taskType, Map config) { + return parsePersistedConfigWithSecrets(modelId, taskType, config, config); + } + + @Override + public InferenceServiceConfiguration getConfiguration() { + return Configuration.get(); + } + + @Override + public EnumSet supportedTaskTypes() { + return SUPPORTED_TASK_TYPES_FOR_SERVICES_API; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_DEEPSEEK; + } + + @Override + public Set supportedStreamingTasks() { + return EnumSet.of(TaskType.CHAT_COMPLETION); + } + + @Override + public void checkModelConfig(Model model, ActionListener listener) { + // TODO: Remove this function once all services have been updated to use the new model validators + ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener); + } + + private static class Configuration { + public static InferenceServiceConfiguration get() { + return configuration.getOrCompute(); + } + + private static final LazyInitializable configuration = new LazyInitializable<>( + () -> { + var configurationMap = new HashMap(); + + configurationMap.put( + MODEL_ID, + new SettingsConfiguration.Builder(SUPPORTED_TASK_TYPES_FOR_SERVICES_API).setDescription( + "The name of the model to use for the inference task." + ) + .setLabel("Model ID") + .setRequired(true) + .setSensitive(false) + .setUpdatable(false) + .setType(SettingsConfigurationFieldType.STRING) + .build() + ); + + configurationMap.putAll( + DefaultSecretSettings.toSettingsConfigurationWithDescription( + "The DeepSeek API authentication key. For more details about generating DeepSeek API keys, " + + "refer to https://api-docs.deepseek.com.", + SUPPORTED_TASK_TYPES_FOR_SERVICES_API + ) + ); + + configurationMap.put( + URL, + new SettingsConfiguration.Builder(SUPPORTED_TASK_TYPES_FOR_SERVICES_API).setDefaultValue( + "https://api.deepseek.com/chat/completions" + ) + .setDescription("The URL endpoint to use for the requests.") + .setLabel("URL") + .setRequired(false) + .setSensitive(false) + .setUpdatable(false) + .setType(SettingsConfigurationFieldType.STRING) + .build() + ); + + return new InferenceServiceConfiguration.Builder().setService(NAME) + .setName(SERVICE_NAME) + .setTaskTypes(SUPPORTED_TASK_TYPES_FOR_SERVICES_API) + .setConfigurations(configurationMap) + .build(); + } + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java new file mode 100644 index 0000000000000..9b5b92b9c90c3 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java @@ -0,0 +1,347 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.deepseek; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.action.support.ActionTestUtils.assertNoFailureListener; +import static org.elasticsearch.action.support.ActionTestUtils.assertNoSuccessListener; +import static org.elasticsearch.common.Strings.format; +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.isA; +import static org.mockito.Mockito.mock; + +public class DeepSeekServiceTests extends ESTestCase { + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private final MockWebServer webServer = new MockWebServer(); + private ThreadPool threadPool; + private HttpClientManager clientManager; + + @Before + public void init() throws Exception { + webServer.start(); + threadPool = createThreadPool(inferenceUtilityPool()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @After + public void shutdown() throws IOException { + clientManager.close(); + terminate(threadPool); + webServer.close(); + } + + public void testParseRequestConfig() throws IOException, URISyntaxException { + parseRequestConfig(format(""" + { + "service_settings": { + "api_key": "12345", + "model_id": "some-cool-model", + "url": "%s" + } + } + """, webServer.getUri(null).toString()), assertNoFailureListener(model -> { + if (model instanceof DeepSeekChatCompletionModel deepSeekModel) { + assertThat(deepSeekModel.apiKey().getChars(), equalTo("12345".toCharArray())); + assertThat(deepSeekModel.model(), equalTo("some-cool-model")); + assertThat(deepSeekModel.uri(), equalTo(webServer.getUri(null))); + } else { + fail("Expected DeepSeekModel, found " + (model != null ? model.getClass().getSimpleName() : "null")); + } + })); + } + + public void testParseRequestConfigWithoutApiKey() throws IOException { + parseRequestConfig(""" + { + "service_settings": { + "model_id": "some-cool-model" + } + } + """, assertNoSuccessListener(e -> { + if (e instanceof ValidationException ve) { + assertThat( + ve.getMessage(), + equalTo("Validation Failed: 1: [service_settings] does not contain the required setting [api_key];") + ); + } + })); + } + + public void testParseRequestConfigWithoutModel() throws IOException { + parseRequestConfig(""" + { + "service_settings": { + "api_key": "1234" + } + } + """, assertNoSuccessListener(e -> { + if (e instanceof ValidationException ve) { + assertThat( + ve.getMessage(), + equalTo("Validation Failed: 1: [service_settings] does not contain the required setting [model_id];") + ); + } + })); + } + + public void testParseRequestConfigWithExtraSettings() throws IOException { + parseRequestConfig( + """ + { + "service_settings": { + "api_key": "12345", + "model_id": "some-cool-model", + "so": "extra" + } + } + """, + assertNoSuccessListener( + e -> assertThat( + e.getMessage(), + equalTo("Model configuration contains settings [{so=extra}] unknown to the [deepseek] service") + ) + ) + ); + } + + public void testParsePersistedConfig() throws IOException { + var deepSeekModel = parsePersistedConfig(""" + { + "service_settings": { + "model_id": "some-cool-model" + }, + "secret_settings": { + "api_key": "12345" + } + } + """); + assertThat(deepSeekModel.apiKey().getChars(), equalTo("12345".toCharArray())); + assertThat(deepSeekModel.model(), equalTo("some-cool-model")); + } + + public void testParsePersistedConfigWithUrl() throws IOException { + var deepSeekModel = parsePersistedConfig(""" + { + "service_settings": { + "model_id": "some-cool-model", + "url": "http://localhost:989" + }, + "secret_settings": { + "api_key": "12345" + } + } + """); + assertThat(deepSeekModel.apiKey().getChars(), equalTo("12345".toCharArray())); + assertThat(deepSeekModel.model(), equalTo("some-cool-model")); + assertThat(deepSeekModel.uri(), equalTo(URI.create("http://localhost:989"))); + } + + public void testParsePersistedConfigWithoutApiKey() { + assertThrows( + "Validation Failed: 1: [secret_settings] does not contain the required setting [api_key];", + ValidationException.class, + () -> parsePersistedConfig(""" + { + "service_settings": { + "model_id": "some-cool-model" + }, + "secret_settings": { + } + } + """) + ); + } + + public void testParsePersistedConfigWithoutModel() { + assertThrows( + "Validation Failed: 1: [service_settings] does not contain the required setting [model];", + ValidationException.class, + () -> parsePersistedConfig(""" + { + "service_settings": { + }, + "secret_settings": { + "api_key": "12345" + } + } + """) + ); + } + + public void testParsePersistedConfigWithoutServiceSettings() { + assertThrows( + "Validation Failed: 1: [service_settings] does not contain the required setting [model];", + ElasticsearchStatusException.class, + () -> parsePersistedConfig(""" + { + "secret_settings": { + "api_key": "12345" + } + } + """) + ); + } + + public void testDoUnifiedInfer() throws Exception { + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(""" + data: {"choices": [{"delta": {"content": "hello, world", "role": "assistant"}, "finish_reason": null, "index": 0, \ + "logprobs": null}], "created": 1718345013, "id": "12345", "model": "deepseek-chat", \ + "object": "chat.completion.chunk", "system_fingerprint": "fp_1234"} + + data: [DONE] + + """)); + var result = doUnifiedCompletionInfer(); + InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoErrors().hasEvent(""" + {"id":"12345","choices":[{"delta":{"content":"hello, world","role":"assistant"},"index":0}],""" + """ + "model":"deepseek-chat","object":"chat.completion.chunk"}"""); + } + + public void testDoInfer() throws Exception { + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(""" + {"choices": [{"message": {"content": "hello, world", "role": "assistant"}, "finish_reason": "stop", "index": 0, \ + "logprobs": null}], "created": 1718345013, "id": "12345", "model": "deepseek-chat", \ + "object": "chat.completion", "system_fingerprint": "fp_1234"}""")); + var result = doInfer(false); + assertThat(result, isA(ChatCompletionResults.class)); + var completionResults = (ChatCompletionResults) result; + assertThat( + completionResults.results().stream().map(ChatCompletionResults.Result::predictedValue).toList(), + equalTo(List.of("hello, world")) + ); + } + + public void testDoInferStream() throws Exception { + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(""" + data: {"choices": [{"delta": {"content": "hello, world", "role": "assistant"}, "finish_reason": null, "index": 0, \ + "logprobs": null}], "created": 1718345013, "id": "12345", "model": "deepseek-chat", \ + "object": "chat.completion.chunk", "system_fingerprint": "fp_1234"} + + data: [DONE] + + """)); + var result = doInfer(true); + InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoErrors().hasEvent(""" + {"completion":[{"delta":"hello, world"}]}"""); + } + + public void testDoChunkedInferAlwaysFails() throws IOException { + try (var service = createService()) { + service.doChunkedInfer(mock(), mock(), Map.of(), InputType.UNSPECIFIED, TIMEOUT, assertNoSuccessListener(e -> { + assertThat(e, isA(UnsupportedOperationException.class)); + assertThat(e.getMessage(), equalTo("The deepseek service only supports unified completion")); + })); + } + } + + private DeepSeekService createService() { + return new DeepSeekService( + HttpRequestSenderTests.createSenderFactory(threadPool, clientManager), + createWithEmptySettings(threadPool) + ); + } + + private void parseRequestConfig(String json, ActionListener listener) throws IOException { + try (var service = createService()) { + service.parseRequestConfig("inference-id", TaskType.CHAT_COMPLETION, map(json), listener); + } + } + + private Map map(String json) throws IOException { + try ( + var parser = XContentType.JSON.xContent().createParser(XContentParserConfiguration.EMPTY, json.getBytes(StandardCharsets.UTF_8)) + ) { + return parser.map(); + } + } + + private DeepSeekChatCompletionModel parsePersistedConfig(String json) throws IOException { + try (var service = createService()) { + var model = service.parsePersistedConfig("inference-id", TaskType.CHAT_COMPLETION, map(json)); + assertThat(model, isA(DeepSeekChatCompletionModel.class)); + return (DeepSeekChatCompletionModel) model; + } + } + + private InferenceServiceResults doUnifiedCompletionInfer() throws Exception { + try (var service = createService()) { + var model = createModel(service, TaskType.CHAT_COMPLETION); + PlainActionFuture listener = new PlainActionFuture<>(); + service.unifiedCompletionInfer( + model, + UnifiedCompletionRequest.of( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null)) + ), + TIMEOUT, + listener + ); + return listener.get(30, TimeUnit.SECONDS); + } + } + + private InferenceServiceResults doInfer(boolean stream) throws Exception { + try (var service = createService()) { + var model = createModel(service, TaskType.COMPLETION); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer(model, null, List.of("hello"), stream, Map.of(), InputType.UNSPECIFIED, TIMEOUT, listener); + return listener.get(30, TimeUnit.SECONDS); + } + } + + private DeepSeekChatCompletionModel createModel(DeepSeekService service, TaskType taskType) throws URISyntaxException, IOException { + var model = service.parsePersistedConfig("inference-id", taskType, map(Strings.format(""" + { + "service_settings": { + "model_id": "some-cool-model", + "url": "%s" + }, + "secret_settings": { + "api_key": "12345" + } + } + """, webServer.getUri(null).toString()))); + assertThat(model, isA(DeepSeekChatCompletionModel.class)); + return (DeepSeekChatCompletionModel) model; + } +} From 746ed7c3b4b6692862d198342508a1579fd0a944 Mon Sep 17 00:00:00 2001 From: Pat Whelan Date: Mon, 10 Feb 2025 17:05:59 -0500 Subject: [PATCH 02/11] Update docs/changelog/122218.yaml --- docs/changelog/122218.yaml | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 docs/changelog/122218.yaml diff --git a/docs/changelog/122218.yaml b/docs/changelog/122218.yaml new file mode 100644 index 0000000000000..bfd44399e3e8d --- /dev/null +++ b/docs/changelog/122218.yaml @@ -0,0 +1,5 @@ +pr: 122218 +summary: Integrate with `DeepSeek` API +area: Machine Learning +type: enhancement +issues: [] From 5d5bd0a7b146770ee580173e013a2fde4d9653e6 Mon Sep 17 00:00:00 2001 From: Pat Whelan Date: Tue, 25 Feb 2025 12:19:00 -0500 Subject: [PATCH 03/11] Finish integration --- .../DeepSeekChatCompletionRequest.java | 8 +- .../DeepSeekChatCompletionRequestEntity.java | 186 +++++ ...pSeekChatCompletionRequestEntityTests.java | 648 ++++++++++++++++++ .../deepseek/DeepSeekServiceTests.java | 95 +++ 4 files changed, 930 insertions(+), 7 deletions(-) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/deepseek/DeepSeekChatCompletionRequestEntity.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/deepseek/DeepSeekChatCompletionRequestEntityTests.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/deepseek/DeepSeekChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/deepseek/DeepSeekChatCompletionRequest.java index e0db6d2021375..4a927f9fac851 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/deepseek/DeepSeekChatCompletionRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/deepseek/DeepSeekChatCompletionRequest.java @@ -18,7 +18,6 @@ import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.external.request.HttpRequest; import org.elasticsearch.xpack.inference.external.request.Request; -import org.elasticsearch.xpack.inference.external.unified.UnifiedChatCompletionRequestEntity; import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel; import java.io.IOException; @@ -26,7 +25,6 @@ import java.nio.charset.StandardCharsets; import java.util.Objects; -import static org.elasticsearch.xpack.inference.external.openai.OpenAiUnifiedStreamingProcessor.MODEL_FIELD; import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; public class DeepSeekChatCompletionRequest implements Request { @@ -52,12 +50,8 @@ public HttpRequest createHttpRequest() { } private ByteArrayEntity createEntity() { - var modelId = Objects.requireNonNullElseGet(unifiedChatInput.getRequest().model(), model::model); try (var builder = JsonXContent.contentBuilder()) { - builder.startObject(); - new UnifiedChatCompletionRequestEntity(unifiedChatInput).toXContent(builder, ToXContent.EMPTY_PARAMS); - builder.field(MODEL_FIELD, modelId); - builder.endObject(); + new DeepSeekChatCompletionRequestEntity(unifiedChatInput, model).toXContent(builder, ToXContent.EMPTY_PARAMS); return new ByteArrayEntity(Strings.toString(builder).getBytes(StandardCharsets.UTF_8)); } catch (IOException e) { throw new ElasticsearchException("Failed to serialize request payload.", e); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/deepseek/DeepSeekChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/deepseek/DeepSeekChatCompletionRequestEntity.java new file mode 100644 index 0000000000000..e2c942e7cef37 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/deepseek/DeepSeekChatCompletionRequestEntity.java @@ -0,0 +1,186 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.deepseek; + +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.xcontent.ToXContentFragment; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel; + +import java.io.IOException; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.openai.OpenAiUnifiedStreamingProcessor.MODEL_FIELD; + +class DeepSeekChatCompletionRequestEntity implements ToXContentFragment { + + public static final String NAME_FIELD = "name"; + public static final String TOOL_CALL_ID_FIELD = "tool_call_id"; + public static final String TOOL_CALLS_FIELD = "tool_calls"; + public static final String ID_FIELD = "id"; + public static final String FUNCTION_FIELD = "function"; + public static final String ARGUMENTS_FIELD = "arguments"; + public static final String DESCRIPTION_FIELD = "description"; + public static final String PARAMETERS_FIELD = "parameters"; + public static final String STRICT_FIELD = "strict"; + public static final String TOP_P_FIELD = "top_p"; + public static final String STREAM_FIELD = "stream"; + private static final String NUMBER_OF_RETURNED_CHOICES_FIELD = "n"; + public static final String MESSAGES_FIELD = "messages"; + private static final String ROLE_FIELD = "role"; + private static final String CONTENT_FIELD = "content"; + private static final String MAX_TOKENS = "max_tokens"; + private static final String STOP_FIELD = "stop"; + private static final String TEMPERATURE_FIELD = "temperature"; + private static final String TOOL_CHOICE_FIELD = "tool_choice"; + private static final String TOOL_FIELD = "tools"; + private static final String TEXT_FIELD = "text"; + private static final String TYPE_FIELD = "type"; + private static final String STREAM_OPTIONS_FIELD = "stream_options"; + private static final String INCLUDE_USAGE_FIELD = "include_usage"; + + private final DeepSeekChatCompletionModel model; + private final UnifiedCompletionRequest unifiedRequest; + private final boolean stream; + + DeepSeekChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, DeepSeekChatCompletionModel model) { + Objects.requireNonNull(unifiedChatInput); + this.model = Objects.requireNonNull(model); + this.unifiedRequest = unifiedChatInput.getRequest(); + this.stream = unifiedChatInput.stream(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + builder.startArray(MESSAGES_FIELD); + { + for (UnifiedCompletionRequest.Message message : unifiedRequest.messages()) { + builder.startObject(); + { + switch (message.content()) { + case UnifiedCompletionRequest.ContentString contentString -> builder.field(CONTENT_FIELD, contentString.content()); + case UnifiedCompletionRequest.ContentObjects contentObjects -> { + builder.startArray(CONTENT_FIELD); + for (UnifiedCompletionRequest.ContentObject contentObject : contentObjects.contentObjects()) { + builder.startObject(); + builder.field(TEXT_FIELD, contentObject.text()); + builder.field(TYPE_FIELD, contentObject.type()); + builder.endObject(); + } + builder.endArray(); + } + case null -> { + // do nothing because content is optional + } + } + + builder.field(ROLE_FIELD, message.role()); + if (message.toolCallId() != null) { + builder.field(TOOL_CALL_ID_FIELD, message.toolCallId()); + } + if (message.toolCalls() != null) { + builder.startArray(TOOL_CALLS_FIELD); + for (UnifiedCompletionRequest.ToolCall toolCall : message.toolCalls()) { + builder.startObject(); + { + builder.field(ID_FIELD, toolCall.id()); + builder.startObject(FUNCTION_FIELD); + { + builder.field(ARGUMENTS_FIELD, toolCall.function().arguments()); + builder.field(NAME_FIELD, toolCall.function().name()); + } + builder.endObject(); + builder.field(TYPE_FIELD, toolCall.type()); + } + builder.endObject(); + } + builder.endArray(); + } + } + builder.endObject(); + } + } + builder.endArray(); + + var modelId = Objects.requireNonNullElseGet(unifiedRequest.model(), model::model); + builder.field(MODEL_FIELD, modelId); + + if (unifiedRequest.maxCompletionTokens() != null) { + builder.field(MAX_TOKENS, unifiedRequest.maxCompletionTokens()); + } + + // Underlying providers expect OpenAI to only return 1 possible choice. + builder.field(NUMBER_OF_RETURNED_CHOICES_FIELD, 1); + + if (unifiedRequest.stop() != null && unifiedRequest.stop().isEmpty() == false) { + builder.field(STOP_FIELD, unifiedRequest.stop()); + } + if (unifiedRequest.temperature() != null) { + builder.field(TEMPERATURE_FIELD, unifiedRequest.temperature()); + } + if (unifiedRequest.toolChoice() != null) { + if (unifiedRequest.toolChoice() instanceof UnifiedCompletionRequest.ToolChoiceString) { + builder.field(TOOL_CHOICE_FIELD, ((UnifiedCompletionRequest.ToolChoiceString) unifiedRequest.toolChoice()).value()); + } else if (unifiedRequest.toolChoice() instanceof UnifiedCompletionRequest.ToolChoiceObject) { + builder.startObject(TOOL_CHOICE_FIELD); + { + builder.field(TYPE_FIELD, ((UnifiedCompletionRequest.ToolChoiceObject) unifiedRequest.toolChoice()).type()); + builder.startObject(FUNCTION_FIELD); + { + builder.field( + NAME_FIELD, + ((UnifiedCompletionRequest.ToolChoiceObject) unifiedRequest.toolChoice()).function().name() + ); + } + builder.endObject(); + } + builder.endObject(); + } + } + boolean usesTools = unifiedRequest.tools() != null && unifiedRequest.tools().isEmpty() == false; + + if (usesTools) { + builder.startArray(TOOL_FIELD); + for (UnifiedCompletionRequest.Tool tool : unifiedRequest.tools()) { + builder.startObject(); + { + builder.field(TYPE_FIELD, tool.type()); + builder.startObject(FUNCTION_FIELD); + { + builder.field(DESCRIPTION_FIELD, tool.function().description()); + builder.field(NAME_FIELD, tool.function().name()); + builder.field(PARAMETERS_FIELD, tool.function().parameters()); + if (tool.function().strict() != null) { + builder.field(STRICT_FIELD, tool.function().strict()); + } + } + builder.endObject(); + } + builder.endObject(); + } + builder.endArray(); + } + if (unifiedRequest.topP() != null) { + builder.field(TOP_P_FIELD, unifiedRequest.topP()); + } + + builder.field(STREAM_FIELD, stream); + if (stream) { + builder.startObject(STREAM_OPTIONS_FIELD); + builder.field(INCLUDE_USAGE_FIELD, true); + builder.endObject(); + } + + builder.endObject(); + + return builder; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/deepseek/DeepSeekChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/deepseek/DeepSeekChatCompletionRequestEntityTests.java new file mode 100644 index 0000000000000..2ae423637c1f0 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/deepseek/DeepSeekChatCompletionRequestEntityTests.java @@ -0,0 +1,648 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.deepseek; + +import org.elasticsearch.common.Randomness; +import org.elasticsearch.common.Strings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.json.JsonXContent; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Locale; +import java.util.Map; +import java.util.Random; + +import static org.elasticsearch.xpack.inference.Utils.assertJsonEquals; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; + +public class DeepSeekChatCompletionRequestEntityTests extends ESTestCase { + + private static final String ROLE = "user"; + + public void testBasicSerialization() throws IOException { + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Hello, world!"), + ROLE, + null, + null + ); + var messageList = new ArrayList(); + messageList.add(message); + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest(messageList, null, null, null, null, null, null, null); + + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + String jsonString = entityString(unifiedChatInput); + String expectedJson = """ + { + "messages": [ + { + "content": "Hello, world!", + "role": "user" + } + ], + "model": "model-name", + "n": 1, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """; + assertJsonEquals(jsonString, expectedJson); + } + + private String entityString(UnifiedChatInput unifiedChatInput) throws IOException { + Map map = new HashMap<>(); + map.put(MODEL_ID, "model-name"); + map.put("api_key", "1234"); + DeepSeekChatCompletionModel model = DeepSeekChatCompletionModel.createFromNewInput( + "inference-id", + TaskType.CHAT_COMPLETION, + "deepseek", + map + ); + + DeepSeekChatCompletionRequestEntity entity = new DeepSeekChatCompletionRequestEntity(unifiedChatInput, model); + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + return Strings.toString(builder); + } + + public void testSerializationWithAllFields() throws IOException { + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Hello, world!"), + ROLE, + "tool_call_id", + Collections.singletonList( + new UnifiedCompletionRequest.ToolCall( + "id", + new UnifiedCompletionRequest.ToolCall.FunctionField("arguments", "function_name"), + "type" + ) + ) + ); + + UnifiedCompletionRequest.Tool tool = new UnifiedCompletionRequest.Tool( + "type", + new UnifiedCompletionRequest.Tool.FunctionField( + "Fetches the weather in the given location", + "get_weather", + createParameters(), + true + ) + ); + var messageList = new ArrayList(); + messageList.add(message); + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( + messageList, + "request-model", + 100L, // maxTokens + Collections.singletonList("stop"), + 0.9f, // temperature + new UnifiedCompletionRequest.ToolChoiceString("tool_choice"), + Collections.singletonList(tool), + 0.8f // topP + ); + + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + String jsonString = entityString(unifiedChatInput); + String expectedJson = """ + { + "messages": [ + { + "content": "Hello, world!", + "role": "user", + "tool_call_id": "tool_call_id", + "tool_calls": [ + { + "id": "id", + "function": { + "arguments": "arguments", + "name": "function_name" + }, + "type": "type" + } + ] + } + ], + "model": "request-model", + "max_tokens": 100, + "n": 1, + "stop": ["stop"], + "temperature": 0.9, + "tool_choice": "tool_choice", + "tools": [ + { + "type": "type", + "function": { + "description": "Fetches the weather in the given location", + "name": "get_weather", + "parameters": { + "type": "object", + "properties": { + "location": { + "description": "The location to get the weather for", + "type": "string" + }, + "unit": { + "description": "The unit to return the temperature in", + "type": "string", + "enum": ["F", "C"] + } + }, + "additionalProperties": false, + "required": ["location", "unit"] + }, + "strict": true + } + } + ], + "top_p": 0.8, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """; + assertJsonEquals(jsonString, expectedJson); + + } + + public void testSerializationWithNullOptionalFields() throws IOException { + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Hello, world!"), + ROLE, + null, + null + ); + var messageList = new ArrayList(); + messageList.add(message); + + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( + messageList, + null, // model + null, // maxTokens + null, // stop + null, // temperature + null, // toolChoice + null, // tools + null // topP + ); + + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + String jsonString = entityString(unifiedChatInput); + String expectedJson = """ + { + "messages": [ + { + "content": "Hello, world!", + "role": "user" + } + ], + "model": "model-name", + "n": 1, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """; + assertJsonEquals(jsonString, expectedJson); + } + + public void testSerializationWithEmptyLists() throws IOException { + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Hello, world!"), + ROLE, + null, + Collections.emptyList() // empty toolCalls list + ); + var messageList = new ArrayList(); + messageList.add(message); + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( + messageList, + null, // model + null, // maxTokens + Collections.emptyList(), // empty stop list + null, // temperature + null, // toolChoice + Collections.emptyList(), // empty tools list + null // topP + ); + + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + String jsonString = entityString(unifiedChatInput); + String expectedJson = """ + { + "messages": [ + { + "content": "Hello, world!", + "role": "user", + "tool_calls": [] + } + ], + "model": "model-name", + "n": 1, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """; + assertJsonEquals(jsonString, expectedJson); + } + + public void testSerializationWithNestedObjects() throws IOException { + Random random = Randomness.get(); + + String randomContent = "Hello, world! " + random.nextInt(1000); + String randomToolCallId = "tool_call_id" + random.nextInt(1000); + String randomArguments = "arguments" + random.nextInt(1000); + String randomFunctionName = "function_name" + random.nextInt(1000); + String randomType = "type" + random.nextInt(1000); + String randomModel = "model" + random.nextInt(1000); + String randomStop = "stop" + random.nextInt(1000); + float randomTemperature = (float) ((float) Math.round(0.5d + (double) random.nextFloat() * 0.5d * 100000d) / 100000d); + float randomTopP = (float) ((float) Math.round(0.5d + (double) random.nextFloat() * 0.5d * 100000d) / 100000d); + + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString(randomContent), + ROLE, + randomToolCallId, + Collections.singletonList( + new UnifiedCompletionRequest.ToolCall( + "id", + new UnifiedCompletionRequest.ToolCall.FunctionField(randomArguments, randomFunctionName), + randomType + ) + ) + ); + + UnifiedCompletionRequest.Tool tool = new UnifiedCompletionRequest.Tool( + randomType, + new UnifiedCompletionRequest.Tool.FunctionField( + "Fetches the weather in the given location", + "get_weather", + createParameters(), + true + ) + ); + var messageList = new ArrayList(); + messageList.add(message); + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( + messageList, + randomModel, + 100L, // maxTokens + Collections.singletonList(randomStop), + randomTemperature, // temperature + new UnifiedCompletionRequest.ToolChoiceObject( + randomType, + new UnifiedCompletionRequest.ToolChoiceObject.FunctionField(randomFunctionName) + ), + Collections.singletonList(tool), + randomTopP // topP + ); + + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + String jsonString = entityString(unifiedChatInput); + String expectedJson = String.format( + Locale.US, + """ + { + "messages": [ + { + "content": "%s", + "role": "user", + "tool_call_id": "%s", + "tool_calls": [ + { + "id": "id", + "function": { + "arguments": "%s", + "name": "%s" + }, + "type": "%s" + } + ] + } + ], + "model": "%s", + "max_tokens": 100, + "n": 1, + "stop": ["%s"], + "temperature": %.5f, + "tool_choice": { + "type": "%s", + "function": { + "name": "%s" + } + }, + "tools": [ + { + "type": "%s", + "function": { + "description": "Fetches the weather in the given location", + "name": "get_weather", + "parameters": { + "type": "object", + "properties": { + "unit": { + "description": "The unit to return the temperature in", + "type": "string", + "enum": ["F", "C"] + }, + "location": { + "description": "The location to get the weather for", + "type": "string" + } + }, + "additionalProperties": false, + "required": ["location", "unit"] + }, + "strict": true + } + } + ], + "top_p": %.5f, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """, + randomContent, + randomToolCallId, + randomArguments, + randomFunctionName, + randomType, + randomModel, + randomStop, + randomTemperature, + randomType, + randomFunctionName, + randomType, + randomTopP + ); + assertJsonEquals(jsonString, expectedJson); + } + + public void testSerializationWithDifferentContentTypes() throws IOException { + Random random = Randomness.get(); + + String randomContentString = "Hello, world! " + random.nextInt(1000); + + String randomText = "Random text " + random.nextInt(1000); + String randomType = "type" + random.nextInt(1000); + UnifiedCompletionRequest.ContentObject contentObject = new UnifiedCompletionRequest.ContentObject(randomText, randomType); + + var contentObjectsList = new ArrayList(); + contentObjectsList.add(contentObject); + UnifiedCompletionRequest.ContentObjects contentObjects = new UnifiedCompletionRequest.ContentObjects(contentObjectsList); + + UnifiedCompletionRequest.Message messageWithString = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString(randomContentString), + ROLE, + null, + null + ); + + UnifiedCompletionRequest.Message messageWithObjects = new UnifiedCompletionRequest.Message(contentObjects, ROLE, null, null); + var messageList = new ArrayList(); + messageList.add(messageWithString); + messageList.add(messageWithObjects); + + UnifiedCompletionRequest unifiedRequest = UnifiedCompletionRequest.of(messageList); + + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + String jsonString = entityString(unifiedChatInput); + String expectedJson = String.format(Locale.US, """ + { + "messages": [ + { + "content": "%s", + "role": "user" + }, + { + "content": [ + { + "text": "%s", + "type": "%s" + } + ], + "role": "user" + } + ], + "model": "model-name", + "n": 1, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """, randomContentString, randomText, randomType); + assertJsonEquals(jsonString, expectedJson); + } + + public void testSerializationWithSpecialCharacters() throws IOException { + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Hello, world! \n \"Special\" characters: \t \\ /"), + ROLE, + "tool_call_id\twith\ttabs", + Collections.singletonList( + new UnifiedCompletionRequest.ToolCall( + "id\\with\\backslashes", + new UnifiedCompletionRequest.ToolCall.FunctionField("arguments\"with\"quotes", "function_name/with/slashes"), + "type" + ) + ) + ); + var messageList = new ArrayList(); + messageList.add(message); + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( + messageList, + null, // model + null, // maxTokens + null, // stop + null, // temperature + null, // toolChoice + null, // tools + null // topP + ); + + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + String jsonString = entityString(unifiedChatInput); + String expectedJson = """ + { + "messages": [ + { + "content": "Hello, world! \\n \\"Special\\" characters: \\t \\\\ /", + "role": "user", + "tool_call_id": "tool_call_id\\twith\\ttabs", + "tool_calls": [ + { + "id": "id\\\\with\\\\backslashes", + "function": { + "arguments": "arguments\\"with\\"quotes", + "name": "function_name/with/slashes" + }, + "type": "type" + } + ] + } + ], + "model": "model-name", + "n": 1, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """; + assertJsonEquals(jsonString, expectedJson); + } + + public void testSerializationWithBooleanFields() throws IOException { + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Hello, world!"), + ROLE, + null, + null + ); + var messageList = new ArrayList(); + messageList.add(message); + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( + messageList, + null, // model + null, // maxTokens + null, // stop + null, // temperature + null, // toolChoice + null, // tools + null // topP + ); + + UnifiedChatInput unifiedChatInputTrue = new UnifiedChatInput(unifiedRequest, true); + String jsonStringTrue = entityString(unifiedChatInputTrue); + String expectedJsonTrue = """ + { + "messages": [ + { + "content": "Hello, world!", + "role": "user" + } + ], + "model": "model-name", + "n": 1, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """; + assertJsonEquals(expectedJsonTrue, jsonStringTrue); + + UnifiedChatInput unifiedChatInputFalse = new UnifiedChatInput(unifiedRequest, false); + String jsonStringFalse = entityString(unifiedChatInputFalse); + String expectedJsonFalse = """ + { + "messages": [ + { + "content": "Hello, world!", + "role": "user" + } + ], + "model": "model-name", + "n": 1, + "stream": false + } + """; + assertJsonEquals(expectedJsonFalse, jsonStringFalse); + } + + public void testSerializationWithoutContentField() throws IOException { + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + null, + "assistant", + "tool_call_id\twith\ttabs", + Collections.singletonList( + new UnifiedCompletionRequest.ToolCall( + "id\\with\\backslashes", + new UnifiedCompletionRequest.ToolCall.FunctionField("arguments\"with\"quotes", "function_name/with/slashes"), + "type" + ) + ) + ); + var messageList = new ArrayList(); + messageList.add(message); + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest(messageList, null, null, null, null, null, null, null); + + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + String jsonString = entityString(unifiedChatInput); + String expectedJson = """ + { + "messages": [ + { + "role": "assistant", + "tool_call_id": "tool_call_id\\twith\\ttabs", + "tool_calls": [ + { + "id": "id\\\\with\\\\backslashes", + "function": { + "arguments": "arguments\\"with\\"quotes", + "name": "function_name/with/slashes" + }, + "type": "type" + } + ] + } + ], + "model": "model-name", + "n": 1, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """; + assertJsonEquals(jsonString, expectedJson); + } + + private static Map createParameters() { + Map parameters = new LinkedHashMap<>(); + parameters.put("type", "object"); + + Map properties = new HashMap<>(); + + Map location = new HashMap<>(); + location.put("type", "string"); + location.put("description", "The location to get the weather for"); + properties.put("location", location); + + Map unit = new HashMap<>(); + unit.put("type", "string"); + unit.put("description", "The unit to return the temperature in"); + unit.put("enum", new String[] { "F", "C" }); + properties.put("unit", unit); + + parameters.put("properties", properties); + parameters.put("additionalProperties", false); + parameters.put("required", new String[] { "location", "unit" }); + + return parameters; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java index 9b5b92b9c90c3..fdc88df87f815 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java @@ -12,7 +12,9 @@ import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.common.Strings; import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; @@ -23,9 +25,12 @@ import org.elasticsearch.test.http.MockResponse; import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; +import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; @@ -41,12 +46,15 @@ import java.util.Map; import java.util.concurrent.TimeUnit; +import static org.elasticsearch.ExceptionsHelper.unwrapCause; import static org.elasticsearch.action.support.ActionTestUtils.assertNoFailureListener; import static org.elasticsearch.action.support.ActionTestUtils.assertNoSuccessListener; import static org.elasticsearch.common.Strings.format; +import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.isA; import static org.mockito.Mockito.mock; @@ -266,6 +274,93 @@ public void testDoInferStream() throws Exception { {"completion":[{"delta":"hello, world"}]}"""); } + public void testUnifiedCompletionError() throws Exception { + String responseJson = """ + { + "error": { + "message": "The model `deepseek-not-chat` does not exist or you do not have access to it.", + "type": "invalid_request_error", + "param": null, + "code": "model_not_found" + } + }"""; + webServer.enqueue(new MockResponse().setResponseCode(404).setBody(responseJson)); + testStreamError(""" + {\ + "error":{\ + "code":"model_not_found",\ + "message":"Received an unsuccessful status code for request from inference entity id [inference-id] status \ + [404]. Error message: [The model `deepseek-not-chat` does not exist or you do not have access to it.]",\ + "type":"invalid_request_error"\ + }}"""); + } + + private void testStreamError(String expectedResponse) throws Exception { + try (var service = createService()) { + var model = createModel(service, TaskType.CHAT_COMPLETION); + PlainActionFuture listener = new PlainActionFuture<>(); + service.unifiedCompletionInfer( + model, + UnifiedCompletionRequest.of( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null)) + ), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + + InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoEvents().hasErrorMatching(e -> { + e = unwrapCause(e); + assertThat(e, isA(UnifiedChatCompletionException.class)); + try (var builder = XContentFactory.jsonBuilder()) { + ((UnifiedChatCompletionException) e).toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> { + try { + xContent.toXContent(builder, EMPTY_PARAMS); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + }); + var json = XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType()); + + assertThat(json, is(expectedResponse)); + } + }); + } + } + + public void testMidStreamUnifiedCompletionError() throws Exception { + String responseJson = """ + event: error + data: { "error": { "message": "Timed out waiting for more data", "type": "timeout" } } + + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + testStreamError(""" + {\ + "error":{\ + "message":"Received an error response for request from inference entity id [inference-id]. Error message: \ + [Timed out waiting for more data]",\ + "type":"timeout"\ + }}"""); + } + + public void testUnifiedCompletionMalformedError() throws Exception { + String responseJson = """ + data: { invalid json } + + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + testStreamError(""" + {\ + "error":{\ + "code":"bad_request",\ + "message":"[1:3] Unexpected character ('i' (code 105)): was expecting double-quote to start field name\\n\ + at [Source: (String)\\"{ invalid json }\\"; line: 1, column: 3]",\ + "type":"x_content_parse_exception"\ + }}"""); + } + public void testDoChunkedInferAlwaysFails() throws IOException { try (var service = createService()) { service.doChunkedInfer(mock(), mock(), Map.of(), InputType.UNSPECIFIED, TIMEOUT, assertNoSuccessListener(e -> { From 65e8b26e08d004682818f65b649fab01ef67b4e4 Mon Sep 17 00:00:00 2001 From: Pat Whelan Date: Tue, 25 Feb 2025 13:39:15 -0500 Subject: [PATCH 04/11] Fix InferenceGetServicesIT --- .../xpack/inference/InferenceGetServicesIT.java | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java index 859a065b6e1a0..6f9a550481049 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java @@ -25,7 +25,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest { @SuppressWarnings("unchecked") public void testGetServicesWithoutTaskType() throws IOException { List services = getAllServices(); - assertThat(services.size(), equalTo(20)); + assertThat(services.size(), equalTo(21)); String[] providers = new String[services.size()]; for (int i = 0; i < services.size(); i++) { @@ -41,6 +41,7 @@ public void testGetServicesWithoutTaskType() throws IOException { "azureaistudio", "azureopenai", "cohere", + "deepseek", "elastic", "elasticsearch", "googleaistudio", @@ -114,7 +115,7 @@ public void testGetServicesWithRerankTaskType() throws IOException { @SuppressWarnings("unchecked") public void testGetServicesWithCompletionTaskType() throws IOException { List services = getServices(TaskType.COMPLETION); - assertThat(services.size(), equalTo(9)); + assertThat(services.size(), equalTo(10)); String[] providers = new String[services.size()]; for (int i = 0; i < services.size(); i++) { @@ -130,6 +131,7 @@ public void testGetServicesWithCompletionTaskType() throws IOException { "azureaistudio", "azureopenai", "cohere", + "deepseek", "googleaistudio", "openai", "streaming_completion_test_service" @@ -141,7 +143,7 @@ public void testGetServicesWithCompletionTaskType() throws IOException { @SuppressWarnings("unchecked") public void testGetServicesWithChatCompletionTaskType() throws IOException { List services = getServices(TaskType.CHAT_COMPLETION); - assertThat(services.size(), equalTo(3)); + assertThat(services.size(), equalTo(4)); String[] providers = new String[services.size()]; for (int i = 0; i < services.size(); i++) { @@ -149,7 +151,7 @@ public void testGetServicesWithChatCompletionTaskType() throws IOException { providers[i] = (String) serviceConfig.get("service"); } - assertArrayEquals(List.of("elastic", "openai", "streaming_completion_test_service").toArray(), providers); + assertArrayEquals(List.of("deepseek", "elastic", "openai", "streaming_completion_test_service").toArray(), providers); } @SuppressWarnings("unchecked") From 9e9fbc17e731e9c9b42f2f5037454ea1f5ac24e6 Mon Sep 17 00:00:00 2001 From: Pat Whelan Date: Mon, 10 Mar 2025 23:09:49 -0400 Subject: [PATCH 05/11] Adding 8.x TransportVersion --- server/src/main/java/org/elasticsearch/TransportVersions.java | 1 + .../request/deepseek/DeepSeekChatCompletionRequest.java | 2 -- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index bd6dde503dea7..7d5ffd5f6dc96 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -146,6 +146,7 @@ static TransportVersion def(int id) { public static final TransportVersion VOYAGE_AI_INTEGRATION_ADDED_BACKPORT_8_X = def(8_841_0_05); public static final TransportVersion JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED_BACKPORT_8_19 = def(8_841_0_06); public static final TransportVersion RETRY_ILM_ASYNC_ACTION_REQUIRE_ERROR_8_19 = def(8_841_0_07); + public static final TransportVersion ML_INFERENCE_DEEPSEEK_8_19 = def(8_841_0_08); public static final TransportVersion INITIAL_ELASTICSEARCH_9_0 = def(9_000_0_00); public static final TransportVersion REMOVE_SNAPSHOT_FAILURES_90 = def(9_000_0_01); public static final TransportVersion TRANSPORT_STATS_HANDLING_TIME_REQUIRED_90 = def(9_000_0_02); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/deepseek/DeepSeekChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/deepseek/DeepSeekChatCompletionRequest.java index 4a927f9fac851..20ba5629a69a6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/deepseek/DeepSeekChatCompletionRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/deepseek/DeepSeekChatCompletionRequest.java @@ -65,13 +65,11 @@ public URI getURI() { @Override public Request truncate() { - // No truncation for OpenAI chat completions return this; } @Override public boolean[] getTruncationInfo() { - // No truncation for OpenAI chat completions return null; } From 0173947eab9a59f899dd5521d757fc17134e746c Mon Sep 17 00:00:00 2001 From: Pat Whelan Date: Tue, 11 Mar 2025 08:35:34 -0400 Subject: [PATCH 06/11] Use new error message API --- .../xpack/inference/services/deepseek/DeepSeekService.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java index c0478763a8803..4433c43e1b8f0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java @@ -83,7 +83,7 @@ private void doInfer( ) { if (model instanceof DeepSeekChatCompletionModel deepSeekModel) { var requestCreator = new DeepSeekRequestManager(deepSeekModel, getServiceComponents().threadPool()); - var errorMessage = constructFailedToSendRequestMessage(deepSeekModel.uri(), errorPrefix); + var errorMessage = constructFailedToSendRequestMessage(errorPrefix); var action = new SenderExecutableAction(getSender(), requestCreator, errorMessage); action.execute(inputs, timeout, listener); } else { From aa79769e059748ac10b8fe7ddf0a2d4d6cd62d65 Mon Sep 17 00:00:00 2001 From: Pat Whelan Date: Tue, 11 Mar 2025 08:52:45 -0400 Subject: [PATCH 07/11] Move request packages to external --- .../{request => }/deepseek/DeepSeekChatCompletionRequest.java | 2 +- .../deepseek/DeepSeekChatCompletionRequestEntity.java | 2 +- .../inference/external/http/sender/DeepSeekRequestManager.java | 2 +- .../deepseek/DeepSeekChatCompletionRequestEntityTests.java | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) rename x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/{request => }/deepseek/DeepSeekChatCompletionRequest.java (97%) rename x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/{request => }/deepseek/DeepSeekChatCompletionRequestEntity.java (99%) rename x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/{request => }/deepseek/DeepSeekChatCompletionRequestEntityTests.java (99%) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/deepseek/DeepSeekChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/deepseek/DeepSeekChatCompletionRequest.java similarity index 97% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/deepseek/DeepSeekChatCompletionRequest.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/deepseek/DeepSeekChatCompletionRequest.java index 20ba5629a69a6..6acab87ce835d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/deepseek/DeepSeekChatCompletionRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/deepseek/DeepSeekChatCompletionRequest.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.external.request.deepseek; +package org.elasticsearch.xpack.inference.external.deepseek; import org.apache.http.HttpHeaders; import org.apache.http.client.methods.HttpPost; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/deepseek/DeepSeekChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/deepseek/DeepSeekChatCompletionRequestEntity.java similarity index 99% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/deepseek/DeepSeekChatCompletionRequestEntity.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/deepseek/DeepSeekChatCompletionRequestEntity.java index e2c942e7cef37..fa81ce6cf2cbf 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/deepseek/DeepSeekChatCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/deepseek/DeepSeekChatCompletionRequestEntity.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.external.request.deepseek; +package org.elasticsearch.xpack.inference.external.deepseek; import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.xcontent.ToXContentFragment; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/DeepSeekRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/DeepSeekRequestManager.java index 5f925e3286959..ec6d88f48a275 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/DeepSeekRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/DeepSeekRequestManager.java @@ -12,11 +12,11 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.deepseek.DeepSeekChatCompletionRequest; import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; import org.elasticsearch.xpack.inference.external.openai.OpenAiChatCompletionResponseHandler; import org.elasticsearch.xpack.inference.external.openai.OpenAiUnifiedChatCompletionResponseHandler; -import org.elasticsearch.xpack.inference.external.request.deepseek.DeepSeekChatCompletionRequest; import org.elasticsearch.xpack.inference.external.response.openai.OpenAiChatCompletionResponseEntity; import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/deepseek/DeepSeekChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/deepseek/DeepSeekChatCompletionRequestEntityTests.java similarity index 99% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/deepseek/DeepSeekChatCompletionRequestEntityTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/deepseek/DeepSeekChatCompletionRequestEntityTests.java index 2ae423637c1f0..3023de2ccb4db 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/deepseek/DeepSeekChatCompletionRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/deepseek/DeepSeekChatCompletionRequestEntityTests.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.external.request.deepseek; +package org.elasticsearch.xpack.inference.external.deepseek; import org.elasticsearch.common.Randomness; import org.elasticsearch.common.Strings; From 9c31d35b6fea67282780922f7dcef6b1cced4de4 Mon Sep 17 00:00:00 2001 From: Pat Whelan Date: Tue, 11 Mar 2025 09:34:44 -0400 Subject: [PATCH 08/11] Move max tokens up a level so it can be customized like model id --- .../DeepSeekChatCompletionRequest.java | 14 +- .../DeepSeekChatCompletionRequestEntity.java | 186 ----- ...iceUnifiedChatCompletionRequestEntity.java | 10 +- ...nAiUnifiedChatCompletionRequestEntity.java | 9 +- .../UnifiedChatCompletionRequestEntity.java | 5 - ...pSeekChatCompletionRequestEntityTests.java | 648 ------------------ 6 files changed, 30 insertions(+), 842 deletions(-) delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/deepseek/DeepSeekChatCompletionRequestEntity.java delete mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/deepseek/DeepSeekChatCompletionRequestEntityTests.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/deepseek/DeepSeekChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/deepseek/DeepSeekChatCompletionRequest.java index 6acab87ce835d..5fbc8883d5051 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/deepseek/DeepSeekChatCompletionRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/deepseek/DeepSeekChatCompletionRequest.java @@ -18,6 +18,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.external.request.HttpRequest; import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.external.unified.UnifiedChatCompletionRequestEntity; import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel; import java.io.IOException; @@ -28,6 +29,8 @@ import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; public class DeepSeekChatCompletionRequest implements Request { + private static final String MODEL_FIELD = "model"; + private static final String MAX_TOKENS = "max_tokens"; private final DeepSeekChatCompletionModel model; private final UnifiedChatInput unifiedChatInput; @@ -50,8 +53,17 @@ public HttpRequest createHttpRequest() { } private ByteArrayEntity createEntity() { + var modelId = Objects.requireNonNullElseGet(unifiedChatInput.getRequest().model(), model::model); try (var builder = JsonXContent.contentBuilder()) { - new DeepSeekChatCompletionRequestEntity(unifiedChatInput, model).toXContent(builder, ToXContent.EMPTY_PARAMS); + builder.startObject(); + new UnifiedChatCompletionRequestEntity(unifiedChatInput).toXContent(builder, ToXContent.EMPTY_PARAMS); + builder.field(MODEL_FIELD, modelId); + + if (unifiedChatInput.getRequest().maxCompletionTokens() != null) { + builder.field(MAX_TOKENS, unifiedChatInput.getRequest().maxCompletionTokens()); + } + + builder.endObject(); return new ByteArrayEntity(Strings.toString(builder).getBytes(StandardCharsets.UTF_8)); } catch (IOException e) { throw new ElasticsearchException("Failed to serialize request payload.", e); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/deepseek/DeepSeekChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/deepseek/DeepSeekChatCompletionRequestEntity.java deleted file mode 100644 index fa81ce6cf2cbf..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/deepseek/DeepSeekChatCompletionRequestEntity.java +++ /dev/null @@ -1,186 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.external.deepseek; - -import org.elasticsearch.inference.UnifiedCompletionRequest; -import org.elasticsearch.xcontent.ToXContentFragment; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; -import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel; - -import java.io.IOException; -import java.util.Objects; - -import static org.elasticsearch.xpack.inference.external.openai.OpenAiUnifiedStreamingProcessor.MODEL_FIELD; - -class DeepSeekChatCompletionRequestEntity implements ToXContentFragment { - - public static final String NAME_FIELD = "name"; - public static final String TOOL_CALL_ID_FIELD = "tool_call_id"; - public static final String TOOL_CALLS_FIELD = "tool_calls"; - public static final String ID_FIELD = "id"; - public static final String FUNCTION_FIELD = "function"; - public static final String ARGUMENTS_FIELD = "arguments"; - public static final String DESCRIPTION_FIELD = "description"; - public static final String PARAMETERS_FIELD = "parameters"; - public static final String STRICT_FIELD = "strict"; - public static final String TOP_P_FIELD = "top_p"; - public static final String STREAM_FIELD = "stream"; - private static final String NUMBER_OF_RETURNED_CHOICES_FIELD = "n"; - public static final String MESSAGES_FIELD = "messages"; - private static final String ROLE_FIELD = "role"; - private static final String CONTENT_FIELD = "content"; - private static final String MAX_TOKENS = "max_tokens"; - private static final String STOP_FIELD = "stop"; - private static final String TEMPERATURE_FIELD = "temperature"; - private static final String TOOL_CHOICE_FIELD = "tool_choice"; - private static final String TOOL_FIELD = "tools"; - private static final String TEXT_FIELD = "text"; - private static final String TYPE_FIELD = "type"; - private static final String STREAM_OPTIONS_FIELD = "stream_options"; - private static final String INCLUDE_USAGE_FIELD = "include_usage"; - - private final DeepSeekChatCompletionModel model; - private final UnifiedCompletionRequest unifiedRequest; - private final boolean stream; - - DeepSeekChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, DeepSeekChatCompletionModel model) { - Objects.requireNonNull(unifiedChatInput); - this.model = Objects.requireNonNull(model); - this.unifiedRequest = unifiedChatInput.getRequest(); - this.stream = unifiedChatInput.stream(); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - - builder.startArray(MESSAGES_FIELD); - { - for (UnifiedCompletionRequest.Message message : unifiedRequest.messages()) { - builder.startObject(); - { - switch (message.content()) { - case UnifiedCompletionRequest.ContentString contentString -> builder.field(CONTENT_FIELD, contentString.content()); - case UnifiedCompletionRequest.ContentObjects contentObjects -> { - builder.startArray(CONTENT_FIELD); - for (UnifiedCompletionRequest.ContentObject contentObject : contentObjects.contentObjects()) { - builder.startObject(); - builder.field(TEXT_FIELD, contentObject.text()); - builder.field(TYPE_FIELD, contentObject.type()); - builder.endObject(); - } - builder.endArray(); - } - case null -> { - // do nothing because content is optional - } - } - - builder.field(ROLE_FIELD, message.role()); - if (message.toolCallId() != null) { - builder.field(TOOL_CALL_ID_FIELD, message.toolCallId()); - } - if (message.toolCalls() != null) { - builder.startArray(TOOL_CALLS_FIELD); - for (UnifiedCompletionRequest.ToolCall toolCall : message.toolCalls()) { - builder.startObject(); - { - builder.field(ID_FIELD, toolCall.id()); - builder.startObject(FUNCTION_FIELD); - { - builder.field(ARGUMENTS_FIELD, toolCall.function().arguments()); - builder.field(NAME_FIELD, toolCall.function().name()); - } - builder.endObject(); - builder.field(TYPE_FIELD, toolCall.type()); - } - builder.endObject(); - } - builder.endArray(); - } - } - builder.endObject(); - } - } - builder.endArray(); - - var modelId = Objects.requireNonNullElseGet(unifiedRequest.model(), model::model); - builder.field(MODEL_FIELD, modelId); - - if (unifiedRequest.maxCompletionTokens() != null) { - builder.field(MAX_TOKENS, unifiedRequest.maxCompletionTokens()); - } - - // Underlying providers expect OpenAI to only return 1 possible choice. - builder.field(NUMBER_OF_RETURNED_CHOICES_FIELD, 1); - - if (unifiedRequest.stop() != null && unifiedRequest.stop().isEmpty() == false) { - builder.field(STOP_FIELD, unifiedRequest.stop()); - } - if (unifiedRequest.temperature() != null) { - builder.field(TEMPERATURE_FIELD, unifiedRequest.temperature()); - } - if (unifiedRequest.toolChoice() != null) { - if (unifiedRequest.toolChoice() instanceof UnifiedCompletionRequest.ToolChoiceString) { - builder.field(TOOL_CHOICE_FIELD, ((UnifiedCompletionRequest.ToolChoiceString) unifiedRequest.toolChoice()).value()); - } else if (unifiedRequest.toolChoice() instanceof UnifiedCompletionRequest.ToolChoiceObject) { - builder.startObject(TOOL_CHOICE_FIELD); - { - builder.field(TYPE_FIELD, ((UnifiedCompletionRequest.ToolChoiceObject) unifiedRequest.toolChoice()).type()); - builder.startObject(FUNCTION_FIELD); - { - builder.field( - NAME_FIELD, - ((UnifiedCompletionRequest.ToolChoiceObject) unifiedRequest.toolChoice()).function().name() - ); - } - builder.endObject(); - } - builder.endObject(); - } - } - boolean usesTools = unifiedRequest.tools() != null && unifiedRequest.tools().isEmpty() == false; - - if (usesTools) { - builder.startArray(TOOL_FIELD); - for (UnifiedCompletionRequest.Tool tool : unifiedRequest.tools()) { - builder.startObject(); - { - builder.field(TYPE_FIELD, tool.type()); - builder.startObject(FUNCTION_FIELD); - { - builder.field(DESCRIPTION_FIELD, tool.function().description()); - builder.field(NAME_FIELD, tool.function().name()); - builder.field(PARAMETERS_FIELD, tool.function().parameters()); - if (tool.function().strict() != null) { - builder.field(STRICT_FIELD, tool.function().strict()); - } - } - builder.endObject(); - } - builder.endObject(); - } - builder.endArray(); - } - if (unifiedRequest.topP() != null) { - builder.field(TOP_P_FIELD, unifiedRequest.topP()); - } - - builder.field(STREAM_FIELD, stream); - if (stream) { - builder.startObject(STREAM_OPTIONS_FIELD); - builder.field(INCLUDE_USAGE_FIELD, true); - builder.endObject(); - } - - builder.endObject(); - - return builder; - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceUnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceUnifiedChatCompletionRequestEntity.java index ded8a074478cf..2631eaa085fb1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceUnifiedChatCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceUnifiedChatCompletionRequestEntity.java @@ -17,12 +17,15 @@ public class ElasticInferenceServiceUnifiedChatCompletionRequestEntity implements ToXContentObject { private static final String MODEL_FIELD = "model"; + private static final String MAX_COMPLETION_TOKENS_FIELD = "max_completion_tokens"; + private final UnifiedChatInput unifiedChatInput; private final UnifiedChatCompletionRequestEntity unifiedRequestEntity; private final String modelId; public ElasticInferenceServiceUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, String modelId) { - this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(Objects.requireNonNull(unifiedChatInput)); + this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput); + this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(unifiedChatInput); this.modelId = Objects.requireNonNull(modelId); } @@ -31,6 +34,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.startObject(); unifiedRequestEntity.toXContent(builder, params); builder.field(MODEL_FIELD, modelId); + + if (unifiedChatInput.getRequest().maxCompletionTokens() != null) { + builder.field(MAX_COMPLETION_TOKENS_FIELD, unifiedChatInput.getRequest().maxCompletionTokens()); + } + builder.endObject(); return builder; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java index b80100c9e2f79..e7d97b47b0837 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java @@ -21,12 +21,15 @@ public class OpenAiUnifiedChatCompletionRequestEntity implements ToXContentObjec public static final String USER_FIELD = "user"; private static final String MODEL_FIELD = "model"; + private static final String MAX_COMPLETION_TOKENS_FIELD = "max_completion_tokens"; + private final UnifiedChatInput unifiedChatInput; private final OpenAiChatCompletionModel model; private final UnifiedChatCompletionRequestEntity unifiedRequestEntity; public OpenAiUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, OpenAiChatCompletionModel model) { - this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(Objects.requireNonNull(unifiedChatInput)); + this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput); + this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(unifiedChatInput); this.model = Objects.requireNonNull(model); } @@ -41,6 +44,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(USER_FIELD, model.getTaskSettings().user()); } + if (unifiedChatInput.getRequest().maxCompletionTokens() != null) { + builder.field(MAX_COMPLETION_TOKENS_FIELD, unifiedChatInput.getRequest().maxCompletionTokens()); + } + builder.endObject(); return builder; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/unified/UnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/unified/UnifiedChatCompletionRequestEntity.java index 5e6d09cde2b9f..6a6f8d92c74ca 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/unified/UnifiedChatCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/unified/UnifiedChatCompletionRequestEntity.java @@ -32,7 +32,6 @@ public class UnifiedChatCompletionRequestEntity implements ToXContentFragment { public static final String MESSAGES_FIELD = "messages"; private static final String ROLE_FIELD = "role"; private static final String CONTENT_FIELD = "content"; - private static final String MAX_COMPLETION_TOKENS_FIELD = "max_completion_tokens"; private static final String STOP_FIELD = "stop"; private static final String TEMPERATURE_FIELD = "temperature"; private static final String TOOL_CHOICE_FIELD = "tool_choice"; @@ -104,10 +103,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } builder.endArray(); - if (unifiedRequest.maxCompletionTokens() != null) { - builder.field(MAX_COMPLETION_TOKENS_FIELD, unifiedRequest.maxCompletionTokens()); - } - // Underlying providers expect OpenAI to only return 1 possible choice. builder.field(NUMBER_OF_RETURNED_CHOICES_FIELD, 1); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/deepseek/DeepSeekChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/deepseek/DeepSeekChatCompletionRequestEntityTests.java deleted file mode 100644 index 3023de2ccb4db..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/deepseek/DeepSeekChatCompletionRequestEntityTests.java +++ /dev/null @@ -1,648 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.external.deepseek; - -import org.elasticsearch.common.Randomness; -import org.elasticsearch.common.Strings; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.inference.UnifiedCompletionRequest; -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xcontent.ToXContent; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xcontent.json.JsonXContent; -import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; -import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.LinkedHashMap; -import java.util.Locale; -import java.util.Map; -import java.util.Random; - -import static org.elasticsearch.xpack.inference.Utils.assertJsonEquals; -import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; - -public class DeepSeekChatCompletionRequestEntityTests extends ESTestCase { - - private static final String ROLE = "user"; - - public void testBasicSerialization() throws IOException { - UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( - new UnifiedCompletionRequest.ContentString("Hello, world!"), - ROLE, - null, - null - ); - var messageList = new ArrayList(); - messageList.add(message); - UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest(messageList, null, null, null, null, null, null, null); - - UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); - String jsonString = entityString(unifiedChatInput); - String expectedJson = """ - { - "messages": [ - { - "content": "Hello, world!", - "role": "user" - } - ], - "model": "model-name", - "n": 1, - "stream": true, - "stream_options": { - "include_usage": true - } - } - """; - assertJsonEquals(jsonString, expectedJson); - } - - private String entityString(UnifiedChatInput unifiedChatInput) throws IOException { - Map map = new HashMap<>(); - map.put(MODEL_ID, "model-name"); - map.put("api_key", "1234"); - DeepSeekChatCompletionModel model = DeepSeekChatCompletionModel.createFromNewInput( - "inference-id", - TaskType.CHAT_COMPLETION, - "deepseek", - map - ); - - DeepSeekChatCompletionRequestEntity entity = new DeepSeekChatCompletionRequestEntity(unifiedChatInput, model); - XContentBuilder builder = JsonXContent.contentBuilder(); - entity.toXContent(builder, ToXContent.EMPTY_PARAMS); - return Strings.toString(builder); - } - - public void testSerializationWithAllFields() throws IOException { - UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( - new UnifiedCompletionRequest.ContentString("Hello, world!"), - ROLE, - "tool_call_id", - Collections.singletonList( - new UnifiedCompletionRequest.ToolCall( - "id", - new UnifiedCompletionRequest.ToolCall.FunctionField("arguments", "function_name"), - "type" - ) - ) - ); - - UnifiedCompletionRequest.Tool tool = new UnifiedCompletionRequest.Tool( - "type", - new UnifiedCompletionRequest.Tool.FunctionField( - "Fetches the weather in the given location", - "get_weather", - createParameters(), - true - ) - ); - var messageList = new ArrayList(); - messageList.add(message); - UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( - messageList, - "request-model", - 100L, // maxTokens - Collections.singletonList("stop"), - 0.9f, // temperature - new UnifiedCompletionRequest.ToolChoiceString("tool_choice"), - Collections.singletonList(tool), - 0.8f // topP - ); - - UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); - String jsonString = entityString(unifiedChatInput); - String expectedJson = """ - { - "messages": [ - { - "content": "Hello, world!", - "role": "user", - "tool_call_id": "tool_call_id", - "tool_calls": [ - { - "id": "id", - "function": { - "arguments": "arguments", - "name": "function_name" - }, - "type": "type" - } - ] - } - ], - "model": "request-model", - "max_tokens": 100, - "n": 1, - "stop": ["stop"], - "temperature": 0.9, - "tool_choice": "tool_choice", - "tools": [ - { - "type": "type", - "function": { - "description": "Fetches the weather in the given location", - "name": "get_weather", - "parameters": { - "type": "object", - "properties": { - "location": { - "description": "The location to get the weather for", - "type": "string" - }, - "unit": { - "description": "The unit to return the temperature in", - "type": "string", - "enum": ["F", "C"] - } - }, - "additionalProperties": false, - "required": ["location", "unit"] - }, - "strict": true - } - } - ], - "top_p": 0.8, - "stream": true, - "stream_options": { - "include_usage": true - } - } - """; - assertJsonEquals(jsonString, expectedJson); - - } - - public void testSerializationWithNullOptionalFields() throws IOException { - UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( - new UnifiedCompletionRequest.ContentString("Hello, world!"), - ROLE, - null, - null - ); - var messageList = new ArrayList(); - messageList.add(message); - - UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( - messageList, - null, // model - null, // maxTokens - null, // stop - null, // temperature - null, // toolChoice - null, // tools - null // topP - ); - - UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); - String jsonString = entityString(unifiedChatInput); - String expectedJson = """ - { - "messages": [ - { - "content": "Hello, world!", - "role": "user" - } - ], - "model": "model-name", - "n": 1, - "stream": true, - "stream_options": { - "include_usage": true - } - } - """; - assertJsonEquals(jsonString, expectedJson); - } - - public void testSerializationWithEmptyLists() throws IOException { - UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( - new UnifiedCompletionRequest.ContentString("Hello, world!"), - ROLE, - null, - Collections.emptyList() // empty toolCalls list - ); - var messageList = new ArrayList(); - messageList.add(message); - UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( - messageList, - null, // model - null, // maxTokens - Collections.emptyList(), // empty stop list - null, // temperature - null, // toolChoice - Collections.emptyList(), // empty tools list - null // topP - ); - - UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); - String jsonString = entityString(unifiedChatInput); - String expectedJson = """ - { - "messages": [ - { - "content": "Hello, world!", - "role": "user", - "tool_calls": [] - } - ], - "model": "model-name", - "n": 1, - "stream": true, - "stream_options": { - "include_usage": true - } - } - """; - assertJsonEquals(jsonString, expectedJson); - } - - public void testSerializationWithNestedObjects() throws IOException { - Random random = Randomness.get(); - - String randomContent = "Hello, world! " + random.nextInt(1000); - String randomToolCallId = "tool_call_id" + random.nextInt(1000); - String randomArguments = "arguments" + random.nextInt(1000); - String randomFunctionName = "function_name" + random.nextInt(1000); - String randomType = "type" + random.nextInt(1000); - String randomModel = "model" + random.nextInt(1000); - String randomStop = "stop" + random.nextInt(1000); - float randomTemperature = (float) ((float) Math.round(0.5d + (double) random.nextFloat() * 0.5d * 100000d) / 100000d); - float randomTopP = (float) ((float) Math.round(0.5d + (double) random.nextFloat() * 0.5d * 100000d) / 100000d); - - UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( - new UnifiedCompletionRequest.ContentString(randomContent), - ROLE, - randomToolCallId, - Collections.singletonList( - new UnifiedCompletionRequest.ToolCall( - "id", - new UnifiedCompletionRequest.ToolCall.FunctionField(randomArguments, randomFunctionName), - randomType - ) - ) - ); - - UnifiedCompletionRequest.Tool tool = new UnifiedCompletionRequest.Tool( - randomType, - new UnifiedCompletionRequest.Tool.FunctionField( - "Fetches the weather in the given location", - "get_weather", - createParameters(), - true - ) - ); - var messageList = new ArrayList(); - messageList.add(message); - UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( - messageList, - randomModel, - 100L, // maxTokens - Collections.singletonList(randomStop), - randomTemperature, // temperature - new UnifiedCompletionRequest.ToolChoiceObject( - randomType, - new UnifiedCompletionRequest.ToolChoiceObject.FunctionField(randomFunctionName) - ), - Collections.singletonList(tool), - randomTopP // topP - ); - - UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); - String jsonString = entityString(unifiedChatInput); - String expectedJson = String.format( - Locale.US, - """ - { - "messages": [ - { - "content": "%s", - "role": "user", - "tool_call_id": "%s", - "tool_calls": [ - { - "id": "id", - "function": { - "arguments": "%s", - "name": "%s" - }, - "type": "%s" - } - ] - } - ], - "model": "%s", - "max_tokens": 100, - "n": 1, - "stop": ["%s"], - "temperature": %.5f, - "tool_choice": { - "type": "%s", - "function": { - "name": "%s" - } - }, - "tools": [ - { - "type": "%s", - "function": { - "description": "Fetches the weather in the given location", - "name": "get_weather", - "parameters": { - "type": "object", - "properties": { - "unit": { - "description": "The unit to return the temperature in", - "type": "string", - "enum": ["F", "C"] - }, - "location": { - "description": "The location to get the weather for", - "type": "string" - } - }, - "additionalProperties": false, - "required": ["location", "unit"] - }, - "strict": true - } - } - ], - "top_p": %.5f, - "stream": true, - "stream_options": { - "include_usage": true - } - } - """, - randomContent, - randomToolCallId, - randomArguments, - randomFunctionName, - randomType, - randomModel, - randomStop, - randomTemperature, - randomType, - randomFunctionName, - randomType, - randomTopP - ); - assertJsonEquals(jsonString, expectedJson); - } - - public void testSerializationWithDifferentContentTypes() throws IOException { - Random random = Randomness.get(); - - String randomContentString = "Hello, world! " + random.nextInt(1000); - - String randomText = "Random text " + random.nextInt(1000); - String randomType = "type" + random.nextInt(1000); - UnifiedCompletionRequest.ContentObject contentObject = new UnifiedCompletionRequest.ContentObject(randomText, randomType); - - var contentObjectsList = new ArrayList(); - contentObjectsList.add(contentObject); - UnifiedCompletionRequest.ContentObjects contentObjects = new UnifiedCompletionRequest.ContentObjects(contentObjectsList); - - UnifiedCompletionRequest.Message messageWithString = new UnifiedCompletionRequest.Message( - new UnifiedCompletionRequest.ContentString(randomContentString), - ROLE, - null, - null - ); - - UnifiedCompletionRequest.Message messageWithObjects = new UnifiedCompletionRequest.Message(contentObjects, ROLE, null, null); - var messageList = new ArrayList(); - messageList.add(messageWithString); - messageList.add(messageWithObjects); - - UnifiedCompletionRequest unifiedRequest = UnifiedCompletionRequest.of(messageList); - - UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); - String jsonString = entityString(unifiedChatInput); - String expectedJson = String.format(Locale.US, """ - { - "messages": [ - { - "content": "%s", - "role": "user" - }, - { - "content": [ - { - "text": "%s", - "type": "%s" - } - ], - "role": "user" - } - ], - "model": "model-name", - "n": 1, - "stream": true, - "stream_options": { - "include_usage": true - } - } - """, randomContentString, randomText, randomType); - assertJsonEquals(jsonString, expectedJson); - } - - public void testSerializationWithSpecialCharacters() throws IOException { - UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( - new UnifiedCompletionRequest.ContentString("Hello, world! \n \"Special\" characters: \t \\ /"), - ROLE, - "tool_call_id\twith\ttabs", - Collections.singletonList( - new UnifiedCompletionRequest.ToolCall( - "id\\with\\backslashes", - new UnifiedCompletionRequest.ToolCall.FunctionField("arguments\"with\"quotes", "function_name/with/slashes"), - "type" - ) - ) - ); - var messageList = new ArrayList(); - messageList.add(message); - UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( - messageList, - null, // model - null, // maxTokens - null, // stop - null, // temperature - null, // toolChoice - null, // tools - null // topP - ); - - UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); - String jsonString = entityString(unifiedChatInput); - String expectedJson = """ - { - "messages": [ - { - "content": "Hello, world! \\n \\"Special\\" characters: \\t \\\\ /", - "role": "user", - "tool_call_id": "tool_call_id\\twith\\ttabs", - "tool_calls": [ - { - "id": "id\\\\with\\\\backslashes", - "function": { - "arguments": "arguments\\"with\\"quotes", - "name": "function_name/with/slashes" - }, - "type": "type" - } - ] - } - ], - "model": "model-name", - "n": 1, - "stream": true, - "stream_options": { - "include_usage": true - } - } - """; - assertJsonEquals(jsonString, expectedJson); - } - - public void testSerializationWithBooleanFields() throws IOException { - UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( - new UnifiedCompletionRequest.ContentString("Hello, world!"), - ROLE, - null, - null - ); - var messageList = new ArrayList(); - messageList.add(message); - UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( - messageList, - null, // model - null, // maxTokens - null, // stop - null, // temperature - null, // toolChoice - null, // tools - null // topP - ); - - UnifiedChatInput unifiedChatInputTrue = new UnifiedChatInput(unifiedRequest, true); - String jsonStringTrue = entityString(unifiedChatInputTrue); - String expectedJsonTrue = """ - { - "messages": [ - { - "content": "Hello, world!", - "role": "user" - } - ], - "model": "model-name", - "n": 1, - "stream": true, - "stream_options": { - "include_usage": true - } - } - """; - assertJsonEquals(expectedJsonTrue, jsonStringTrue); - - UnifiedChatInput unifiedChatInputFalse = new UnifiedChatInput(unifiedRequest, false); - String jsonStringFalse = entityString(unifiedChatInputFalse); - String expectedJsonFalse = """ - { - "messages": [ - { - "content": "Hello, world!", - "role": "user" - } - ], - "model": "model-name", - "n": 1, - "stream": false - } - """; - assertJsonEquals(expectedJsonFalse, jsonStringFalse); - } - - public void testSerializationWithoutContentField() throws IOException { - UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( - null, - "assistant", - "tool_call_id\twith\ttabs", - Collections.singletonList( - new UnifiedCompletionRequest.ToolCall( - "id\\with\\backslashes", - new UnifiedCompletionRequest.ToolCall.FunctionField("arguments\"with\"quotes", "function_name/with/slashes"), - "type" - ) - ) - ); - var messageList = new ArrayList(); - messageList.add(message); - UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest(messageList, null, null, null, null, null, null, null); - - UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); - String jsonString = entityString(unifiedChatInput); - String expectedJson = """ - { - "messages": [ - { - "role": "assistant", - "tool_call_id": "tool_call_id\\twith\\ttabs", - "tool_calls": [ - { - "id": "id\\\\with\\\\backslashes", - "function": { - "arguments": "arguments\\"with\\"quotes", - "name": "function_name/with/slashes" - }, - "type": "type" - } - ] - } - ], - "model": "model-name", - "n": 1, - "stream": true, - "stream_options": { - "include_usage": true - } - } - """; - assertJsonEquals(jsonString, expectedJson); - } - - private static Map createParameters() { - Map parameters = new LinkedHashMap<>(); - parameters.put("type", "object"); - - Map properties = new HashMap<>(); - - Map location = new HashMap<>(); - location.put("type", "string"); - location.put("description", "The location to get the weather for"); - properties.put("location", location); - - Map unit = new HashMap<>(); - unit.put("type", "string"); - unit.put("description", "The unit to return the temperature in"); - unit.put("enum", new String[] { "F", "C" }); - properties.put("unit", unit); - - parameters.put("properties", properties); - parameters.put("additionalProperties", false); - parameters.put("required", new String[] { "location", "unit" }); - - return parameters; - } -} From fe5d29c7d47c28e5c8ddbbcb27aef77e5b2ff3e5 Mon Sep 17 00:00:00 2001 From: Pat Whelan Date: Tue, 11 Mar 2025 09:44:26 -0400 Subject: [PATCH 09/11] Fix tests to match the new stream error handling --- .../deepseek/DeepSeekServiceTests.java | 63 +++++++++---------- 1 file changed, 31 insertions(+), 32 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java index fdc88df87f815..277eba9e7dbfc 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java @@ -240,8 +240,7 @@ public void testDoUnifiedInfer() throws Exception { data: [DONE] """)); - var result = doUnifiedCompletionInfer(); - InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoErrors().hasEvent(""" + doUnifiedCompletionInfer().hasNoErrors().hasEvent(""" {"id":"12345","choices":[{"delta":{"content":"hello, world","role":"assistant"},"index":0}],""" + """ "model":"deepseek-chat","object":"chat.completion.chunk"}"""); } @@ -251,13 +250,18 @@ public void testDoInfer() throws Exception { {"choices": [{"message": {"content": "hello, world", "role": "assistant"}, "finish_reason": "stop", "index": 0, \ "logprobs": null}], "created": 1718345013, "id": "12345", "model": "deepseek-chat", \ "object": "chat.completion", "system_fingerprint": "fp_1234"}""")); - var result = doInfer(false); - assertThat(result, isA(ChatCompletionResults.class)); - var completionResults = (ChatCompletionResults) result; - assertThat( - completionResults.results().stream().map(ChatCompletionResults.Result::predictedValue).toList(), - equalTo(List.of("hello, world")) - ); + try (var service = createService()) { + var model = createModel(service, TaskType.COMPLETION); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer(model, null, List.of("hello"), false, Map.of(), InputType.UNSPECIFIED, TIMEOUT, listener); + var result = listener.actionGet(TIMEOUT); + assertThat(result, isA(ChatCompletionResults.class)); + var completionResults = (ChatCompletionResults) result; + assertThat( + completionResults.results().stream().map(ChatCompletionResults.Result::predictedValue).toList(), + equalTo(List.of("hello, world")) + ); + } } public void testDoInferStream() throws Exception { @@ -269,12 +273,16 @@ public void testDoInferStream() throws Exception { data: [DONE] """)); - var result = doInfer(true); - InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoErrors().hasEvent(""" - {"completion":[{"delta":"hello, world"}]}"""); + try (var service = createService()) { + var model = createModel(service, TaskType.COMPLETION); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer(model, null, List.of("hello"), true, Map.of(), InputType.UNSPECIFIED, TIMEOUT, listener); + InferenceEventsAssertion.assertThat(listener.actionGet(TIMEOUT)).hasFinishedStream().hasNoErrors().hasEvent(""" + {"completion":[{"delta":"hello, world"}]}"""); + } } - public void testUnifiedCompletionError() throws Exception { + public void testUnifiedCompletionError() { String responseJson = """ { "error": { @@ -285,14 +293,14 @@ public void testUnifiedCompletionError() throws Exception { } }"""; webServer.enqueue(new MockResponse().setResponseCode(404).setBody(responseJson)); - testStreamError(""" - {\ - "error":{\ - "code":"model_not_found",\ - "message":"Received an unsuccessful status code for request from inference entity id [inference-id] status \ - [404]. Error message: [The model `deepseek-not-chat` does not exist or you do not have access to it.]",\ - "type":"invalid_request_error"\ - }}"""); + var e = assertThrows(UnifiedChatCompletionException.class, this::doUnifiedCompletionInfer); + assertThat( + e.getMessage(), + equalTo( + "Received an unsuccessful status code for request from inference entity id [inference-id] status" + + " [404]. Error message: [The model `deepseek-not-chat` does not exist or you do not have access to it.]" + ) + ); } private void testStreamError(String expectedResponse) throws Exception { @@ -399,7 +407,7 @@ private DeepSeekChatCompletionModel parsePersistedConfig(String json) throws IOE } } - private InferenceServiceResults doUnifiedCompletionInfer() throws Exception { + private InferenceEventsAssertion doUnifiedCompletionInfer() throws Exception { try (var service = createService()) { var model = createModel(service, TaskType.CHAT_COMPLETION); PlainActionFuture listener = new PlainActionFuture<>(); @@ -411,16 +419,7 @@ private InferenceServiceResults doUnifiedCompletionInfer() throws Exception { TIMEOUT, listener ); - return listener.get(30, TimeUnit.SECONDS); - } - } - - private InferenceServiceResults doInfer(boolean stream) throws Exception { - try (var service = createService()) { - var model = createModel(service, TaskType.COMPLETION); - PlainActionFuture listener = new PlainActionFuture<>(); - service.infer(model, null, List.of("hello"), stream, Map.of(), InputType.UNSPECIFIED, TIMEOUT, listener); - return listener.get(30, TimeUnit.SECONDS); + return InferenceEventsAssertion.assertThat(listener.actionGet(TIMEOUT)).hasFinishedStream(); } } From ea24ac1039502f631c41032453fa956be6595c36 Mon Sep 17 00:00:00 2001 From: Pat Whelan Date: Tue, 11 Mar 2025 09:45:49 -0400 Subject: [PATCH 10/11] Adding per-node rate limit comment --- .../services/deepseek/DeepSeekChatCompletionModel.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekChatCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekChatCompletionModel.java index 2391eb49ff04b..bcfcf279ab768 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekChatCompletionModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekChatCompletionModel.java @@ -57,8 +57,10 @@ * - The website claims to want unlimited, so we're setting it as MAX_INT per minute? */ public class DeepSeekChatCompletionModel extends Model { + // Per-node rate limit group and settings, limiting the outbound requests this node can make to INTEGER.MAX_VALUE per minute. private static final Object RATE_LIMIT_GROUP = new Object(); private static final RateLimitSettings RATE_LIMIT_SETTINGS = new RateLimitSettings(Integer.MAX_VALUE); + private static final URI DEFAULT_URI = URI.create("https://api.deepseek.com/chat/completions"); private final DeepSeekServiceSettings serviceSettings; private final DefaultSecretSettings secretSettings; From 68dbda32d70c9f9bd2a4b7755d938ef429539af9 Mon Sep 17 00:00:00 2001 From: Pat Whelan Date: Tue, 11 Mar 2025 10:27:25 -0400 Subject: [PATCH 11/11] Fix import statements from merge --- .../inference/external/http/sender/DeepSeekRequestManager.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/DeepSeekRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/DeepSeekRequestManager.java index ec6d88f48a275..ffc5bfb1eb918 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/DeepSeekRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/DeepSeekRequestManager.java @@ -15,9 +15,9 @@ import org.elasticsearch.xpack.inference.external.deepseek.DeepSeekChatCompletionRequest; import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.openai.OpenAiChatCompletionResponseEntity; import org.elasticsearch.xpack.inference.external.openai.OpenAiChatCompletionResponseHandler; import org.elasticsearch.xpack.inference.external.openai.OpenAiUnifiedChatCompletionResponseHandler; -import org.elasticsearch.xpack.inference.external.response.openai.OpenAiChatCompletionResponseEntity; import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel; import java.util.Objects;