Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/118301.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 118301
summary: EIS Unified chat completions integration
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,13 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalFloat(topP);
}

public record Message(Content content, String role, @Nullable String name, @Nullable String toolCallId, List<ToolCall> toolCalls)
implements
Writeable {
public record Message(
Content content,
String role,
@Nullable String name,
@Nullable String toolCallId,
@Nullable List<ToolCall> toolCalls
) implements Writeable {

@SuppressWarnings("unchecked")
static final ConstructingObjectParser<Message, Void> PARSER = new ConstructingObjectParser<>(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.elastic;

import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
import org.elasticsearch.xpack.inference.external.openai.OpenAiUnifiedStreamingProcessor;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor;

import java.util.concurrent.Flow;

public class EISUnifiedChatCompletionResponseHandler extends ElasticInferenceServiceResponseHandler {
public EISUnifiedChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
super(requestType, parseFunction);
}

@Override
public boolean canHandleStreamingResponses() {
return true;
}

@Override
public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpResult> flow) {
var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser());
var openAiProcessor = new OpenAiUnifiedStreamingProcessor(); // EIS uses the unified API spec

flow.subscribe(serverSentEventProcessor);
serverSentEventProcessor.subscribe(openAiProcessor);
return new StreamingUnifiedChatCompletionResults(openAiProcessor);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.http.sender;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.external.elastic.EISUnifiedChatCompletionResponseHandler;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.request.elastic.EISUnifiedChatCompletionRequest;
import org.elasticsearch.xpack.inference.external.response.openai.OpenAiChatCompletionResponseEntity;
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
import org.elasticsearch.xpack.inference.telemetry.TraceContext;

import java.util.Objects;
import java.util.function.Supplier;

public class EISUnifiedCompletionRequestManager extends ElasticInferenceServiceRequestManager {

private static final Logger logger = LogManager.getLogger(EISUnifiedCompletionRequestManager.class);

private static final ResponseHandler HANDLER = createCompletionHandler();

public static EISUnifiedCompletionRequestManager of(
ElasticInferenceServiceCompletionModel model,
ThreadPool threadPool,
TraceContext traceContext
) {
return new EISUnifiedCompletionRequestManager(
Objects.requireNonNull(model),
Objects.requireNonNull(threadPool),
Objects.requireNonNull(traceContext)
);
}

private final ElasticInferenceServiceCompletionModel model;
private final TraceContext traceContext;

private EISUnifiedCompletionRequestManager(
ElasticInferenceServiceCompletionModel model,
ThreadPool threadPool,
TraceContext traceContext
) {
super(threadPool, model);
this.model = model;
this.traceContext = traceContext;
}

@Override
public void execute(
InferenceInputs inferenceInputs,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {

EISUnifiedChatCompletionRequest request = new EISUnifiedChatCompletionRequest(
inferenceInputs.castTo(UnifiedChatInput.class),
model,
traceContext
);

execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
}

private static ResponseHandler createCompletionHandler() {
return new EISUnifiedChatCompletionResponseHandler("eis completion", OpenAiChatCompletionResponseEntity::fromResponse);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.request.elastic;

import org.apache.http.HttpHeaders;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.entity.ByteArrayEntity;
import org.apache.http.message.BasicHeader;
import org.elasticsearch.common.Strings;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.external.request.openai.OpenAiRequest;
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
import org.elasticsearch.xpack.inference.telemetry.TraceContext;

import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.Objects;

import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader;

public class EISUnifiedChatCompletionRequest implements OpenAiRequest {

private final ElasticInferenceServiceCompletionModel model;
private final UnifiedChatInput unifiedChatInput;
private final URI uri;
private final TraceContext traceContext;

public EISUnifiedChatCompletionRequest(
UnifiedChatInput unifiedChatInput,
ElasticInferenceServiceCompletionModel model,
TraceContext traceContext
) {
this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput);
this.model = Objects.requireNonNull(model);
this.uri = model.uri();
this.traceContext = traceContext;

}

@Override
public HttpRequest createHttpRequest() {
var httpPost = new HttpPost(uri);
var requestEntity = Strings.toString(
// TODO remove the modelId() call if not used
new EISUnifiedChatCompletionRequestEntity(unifiedChatInput, model.getServiceSettings().modelId())
);

ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8));
httpPost.setEntity(byteEntity);

if (traceContext != null) {
propagateTraceContext(httpPost);
}

httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()));
// TODO remove EIS doesn't use an API key
httpPost.setHeader(createAuthBearerHeader(model.getSecretSettings().apiKey()));

return new HttpRequest(httpPost, getInferenceEntityId());
}

@Override
public URI getURI() {
return uri;
}

@Override
public Request truncate() {
// No truncation
return this;
}

@Override
public boolean[] getTruncationInfo() {
// No truncation
return null;
}

@Override
public String getInferenceEntityId() {
return model.getInferenceEntityId();
}

@Override
public boolean isStreaming() {
return true;
}

public TraceContext getTraceContext() {
return traceContext;
}

private void propagateTraceContext(HttpPost httpPost) {
var traceParent = traceContext.traceParent();
var traceState = traceContext.traceState();

if (traceParent != null) {
httpPost.setHeader(Task.TRACE_PARENT_HTTP_HEADER, traceParent);
}

if (traceState != null) {
httpPost.setHeader(Task.TRACE_STATE, traceState);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.request.elastic;

import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.external.unified.UnifiedChatCompletionRequestEntity;

import java.io.IOException;
import java.util.Objects;

public class EISUnifiedChatCompletionRequestEntity implements ToXContentObject {
// TODO remove this if EIS doesn't use it
private static final String MODEL_FIELD = "model";

private final UnifiedChatCompletionRequestEntity unifiedRequestEntity;
private final String modelId;

public EISUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, String modelId) {
this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(Objects.requireNonNull(unifiedChatInput));
this.modelId = Objects.requireNonNull(modelId);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
unifiedRequestEntity.toXContent(builder, params);
// TODO remove this if EIS doesn't use it
builder.field(MODEL_FIELD, modelId);
builder.endObject();

return builder;
}
}
Loading
Loading