Skip to content

Commit 87a93ca

Browse files
committed
[Elastic Inference Service] Add ElasticInferenceService Unified ChatCompletions Integration (elastic#118871)
(cherry picked from commit 18345c4) # Conflicts: # server/src/main/java/org/elasticsearch/TransportVersions.java # x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java # x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModel.java
1 parent 338376c commit 87a93ca

File tree

28 files changed

+1813
-958
lines changed

28 files changed

+1813
-958
lines changed

docs/changelog/118871.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 118871
2+
summary: "[Elastic Inference Service] Add ElasticInferenceService Unified ChatCompletions Integration"
3+
area: Inference
4+
type: enhancement
5+
issues: []

server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,13 @@ public void writeTo(StreamOutput out) throws IOException {
111111
out.writeOptionalFloat(topP);
112112
}
113113

114-
public record Message(Content content, String role, @Nullable String name, @Nullable String toolCallId, List<ToolCall> toolCalls)
115-
implements
116-
Writeable {
114+
public record Message(
115+
Content content,
116+
String role,
117+
@Nullable String name,
118+
@Nullable String toolCallId,
119+
@Nullable List<ToolCall> toolCalls
120+
) implements Writeable {
117121

118122
@SuppressWarnings("unchecked")
119123
static final ConstructingObjectParser<Message, Void> PARSER = new ConstructingObjectParser<>(

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ public void testGetServicesWithRerankTaskType() throws IOException {
243243
@SuppressWarnings("unchecked")
244244
public void testGetServicesWithCompletionTaskType() throws IOException {
245245
List<Object> services = getServices(TaskType.COMPLETION);
246-
assertThat(services.size(), equalTo(9));
246+
assertThat(services.size(), equalTo(10));
247247

248248
var providers = new ArrayList<String>();
249249
for (int i = 0; i < services.size(); i++) {
@@ -259,6 +259,7 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
259259
"azureaistudio",
260260
"azureopenai",
261261
"cohere",
262+
"elastic",
262263
"googleaistudio",
263264
"openai",
264265
"streaming_completion_test_service"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.elastic;
9+
10+
import org.elasticsearch.inference.InferenceServiceResults;
11+
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
12+
import org.elasticsearch.xpack.inference.external.http.HttpResult;
13+
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
14+
import org.elasticsearch.xpack.inference.external.openai.OpenAiUnifiedStreamingProcessor;
15+
import org.elasticsearch.xpack.inference.external.request.Request;
16+
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser;
17+
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor;
18+
19+
import java.util.concurrent.Flow;
20+
21+
public class ElasticInferenceServiceUnifiedChatCompletionResponseHandler extends ElasticInferenceServiceResponseHandler {
22+
public ElasticInferenceServiceUnifiedChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
23+
super(requestType, parseFunction);
24+
}
25+
26+
@Override
27+
public boolean canHandleStreamingResponses() {
28+
return true;
29+
}
30+
31+
@Override
32+
public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpResult> flow) {
33+
var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser());
34+
var openAiProcessor = new OpenAiUnifiedStreamingProcessor(); // EIS uses the unified API spec
35+
36+
flow.subscribe(serverSentEventProcessor);
37+
serverSentEventProcessor.subscribe(openAiProcessor);
38+
return new StreamingUnifiedChatCompletionResults(openAiProcessor);
39+
}
40+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.http.sender;
9+
10+
import org.apache.logging.log4j.LogManager;
11+
import org.apache.logging.log4j.Logger;
12+
import org.elasticsearch.action.ActionListener;
13+
import org.elasticsearch.inference.InferenceServiceResults;
14+
import org.elasticsearch.threadpool.ThreadPool;
15+
import org.elasticsearch.xpack.inference.external.elastic.ElasticInferenceServiceUnifiedChatCompletionResponseHandler;
16+
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
17+
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
18+
import org.elasticsearch.xpack.inference.external.request.elastic.ElasticInferenceServiceUnifiedChatCompletionRequest;
19+
import org.elasticsearch.xpack.inference.external.response.openai.OpenAiChatCompletionResponseEntity;
20+
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
21+
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
22+
23+
import java.util.Objects;
24+
import java.util.function.Supplier;
25+
26+
public class ElasticInferenceServiceUnifiedCompletionRequestManager extends ElasticInferenceServiceRequestManager {
27+
28+
private static final Logger logger = LogManager.getLogger(ElasticInferenceServiceUnifiedCompletionRequestManager.class);
29+
30+
private static final ResponseHandler HANDLER = createCompletionHandler();
31+
32+
public static ElasticInferenceServiceUnifiedCompletionRequestManager of(
33+
ElasticInferenceServiceCompletionModel model,
34+
ThreadPool threadPool,
35+
TraceContext traceContext
36+
) {
37+
return new ElasticInferenceServiceUnifiedCompletionRequestManager(
38+
Objects.requireNonNull(model),
39+
Objects.requireNonNull(threadPool),
40+
Objects.requireNonNull(traceContext)
41+
);
42+
}
43+
44+
private final ElasticInferenceServiceCompletionModel model;
45+
private final TraceContext traceContext;
46+
47+
private ElasticInferenceServiceUnifiedCompletionRequestManager(
48+
ElasticInferenceServiceCompletionModel model,
49+
ThreadPool threadPool,
50+
TraceContext traceContext
51+
) {
52+
super(threadPool, model);
53+
this.model = model;
54+
this.traceContext = traceContext;
55+
}
56+
57+
@Override
58+
public void execute(
59+
InferenceInputs inferenceInputs,
60+
RequestSender requestSender,
61+
Supplier<Boolean> hasRequestCompletedFunction,
62+
ActionListener<InferenceServiceResults> listener
63+
) {
64+
65+
ElasticInferenceServiceUnifiedChatCompletionRequest request = new ElasticInferenceServiceUnifiedChatCompletionRequest(
66+
inferenceInputs.castTo(UnifiedChatInput.class),
67+
model,
68+
traceContext
69+
);
70+
71+
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
72+
}
73+
74+
private static ResponseHandler createCompletionHandler() {
75+
return new ElasticInferenceServiceUnifiedChatCompletionResponseHandler(
76+
"elastic inference service completion",
77+
// We use OpenAiChatCompletionResponseEntity here as the ElasticInferenceServiceResponseEntity fields are a subset of the OpenAI
78+
// one.
79+
OpenAiChatCompletionResponseEntity::fromResponse
80+
);
81+
}
82+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceSparseEmbeddingsRequest.java

Lines changed: 9 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,13 @@
1212
import org.apache.http.entity.ByteArrayEntity;
1313
import org.apache.http.message.BasicHeader;
1414
import org.elasticsearch.common.Strings;
15-
import org.elasticsearch.tasks.Task;
1615
import org.elasticsearch.xcontent.XContentType;
1716
import org.elasticsearch.xpack.inference.common.Truncator;
1817
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
1918
import org.elasticsearch.xpack.inference.external.request.Request;
2019
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModel;
2120
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
21+
import org.elasticsearch.xpack.inference.telemetry.TraceContextHandler;
2222

2323
import java.net.URI;
2424
import java.nio.charset.StandardCharsets;
@@ -27,13 +27,10 @@
2727
public class ElasticInferenceServiceSparseEmbeddingsRequest implements ElasticInferenceServiceRequest {
2828

2929
private final URI uri;
30-
3130
private final ElasticInferenceServiceSparseEmbeddingsModel model;
32-
3331
private final Truncator.TruncationResult truncationResult;
3432
private final Truncator truncator;
35-
36-
private final TraceContext traceContext;
33+
private final TraceContextHandler traceContextHandler;
3734

3835
public ElasticInferenceServiceSparseEmbeddingsRequest(
3936
Truncator truncator,
@@ -45,7 +42,7 @@ public ElasticInferenceServiceSparseEmbeddingsRequest(
4542
this.truncationResult = truncationResult;
4643
this.model = Objects.requireNonNull(model);
4744
this.uri = model.uri();
48-
this.traceContext = traceContext;
45+
this.traceContextHandler = new TraceContextHandler(traceContext);
4946
}
5047

5148
@Override
@@ -56,15 +53,16 @@ public HttpRequest createHttpRequest() {
5653
ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8));
5754
httpPost.setEntity(byteEntity);
5855

59-
if (traceContext != null) {
60-
propagateTraceContext(httpPost);
61-
}
62-
56+
traceContextHandler.propagateTraceContext(httpPost);
6357
httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()));
6458

6559
return new HttpRequest(httpPost, getInferenceEntityId());
6660
}
6761

62+
public TraceContext getTraceContext() {
63+
return traceContextHandler.traceContext();
64+
}
65+
6866
@Override
6967
public String getInferenceEntityId() {
7068
return model.getInferenceEntityId();
@@ -75,32 +73,15 @@ public URI getURI() {
7573
return this.uri;
7674
}
7775

78-
public TraceContext getTraceContext() {
79-
return traceContext;
80-
}
81-
8276
@Override
8377
public Request truncate() {
8478
var truncatedInput = truncator.truncate(truncationResult.input());
85-
86-
return new ElasticInferenceServiceSparseEmbeddingsRequest(truncator, truncatedInput, model, traceContext);
79+
return new ElasticInferenceServiceSparseEmbeddingsRequest(truncator, truncatedInput, model, traceContextHandler.traceContext());
8780
}
8881

8982
@Override
9083
public boolean[] getTruncationInfo() {
9184
return truncationResult.truncated().clone();
9285
}
9386

94-
private void propagateTraceContext(HttpPost httpPost) {
95-
var traceParent = traceContext.traceParent();
96-
var traceState = traceContext.traceState();
97-
98-
if (traceParent != null) {
99-
httpPost.setHeader(Task.TRACE_PARENT_HTTP_HEADER, traceParent);
100-
}
101-
102-
if (traceState != null) {
103-
httpPost.setHeader(Task.TRACE_STATE, traceState);
104-
}
105-
}
10687
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.request.elastic;
9+
10+
import org.apache.http.HttpHeaders;
11+
import org.apache.http.client.methods.HttpPost;
12+
import org.apache.http.entity.ByteArrayEntity;
13+
import org.apache.http.message.BasicHeader;
14+
import org.elasticsearch.common.Strings;
15+
import org.elasticsearch.xcontent.XContentType;
16+
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
17+
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
18+
import org.elasticsearch.xpack.inference.external.request.Request;
19+
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
20+
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
21+
import org.elasticsearch.xpack.inference.telemetry.TraceContextHandler;
22+
23+
import java.net.URI;
24+
import java.nio.charset.StandardCharsets;
25+
import java.util.Objects;
26+
27+
public class ElasticInferenceServiceUnifiedChatCompletionRequest implements Request {
28+
29+
private final ElasticInferenceServiceCompletionModel model;
30+
private final UnifiedChatInput unifiedChatInput;
31+
private final TraceContextHandler traceContextHandler;
32+
33+
public ElasticInferenceServiceUnifiedChatCompletionRequest(
34+
UnifiedChatInput unifiedChatInput,
35+
ElasticInferenceServiceCompletionModel model,
36+
TraceContext traceContext
37+
) {
38+
this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput);
39+
this.model = Objects.requireNonNull(model);
40+
this.traceContextHandler = new TraceContextHandler(traceContext);
41+
}
42+
43+
@Override
44+
public HttpRequest createHttpRequest() {
45+
var httpPost = new HttpPost(model.uri());
46+
var requestEntity = Strings.toString(
47+
new ElasticInferenceServiceUnifiedChatCompletionRequestEntity(unifiedChatInput, model.getServiceSettings().modelId())
48+
);
49+
50+
ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8));
51+
httpPost.setEntity(byteEntity);
52+
53+
traceContextHandler.propagateTraceContext(httpPost);
54+
httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()));
55+
56+
return new HttpRequest(httpPost, getInferenceEntityId());
57+
}
58+
59+
@Override
60+
public URI getURI() {
61+
return model.uri();
62+
}
63+
64+
@Override
65+
public Request truncate() {
66+
// No truncation
67+
return this;
68+
}
69+
70+
@Override
71+
public boolean[] getTruncationInfo() {
72+
// No truncation
73+
return null;
74+
}
75+
76+
@Override
77+
public String getInferenceEntityId() {
78+
return model.getInferenceEntityId();
79+
}
80+
81+
@Override
82+
public boolean isStreaming() {
83+
return true;
84+
}
85+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.request.elastic;
9+
10+
import org.elasticsearch.xcontent.ToXContentObject;
11+
import org.elasticsearch.xcontent.XContentBuilder;
12+
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
13+
import org.elasticsearch.xpack.inference.external.unified.UnifiedChatCompletionRequestEntity;
14+
15+
import java.io.IOException;
16+
import java.util.Objects;
17+
18+
public class ElasticInferenceServiceUnifiedChatCompletionRequestEntity implements ToXContentObject {
19+
private static final String MODEL_FIELD = "model";
20+
21+
private final UnifiedChatCompletionRequestEntity unifiedRequestEntity;
22+
private final String modelId;
23+
24+
public ElasticInferenceServiceUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, String modelId) {
25+
this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(Objects.requireNonNull(unifiedChatInput));
26+
this.modelId = Objects.requireNonNull(modelId);
27+
}
28+
29+
@Override
30+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
31+
builder.startObject();
32+
unifiedRequestEntity.toXContent(builder, params);
33+
builder.field(MODEL_FIELD, modelId);
34+
builder.endObject();
35+
36+
return builder;
37+
}
38+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiEmbeddingsRequest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader;
2828
import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.createOrgHeader;
2929

30-
public class OpenAiEmbeddingsRequest implements OpenAiRequest {
30+
public class OpenAiEmbeddingsRequest implements Request {
3131

3232
private final Truncator truncator;
3333
private final OpenAiAccount account;

0 commit comments

Comments
 (0)