Skip to content

Commit 14a5383

Browse files
Inference changes
1 parent dbfaf30 commit 14a5383

File tree

56 files changed

+5547
-173
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

56 files changed

+5547
-173
lines changed

x-pack/plugin/inference/build.gradle

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
apply plugin: 'elasticsearch.internal-es-plugin'
99
apply plugin: 'elasticsearch.internal-cluster-test'
1010
apply plugin: 'elasticsearch.internal-yaml-rest-test'
11-
apply plugin: 'elasticsearch.internal-test-artifact'
1211

1312
restResources {
1413
restApi {

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
2525

2626
public void testGetServicesWithoutTaskType() throws IOException {
2727
List<Object> services = getAllServices();
28-
assertThat(services.size(), equalTo(22));
28+
assertThat(services.size(), equalTo(23));
2929

3030
var providers = providers(services);
3131

@@ -39,6 +39,7 @@ public void testGetServicesWithoutTaskType() throws IOException {
3939
"azureaistudio",
4040
"azureopenai",
4141
"cohere",
42+
"custom",
4243
"deepseek",
4344
"elastic",
4445
"elasticsearch",
@@ -70,7 +71,7 @@ private Iterable<String> providers(List<Object> services) {
7071

7172
public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
7273
List<Object> services = getServices(TaskType.TEXT_EMBEDDING);
73-
assertThat(services.size(), equalTo(16));
74+
assertThat(services.size(), equalTo(17));
7475

7576
var providers = providers(services);
7677

@@ -83,6 +84,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
8384
"azureaistudio",
8485
"azureopenai",
8586
"cohere",
87+
"custom",
8688
"elasticsearch",
8789
"googleaistudio",
8890
"googlevertexai",
@@ -101,7 +103,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
101103

102104
public void testGetServicesWithRerankTaskType() throws IOException {
103105
List<Object> services = getServices(TaskType.RERANK);
104-
assertThat(services.size(), equalTo(7));
106+
assertThat(services.size(), equalTo(8));
105107

106108
var providers = providers(services);
107109

@@ -111,6 +113,7 @@ public void testGetServicesWithRerankTaskType() throws IOException {
111113
List.of(
112114
"alibabacloud-ai-search",
113115
"cohere",
116+
"custom",
114117
"elasticsearch",
115118
"googlevertexai",
116119
"jinaai",
@@ -123,7 +126,7 @@ public void testGetServicesWithRerankTaskType() throws IOException {
123126

124127
public void testGetServicesWithCompletionTaskType() throws IOException {
125128
List<Object> services = getServices(TaskType.COMPLETION);
126-
assertThat(services.size(), equalTo(10));
129+
assertThat(services.size(), equalTo(11));
127130

128131
var providers = providers(services);
129132

@@ -137,6 +140,7 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
137140
"azureaistudio",
138141
"azureopenai",
139142
"cohere",
143+
"custom",
140144
"deepseek",
141145
"googleaistudio",
142146
"openai",
@@ -157,7 +161,7 @@ public void testGetServicesWithChatCompletionTaskType() throws IOException {
157161

158162
public void testGetServicesWithSparseEmbeddingTaskType() throws IOException {
159163
List<Object> services = getServices(TaskType.SPARSE_EMBEDDING);
160-
assertThat(services.size(), equalTo(6));
164+
assertThat(services.size(), equalTo(7));
161165

162166
var providers = providers(services);
163167

@@ -166,6 +170,7 @@ public void testGetServicesWithSparseEmbeddingTaskType() throws IOException {
166170
containsInAnyOrder(
167171
List.of(
168172
"alibabacloud-ai-search",
173+
"custom",
169174
"elastic",
170175
"elasticsearch",
171176
"hugging_face",

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
import org.elasticsearch.xcontent.ToXContent;
3535
import org.elasticsearch.xcontent.ToXContentObject;
3636
import org.elasticsearch.xcontent.XContentBuilder;
37-
import org.elasticsearch.xpack.core.inference.DequeUtils;
3837
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
3938
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
4039
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
@@ -257,24 +256,37 @@ public void cancel() {}
257256
"object": "chat.completion.chunk"
258257
}
259258
*/
260-
private StreamingUnifiedChatCompletionResults.Results unifiedCompletionChunk(String delta) {
261-
return new StreamingUnifiedChatCompletionResults.Results(
262-
DequeUtils.of(
263-
new StreamingUnifiedChatCompletionResults.ChatCompletionChunk(
264-
"id",
265-
List.of(
266-
new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice(
267-
new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta(delta, null, null, null),
268-
null,
269-
0
270-
)
271-
),
272-
"gpt-4o-2024-08-06",
273-
"chat.completion.chunk",
274-
null
275-
)
276-
)
277-
);
259+
private InferenceServiceResults.Result unifiedCompletionChunk(String delta) {
260+
return new InferenceServiceResults.Result() {
261+
@Override
262+
public String getWriteableName() {
263+
return "test_unifiedCompletionChunk";
264+
}
265+
266+
@Override
267+
public void writeTo(StreamOutput out) throws IOException {
268+
out.writeString(delta);
269+
}
270+
271+
@Override
272+
public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
273+
return ChunkedToXContentHelper.chunk(
274+
(b, p) -> b.startObject()
275+
.field("id", "id")
276+
.startArray("choices")
277+
.startObject()
278+
.startObject("delta")
279+
.field("content", delta)
280+
.endObject()
281+
.field("index", 0)
282+
.endObject()
283+
.endArray()
284+
.field("model", "gpt-4o-2024-08-06")
285+
.field("object", "chat.completion.chunk")
286+
.endObject()
287+
);
288+
}
289+
};
278290
}
279291

280292
@Override

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@
99

1010
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
1111

12+
import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
1213
import org.elasticsearch.action.bulk.BulkItemResponse;
1314
import org.elasticsearch.action.bulk.BulkRequestBuilder;
1415
import org.elasticsearch.action.bulk.BulkResponse;
1516
import org.elasticsearch.action.delete.DeleteRequestBuilder;
1617
import org.elasticsearch.action.index.IndexRequestBuilder;
1718
import org.elasticsearch.action.search.SearchRequest;
1819
import org.elasticsearch.action.search.SearchResponse;
19-
import org.elasticsearch.action.support.WriteRequest;
2020
import org.elasticsearch.action.update.UpdateRequestBuilder;
2121
import org.elasticsearch.cluster.metadata.IndexMetadata;
2222
import org.elasticsearch.common.settings.Settings;
@@ -242,10 +242,12 @@ public void testRestart() throws Exception {
242242

243243
private void assertRandomBulkOperations(String indexName, Function<Boolean, Map<String, Object>> sourceSupplier) throws Exception {
244244
int numHits = numHits(indexName);
245-
int totalBulkReqs = randomIntBetween(2, 10);
245+
int totalBulkReqs = randomIntBetween(2, 100);
246+
long totalDocs = numHits;
246247
Set<String> ids = new HashSet<>();
247-
for (int bulkReqs = 0; bulkReqs < totalBulkReqs; bulkReqs++) {
248-
BulkRequestBuilder bulkReqBuilder = client().prepareBulk().setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
248+
249+
for (int bulkReqs = numHits; bulkReqs < totalBulkReqs; bulkReqs++) {
250+
BulkRequestBuilder bulkReqBuilder = client().prepareBulk();
249251
int totalBulkSize = randomIntBetween(1, 100);
250252
for (int bulkSize = 0; bulkSize < totalBulkSize; bulkSize++) {
251253
if (ids.size() > 0 && rarely(random())) {
@@ -255,15 +257,24 @@ private void assertRandomBulkOperations(String indexName, Function<Boolean, Map<
255257
bulkReqBuilder.add(request);
256258
continue;
257259
}
258-
boolean isIndexRequest = ids.size() == 0 || randomBoolean();
260+
String id = Long.toString(totalDocs++);
261+
boolean isIndexRequest = randomBoolean();
259262
Map<String, Object> source = sourceSupplier.apply(isIndexRequest);
260263
if (isIndexRequest) {
261-
String id = randomAlphaOfLength(20);
262264
bulkReqBuilder.add(new IndexRequestBuilder(client()).setIndex(indexName).setId(id).setSource(source));
263265
ids.add(id);
264266
} else {
265-
String id = randomFrom(ids);
266-
bulkReqBuilder.add(new UpdateRequestBuilder(client()).setIndex(indexName).setId(id).setDoc(source));
267+
boolean isUpsert = randomBoolean();
268+
UpdateRequestBuilder request = new UpdateRequestBuilder(client()).setIndex(indexName).setDoc(source);
269+
if (isUpsert || ids.size() == 0) {
270+
request.setDocAsUpsert(true);
271+
} else {
272+
// Update already existing document
273+
id = randomFrom(ids);
274+
}
275+
request.setId(id);
276+
bulkReqBuilder.add(request);
277+
ids.add(id);
267278
}
268279
}
269280
BulkResponse bulkResponse = bulkReqBuilder.get();
@@ -282,7 +293,8 @@ private void assertRandomBulkOperations(String indexName, Function<Boolean, Map<
282293
}
283294
assertFalse(bulkResponse.hasFailures());
284295
}
285-
assertThat(numHits(indexName), equalTo(numHits + ids.size()));
296+
client().admin().indices().refresh(new RefreshRequest(indexName)).get();
297+
assertThat(numHits(indexName), equalTo(ids.size() + numHits));
286298
}
287299

288300
private int numHits(String indexName) throws Exception {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,15 @@
5959
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings;
6060
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankServiceSettings;
6161
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings;
62+
import org.elasticsearch.xpack.inference.services.custom.CustomSecretSettings;
63+
import org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings;
64+
import org.elasticsearch.xpack.inference.services.custom.CustomTaskSettings;
65+
import org.elasticsearch.xpack.inference.services.custom.response.CompletionResponseParser;
66+
import org.elasticsearch.xpack.inference.services.custom.response.CustomResponseParser;
67+
import org.elasticsearch.xpack.inference.services.custom.response.NoopResponseParser;
68+
import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser;
69+
import org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser;
70+
import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser;
6271
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel;
6372
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
6473
import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings;
@@ -154,6 +163,7 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
154163
addAlibabaCloudSearchNamedWriteables(namedWriteables);
155164
addJinaAINamedWriteables(namedWriteables);
156165
addVoyageAINamedWriteables(namedWriteables);
166+
addCustomNamedWriteables(namedWriteables);
157167

158168
addUnifiedNamedWriteables(namedWriteables);
159169

@@ -165,6 +175,38 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
165175
return namedWriteables;
166176
}
167177

178+
private static void addCustomNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
179+
namedWriteables.add(
180+
new NamedWriteableRegistry.Entry(ServiceSettings.class, CustomServiceSettings.NAME, CustomServiceSettings::new)
181+
);
182+
183+
namedWriteables.add(new NamedWriteableRegistry.Entry(TaskSettings.class, CustomTaskSettings.NAME, CustomTaskSettings::new));
184+
185+
namedWriteables.add(new NamedWriteableRegistry.Entry(SecretSettings.class, CustomSecretSettings.NAME, CustomSecretSettings::new));
186+
187+
namedWriteables.add(
188+
new NamedWriteableRegistry.Entry(CustomResponseParser.class, TextEmbeddingResponseParser.NAME, TextEmbeddingResponseParser::new)
189+
);
190+
191+
namedWriteables.add(
192+
new NamedWriteableRegistry.Entry(
193+
CustomResponseParser.class,
194+
SparseEmbeddingResponseParser.NAME,
195+
SparseEmbeddingResponseParser::new
196+
)
197+
);
198+
199+
namedWriteables.add(
200+
new NamedWriteableRegistry.Entry(CustomResponseParser.class, RerankResponseParser.NAME, RerankResponseParser::new)
201+
);
202+
203+
namedWriteables.add(new NamedWriteableRegistry.Entry(CustomResponseParser.class, NoopResponseParser.NAME, NoopResponseParser::new));
204+
205+
namedWriteables.add(
206+
new NamedWriteableRegistry.Entry(CustomResponseParser.class, CompletionResponseParser.NAME, CompletionResponseParser::new)
207+
);
208+
}
209+
168210
private static void addUnifiedNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
169211
var writeables = UnifiedCompletionRequest.getNamedWriteables();
170212
namedWriteables.addAll(writeables);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@
119119
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioService;
120120
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiService;
121121
import org.elasticsearch.xpack.inference.services.cohere.CohereService;
122+
import org.elasticsearch.xpack.inference.services.custom.CustomService;
122123
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekService;
123124
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;
124125
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents;
@@ -276,17 +277,13 @@ public Collection<?> createComponents(PluginServices services) {
276277
var inferenceServices = new ArrayList<>(inferenceServiceExtensions);
277278
inferenceServices.add(this::getInferenceServiceFactories);
278279

279-
var inferenceServiceSettings = new ElasticInferenceServiceSettings(settings);
280-
inferenceServiceSettings.init(services.clusterService());
281-
282280
// Create a separate instance of HTTPClientManager with its own SSL configuration (`xpack.inference.elastic.http.ssl.*`).
283281
var elasticInferenceServiceHttpClientManager = HttpClientManager.create(
284282
settings,
285283
services.threadPool(),
286284
services.clusterService(),
287285
throttlerManager,
288-
getSslService(),
289-
inferenceServiceSettings.getConnectionTtl()
286+
getSslService()
290287
);
291288

292289
var elasticInferenceServiceRequestSenderFactory = new HttpRequestSender.Factory(
@@ -296,6 +293,9 @@ public Collection<?> createComponents(PluginServices services) {
296293
);
297294
elasicInferenceServiceFactory.set(elasticInferenceServiceRequestSenderFactory);
298295

296+
var inferenceServiceSettings = new ElasticInferenceServiceSettings(settings);
297+
inferenceServiceSettings.init(services.clusterService());
298+
299299
var authorizationHandler = new ElasticInferenceServiceAuthorizationRequestHandler(
300300
inferenceServiceSettings.getElasticInferenceServiceUrl(),
301301
services.threadPool()
@@ -396,6 +396,7 @@ public List<InferenceServiceExtension.Factory> getInferenceServiceFactories() {
396396
context -> new JinaAIService(httpFactory.get(), serviceComponents.get()),
397397
context -> new VoyageAIService(httpFactory.get(), serviceComponents.get()),
398398
context -> new DeepSeekService(httpFactory.get(), serviceComponents.get()),
399+
context -> new CustomService(httpFactory.get(), serviceComponents.get()),
399400
ElasticsearchInternalService::new
400401
);
401402
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ protected void masterOperation(
177177
return;
178178
}
179179

180-
parseAndStoreModel(service.get(), request.getInferenceEntityId(), resolvedTaskType, requestAsMap, request.getTimeout(), listener);
180+
parseAndStoreModel(service.get(), request.getInferenceEntityId(), resolvedTaskType, requestAsMap, request.ackTimeout(), listener);
181181
}
182182

183183
private void parseAndStoreModel(

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpClientManager.java

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
import java.io.Closeable;
3333
import java.io.IOException;
3434
import java.util.List;
35-
import java.util.concurrent.TimeUnit;
3635

3736
import static org.elasticsearch.core.Strings.format;
3837
import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings.ELASTIC_INFERENCE_SERVICE_SSL_CONFIGURATION_PREFIX;
@@ -113,15 +112,14 @@ public static HttpClientManager create(
113112
ThreadPool threadPool,
114113
ClusterService clusterService,
115114
ThrottlerManager throttlerManager,
116-
SSLService sslService,
117-
TimeValue connectionTtl
115+
SSLService sslService
118116
) {
119117
// Set the sslStrategy to ensure an encrypted connection, as Elastic Inference Service requires it.
120118
SSLIOSessionStrategy sslioSessionStrategy = sslService.sslIOSessionStrategy(
121119
sslService.getSSLConfiguration(ELASTIC_INFERENCE_SERVICE_SSL_CONFIGURATION_PREFIX)
122120
);
123121

124-
PoolingNHttpClientConnectionManager connectionManager = createConnectionManager(sslioSessionStrategy, connectionTtl);
122+
PoolingNHttpClientConnectionManager connectionManager = createConnectionManager(sslioSessionStrategy);
125123
return new HttpClientManager(settings, connectionManager, threadPool, clusterService, throttlerManager);
126124
}
127125

@@ -148,7 +146,7 @@ public static HttpClientManager create(
148146
this.addSettingsUpdateConsumers(clusterService);
149147
}
150148

151-
private static PoolingNHttpClientConnectionManager createConnectionManager(SSLIOSessionStrategy sslStrategy, TimeValue connectionTtl) {
149+
private static PoolingNHttpClientConnectionManager createConnectionManager(SSLIOSessionStrategy sslStrategy) {
152150
ConnectingIOReactor ioReactor;
153151
try {
154152
var configBuilder = IOReactorConfig.custom().setSoKeepAlive(true);
@@ -164,15 +162,7 @@ private static PoolingNHttpClientConnectionManager createConnectionManager(SSLIO
164162
.register("https", sslStrategy)
165163
.build();
166164

167-
return new PoolingNHttpClientConnectionManager(
168-
ioReactor,
169-
null,
170-
registry,
171-
null,
172-
null,
173-
Math.toIntExact(connectionTtl.getMillis()),
174-
TimeUnit.MILLISECONDS
175-
);
165+
return new PoolingNHttpClientConnectionManager(ioReactor, registry);
176166
}
177167

178168
private static PoolingNHttpClientConnectionManager createConnectionManager() {

0 commit comments

Comments
 (0)