Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/122218.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 122218
summary: Integrate with `DeepSeek` API
area: Machine Learning
type: enhancement
issues: []
2 changes: 2 additions & 0 deletions server/src/main/java/org/elasticsearch/TransportVersions.java
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ static TransportVersion def(int id) {
public static final TransportVersion VOYAGE_AI_INTEGRATION_ADDED_BACKPORT_8_X = def(8_841_0_05);
public static final TransportVersion JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED_BACKPORT_8_19 = def(8_841_0_06);
public static final TransportVersion RETRY_ILM_ASYNC_ACTION_REQUIRE_ERROR_8_19 = def(8_841_0_07);
public static final TransportVersion ML_INFERENCE_DEEPSEEK_8_19 = def(8_841_0_08);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0 = def(9_000_0_00);
public static final TransportVersion REMOVE_SNAPSHOT_FAILURES_90 = def(9_000_0_01);
public static final TransportVersion TRANSPORT_STATS_HANDLING_TIME_REQUIRED_90 = def(9_000_0_02);
Expand Down Expand Up @@ -180,6 +181,7 @@ static TransportVersion def(int id) {
public static final TransportVersion MAX_OPERATION_SIZE_REJECTIONS_ADDED = def(9_024_0_00);
public static final TransportVersion RETRY_ILM_ASYNC_ACTION_REQUIRE_ERROR = def(9_025_0_00);
public static final TransportVersion ESQL_SERIALIZE_BLOCK_TYPE_CODE = def(9_026_0_00);
public static final TransportVersion ML_INFERENCE_DEEPSEEK = def(9_027_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
@SuppressWarnings("unchecked")
public void testGetServicesWithoutTaskType() throws IOException {
List<Object> services = getAllServices();
assertThat(services.size(), equalTo(20));
assertThat(services.size(), equalTo(21));

String[] providers = new String[services.size()];
for (int i = 0; i < services.size(); i++) {
Expand All @@ -41,6 +41,7 @@ public void testGetServicesWithoutTaskType() throws IOException {
"azureaistudio",
"azureopenai",
"cohere",
"deepseek",
"elastic",
"elasticsearch",
"googleaistudio",
Expand Down Expand Up @@ -114,7 +115,7 @@ public void testGetServicesWithRerankTaskType() throws IOException {
@SuppressWarnings("unchecked")
public void testGetServicesWithCompletionTaskType() throws IOException {
List<Object> services = getServices(TaskType.COMPLETION);
assertThat(services.size(), equalTo(9));
assertThat(services.size(), equalTo(10));

String[] providers = new String[services.size()];
for (int i = 0; i < services.size(); i++) {
Expand All @@ -130,6 +131,7 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
"azureaistudio",
"azureopenai",
"cohere",
"deepseek",
"googleaistudio",
"openai",
"streaming_completion_test_service"
Expand All @@ -141,15 +143,15 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
@SuppressWarnings("unchecked")
public void testGetServicesWithChatCompletionTaskType() throws IOException {
List<Object> services = getServices(TaskType.CHAT_COMPLETION);
assertThat(services.size(), equalTo(3));
assertThat(services.size(), equalTo(4));

String[] providers = new String[services.size()];
for (int i = 0; i < services.size(); i++) {
Map<String, Object> serviceConfig = (Map<String, Object>) services.get(i);
providers[i] = (String) serviceConfig.get("service");
}

assertArrayEquals(List.of("elastic", "openai", "streaming_completion_test_service").toArray(), providers);
assertArrayEquals(List.of("deepseek", "elastic", "openai", "streaming_completion_test_service").toArray(), providers);
}

@SuppressWarnings("unchecked")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings;
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalServiceSettings;
Expand Down Expand Up @@ -153,6 +154,7 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
addUnifiedNamedWriteables(namedWriteables);

namedWriteables.addAll(StreamingTaskManager.namedWriteables());
namedWriteables.addAll(DeepSeekChatCompletionModel.namedWriteables());

return namedWriteables;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioService;
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiService;
import org.elasticsearch.xpack.inference.services.cohere.CohereService;
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekService;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings;
Expand Down Expand Up @@ -358,6 +359,7 @@ public List<InferenceServiceExtension.Factory> getInferenceServiceFactories() {
context -> new IbmWatsonxService(httpFactory.get(), serviceComponents.get()),
context -> new JinaAIService(httpFactory.get(), serviceComponents.get()),
context -> new VoyageAIService(httpFactory.get(), serviceComponents.get()),
context -> new DeepSeekService(httpFactory.get(), serviceComponents.get()),
ElasticsearchInternalService::new
);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.deepseek;

import org.apache.http.HttpHeaders;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.entity.ByteArrayEntity;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.common.Strings;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xcontent.json.JsonXContent;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel;

import java.io.IOException;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.Objects;

import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader;

public class DeepSeekChatCompletionRequest implements Request {

private final DeepSeekChatCompletionModel model;
private final UnifiedChatInput unifiedChatInput;

public DeepSeekChatCompletionRequest(UnifiedChatInput unifiedChatInput, DeepSeekChatCompletionModel model) {
this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput);
this.model = Objects.requireNonNull(model);
}

@Override
public HttpRequest createHttpRequest() {
HttpPost httpPost = new HttpPost(model.uri());

httpPost.setEntity(createEntity());

httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType());
httpPost.setHeader(createAuthBearerHeader(model.apiKey()));

return new HttpRequest(httpPost, getInferenceEntityId());
}

private ByteArrayEntity createEntity() {
try (var builder = JsonXContent.contentBuilder()) {
new DeepSeekChatCompletionRequestEntity(unifiedChatInput, model).toXContent(builder, ToXContent.EMPTY_PARAMS);
return new ByteArrayEntity(Strings.toString(builder).getBytes(StandardCharsets.UTF_8));
} catch (IOException e) {
throw new ElasticsearchException("Failed to serialize request payload.", e);
}
}

@Override
public URI getURI() {
return model.uri();
}

@Override
public Request truncate() {
return this;
}

@Override
public boolean[] getTruncationInfo() {
return null;
}

@Override
public String getInferenceEntityId() {
return model.getInferenceEntityId();
}

@Override
public boolean isStreaming() {
return unifiedChatInput.stream();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.deepseek;

import org.elasticsearch.inference.UnifiedCompletionRequest;
import org.elasticsearch.xcontent.ToXContentFragment;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel;

import java.io.IOException;
import java.util.Objects;

import static org.elasticsearch.xpack.inference.external.openai.OpenAiUnifiedStreamingProcessor.MODEL_FIELD;

class DeepSeekChatCompletionRequestEntity implements ToXContentFragment {

public static final String NAME_FIELD = "name";
public static final String TOOL_CALL_ID_FIELD = "tool_call_id";
public static final String TOOL_CALLS_FIELD = "tool_calls";
public static final String ID_FIELD = "id";
public static final String FUNCTION_FIELD = "function";
public static final String ARGUMENTS_FIELD = "arguments";
public static final String DESCRIPTION_FIELD = "description";
public static final String PARAMETERS_FIELD = "parameters";
public static final String STRICT_FIELD = "strict";
public static final String TOP_P_FIELD = "top_p";
public static final String STREAM_FIELD = "stream";
private static final String NUMBER_OF_RETURNED_CHOICES_FIELD = "n";
public static final String MESSAGES_FIELD = "messages";
private static final String ROLE_FIELD = "role";
private static final String CONTENT_FIELD = "content";
private static final String MAX_TOKENS = "max_tokens";
private static final String STOP_FIELD = "stop";
private static final String TEMPERATURE_FIELD = "temperature";
private static final String TOOL_CHOICE_FIELD = "tool_choice";
private static final String TOOL_FIELD = "tools";
private static final String TEXT_FIELD = "text";
private static final String TYPE_FIELD = "type";
private static final String STREAM_OPTIONS_FIELD = "stream_options";
private static final String INCLUDE_USAGE_FIELD = "include_usage";

private final DeepSeekChatCompletionModel model;
private final UnifiedCompletionRequest unifiedRequest;
private final boolean stream;

DeepSeekChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, DeepSeekChatCompletionModel model) {
Objects.requireNonNull(unifiedChatInput);
this.model = Objects.requireNonNull(model);
this.unifiedRequest = unifiedChatInput.getRequest();
this.stream = unifiedChatInput.stream();
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();

builder.startArray(MESSAGES_FIELD);
{
for (UnifiedCompletionRequest.Message message : unifiedRequest.messages()) {
builder.startObject();
{
switch (message.content()) {
case UnifiedCompletionRequest.ContentString contentString -> builder.field(CONTENT_FIELD, contentString.content());
case UnifiedCompletionRequest.ContentObjects contentObjects -> {
builder.startArray(CONTENT_FIELD);
for (UnifiedCompletionRequest.ContentObject contentObject : contentObjects.contentObjects()) {
builder.startObject();
builder.field(TEXT_FIELD, contentObject.text());
builder.field(TYPE_FIELD, contentObject.type());
builder.endObject();
}
builder.endArray();
}
case null -> {
// do nothing because content is optional
}
}

builder.field(ROLE_FIELD, message.role());
if (message.toolCallId() != null) {
builder.field(TOOL_CALL_ID_FIELD, message.toolCallId());
}
if (message.toolCalls() != null) {
builder.startArray(TOOL_CALLS_FIELD);
for (UnifiedCompletionRequest.ToolCall toolCall : message.toolCalls()) {
builder.startObject();
{
builder.field(ID_FIELD, toolCall.id());
builder.startObject(FUNCTION_FIELD);
{
builder.field(ARGUMENTS_FIELD, toolCall.function().arguments());
builder.field(NAME_FIELD, toolCall.function().name());
}
builder.endObject();
builder.field(TYPE_FIELD, toolCall.type());
}
builder.endObject();
}
builder.endArray();
}
}
builder.endObject();
}
}
builder.endArray();

var modelId = Objects.requireNonNullElseGet(unifiedRequest.model(), model::model);
builder.field(MODEL_FIELD, modelId);

if (unifiedRequest.maxCompletionTokens() != null) {
builder.field(MAX_TOKENS, unifiedRequest.maxCompletionTokens());
}

// Underlying providers expect OpenAI to only return 1 possible choice.
builder.field(NUMBER_OF_RETURNED_CHOICES_FIELD, 1);

if (unifiedRequest.stop() != null && unifiedRequest.stop().isEmpty() == false) {
builder.field(STOP_FIELD, unifiedRequest.stop());
}
if (unifiedRequest.temperature() != null) {
builder.field(TEMPERATURE_FIELD, unifiedRequest.temperature());
}
if (unifiedRequest.toolChoice() != null) {
if (unifiedRequest.toolChoice() instanceof UnifiedCompletionRequest.ToolChoiceString) {
builder.field(TOOL_CHOICE_FIELD, ((UnifiedCompletionRequest.ToolChoiceString) unifiedRequest.toolChoice()).value());
} else if (unifiedRequest.toolChoice() instanceof UnifiedCompletionRequest.ToolChoiceObject) {
builder.startObject(TOOL_CHOICE_FIELD);
{
builder.field(TYPE_FIELD, ((UnifiedCompletionRequest.ToolChoiceObject) unifiedRequest.toolChoice()).type());
builder.startObject(FUNCTION_FIELD);
{
builder.field(
NAME_FIELD,
((UnifiedCompletionRequest.ToolChoiceObject) unifiedRequest.toolChoice()).function().name()
);
}
builder.endObject();
}
builder.endObject();
}
}
boolean usesTools = unifiedRequest.tools() != null && unifiedRequest.tools().isEmpty() == false;

if (usesTools) {
builder.startArray(TOOL_FIELD);
for (UnifiedCompletionRequest.Tool tool : unifiedRequest.tools()) {
builder.startObject();
{
builder.field(TYPE_FIELD, tool.type());
builder.startObject(FUNCTION_FIELD);
{
builder.field(DESCRIPTION_FIELD, tool.function().description());
builder.field(NAME_FIELD, tool.function().name());
builder.field(PARAMETERS_FIELD, tool.function().parameters());
if (tool.function().strict() != null) {
builder.field(STRICT_FIELD, tool.function().strict());
}
}
builder.endObject();
}
builder.endObject();
}
builder.endArray();
}
if (unifiedRequest.topP() != null) {
builder.field(TOP_P_FIELD, unifiedRequest.topP());
}

builder.field(STREAM_FIELD, stream);
if (stream) {
builder.startObject(STREAM_OPTIONS_FIELD);
builder.field(INCLUDE_USAGE_FIELD, true);
builder.endObject();
}

builder.endObject();

return builder;
}
}
Loading
Loading