Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/122218.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 122218
summary: Integrate with `DeepSeek` API
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ static TransportVersion def(int id) {
public static final TransportVersion JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED_BACKPORT_8_19 = def(8_841_0_06);
public static final TransportVersion RETRY_ILM_ASYNC_ACTION_REQUIRE_ERROR_8_19 = def(8_841_0_07);
public static final TransportVersion INFERENCE_CONTEXT_8_X = def(8_841_0_08);
public static final TransportVersion ML_INFERENCE_DEEPSEEK_8_19 = def(8_841_0_09);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
@SuppressWarnings("unchecked")
public void testGetServicesWithoutTaskType() throws IOException {
List<Object> services = getAllServices();
assertThat(services.size(), equalTo(20));
assertThat(services.size(), equalTo(21));

String[] providers = new String[services.size()];
for (int i = 0; i < services.size(); i++) {
Expand All @@ -42,6 +42,7 @@ public void testGetServicesWithoutTaskType() throws IOException {
"azureaistudio",
"azureopenai",
"cohere",
"deepseek",
"elastic",
"elasticsearch",
"googleaistudio",
Expand Down Expand Up @@ -115,7 +116,7 @@ public void testGetServicesWithRerankTaskType() throws IOException {
@SuppressWarnings("unchecked")
public void testGetServicesWithCompletionTaskType() throws IOException {
List<Object> services = getServices(TaskType.COMPLETION);
assertThat(services.size(), equalTo(9));
assertThat(services.size(), equalTo(10));

String[] providers = new String[services.size()];
for (int i = 0; i < services.size(); i++) {
Expand All @@ -131,6 +132,7 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
"azureaistudio",
"azureopenai",
"cohere",
"deepseek",
"googleaistudio",
"openai",
"streaming_completion_test_service"
Expand All @@ -143,15 +145,15 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
@SuppressWarnings("unchecked")
public void testGetServicesWithChatCompletionTaskType() throws IOException {
List<Object> services = getServices(TaskType.CHAT_COMPLETION);
assertThat(services.size(), equalTo(3));
assertThat(services.size(), equalTo(4));

String[] providers = new String[services.size()];
for (int i = 0; i < services.size(); i++) {
Map<String, Object> serviceConfig = (Map<String, Object>) services.get(i);
providers[i] = (String) serviceConfig.get("service");
}

assertArrayEquals(List.of("elastic", "openai", "streaming_completion_test_service").toArray(), providers);
assertArrayEquals(List.of("deepseek", "elastic", "openai", "streaming_completion_test_service").toArray(), providers);
}

@SuppressWarnings("unchecked")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings;
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalServiceSettings;
Expand Down Expand Up @@ -153,6 +154,7 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
addUnifiedNamedWriteables(namedWriteables);

namedWriteables.addAll(StreamingTaskManager.namedWriteables());
namedWriteables.addAll(DeepSeekChatCompletionModel.namedWriteables());

return namedWriteables;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioService;
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiService;
import org.elasticsearch.xpack.inference.services.cohere.CohereService;
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekService;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings;
Expand Down Expand Up @@ -362,6 +363,7 @@ public List<InferenceServiceExtension.Factory> getInferenceServiceFactories() {
context -> new IbmWatsonxService(httpFactory.get(), serviceComponents.get()),
context -> new JinaAIService(httpFactory.get(), serviceComponents.get()),
context -> new VoyageAIService(httpFactory.get(), serviceComponents.get()),
context -> new DeepSeekService(httpFactory.get(), serviceComponents.get()),
ElasticsearchInternalService::new
);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.deepseek;

import org.apache.http.HttpHeaders;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.entity.ByteArrayEntity;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.common.Strings;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xcontent.json.JsonXContent;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.external.unified.UnifiedChatCompletionRequestEntity;
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel;

import java.io.IOException;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.Objects;

import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader;

public class DeepSeekChatCompletionRequest implements Request {
private static final String MODEL_FIELD = "model";
private static final String MAX_TOKENS = "max_tokens";

private final DeepSeekChatCompletionModel model;
private final UnifiedChatInput unifiedChatInput;

public DeepSeekChatCompletionRequest(UnifiedChatInput unifiedChatInput, DeepSeekChatCompletionModel model) {
this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput);
this.model = Objects.requireNonNull(model);
}

@Override
public HttpRequest createHttpRequest() {
HttpPost httpPost = new HttpPost(model.uri());

httpPost.setEntity(createEntity());

httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType());
httpPost.setHeader(createAuthBearerHeader(model.apiKey()));

return new HttpRequest(httpPost, getInferenceEntityId());
}

private ByteArrayEntity createEntity() {
var modelId = Objects.requireNonNullElseGet(unifiedChatInput.getRequest().model(), model::model);
try (var builder = JsonXContent.contentBuilder()) {
builder.startObject();
new UnifiedChatCompletionRequestEntity(unifiedChatInput).toXContent(builder, ToXContent.EMPTY_PARAMS);
builder.field(MODEL_FIELD, modelId);

if (unifiedChatInput.getRequest().maxCompletionTokens() != null) {
builder.field(MAX_TOKENS, unifiedChatInput.getRequest().maxCompletionTokens());
}

builder.endObject();
return new ByteArrayEntity(Strings.toString(builder).getBytes(StandardCharsets.UTF_8));
} catch (IOException e) {
throw new ElasticsearchException("Failed to serialize request payload.", e);
}
}

@Override
public URI getURI() {
return model.uri();
}

@Override
public Request truncate() {
return this;
}

@Override
public boolean[] getTruncationInfo() {
return null;
}

@Override
public String getInferenceEntityId() {
return model.getInferenceEntityId();
}

@Override
public boolean isStreaming() {
return unifiedChatInput.stream();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.http.sender;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.external.deepseek.DeepSeekChatCompletionRequest;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.openai.OpenAiChatCompletionResponseEntity;
import org.elasticsearch.xpack.inference.external.openai.OpenAiChatCompletionResponseHandler;
import org.elasticsearch.xpack.inference.external.openai.OpenAiUnifiedChatCompletionResponseHandler;
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel;

import java.util.Objects;
import java.util.function.Supplier;

import static org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs.createUnsupportedTypeException;

public class DeepSeekRequestManager extends BaseRequestManager {

private static final Logger logger = LogManager.getLogger(DeepSeekRequestManager.class);

private static final ResponseHandler CHAT_COMPLETION = createChatCompletionHandler();
private static final ResponseHandler COMPLETION = createCompletionHandler();

private final DeepSeekChatCompletionModel model;

public DeepSeekRequestManager(DeepSeekChatCompletionModel model, ThreadPool threadPool) {
super(threadPool, model.getInferenceEntityId(), model.rateLimitGroup(), model.rateLimitSettings());
this.model = Objects.requireNonNull(model);
}

@Override
public void execute(
InferenceInputs inferenceInputs,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
switch (inferenceInputs) {
case UnifiedChatInput uci -> execute(uci, requestSender, hasRequestCompletedFunction, listener);
case ChatCompletionInput cci -> execute(cci, requestSender, hasRequestCompletedFunction, listener);
default -> throw createUnsupportedTypeException(inferenceInputs, UnifiedChatInput.class);
}
}

private void execute(
UnifiedChatInput inferenceInputs,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
var request = new DeepSeekChatCompletionRequest(inferenceInputs, model);
execute(new ExecutableInferenceRequest(requestSender, logger, request, CHAT_COMPLETION, hasRequestCompletedFunction, listener));
}

private void execute(
ChatCompletionInput inferenceInputs,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
var unifiedInputs = new UnifiedChatInput(inferenceInputs.getInputs(), "user", inferenceInputs.stream());
var request = new DeepSeekChatCompletionRequest(unifiedInputs, model);
execute(new ExecutableInferenceRequest(requestSender, logger, request, COMPLETION, hasRequestCompletedFunction, listener));
}

private static ResponseHandler createChatCompletionHandler() {
return new OpenAiUnifiedChatCompletionResponseHandler("deepseek chat completion", OpenAiChatCompletionResponseEntity::fromResponse);
}

private static ResponseHandler createCompletionHandler() {
return new OpenAiChatCompletionResponseHandler("deepseek completion", OpenAiChatCompletionResponseEntity::fromResponse);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,15 @@ public class OpenAiUnifiedChatCompletionRequestEntity implements ToXContentObjec

public static final String USER_FIELD = "user";
private static final String MODEL_FIELD = "model";
private static final String MAX_COMPLETION_TOKENS_FIELD = "max_completion_tokens";

private final UnifiedChatInput unifiedChatInput;
private final OpenAiChatCompletionModel model;
private final UnifiedChatCompletionRequestEntity unifiedRequestEntity;

public OpenAiUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, OpenAiChatCompletionModel model) {
this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(Objects.requireNonNull(unifiedChatInput));
this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput);
this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(unifiedChatInput);
this.model = Objects.requireNonNull(model);
}

Expand All @@ -41,6 +44,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(USER_FIELD, model.getTaskSettings().user());
}

if (unifiedChatInput.getRequest().maxCompletionTokens() != null) {
builder.field(MAX_COMPLETION_TOKENS_FIELD, unifiedChatInput.getRequest().maxCompletionTokens());
}

builder.endObject();

return builder;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@

public class ElasticInferenceServiceUnifiedChatCompletionRequestEntity implements ToXContentObject {
private static final String MODEL_FIELD = "model";
private static final String MAX_COMPLETION_TOKENS_FIELD = "max_completion_tokens";

private final UnifiedChatInput unifiedChatInput;
private final UnifiedChatCompletionRequestEntity unifiedRequestEntity;
private final String modelId;

public ElasticInferenceServiceUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, String modelId) {
this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(Objects.requireNonNull(unifiedChatInput));
this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput);
this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(unifiedChatInput);
this.modelId = Objects.requireNonNull(modelId);
}

Expand All @@ -31,6 +34,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.startObject();
unifiedRequestEntity.toXContent(builder, params);
builder.field(MODEL_FIELD, modelId);

if (unifiedChatInput.getRequest().maxCompletionTokens() != null) {
builder.field(MAX_COMPLETION_TOKENS_FIELD, unifiedChatInput.getRequest().maxCompletionTokens());
}

builder.endObject();

return builder;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ public class UnifiedChatCompletionRequestEntity implements ToXContentFragment {
public static final String MESSAGES_FIELD = "messages";
private static final String ROLE_FIELD = "role";
private static final String CONTENT_FIELD = "content";
private static final String MAX_COMPLETION_TOKENS_FIELD = "max_completion_tokens";
private static final String STOP_FIELD = "stop";
private static final String TEMPERATURE_FIELD = "temperature";
private static final String TOOL_CHOICE_FIELD = "tool_choice";
Expand Down Expand Up @@ -102,10 +101,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
}
builder.endArray();

if (unifiedRequest.maxCompletionTokens() != null) {
builder.field(MAX_COMPLETION_TOKENS_FIELD, unifiedRequest.maxCompletionTokens());
}

// Underlying providers expect OpenAI to only return 1 possible choice.
builder.field(NUMBER_OF_RETURNED_CHOICES_FIELD, 1);

Expand Down
Loading