Skip to content

Commit ec94f53

Browse files
committed
[ML] Integrate with DeepSeek API
Integrating for Chat Completion and Completion task types, both calling the chat completion API for DeepSeek.
1 parent 27adf20 commit ec94f53

File tree

8 files changed

+960
-0
lines changed

8 files changed

+960
-0
lines changed

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ static TransportVersion def(int id) {
185185
public static final TransportVersion ESQL_DRIVER_TASK_DESCRIPTION = def(9_005_0_00);
186186
public static final TransportVersion ESQL_RETRY_ON_SHARD_LEVEL_FAILURE = def(9_006_0_00);
187187
public static final TransportVersion ESQL_PROFILE_ASYNC_NANOS = def(9_007_00_0);
188+
public static final TransportVersion ML_INFERENCE_DEEPSEEK = def(9_008_00_0);
188189

189190
/*
190191
* STOP! READ THIS FIRST! No, really,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings;
5757
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankServiceSettings;
5858
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings;
59+
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel;
5960
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsServiceSettings;
6061
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalServiceSettings;
6162
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalTextEmbeddingServiceSettings;
@@ -144,6 +145,7 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
144145
addUnifiedNamedWriteables(namedWriteables);
145146

146147
namedWriteables.addAll(StreamingTaskManager.namedWriteables());
148+
namedWriteables.addAll(DeepSeekChatCompletionModel.namedWriteables());
147149

148150
return namedWriteables;
149151
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@
114114
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioService;
115115
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiService;
116116
import org.elasticsearch.xpack.inference.services.cohere.CohereService;
117+
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekService;
117118
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;
118119
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents;
119120
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings;
@@ -357,6 +358,7 @@ public List<InferenceServiceExtension.Factory> getInferenceServiceFactories() {
357358
context -> new AlibabaCloudSearchService(httpFactory.get(), serviceComponents.get()),
358359
context -> new IbmWatsonxService(httpFactory.get(), serviceComponents.get()),
359360
context -> new JinaAIService(httpFactory.get(), serviceComponents.get()),
361+
context -> new DeepSeekService(httpFactory.get(), serviceComponents.get()),
360362
ElasticsearchInternalService::new
361363
);
362364
}
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.http.sender;
9+
10+
import org.apache.logging.log4j.LogManager;
11+
import org.apache.logging.log4j.Logger;
12+
import org.elasticsearch.action.ActionListener;
13+
import org.elasticsearch.inference.InferenceServiceResults;
14+
import org.elasticsearch.threadpool.ThreadPool;
15+
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
16+
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
17+
import org.elasticsearch.xpack.inference.external.openai.OpenAiChatCompletionResponseHandler;
18+
import org.elasticsearch.xpack.inference.external.openai.OpenAiUnifiedChatCompletionResponseHandler;
19+
import org.elasticsearch.xpack.inference.external.request.deepseek.DeepSeekChatCompletionRequest;
20+
import org.elasticsearch.xpack.inference.external.response.openai.OpenAiChatCompletionResponseEntity;
21+
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel;
22+
23+
import java.util.Objects;
24+
import java.util.function.Supplier;
25+
26+
import static org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs.createUnsupportedTypeException;
27+
28+
public class DeepSeekRequestManager extends BaseRequestManager {
29+
30+
private static final Logger logger = LogManager.getLogger(DeepSeekRequestManager.class);
31+
32+
private static final ResponseHandler CHAT_COMPLETION = createChatCompletionHandler();
33+
private static final ResponseHandler COMPLETION = createCompletionHandler();
34+
35+
private final DeepSeekChatCompletionModel model;
36+
37+
public DeepSeekRequestManager(DeepSeekChatCompletionModel model, ThreadPool threadPool) {
38+
super(threadPool, model.getInferenceEntityId(), model.rateLimitGroup(), model.rateLimitSettings());
39+
this.model = Objects.requireNonNull(model);
40+
}
41+
42+
@Override
43+
public void execute(
44+
InferenceInputs inferenceInputs,
45+
RequestSender requestSender,
46+
Supplier<Boolean> hasRequestCompletedFunction,
47+
ActionListener<InferenceServiceResults> listener
48+
) {
49+
switch (inferenceInputs) {
50+
case UnifiedChatInput uci -> execute(uci, requestSender, hasRequestCompletedFunction, listener);
51+
case ChatCompletionInput cci -> execute(cci, requestSender, hasRequestCompletedFunction, listener);
52+
default -> throw createUnsupportedTypeException(inferenceInputs, UnifiedChatInput.class);
53+
}
54+
}
55+
56+
private void execute(
57+
UnifiedChatInput inferenceInputs,
58+
RequestSender requestSender,
59+
Supplier<Boolean> hasRequestCompletedFunction,
60+
ActionListener<InferenceServiceResults> listener
61+
) {
62+
var request = new DeepSeekChatCompletionRequest(inferenceInputs, model);
63+
execute(new ExecutableInferenceRequest(requestSender, logger, request, CHAT_COMPLETION, hasRequestCompletedFunction, listener));
64+
}
65+
66+
private void execute(
67+
ChatCompletionInput inferenceInputs,
68+
RequestSender requestSender,
69+
Supplier<Boolean> hasRequestCompletedFunction,
70+
ActionListener<InferenceServiceResults> listener
71+
) {
72+
var unifiedInputs = new UnifiedChatInput(inferenceInputs.getInputs(), "user", inferenceInputs.stream());
73+
var request = new DeepSeekChatCompletionRequest(unifiedInputs, model);
74+
execute(new ExecutableInferenceRequest(requestSender, logger, request, COMPLETION, hasRequestCompletedFunction, listener));
75+
}
76+
77+
private static ResponseHandler createChatCompletionHandler() {
78+
return new OpenAiUnifiedChatCompletionResponseHandler("deepseek chat completion", OpenAiChatCompletionResponseEntity::fromResponse);
79+
}
80+
81+
private static ResponseHandler createCompletionHandler() {
82+
return new OpenAiChatCompletionResponseHandler("deepseek completion", OpenAiChatCompletionResponseEntity::fromResponse);
83+
}
84+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.request.deepseek;
9+
10+
import org.apache.http.HttpHeaders;
11+
import org.apache.http.client.methods.HttpPost;
12+
import org.apache.http.entity.ByteArrayEntity;
13+
import org.elasticsearch.ElasticsearchException;
14+
import org.elasticsearch.common.Strings;
15+
import org.elasticsearch.xcontent.ToXContent;
16+
import org.elasticsearch.xcontent.XContentType;
17+
import org.elasticsearch.xcontent.json.JsonXContent;
18+
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
19+
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
20+
import org.elasticsearch.xpack.inference.external.request.Request;
21+
import org.elasticsearch.xpack.inference.external.unified.UnifiedChatCompletionRequestEntity;
22+
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel;
23+
24+
import java.io.IOException;
25+
import java.net.URI;
26+
import java.nio.charset.StandardCharsets;
27+
import java.util.Objects;
28+
29+
import static org.elasticsearch.xpack.inference.external.openai.OpenAiUnifiedStreamingProcessor.MODEL_FIELD;
30+
import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader;
31+
32+
public class DeepSeekChatCompletionRequest implements Request {
33+
34+
private final DeepSeekChatCompletionModel model;
35+
private final UnifiedChatInput unifiedChatInput;
36+
37+
public DeepSeekChatCompletionRequest(UnifiedChatInput unifiedChatInput, DeepSeekChatCompletionModel model) {
38+
this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput);
39+
this.model = Objects.requireNonNull(model);
40+
}
41+
42+
@Override
43+
public HttpRequest createHttpRequest() {
44+
HttpPost httpPost = new HttpPost(model.uri());
45+
46+
httpPost.setEntity(createEntity());
47+
48+
httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType());
49+
httpPost.setHeader(createAuthBearerHeader(model.apiKey()));
50+
51+
return new HttpRequest(httpPost, getInferenceEntityId());
52+
}
53+
54+
private ByteArrayEntity createEntity() {
55+
var modelId = Objects.requireNonNullElseGet(unifiedChatInput.getRequest().model(), model::model);
56+
try (var builder = JsonXContent.contentBuilder()) {
57+
builder.startObject();
58+
new UnifiedChatCompletionRequestEntity(unifiedChatInput).toXContent(builder, ToXContent.EMPTY_PARAMS);
59+
builder.field(MODEL_FIELD, modelId);
60+
builder.endObject();
61+
return new ByteArrayEntity(Strings.toString(builder).getBytes(StandardCharsets.UTF_8));
62+
} catch (IOException e) {
63+
throw new ElasticsearchException("Failed to serialize request payload.", e);
64+
}
65+
}
66+
67+
@Override
68+
public URI getURI() {
69+
return model.uri();
70+
}
71+
72+
@Override
73+
public Request truncate() {
74+
// No truncation for OpenAI chat completions
75+
return this;
76+
}
77+
78+
@Override
79+
public boolean[] getTruncationInfo() {
80+
// No truncation for OpenAI chat completions
81+
return null;
82+
}
83+
84+
@Override
85+
public String getInferenceEntityId() {
86+
return model.getInferenceEntityId();
87+
}
88+
89+
@Override
90+
public boolean isStreaming() {
91+
return unifiedChatInput.stream();
92+
}
93+
}

0 commit comments

Comments
 (0)