Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/122218.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 122218
summary: Integrate with `DeepSeek` API
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ static TransportVersion def(int id) {
public static final TransportVersion VOYAGE_AI_INTEGRATION_ADDED = def(9_014_0_00);
public static final TransportVersion BYTE_SIZE_VALUE_ALWAYS_USES_BYTES = def(9_015_0_00);
public static final TransportVersion ESQL_SERIALIZE_SOURCE_FUNCTIONS_WARNINGS = def(9_016_0_00);
public static final TransportVersion ML_INFERENCE_DEEPSEEK = def(9_017_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
@SuppressWarnings("unchecked")
public void testGetServicesWithoutTaskType() throws IOException {
List<Object> services = getAllServices();
assertThat(services.size(), equalTo(20));
assertThat(services.size(), equalTo(21));

String[] providers = new String[services.size()];
for (int i = 0; i < services.size(); i++) {
Expand All @@ -41,6 +41,7 @@ public void testGetServicesWithoutTaskType() throws IOException {
"azureaistudio",
"azureopenai",
"cohere",
"deepseek",
"elastic",
"elasticsearch",
"googleaistudio",
Expand Down Expand Up @@ -114,7 +115,7 @@ public void testGetServicesWithRerankTaskType() throws IOException {
@SuppressWarnings("unchecked")
public void testGetServicesWithCompletionTaskType() throws IOException {
List<Object> services = getServices(TaskType.COMPLETION);
assertThat(services.size(), equalTo(9));
assertThat(services.size(), equalTo(10));

String[] providers = new String[services.size()];
for (int i = 0; i < services.size(); i++) {
Expand All @@ -130,6 +131,7 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
"azureaistudio",
"azureopenai",
"cohere",
"deepseek",
"googleaistudio",
"openai",
"streaming_completion_test_service"
Expand All @@ -141,15 +143,15 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
@SuppressWarnings("unchecked")
public void testGetServicesWithChatCompletionTaskType() throws IOException {
List<Object> services = getServices(TaskType.CHAT_COMPLETION);
assertThat(services.size(), equalTo(3));
assertThat(services.size(), equalTo(4));

String[] providers = new String[services.size()];
for (int i = 0; i < services.size(); i++) {
Map<String, Object> serviceConfig = (Map<String, Object>) services.get(i);
providers[i] = (String) serviceConfig.get("service");
}

assertArrayEquals(List.of("elastic", "openai", "streaming_completion_test_service").toArray(), providers);
assertArrayEquals(List.of("deepseek", "elastic", "openai", "streaming_completion_test_service").toArray(), providers);
}

@SuppressWarnings("unchecked")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings;
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalServiceSettings;
Expand Down Expand Up @@ -153,6 +154,7 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
addUnifiedNamedWriteables(namedWriteables);

namedWriteables.addAll(StreamingTaskManager.namedWriteables());
namedWriteables.addAll(DeepSeekChatCompletionModel.namedWriteables());

return namedWriteables;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioService;
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiService;
import org.elasticsearch.xpack.inference.services.cohere.CohereService;
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekService;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings;
Expand Down Expand Up @@ -361,6 +362,7 @@ public List<InferenceServiceExtension.Factory> getInferenceServiceFactories() {
context -> new IbmWatsonxService(httpFactory.get(), serviceComponents.get()),
context -> new JinaAIService(httpFactory.get(), serviceComponents.get()),
context -> new VoyageAIService(httpFactory.get(), serviceComponents.get()),
context -> new DeepSeekService(httpFactory.get(), serviceComponents.get()),
ElasticsearchInternalService::new
);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.http.sender;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.openai.OpenAiChatCompletionResponseHandler;
import org.elasticsearch.xpack.inference.external.openai.OpenAiUnifiedChatCompletionResponseHandler;
import org.elasticsearch.xpack.inference.external.request.deepseek.DeepSeekChatCompletionRequest;
import org.elasticsearch.xpack.inference.external.response.openai.OpenAiChatCompletionResponseEntity;
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel;

import java.util.Objects;
import java.util.function.Supplier;

import static org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs.createUnsupportedTypeException;

public class DeepSeekRequestManager extends BaseRequestManager {

private static final Logger logger = LogManager.getLogger(DeepSeekRequestManager.class);

private static final ResponseHandler CHAT_COMPLETION = createChatCompletionHandler();
private static final ResponseHandler COMPLETION = createCompletionHandler();

private final DeepSeekChatCompletionModel model;

public DeepSeekRequestManager(DeepSeekChatCompletionModel model, ThreadPool threadPool) {
super(threadPool, model.getInferenceEntityId(), model.rateLimitGroup(), model.rateLimitSettings());
this.model = Objects.requireNonNull(model);
}

@Override
public void execute(
InferenceInputs inferenceInputs,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
switch (inferenceInputs) {
case UnifiedChatInput uci -> execute(uci, requestSender, hasRequestCompletedFunction, listener);
case ChatCompletionInput cci -> execute(cci, requestSender, hasRequestCompletedFunction, listener);
default -> throw createUnsupportedTypeException(inferenceInputs, UnifiedChatInput.class);
}
}

private void execute(
UnifiedChatInput inferenceInputs,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
var request = new DeepSeekChatCompletionRequest(inferenceInputs, model);
execute(new ExecutableInferenceRequest(requestSender, logger, request, CHAT_COMPLETION, hasRequestCompletedFunction, listener));
}

private void execute(
ChatCompletionInput inferenceInputs,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
var unifiedInputs = new UnifiedChatInput(inferenceInputs.getInputs(), "user", inferenceInputs.stream());
var request = new DeepSeekChatCompletionRequest(unifiedInputs, model);
execute(new ExecutableInferenceRequest(requestSender, logger, request, COMPLETION, hasRequestCompletedFunction, listener));
}

private static ResponseHandler createChatCompletionHandler() {
return new OpenAiUnifiedChatCompletionResponseHandler("deepseek chat completion", OpenAiChatCompletionResponseEntity::fromResponse);
}

private static ResponseHandler createCompletionHandler() {
return new OpenAiChatCompletionResponseHandler("deepseek completion", OpenAiChatCompletionResponseEntity::fromResponse);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.request.deepseek;

import org.apache.http.HttpHeaders;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.entity.ByteArrayEntity;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.common.Strings;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xcontent.json.JsonXContent;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel;

import java.io.IOException;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.Objects;

import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader;

public class DeepSeekChatCompletionRequest implements Request {

private final DeepSeekChatCompletionModel model;
private final UnifiedChatInput unifiedChatInput;

public DeepSeekChatCompletionRequest(UnifiedChatInput unifiedChatInput, DeepSeekChatCompletionModel model) {
this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput);
this.model = Objects.requireNonNull(model);
}

@Override
public HttpRequest createHttpRequest() {
HttpPost httpPost = new HttpPost(model.uri());

httpPost.setEntity(createEntity());

httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType());
httpPost.setHeader(createAuthBearerHeader(model.apiKey()));

return new HttpRequest(httpPost, getInferenceEntityId());
}

private ByteArrayEntity createEntity() {
try (var builder = JsonXContent.contentBuilder()) {
new DeepSeekChatCompletionRequestEntity(unifiedChatInput, model).toXContent(builder, ToXContent.EMPTY_PARAMS);
return new ByteArrayEntity(Strings.toString(builder).getBytes(StandardCharsets.UTF_8));
} catch (IOException e) {
throw new ElasticsearchException("Failed to serialize request payload.", e);
}
}

@Override
public URI getURI() {
return model.uri();
}

@Override
public Request truncate() {
// No truncation for OpenAI chat completions
return this;
}

@Override
public boolean[] getTruncationInfo() {
// No truncation for OpenAI chat completions
return null;
}

@Override
public String getInferenceEntityId() {
return model.getInferenceEntityId();
}

@Override
public boolean isStreaming() {
return unifiedChatInput.stream();
}
}
Loading