Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/127939.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 127939
summary: Add Custom inference service
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ML_INFERENCE_HUGGING_FACE_RERANK_ADDED_8_19 = def(8_841_0_36);
public static final TransportVersion ML_INFERENCE_SAGEMAKER_CHAT_COMPLETION_8_19 = def(8_841_0_37);
public static final TransportVersion ML_INFERENCE_VERTEXAI_CHATCOMPLETION_ADDED_8_19 = def(8_841_0_38);
public static final TransportVersion INFERENCE_CUSTOM_SERVICE_ADDED_8_19 = def(8_841_0_39);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
*/
public enum FeatureFlag {
TIME_SERIES_MODE("es.index_mode_feature_flag_registered=true", Version.fromString("8.0.0"), null),
SUB_OBJECTS_AUTO_ENABLED("es.sub_objects_auto_feature_flag_enabled=true", Version.fromString("8.16.0"), null);
SUB_OBJECTS_AUTO_ENABLED("es.sub_objects_auto_feature_flag_enabled=true", Version.fromString("8.16.0"), null),
INFERENCE_CUSTOM_SERVICE_ENABLED("es.inference_custom_service_feature_flag_enabled=true", Version.fromString("8.19.0"), null);

public final String systemProperty;
public final Version from;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.test.cluster.ElasticsearchCluster;
import org.elasticsearch.test.cluster.FeatureFlag;
import org.elasticsearch.test.cluster.local.distribution.DistributionType;
import org.elasticsearch.test.rest.ESRestTestCase;
import org.junit.ClassRule;
Expand Down Expand Up @@ -46,6 +47,7 @@ public class BaseMockEISAuthServerTest extends ESRestTestCase {
// This plugin is located in the inference/qa/test-service-plugin package, look for TestInferenceServicePlugin
.plugin("inference-service-test")
.user("x_pack_rest_user", "x-pack-test-password")
.feature(FeatureFlag.INFERENCE_CUSTOM_SERVICE_ENABLED)
.build();

// The reason we're doing this is to make sure the mock server is initialized first so we can get the address before communicating
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.cluster.ElasticsearchCluster;
import org.elasticsearch.test.cluster.FeatureFlag;
import org.elasticsearch.test.cluster.local.distribution.DistributionType;
import org.elasticsearch.test.rest.ESRestTestCase;
import org.elasticsearch.xcontent.XContentBuilder;
Expand Down Expand Up @@ -50,6 +51,7 @@ public class InferenceBaseRestTest extends ESRestTestCase {
.setting("xpack.security.enabled", "true")
.plugin("inference-service-test")
.user("x_pack_rest_user", "x-pack-test-password")
.feature(FeatureFlag.INFERENCE_CUSTOM_SERVICE_ENABLED)
.build();

@ClassRule
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {

public void testGetServicesWithoutTaskType() throws IOException {
List<Object> services = getAllServices();
assertThat(services.size(), equalTo(22));
assertThat(services.size(), equalTo(23));

var providers = providers(services);

Expand All @@ -39,6 +39,7 @@ public void testGetServicesWithoutTaskType() throws IOException {
"azureaistudio",
"azureopenai",
"cohere",
"custom",
"deepseek",
"elastic",
"elasticsearch",
Expand Down Expand Up @@ -70,7 +71,7 @@ private Iterable<String> providers(List<Object> services) {

public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
List<Object> services = getServices(TaskType.TEXT_EMBEDDING);
assertThat(services.size(), equalTo(16));
assertThat(services.size(), equalTo(17));

var providers = providers(services);

Expand All @@ -83,6 +84,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
"azureaistudio",
"azureopenai",
"cohere",
"custom",
"elasticsearch",
"googleaistudio",
"googlevertexai",
Expand All @@ -101,7 +103,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException {

public void testGetServicesWithRerankTaskType() throws IOException {
List<Object> services = getServices(TaskType.RERANK);
assertThat(services.size(), equalTo(8));
assertThat(services.size(), equalTo(9));

var providers = providers(services);

Expand All @@ -111,6 +113,7 @@ public void testGetServicesWithRerankTaskType() throws IOException {
List.of(
"alibabacloud-ai-search",
"cohere",
"custom",
"elasticsearch",
"googlevertexai",
"jinaai",
Expand All @@ -124,7 +127,7 @@ public void testGetServicesWithRerankTaskType() throws IOException {

public void testGetServicesWithCompletionTaskType() throws IOException {
List<Object> services = getServices(TaskType.COMPLETION);
assertThat(services.size(), equalTo(12));
assertThat(services.size(), equalTo(13));

var providers = providers(services);

Expand All @@ -138,6 +141,7 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
"azureaistudio",
"azureopenai",
"cohere",
"custom",
"deepseek",
"googleaistudio",
"openai",
Expand Down Expand Up @@ -173,7 +177,7 @@ public void testGetServicesWithChatCompletionTaskType() throws IOException {

public void testGetServicesWithSparseEmbeddingTaskType() throws IOException {
List<Object> services = getServices(TaskType.SPARSE_EMBEDDING);
assertThat(services.size(), equalTo(6));
assertThat(services.size(), equalTo(7));

var providers = providers(services);

Expand All @@ -182,6 +186,7 @@ public void testGetServicesWithSparseEmbeddingTaskType() throws IOException {
containsInAnyOrder(
List.of(
"alibabacloud-ai-search",
"custom",
"elastic",
"elasticsearch",
"hugging_face",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference;

import org.elasticsearch.common.util.FeatureFlag;

public class CustomServiceFeatureFlag {
/**
* {@link org.elasticsearch.xpack.inference.services.custom.CustomService} feature flag. When the feature is complete,
* this flag will be removed.
* Enable feature via JVM option: `-Des.inference_custom_service_feature_flag_enabled=true`.
*/
public static final FeatureFlag CUSTOM_SERVICE_FEATURE_FLAG = new FeatureFlag("inference_custom_service");

private CustomServiceFeatureFlag() {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,15 @@
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings;
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.custom.CustomSecretSettings;
import org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings;
import org.elasticsearch.xpack.inference.services.custom.CustomTaskSettings;
import org.elasticsearch.xpack.inference.services.custom.response.CompletionResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.CustomResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.NoopResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser;
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
Expand Down Expand Up @@ -108,6 +117,8 @@
import java.util.ArrayList;
import java.util.List;

import static org.elasticsearch.xpack.inference.CustomServiceFeatureFlag.CUSTOM_SERVICE_FEATURE_FLAG;

public class InferenceNamedWriteablesProvider {

private InferenceNamedWriteablesProvider() {}
Expand Down Expand Up @@ -158,6 +169,7 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
addAlibabaCloudSearchNamedWriteables(namedWriteables);
addJinaAINamedWriteables(namedWriteables);
addVoyageAINamedWriteables(namedWriteables);
addCustomNamedWriteables(namedWriteables);

addUnifiedNamedWriteables(namedWriteables);

Expand All @@ -169,6 +181,42 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
return namedWriteables;
}

private static void addCustomNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
if (CUSTOM_SERVICE_FEATURE_FLAG.isEnabled() == false) {
return;
}

namedWriteables.add(
new NamedWriteableRegistry.Entry(ServiceSettings.class, CustomServiceSettings.NAME, CustomServiceSettings::new)
);

namedWriteables.add(new NamedWriteableRegistry.Entry(TaskSettings.class, CustomTaskSettings.NAME, CustomTaskSettings::new));

namedWriteables.add(new NamedWriteableRegistry.Entry(SecretSettings.class, CustomSecretSettings.NAME, CustomSecretSettings::new));

namedWriteables.add(
new NamedWriteableRegistry.Entry(CustomResponseParser.class, TextEmbeddingResponseParser.NAME, TextEmbeddingResponseParser::new)
);

namedWriteables.add(
new NamedWriteableRegistry.Entry(
CustomResponseParser.class,
SparseEmbeddingResponseParser.NAME,
SparseEmbeddingResponseParser::new
)
);

namedWriteables.add(
new NamedWriteableRegistry.Entry(CustomResponseParser.class, RerankResponseParser.NAME, RerankResponseParser::new)
);

namedWriteables.add(new NamedWriteableRegistry.Entry(CustomResponseParser.class, NoopResponseParser.NAME, NoopResponseParser::new));

namedWriteables.add(
new NamedWriteableRegistry.Entry(CustomResponseParser.class, CompletionResponseParser.NAME, CompletionResponseParser::new)
);
}

private static void addUnifiedNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
var writeables = UnifiedCompletionRequest.getNamedWriteables();
namedWriteables.addAll(writeables);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioService;
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiService;
import org.elasticsearch.xpack.inference.services.cohere.CohereService;
import org.elasticsearch.xpack.inference.services.custom.CustomService;
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekService;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents;
Expand Down Expand Up @@ -150,8 +151,10 @@
import java.util.Set;
import java.util.function.Predicate;
import java.util.function.Supplier;
import java.util.stream.Stream;

import static java.util.Collections.singletonList;
import static org.elasticsearch.xpack.inference.CustomServiceFeatureFlag.CUSTOM_SERVICE_FEATURE_FLAG;
import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.INDICES_INFERENCE_BATCH_SIZE;
import static org.elasticsearch.xpack.inference.common.InferenceAPIClusterAwareRateLimitingFeature.INFERENCE_API_CLUSTER_AWARE_RATE_LIMITING_FEATURE_FLAG;

Expand Down Expand Up @@ -381,7 +384,11 @@ public void loadExtensions(ExtensionLoader loader) {
}

public List<InferenceServiceExtension.Factory> getInferenceServiceFactories() {
return List.of(
List<InferenceServiceExtension.Factory> conditionalServices = CUSTOM_SERVICE_FEATURE_FLAG.isEnabled()
? List.of(context -> new CustomService(httpFactory.get(), serviceComponents.get()))
: List.of();

List<InferenceServiceExtension.Factory> availableServices = List.of(
context -> new HuggingFaceElserService(httpFactory.get(), serviceComponents.get()),
context -> new HuggingFaceService(httpFactory.get(), serviceComponents.get()),
context -> new OpenAiService(httpFactory.get(), serviceComponents.get()),
Expand All @@ -400,6 +407,8 @@ public List<InferenceServiceExtension.Factory> getInferenceServiceFactories() {
context -> new DeepSeekService(httpFactory.get(), serviceComponents.get()),
ElasticsearchInternalService::new
);

return Stream.concat(availableServices.stream(), conditionalServices.stream()).toList();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public abstract class BaseResponseHandler implements ResponseHandler {
public static final String METHOD_NOT_ALLOWED = "Received a method not allowed status code";

protected final String requestType;
private final ResponseParser parseFunction;
protected final ResponseParser parseFunction;
private final Function<HttpResult, ErrorResponse> errorParseFunction;
private final boolean canHandleStreamingResponses;

Expand Down
Loading