Skip to content

Commit 2a81970

Browse files
Merge branch 'main' into code-comments
2 parents 8c51826 + 9f89a3b commit 2a81970

File tree

15 files changed

+1114
-20
lines changed

15 files changed

+1114
-20
lines changed

docs/changelog/122218.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 122218
2+
summary: Integrate with `DeepSeek` API
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ static TransportVersion def(int id) {
147147
public static final TransportVersion JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED_BACKPORT_8_19 = def(8_841_0_06);
148148
public static final TransportVersion RETRY_ILM_ASYNC_ACTION_REQUIRE_ERROR_8_19 = def(8_841_0_07);
149149
public static final TransportVersion INFERENCE_CONTEXT_8_X = def(8_841_0_08);
150+
public static final TransportVersion ML_INFERENCE_DEEPSEEK_8_19 = def(8_841_0_09);
150151
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0 = def(9_000_0_00);
151152
public static final TransportVersion REMOVE_SNAPSHOT_FAILURES_90 = def(9_000_0_01);
152153
public static final TransportVersion TRANSPORT_STATS_HANDLING_TIME_REQUIRED_90 = def(9_000_0_02);
@@ -183,6 +184,7 @@ static TransportVersion def(int id) {
183184
public static final TransportVersion ESQL_SERIALIZE_BLOCK_TYPE_CODE = def(9_026_0_00);
184185
public static final TransportVersion ESQL_THREAD_NAME_IN_DRIVER_PROFILE = def(9_027_0_00);
185186
public static final TransportVersion INFERENCE_CONTEXT = def(9_028_0_00);
187+
public static final TransportVersion ML_INFERENCE_DEEPSEEK = def(9_029_00_0);
186188

187189
/*
188190
* STOP! READ THIS FIRST! No, really,

x-pack/plugin/blob-cache/src/main/java/org/elasticsearch/blobcache/BlobCacheMetrics.java

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import org.apache.logging.log4j.LogManager;
1111
import org.apache.logging.log4j.Logger;
12+
import org.elasticsearch.index.store.LuceneFilesExtensions;
1213
import org.elasticsearch.telemetry.TelemetryProvider;
1314
import org.elasticsearch.telemetry.metric.DoubleHistogram;
1415
import org.elasticsearch.telemetry.metric.LongCounter;
@@ -24,8 +25,8 @@ public class BlobCacheMetrics {
2425
private static final double BYTES_PER_NANOSECONDS_TO_MEBIBYTES_PER_SECOND = 1e9D / (1 << 20);
2526
public static final String CACHE_POPULATION_REASON_ATTRIBUTE_KEY = "reason";
2627
public static final String CACHE_POPULATION_SOURCE_ATTRIBUTE_KEY = "source";
27-
public static final String SHARD_ID_ATTRIBUTE_KEY = "shard_id";
28-
public static final String INDEX_ATTRIBUTE_KEY = "index_name";
28+
public static final String LUCENE_FILE_EXTENSION_ATTRIBUTE_KEY = "file_extension";
29+
public static final String NON_LUCENE_EXTENSION_TO_RECORD = "other";
2930

3031
private final LongCounter cacheMissCounter;
3132
private final LongCounter evictedCountNonZeroFrequency;
@@ -113,22 +114,28 @@ public LongHistogram getCacheMissLoadTimes() {
113114
/**
114115
* Record the various cache population metrics after a chunk is copied to the cache
115116
*
117+
* @param blobName The file that was requested and triggered the cache population.
116118
* @param bytesCopied The number of bytes copied
117119
* @param copyTimeNanos The time taken to copy the bytes in nanoseconds
118120
* @param cachePopulationReason The reason for the cache being populated
119121
* @param cachePopulationSource The source from which the data is being loaded
120122
*/
121123
public void recordCachePopulationMetrics(
124+
String blobName,
122125
int bytesCopied,
123126
long copyTimeNanos,
124127
CachePopulationReason cachePopulationReason,
125128
CachePopulationSource cachePopulationSource
126129
) {
130+
LuceneFilesExtensions luceneFilesExtensions = LuceneFilesExtensions.fromFile(blobName);
131+
String blobFileExtension = luceneFilesExtensions != null ? luceneFilesExtensions.getExtension() : NON_LUCENE_EXTENSION_TO_RECORD;
127132
Map<String, Object> metricAttributes = Map.of(
128133
CACHE_POPULATION_REASON_ATTRIBUTE_KEY,
129134
cachePopulationReason.name(),
130135
CACHE_POPULATION_SOURCE_ATTRIBUTE_KEY,
131-
cachePopulationSource.name()
136+
cachePopulationSource.name(),
137+
LUCENE_FILE_EXTENSION_ATTRIBUTE_KEY,
138+
blobFileExtension
132139
);
133140
assert bytesCopied > 0 : "We shouldn't be recording zero-sized copies";
134141
cachePopulationBytes.incrementBy(bytesCopied, metricAttributes);

x-pack/plugin/blob-cache/src/test/java/org/elasticsearch/blobcache/BlobCacheMetricsTests.java

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,18 @@
88
package org.elasticsearch.blobcache;
99

1010
import org.elasticsearch.common.unit.ByteSizeValue;
11+
import org.elasticsearch.index.store.LuceneFilesExtensions;
1112
import org.elasticsearch.telemetry.InstrumentType;
1213
import org.elasticsearch.telemetry.Measurement;
1314
import org.elasticsearch.telemetry.RecordingMeterRegistry;
1415
import org.elasticsearch.test.ESTestCase;
1516
import org.junit.Before;
1617

18+
import java.util.Arrays;
1719
import java.util.concurrent.TimeUnit;
1820

21+
import static org.hamcrest.Matchers.is;
22+
1923
public class BlobCacheMetricsTests extends ESTestCase {
2024

2125
private RecordingMeterRegistry recordingMeterRegistry;
@@ -32,7 +36,10 @@ public void testRecordCachePopulationMetricsRecordsThroughput() {
3236
int secondsTaken = randomIntBetween(1, 5);
3337
BlobCacheMetrics.CachePopulationReason cachePopulationReason = randomFrom(BlobCacheMetrics.CachePopulationReason.values());
3438
CachePopulationSource cachePopulationSource = randomFrom(CachePopulationSource.values());
39+
String fileExtension = randomFrom(Arrays.stream(LuceneFilesExtensions.values()).map(LuceneFilesExtensions::getExtension).toList());
40+
String luceneBlobFile = randomAlphanumericOfLength(15) + "." + fileExtension;
3541
metrics.recordCachePopulationMetrics(
42+
luceneBlobFile,
3643
Math.toIntExact(ByteSizeValue.ofMb(mebiBytesSent).getBytes()),
3744
TimeUnit.SECONDS.toNanos(secondsTaken),
3845
cachePopulationReason,
@@ -44,29 +51,31 @@ public void testRecordCachePopulationMetricsRecordsThroughput() {
4451
.getMeasurements(InstrumentType.DOUBLE_HISTOGRAM, "es.blob_cache.population.throughput.histogram")
4552
.get(0);
4653
assertEquals(throughputMeasurement.getDouble(), (double) mebiBytesSent / secondsTaken, 0.0);
47-
assertExpectedAttributesPresent(throughputMeasurement, cachePopulationReason, cachePopulationSource);
54+
assertExpectedAttributesPresent(throughputMeasurement, cachePopulationReason, cachePopulationSource, fileExtension);
4855

4956
// bytes counter
5057
Measurement totalBytesMeasurement = recordingMeterRegistry.getRecorder()
5158
.getMeasurements(InstrumentType.LONG_COUNTER, "es.blob_cache.population.bytes.total")
5259
.get(0);
5360
assertEquals(totalBytesMeasurement.getLong(), ByteSizeValue.ofMb(mebiBytesSent).getBytes());
54-
assertExpectedAttributesPresent(totalBytesMeasurement, cachePopulationReason, cachePopulationSource);
61+
assertExpectedAttributesPresent(totalBytesMeasurement, cachePopulationReason, cachePopulationSource, fileExtension);
5562

5663
// time counter
5764
Measurement totalTimeMeasurement = recordingMeterRegistry.getRecorder()
5865
.getMeasurements(InstrumentType.LONG_COUNTER, "es.blob_cache.population.time.total")
5966
.get(0);
6067
assertEquals(totalTimeMeasurement.getLong(), TimeUnit.SECONDS.toMillis(secondsTaken));
61-
assertExpectedAttributesPresent(totalTimeMeasurement, cachePopulationReason, cachePopulationSource);
68+
assertExpectedAttributesPresent(totalTimeMeasurement, cachePopulationReason, cachePopulationSource, fileExtension);
6269
}
6370

6471
private static void assertExpectedAttributesPresent(
6572
Measurement measurement,
6673
BlobCacheMetrics.CachePopulationReason cachePopulationReason,
67-
CachePopulationSource cachePopulationSource
74+
CachePopulationSource cachePopulationSource,
75+
String fileExtension
6876
) {
69-
assertEquals(measurement.attributes().get(BlobCacheMetrics.CACHE_POPULATION_REASON_ATTRIBUTE_KEY), cachePopulationReason.name());
70-
assertEquals(measurement.attributes().get(BlobCacheMetrics.CACHE_POPULATION_SOURCE_ATTRIBUTE_KEY), cachePopulationSource.name());
77+
assertThat(measurement.attributes().get(BlobCacheMetrics.CACHE_POPULATION_REASON_ATTRIBUTE_KEY), is(cachePopulationReason.name()));
78+
assertThat(measurement.attributes().get(BlobCacheMetrics.CACHE_POPULATION_SOURCE_ATTRIBUTE_KEY), is(cachePopulationSource.name()));
79+
assertThat(measurement.attributes().get(BlobCacheMetrics.LUCENE_FILE_EXTENSION_ATTRIBUTE_KEY), is(fileExtension));
7180
}
7281
}

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
2525
@SuppressWarnings("unchecked")
2626
public void testGetServicesWithoutTaskType() throws IOException {
2727
List<Object> services = getAllServices();
28-
assertThat(services.size(), equalTo(20));
28+
assertThat(services.size(), equalTo(21));
2929

3030
String[] providers = new String[services.size()];
3131
for (int i = 0; i < services.size(); i++) {
@@ -41,6 +41,7 @@ public void testGetServicesWithoutTaskType() throws IOException {
4141
"azureaistudio",
4242
"azureopenai",
4343
"cohere",
44+
"deepseek",
4445
"elastic",
4546
"elasticsearch",
4647
"googleaistudio",
@@ -114,7 +115,7 @@ public void testGetServicesWithRerankTaskType() throws IOException {
114115
@SuppressWarnings("unchecked")
115116
public void testGetServicesWithCompletionTaskType() throws IOException {
116117
List<Object> services = getServices(TaskType.COMPLETION);
117-
assertThat(services.size(), equalTo(9));
118+
assertThat(services.size(), equalTo(10));
118119

119120
String[] providers = new String[services.size()];
120121
for (int i = 0; i < services.size(); i++) {
@@ -130,6 +131,7 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
130131
"azureaistudio",
131132
"azureopenai",
132133
"cohere",
134+
"deepseek",
133135
"googleaistudio",
134136
"openai",
135137
"streaming_completion_test_service"
@@ -141,15 +143,15 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
141143
@SuppressWarnings("unchecked")
142144
public void testGetServicesWithChatCompletionTaskType() throws IOException {
143145
List<Object> services = getServices(TaskType.CHAT_COMPLETION);
144-
assertThat(services.size(), equalTo(3));
146+
assertThat(services.size(), equalTo(4));
145147

146148
String[] providers = new String[services.size()];
147149
for (int i = 0; i < services.size(); i++) {
148150
Map<String, Object> serviceConfig = (Map<String, Object>) services.get(i);
149151
providers[i] = (String) serviceConfig.get("service");
150152
}
151153

152-
assertArrayEquals(List.of("elastic", "openai", "streaming_completion_test_service").toArray(), providers);
154+
assertArrayEquals(List.of("deepseek", "elastic", "openai", "streaming_completion_test_service").toArray(), providers);
153155
}
154156

155157
@SuppressWarnings("unchecked")

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings;
5959
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankServiceSettings;
6060
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings;
61+
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel;
6162
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsServiceSettings;
6263
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
6364
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalServiceSettings;
@@ -153,6 +154,7 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
153154
addUnifiedNamedWriteables(namedWriteables);
154155

155156
namedWriteables.addAll(StreamingTaskManager.namedWriteables());
157+
namedWriteables.addAll(DeepSeekChatCompletionModel.namedWriteables());
156158

157159
return namedWriteables;
158160
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@
116116
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioService;
117117
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiService;
118118
import org.elasticsearch.xpack.inference.services.cohere.CohereService;
119+
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekService;
119120
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;
120121
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents;
121122
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings;
@@ -362,6 +363,7 @@ public List<InferenceServiceExtension.Factory> getInferenceServiceFactories() {
362363
context -> new IbmWatsonxService(httpFactory.get(), serviceComponents.get()),
363364
context -> new JinaAIService(httpFactory.get(), serviceComponents.get()),
364365
context -> new VoyageAIService(httpFactory.get(), serviceComponents.get()),
366+
context -> new DeepSeekService(httpFactory.get(), serviceComponents.get()),
365367
ElasticsearchInternalService::new
366368
);
367369
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.deepseek;
9+
10+
import org.apache.http.HttpHeaders;
11+
import org.apache.http.client.methods.HttpPost;
12+
import org.apache.http.entity.ByteArrayEntity;
13+
import org.elasticsearch.ElasticsearchException;
14+
import org.elasticsearch.common.Strings;
15+
import org.elasticsearch.xcontent.ToXContent;
16+
import org.elasticsearch.xcontent.XContentType;
17+
import org.elasticsearch.xcontent.json.JsonXContent;
18+
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
19+
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
20+
import org.elasticsearch.xpack.inference.external.request.Request;
21+
import org.elasticsearch.xpack.inference.external.unified.UnifiedChatCompletionRequestEntity;
22+
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel;
23+
24+
import java.io.IOException;
25+
import java.net.URI;
26+
import java.nio.charset.StandardCharsets;
27+
import java.util.Objects;
28+
29+
import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader;
30+
31+
public class DeepSeekChatCompletionRequest implements Request {
32+
private static final String MODEL_FIELD = "model";
33+
private static final String MAX_TOKENS = "max_tokens";
34+
35+
private final DeepSeekChatCompletionModel model;
36+
private final UnifiedChatInput unifiedChatInput;
37+
38+
public DeepSeekChatCompletionRequest(UnifiedChatInput unifiedChatInput, DeepSeekChatCompletionModel model) {
39+
this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput);
40+
this.model = Objects.requireNonNull(model);
41+
}
42+
43+
@Override
44+
public HttpRequest createHttpRequest() {
45+
HttpPost httpPost = new HttpPost(model.uri());
46+
47+
httpPost.setEntity(createEntity());
48+
49+
httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType());
50+
httpPost.setHeader(createAuthBearerHeader(model.apiKey()));
51+
52+
return new HttpRequest(httpPost, getInferenceEntityId());
53+
}
54+
55+
private ByteArrayEntity createEntity() {
56+
var modelId = Objects.requireNonNullElseGet(unifiedChatInput.getRequest().model(), model::model);
57+
try (var builder = JsonXContent.contentBuilder()) {
58+
builder.startObject();
59+
new UnifiedChatCompletionRequestEntity(unifiedChatInput).toXContent(builder, ToXContent.EMPTY_PARAMS);
60+
builder.field(MODEL_FIELD, modelId);
61+
62+
if (unifiedChatInput.getRequest().maxCompletionTokens() != null) {
63+
builder.field(MAX_TOKENS, unifiedChatInput.getRequest().maxCompletionTokens());
64+
}
65+
66+
builder.endObject();
67+
return new ByteArrayEntity(Strings.toString(builder).getBytes(StandardCharsets.UTF_8));
68+
} catch (IOException e) {
69+
throw new ElasticsearchException("Failed to serialize request payload.", e);
70+
}
71+
}
72+
73+
@Override
74+
public URI getURI() {
75+
return model.uri();
76+
}
77+
78+
@Override
79+
public Request truncate() {
80+
return this;
81+
}
82+
83+
@Override
84+
public boolean[] getTruncationInfo() {
85+
return null;
86+
}
87+
88+
@Override
89+
public String getInferenceEntityId() {
90+
return model.getInferenceEntityId();
91+
}
92+
93+
@Override
94+
public boolean isStreaming() {
95+
return unifiedChatInput.stream();
96+
}
97+
}

0 commit comments

Comments
 (0)