Skip to content
Closed
Show file tree
Hide file tree
Changes from 48 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
4f4c603
add inference custom model
Huaixinww Mar 7, 2025
e53b2e4
add unit test
Huaixinww Mar 7, 2025
c593b1a
spotless apply
Huaixinww Mar 7, 2025
0a851c1
add custom validation
Huaixinww Mar 7, 2025
e240f06
xpack core spotless apply
Huaixinww Mar 7, 2025
3ea3053
update commons-lang3's version
Huaixinww Mar 7, 2025
83daf69
Fix compilation after rebase
davidkyle Mar 25, 2025
3cb0cfb
Add missing licences and fix build checks
davidkyle Mar 25, 2025
a3c862c
Remove some unused code
davidkyle Mar 25, 2025
2b7e6fe
Update docs/changelog/125679.yaml
davidkyle Mar 26, 2025
6cc593d
Fix services it
davidkyle Mar 27, 2025
a4630e3
Merge branch 'main' of github.com:elastic/elasticsearch into custom-i…
jonathan-buttner Apr 7, 2025
95f23f0
Contuing refactor of service settings
jonathan-buttner Apr 9, 2025
014f95b
Merge branch 'main' of github.com:elastic/elasticsearch into custom-i…
jonathan-buttner Apr 9, 2025
189edba
Moving classes to reflect new structure
jonathan-buttner Apr 9, 2025
4fe3a1f
Refactoring service settings
jonathan-buttner Apr 9, 2025
4ef37f5
Refactoring the request
jonathan-buttner Apr 10, 2025
6bac18b
Adding files to handle generic error response
jonathan-buttner Apr 11, 2025
f644471
Making progress on tests
jonathan-buttner Apr 15, 2025
11cf7cc
Merge branch 'main' of github.com:elastic/elasticsearch into custom-i…
jonathan-buttner Apr 15, 2025
f962d74
Adding more tests
jonathan-buttner Apr 16, 2025
eb63e8b
Adding more tests
jonathan-buttner Apr 18, 2025
adc3210
Merge branch 'main' of github.com:elastic/elasticsearch into custom-i…
jonathan-buttner Apr 18, 2025
c9ff298
Adding tests for remaining parsers
jonathan-buttner Apr 21, 2025
de83271
More tests
jonathan-buttner Apr 22, 2025
34df922
Need to address quoted strings
jonathan-buttner Apr 24, 2025
b496732
Merge branch 'main' of github.com:elastic/elasticsearch into custom-i…
jonathan-buttner Apr 28, 2025
097246b
Adding query parameter handling and tests
jonathan-buttner Apr 28, 2025
e7f6ac5
Adding encoding tests
jonathan-buttner Apr 29, 2025
a8c5241
Fixing embedding dimensions issue and test field names
jonathan-buttner Apr 29, 2025
3df0f70
Merge branch 'main' of github.com:elastic/elasticsearch into custom-i…
jonathan-buttner Apr 29, 2025
ad55337
Fixing tests
jonathan-buttner Apr 29, 2025
4714fd3
[CI] Auto commit changes from spotless
Apr 29, 2025
d13191c
Removing licenses
jonathan-buttner Apr 29, 2025
12d46d7
Adding custom service tests
jonathan-buttner May 2, 2025
0134346
Merge branch 'custom-inference-service' of github.com:davidkyle/elast…
jonathan-buttner May 2, 2025
e6fefc4
[CI] Auto commit changes from spotless
May 2, 2025
eef7188
Correcting tranport version number
jonathan-buttner May 2, 2025
83837c8
Merge branch 'main' of github.com:elastic/elasticsearch into custom-i…
jonathan-buttner May 2, 2025
dc02425
Merge branch 'custom-inference-service' of github.com:davidkyle/elast…
jonathan-buttner May 2, 2025
59f75b9
Cleaning up
jonathan-buttner May 2, 2025
8a82163
Fixing counts
jonathan-buttner May 5, 2025
5f13d28
Merge branch 'main' of github.com:elastic/elasticsearch into custom-i…
jonathan-buttner May 5, 2025
c211d83
Fixing rerank and chat completions
jonathan-buttner May 7, 2025
133ef4e
Missing a few changes
jonathan-buttner May 7, 2025
5c28ee8
Passing request to the error response handler
jonathan-buttner May 7, 2025
be84291
Merge remote-tracking branch 'origin/ml-expose-request-in-error-parse…
jonathan-buttner May 7, 2025
8d1bd22
Adding inference id to error parser log message
jonathan-buttner May 8, 2025
a0984c7
Reverting exposing request to error parsing logic
jonathan-buttner May 8, 2025
4242a37
Refactoring the error parsing logic
jonathan-buttner May 8, 2025
6492cd7
Merge branch 'main' of github.com:elastic/elasticsearch into custom-i…
jonathan-buttner May 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/125679.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 125679
summary: Adding support for generic Inference services
area: Machine Learning
type: enhancement
issues: []
2 changes: 2 additions & 0 deletions server/src/main/java/org/elasticsearch/TransportVersions.java
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ static TransportVersion def(int id) {
public static final TransportVersion INTRODUCE_FAILURES_DEFAULT_RETENTION_BACKPORT_8_19 = def(8_841_0_26);
public static final TransportVersion RESCORE_VECTOR_ALLOW_ZERO_BACKPORT_8_19 = def(8_841_0_27);
public static final TransportVersion INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19 = def(8_841_0_28);
public static final TransportVersion ADD_INFERENCE_CUSTOM_MODEL_8_X = def(8_841_0_29);
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11);
Expand Down Expand Up @@ -250,6 +251,7 @@ static TransportVersion def(int id) {
public static final TransportVersion FILE_SETTINGS_HEALTH_INFO = def(9_072_0_00);
public static final TransportVersion FIELD_CAPS_ADD_CLUSTER_ALIAS = def(9_073_0_00);
public static final TransportVersion INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT = def(9_074_00_0);
public static final TransportVersion ADD_INFERENCE_CUSTOM_MODEL = def(9_075_0_00);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {

public void testGetServicesWithoutTaskType() throws IOException {
List<Object> services = getAllServices();
assertThat(services.size(), equalTo(22));
assertThat(services.size(), equalTo(23));

var providers = providers(services);

Expand All @@ -39,6 +39,7 @@ public void testGetServicesWithoutTaskType() throws IOException {
"azureaistudio",
"azureopenai",
"cohere",
"custom",
"deepseek",
"elastic",
"elasticsearch",
Expand Down Expand Up @@ -70,7 +71,7 @@ private Iterable<String> providers(List<Object> services) {

public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
List<Object> services = getServices(TaskType.TEXT_EMBEDDING);
assertThat(services.size(), equalTo(16));
assertThat(services.size(), equalTo(17));

var providers = providers(services);

Expand All @@ -83,6 +84,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
"azureaistudio",
"azureopenai",
"cohere",
"custom",
"elasticsearch",
"googleaistudio",
"googlevertexai",
Expand All @@ -101,7 +103,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException {

public void testGetServicesWithRerankTaskType() throws IOException {
List<Object> services = getServices(TaskType.RERANK);
assertThat(services.size(), equalTo(7));
assertThat(services.size(), equalTo(8));

var providers = providers(services);

Expand All @@ -111,6 +113,7 @@ public void testGetServicesWithRerankTaskType() throws IOException {
List.of(
"alibabacloud-ai-search",
"cohere",
"custom",
"elasticsearch",
"googlevertexai",
"jinaai",
Expand All @@ -123,7 +126,7 @@ public void testGetServicesWithRerankTaskType() throws IOException {

public void testGetServicesWithCompletionTaskType() throws IOException {
List<Object> services = getServices(TaskType.COMPLETION);
assertThat(services.size(), equalTo(10));
assertThat(services.size(), equalTo(11));

var providers = providers(services);

Expand All @@ -137,6 +140,7 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
"azureaistudio",
"azureopenai",
"cohere",
"custom",
"deepseek",
"googleaistudio",
"openai",
Expand All @@ -157,7 +161,7 @@ public void testGetServicesWithChatCompletionTaskType() throws IOException {

public void testGetServicesWithSparseEmbeddingTaskType() throws IOException {
List<Object> services = getServices(TaskType.SPARSE_EMBEDDING);
assertThat(services.size(), equalTo(6));
assertThat(services.size(), equalTo(7));

var providers = providers(services);

Expand All @@ -166,6 +170,7 @@ public void testGetServicesWithSparseEmbeddingTaskType() throws IOException {
containsInAnyOrder(
List.of(
"alibabacloud-ai-search",
"custom",
"elastic",
"elasticsearch",
"hugging_face",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,15 @@
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings;
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.custom.CustomSecretSettings;
import org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings;
import org.elasticsearch.xpack.inference.services.custom.CustomTaskSettings;
import org.elasticsearch.xpack.inference.services.custom.response.CompletionResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.CustomResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.NoopResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser;
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel;
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings;
Expand Down Expand Up @@ -154,6 +163,7 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
addAlibabaCloudSearchNamedWriteables(namedWriteables);
addJinaAINamedWriteables(namedWriteables);
addVoyageAINamedWriteables(namedWriteables);
addCustomNamedWriteables(namedWriteables);

addUnifiedNamedWriteables(namedWriteables);

Expand All @@ -165,6 +175,38 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
return namedWriteables;
}

private static void addCustomNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
namedWriteables.add(
new NamedWriteableRegistry.Entry(ServiceSettings.class, CustomServiceSettings.NAME, CustomServiceSettings::new)
);

namedWriteables.add(new NamedWriteableRegistry.Entry(TaskSettings.class, CustomTaskSettings.NAME, CustomTaskSettings::new));

namedWriteables.add(new NamedWriteableRegistry.Entry(SecretSettings.class, CustomSecretSettings.NAME, CustomSecretSettings::new));

namedWriteables.add(
new NamedWriteableRegistry.Entry(CustomResponseParser.class, TextEmbeddingResponseParser.NAME, TextEmbeddingResponseParser::new)
);

namedWriteables.add(
new NamedWriteableRegistry.Entry(
CustomResponseParser.class,
SparseEmbeddingResponseParser.NAME,
SparseEmbeddingResponseParser::new
)
);

namedWriteables.add(
new NamedWriteableRegistry.Entry(CustomResponseParser.class, RerankResponseParser.NAME, RerankResponseParser::new)
);

namedWriteables.add(new NamedWriteableRegistry.Entry(CustomResponseParser.class, NoopResponseParser.NAME, NoopResponseParser::new));

namedWriteables.add(
new NamedWriteableRegistry.Entry(CustomResponseParser.class, CompletionResponseParser.NAME, CompletionResponseParser::new)
);
}

private static void addUnifiedNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
var writeables = UnifiedCompletionRequest.getNamedWriteables();
namedWriteables.addAll(writeables);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioService;
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiService;
import org.elasticsearch.xpack.inference.services.cohere.CohereService;
import org.elasticsearch.xpack.inference.services.custom.CustomService;
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekService;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents;
Expand Down Expand Up @@ -396,6 +397,7 @@ public List<InferenceServiceExtension.Factory> getInferenceServiceFactories() {
context -> new JinaAIService(httpFactory.get(), serviceComponents.get()),
context -> new VoyageAIService(httpFactory.get(), serviceComponents.get()),
context -> new DeepSeekService(httpFactory.get(), serviceComponents.get()),
context -> new CustomService(httpFactory.get(), serviceComponents.get()),
ElasticsearchInternalService::new
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.external.request.HttpRequest;

import java.io.ByteArrayOutputStream;
import java.util.concurrent.Flow;
Expand All @@ -21,7 +22,7 @@ public boolean isSuccessfulResponse() {
return RestStatus.isSuccessful(response.getStatusLine().getStatusCode());
}

public Flow.Publisher<HttpResult> toHttpResult() {
public Flow.Publisher<HttpResult> toHttpResult(HttpRequest httpRequest) {
return subscriber -> body().subscribe(new Flow.Subscriber<>() {
@Override
public void onSubscribe(Flow.Subscription subscription) {
Expand All @@ -45,7 +46,7 @@ public void onComplete() {
});
}

public void readFullResponse(ActionListener<HttpResult> fullResponse) {
public void readFullResponse(HttpRequest httpRequest, ActionListener<HttpResult> fullResponse) {
var stream = new ByteArrayOutputStream();
AtomicReference<Flow.Subscription> upstream = new AtomicReference<>(null);
body.subscribe(new Flow.Subscriber<>() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;

import java.util.Objects;
import java.util.function.Function;
import java.util.function.BiFunction;

import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.inference.external.http.HttpUtils.checkForEmptyBody;
Expand All @@ -36,18 +36,22 @@ public abstract class BaseResponseHandler implements ResponseHandler {
public static final String METHOD_NOT_ALLOWED = "Received a method not allowed status code";

protected final String requestType;
private final ResponseParser parseFunction;
private final Function<HttpResult, ErrorResponse> errorParseFunction;
protected final ResponseParser parseFunction;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making this available so the custom response handler can immediately return on a parse failure instead of retrying.

private final BiFunction<Request, HttpResult, ErrorResponse> errorParseFunction;
private final boolean canHandleStreamingResponses;

public BaseResponseHandler(String requestType, ResponseParser parseFunction, Function<HttpResult, ErrorResponse> errorParseFunction) {
public BaseResponseHandler(
String requestType,
ResponseParser parseFunction,
BiFunction<Request, HttpResult, ErrorResponse> errorParseFunction
) {
this(requestType, parseFunction, errorParseFunction, false);
}

public BaseResponseHandler(
String requestType,
ResponseParser parseFunction,
Function<HttpResult, ErrorResponse> errorParseFunction,
BiFunction<Request, HttpResult, ErrorResponse> errorParseFunction,
boolean canHandleStreamingResponses
) {
this.requestType = Objects.requireNonNull(requestType);
Expand Down Expand Up @@ -96,7 +100,7 @@ public void validateResponse(
protected abstract void checkForFailureStatusCode(Request request, HttpResult result);

private void checkForErrorObject(Request request, HttpResult result) {
var errorEntity = errorParseFunction.apply(result);
var errorEntity = errorParseFunction.apply(request, result);

if (errorEntity.errorStructureFound()) {
// We don't really know what happened because the status code was 200 so we'll return a failure and let the
Expand All @@ -109,7 +113,7 @@ private void checkForErrorObject(Request request, HttpResult result) {
}

protected Exception buildError(String message, Request request, HttpResult result) {
var errorEntityMsg = errorParseFunction.apply(result);
var errorEntityMsg = errorParseFunction.apply(request, result);
return buildError(message, request, result, errorEntityMsg);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,12 @@ public void tryAction(ActionListener<InferenceServiceResults> listener) {

try {
if (request.isStreaming() && responseHandler.canHandleStreamingResponses()) {
httpClient.stream(request.createHttpRequest(), context, retryableListener.delegateFailure((l, r) -> {
var httpRequest = request.createHttpRequest();
httpClient.stream(httpRequest, context, retryableListener.delegateFailure((l, r) -> {
if (r.isSuccessfulResponse()) {
l.onResponse(responseHandler.parseResult(request, r.toHttpResult()));
l.onResponse(responseHandler.parseResult(request, r.toHttpResult(httpRequest)));
} else {
r.readFullResponse(l.delegateFailureAndWrap((ll, httpResult) -> {
r.readFullResponse(httpRequest, l.delegateFailureAndWrap((ll, httpResult) -> {
try {
responseHandler.validateResponse(throttlerManager, logger, request, httpResult, true);
InferenceServiceResults inferenceResults = responseHandler.parseResult(request, httpResult);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
import org.elasticsearch.xpack.inference.external.request.Request;

import java.util.Map;
import java.util.Objects;
Expand Down Expand Up @@ -68,4 +69,8 @@ public static ErrorResponse fromResponse(HttpResult response, String defaultMess
public static ErrorResponse fromResponse(HttpResult response) {
return fromResponse(response, "");
}

public static ErrorResponse fromResponse(Request request, HttpResult response) {
return fromResponse(response);
}
}
Loading
Loading