From bf507e79582c587f7ef614d1c5c9bbfe3345bd2c Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Wed, 20 Nov 2024 17:24:35 -0500 Subject: [PATCH 01/53] Adding some shell classes --- .../org/elasticsearch/TransportVersions.java | 1 + .../inference/InferenceService.java | 16 ++++++ .../inference/action/InferenceAction.java | 53 +++++++++++++++++-- .../action/TransportInferenceAction.java | 34 +++++++++++- .../action/openai/OpenAiActionCreator.java | 2 +- .../external/http/sender/InferenceInputs.java | 8 +++ .../http/sender/UnifiedCompletionInputs.java | 32 +++++++++++ .../xpack/inference/rest/Paths.java | 6 +++ .../RestUnifiedCompletionInferenceAction.java | 42 +++++++++++++++ .../rest/UnifiedCompletionFeature.java | 20 +++++++ .../inference/services/SenderService.java | 14 +++++ .../services/openai/OpenAiService.java | 30 +++++++++++ 12 files changed, 251 insertions(+), 7 deletions(-) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedCompletionInputs.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceAction.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/UnifiedCompletionFeature.java diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 5b5d12d738194..622c8a1f8371f 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -200,6 +200,7 @@ static TransportVersion def(int id) { public static final TransportVersion SKIP_INNER_HITS_SEARCH_SOURCE = def(8_791_00_0); public static final TransportVersion QUERY_RULES_LIST_INCLUDES_TYPES = def(8_792_00_0); public static final TransportVersion INDEX_STATS_ADDITIONAL_FIELDS = def(8_793_00_0); + public static final TransportVersion ML_INFERENCE_UNIFIED_COMPLETIONS_API = def(8_794_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceService.java b/server/src/main/java/org/elasticsearch/inference/InferenceService.java index c6e09f61befa0..45dc49308008f 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceService.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceService.java @@ -111,6 +111,22 @@ void infer( ActionListener listener ); + /** + * Perform completion inference on the model using the unified schema. + * + * @param model The model + * @param parameters Parameters for the request + * @param timeout The timeout for the request + * @param listener Inference result listener + */ + void completionInfer( + Model model, + // TODO create the class for this object + Object parameters, + TimeValue timeout, + ActionListener listener + ); + /** * Chunk long text according to {@code chunkingOptions} or the * model defaults if {@code chunkingOptions} contains unset diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java index a19edd5a08162..186002a5ce0b6 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java @@ -93,6 +93,7 @@ public static Builder parseRequest(String inferenceEntityId, TaskType taskType, private final InputType inputType; private final TimeValue inferenceTimeout; private final boolean stream; + private final boolean isUnifiedCompletionMode; public Request( TaskType taskType, @@ -102,7 +103,8 @@ public Request( Map taskSettings, InputType inputType, TimeValue inferenceTimeout, - boolean stream + boolean stream, + boolean isUnifiedCompletionsMode ) { this.taskType = taskType; this.inferenceEntityId = inferenceEntityId; @@ -112,6 +114,7 @@ public Request( this.inputType = inputType; this.inferenceTimeout = inferenceTimeout; this.stream = stream; + this.isUnifiedCompletionMode = isUnifiedCompletionsMode; } public Request(StreamInput in) throws IOException { @@ -138,6 +141,12 @@ public Request(StreamInput in) throws IOException { this.inferenceTimeout = DEFAULT_TIMEOUT; } + if (in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_UNIFIED_COMPLETIONS_API)) { + this.isUnifiedCompletionMode = in.readBoolean(); + } else { + this.isUnifiedCompletionMode = false; + } + // streaming is not supported yet for transport traffic this.stream = false; } @@ -174,6 +183,10 @@ public boolean isStreaming() { return stream; } + public boolean isUnifiedCompletionMode() { + return isUnifiedCompletionMode; + } + @Override public ActionRequestValidationException validate() { if (input == null) { @@ -224,6 +237,10 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalString(query); out.writeTimeValue(inferenceTimeout); } + + if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_UNIFIED_COMPLETIONS_API)) { + out.writeBoolean(isUnifiedCompletionMode); + } } // default for easier testing @@ -250,12 +267,22 @@ public boolean equals(Object o) { && Objects.equals(taskSettings, request.taskSettings) && Objects.equals(inputType, request.inputType) && Objects.equals(query, request.query) - && Objects.equals(inferenceTimeout, request.inferenceTimeout); + && Objects.equals(inferenceTimeout, request.inferenceTimeout) + && Objects.equals(isUnifiedCompletionMode, request.isUnifiedCompletionMode); } @Override public int hashCode() { - return Objects.hash(taskType, inferenceEntityId, input, taskSettings, inputType, query, inferenceTimeout); + return Objects.hash( + taskType, + inferenceEntityId, + input, + taskSettings, + inputType, + query, + inferenceTimeout, + isUnifiedCompletionMode + ); } public static class Builder { @@ -268,6 +295,7 @@ public static class Builder { private String query; private TimeValue timeout = DEFAULT_TIMEOUT; private boolean stream = false; + private boolean unifiedCompletionMode = false; private Builder() {} @@ -315,8 +343,23 @@ public Builder setStream(boolean stream) { return this; } + public Builder setUnifiedCompletionMode(boolean unified) { + this.unifiedCompletionMode = unified; + return this; + } + public Request build() { - return new Request(taskType, inferenceEntityId, query, input, taskSettings, inputType, timeout, stream); + return new Request( + taskType, + inferenceEntityId, + query, + input, + taskSettings, + inputType, + timeout, + stream, + unifiedCompletionMode + ); } } @@ -335,6 +378,8 @@ public String toString() { + this.getInputType() + ", timeout=" + this.getInferenceTimeout() + + ", isUnifiedCompletionsMode=" + + this.isUnifiedCompletionMode() + ")"; } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java index ba9ab3c133731..0c8f4755bf5b8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java @@ -40,10 +40,10 @@ import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.responseAttributes; public class TransportInferenceAction extends HandledTransportAction { + private static final Logger log = LogManager.getLogger(TransportInferenceAction.class); private static final String STREAMING_INFERENCE_TASK_TYPE = "streaming_inference"; private static final String STREAMING_TASK_ACTION = "xpack/inference/streaming_inference[n]"; - private final ModelRegistry modelRegistry; private final InferenceServiceRegistry serviceRegistry; private final InferenceStats inferenceStats; @@ -86,6 +86,13 @@ protected void doExecute(Task task, InferenceAction.Request request, ActionListe return; } + if (isInvalidTaskTypeForUnifiedCompletionMode(request, unparsedModel)) { + var e = incompatibleUnifiedModeTaskTypeException(request.getTaskType()); + recordMetrics(unparsedModel, timer, e); + listener.onFailure(e); + return; + } + var model = service.get() .parsePersistedConfigWithSecrets( unparsedModel.inferenceEntityId(), @@ -106,6 +113,19 @@ protected void doExecute(Task task, InferenceAction.Request request, ActionListe modelRegistry.getModelWithSecrets(request.getInferenceEntityId(), getModelListener); } + private static boolean isInvalidTaskTypeForUnifiedCompletionMode(InferenceAction.Request request, UnparsedModel unparsedModel) { + return request.isUnifiedCompletionMode() && request.getTaskType() != TaskType.COMPLETION; + } + + private static ElasticsearchStatusException incompatibleUnifiedModeTaskTypeException(TaskType requested) { + return new ElasticsearchStatusException( + "Incompatible task_type for unified API, the requested type [{}] must be one of [{}]", + RestStatus.BAD_REQUEST, + requested, + TaskType.COMPLETION.toString() + ); + } + private void recordMetrics(UnparsedModel model, InferenceTimer timer, @Nullable Throwable t) { try { inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(model, t)); @@ -155,6 +175,16 @@ private void inferOnService( InferenceService service, ActionListener listener ) { + Runnable a = () -> service.infer( + model, + request.getQuery(), + request.getInput(), + request.isStreaming(), + request.getTaskSettings(), + request.getInputType(), + request.getInferenceTimeout(), + listener + ); if (request.isStreaming() == false || service.canStream(request.getTaskType())) { service.infer( model, @@ -206,6 +236,7 @@ private static ElasticsearchStatusException incompatibleTaskTypeException(TaskTy } private class PublisherWithMetrics extends DelegatingProcessor { + private final InferenceTimer timer; private final Model model; @@ -237,5 +268,4 @@ public void onComplete() { super.onComplete(); } } - } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreator.java index 9c83264b5581f..bd5c53d589df0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreator.java @@ -26,7 +26,7 @@ * Provides a way to construct an {@link ExecutableAction} using the visitor pattern based on the openai model type. */ public class OpenAiActionCreator implements OpenAiActionVisitor { - private static final String COMPLETION_ERROR_PREFIX = "OpenAI chat completions"; + public static final String COMPLETION_ERROR_PREFIX = "OpenAI chat completions"; private final Sender sender; private final ServiceComponents serviceComponents; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputs.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputs.java index dd241857ef0c4..45244a25db891 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputs.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputs.java @@ -13,4 +13,12 @@ public abstract class InferenceInputs { public static IllegalArgumentException createUnsupportedTypeException(InferenceInputs inferenceInputs) { return new IllegalArgumentException(Strings.format("Unsupported inference inputs type: [%s]", inferenceInputs.getClass())); } + + public static T abc(InferenceInputs inputs, Class clazz) { + if (inputs.getClass().isInstance(clazz) == false) { + throw createUnsupportedTypeException(inputs); + } + + return clazz.cast(inputs); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedCompletionInputs.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedCompletionInputs.java new file mode 100644 index 0000000000000..35dd9990214a1 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedCompletionInputs.java @@ -0,0 +1,32 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +public class UnifiedCompletionInputs extends InferenceInputs { + public static UnifiedCompletionInputs of(InferenceInputs inferenceInputs) { + return InferenceInputs.abc(inferenceInputs, UnifiedCompletionInputs.class); + } + + private final Object parameters; + private final boolean stream; + + public UnifiedCompletionInputs(Object parameters) { + super(); + this.parameters = parameters; + // TODO retrieve this from the parameters eventually + this.stream = true; + } + + public Object parameters() { + return parameters; + } + + public boolean stream() { + return stream; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/Paths.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/Paths.java index 55d6443b43c03..c46f211bb26af 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/Paths.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/Paths.java @@ -30,6 +30,12 @@ public final class Paths { + "}/{" + INFERENCE_ID + "}/_stream"; + static final String UNIFIED_INFERENCE_ID_PATH = "_inference/{" + TASK_TYPE_OR_INFERENCE_ID + "}/_unified"; + static final String UNIFIED_TASK_TYPE_INFERENCE_ID_PATH = "_inference/{" + + TASK_TYPE_OR_INFERENCE_ID + + "}/{" + + INFERENCE_ID + + "}/_unified"; private Paths() { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceAction.java new file mode 100644 index 0000000000000..0056e80af15ca --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceAction.java @@ -0,0 +1,42 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.rest; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.rest.RestChannel; +import org.elasticsearch.rest.Scope; +import org.elasticsearch.rest.ServerlessScope; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; + +import java.util.List; + +import static org.elasticsearch.rest.RestRequest.Method.POST; +import static org.elasticsearch.xpack.inference.rest.Paths.UNIFIED_TASK_TYPE_INFERENCE_ID_PATH; + +@ServerlessScope(Scope.PUBLIC) +public class RestUnifiedCompletionInferenceAction extends BaseInferenceAction { + @Override + public String getName() { + return "unified_inference_action"; + } + + @Override + public List routes() { + return List.of(new Route(POST, UNIFIED_TASK_TYPE_INFERENCE_ID_PATH), new Route(POST, UNIFIED_TASK_TYPE_INFERENCE_ID_PATH)); + } + + @Override + protected InferenceAction.Request prepareInferenceRequest(InferenceAction.Request.Builder builder) { + return builder.setUnifiedCompletionMode(true).setStream(true).build(); + } + + @Override + protected ActionListener listener(RestChannel channel) { + return new ServerSentEventsRestActionListener(channel); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/UnifiedCompletionFeature.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/UnifiedCompletionFeature.java new file mode 100644 index 0000000000000..a02e5591174d4 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/UnifiedCompletionFeature.java @@ -0,0 +1,20 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.rest; + +import org.elasticsearch.common.util.FeatureFlag; + +/** + * Unified Completion feature flag. When the feature is complete, this flag will be removed. + * Enable feature via JVM option: `-Des.unified_feature_flag_enabled=true`. + */ +public class UnifiedCompletionFeature { + public static final FeatureFlag UNIFIED_COMPLETION_FEATURE_FLAG = new FeatureFlag("unified"); + + private UnifiedCompletionFeature() {} +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java index b8a99227cf517..94fb4e59f3c79 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java @@ -23,6 +23,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedCompletionInputs; import java.io.IOException; import java.util.EnumSet; @@ -69,6 +70,12 @@ public void infer( } } + @Override + public void completionInfer(Model model, Object parameters, TimeValue timeout, ActionListener listener) { + init(); + doUnifiedCompletionInfer(model, new UnifiedCompletionInputs(parameters), timeout, listener); + } + @Override public void chunkedInfer( Model model, @@ -94,6 +101,13 @@ protected abstract void doInfer( ActionListener listener ); + protected abstract void doUnifiedCompletionInfer( + Model model, + UnifiedCompletionInputs inputs, + TimeValue timeout, + ActionListener listener + ); + protected abstract void doChunkedInfer( Model model, DocumentsOnlyInput inputs, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java index 81ab87a461696..b463635c0d2f9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java @@ -33,10 +33,13 @@ import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; +import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction; import org.elasticsearch.xpack.inference.external.action.openai.OpenAiActionCreator; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.OpenAiCompletionRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedCompletionInputs; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; @@ -54,6 +57,8 @@ import java.util.Map; import java.util.Set; +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; +import static org.elasticsearch.xpack.inference.external.action.openai.OpenAiActionCreator.COMPLETION_ERROR_PREFIX; import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; @@ -258,6 +263,31 @@ public void doInfer( action.execute(inputs, timeout, listener); } + @Override + public void doUnifiedCompletionInfer( + Model model, + UnifiedCompletionInputs inputs, + TimeValue timeout, + ActionListener listener + ) { + if (model instanceof OpenAiChatCompletionModel == false) { + listener.onFailure(createInvalidModelException(model)); + return; + } + + OpenAiChatCompletionModel openAiModel = (OpenAiChatCompletionModel) model; + + // TODO override fields from the persisted model + // var overriddenModel = OpenAiChatCompletionModel.of(model, taskSettings); + // TODO create a new OpenAiCompletionRequestManager with the appropriate unified completion input + // or look into merging the functionality but that'd require potentially a lot more fields for the old version? + var requestCreator = OpenAiCompletionRequestManager.of(openAiModel, getServiceComponents().threadPool()); + var errorMessage = constructFailedToSendRequestMessage(openAiModel.getServiceSettings().uri(), COMPLETION_ERROR_PREFIX); + var action = new SingleInputSenderExecutableAction(getSender(), requestCreator, errorMessage, COMPLETION_ERROR_PREFIX); + + action.execute(inputs, timeout, listener); + } + @Override protected void doChunkedInfer( Model model, From 705aa42f3475a3205d5b4bd058d26233a2fd8065 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Thu, 21 Nov 2024 16:52:39 -0500 Subject: [PATCH 02/53] modeling the request objects --- .../inference/action/InferenceAction.java | 4 + .../action/UnifiedCompletionRequest.java | 138 ++++++++++++++++++ .../action/TransportInferenceAction.java | 32 ++-- ...etionInputs.java => CompletionInputs.java} | 8 +- ...OpenAiUnifiedCompletionRequestManager.java | 61 ++++++++ .../inference/services/SenderService.java | 29 +++- .../services/openai/OpenAiService.java | 4 +- 7 files changed, 248 insertions(+), 28 deletions(-) create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequest.java rename x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/{UnifiedCompletionInputs.java => CompletionInputs.java} (70%) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiUnifiedCompletionRequestManager.java diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java index 186002a5ce0b6..9511fbbe011f4 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java @@ -212,6 +212,10 @@ public ActionRequestValidationException validate() { e.addValidationError(format("Field [query] cannot be empty for task type [%s]", TaskType.RERANK)); return e; } + } else if (query != null) { + var e = new ActionRequestValidationException(); + e.addValidationError(format("Task type [%s] does not support field [query]", TaskType.RERANK)); + return e; } return null; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequest.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequest.java new file mode 100644 index 0000000000000..32deae60c0749 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequest.java @@ -0,0 +1,138 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.action; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.XContentParseException; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; +import java.util.List; + +public class UnifiedCompletionRequest { + + static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + InferenceAction.NAME, + args -> new Thing(args[0]) + ); + + static { + + } + + public static class Thing { + private final Object obj; + + Thing(Object obj) { + this.obj = obj; + } + } + + // TODO convert these to static classes instead of record to make transport changes easier in the future + public record Message(Content content, String role, @Nullable String name, @Nullable String toolCallId, List toolCalls) { + + @SuppressWarnings("unchecked") + static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + Message.class.getSimpleName(), + args -> new Message((Content) args[0], (String) args[1], (String) args[2], (String) args[3], (List) args[4]) + ); + static { + PARSER.declareField( + ConstructingObjectParser.constructorArg(), + (p, c) -> parseContent(p), + new ParseField("content"), + ObjectParser.ValueType.VALUE + ); + PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("role")); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField("name")); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField("tool_call_id")); + PARSER.declareObjectArray( + ConstructingObjectParser.optionalConstructorArg(), + ToolCall.PARSER::apply, + new ParseField("tool_calls") + ); + } + + private static Content parseContent(XContentParser parser) throws IOException { + var token = parser.currentToken(); + if (token == XContentParser.Token.START_OBJECT) { + return ContentObject.PARSER.parse(parser, null); + } else if (token == XContentParser.Token.VALUE_STRING) { + return ContentString.of(parser); + } + + throw new XContentParseException("Unsupported token [" + token + "]"); + } + } + + sealed interface Content permits ContentObject, ContentString {} + + public record ContentObject(String text, String type) implements Content { + static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + ContentObject.class.getSimpleName(), + args -> new ContentObject((String) args[0], (String) args[1]) + ); + + static { + PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("text")); + PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("type")); + } + } + + public record ContentString(String content) implements Content { + public static ContentString of(XContentParser parser) throws IOException { + var content = parser.text(); + return new ContentString(content); + } + } + + public record ToolCall(String id, FunctionField function, String type) { + + static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + ToolCall.class.getSimpleName(), + args -> new ToolCall((String) args[0], (FunctionField) args[1], (String) args[2]) + ); + + static { + PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("id")); + PARSER.declareObject(ConstructingObjectParser.constructorArg(), FunctionField.PARSER::apply, new ParseField("function")); + PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("type")); + } + + public record FunctionField(String arguments, String name) { + static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + FunctionField.class.getSimpleName(), + args -> new FunctionField((String) args[0], (String) args[1]) + ); + + static { + PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("arguments")); + PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("name")); + } + } + } + + public static class Builder { + private Builder() {} + + public Builder setRole(String role) { + return this; + } + + } + + private static void moveToFirstToken(XContentParser parser) throws IOException { + if (parser.currentToken() == null) { + parser.nextToken(); + } + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java index 0c8f4755bf5b8..8459dbff75f36 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java @@ -175,18 +175,25 @@ private void inferOnService( InferenceService service, ActionListener listener ) { - Runnable a = () -> service.infer( - model, - request.getQuery(), - request.getInput(), - request.isStreaming(), - request.getTaskSettings(), - request.getInputType(), - request.getInferenceTimeout(), - listener - ); + Runnable inferenceRunnable = inferRunnable(model, request, service, listener); + if (request.isStreaming() == false || service.canStream(request.getTaskType())) { - service.infer( + inferenceRunnable.run(); + } else { + listener.onFailure(unsupportedStreamingTaskException(request, service)); + } + } + + private static Runnable inferRunnable( + Model model, + InferenceAction.Request request, + InferenceService service, + ActionListener listener + ) { + return request.isUnifiedCompletionMode() + // TODO add parameters + ? () -> service.completionInfer(model, null, request.getInferenceTimeout(), listener) + : () -> service.infer( model, request.getQuery(), request.getInput(), @@ -196,9 +203,6 @@ private void inferOnService( request.getInferenceTimeout(), listener ); - } else { - listener.onFailure(unsupportedStreamingTaskException(request, service)); - } } private ElasticsearchStatusException unsupportedStreamingTaskException(InferenceAction.Request request, InferenceService service) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedCompletionInputs.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CompletionInputs.java similarity index 70% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedCompletionInputs.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CompletionInputs.java index 35dd9990214a1..8f79ceca47b79 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedCompletionInputs.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CompletionInputs.java @@ -7,15 +7,15 @@ package org.elasticsearch.xpack.inference.external.http.sender; -public class UnifiedCompletionInputs extends InferenceInputs { - public static UnifiedCompletionInputs of(InferenceInputs inferenceInputs) { - return InferenceInputs.abc(inferenceInputs, UnifiedCompletionInputs.class); +public class CompletionInputs extends InferenceInputs { + public static CompletionInputs of(InferenceInputs inferenceInputs) { + return InferenceInputs.abc(inferenceInputs, CompletionInputs.class); } private final Object parameters; private final boolean stream; - public UnifiedCompletionInputs(Object parameters) { + public CompletionInputs(Object parameters) { super(); this.parameters = parameters; // TODO retrieve this from the parameters eventually diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiUnifiedCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiUnifiedCompletionRequestManager.java new file mode 100644 index 0000000000000..4d8695475ff6b --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiUnifiedCompletionRequestManager.java @@ -0,0 +1,61 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.openai.OpenAiChatCompletionResponseHandler; +import org.elasticsearch.xpack.inference.external.request.openai.OpenAiChatCompletionRequest; +import org.elasticsearch.xpack.inference.external.response.openai.OpenAiChatCompletionResponseEntity; +import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel; + +import java.util.Objects; +import java.util.function.Supplier; + +// TODO remove this class and instead create a +public class OpenAiUnifiedCompletionRequestManager extends OpenAiRequestManager { + + private static final Logger logger = LogManager.getLogger(OpenAiUnifiedCompletionRequestManager.class); + + private static final ResponseHandler HANDLER = createCompletionHandler(); + + public static OpenAiUnifiedCompletionRequestManager of(OpenAiChatCompletionModel model, ThreadPool threadPool) { + return new OpenAiUnifiedCompletionRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool)); + } + + private final OpenAiChatCompletionModel model; + + private OpenAiUnifiedCompletionRequestManager(OpenAiChatCompletionModel model, ThreadPool threadPool) { + super(threadPool, model, OpenAiChatCompletionRequest::buildDefaultUri); + this.model = Objects.requireNonNull(model); + } + + @Override + public void execute( + InferenceInputs inferenceInputs, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + ActionListener listener + ) { + var docsOnly = DocumentsOnlyInput.of(inferenceInputs); + var docsInput = docsOnly.getInputs(); + var stream = docsOnly.stream(); + OpenAiChatCompletionRequest request = new OpenAiChatCompletionRequest(docsInput, model, stream); + + execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); + } + + private static ResponseHandler createCompletionHandler() { + return new OpenAiChatCompletionResponseHandler("openai completion", OpenAiChatCompletionResponseEntity::fromResponse); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java index 94fb4e59f3c79..8340852866db9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java @@ -7,9 +7,11 @@ package org.elasticsearch.xpack.inference.services; +import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.core.IOUtils; import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.Strings; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInferenceServiceResults; import org.elasticsearch.inference.ChunkingOptions; @@ -18,12 +20,13 @@ import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.inference.external.http.sender.CompletionInputs; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; import org.elasticsearch.xpack.inference.external.http.sender.Sender; -import org.elasticsearch.xpack.inference.external.http.sender.UnifiedCompletionInputs; import java.io.IOException; import java.util.EnumSet; @@ -63,17 +66,27 @@ public void infer( ActionListener listener ) { init(); - if (query != null) { - doInfer(model, new QueryAndDocsInputs(query, input, stream), taskSettings, inputType, timeout, listener); - } else { - doInfer(model, new DocumentsOnlyInput(input, stream), taskSettings, inputType, timeout, listener); - } + var inferenceInput = createInput(model, input, query, stream); + doInfer(model, inferenceInput, taskSettings, inputType, timeout, listener); + } + + private static InferenceInputs createInput(Model model, List input, @Nullable String query, boolean stream) { + return switch (model.getTaskType()) { + // TODO implement parameters + case COMPLETION -> new CompletionInputs(null); + case RERANK -> new QueryAndDocsInputs(query, input, stream); + case TEXT_EMBEDDING -> new DocumentsOnlyInput(input, stream); + default -> throw new ElasticsearchStatusException( + Strings.format("Invalid task type received when determining input type: [%s]", model.getTaskType().toString()), + RestStatus.BAD_REQUEST + ); + }; } @Override public void completionInfer(Model model, Object parameters, TimeValue timeout, ActionListener listener) { init(); - doUnifiedCompletionInfer(model, new UnifiedCompletionInputs(parameters), timeout, listener); + doUnifiedCompletionInfer(model, new CompletionInputs(parameters), timeout, listener); } @Override @@ -103,7 +116,7 @@ protected abstract void doInfer( protected abstract void doUnifiedCompletionInfer( Model model, - UnifiedCompletionInputs inputs, + CompletionInputs inputs, TimeValue timeout, ActionListener listener ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java index b463635c0d2f9..538fba607b4cb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java @@ -35,11 +35,11 @@ import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction; import org.elasticsearch.xpack.inference.external.action.openai.OpenAiActionCreator; +import org.elasticsearch.xpack.inference.external.http.sender.CompletionInputs; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.http.sender.OpenAiCompletionRequestManager; -import org.elasticsearch.xpack.inference.external.http.sender.UnifiedCompletionInputs; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; @@ -266,7 +266,7 @@ public void doInfer( @Override public void doUnifiedCompletionInfer( Model model, - UnifiedCompletionInputs inputs, + CompletionInputs inputs, TimeValue timeout, ActionListener listener ) { From bd5df97f4e8a38b5a9f62390f2406f2ff12141a4 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Fri, 22 Nov 2024 14:21:16 -0500 Subject: [PATCH 03/53] Writeable changes to schema --- .../action/UnifiedCompletionRequest.java | 422 ++++++++++++++++-- 1 file changed, 376 insertions(+), 46 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequest.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequest.java index 32deae60c0749..e62b401d6a0cc 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequest.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequest.java @@ -7,6 +7,11 @@ package org.elasticsearch.xpack.core.inference.action; +import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentParserUtils; import org.elasticsearch.core.Nullable; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ObjectParser; @@ -16,28 +21,95 @@ import java.io.IOException; import java.util.List; +import java.util.Map; -public class UnifiedCompletionRequest { +import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; - static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( +public record UnifiedCompletionRequest( + List messages, + @Nullable String model, + @Nullable Long maxCompletionTokens, + @Nullable Integer n, + @Nullable Stop stop, + @Nullable Boolean stream, + @Nullable Float temperature, + @Nullable ToolChoice toolChoice, + @Nullable Tool tool, + @Nullable Float topP, + @Nullable String user +) implements Writeable { + + static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( InferenceAction.NAME, - args -> new Thing(args[0]) + args -> new UnifiedCompletionRequest( + (List) args[0], + (String) args[1], + (Long) args[2], + (Integer) args[3], + (Stop) args[4], + (Boolean) args[5], + (Float) args[6], + (ToolChoice) args[7], + (Tool) args[8], + (Float) args[9], + (String) args[10] + ) ); static { - + PARSER.declareObjectArray(constructorArg(), Message.PARSER::apply, new ParseField("messages")); + PARSER.declareString(optionalConstructorArg(), new ParseField("model")); + PARSER.declareLong(optionalConstructorArg(), new ParseField("max_tokens")); + PARSER.declareInt(optionalConstructorArg(), new ParseField("n")); + PARSER.declareField(optionalConstructorArg(), (p, c) -> parseStop(p), new ParseField("stop"), ObjectParser.ValueType.VALUE); + PARSER.declareBoolean(optionalConstructorArg(), new ParseField("stream")); + PARSER.declareFloat(optionalConstructorArg(), new ParseField("temperature")); + PARSER.declareField( + optionalConstructorArg(), + (p, c) -> parseToolChoice(p), + new ParseField("tool_choice"), + ObjectParser.ValueType.VALUE + ); + PARSER.declareObjectArray(optionalConstructorArg(), Tool.PARSER::apply, new ParseField("tool")); + PARSER.declareFloat(optionalConstructorArg(), new ParseField("top_p")); + PARSER.declareString(optionalConstructorArg(), new ParseField("user")); } - public static class Thing { - private final Object obj; + public UnifiedCompletionRequest(StreamInput in) throws IOException { + this( + in.readCollectionAsImmutableList(Message::new), + in.readOptionalString(), + in.readOptionalVLong(), + in.readOptionalVInt(), + in.readOptionalNamedWriteable(Stop.class), + in.readOptionalBoolean(), + in.readOptionalFloat(), + in.readOptionalNamedWriteable(ToolChoice.class), + in.readOptionalWriteable(Tool::new), + in.readOptionalFloat(), + in.readOptionalString() + ); + } - Thing(Object obj) { - this.obj = obj; - } + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeCollection(messages); + out.writeOptionalString(model); + out.writeOptionalVLong(maxCompletionTokens); + out.writeOptionalVInt(n); + out.writeOptionalNamedWriteable(stop); + out.writeOptionalBoolean(stream); + out.writeOptionalFloat(temperature); + out.writeOptionalNamedWriteable(toolChoice); + out.writeOptionalWriteable(tool); + out.writeOptionalFloat(topP); + out.writeOptionalString(user); } - // TODO convert these to static classes instead of record to make transport changes easier in the future - public record Message(Content content, String role, @Nullable String name, @Nullable String toolCallId, List toolCalls) { + public record Message(Content content, String role, @Nullable String name, @Nullable String toolCallId, List toolCalls) + implements + Writeable { @SuppressWarnings("unchecked") static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( @@ -45,56 +117,112 @@ public record Message(Content content, String role, @Nullable String name, @Null args -> new Message((Content) args[0], (String) args[1], (String) args[2], (String) args[3], (List) args[4]) ); static { - PARSER.declareField( - ConstructingObjectParser.constructorArg(), - (p, c) -> parseContent(p), - new ParseField("content"), - ObjectParser.ValueType.VALUE - ); - PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("role")); - PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField("name")); - PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField("tool_call_id")); - PARSER.declareObjectArray( - ConstructingObjectParser.optionalConstructorArg(), - ToolCall.PARSER::apply, - new ParseField("tool_calls") - ); + PARSER.declareField(constructorArg(), (p, c) -> parseContent(p), new ParseField("content"), ObjectParser.ValueType.VALUE); + PARSER.declareString(constructorArg(), new ParseField("role")); + PARSER.declareString(optionalConstructorArg(), new ParseField("name")); + PARSER.declareString(optionalConstructorArg(), new ParseField("tool_call_id")); + PARSER.declareObjectArray(optionalConstructorArg(), ToolCall.PARSER::apply, new ParseField("tool_calls")); } private static Content parseContent(XContentParser parser) throws IOException { var token = parser.currentToken(); - if (token == XContentParser.Token.START_OBJECT) { - return ContentObject.PARSER.parse(parser, null); + if (token == XContentParser.Token.START_ARRAY) { + var parsedContentObjects = XContentParserUtils.parseList(parser, (p) -> ContentObject.PARSER.apply(p, null)); + return new ContentObjects(parsedContentObjects); } else if (token == XContentParser.Token.VALUE_STRING) { return ContentString.of(parser); } throw new XContentParseException("Unsupported token [" + token + "]"); } + + public Message(StreamInput in) throws IOException { + this( + in.readNamedWriteable(Content.class), + in.readString(), + in.readOptionalString(), + in.readOptionalString(), + in.readCollectionAsImmutableList(ToolCall::new) + ); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeNamedWriteable(content); + out.writeString(role); + out.writeOptionalString(name); + out.writeOptionalString(toolCallId); + out.writeCollection(toolCalls); + } } - sealed interface Content permits ContentObject, ContentString {} + public sealed interface Content extends NamedWriteable permits ContentObjects, ContentString {} + + public record ContentObjects(List contentObjects) implements Content, Writeable { - public record ContentObject(String text, String type) implements Content { + public static final String NAME = "content_objects"; + + public ContentObjects(StreamInput in) throws IOException { + this(in.readCollectionAsImmutableList(ContentObject::new)); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeCollection(contentObjects); + } + + @Override + public String getWriteableName() { + return NAME; + } + } + + public record ContentObject(String text, String type) implements Writeable { static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( ContentObject.class.getSimpleName(), args -> new ContentObject((String) args[0], (String) args[1]) ); static { - PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("text")); - PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("type")); + PARSER.declareString(constructorArg(), new ParseField("text")); + PARSER.declareString(constructorArg(), new ParseField("type")); + } + + public ContentObject(StreamInput in) throws IOException { + this(in.readString(), in.readString()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(text); + out.writeString(type); } } - public record ContentString(String content) implements Content { + public record ContentString(String content) implements Content, NamedWriteable { + public static final String NAME = "content_string"; + public static ContentString of(XContentParser parser) throws IOException { var content = parser.text(); return new ContentString(content); } + + public ContentString(StreamInput in) throws IOException { + this(in.readString()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(content); + } + + @Override + public String getWriteableName() { + return NAME; + } } - public record ToolCall(String id, FunctionField function, String type) { + public record ToolCall(String id, FunctionField function, String type) implements Writeable { static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( ToolCall.class.getSimpleName(), @@ -102,37 +230,239 @@ public record ToolCall(String id, FunctionField function, String type) { ); static { - PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("id")); - PARSER.declareObject(ConstructingObjectParser.constructorArg(), FunctionField.PARSER::apply, new ParseField("function")); - PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("type")); + PARSER.declareString(constructorArg(), new ParseField("id")); + PARSER.declareObject(constructorArg(), FunctionField.PARSER::apply, new ParseField("function")); + PARSER.declareString(constructorArg(), new ParseField("type")); + } + + public ToolCall(StreamInput in) throws IOException { + this(in.readString(), new FunctionField(in), in.readString()); } - public record FunctionField(String arguments, String name) { + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(id); + function.writeTo(out); + out.writeString(type); + } + + public record FunctionField(String arguments, String name) implements Writeable { static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( FunctionField.class.getSimpleName(), args -> new FunctionField((String) args[0], (String) args[1]) ); static { - PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("arguments")); - PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("name")); + PARSER.declareString(constructorArg(), new ParseField("arguments")); + PARSER.declareString(constructorArg(), new ParseField("name")); + } + + public FunctionField(StreamInput in) throws IOException { + this(in.readString(), in.readString()); } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(arguments); + out.writeString(name); + } + } + } + + private static Stop parseStop(XContentParser parser) throws IOException { + var token = parser.currentToken(); + if (token == XContentParser.Token.START_ARRAY) { + var parsedStopValues = XContentParserUtils.parseList(parser, XContentParser::text); + return new StopValues(parsedStopValues); + } else if (token == XContentParser.Token.VALUE_STRING) { + return StopString.of(parser); + } + + throw new XContentParseException("Unsupported token [" + token + "]"); + } + + public sealed interface Stop extends NamedWriteable permits StopString, StopValues {} + + public record StopString(String value) implements Stop, NamedWriteable { + public static final String NAME = "stop_string"; + + public static StopString of(XContentParser parser) throws IOException { + var content = parser.text(); + return new StopString(content); + } + + public StopString(StreamInput in) throws IOException { + this(in.readString()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(value); + } + + @Override + public String getWriteableName() { + return NAME; + } + } + + public record StopValues(List values) implements Stop, NamedWriteable { + public static final String NAME = "stop_values"; + + public StopValues(StreamInput in) throws IOException { + this(in.readStringCollectionAsImmutableList()); } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeStringCollection(values); + } + + @Override + public String getWriteableName() { + return NAME; + } + } + + private static ToolChoice parseToolChoice(XContentParser parser) throws IOException { + var token = parser.currentToken(); + if (token == XContentParser.Token.START_OBJECT) { + return ToolChoiceObject.PARSER.apply(parser, null); + } else if (token == XContentParser.Token.VALUE_STRING) { + return ToolChoiceString.of(parser); + } + + throw new XContentParseException("Unsupported token [" + token + "]"); } - public static class Builder { - private Builder() {} + public sealed interface ToolChoice extends NamedWriteable permits ToolChoiceObject, ToolChoiceString {} - public Builder setRole(String role) { - return this; + public record ToolChoiceObject(String type, FunctionField function) implements ToolChoice, NamedWriteable { + + public static final String NAME = "tool_choice_object"; + + static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + ToolChoiceObject.class.getSimpleName(), + args -> new ToolChoiceObject((String) args[0], (FunctionField) args[1]) + ); + + static { + PARSER.declareString(constructorArg(), new ParseField("type")); + PARSER.declareObject(constructorArg(), ToolCall.FunctionField.PARSER::apply, new ParseField("function")); + } + + public ToolChoiceObject(StreamInput in) throws IOException { + this(in.readString(), new FunctionField(in)); } + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(type); + function.writeTo(out); + } + + @Override + public String getWriteableName() { + return NAME; + } + + public record FunctionField(String name) implements Writeable { + static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + FunctionField.class.getSimpleName(), + args -> new FunctionField((String) args[0]) + ); + + static { + PARSER.declareString(constructorArg(), new ParseField("name")); + } + + public FunctionField(StreamInput in) throws IOException { + this(in.readString()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(name); + } + } } - private static void moveToFirstToken(XContentParser parser) throws IOException { - if (parser.currentToken() == null) { - parser.nextToken(); + public record ToolChoiceString(String value) implements ToolChoice, NamedWriteable { + public static final String NAME = "tool_choice_string"; + + public static ToolChoiceString of(XContentParser parser) throws IOException { + var content = parser.text(); + return new ToolChoiceString(content); + } + + public ToolChoiceString(StreamInput in) throws IOException { + this(in.readString()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(value); + } + + @Override + public String getWriteableName() { + return NAME; } } + public record Tool(String type, FunctionField function) implements Writeable { + + static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + Tool.class.getSimpleName(), + args -> new Tool((String) args[0], (FunctionField) args[1]) + ); + + static { + PARSER.declareString(constructorArg(), new ParseField("type")); + PARSER.declareObject(constructorArg(), FunctionField.PARSER::apply, new ParseField("function")); + } + + public Tool(StreamInput in) throws IOException { + this(in.readString(), new FunctionField(in)); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(type); + function.writeTo(out); + } + + public record FunctionField( + @Nullable String description, + String name, + @Nullable Map parameters, + @Nullable Boolean strict + ) implements Writeable { + + @SuppressWarnings("unchecked") + static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + FunctionField.class.getSimpleName(), + args -> new FunctionField((String) args[0], (String) args[1], (Map) args[2], (Boolean) args[3]) + ); + + static { + PARSER.declareString(optionalConstructorArg(), new ParseField("description")); + PARSER.declareString(constructorArg(), new ParseField("name")); + PARSER.declareObject(optionalConstructorArg(), (p, c) -> p.mapOrdered(), new ParseField("name")); + PARSER.declareBoolean(optionalConstructorArg(), new ParseField("strict")); + } + + public FunctionField(StreamInput in) throws IOException { + this(in.readOptionalString(), in.readString(), in.readGenericMap(), in.readOptionalBoolean()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalString(description); + out.writeString(name); + out.writeGenericMap(parameters); + out.writeOptionalBoolean(strict); + } + } + } } From bd59543cadc478e8fb4a52feef3aacbfd7a22233 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Fri, 22 Nov 2024 16:34:14 -0500 Subject: [PATCH 04/53] Working parsing tests --- .../org/elasticsearch/test/ESTestCase.java | 12 + .../action/UnifiedCompletionRequest.java | 30 +- .../action/InferenceActionRequestTests.java | 21 ++ .../action/UnifiedCompletionRequestTests.java | 292 ++++++++++++++++++ 4 files changed, 341 insertions(+), 14 deletions(-) create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequestTests.java diff --git a/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java index 207409dfcf751..67dc36cb29b6b 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java @@ -1209,6 +1209,18 @@ public static String randomNullOrAlphaOfLength(int codeUnits) { return randomBoolean() ? null : randomAlphaOfLength(codeUnits); } + public static Long randomNullOrLong() { + return randomBoolean() ? null : randomLong(); + } + + public static Integer randomNullOrInt() { + return randomBoolean() ? null : randomInt(); + } + + public static Float randomNullOrFloat() { + return randomBoolean() ? null : randomFloat(); + } + /** * Creates a valid random identifier such as node id or index name */ diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequest.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequest.java index e62b401d6a0cc..2799ab6e1b6ef 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequest.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequest.java @@ -35,11 +35,12 @@ public record UnifiedCompletionRequest( @Nullable Boolean stream, @Nullable Float temperature, @Nullable ToolChoice toolChoice, - @Nullable Tool tool, + @Nullable List tool, @Nullable Float topP, @Nullable String user ) implements Writeable { + @SuppressWarnings("unchecked") static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( InferenceAction.NAME, args -> new UnifiedCompletionRequest( @@ -51,7 +52,7 @@ public record UnifiedCompletionRequest( (Boolean) args[5], (Float) args[6], (ToolChoice) args[7], - (Tool) args[8], + (List) args[8], (Float) args[9], (String) args[10] ) @@ -60,18 +61,18 @@ public record UnifiedCompletionRequest( static { PARSER.declareObjectArray(constructorArg(), Message.PARSER::apply, new ParseField("messages")); PARSER.declareString(optionalConstructorArg(), new ParseField("model")); - PARSER.declareLong(optionalConstructorArg(), new ParseField("max_tokens")); + PARSER.declareLong(optionalConstructorArg(), new ParseField("max_completion_tokens")); PARSER.declareInt(optionalConstructorArg(), new ParseField("n")); - PARSER.declareField(optionalConstructorArg(), (p, c) -> parseStop(p), new ParseField("stop"), ObjectParser.ValueType.VALUE); + PARSER.declareField(optionalConstructorArg(), (p, c) -> parseStop(p), new ParseField("stop"), ObjectParser.ValueType.VALUE_ARRAY); PARSER.declareBoolean(optionalConstructorArg(), new ParseField("stream")); PARSER.declareFloat(optionalConstructorArg(), new ParseField("temperature")); PARSER.declareField( optionalConstructorArg(), (p, c) -> parseToolChoice(p), new ParseField("tool_choice"), - ObjectParser.ValueType.VALUE + ObjectParser.ValueType.OBJECT_OR_STRING ); - PARSER.declareObjectArray(optionalConstructorArg(), Tool.PARSER::apply, new ParseField("tool")); + PARSER.declareObjectArray(optionalConstructorArg(), Tool.PARSER::apply, new ParseField("tools")); PARSER.declareFloat(optionalConstructorArg(), new ParseField("top_p")); PARSER.declareString(optionalConstructorArg(), new ParseField("user")); } @@ -86,7 +87,7 @@ public UnifiedCompletionRequest(StreamInput in) throws IOException { in.readOptionalBoolean(), in.readOptionalFloat(), in.readOptionalNamedWriteable(ToolChoice.class), - in.readOptionalWriteable(Tool::new), + in.readCollectionAsImmutableList(Tool::new), in.readOptionalFloat(), in.readOptionalString() ); @@ -102,7 +103,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalBoolean(stream); out.writeOptionalFloat(temperature); out.writeOptionalNamedWriteable(toolChoice); - out.writeOptionalWriteable(tool); + out.writeOptionalCollection(tool); out.writeOptionalFloat(topP); out.writeOptionalString(user); } @@ -116,8 +117,9 @@ public record Message(Content content, String role, @Nullable String name, @Null Message.class.getSimpleName(), args -> new Message((Content) args[0], (String) args[1], (String) args[2], (String) args[3], (List) args[4]) ); + static { - PARSER.declareField(constructorArg(), (p, c) -> parseContent(p), new ParseField("content"), ObjectParser.ValueType.VALUE); + PARSER.declareField(constructorArg(), (p, c) -> parseContent(p), new ParseField("content"), ObjectParser.ValueType.VALUE_ARRAY); PARSER.declareString(constructorArg(), new ParseField("role")); PARSER.declareString(optionalConstructorArg(), new ParseField("name")); PARSER.declareString(optionalConstructorArg(), new ParseField("tool_call_id")); @@ -248,7 +250,7 @@ public void writeTo(StreamOutput out) throws IOException { public record FunctionField(String arguments, String name) implements Writeable { static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( - FunctionField.class.getSimpleName(), + "tool_call_function_field", args -> new FunctionField((String) args[0], (String) args[1]) ); @@ -348,7 +350,7 @@ public record ToolChoiceObject(String type, FunctionField function) implements T static { PARSER.declareString(constructorArg(), new ParseField("type")); - PARSER.declareObject(constructorArg(), ToolCall.FunctionField.PARSER::apply, new ParseField("function")); + PARSER.declareObject(constructorArg(), FunctionField.PARSER::apply, new ParseField("function")); } public ToolChoiceObject(StreamInput in) throws IOException { @@ -368,7 +370,7 @@ public String getWriteableName() { public record FunctionField(String name) implements Writeable { static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( - FunctionField.class.getSimpleName(), + "tool_choice_function_field", args -> new FunctionField((String) args[0]) ); @@ -441,14 +443,14 @@ public record FunctionField( @SuppressWarnings("unchecked") static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( - FunctionField.class.getSimpleName(), + "tool_function_field", args -> new FunctionField((String) args[0], (String) args[1], (Map) args[2], (Boolean) args[3]) ); static { PARSER.declareString(optionalConstructorArg(), new ParseField("description")); PARSER.declareString(constructorArg(), new ParseField("name")); - PARSER.declareObject(optionalConstructorArg(), (p, c) -> p.mapOrdered(), new ParseField("name")); + PARSER.declareObject(optionalConstructorArg(), (p, c) -> p.mapOrdered(), new ParseField("parameters")); PARSER.declareBoolean(optionalConstructorArg(), new ParseField("strict")); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java index a9ca5e6da8720..a7eb9ce3e3fd0 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java @@ -47,6 +47,7 @@ protected InferenceAction.Request createTestInstance() { randomMap(0, 3, () -> new Tuple<>(randomAlphaOfLength(4), randomAlphaOfLength(4))), randomFrom(InputType.values()), TimeValue.timeValueMillis(randomLongBetween(1, 2048)), + false, false ); } @@ -82,6 +83,7 @@ public void testValidation_TextEmbedding() { null, null, null, + false, false ); ActionRequestValidationException e = request.validate(); @@ -97,6 +99,7 @@ public void testValidation_Rerank() { null, null, null, + false, false ); ActionRequestValidationException e = request.validate(); @@ -112,6 +115,7 @@ public void testValidation_TextEmbedding_Null() { null, null, null, + false, false ); ActionRequestValidationException inputNullError = inputNullRequest.validate(); @@ -128,6 +132,7 @@ public void testValidation_TextEmbedding_Empty() { null, null, null, + false, false ); ActionRequestValidationException inputEmptyError = inputEmptyRequest.validate(); @@ -144,6 +149,7 @@ public void testValidation_Rerank_Null() { null, null, null, + false, false ); ActionRequestValidationException queryNullError = queryNullRequest.validate(); @@ -160,6 +166,7 @@ public void testValidation_Rerank_Empty() { null, null, null, + false, false ); ActionRequestValidationException queryEmptyError = queryEmptyRequest.validate(); @@ -193,6 +200,7 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc instance.getTaskSettings(), instance.getInputType(), instance.getInferenceTimeout(), + false, false ); } @@ -204,6 +212,7 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc instance.getTaskSettings(), instance.getInputType(), instance.getInferenceTimeout(), + false, false ); case 2 -> { @@ -217,6 +226,7 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc instance.getTaskSettings(), instance.getInputType(), instance.getInferenceTimeout(), + false, false ); } @@ -236,6 +246,7 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc taskSettings, instance.getInputType(), instance.getInferenceTimeout(), + false, false ); } @@ -249,6 +260,7 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc instance.getTaskSettings(), nextInputType, instance.getInferenceTimeout(), + false, false ); } @@ -260,6 +272,7 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc instance.getTaskSettings(), instance.getInputType(), instance.getInferenceTimeout(), + false, false ); case 6 -> { @@ -276,6 +289,7 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc instance.getTaskSettings(), instance.getInputType(), TimeValue.timeValueMillis(newDuration.plus(additionalTime).toMillis()), + false, false ); } @@ -294,6 +308,7 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque instance.getTaskSettings(), InputType.UNSPECIFIED, InferenceAction.Request.DEFAULT_TIMEOUT, + false, false ); } else if (version.before(TransportVersions.V_8_13_0)) { @@ -305,6 +320,7 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque instance.getTaskSettings(), InputType.UNSPECIFIED, InferenceAction.Request.DEFAULT_TIMEOUT, + false, false ); } else if (version.before(TransportVersions.V_8_13_0) @@ -319,6 +335,7 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque instance.getTaskSettings(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, + false, false ); } else if (version.before(TransportVersions.V_8_13_0) @@ -331,6 +348,7 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque instance.getTaskSettings(), InputType.UNSPECIFIED, InferenceAction.Request.DEFAULT_TIMEOUT, + false, false ); } else if (version.before(TransportVersions.V_8_14_0)) { @@ -342,6 +360,7 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque instance.getTaskSettings(), instance.getInputType(), InferenceAction.Request.DEFAULT_TIMEOUT, + false, false ); } @@ -359,6 +378,7 @@ public void testWriteTo_WhenVersionIsOnAfterUnspecifiedAdded() throws IOExceptio Map.of(), InputType.UNSPECIFIED, InferenceAction.Request.DEFAULT_TIMEOUT, + false, false ), TransportVersions.V_8_13_0 @@ -374,6 +394,7 @@ public void testWriteTo_WhenVersionIsBeforeInputTypeAdded_ShouldSetInputTypeToUn Map.of(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, + false, false ); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequestTests.java new file mode 100644 index 0000000000000..58d21ebef3038 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequestTests.java @@ -0,0 +1,292 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.action; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.xcontent.json.JsonXContent; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.hamcrest.Matchers.is; + +public class UnifiedCompletionRequestTests extends AbstractBWCWireSerializationTestCase { + + public void testParseAllFields() throws IOException { + String requestJson = """ + { + "model": "gpt-4o", + "messages": [ + { + "role": "user", + "content": [ + { + "text": "some text", + "type": "string" + } + ], + "name": "a name", + "tool_call_id": "100", + "tool_calls": [ + { + "id": "call_62136354", + "type": "function", + "function": { + "arguments": "{'order_id': 'order_12345'}", + "name": "get_delivery_date" + } + } + ] + } + ], + "max_completion_tokens": 100, + "n": 1, + "stop": ["stop"], + "stream": true, + "temperature": 0.1, + "tools": [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object" + } + } + } + ], + "tool_choice": { + "type": "function", + "function": { + "name": "some function" + } + }, + "top_p": 0.2, + "user": "user" + } + """; + + try (var parser = createParser(JsonXContent.jsonXContent, requestJson)) { + var request = UnifiedCompletionRequest.PARSER.apply(parser, null); + var expected = new UnifiedCompletionRequest( + List.of( + new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentObjects( + List.of(new UnifiedCompletionRequest.ContentObject("some text", "string")) + ), + "user", + "a name", + "100", + List.of( + new UnifiedCompletionRequest.ToolCall( + "call_62136354", + new UnifiedCompletionRequest.ToolCall.FunctionField("{'order_id': 'order_12345'}", "get_delivery_date"), + "function" + ) + ) + ) + ), + "gpt-4o", + 100L, + 1, + new UnifiedCompletionRequest.StopValues(List.of("stop")), + true, + 0.1F, + new UnifiedCompletionRequest.ToolChoiceObject( + "function", + new UnifiedCompletionRequest.ToolChoiceObject.FunctionField("some function") + ), + List.of( + new UnifiedCompletionRequest.Tool( + "function", + new UnifiedCompletionRequest.Tool.FunctionField( + "Get the current weather in a given location", + "get_current_weather", + Map.of("type", "object"), + null + ) + ) + ), + 0.2F, + "user" + ); + + assertThat(request, is(expected)); + } + } + + public void testParsing() throws IOException { + String requestJson = """ + { + "model": "gpt-4o", + "messages": [ + { + "role": "user", + "content": "What is the weather like in Boston today?" + } + ], + "stop": "none", + "tools": [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object" + } + } + } + ], + "tool_choice": "auto" + } + """; + + try (var parser = createParser(JsonXContent.jsonXContent, requestJson)) { + var request = UnifiedCompletionRequest.PARSER.apply(parser, null); + var expected = new UnifiedCompletionRequest( + List.of( + new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("What is the weather like in Boston today?"), + "user", + null, + null, + null + ) + ), + "gpt-4o", + null, + null, + new UnifiedCompletionRequest.StopString("none"), + null, + null, + new UnifiedCompletionRequest.ToolChoiceString("auto"), + List.of( + new UnifiedCompletionRequest.Tool( + "function", + new UnifiedCompletionRequest.Tool.FunctionField( + "Get the current weather in a given location", + "get_current_weather", + Map.of("type", "object"), + null + ) + ) + ), + null, + null + ); + + assertThat(request, is(expected)); + } + } + + public static UnifiedCompletionRequest randomUnifiedCompletionRequest() { + return new UnifiedCompletionRequest( + randomList(5, UnifiedCompletionRequestTests::randomMessage), + randomNullOrAlphaOfLength(10), + randomNullOrLong(), + randomNullOrInt(), + randomNullOrStop(), + randomOptionalBoolean(), + randomNullOrFloat(), + randomNullOrToolChoice(), + randomList(5, UnifiedCompletionRequestTests::randomTool), + randomNullOrFloat(), + randomNullOrAlphaOfLength(10) + ); + } + + public static UnifiedCompletionRequest.Message randomMessage() { + return new UnifiedCompletionRequest.Message( + randomContent(), + randomAlphaOfLength(10), + randomNullOrAlphaOfLength(10), + randomNullOrAlphaOfLength(10), + randomList(10, UnifiedCompletionRequestTests::randomToolCall) + ); + } + + public static UnifiedCompletionRequest.Content randomContent() { + return randomBoolean() + ? new UnifiedCompletionRequest.ContentString(randomAlphaOfLength(10)) + : new UnifiedCompletionRequest.ContentObjects(randomList(10, UnifiedCompletionRequestTests::randomContentObject)); + } + + public static UnifiedCompletionRequest.ContentObject randomContentObject() { + return new UnifiedCompletionRequest.ContentObject(randomAlphaOfLength(10), randomAlphaOfLength(10)); + } + + public static UnifiedCompletionRequest.ToolCall randomToolCall() { + return new UnifiedCompletionRequest.ToolCall(randomAlphaOfLength(10), randomToolCallFunctionField(), randomAlphaOfLength(10)); + } + + public static UnifiedCompletionRequest.ToolCall.FunctionField randomToolCallFunctionField() { + return new UnifiedCompletionRequest.ToolCall.FunctionField(randomAlphaOfLength(10), randomAlphaOfLength(10)); + } + + public static UnifiedCompletionRequest.Stop randomNullOrStop() { + return randomBoolean() ? randomStop() : null; + } + + public static UnifiedCompletionRequest.Stop randomStop() { + return randomBoolean() + ? new UnifiedCompletionRequest.StopString(randomAlphaOfLength(10)) + : new UnifiedCompletionRequest.StopValues(randomList(5, () -> randomAlphaOfLength(10))); + } + + public static UnifiedCompletionRequest.ToolChoice randomNullOrToolChoice() { + return randomBoolean() ? randomToolChoice() : null; + } + + public static UnifiedCompletionRequest.ToolChoice randomToolChoice() { + return randomBoolean() + ? new UnifiedCompletionRequest.ToolChoiceString(randomAlphaOfLength(10)) + : new UnifiedCompletionRequest.ToolChoiceObject(randomAlphaOfLength(10), randomToolChoiceObjectFunctionField()); + } + + public static UnifiedCompletionRequest.ToolChoiceObject.FunctionField randomToolChoiceObjectFunctionField() { + return new UnifiedCompletionRequest.ToolChoiceObject.FunctionField(randomAlphaOfLength(10)); + } + + public static UnifiedCompletionRequest.Tool randomTool() { + return new UnifiedCompletionRequest.Tool(randomAlphaOfLength(10), randomToolFunctionField()); + } + + public static UnifiedCompletionRequest.Tool.FunctionField randomToolFunctionField() { + return new UnifiedCompletionRequest.Tool.FunctionField( + randomNullOrAlphaOfLength(10), + randomAlphaOfLength(10), + null, + randomOptionalBoolean() + ); + } + + @Override + protected UnifiedCompletionRequest mutateInstanceForVersion(UnifiedCompletionRequest instance, TransportVersion version) { + return instance; + } + + @Override + protected Writeable.Reader instanceReader() { + return UnifiedCompletionRequest::new; + } + + @Override + protected UnifiedCompletionRequest createTestInstance() { + return randomUnifiedCompletionRequest(); + } + + @Override + protected UnifiedCompletionRequest mutateInstance(UnifiedCompletionRequest instance) throws IOException { + return randomValueOtherThan(instance, this::createTestInstance); + } +} From 1e30c6d2d147e87a91930c67a961657c3adbf95a Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Mon, 25 Nov 2024 11:45:52 -0500 Subject: [PATCH 05/53] Creating a new action --- .../inference/action/InferenceAction.java | 41 ++- .../action/UnifiedCompletionAction.java | 114 +++++++ .../action/BaseInferenceActionRequest.java | 19 ++ .../action/BaseTransportInferenceAction.java | 284 ++++++++++++++++++ 4 files changed, 457 insertions(+), 1 deletion(-) create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionAction.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseInferenceActionRequest.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java index 9511fbbe011f4..ad8e7871b5ee6 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java @@ -20,6 +20,7 @@ import org.elasticsearch.common.xcontent.ChunkedToXContent; import org.elasticsearch.common.xcontent.ChunkedToXContentHelper; import org.elasticsearch.common.xcontent.ChunkedToXContentObject; +import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.inference.InferenceServiceResults; @@ -94,6 +95,7 @@ public static Builder parseRequest(String inferenceEntityId, TaskType taskType, private final TimeValue inferenceTimeout; private final boolean stream; private final boolean isUnifiedCompletionMode; + private final UnifiedCompletionRequest unifiedCompletionRequest; public Request( TaskType taskType, @@ -105,6 +107,32 @@ public Request( TimeValue inferenceTimeout, boolean stream, boolean isUnifiedCompletionsMode + ) { + this( + taskType, + inferenceEntityId, + query, + input, + taskSettings, + inputType, + inferenceTimeout, + stream, + isUnifiedCompletionsMode, + null + ); + } + + public Request( + TaskType taskType, + String inferenceEntityId, + String query, + List input, + Map taskSettings, + InputType inputType, + TimeValue inferenceTimeout, + boolean stream, + boolean isUnifiedCompletionsMode, + @Nullable UnifiedCompletionRequest unifiedCompletionRequest ) { this.taskType = taskType; this.inferenceEntityId = inferenceEntityId; @@ -115,6 +143,7 @@ public Request( this.inferenceTimeout = inferenceTimeout; this.stream = stream; this.isUnifiedCompletionMode = isUnifiedCompletionsMode; + this.unifiedCompletionRequest = unifiedCompletionRequest; } public Request(StreamInput in) throws IOException { @@ -143,8 +172,10 @@ public Request(StreamInput in) throws IOException { if (in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_UNIFIED_COMPLETIONS_API)) { this.isUnifiedCompletionMode = in.readBoolean(); + this.unifiedCompletionRequest = in.readOptionalWriteable(UnifiedCompletionRequest::new); } else { this.isUnifiedCompletionMode = false; + this.unifiedCompletionRequest = null; } // streaming is not supported yet for transport traffic @@ -244,6 +275,7 @@ public void writeTo(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_UNIFIED_COMPLETIONS_API)) { out.writeBoolean(isUnifiedCompletionMode); + out.writeOptionalWriteable(unifiedCompletionRequest); } } @@ -300,6 +332,7 @@ public static class Builder { private TimeValue timeout = DEFAULT_TIMEOUT; private boolean stream = false; private boolean unifiedCompletionMode = false; + private UnifiedCompletionRequest unifiedCompletionRequest; private Builder() {} @@ -352,6 +385,11 @@ public Builder setUnifiedCompletionMode(boolean unified) { return this; } + public Builder setUnifiedCompletionRequest(UnifiedCompletionRequest unifiedCompletionRequest) { + this.unifiedCompletionRequest = unifiedCompletionRequest; + return this; + } + public Request build() { return new Request( taskType, @@ -362,7 +400,8 @@ public Request build() { inputType, timeout, stream, - unifiedCompletionMode + unifiedCompletionMode, + unifiedCompletionRequest ); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionAction.java new file mode 100644 index 0000000000000..adc6a343ea782 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionAction.java @@ -0,0 +1,114 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.action; + +import org.elasticsearch.action.ActionRequest; +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.ActionType; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Objects; + +public class UnifiedCompletionAction extends ActionType { + public static final UnifiedCompletionAction INSTANCE = new UnifiedCompletionAction(); + public static final String NAME = "cluster:monitor/xpack/inference/unified"; + + public UnifiedCompletionAction() { + super(NAME); + } + + public static class Request extends ActionRequest { + public static Request parseRequest(String inferenceEntityId, TaskType taskType, XContentParser parser) throws IOException { + var unifiedRequest = UnifiedCompletionRequest.PARSER.apply(parser, null); + return new Request(inferenceEntityId, taskType, unifiedRequest); + } + + private final String inferenceEntityId; + private final TaskType taskType; + private final UnifiedCompletionRequest unifiedCompletionRequest; + + public Request(String inferenceEntityId, TaskType taskType, UnifiedCompletionRequest unifiedCompletionRequest) { + this.inferenceEntityId = Objects.requireNonNull(inferenceEntityId); + this.taskType = Objects.requireNonNull(taskType); + this.unifiedCompletionRequest = Objects.requireNonNull(unifiedCompletionRequest); + } + + public Request(StreamInput in) throws IOException { + super(in); + this.inferenceEntityId = in.readString(); + this.taskType = TaskType.fromStream(in); + this.unifiedCompletionRequest = new UnifiedCompletionRequest(in); + } + + public TaskType getTaskType() { + return taskType; + } + + public String getInferenceEntityId() { + return inferenceEntityId; + } + + public UnifiedCompletionRequest getUnifiedCompletionRequest() { + return unifiedCompletionRequest; + } + + public boolean isStreaming() { + return Objects.requireNonNullElse(unifiedCompletionRequest.stream(), false); + } + + @Override + public ActionRequestValidationException validate() { + if (unifiedCompletionRequest == null || unifiedCompletionRequest.messages() == null) { + var e = new ActionRequestValidationException(); + e.addValidationError("Field [messages] cannot be null"); + return e; + } + + if (unifiedCompletionRequest.messages().isEmpty()) { + var e = new ActionRequestValidationException(); + e.addValidationError("Field [messages] cannot be an empty array"); + return e; + } + + if (taskType != TaskType.COMPLETION) { + var e = new ActionRequestValidationException(); + e.addValidationError("Field [taskType] must be [completion]"); + return e; + } + + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(inferenceEntityId); + taskType.writeTo(out); + unifiedCompletionRequest.writeTo(out); + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + Request request = (Request) o; + return Objects.equals(inferenceEntityId, request.inferenceEntityId) + && taskType == request.taskType + && Objects.equals(unifiedCompletionRequest, request.unifiedCompletionRequest); + } + + @Override + public int hashCode() { + return Objects.hash(inferenceEntityId, taskType, unifiedCompletionRequest); + } + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseInferenceActionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseInferenceActionRequest.java new file mode 100644 index 0000000000000..5d3b470edc933 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseInferenceActionRequest.java @@ -0,0 +1,19 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.action; + +import org.elasticsearch.action.ActionRequest; +import org.elasticsearch.inference.TaskType; + +public abstract class BaseInferenceActionRequest extends ActionRequest { + public abstract boolean isStreaming(); + + public abstract TaskType getTaskType(); + + public abstract String getInferenceEntityId(); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java new file mode 100644 index 0000000000000..7fb813ebd777f --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java @@ -0,0 +1,284 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.action; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.HandledTransportAction; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.common.xcontent.ChunkedToXContent; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnparsedModel; +import org.elasticsearch.injection.guice.Inject; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; +import org.elasticsearch.xpack.inference.common.DelegatingProcessor; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.telemetry.InferenceStats; +import org.elasticsearch.xpack.inference.telemetry.InferenceTimer; + +import java.util.stream.Collectors; + +import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.modelAttributes; +import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.responseAttributes; + +public abstract class BaseTransportInferenceAction extends HandledTransportAction< + T, + InferenceAction.Response> { + + private static final Logger log = LogManager.getLogger(BaseTransportInferenceAction.class); + private static final String STREAMING_INFERENCE_TASK_TYPE = "streaming_inference"; + private static final String STREAMING_TASK_ACTION = "xpack/inference/streaming_inference[n]"; + private final ModelRegistry modelRegistry; + private final InferenceServiceRegistry serviceRegistry; + private final InferenceStats inferenceStats; + private final StreamingTaskManager streamingTaskManager; + + // TODO remove the inject here? + @Inject + public BaseTransportInferenceAction( + TransportService transportService, + ActionFilters actionFilters, + ModelRegistry modelRegistry, + InferenceServiceRegistry serviceRegistry, + InferenceStats inferenceStats, + StreamingTaskManager streamingTaskManager, + Writeable.Reader requestReader + ) { + super(InferenceAction.NAME, transportService, actionFilters, requestReader, EsExecutors.DIRECT_EXECUTOR_SERVICE); + this.modelRegistry = modelRegistry; + this.serviceRegistry = serviceRegistry; + this.inferenceStats = inferenceStats; + this.streamingTaskManager = streamingTaskManager; + } + + @Override + protected void doExecute(Task task, T request, ActionListener listener) { + var timer = InferenceTimer.start(); + + var getModelListener = ActionListener.wrap((UnparsedModel unparsedModel) -> { + var service = serviceRegistry.getService(unparsedModel.service()); + if (service.isEmpty()) { + var e = unknownServiceException(unparsedModel.service(), request.getInferenceEntityId()); + recordMetrics(unparsedModel, timer, e); + listener.onFailure(e); + return; + } + + if (request.getTaskType().isAnyOrSame(unparsedModel.taskType()) == false) { + // not the wildcard task type and not the model task type + var e = incompatibleTaskTypeException(request.getTaskType(), unparsedModel.taskType()); + recordMetrics(unparsedModel, timer, e); + listener.onFailure(e); + return; + } + + if (isInvalidTaskTypeForInferenceEndpoint(request, unparsedModel)) { + var e = incompatibleUnifiedModeTaskTypeException(request.getTaskType()); + recordMetrics(unparsedModel, timer, e); + listener.onFailure(e); + return; + } + + var model = service.get() + .parsePersistedConfigWithSecrets( + unparsedModel.inferenceEntityId(), + unparsedModel.taskType(), + unparsedModel.settings(), + unparsedModel.secrets() + ); + inferOnServiceWithMetrics(model, request, service.get(), timer, listener); + }, e -> { + try { + inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(e)); + } catch (Exception metricsException) { + log.atDebug().withThrowable(metricsException).log("Failed to record metrics when the model is missing, dropping metrics"); + } + listener.onFailure(e); + }); + + modelRegistry.getModelWithSecrets(request.getInferenceEntityId(), getModelListener); + } + + protected abstract boolean isInvalidTaskTypeForInferenceEndpoint(T request, UnparsedModel unparsedModel); + + protected abstract ElasticsearchStatusException createIncompatibleTaskTypeException(T request, UnparsedModel unparsedModel); + + private boolean isInvalidTaskTypeForUnifiedCompletionMode(T request, UnparsedModel unparsedModel) { + return request.isUnifiedCompletionMode() && request.getTaskType() != TaskType.COMPLETION; + } + + private static ElasticsearchStatusException incompatibleUnifiedModeTaskTypeException(TaskType requested) { + return new ElasticsearchStatusException( + "Incompatible task_type for unified API, the requested type [{}] must be one of [{}]", + RestStatus.BAD_REQUEST, + requested, + TaskType.COMPLETION.toString() + ); + } + + private void recordMetrics(UnparsedModel model, InferenceTimer timer, @Nullable Throwable t) { + try { + inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(model, t)); + } catch (Exception e) { + log.atDebug().withThrowable(e).log("Failed to record metrics with an unparsed model, dropping metrics"); + } + } + + private void inferOnServiceWithMetrics( + Model model, + InferenceAction.Request request, + InferenceService service, + InferenceTimer timer, + ActionListener listener + ) { + inferenceStats.requestCount().incrementBy(1, modelAttributes(model)); + inferOnService(model, request, service, ActionListener.wrap(inferenceResults -> { + if (request.isStreaming()) { + var taskProcessor = streamingTaskManager.create(STREAMING_INFERENCE_TASK_TYPE, STREAMING_TASK_ACTION); + inferenceResults.publisher().subscribe(taskProcessor); + + var instrumentedStream = new PublisherWithMetrics(timer, model); + taskProcessor.subscribe(instrumentedStream); + + listener.onResponse(new InferenceAction.Response(inferenceResults, instrumentedStream)); + } else { + recordMetrics(model, timer, null); + listener.onResponse(new InferenceAction.Response(inferenceResults)); + } + }, e -> { + recordMetrics(model, timer, e); + listener.onFailure(e); + })); + } + + private void recordMetrics(Model model, InferenceTimer timer, @Nullable Throwable t) { + try { + inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(model, t)); + } catch (Exception e) { + log.atDebug().withThrowable(e).log("Failed to record metrics with a parsed model, dropping metrics"); + } + } + + private void inferOnService( + Model model, + InferenceAction.Request request, + InferenceService service, + ActionListener listener + ) { + Runnable inferenceRunnable = inferRunnable(model, request, service, listener); + + if (request.isStreaming() == false || service.canStream(request.getTaskType())) { + inferenceRunnable.run(); + } else { + listener.onFailure(unsupportedStreamingTaskException(request, service)); + } + } + + private static Runnable inferRunnable( + Model model, + InferenceAction.Request request, + InferenceService service, + ActionListener listener + ) { + return request.isUnifiedCompletionMode() + // TODO add parameters + ? () -> service.completionInfer(model, null, request.getInferenceTimeout(), listener) + : () -> service.infer( + model, + request.getQuery(), + request.getInput(), + request.isStreaming(), + request.getTaskSettings(), + request.getInputType(), + request.getInferenceTimeout(), + listener + ); + } + + private ElasticsearchStatusException unsupportedStreamingTaskException(InferenceAction.Request request, InferenceService service) { + var supportedTasks = service.supportedStreamingTasks(); + if (supportedTasks.isEmpty()) { + return new ElasticsearchStatusException( + format("Streaming is not allowed for service [%s].", service.name()), + RestStatus.METHOD_NOT_ALLOWED + ); + } else { + var validTasks = supportedTasks.stream().map(TaskType::toString).collect(Collectors.joining(",")); + return new ElasticsearchStatusException( + format( + "Streaming is not allowed for service [%s] and task [%s]. Supported tasks: [%s]", + service.name(), + request.getTaskType(), + validTasks + ), + RestStatus.METHOD_NOT_ALLOWED + ); + } + } + + private static ElasticsearchStatusException unknownServiceException(String service, String inferenceId) { + return new ElasticsearchStatusException("Unknown service [{}] for model [{}]. ", RestStatus.BAD_REQUEST, service, inferenceId); + } + + private static ElasticsearchStatusException incompatibleTaskTypeException(TaskType requested, TaskType expected) { + return new ElasticsearchStatusException( + "Incompatible task_type, the requested type [{}] does not match the model type [{}]", + RestStatus.BAD_REQUEST, + requested, + expected + ); + } + + private class PublisherWithMetrics extends DelegatingProcessor { + + private final InferenceTimer timer; + private final Model model; + + private PublisherWithMetrics(InferenceTimer timer, Model model) { + this.timer = timer; + this.model = model; + } + + @Override + protected void next(ChunkedToXContent item) { + downstream().onNext(item); + } + + @Override + public void onError(Throwable throwable) { + recordMetrics(model, timer, throwable); + super.onError(throwable); + } + + @Override + protected void onCancel() { + recordMetrics(model, timer, null); + super.onCancel(); + } + + @Override + public void onComplete() { + recordMetrics(model, timer, null); + super.onComplete(); + } + } +} From 284694263183c237497a5a6f386196d12dba9dab Mon Sep 17 00:00:00 2001 From: Max Hniebergall Date: Mon, 25 Nov 2024 15:53:55 -0500 Subject: [PATCH 06/53] Add outbound request writing (WIP) --- .../action/UnifiedCompletionRequest.java | 33 ++- .../OpenAiCompletionRequestManager.java | 13 +- ...OpenAiUnifiedCompletionRequestManager.java | 61 ------ .../http/sender/UnifiedChatInput.java | 40 ++++ ...> OpenAiUnifiedChatCompletionRequest.java} | 21 +- ...nAiUnifiedChatCompletionRequestEntity.java | 200 ++++++++++++++++++ 6 files changed, 284 insertions(+), 84 deletions(-) delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiUnifiedCompletionRequestManager.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInput.java rename x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/{OpenAiChatCompletionRequest.java => OpenAiUnifiedChatCompletionRequest.java} (78%) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequest.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequest.java index 2799ab6e1b6ef..0e163792f8b06 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequest.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequest.java @@ -16,6 +16,8 @@ import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ObjectParser; import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParseException; import org.elasticsearch.xcontent.XContentParser; @@ -40,6 +42,10 @@ public record UnifiedCompletionRequest( @Nullable String user ) implements Writeable { + public sealed interface Content extends NamedWriteable permits ContentObjects, ContentString { + void toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException; + } + @SuppressWarnings("unchecked") static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( InferenceAction.NAME, @@ -158,8 +164,6 @@ public void writeTo(StreamOutput out) throws IOException { } } - public sealed interface Content extends NamedWriteable permits ContentObjects, ContentString {} - public record ContentObjects(List contentObjects) implements Content, Writeable { public static final String NAME = "content_objects"; @@ -173,6 +177,17 @@ public void writeTo(StreamOutput out) throws IOException { out.writeCollection(contentObjects); } + @Override + public void toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startArray(); + for (ContentObject contentObject : contentObjects) { + builder.startObject(); + contentObject.toXContentObject(builder, params); + builder.endObject(); + } + builder.endArray(); + } + @Override public String getWriteableName() { return NAME; @@ -199,6 +214,14 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(text); out.writeString(type); } + + public XContentBuilder toXContentObject(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(); + builder.field("text", text); + builder.field("type", type); + builder.endObject(); + return builder; + } } public record ContentString(String content) implements Content, NamedWriteable { @@ -222,6 +245,10 @@ public void writeTo(StreamOutput out) throws IOException { public String getWriteableName() { return NAME; } + + public void toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.value(content); + } } public record ToolCall(String id, FunctionField function, String type) implements Writeable { @@ -437,7 +464,7 @@ public void writeTo(StreamOutput out) throws IOException { public record FunctionField( @Nullable String description, String name, - @Nullable Map parameters, + @Nullable Map parameters, // TODO can we parse this as a string? @Nullable Boolean strict ) implements Writeable { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java index cea89332e5bf0..0717eef1e52e5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java @@ -15,7 +15,7 @@ import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; import org.elasticsearch.xpack.inference.external.openai.OpenAiChatCompletionResponseHandler; -import org.elasticsearch.xpack.inference.external.request.openai.OpenAiChatCompletionRequest; +import org.elasticsearch.xpack.inference.external.request.openai.OpenAiUnifiedChatCompletionRequest; import org.elasticsearch.xpack.inference.external.response.openai.OpenAiChatCompletionResponseEntity; import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel; @@ -35,7 +35,7 @@ public static OpenAiCompletionRequestManager of(OpenAiChatCompletionModel model, private final OpenAiChatCompletionModel model; private OpenAiCompletionRequestManager(OpenAiChatCompletionModel model, ThreadPool threadPool) { - super(threadPool, model, OpenAiChatCompletionRequest::buildDefaultUri); + super(threadPool, model, OpenAiUnifiedChatCompletionRequest::buildDefaultUri); this.model = Objects.requireNonNull(model); } @@ -46,10 +46,11 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - var docsOnly = DocumentsOnlyInput.of(inferenceInputs); - var docsInput = docsOnly.getInputs(); - var stream = docsOnly.stream(); - OpenAiChatCompletionRequest request = new OpenAiChatCompletionRequest(docsInput, model, stream); + + OpenAiUnifiedChatCompletionRequest request = new OpenAiUnifiedChatCompletionRequest( + UnifiedChatInput.of(inferenceInputs).getRequestEntity(), + model + ); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiUnifiedCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiUnifiedCompletionRequestManager.java deleted file mode 100644 index 4d8695475ff6b..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiUnifiedCompletionRequestManager.java +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.external.http.sender; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.inference.InferenceServiceResults; -import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; -import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; -import org.elasticsearch.xpack.inference.external.openai.OpenAiChatCompletionResponseHandler; -import org.elasticsearch.xpack.inference.external.request.openai.OpenAiChatCompletionRequest; -import org.elasticsearch.xpack.inference.external.response.openai.OpenAiChatCompletionResponseEntity; -import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel; - -import java.util.Objects; -import java.util.function.Supplier; - -// TODO remove this class and instead create a -public class OpenAiUnifiedCompletionRequestManager extends OpenAiRequestManager { - - private static final Logger logger = LogManager.getLogger(OpenAiUnifiedCompletionRequestManager.class); - - private static final ResponseHandler HANDLER = createCompletionHandler(); - - public static OpenAiUnifiedCompletionRequestManager of(OpenAiChatCompletionModel model, ThreadPool threadPool) { - return new OpenAiUnifiedCompletionRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool)); - } - - private final OpenAiChatCompletionModel model; - - private OpenAiUnifiedCompletionRequestManager(OpenAiChatCompletionModel model, ThreadPool threadPool) { - super(threadPool, model, OpenAiChatCompletionRequest::buildDefaultUri); - this.model = Objects.requireNonNull(model); - } - - @Override - public void execute( - InferenceInputs inferenceInputs, - RequestSender requestSender, - Supplier hasRequestCompletedFunction, - ActionListener listener - ) { - var docsOnly = DocumentsOnlyInput.of(inferenceInputs); - var docsInput = docsOnly.getInputs(); - var stream = docsOnly.stream(); - OpenAiChatCompletionRequest request = new OpenAiChatCompletionRequest(docsInput, model, stream); - - execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); - } - - private static ResponseHandler createCompletionHandler() { - return new OpenAiChatCompletionResponseHandler("openai completion", OpenAiChatCompletionResponseEntity::fromResponse); - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInput.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInput.java new file mode 100644 index 0000000000000..ef2fc7fafc2bd --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInput.java @@ -0,0 +1,40 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.elasticsearch.xpack.inference.external.request.openai.OpenAiUnifiedChatCompletionRequestEntity; + +import java.util.Objects; + +public class UnifiedChatInput extends InferenceInputs { + + public static UnifiedChatInput of(InferenceInputs inferenceInputs) { + + if (inferenceInputs instanceof DocumentsOnlyInput docsOnly) { + return new UnifiedChatInput(new OpenAiUnifiedChatCompletionRequestEntity(docsOnly)); + } else if (inferenceInputs instanceof UnifiedChatInput == false) { + throw createUnsupportedTypeException(inferenceInputs); + } + + return (UnifiedChatInput) inferenceInputs; + } + + public OpenAiUnifiedChatCompletionRequestEntity getRequestEntity() { + return requestEntity; + } + + private final OpenAiUnifiedChatCompletionRequestEntity requestEntity; + + public UnifiedChatInput(OpenAiUnifiedChatCompletionRequestEntity requestEntity) { + this.requestEntity = Objects.requireNonNull(requestEntity); + } + + public boolean stream() { + return requestEntity.isStream(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequest.java similarity index 78% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequest.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequest.java index 99a025e70d003..40e1145e0b256 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequest.java @@ -21,35 +21,28 @@ import java.net.URI; import java.net.URISyntaxException; import java.nio.charset.StandardCharsets; -import java.util.List; import java.util.Objects; import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.createOrgHeader; -public class OpenAiChatCompletionRequest implements OpenAiRequest { +public class OpenAiUnifiedChatCompletionRequest implements OpenAiRequest { private final OpenAiAccount account; - private final List input; + private final OpenAiUnifiedChatCompletionRequestEntity requestEntity; private final OpenAiChatCompletionModel model; - private final boolean stream; - public OpenAiChatCompletionRequest(List input, OpenAiChatCompletionModel model, boolean stream) { - this.account = OpenAiAccount.of(model, OpenAiChatCompletionRequest::buildDefaultUri); - this.input = Objects.requireNonNull(input); + public OpenAiUnifiedChatCompletionRequest(OpenAiUnifiedChatCompletionRequestEntity requestEntity, OpenAiChatCompletionModel model) { + this.account = OpenAiAccount.of(model, OpenAiUnifiedChatCompletionRequest::buildDefaultUri); + this.requestEntity = Objects.requireNonNull(requestEntity); this.model = Objects.requireNonNull(model); - this.stream = stream; } @Override public HttpRequest createHttpRequest() { HttpPost httpPost = new HttpPost(account.uri()); - ByteArrayEntity byteEntity = new ByteArrayEntity( - Strings.toString( - new OpenAiChatCompletionRequestEntity(input, model.getServiceSettings().modelId(), model.getTaskSettings().user(), stream) - ).getBytes(StandardCharsets.UTF_8) - ); + ByteArrayEntity byteEntity = new ByteArrayEntity(Strings.toString(requestEntity).getBytes(StandardCharsets.UTF_8)); httpPost.setEntity(byteEntity); httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()); @@ -87,7 +80,7 @@ public String getInferenceEntityId() { @Override public boolean isStreaming() { - return stream; + return requestEntity.isStream(); } public static URI buildDefaultUri() throws URISyntaxException { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java new file mode 100644 index 0000000000000..ad1bbc424bfba --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java @@ -0,0 +1,200 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.openai; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionRequest; +import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +public class OpenAiUnifiedChatCompletionRequestEntity implements ToXContentObject { + + private static final String MESSAGES_FIELD = "messages"; + private static final String MODEL_FIELD = "model"; + + private static final String NUMBER_OF_RETURNED_CHOICES_FIELD = "n"; + + private static final String ROLE_FIELD = "role"; + private static final String USER_FIELD = "user"; + private static final String CONTENT_FIELD = "content"; + private static final String STREAM_FIELD = "stream"; + private static final String MAX_COMPLETION_TOKENS_FIELD = "max_completion_tokens"; + private static final String STOP_FIELD = "stop"; + private static final String TEMPERATURE_FIELD = "temperature"; + private static final String TOOL_CHOICE_FIELD = "tool_choice"; + private static final String TOOL_FIELD = "tool"; + private static final String TOP_P_FIELD = "top_p"; + + private final String user; + + public boolean isStream() { + return stream; + } + + private final boolean stream; + private final Long maxCompletionTokens; + private final Integer n; + private final UnifiedCompletionRequest.Stop stop; + private final Float temperature; + private final UnifiedCompletionRequest.ToolChoice toolChoice; + private final List tool; + private final Float topP; + private final List messages; + private final String model; + + public OpenAiUnifiedChatCompletionRequestEntity(DocumentsOnlyInput input) { + this(convertDocumentsOnlyInputToMessages(input), null, null, null, null, null, null, null, null, null); + } + + private static List convertDocumentsOnlyInputToMessages(DocumentsOnlyInput input) { + return input.getInputs() + .stream() + .map(doc -> new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString(doc), "user", null, null, null)) + .toList(); + } + + public OpenAiUnifiedChatCompletionRequestEntity( + List messages, + @Nullable String model, + @Nullable Long maxCompletionTokens, + @Nullable Integer n, + @Nullable UnifiedCompletionRequest.Stop stop, + @Nullable Float temperature, + @Nullable UnifiedCompletionRequest.ToolChoice toolChoice, + @Nullable List tool, + @Nullable Float topP, + @Nullable String user + ) { + Objects.requireNonNull(messages); + Objects.requireNonNull(model); + + this.user = user; + this.stream = true; // always stream in unified API + this.maxCompletionTokens = maxCompletionTokens; + this.n = n; + this.stop = stop; + this.temperature = temperature; + this.toolChoice = toolChoice; + this.tool = tool; + this.topP = topP; + this.messages = messages; + this.model = model; + + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.startArray(MESSAGES_FIELD); + { + for (UnifiedCompletionRequest.Message message : messages) { + builder.startObject(); + { + builder.field(CONTENT_FIELD); + message.content().toXContent(builder, params); + builder.field(ROLE_FIELD, message.role()); + if (message.name() != null) { + builder.field("name", message.name());// <---- HERE + } + if (message.toolCallId() != null) { + builder.field("tool_call_id", message.toolCallId()); + } + if (message.toolCalls() != null) { + builder.startArray("tool_calls"); + for (UnifiedCompletionRequest.ToolCall toolCall : message.toolCalls()) { + builder.startObject(); + { + builder.field("id", toolCall.id()); + builder.startObject("function"); + { + builder.field("arguments", toolCall.function().arguments()); + builder.field("name", toolCall.function().name()); + } + builder.endObject(); + builder.field("type", toolCall.type()); + } + builder.endObject(); + } + builder.endArray(); + } + } + builder.endObject(); + } + } + builder.endArray(); + + if (model != null) { + builder.field(MODEL_FIELD, model); + } + if (maxCompletionTokens != null) { + builder.field(MAX_COMPLETION_TOKENS_FIELD, maxCompletionTokens); + } + if (n != null) { + builder.field(NUMBER_OF_RETURNED_CHOICES_FIELD, n); + } + if (stop != null) { + if (stop instanceof UnifiedCompletionRequest.StopString) { + builder.field(STOP_FIELD, ((UnifiedCompletionRequest.StopString) stop).value()); + } else if (stop instanceof UnifiedCompletionRequest.StopValues) { + builder.field(STOP_FIELD, ((UnifiedCompletionRequest.StopValues) stop).values()); + } + } + if (temperature != null) { + builder.field(TEMPERATURE_FIELD, temperature); + } + if (toolChoice != null) { + if (toolChoice instanceof UnifiedCompletionRequest.ToolChoiceString) { + builder.field(TOOL_CHOICE_FIELD, ((UnifiedCompletionRequest.ToolChoiceString) toolChoice).value()); + } else if (toolChoice instanceof UnifiedCompletionRequest.ToolChoiceObject) { + builder.startObject(TOOL_CHOICE_FIELD); + { + builder.field("type", ((UnifiedCompletionRequest.ToolChoiceObject) toolChoice).type()); + builder.startObject("function"); + { + builder.field("name", ((UnifiedCompletionRequest.ToolChoiceObject) toolChoice).function().name()); + } + builder.endObject(); + } + builder.endObject(); + } + } + if (tool != null) { + builder.startArray(TOOL_FIELD); + for (UnifiedCompletionRequest.Tool t : tool) { + builder.startObject(); + { + builder.field("type", t.type()); + builder.startObject("function"); + { + builder.field("description", t.function().description()); + builder.field("name", t.function().name()); + builder.field("parameters", t.function().parameters()); + builder.field("strict", t.function().strict()); + } + builder.endObject(); + } + builder.endObject(); + } + builder.endArray(); + } + if (topP != null) { + builder.field(TOP_P_FIELD, topP); + } + if (Strings.isNullOrEmpty(user) == false) { + builder.field(USER_FIELD, user); + } + builder.endObject(); + return builder; + } +} From 9cb401caf90268772cb595b6be6dd840f08a9246 Mon Sep 17 00:00:00 2001 From: Max Hniebergall Date: Mon, 25 Nov 2024 16:06:38 -0500 Subject: [PATCH 07/53] Improvements to request serialization --- .../action/UnifiedCompletionRequest.java | 22 +----------------- ...nAiUnifiedChatCompletionRequestEntity.java | 23 +++++++++++++------ 2 files changed, 17 insertions(+), 28 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequest.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequest.java index 0e163792f8b06..4d093b4ed600a 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequest.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequest.java @@ -42,9 +42,7 @@ public record UnifiedCompletionRequest( @Nullable String user ) implements Writeable { - public sealed interface Content extends NamedWriteable permits ContentObjects, ContentString { - void toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException; - } + public sealed interface Content extends NamedWriteable permits ContentObjects, ContentString {} @SuppressWarnings("unchecked") static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( @@ -177,17 +175,6 @@ public void writeTo(StreamOutput out) throws IOException { out.writeCollection(contentObjects); } - @Override - public void toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { - builder.startArray(); - for (ContentObject contentObject : contentObjects) { - builder.startObject(); - contentObject.toXContentObject(builder, params); - builder.endObject(); - } - builder.endArray(); - } - @Override public String getWriteableName() { return NAME; @@ -215,13 +202,6 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(type); } - public XContentBuilder toXContentObject(XContentBuilder builder, ToXContent.Params params) throws IOException { - builder.startObject(); - builder.field("text", text); - builder.field("type", type); - builder.endObject(); - return builder; - } } public record ContentString(String content) implements Content, NamedWriteable { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java index ad1bbc424bfba..4c36806a88b76 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java @@ -101,11 +101,21 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws for (UnifiedCompletionRequest.Message message : messages) { builder.startObject(); { - builder.field(CONTENT_FIELD); - message.content().toXContent(builder, params); + switch (message.content()) { + case UnifiedCompletionRequest.ContentString contentString -> builder.field(CONTENT_FIELD, contentString.content()); + case UnifiedCompletionRequest.ContentObjects contentObjects -> { + for (UnifiedCompletionRequest.ContentObject contentObject : contentObjects.contentObjects()) { + builder.startObject(CONTENT_FIELD); + builder.field("text", contentObject.text()); + builder.field("type", contentObject.type()); + builder.endObject(); + } + } + } + builder.field(ROLE_FIELD, message.role()); if (message.name() != null) { - builder.field("name", message.name());// <---- HERE + builder.field("name", message.name()); } if (message.toolCallId() != null) { builder.field("tool_call_id", message.toolCallId()); @@ -144,10 +154,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(NUMBER_OF_RETURNED_CHOICES_FIELD, n); } if (stop != null) { - if (stop instanceof UnifiedCompletionRequest.StopString) { - builder.field(STOP_FIELD, ((UnifiedCompletionRequest.StopString) stop).value()); - } else if (stop instanceof UnifiedCompletionRequest.StopValues) { - builder.field(STOP_FIELD, ((UnifiedCompletionRequest.StopValues) stop).values()); + switch (stop) { + case UnifiedCompletionRequest.StopString stopString -> builder.field(STOP_FIELD, stopString.value()); + case UnifiedCompletionRequest.StopValues stopValues -> builder.field(STOP_FIELD, stopValues.values()); } } if (temperature != null) { From 1e0eb204ecfc05475147af2dac3e1187e2eac0ac Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Mon, 25 Nov 2024 16:37:40 -0500 Subject: [PATCH 08/53] Adding separate transport classes --- .../action/BaseInferenceActionRequest.java | 14 +- .../inference/action/InferenceAction.java | 3 +- .../action/UnifiedCompletionAction.java | 25 +- .../action/UnifiedCompletionRequest.java | 5 - .../authz/store/ReservedRolesStoreTests.java | 1 + .../xpack/inference/InferencePlugin.java | 3 + .../action/BaseTransportInferenceAction.java | 87 +++--- .../action/TransportInferenceAction.java | 251 ++---------------- ...sportUnifiedCompletionInferenceAction.java | 76 ++++++ .../inference/rest/BaseInferenceAction.java | 32 ++- .../RestUnifiedCompletionInferenceAction.java | 28 +- .../xpack/security/operator/Constants.java | 1 + 12 files changed, 212 insertions(+), 314 deletions(-) rename x-pack/plugin/{inference/src/main/java/org/elasticsearch/xpack => core/src/main/java/org/elasticsearch/xpack/core}/inference/action/BaseInferenceActionRequest.java (64%) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseInferenceActionRequest.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/BaseInferenceActionRequest.java similarity index 64% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseInferenceActionRequest.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/BaseInferenceActionRequest.java index 5d3b470edc933..e426574c52ce6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseInferenceActionRequest.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/BaseInferenceActionRequest.java @@ -5,12 +5,24 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.action; +package org.elasticsearch.xpack.core.inference.action; import org.elasticsearch.action.ActionRequest; +import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.inference.TaskType; +import java.io.IOException; + public abstract class BaseInferenceActionRequest extends ActionRequest { + + public BaseInferenceActionRequest() { + super(); + } + + public BaseInferenceActionRequest(StreamInput in) throws IOException { + super(in); + } + public abstract boolean isStreaming(); public abstract TaskType getTaskType(); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java index ad8e7871b5ee6..4d342c4fce701 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java @@ -10,7 +10,6 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; -import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.ActionResponse; import org.elasticsearch.action.ActionType; @@ -55,7 +54,7 @@ public InferenceAction() { super(NAME); } - public static class Request extends ActionRequest { + public static class Request extends BaseInferenceActionRequest { public static final TimeValue DEFAULT_TIMEOUT = TimeValue.timeValueSeconds(30); public static final ParseField INPUT = new ParseField("input"); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionAction.java index adc6a343ea782..13db70dd04f72 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionAction.java @@ -7,11 +7,11 @@ package org.elasticsearch.xpack.core.inference.action; -import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.ActionType; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xcontent.XContentParser; @@ -26,20 +26,22 @@ public UnifiedCompletionAction() { super(NAME); } - public static class Request extends ActionRequest { - public static Request parseRequest(String inferenceEntityId, TaskType taskType, XContentParser parser) throws IOException { + public static class Request extends BaseInferenceActionRequest { + public static Request parseRequest(String inferenceEntityId, TaskType taskType, TimeValue timeout, XContentParser parser) throws IOException { var unifiedRequest = UnifiedCompletionRequest.PARSER.apply(parser, null); - return new Request(inferenceEntityId, taskType, unifiedRequest); + return new Request(inferenceEntityId, taskType, unifiedRequest, timeout); } private final String inferenceEntityId; private final TaskType taskType; private final UnifiedCompletionRequest unifiedCompletionRequest; + private final TimeValue timeout; - public Request(String inferenceEntityId, TaskType taskType, UnifiedCompletionRequest unifiedCompletionRequest) { + public Request(String inferenceEntityId, TaskType taskType, UnifiedCompletionRequest unifiedCompletionRequest, TimeValue timeout) { this.inferenceEntityId = Objects.requireNonNull(inferenceEntityId); this.taskType = Objects.requireNonNull(taskType); this.unifiedCompletionRequest = Objects.requireNonNull(unifiedCompletionRequest); + this.timeout = Objects.requireNonNull(timeout); } public Request(StreamInput in) throws IOException { @@ -47,6 +49,7 @@ public Request(StreamInput in) throws IOException { this.inferenceEntityId = in.readString(); this.taskType = TaskType.fromStream(in); this.unifiedCompletionRequest = new UnifiedCompletionRequest(in); + this.timeout = in.readTimeValue(); } public TaskType getTaskType() { @@ -62,7 +65,11 @@ public UnifiedCompletionRequest getUnifiedCompletionRequest() { } public boolean isStreaming() { - return Objects.requireNonNullElse(unifiedCompletionRequest.stream(), false); + return true; + } + + public TimeValue getTimeout() { + return timeout; } @Override @@ -94,6 +101,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(inferenceEntityId); taskType.writeTo(out); unifiedCompletionRequest.writeTo(out); + out.writeTimeValue(timeout); } @Override @@ -102,12 +110,13 @@ public boolean equals(Object o) { Request request = (Request) o; return Objects.equals(inferenceEntityId, request.inferenceEntityId) && taskType == request.taskType - && Objects.equals(unifiedCompletionRequest, request.unifiedCompletionRequest); + && Objects.equals(unifiedCompletionRequest, request.unifiedCompletionRequest) && + Objects.equals(timeout, request.timeout); } @Override public int hashCode() { - return Objects.hash(inferenceEntityId, taskType, unifiedCompletionRequest); + return Objects.hash(inferenceEntityId, taskType, unifiedCompletionRequest, timeout); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequest.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequest.java index 2799ab6e1b6ef..a17fb6694dc7c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequest.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequest.java @@ -32,7 +32,6 @@ public record UnifiedCompletionRequest( @Nullable Long maxCompletionTokens, @Nullable Integer n, @Nullable Stop stop, - @Nullable Boolean stream, @Nullable Float temperature, @Nullable ToolChoice toolChoice, @Nullable List tool, @@ -49,7 +48,6 @@ public record UnifiedCompletionRequest( (Long) args[2], (Integer) args[3], (Stop) args[4], - (Boolean) args[5], (Float) args[6], (ToolChoice) args[7], (List) args[8], @@ -64,7 +62,6 @@ public record UnifiedCompletionRequest( PARSER.declareLong(optionalConstructorArg(), new ParseField("max_completion_tokens")); PARSER.declareInt(optionalConstructorArg(), new ParseField("n")); PARSER.declareField(optionalConstructorArg(), (p, c) -> parseStop(p), new ParseField("stop"), ObjectParser.ValueType.VALUE_ARRAY); - PARSER.declareBoolean(optionalConstructorArg(), new ParseField("stream")); PARSER.declareFloat(optionalConstructorArg(), new ParseField("temperature")); PARSER.declareField( optionalConstructorArg(), @@ -84,7 +81,6 @@ public UnifiedCompletionRequest(StreamInput in) throws IOException { in.readOptionalVLong(), in.readOptionalVInt(), in.readOptionalNamedWriteable(Stop.class), - in.readOptionalBoolean(), in.readOptionalFloat(), in.readOptionalNamedWriteable(ToolChoice.class), in.readCollectionAsImmutableList(Tool::new), @@ -100,7 +96,6 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalVLong(maxCompletionTokens); out.writeOptionalVInt(n); out.writeOptionalNamedWriteable(stop); - out.writeOptionalBoolean(stream); out.writeOptionalFloat(temperature); out.writeOptionalNamedWriteable(toolChoice); out.writeOptionalCollection(tool); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/store/ReservedRolesStoreTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/store/ReservedRolesStoreTests.java index fb4d822b7655c..67d61d68931ca 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/store/ReservedRolesStoreTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/store/ReservedRolesStoreTests.java @@ -4236,6 +4236,7 @@ public void testInferenceUserRole() { assertTrue(role.cluster().check("cluster:monitor/xpack/inference", request, authentication)); assertTrue(role.cluster().check("cluster:monitor/xpack/inference/get", request, authentication)); assertFalse(role.cluster().check("cluster:admin/xpack/inference/put", request, authentication)); + assertTrue(role.cluster().check("cluster:monitor/xpack/inference/unified", request, authentication)); assertFalse(role.cluster().check("cluster:admin/xpack/inference/delete", request, authentication)); assertTrue(role.cluster().check("cluster:monitor/xpack/ml/trained_models/deployment/infer", request, authentication)); assertFalse(role.cluster().check("cluster:admin/xpack/ml/trained_models/deployment/start", request, authentication)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index f068caff805af..e61c25291daf7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -48,6 +48,7 @@ import org.elasticsearch.xpack.core.inference.action.GetInferenceServicesAction; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction; +import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction; import org.elasticsearch.xpack.core.inference.action.UpdateInferenceModelAction; import org.elasticsearch.xpack.inference.action.TransportDeleteInferenceEndpointAction; import org.elasticsearch.xpack.inference.action.TransportGetInferenceDiagnosticsAction; @@ -56,6 +57,7 @@ import org.elasticsearch.xpack.inference.action.TransportInferenceAction; import org.elasticsearch.xpack.inference.action.TransportInferenceUsageAction; import org.elasticsearch.xpack.inference.action.TransportPutInferenceModelAction; +import org.elasticsearch.xpack.inference.action.TransportUnifiedCompletionInferenceAction; import org.elasticsearch.xpack.inference.action.TransportUpdateInferenceModelAction; import org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter; import org.elasticsearch.xpack.inference.common.Truncator; @@ -152,6 +154,7 @@ public InferencePlugin(Settings settings) { public List> getActions() { return List.of( new ActionHandler<>(InferenceAction.INSTANCE, TransportInferenceAction.class), + new ActionHandler<>(UnifiedCompletionAction.INSTANCE, TransportUnifiedCompletionInferenceAction.class), new ActionHandler<>(GetInferenceModelAction.INSTANCE, TransportGetInferenceModelAction.class), new ActionHandler<>(PutInferenceModelAction.INSTANCE, TransportPutInferenceModelAction.class), new ActionHandler<>(UpdateInferenceModelAction.INSTANCE, TransportUpdateInferenceModelAction.class), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java index 7fb813ebd777f..efb8098c2b455 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java @@ -23,10 +23,10 @@ import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnparsedModel; -import org.elasticsearch.injection.guice.Inject; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.inference.action.BaseInferenceActionRequest; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; import org.elasticsearch.xpack.inference.common.DelegatingProcessor; @@ -40,8 +40,8 @@ import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.modelAttributes; import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.responseAttributes; -public abstract class BaseTransportInferenceAction extends HandledTransportAction< - T, +public abstract class BaseTransportInferenceAction extends HandledTransportAction< + Request, InferenceAction.Response> { private static final Logger log = LogManager.getLogger(BaseTransportInferenceAction.class); @@ -52,8 +52,6 @@ public abstract class BaseTransportInferenceAction requestReader + Writeable.Reader requestReader ) { super(InferenceAction.NAME, transportService, actionFilters, requestReader, EsExecutors.DIRECT_EXECUTOR_SERVICE); this.modelRegistry = modelRegistry; @@ -71,7 +69,7 @@ public BaseTransportInferenceAction( } @Override - protected void doExecute(Task task, T request, ActionListener listener) { + protected void doExecute(Task task, Request request, ActionListener listener) { var timer = InferenceTimer.start(); var getModelListener = ActionListener.wrap((UnparsedModel unparsedModel) -> { @@ -92,7 +90,7 @@ protected void doExecute(Task task, T request, ActionListener listener @@ -178,43 +163,43 @@ private void recordMetrics(Model model, InferenceTimer timer, @Nullable Throwabl } } - private void inferOnService( - Model model, - InferenceAction.Request request, - InferenceService service, - ActionListener listener - ) { - Runnable inferenceRunnable = inferRunnable(model, request, service, listener); - + private void inferOnService(Model model, Request request, InferenceService service, ActionListener listener) { if (request.isStreaming() == false || service.canStream(request.getTaskType())) { - inferenceRunnable.run(); + doInference(model, request, service, listener); } else { listener.onFailure(unsupportedStreamingTaskException(request, service)); } } - private static Runnable inferRunnable( + // private static Runnable inferRunnable( + // Model model, + // T request, + // InferenceService service, + // ActionListener listener + // ) { + // return request.isUnifiedCompletionMode() + // // TODO add parameters + // ? () -> service.completionInfer(model, null, request.getInferenceTimeout(), listener) + // : () -> service.infer( + // model, + // request.getQuery(), + // request.getInput(), + // request.isStreaming(), + // request.getTaskSettings(), + // request.getInputType(), + // request.getInferenceTimeout(), + // listener + // ); + // } + + protected abstract void doInference( Model model, - InferenceAction.Request request, + Request request, InferenceService service, ActionListener listener - ) { - return request.isUnifiedCompletionMode() - // TODO add parameters - ? () -> service.completionInfer(model, null, request.getInferenceTimeout(), listener) - : () -> service.infer( - model, - request.getQuery(), - request.getInput(), - request.isStreaming(), - request.getTaskSettings(), - request.getInputType(), - request.getInferenceTimeout(), - listener - ); - } + ); - private ElasticsearchStatusException unsupportedStreamingTaskException(InferenceAction.Request request, InferenceService service) { + private ElasticsearchStatusException unsupportedStreamingTaskException(Request request, InferenceService service) { var supportedTasks = service.supportedStreamingTasks(); if (supportedTasks.isEmpty()) { return new ElasticsearchStatusException( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java index 8459dbff75f36..11e849d4be883 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java @@ -7,47 +7,22 @@ package org.elasticsearch.xpack.inference.action; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.ActionFilters; -import org.elasticsearch.action.support.HandledTransportAction; -import org.elasticsearch.common.util.concurrent.EsExecutors; -import org.elasticsearch.common.xcontent.ChunkedToXContent; -import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.injection.guice.Inject; -import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.tasks.Task; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; -import org.elasticsearch.xpack.inference.common.DelegatingProcessor; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.telemetry.InferenceStats; -import org.elasticsearch.xpack.inference.telemetry.InferenceTimer; -import java.util.stream.Collectors; - -import static org.elasticsearch.core.Strings.format; -import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.modelAttributes; -import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.responseAttributes; - -public class TransportInferenceAction extends HandledTransportAction { - - private static final Logger log = LogManager.getLogger(TransportInferenceAction.class); - private static final String STREAMING_INFERENCE_TASK_TYPE = "streaming_inference"; - private static final String STREAMING_TASK_ACTION = "xpack/inference/streaming_inference[n]"; - private final ModelRegistry modelRegistry; - private final InferenceServiceRegistry serviceRegistry; - private final InferenceStats inferenceStats; - private final StreamingTaskManager streamingTaskManager; +public class TransportInferenceAction extends BaseTransportInferenceAction { @Inject public TransportInferenceAction( @@ -58,218 +33,46 @@ public TransportInferenceAction( InferenceStats inferenceStats, StreamingTaskManager streamingTaskManager ) { - super(InferenceAction.NAME, transportService, actionFilters, InferenceAction.Request::new, EsExecutors.DIRECT_EXECUTOR_SERVICE); - this.modelRegistry = modelRegistry; - this.serviceRegistry = serviceRegistry; - this.inferenceStats = inferenceStats; - this.streamingTaskManager = streamingTaskManager; - } - - @Override - protected void doExecute(Task task, InferenceAction.Request request, ActionListener listener) { - var timer = InferenceTimer.start(); - - var getModelListener = ActionListener.wrap((UnparsedModel unparsedModel) -> { - var service = serviceRegistry.getService(unparsedModel.service()); - if (service.isEmpty()) { - var e = unknownServiceException(unparsedModel.service(), request.getInferenceEntityId()); - recordMetrics(unparsedModel, timer, e); - listener.onFailure(e); - return; - } - - if (request.getTaskType().isAnyOrSame(unparsedModel.taskType()) == false) { - // not the wildcard task type and not the model task type - var e = incompatibleTaskTypeException(request.getTaskType(), unparsedModel.taskType()); - recordMetrics(unparsedModel, timer, e); - listener.onFailure(e); - return; - } - - if (isInvalidTaskTypeForUnifiedCompletionMode(request, unparsedModel)) { - var e = incompatibleUnifiedModeTaskTypeException(request.getTaskType()); - recordMetrics(unparsedModel, timer, e); - listener.onFailure(e); - return; - } - - var model = service.get() - .parsePersistedConfigWithSecrets( - unparsedModel.inferenceEntityId(), - unparsedModel.taskType(), - unparsedModel.settings(), - unparsedModel.secrets() - ); - inferOnServiceWithMetrics(model, request, service.get(), timer, listener); - }, e -> { - try { - inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(e)); - } catch (Exception metricsException) { - log.atDebug().withThrowable(metricsException).log("Failed to record metrics when the model is missing, dropping metrics"); - } - listener.onFailure(e); - }); - - modelRegistry.getModelWithSecrets(request.getInferenceEntityId(), getModelListener); - } - - private static boolean isInvalidTaskTypeForUnifiedCompletionMode(InferenceAction.Request request, UnparsedModel unparsedModel) { - return request.isUnifiedCompletionMode() && request.getTaskType() != TaskType.COMPLETION; - } - - private static ElasticsearchStatusException incompatibleUnifiedModeTaskTypeException(TaskType requested) { - return new ElasticsearchStatusException( - "Incompatible task_type for unified API, the requested type [{}] must be one of [{}]", - RestStatus.BAD_REQUEST, - requested, - TaskType.COMPLETION.toString() + super( + transportService, + actionFilters, + modelRegistry, + serviceRegistry, + inferenceStats, + streamingTaskManager, + InferenceAction.Request::new ); } - private void recordMetrics(UnparsedModel model, InferenceTimer timer, @Nullable Throwable t) { - try { - inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(model, t)); - } catch (Exception e) { - log.atDebug().withThrowable(e).log("Failed to record metrics with an unparsed model, dropping metrics"); - } - } - - private void inferOnServiceWithMetrics( - Model model, - InferenceAction.Request request, - InferenceService service, - InferenceTimer timer, - ActionListener listener - ) { - inferenceStats.requestCount().incrementBy(1, modelAttributes(model)); - inferOnService(model, request, service, ActionListener.wrap(inferenceResults -> { - if (request.isStreaming()) { - var taskProcessor = streamingTaskManager.create(STREAMING_INFERENCE_TASK_TYPE, STREAMING_TASK_ACTION); - inferenceResults.publisher().subscribe(taskProcessor); - - var instrumentedStream = new PublisherWithMetrics(timer, model); - taskProcessor.subscribe(instrumentedStream); - - listener.onResponse(new InferenceAction.Response(inferenceResults, instrumentedStream)); - } else { - recordMetrics(model, timer, null); - listener.onResponse(new InferenceAction.Response(inferenceResults)); - } - }, e -> { - recordMetrics(model, timer, e); - listener.onFailure(e); - })); - } - - private void recordMetrics(Model model, InferenceTimer timer, @Nullable Throwable t) { - try { - inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(model, t)); - } catch (Exception e) { - log.atDebug().withThrowable(e).log("Failed to record metrics with a parsed model, dropping metrics"); - } + @Override + protected boolean isInvalidTaskTypeForInferenceEndpoint(InferenceAction.Request request, UnparsedModel unparsedModel) { + return false; } - private void inferOnService( - Model model, + @Override + protected ElasticsearchStatusException createIncompatibleTaskTypeException( InferenceAction.Request request, - InferenceService service, - ActionListener listener + UnparsedModel unparsedModel ) { - Runnable inferenceRunnable = inferRunnable(model, request, service, listener); - - if (request.isStreaming() == false || service.canStream(request.getTaskType())) { - inferenceRunnable.run(); - } else { - listener.onFailure(unsupportedStreamingTaskException(request, service)); - } + return null; } - private static Runnable inferRunnable( + @Override + protected void doInference( Model model, InferenceAction.Request request, InferenceService service, ActionListener listener ) { - return request.isUnifiedCompletionMode() - // TODO add parameters - ? () -> service.completionInfer(model, null, request.getInferenceTimeout(), listener) - : () -> service.infer( - model, - request.getQuery(), - request.getInput(), - request.isStreaming(), - request.getTaskSettings(), - request.getInputType(), - request.getInferenceTimeout(), - listener - ); - } - - private ElasticsearchStatusException unsupportedStreamingTaskException(InferenceAction.Request request, InferenceService service) { - var supportedTasks = service.supportedStreamingTasks(); - if (supportedTasks.isEmpty()) { - return new ElasticsearchStatusException( - format("Streaming is not allowed for service [%s].", service.name()), - RestStatus.METHOD_NOT_ALLOWED - ); - } else { - var validTasks = supportedTasks.stream().map(TaskType::toString).collect(Collectors.joining(",")); - return new ElasticsearchStatusException( - format( - "Streaming is not allowed for service [%s] and task [%s]. Supported tasks: [%s]", - service.name(), - request.getTaskType(), - validTasks - ), - RestStatus.METHOD_NOT_ALLOWED - ); - } - } - - private static ElasticsearchStatusException unknownServiceException(String service, String inferenceId) { - return new ElasticsearchStatusException("Unknown service [{}] for model [{}]. ", RestStatus.BAD_REQUEST, service, inferenceId); - } - - private static ElasticsearchStatusException incompatibleTaskTypeException(TaskType requested, TaskType expected) { - return new ElasticsearchStatusException( - "Incompatible task_type, the requested type [{}] does not match the model type [{}]", - RestStatus.BAD_REQUEST, - requested, - expected + service.infer( + model, + request.getQuery(), + request.getInput(), + request.isStreaming(), + request.getTaskSettings(), + request.getInputType(), + request.getInferenceTimeout(), + listener ); } - - private class PublisherWithMetrics extends DelegatingProcessor { - - private final InferenceTimer timer; - private final Model model; - - private PublisherWithMetrics(InferenceTimer timer, Model model) { - this.timer = timer; - this.model = model; - } - - @Override - protected void next(ChunkedToXContent item) { - downstream().onNext(item); - } - - @Override - public void onError(Throwable throwable) { - recordMetrics(model, timer, throwable); - super.onError(throwable); - } - - @Override - protected void onCancel() { - recordMetrics(model, timer, null); - super.onCancel(); - } - - @Override - public void onComplete() { - recordMetrics(model, timer, null); - super.onComplete(); - } - } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java new file mode 100644 index 0000000000000..497929b3e5848 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java @@ -0,0 +1,76 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.action; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnparsedModel; +import org.elasticsearch.injection.guice.Inject; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction; +import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.telemetry.InferenceStats; + +public class TransportUnifiedCompletionInferenceAction extends BaseTransportInferenceAction { + + @Inject + public TransportUnifiedCompletionInferenceAction( + TransportService transportService, + ActionFilters actionFilters, + ModelRegistry modelRegistry, + InferenceServiceRegistry serviceRegistry, + InferenceStats inferenceStats, + StreamingTaskManager streamingTaskManager + ) { + super( + transportService, + actionFilters, + modelRegistry, + serviceRegistry, + inferenceStats, + streamingTaskManager, + UnifiedCompletionAction.Request::new + ); + } + + @Override + protected boolean isInvalidTaskTypeForInferenceEndpoint(UnifiedCompletionAction.Request request, UnparsedModel unparsedModel) { + return request.getTaskType() != TaskType.COMPLETION; + } + + @Override + protected ElasticsearchStatusException createIncompatibleTaskTypeException( + UnifiedCompletionAction.Request request, + UnparsedModel unparsedModel + ) { + return new ElasticsearchStatusException( + "Incompatible task_type for unified API, the requested type [{}] must be one of [{}]", + RestStatus.BAD_REQUEST, + request.getTaskType(), + TaskType.COMPLETION.toString() + ); + } + + @Override + protected void doInference( + Model model, + UnifiedCompletionAction.Request request, + InferenceService service, + ActionListener listener + ) { + service.completionInfer(model, request.getUnifiedCompletionRequest(), null, listener); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/BaseInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/BaseInferenceAction.java index e72e68052f648..d911158e82296 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/BaseInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/BaseInferenceAction.java @@ -9,6 +9,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.client.internal.node.NodeClient; +import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.BaseRestHandler; import org.elasticsearch.rest.RestChannel; @@ -21,27 +22,32 @@ import static org.elasticsearch.xpack.inference.rest.Paths.TASK_TYPE_OR_INFERENCE_ID; abstract class BaseInferenceAction extends BaseRestHandler { - @Override - protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { - String inferenceEntityId; - TaskType taskType; + static Params parseParams(RestRequest restRequest) { if (restRequest.hasParam(INFERENCE_ID)) { - inferenceEntityId = restRequest.param(INFERENCE_ID); - taskType = TaskType.fromStringOrStatusException(restRequest.param(TASK_TYPE_OR_INFERENCE_ID)); + var inferenceEntityId = restRequest.param(INFERENCE_ID); + var taskType = TaskType.fromStringOrStatusException(restRequest.param(TASK_TYPE_OR_INFERENCE_ID)); + return new Params(inferenceEntityId, taskType); } else { - inferenceEntityId = restRequest.param(TASK_TYPE_OR_INFERENCE_ID); - taskType = TaskType.ANY; + return new Params(restRequest.param(TASK_TYPE_OR_INFERENCE_ID), TaskType.ANY); } + } + + record Params(String inferenceEntityId, TaskType taskType) {} + + static TimeValue parseTimeout(RestRequest restRequest) { + return restRequest.paramAsTime(InferenceAction.Request.TIMEOUT.getPreferredName(), InferenceAction.Request.DEFAULT_TIMEOUT); + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { + var params = parseParams(restRequest); InferenceAction.Request.Builder requestBuilder; try (var parser = restRequest.contentParser()) { - requestBuilder = InferenceAction.Request.parseRequest(inferenceEntityId, taskType, parser); + requestBuilder = InferenceAction.Request.parseRequest(params.inferenceEntityId(), params.taskType(), parser); } - var inferTimeout = restRequest.paramAsTime( - InferenceAction.Request.TIMEOUT.getPreferredName(), - InferenceAction.Request.DEFAULT_TIMEOUT - ); + var inferTimeout = parseTimeout(restRequest); requestBuilder.setInferenceTimeout(inferTimeout); var request = prepareInferenceRequest(requestBuilder); return channel -> client.execute(InferenceAction.INSTANCE, request, listener(channel)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceAction.java index 0056e80af15ca..48f031d3df8cc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceAction.java @@ -7,19 +7,23 @@ package org.elasticsearch.xpack.inference.rest; -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.rest.RestChannel; +import org.elasticsearch.client.internal.node.NodeClient; +import org.elasticsearch.rest.BaseRestHandler; +import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.Scope; import org.elasticsearch.rest.ServerlessScope; import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction; +import java.io.IOException; import java.util.List; import static org.elasticsearch.rest.RestRequest.Method.POST; +import static org.elasticsearch.xpack.inference.rest.Paths.UNIFIED_INFERENCE_ID_PATH; import static org.elasticsearch.xpack.inference.rest.Paths.UNIFIED_TASK_TYPE_INFERENCE_ID_PATH; @ServerlessScope(Scope.PUBLIC) -public class RestUnifiedCompletionInferenceAction extends BaseInferenceAction { +public class RestUnifiedCompletionInferenceAction extends BaseRestHandler { @Override public String getName() { return "unified_inference_action"; @@ -27,16 +31,20 @@ public String getName() { @Override public List routes() { - return List.of(new Route(POST, UNIFIED_TASK_TYPE_INFERENCE_ID_PATH), new Route(POST, UNIFIED_TASK_TYPE_INFERENCE_ID_PATH)); + return List.of(new Route(POST, UNIFIED_INFERENCE_ID_PATH), new Route(POST, UNIFIED_TASK_TYPE_INFERENCE_ID_PATH)); } @Override - protected InferenceAction.Request prepareInferenceRequest(InferenceAction.Request.Builder builder) { - return builder.setUnifiedCompletionMode(true).setStream(true).build(); - } + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { + var params = BaseInferenceAction.parseParams(restRequest); - @Override - protected ActionListener listener(RestChannel channel) { - return new ServerSentEventsRestActionListener(channel); + var inferTimeout = BaseInferenceAction.parseTimeout(restRequest); + + UnifiedCompletionAction.Request request; + try (var parser = restRequest.contentParser()) { + request = UnifiedCompletionAction.Request.parseRequest(params.inferenceEntityId(), params.taskType(), inferTimeout, parser); + } + + return channel -> client.execute(InferenceAction.INSTANCE, request, new ServerSentEventsRestActionListener(channel)); } } diff --git a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java index df97c489cc6b7..2a0fed5f445e3 100644 --- a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java +++ b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java @@ -383,6 +383,7 @@ public class Constants { "cluster:monitor/xpack/esql/stats/dist", "cluster:monitor/xpack/inference", "cluster:monitor/xpack/inference/get", + "cluster:monitor/xpack/inference/unified", "cluster:monitor/xpack/inference/diagnostics/get", "cluster:monitor/xpack/inference/services/get", "cluster:monitor/xpack/info", From d6cc22334c2f835eaa8ef074765f09151eaafa1d Mon Sep 17 00:00:00 2001 From: Max Hniebergall Date: Mon, 25 Nov 2024 17:24:50 -0500 Subject: [PATCH 09/53] separate out unified request and combine inputs --- .../http/sender/CompletionInputs.java | 32 ---- .../external/request/UnifiedRequest.java | 27 +++ ...nAiUnifiedChatCompletionRequestEntity.java | 161 +++++++----------- .../inference/services/SenderService.java | 16 +- .../services/openai/OpenAiService.java | 4 +- 5 files changed, 106 insertions(+), 134 deletions(-) delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CompletionInputs.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/UnifiedRequest.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CompletionInputs.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CompletionInputs.java deleted file mode 100644 index 8f79ceca47b79..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CompletionInputs.java +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.external.http.sender; - -public class CompletionInputs extends InferenceInputs { - public static CompletionInputs of(InferenceInputs inferenceInputs) { - return InferenceInputs.abc(inferenceInputs, CompletionInputs.class); - } - - private final Object parameters; - private final boolean stream; - - public CompletionInputs(Object parameters) { - super(); - this.parameters = parameters; - // TODO retrieve this from the parameters eventually - this.stream = true; - } - - public Object parameters() { - return parameters; - } - - public boolean stream() { - return stream; - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/UnifiedRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/UnifiedRequest.java new file mode 100644 index 0000000000000..e994d29f89426 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/UnifiedRequest.java @@ -0,0 +1,27 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionRequest; + +import java.util.List; + +public record UnifiedRequest( + List messages, + @Nullable String model, + @Nullable Long maxCompletionTokens, + @Nullable Integer n, + @Nullable UnifiedCompletionRequest.Stop stop, + @Nullable Float temperature, + @Nullable UnifiedCompletionRequest.ToolChoice toolChoice, + @Nullable List tool, + @Nullable Float topP, + @Nullable String user, + boolean stream +) {} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java index 4c36806a88b76..b082c80311ba4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java @@ -7,98 +7,65 @@ package org.elasticsearch.xpack.inference.external.request.openai; -import org.elasticsearch.common.Strings; -import org.elasticsearch.core.Nullable; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionRequest; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.request.UnifiedRequest; import java.io.IOException; import java.util.List; -import java.util.Objects; public class OpenAiUnifiedChatCompletionRequestEntity implements ToXContentObject { - private static final String MESSAGES_FIELD = "messages"; - private static final String MODEL_FIELD = "model"; - + public static final String NAME_FIELD = "name"; + public static final String TOOL_CALL_ID_FIELD = "tool_call_id"; + public static final String TOOL_CALLS_FIELD = "tool_calls"; + public static final String ID_FIELD = "id"; + public static final String FUNCTION_FIELD = "function"; + public static final String ARGUMENTS_FIELD = "arguments"; + public static final String DESCRIPTION_FIELD = "description"; + public static final String PARAMETERS_FIELD = "parameters"; + public static final String STRICT_FIELD = "strict"; + public static final String TOP_P_FIELD = "top_p"; + public static final String USER_FIELD = "user"; + public static final String STREAM_FIELD = "stream"; private static final String NUMBER_OF_RETURNED_CHOICES_FIELD = "n"; - + private static final String MODEL_FIELD = "model"; + public static final String MESSAGES_FIELD = "messages"; private static final String ROLE_FIELD = "role"; - private static final String USER_FIELD = "user"; private static final String CONTENT_FIELD = "content"; - private static final String STREAM_FIELD = "stream"; private static final String MAX_COMPLETION_TOKENS_FIELD = "max_completion_tokens"; private static final String STOP_FIELD = "stop"; private static final String TEMPERATURE_FIELD = "temperature"; private static final String TOOL_CHOICE_FIELD = "tool_choice"; private static final String TOOL_FIELD = "tool"; - private static final String TOP_P_FIELD = "top_p"; + private static final String TEXT_FIELD = "text"; + private static final String TYPE_FIELD = "type"; - private final String user; + private final UnifiedRequest unifiedRequest; - public boolean isStream() { - return stream; + public OpenAiUnifiedChatCompletionRequestEntity(UnifiedRequest unifiedRequest) { + this.unifiedRequest = unifiedRequest; } - private final boolean stream; - private final Long maxCompletionTokens; - private final Integer n; - private final UnifiedCompletionRequest.Stop stop; - private final Float temperature; - private final UnifiedCompletionRequest.ToolChoice toolChoice; - private final List tool; - private final Float topP; - private final List messages; - private final String model; - public OpenAiUnifiedChatCompletionRequestEntity(DocumentsOnlyInput input) { - this(convertDocumentsOnlyInputToMessages(input), null, null, null, null, null, null, null, null, null); + this(new UnifiedRequest(convertDocumentsOnlyInputToMessages(input), null, null, null, null, null, null, null, null, null, true)); } private static List convertDocumentsOnlyInputToMessages(DocumentsOnlyInput input) { return input.getInputs() .stream() - .map(doc -> new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString(doc), "user", null, null, null)) + .map(doc -> new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString(doc), USER_FIELD, null, null, null)) .toList(); } - public OpenAiUnifiedChatCompletionRequestEntity( - List messages, - @Nullable String model, - @Nullable Long maxCompletionTokens, - @Nullable Integer n, - @Nullable UnifiedCompletionRequest.Stop stop, - @Nullable Float temperature, - @Nullable UnifiedCompletionRequest.ToolChoice toolChoice, - @Nullable List tool, - @Nullable Float topP, - @Nullable String user - ) { - Objects.requireNonNull(messages); - Objects.requireNonNull(model); - - this.user = user; - this.stream = true; // always stream in unified API - this.maxCompletionTokens = maxCompletionTokens; - this.n = n; - this.stop = stop; - this.temperature = temperature; - this.toolChoice = toolChoice; - this.tool = tool; - this.topP = topP; - this.messages = messages; - this.model = model; - - } - @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); builder.startArray(MESSAGES_FIELD); { - for (UnifiedCompletionRequest.Message message : messages) { + for (UnifiedCompletionRequest.Message message : unifiedRequest.messages()) { builder.startObject(); { switch (message.content()) { @@ -106,8 +73,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws case UnifiedCompletionRequest.ContentObjects contentObjects -> { for (UnifiedCompletionRequest.ContentObject contentObject : contentObjects.contentObjects()) { builder.startObject(CONTENT_FIELD); - builder.field("text", contentObject.text()); - builder.field("type", contentObject.type()); + builder.field(TEXT_FIELD, contentObject.text()); + builder.field(TYPE_FIELD, contentObject.type()); builder.endObject(); } } @@ -115,24 +82,24 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(ROLE_FIELD, message.role()); if (message.name() != null) { - builder.field("name", message.name()); + builder.field(NAME_FIELD, message.name()); } if (message.toolCallId() != null) { - builder.field("tool_call_id", message.toolCallId()); + builder.field(TOOL_CALL_ID_FIELD, message.toolCallId()); } if (message.toolCalls() != null) { - builder.startArray("tool_calls"); + builder.startArray(TOOL_CALLS_FIELD); for (UnifiedCompletionRequest.ToolCall toolCall : message.toolCalls()) { builder.startObject(); { - builder.field("id", toolCall.id()); - builder.startObject("function"); + builder.field(ID_FIELD, toolCall.id()); + builder.startObject(FUNCTION_FIELD); { - builder.field("arguments", toolCall.function().arguments()); - builder.field("name", toolCall.function().name()); + builder.field(ARGUMENTS_FIELD, toolCall.function().arguments()); + builder.field(NAME_FIELD, toolCall.function().name()); } builder.endObject(); - builder.field("type", toolCall.type()); + builder.field(TYPE_FIELD, toolCall.type()); } builder.endObject(); } @@ -144,52 +111,55 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } builder.endArray(); - if (model != null) { - builder.field(MODEL_FIELD, model); + if (unifiedRequest.model() != null) { + builder.field(MODEL_FIELD, unifiedRequest.model()); } - if (maxCompletionTokens != null) { - builder.field(MAX_COMPLETION_TOKENS_FIELD, maxCompletionTokens); + if (unifiedRequest.maxCompletionTokens() != null) { + builder.field(MAX_COMPLETION_TOKENS_FIELD, unifiedRequest.maxCompletionTokens()); } - if (n != null) { - builder.field(NUMBER_OF_RETURNED_CHOICES_FIELD, n); + if (unifiedRequest.n() != null) { + builder.field(NUMBER_OF_RETURNED_CHOICES_FIELD, unifiedRequest.n()); } - if (stop != null) { - switch (stop) { + if (unifiedRequest.stop() != null) { + switch (unifiedRequest.stop()) { case UnifiedCompletionRequest.StopString stopString -> builder.field(STOP_FIELD, stopString.value()); case UnifiedCompletionRequest.StopValues stopValues -> builder.field(STOP_FIELD, stopValues.values()); } } - if (temperature != null) { - builder.field(TEMPERATURE_FIELD, temperature); + if (unifiedRequest.temperature() != null) { + builder.field(TEMPERATURE_FIELD, unifiedRequest.temperature()); } - if (toolChoice != null) { - if (toolChoice instanceof UnifiedCompletionRequest.ToolChoiceString) { - builder.field(TOOL_CHOICE_FIELD, ((UnifiedCompletionRequest.ToolChoiceString) toolChoice).value()); - } else if (toolChoice instanceof UnifiedCompletionRequest.ToolChoiceObject) { + if (unifiedRequest.toolChoice() != null) { + if (unifiedRequest.toolChoice() instanceof UnifiedCompletionRequest.ToolChoiceString) { + builder.field(TOOL_CHOICE_FIELD, ((UnifiedCompletionRequest.ToolChoiceString) unifiedRequest.toolChoice()).value()); + } else if (unifiedRequest.toolChoice() instanceof UnifiedCompletionRequest.ToolChoiceObject) { builder.startObject(TOOL_CHOICE_FIELD); { - builder.field("type", ((UnifiedCompletionRequest.ToolChoiceObject) toolChoice).type()); - builder.startObject("function"); + builder.field(TYPE_FIELD, ((UnifiedCompletionRequest.ToolChoiceObject) unifiedRequest.toolChoice()).type()); + builder.startObject(FUNCTION_FIELD); { - builder.field("name", ((UnifiedCompletionRequest.ToolChoiceObject) toolChoice).function().name()); + builder.field( + NAME_FIELD, + ((UnifiedCompletionRequest.ToolChoiceObject) unifiedRequest.toolChoice()).function().name() + ); } builder.endObject(); } builder.endObject(); } } - if (tool != null) { + if (unifiedRequest.tool() != null) { builder.startArray(TOOL_FIELD); - for (UnifiedCompletionRequest.Tool t : tool) { + for (UnifiedCompletionRequest.Tool t : unifiedRequest.tool()) { builder.startObject(); { - builder.field("type", t.type()); - builder.startObject("function"); + builder.field(TYPE_FIELD, t.type()); + builder.startObject(FUNCTION_FIELD); { - builder.field("description", t.function().description()); - builder.field("name", t.function().name()); - builder.field("parameters", t.function().parameters()); - builder.field("strict", t.function().strict()); + builder.field(DESCRIPTION_FIELD, t.function().description()); + builder.field(NAME_FIELD, t.function().name()); + builder.field(PARAMETERS_FIELD, t.function().parameters()); + builder.field(STRICT_FIELD, t.function().strict()); } builder.endObject(); } @@ -197,12 +167,13 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } builder.endArray(); } - if (topP != null) { - builder.field(TOP_P_FIELD, topP); + if (unifiedRequest.topP() != null) { + builder.field(TOP_P_FIELD, unifiedRequest.topP()); } - if (Strings.isNullOrEmpty(user) == false) { - builder.field(USER_FIELD, user); + if (unifiedRequest.user() != null && unifiedRequest.user().isEmpty() == false) { + builder.field(USER_FIELD, unifiedRequest.user()); } + builder.field(STREAM_FIELD, unifiedRequest.stream()); builder.endObject(); return builder; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java index 8340852866db9..2de12886873df 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java @@ -21,12 +21,13 @@ import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xpack.inference.external.http.sender.CompletionInputs; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.external.request.openai.OpenAiUnifiedChatCompletionRequestEntity; import java.io.IOException; import java.util.EnumSet; @@ -73,7 +74,7 @@ public void infer( private static InferenceInputs createInput(Model model, List input, @Nullable String query, boolean stream) { return switch (model.getTaskType()) { // TODO implement parameters - case COMPLETION -> new CompletionInputs(null); + case COMPLETION -> new UnifiedChatInput(null); case RERANK -> new QueryAndDocsInputs(query, input, stream); case TEXT_EMBEDDING -> new DocumentsOnlyInput(input, stream); default -> throw new ElasticsearchStatusException( @@ -84,9 +85,14 @@ private static InferenceInputs createInput(Model model, List input, @Nul } @Override - public void completionInfer(Model model, Object parameters, TimeValue timeout, ActionListener listener) { + public void completionInfer( + Model model, + OpenAiUnifiedChatCompletionRequestEntity parameters, + TimeValue timeout, + ActionListener listener + ) { init(); - doUnifiedCompletionInfer(model, new CompletionInputs(parameters), timeout, listener); + doUnifiedCompletionInfer(model, new UnifiedChatInput(parameters), timeout, listener); } @Override @@ -116,7 +122,7 @@ protected abstract void doInfer( protected abstract void doUnifiedCompletionInfer( Model model, - CompletionInputs inputs, + UnifiedChatInput inputs, TimeValue timeout, ActionListener listener ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java index 538fba607b4cb..2986dab872103 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java @@ -35,11 +35,11 @@ import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction; import org.elasticsearch.xpack.inference.external.action.openai.OpenAiActionCreator; -import org.elasticsearch.xpack.inference.external.http.sender.CompletionInputs; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.http.sender.OpenAiCompletionRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; @@ -266,7 +266,7 @@ public void doInfer( @Override public void doUnifiedCompletionInfer( Model model, - CompletionInputs inputs, + UnifiedChatInput inputs, TimeValue timeout, ActionListener listener ) { From bf817d00f732c76dab8c8b62edbea35b47632894 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Tue, 26 Nov 2024 11:44:56 -0500 Subject: [PATCH 10/53] Reworking unified inputs --- .../inference/InferenceService.java | 7 +- .../inference}/UnifiedCompletionRequest.java | 6 +- .../inference/action/InferenceAction.java | 96 +------------------ .../action/UnifiedCompletionAction.java | 8 +- .../action/UnifiedCompletionRequestTests.java | 4 +- ...sportUnifiedCompletionInferenceAction.java | 2 +- .../external/http/sender/InferenceInputs.java | 8 +- .../OpenAiCompletionRequestManager.java | 7 +- .../http/sender/UnifiedChatInput.java | 48 +++++++--- .../OpenAiUnifiedChatCompletionRequest.java | 13 ++- ...nAiUnifiedChatCompletionRequestEntity.java | 28 ++---- .../inference/services/SenderService.java | 10 +- 12 files changed, 84 insertions(+), 153 deletions(-) rename {x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action => server/src/main/java/org/elasticsearch/inference}/UnifiedCompletionRequest.java (98%) diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceService.java b/server/src/main/java/org/elasticsearch/inference/InferenceService.java index 45dc49308008f..674165612697e 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceService.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceService.java @@ -115,14 +115,13 @@ void infer( * Perform completion inference on the model using the unified schema. * * @param model The model - * @param parameters Parameters for the request + * @param request Parameters for the request * @param timeout The timeout for the request * @param listener Inference result listener */ - void completionInfer( + void unifiedCompletionInfer( Model model, - // TODO create the class for this object - Object parameters, + UnifiedCompletionRequest request, TimeValue timeout, ActionListener listener ); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequest.java b/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java similarity index 98% rename from x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequest.java rename to server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java index 67a12248e55df..b3fb1ea9a425f 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequest.java +++ b/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.core.inference.action; +package org.elasticsearch.inference; import org.elasticsearch.common.io.stream.NamedWriteable; import org.elasticsearch.common.io.stream.StreamInput; @@ -44,8 +44,8 @@ public record UnifiedCompletionRequest( public sealed interface Content extends NamedWriteable permits ContentObjects, ContentString {} @SuppressWarnings("unchecked") - static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( - InferenceAction.NAME, + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + UnifiedCompletionRequest.class.getSimpleName(), args -> new UnifiedCompletionRequest( (List) args[0], (String) args[1], diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java index 4d342c4fce701..f88909ba4208e 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java @@ -19,7 +19,6 @@ import org.elasticsearch.common.xcontent.ChunkedToXContent; import org.elasticsearch.common.xcontent.ChunkedToXContentHelper; import org.elasticsearch.common.xcontent.ChunkedToXContentObject; -import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.inference.InferenceServiceResults; @@ -93,8 +92,6 @@ public static Builder parseRequest(String inferenceEntityId, TaskType taskType, private final InputType inputType; private final TimeValue inferenceTimeout; private final boolean stream; - private final boolean isUnifiedCompletionMode; - private final UnifiedCompletionRequest unifiedCompletionRequest; public Request( TaskType taskType, @@ -104,34 +101,7 @@ public Request( Map taskSettings, InputType inputType, TimeValue inferenceTimeout, - boolean stream, - boolean isUnifiedCompletionsMode - ) { - this( - taskType, - inferenceEntityId, - query, - input, - taskSettings, - inputType, - inferenceTimeout, - stream, - isUnifiedCompletionsMode, - null - ); - } - - public Request( - TaskType taskType, - String inferenceEntityId, - String query, - List input, - Map taskSettings, - InputType inputType, - TimeValue inferenceTimeout, - boolean stream, - boolean isUnifiedCompletionsMode, - @Nullable UnifiedCompletionRequest unifiedCompletionRequest + boolean stream ) { this.taskType = taskType; this.inferenceEntityId = inferenceEntityId; @@ -141,8 +111,6 @@ public Request( this.inputType = inputType; this.inferenceTimeout = inferenceTimeout; this.stream = stream; - this.isUnifiedCompletionMode = isUnifiedCompletionsMode; - this.unifiedCompletionRequest = unifiedCompletionRequest; } public Request(StreamInput in) throws IOException { @@ -169,14 +137,6 @@ public Request(StreamInput in) throws IOException { this.inferenceTimeout = DEFAULT_TIMEOUT; } - if (in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_UNIFIED_COMPLETIONS_API)) { - this.isUnifiedCompletionMode = in.readBoolean(); - this.unifiedCompletionRequest = in.readOptionalWriteable(UnifiedCompletionRequest::new); - } else { - this.isUnifiedCompletionMode = false; - this.unifiedCompletionRequest = null; - } - // streaming is not supported yet for transport traffic this.stream = false; } @@ -213,10 +173,6 @@ public boolean isStreaming() { return stream; } - public boolean isUnifiedCompletionMode() { - return isUnifiedCompletionMode; - } - @Override public ActionRequestValidationException validate() { if (input == null) { @@ -242,10 +198,6 @@ public ActionRequestValidationException validate() { e.addValidationError(format("Field [query] cannot be empty for task type [%s]", TaskType.RERANK)); return e; } - } else if (query != null) { - var e = new ActionRequestValidationException(); - e.addValidationError(format("Task type [%s] does not support field [query]", TaskType.RERANK)); - return e; } return null; @@ -271,11 +223,6 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalString(query); out.writeTimeValue(inferenceTimeout); } - - if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_UNIFIED_COMPLETIONS_API)) { - out.writeBoolean(isUnifiedCompletionMode); - out.writeOptionalWriteable(unifiedCompletionRequest); - } } // default for easier testing @@ -302,22 +249,12 @@ public boolean equals(Object o) { && Objects.equals(taskSettings, request.taskSettings) && Objects.equals(inputType, request.inputType) && Objects.equals(query, request.query) - && Objects.equals(inferenceTimeout, request.inferenceTimeout) - && Objects.equals(isUnifiedCompletionMode, request.isUnifiedCompletionMode); + && Objects.equals(inferenceTimeout, request.inferenceTimeout); } @Override public int hashCode() { - return Objects.hash( - taskType, - inferenceEntityId, - input, - taskSettings, - inputType, - query, - inferenceTimeout, - isUnifiedCompletionMode - ); + return Objects.hash(taskType, inferenceEntityId, input, taskSettings, inputType, query, inferenceTimeout); } public static class Builder { @@ -330,8 +267,6 @@ public static class Builder { private String query; private TimeValue timeout = DEFAULT_TIMEOUT; private boolean stream = false; - private boolean unifiedCompletionMode = false; - private UnifiedCompletionRequest unifiedCompletionRequest; private Builder() {} @@ -379,29 +314,8 @@ public Builder setStream(boolean stream) { return this; } - public Builder setUnifiedCompletionMode(boolean unified) { - this.unifiedCompletionMode = unified; - return this; - } - - public Builder setUnifiedCompletionRequest(UnifiedCompletionRequest unifiedCompletionRequest) { - this.unifiedCompletionRequest = unifiedCompletionRequest; - return this; - } - public Request build() { - return new Request( - taskType, - inferenceEntityId, - query, - input, - taskSettings, - inputType, - timeout, - stream, - unifiedCompletionMode, - unifiedCompletionRequest - ); + return new Request(taskType, inferenceEntityId, query, input, taskSettings, inputType, timeout, stream); } } @@ -420,8 +334,6 @@ public String toString() { + this.getInputType() + ", timeout=" + this.getInferenceTimeout() - + ", isUnifiedCompletionsMode=" - + this.isUnifiedCompletionMode() + ")"; } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionAction.java index 13db70dd04f72..39188540cc7eb 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionAction.java @@ -13,6 +13,7 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.xcontent.XContentParser; import java.io.IOException; @@ -27,7 +28,8 @@ public UnifiedCompletionAction() { } public static class Request extends BaseInferenceActionRequest { - public static Request parseRequest(String inferenceEntityId, TaskType taskType, TimeValue timeout, XContentParser parser) throws IOException { + public static Request parseRequest(String inferenceEntityId, TaskType taskType, TimeValue timeout, XContentParser parser) + throws IOException { var unifiedRequest = UnifiedCompletionRequest.PARSER.apply(parser, null); return new Request(inferenceEntityId, taskType, unifiedRequest, timeout); } @@ -110,8 +112,8 @@ public boolean equals(Object o) { Request request = (Request) o; return Objects.equals(inferenceEntityId, request.inferenceEntityId) && taskType == request.taskType - && Objects.equals(unifiedCompletionRequest, request.unifiedCompletionRequest) && - Objects.equals(timeout, request.timeout); + && Objects.equals(unifiedCompletionRequest, request.unifiedCompletionRequest) + && Objects.equals(timeout, request.timeout); } @Override diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequestTests.java index 58d21ebef3038..6f907bd83d17a 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequestTests.java @@ -9,6 +9,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.xcontent.json.JsonXContent; import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; @@ -99,7 +100,6 @@ public void testParseAllFields() throws IOException { 100L, 1, new UnifiedCompletionRequest.StopValues(List.of("stop")), - true, 0.1F, new UnifiedCompletionRequest.ToolChoiceObject( "function", @@ -168,7 +168,6 @@ public void testParsing() throws IOException { null, new UnifiedCompletionRequest.StopString("none"), null, - null, new UnifiedCompletionRequest.ToolChoiceString("auto"), List.of( new UnifiedCompletionRequest.Tool( @@ -196,7 +195,6 @@ public static UnifiedCompletionRequest randomUnifiedCompletionRequest() { randomNullOrLong(), randomNullOrInt(), randomNullOrStop(), - randomOptionalBoolean(), randomNullOrFloat(), randomNullOrToolChoice(), randomList(5, UnifiedCompletionRequestTests::randomTool), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java index 497929b3e5848..4291bc046c919 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java @@ -71,6 +71,6 @@ protected void doInference( InferenceService service, ActionListener listener ) { - service.completionInfer(model, request.getUnifiedCompletionRequest(), null, listener); + service.unifiedCompletionInfer(model, request.getUnifiedCompletionRequest(), null, listener); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputs.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputs.java index 45244a25db891..7b0e240a1dc40 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputs.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputs.java @@ -14,11 +14,11 @@ public static IllegalArgumentException createUnsupportedTypeException(InferenceI return new IllegalArgumentException(Strings.format("Unsupported inference inputs type: [%s]", inferenceInputs.getClass())); } - public static T abc(InferenceInputs inputs, Class clazz) { - if (inputs.getClass().isInstance(clazz) == false) { - throw createUnsupportedTypeException(inputs); + public T castTo(Class clazz) { + if (this.getClass().isInstance(clazz) == false) { + throw createUnsupportedTypeException(this); } - return clazz.cast(inputs); + return clazz.cast(this); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java index 0717eef1e52e5..6f9f0c1674ead 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java @@ -47,10 +47,9 @@ public void execute( ActionListener listener ) { - OpenAiUnifiedChatCompletionRequest request = new OpenAiUnifiedChatCompletionRequest( - UnifiedChatInput.of(inferenceInputs).getRequestEntity(), - model - ); + // TODO check and see if this works +// OpenAiUnifiedChatCompletionRequest request = new OpenAiUnifiedChatCompletionRequest(UnifiedChatInput.of(inferenceInputs), model); + OpenAiUnifiedChatCompletionRequest request = new OpenAiUnifiedChatCompletionRequest(inferenceInputs.castTo(UnifiedChatInput.class), model); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInput.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInput.java index ef2fc7fafc2bd..3c3577b9ec116 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInput.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInput.java @@ -7,34 +7,60 @@ package org.elasticsearch.xpack.inference.external.http.sender; -import org.elasticsearch.xpack.inference.external.request.openai.OpenAiUnifiedChatCompletionRequestEntity; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import java.util.List; import java.util.Objects; +import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD; + public class UnifiedChatInput extends InferenceInputs { public static UnifiedChatInput of(InferenceInputs inferenceInputs) { - - if (inferenceInputs instanceof DocumentsOnlyInput docsOnly) { - return new UnifiedChatInput(new OpenAiUnifiedChatCompletionRequestEntity(docsOnly)); - } else if (inferenceInputs instanceof UnifiedChatInput == false) { + if (inferenceInputs instanceof UnifiedChatInput == false) { throw createUnsupportedTypeException(inferenceInputs); } return (UnifiedChatInput) inferenceInputs; } - public OpenAiUnifiedChatCompletionRequestEntity getRequestEntity() { - return requestEntity; + public static UnifiedChatInput of(List input, boolean stream) { + var unifiedRequest = new UnifiedCompletionRequest( + convertToMessages(input), + null, + null, + null, + null, + null, + null, + null, + null, + // TODO we need to get the user field from task settings if it is there + null + ); + + return new UnifiedChatInput(unifiedRequest, stream); } - private final OpenAiUnifiedChatCompletionRequestEntity requestEntity; + private static List convertToMessages(List inputs) { + return inputs.stream() + .map(doc -> new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString(doc), USER_FIELD, null, null, null)) + .toList(); + } + + private final UnifiedCompletionRequest request; + private final boolean stream; + + public UnifiedChatInput(UnifiedCompletionRequest request, boolean stream) { + this.request = Objects.requireNonNull(request); + this.stream = stream; + } - public UnifiedChatInput(OpenAiUnifiedChatCompletionRequestEntity requestEntity) { - this.requestEntity = Objects.requireNonNull(requestEntity); + public UnifiedCompletionRequest getRequest() { + return request; } public boolean stream() { - return requestEntity.isStream(); + return stream; } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequest.java index 40e1145e0b256..07add828394f5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequest.java @@ -13,6 +13,7 @@ import org.apache.http.entity.ByteArrayEntity; import org.elasticsearch.common.Strings; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.external.openai.OpenAiAccount; import org.elasticsearch.xpack.inference.external.request.HttpRequest; import org.elasticsearch.xpack.inference.external.request.Request; @@ -29,12 +30,12 @@ public class OpenAiUnifiedChatCompletionRequest implements OpenAiRequest { private final OpenAiAccount account; - private final OpenAiUnifiedChatCompletionRequestEntity requestEntity; private final OpenAiChatCompletionModel model; + private final UnifiedChatInput unifiedChatInput; - public OpenAiUnifiedChatCompletionRequest(OpenAiUnifiedChatCompletionRequestEntity requestEntity, OpenAiChatCompletionModel model) { + public OpenAiUnifiedChatCompletionRequest(UnifiedChatInput unifiedChatInput, OpenAiChatCompletionModel model) { this.account = OpenAiAccount.of(model, OpenAiUnifiedChatCompletionRequest::buildDefaultUri); - this.requestEntity = Objects.requireNonNull(requestEntity); + this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput); this.model = Objects.requireNonNull(model); } @@ -42,7 +43,9 @@ public OpenAiUnifiedChatCompletionRequest(OpenAiUnifiedChatCompletionRequestEnti public HttpRequest createHttpRequest() { HttpPost httpPost = new HttpPost(account.uri()); - ByteArrayEntity byteEntity = new ByteArrayEntity(Strings.toString(requestEntity).getBytes(StandardCharsets.UTF_8)); + ByteArrayEntity byteEntity = new ByteArrayEntity( + Strings.toString(new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput)).getBytes(StandardCharsets.UTF_8) + ); httpPost.setEntity(byteEntity); httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()); @@ -80,7 +83,7 @@ public String getInferenceEntityId() { @Override public boolean isStreaming() { - return requestEntity.isStream(); + return unifiedChatInput.stream(); } public static URI buildDefaultUri() throws URISyntaxException { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java index b082c80311ba4..6c567861ce66d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java @@ -7,14 +7,13 @@ package org.elasticsearch.xpack.inference.external.request.openai; +import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionRequest; -import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; -import org.elasticsearch.xpack.inference.external.request.UnifiedRequest; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import java.io.IOException; -import java.util.List; +import java.util.Objects; public class OpenAiUnifiedChatCompletionRequestEntity implements ToXContentObject { @@ -43,21 +42,14 @@ public class OpenAiUnifiedChatCompletionRequestEntity implements ToXContentObjec private static final String TEXT_FIELD = "text"; private static final String TYPE_FIELD = "type"; - private final UnifiedRequest unifiedRequest; + private final UnifiedCompletionRequest unifiedRequest; + private final boolean stream; - public OpenAiUnifiedChatCompletionRequestEntity(UnifiedRequest unifiedRequest) { - this.unifiedRequest = unifiedRequest; - } - - public OpenAiUnifiedChatCompletionRequestEntity(DocumentsOnlyInput input) { - this(new UnifiedRequest(convertDocumentsOnlyInputToMessages(input), null, null, null, null, null, null, null, null, null, true)); - } + public OpenAiUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput) { + Objects.requireNonNull(unifiedChatInput); - private static List convertDocumentsOnlyInputToMessages(DocumentsOnlyInput input) { - return input.getInputs() - .stream() - .map(doc -> new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString(doc), USER_FIELD, null, null, null)) - .toList(); + this.unifiedRequest = unifiedChatInput.getRequest(); + this.stream = unifiedChatInput.stream(); } @Override @@ -173,7 +165,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (unifiedRequest.user() != null && unifiedRequest.user().isEmpty() == false) { builder.field(USER_FIELD, unifiedRequest.user()); } - builder.field(STREAM_FIELD, unifiedRequest.stream()); + builder.field(STREAM_FIELD, stream); builder.endObject(); return builder; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java index 2de12886873df..95c84859a28ef 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java @@ -20,6 +20,7 @@ import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; @@ -27,7 +28,6 @@ import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; -import org.elasticsearch.xpack.inference.external.request.openai.OpenAiUnifiedChatCompletionRequestEntity; import java.io.IOException; import java.util.EnumSet; @@ -74,7 +74,7 @@ public void infer( private static InferenceInputs createInput(Model model, List input, @Nullable String query, boolean stream) { return switch (model.getTaskType()) { // TODO implement parameters - case COMPLETION -> new UnifiedChatInput(null); + case COMPLETION -> UnifiedChatInput.of(input, stream); case RERANK -> new QueryAndDocsInputs(query, input, stream); case TEXT_EMBEDDING -> new DocumentsOnlyInput(input, stream); default -> throw new ElasticsearchStatusException( @@ -85,14 +85,14 @@ private static InferenceInputs createInput(Model model, List input, @Nul } @Override - public void completionInfer( + public void unifiedCompletionInfer( Model model, - OpenAiUnifiedChatCompletionRequestEntity parameters, + UnifiedCompletionRequest request, TimeValue timeout, ActionListener listener ) { init(); - doUnifiedCompletionInfer(model, new UnifiedChatInput(parameters), timeout, listener); + doUnifiedCompletionInfer(model, new UnifiedChatInput(request, true), timeout, listener); } @Override From 81a05b77f6c2a41553865a8516e95fcc24f53465 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Tue, 26 Nov 2024 12:04:53 -0500 Subject: [PATCH 11/53] Adding unsupported operation calls --- .../sender/OpenAiCompletionRequestManager.java | 7 +++++-- .../xpack/inference/services/ServiceUtils.java | 4 ++++ .../AlibabaCloudSearchService.java | 14 ++++++++++++-- .../amazonbedrock/AmazonBedrockService.java | 12 ++++++++++++ .../services/anthropic/AnthropicService.java | 12 ++++++++++++ .../azureaistudio/AzureAiStudioService.java | 12 ++++++++++++ .../services/azureopenai/AzureOpenAiService.java | 12 ++++++++++++ .../inference/services/cohere/CohereService.java | 12 ++++++++++++ .../services/elastic/ElasticInferenceService.java | 12 ++++++++++++ .../ElasticsearchInternalService.java | 12 ++++++++++++ .../googleaistudio/GoogleAiStudioService.java | 12 ++++++++++++ .../googlevertexai/GoogleVertexAiService.java | 12 ++++++++++++ .../services/huggingface/HuggingFaceService.java | 13 +++++++++++++ .../huggingface/elser/HuggingFaceElserService.java | 12 ++++++++++++ .../services/ibmwatsonx/IbmWatsonxService.java | 12 ++++++++++++ .../inference/services/mistral/MistralService.java | 12 ++++++++++++ 16 files changed, 178 insertions(+), 4 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java index 6f9f0c1674ead..372a57174315a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java @@ -48,8 +48,11 @@ public void execute( ) { // TODO check and see if this works -// OpenAiUnifiedChatCompletionRequest request = new OpenAiUnifiedChatCompletionRequest(UnifiedChatInput.of(inferenceInputs), model); - OpenAiUnifiedChatCompletionRequest request = new OpenAiUnifiedChatCompletionRequest(inferenceInputs.castTo(UnifiedChatInput.class), model); + // OpenAiUnifiedChatCompletionRequest request = new OpenAiUnifiedChatCompletionRequest(UnifiedChatInput.of(inferenceInputs), model); + OpenAiUnifiedChatCompletionRequest request = new OpenAiUnifiedChatCompletionRequest( + inferenceInputs.castTo(UnifiedChatInput.class), + model + ); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index ec4b8d9bb4d3d..7d05bac363fb1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -776,5 +776,9 @@ public static T nonNullOrDefault(@Nullable T requestValue, @Nullable T origi return requestValue == null ? originalSettingsValue : requestValue; } + public static void throwUnsupportedUnifiedCompletionOperation(String serviceName) { + throw new UnsupportedOperationException(Strings.format("The %s service does not support unified completion", serviceName)); + } + private ServiceUtils() {} } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java index c84b4314b9d1a..d402be04c8a0b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java @@ -39,6 +39,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.external.request.alibabacloudsearch.AlibabaCloudSearchUtils; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; @@ -58,8 +59,6 @@ import java.util.Map; import java.util.stream.Stream; -import static org.elasticsearch.inference.TaskType.SPARSE_EMBEDDING; -import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING; import static org.elasticsearch.xpack.core.inference.action.InferenceAction.Request.DEFAULT_TIMEOUT; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; @@ -67,6 +66,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; import static org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchServiceFields.EMBEDDING_MAX_BATCH_SIZE; import static org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchServiceSettings.HOST; import static org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchServiceSettings.HTTP_SCHEMA_NAME; @@ -263,6 +263,16 @@ public AlibabaCloudSearchModel parsePersistedConfig(String inferenceEntityId, Ta ); } + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + @Override public void doInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java index f9822c7ab4af9..695ed452ecda8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java @@ -41,6 +41,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; @@ -64,6 +65,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.MODEL_FIELD; import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.PROVIDER_FIELD; import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.REGION_FIELD; @@ -89,6 +91,16 @@ public AmazonBedrockService( this.amazonBedrockSender = amazonBedrockFactory.createSender(); } + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + @Override protected void doInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java index 556b34b945c14..59994de37b0b9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java @@ -33,6 +33,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; @@ -52,6 +53,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; public class AnthropicService extends SenderService { public static final String NAME = "anthropic"; @@ -186,6 +188,16 @@ public EnumSet supportedTaskTypes() { return supportedTaskTypes; } + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + @Override public void doInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java index a2f8dc409585e..7261e5df80fd4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java @@ -39,6 +39,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; @@ -64,6 +65,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.ENDPOINT_TYPE_FIELD; import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.PROVIDER_FIELD; import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.TARGET_FIELD; @@ -82,6 +84,16 @@ public AzureAiStudioService(HttpRequestSender.Factory factory, ServiceComponents super(factory, serviceComponents); } + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + @Override protected void doInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java index 6d36e5f6c8fe7..30c8e1037c202 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java @@ -38,6 +38,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; @@ -59,6 +60,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields.API_VERSION; import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields.DEPLOYMENT_ID; import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields.RESOURCE_NAME; @@ -234,6 +236,16 @@ public EnumSet supportedTaskTypes() { return supportedTaskTypes; } + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + @Override protected void doInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java index de1d055e160da..993bfb42b18c1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java @@ -35,6 +35,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; @@ -58,6 +59,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; import static org.elasticsearch.xpack.inference.services.cohere.CohereServiceFields.EMBEDDING_MAX_BATCH_SIZE; public class CohereService extends SenderService { @@ -232,6 +234,16 @@ public EnumSet supportedTaskTypes() { return supportedTaskTypes; } + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + @Override public void doInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 98429ed3d001d..6680366977db7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -39,6 +39,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; @@ -58,6 +59,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; public class ElasticInferenceService extends SenderService { @@ -76,6 +78,16 @@ public ElasticInferenceService( this.elasticInferenceServiceComponents = eisComponents; } + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + @Override protected void doInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index fe83acc8574aa..3f9f3290ea689 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -32,6 +32,7 @@ import org.elasticsearch.inference.SettingsConfiguration; import org.elasticsearch.inference.TaskSettingsConfiguration; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.inference.configuration.SettingsConfigurationDisplayType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.inference.configuration.SettingsConfigurationSelectOption; @@ -78,6 +79,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings.MODEL_ID; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings.NUM_ALLOCATIONS; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings.NUM_THREADS; @@ -541,6 +543,16 @@ private static CustomElandEmbeddingModel updateModelWithEmbeddingDetails(CustomE ); } + @Override + public void unifiedCompletionInfer( + Model model, + UnifiedCompletionRequest request, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + @Override public void infer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java index 1c01ebbe2c0e4..39b81e249d940 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java @@ -40,6 +40,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.GoogleAiStudioEmbeddingsRequestManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; @@ -65,6 +66,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; import static org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioServiceFields.EMBEDDING_MAX_BATCH_SIZE; public class GoogleAiStudioService extends SenderService { @@ -309,6 +311,16 @@ protected void doInfer( } } + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + @Override protected void doChunkedInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java index a05b1a937d376..e8dde29569ad7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java @@ -37,6 +37,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; @@ -58,6 +59,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.EMBEDDING_MAX_BATCH_SIZE; import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.LOCATION; import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.PROJECT_ID; @@ -214,6 +216,16 @@ protected void doInfer( action.execute(inputs, timeout, listener); } + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + @Override protected void doChunkedInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java index eede14a975234..58cb76395f8be 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java @@ -19,6 +19,7 @@ import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.SettingsConfiguration; @@ -32,6 +33,7 @@ import org.elasticsearch.xpack.inference.external.action.huggingface.HuggingFaceActionCreator; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.ServiceUtils; @@ -48,6 +50,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; public class HuggingFaceService extends HuggingFaceBaseService { public static final String NAME = "hugging_face"; @@ -141,6 +144,16 @@ protected void doChunkedInfer( } } + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + @Override public InferenceServiceConfiguration getConfiguration() { return Configuration.get(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java index a2e22e24172cf..627136a7a7c8c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java @@ -37,6 +37,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceBaseService; @@ -50,6 +51,7 @@ import java.util.Map; import static org.elasticsearch.xpack.core.inference.results.ResultUtils.createInvalidChunkedResultException; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; import static org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings.URL; public class HuggingFaceElserService extends HuggingFaceBaseService { @@ -82,6 +84,16 @@ protected HuggingFaceModel createModel( }; } + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + @Override protected void doChunkedInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java index e960b0b777f2b..95b2e91068a19 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java @@ -36,6 +36,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; @@ -55,6 +56,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; import static org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings.URL; import static org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxServiceFields.API_VERSION; import static org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxServiceFields.EMBEDDING_MAX_BATCH_SIZE; @@ -251,6 +253,16 @@ protected void doInfer( action.execute(input, timeout, listener); } + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + @Override protected void doChunkedInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java index 2e810c357f8bd..9adc4dbd6d3fc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java @@ -37,6 +37,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; @@ -59,6 +60,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; import static org.elasticsearch.xpack.inference.services.mistral.MistralConstants.MODEL_FIELD; public class MistralService extends SenderService { @@ -89,6 +91,16 @@ protected void doInfer( } } + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + @Override protected void doChunkedInfer( Model model, From cb440e1dd76647b7ddad2fd368dc8afc39f183d7 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Tue, 26 Nov 2024 14:15:57 -0500 Subject: [PATCH 12/53] Fixing parsing logic --- .../inference/UnifiedCompletionRequest.java | 24 ++++++++++++----- .../org/elasticsearch/test/ESTestCase.java | 8 ++++++ .../action/InferenceActionRequestTests.java | 21 --------------- .../action/UnifiedCompletionRequestTests.java | 14 +++++++--- .../xpack/inference/InferencePlugin.java | 19 ++++++++++--- .../{rest => }/UnifiedCompletionFeature.java | 6 ++--- .../external/request/UnifiedRequest.java | 27 ------------------- .../OpenAiChatCompletionRequestEntity.java | 1 + .../OpenAiUnifiedChatCompletionRequest.java | 2 +- ...nAiUnifiedChatCompletionRequestEntity.java | 12 ++++++--- .../services/openai/OpenAiService.java | 13 ++++----- .../completion/OpenAiChatCompletionModel.java | 6 +++++ ...enAiChatCompletionRequestTaskSettings.java | 5 ++++ .../OpenAiChatCompletionModelTests.java | 2 +- 14 files changed, 84 insertions(+), 76 deletions(-) rename x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/{rest => }/UnifiedCompletionFeature.java (76%) delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/UnifiedRequest.java diff --git a/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java b/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java index b3fb1ea9a425f..5bdd0f212e68f 100644 --- a/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java +++ b/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java @@ -8,6 +8,7 @@ package org.elasticsearch.inference; import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; @@ -52,11 +53,11 @@ public sealed interface Content extends NamedWriteable permits ContentObjects, C (Long) args[2], (Integer) args[3], (Stop) args[4], - (Float) args[6], - (ToolChoice) args[7], - (List) args[8], - (Float) args[9], - (String) args[10] + (Float) args[5], + (ToolChoice) args[6], + (List) args[7], + (Float) args[8], + (String) args[9] ) ); @@ -78,6 +79,17 @@ public sealed interface Content extends NamedWriteable permits ContentObjects, C PARSER.declareString(optionalConstructorArg(), new ParseField("user")); } + public static List getNamedWriteables() { + return List.of( + new NamedWriteableRegistry.Entry(Content.class, ContentObjects.NAME, ContentObjects::new), + new NamedWriteableRegistry.Entry(Content.class, ContentString.NAME, ContentString::new), + new NamedWriteableRegistry.Entry(ToolChoice.class, ToolChoiceObject.NAME, ToolChoiceObject::new), + new NamedWriteableRegistry.Entry(ToolChoice.class, ToolChoiceString.NAME, ToolChoiceString::new), + new NamedWriteableRegistry.Entry(Stop.class, StopValues.NAME, StopValues::new), + new NamedWriteableRegistry.Entry(Stop.class, StopString.NAME, StopString::new) + ); + } + public UnifiedCompletionRequest(StreamInput in) throws IOException { this( in.readCollectionAsImmutableList(Message::new), @@ -157,7 +169,7 @@ public void writeTo(StreamOutput out) throws IOException { } } - public record ContentObjects(List contentObjects) implements Content, Writeable { + public record ContentObjects(List contentObjects) implements Content, NamedWriteable { public static final String NAME = "content_objects"; diff --git a/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java index 67dc36cb29b6b..c2ce64aa63cc3 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java @@ -1213,10 +1213,18 @@ public static Long randomNullOrLong() { return randomBoolean() ? null : randomLong(); } + public static Long randomNullOrPositiveLong() { + return randomBoolean() ? null : randomLongBetween(0L, Long.MAX_VALUE); + } + public static Integer randomNullOrInt() { return randomBoolean() ? null : randomInt(); } + public static Integer randomNullOrPositiveInt() { + return randomBoolean() ? null : randomIntBetween(0, Integer.MAX_VALUE); + } + public static Float randomNullOrFloat() { return randomBoolean() ? null : randomFloat(); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java index a7eb9ce3e3fd0..a9ca5e6da8720 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java @@ -47,7 +47,6 @@ protected InferenceAction.Request createTestInstance() { randomMap(0, 3, () -> new Tuple<>(randomAlphaOfLength(4), randomAlphaOfLength(4))), randomFrom(InputType.values()), TimeValue.timeValueMillis(randomLongBetween(1, 2048)), - false, false ); } @@ -83,7 +82,6 @@ public void testValidation_TextEmbedding() { null, null, null, - false, false ); ActionRequestValidationException e = request.validate(); @@ -99,7 +97,6 @@ public void testValidation_Rerank() { null, null, null, - false, false ); ActionRequestValidationException e = request.validate(); @@ -115,7 +112,6 @@ public void testValidation_TextEmbedding_Null() { null, null, null, - false, false ); ActionRequestValidationException inputNullError = inputNullRequest.validate(); @@ -132,7 +128,6 @@ public void testValidation_TextEmbedding_Empty() { null, null, null, - false, false ); ActionRequestValidationException inputEmptyError = inputEmptyRequest.validate(); @@ -149,7 +144,6 @@ public void testValidation_Rerank_Null() { null, null, null, - false, false ); ActionRequestValidationException queryNullError = queryNullRequest.validate(); @@ -166,7 +160,6 @@ public void testValidation_Rerank_Empty() { null, null, null, - false, false ); ActionRequestValidationException queryEmptyError = queryEmptyRequest.validate(); @@ -200,7 +193,6 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc instance.getTaskSettings(), instance.getInputType(), instance.getInferenceTimeout(), - false, false ); } @@ -212,7 +204,6 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc instance.getTaskSettings(), instance.getInputType(), instance.getInferenceTimeout(), - false, false ); case 2 -> { @@ -226,7 +217,6 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc instance.getTaskSettings(), instance.getInputType(), instance.getInferenceTimeout(), - false, false ); } @@ -246,7 +236,6 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc taskSettings, instance.getInputType(), instance.getInferenceTimeout(), - false, false ); } @@ -260,7 +249,6 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc instance.getTaskSettings(), nextInputType, instance.getInferenceTimeout(), - false, false ); } @@ -272,7 +260,6 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc instance.getTaskSettings(), instance.getInputType(), instance.getInferenceTimeout(), - false, false ); case 6 -> { @@ -289,7 +276,6 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc instance.getTaskSettings(), instance.getInputType(), TimeValue.timeValueMillis(newDuration.plus(additionalTime).toMillis()), - false, false ); } @@ -308,7 +294,6 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque instance.getTaskSettings(), InputType.UNSPECIFIED, InferenceAction.Request.DEFAULT_TIMEOUT, - false, false ); } else if (version.before(TransportVersions.V_8_13_0)) { @@ -320,7 +305,6 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque instance.getTaskSettings(), InputType.UNSPECIFIED, InferenceAction.Request.DEFAULT_TIMEOUT, - false, false ); } else if (version.before(TransportVersions.V_8_13_0) @@ -335,7 +319,6 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque instance.getTaskSettings(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - false, false ); } else if (version.before(TransportVersions.V_8_13_0) @@ -348,7 +331,6 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque instance.getTaskSettings(), InputType.UNSPECIFIED, InferenceAction.Request.DEFAULT_TIMEOUT, - false, false ); } else if (version.before(TransportVersions.V_8_14_0)) { @@ -360,7 +342,6 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque instance.getTaskSettings(), instance.getInputType(), InferenceAction.Request.DEFAULT_TIMEOUT, - false, false ); } @@ -378,7 +359,6 @@ public void testWriteTo_WhenVersionIsOnAfterUnspecifiedAdded() throws IOExceptio Map.of(), InputType.UNSPECIFIED, InferenceAction.Request.DEFAULT_TIMEOUT, - false, false ), TransportVersions.V_8_13_0 @@ -394,7 +374,6 @@ public void testWriteTo_WhenVersionIsBeforeInputTypeAdded_ShouldSetInputTypeToUn Map.of(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, - false, false ); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequestTests.java index 6f907bd83d17a..612fadfab308f 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequestTests.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.core.inference.action; import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.xcontent.json.JsonXContent; @@ -51,7 +52,6 @@ public void testParseAllFields() throws IOException { "max_completion_tokens": 100, "n": 1, "stop": ["stop"], - "stream": true, "temperature": 0.1, "tools": [ { @@ -192,8 +192,8 @@ public static UnifiedCompletionRequest randomUnifiedCompletionRequest() { return new UnifiedCompletionRequest( randomList(5, UnifiedCompletionRequestTests::randomMessage), randomNullOrAlphaOfLength(10), - randomNullOrLong(), - randomNullOrInt(), + randomNullOrPositiveLong(), + randomNullOrPositiveInt(), randomNullOrStop(), randomNullOrFloat(), randomNullOrToolChoice(), @@ -287,4 +287,12 @@ protected UnifiedCompletionRequest createTestInstance() { protected UnifiedCompletionRequest mutateInstance(UnifiedCompletionRequest instance) throws IOException { return randomValueOtherThan(instance, this::createTestInstance); } + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + // List entries = new ArrayList<>(); + // entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables()); + // entries.addAll(InferenceNamedWriteablesProvider.getNamedWriteables()); + return new NamedWriteableRegistry(UnifiedCompletionRequest.getNamedWriteables()); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index e61c25291daf7..d0469f923f101 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -83,6 +83,7 @@ import org.elasticsearch.xpack.inference.rest.RestInferenceAction; import org.elasticsearch.xpack.inference.rest.RestPutInferenceModelAction; import org.elasticsearch.xpack.inference.rest.RestStreamInferenceAction; +import org.elasticsearch.xpack.inference.rest.RestUnifiedCompletionInferenceAction; import org.elasticsearch.xpack.inference.rest.RestUpdateInferenceModelAction; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchService; @@ -152,9 +153,9 @@ public InferencePlugin(Settings settings) { @Override public List> getActions() { - return List.of( + var availableActions = List.of( new ActionHandler<>(InferenceAction.INSTANCE, TransportInferenceAction.class), - new ActionHandler<>(UnifiedCompletionAction.INSTANCE, TransportUnifiedCompletionInferenceAction.class), + new ActionHandler<>(GetInferenceModelAction.INSTANCE, TransportGetInferenceModelAction.class), new ActionHandler<>(PutInferenceModelAction.INSTANCE, TransportPutInferenceModelAction.class), new ActionHandler<>(UpdateInferenceModelAction.INSTANCE, TransportUpdateInferenceModelAction.class), @@ -163,6 +164,13 @@ public InferencePlugin(Settings settings) { new ActionHandler<>(GetInferenceDiagnosticsAction.INSTANCE, TransportGetInferenceDiagnosticsAction.class), new ActionHandler<>(GetInferenceServicesAction.INSTANCE, TransportGetInferenceServicesAction.class) ); + + List> conditionalActions = + UnifiedCompletionFeature.UNIFIED_COMPLETION_FEATURE_FLAG.isEnabled() + ? List.of(new ActionHandler<>(UnifiedCompletionAction.INSTANCE, TransportUnifiedCompletionInferenceAction.class)) + : List.of(); + + return Stream.concat(availableActions.stream(), conditionalActions.stream()).toList(); } @Override @@ -177,7 +185,7 @@ public List getRestHandlers( Supplier nodesInCluster, Predicate clusterSupportsFeature ) { - return List.of( + var availableRestActions = List.of( new RestInferenceAction(), new RestStreamInferenceAction(), new RestGetInferenceModelAction(), @@ -187,6 +195,11 @@ public List getRestHandlers( new RestGetInferenceDiagnosticsAction(), new RestGetInferenceServicesAction() ); + List conditionalRestActions = UnifiedCompletionFeature.UNIFIED_COMPLETION_FEATURE_FLAG.isEnabled() + ? List.of(new RestUnifiedCompletionInferenceAction()) + : List.of(); + + return Stream.concat(availableRestActions.stream(), conditionalRestActions.stream()).toList(); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/UnifiedCompletionFeature.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/UnifiedCompletionFeature.java similarity index 76% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/UnifiedCompletionFeature.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/UnifiedCompletionFeature.java index a02e5591174d4..3e13d0c1e39de 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/UnifiedCompletionFeature.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/UnifiedCompletionFeature.java @@ -5,16 +5,16 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.rest; +package org.elasticsearch.xpack.inference; import org.elasticsearch.common.util.FeatureFlag; /** * Unified Completion feature flag. When the feature is complete, this flag will be removed. - * Enable feature via JVM option: `-Des.unified_feature_flag_enabled=true`. + * Enable feature via JVM option: `-Des.inference_unified_feature_flag_enabled=true`. */ public class UnifiedCompletionFeature { - public static final FeatureFlag UNIFIED_COMPLETION_FEATURE_FLAG = new FeatureFlag("unified"); + public static final FeatureFlag UNIFIED_COMPLETION_FEATURE_FLAG = new FeatureFlag("inference_unified"); private UnifiedCompletionFeature() {} } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/UnifiedRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/UnifiedRequest.java deleted file mode 100644 index e994d29f89426..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/UnifiedRequest.java +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.external.request; - -import org.elasticsearch.core.Nullable; -import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionRequest; - -import java.util.List; - -public record UnifiedRequest( - List messages, - @Nullable String model, - @Nullable Long maxCompletionTokens, - @Nullable Integer n, - @Nullable UnifiedCompletionRequest.Stop stop, - @Nullable Float temperature, - @Nullable UnifiedCompletionRequest.ToolChoice toolChoice, - @Nullable List tool, - @Nullable Float topP, - @Nullable String user, - boolean stream -) {} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequestEntity.java index 867a7ca80cbcb..2332e70589104 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequestEntity.java @@ -15,6 +15,7 @@ import java.util.List; import java.util.Objects; +// TODO remove this public class OpenAiChatCompletionRequestEntity implements ToXContentObject { private static final String MESSAGES_FIELD = "messages"; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequest.java index 07add828394f5..2e6bdb748fd33 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequest.java @@ -44,7 +44,7 @@ public HttpRequest createHttpRequest() { HttpPost httpPost = new HttpPost(account.uri()); ByteArrayEntity byteEntity = new ByteArrayEntity( - Strings.toString(new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput)).getBytes(StandardCharsets.UTF_8) + Strings.toString(new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model)).getBytes(StandardCharsets.UTF_8) ); httpPost.setEntity(byteEntity); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java index 6c567861ce66d..6e78c023ce77c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java @@ -7,10 +7,12 @@ package org.elasticsearch.xpack.inference.external.request.openai; +import org.elasticsearch.common.Strings; import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel; import java.io.IOException; import java.util.Objects; @@ -44,12 +46,14 @@ public class OpenAiUnifiedChatCompletionRequestEntity implements ToXContentObjec private final UnifiedCompletionRequest unifiedRequest; private final boolean stream; + private final OpenAiChatCompletionModel model; - public OpenAiUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput) { + public OpenAiUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, OpenAiChatCompletionModel model) { Objects.requireNonNull(unifiedChatInput); this.unifiedRequest = unifiedChatInput.getRequest(); this.stream = unifiedChatInput.stream(); + this.model = Objects.requireNonNull(model); } @Override @@ -162,9 +166,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (unifiedRequest.topP() != null) { builder.field(TOP_P_FIELD, unifiedRequest.topP()); } - if (unifiedRequest.user() != null && unifiedRequest.user().isEmpty() == false) { - builder.field(USER_FIELD, unifiedRequest.user()); + + if (Strings.isNullOrEmpty(model.getTaskSettings().user())) { + builder.field(USER_FIELD, model.getTaskSettings().user()); } + builder.field(STREAM_FIELD, stream); builder.endObject(); return builder; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java index 2986dab872103..8851fcc7e5144 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java @@ -33,7 +33,7 @@ import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; -import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction; +import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; import org.elasticsearch.xpack.inference.external.action.openai.OpenAiActionCreator; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; @@ -277,13 +277,10 @@ public void doUnifiedCompletionInfer( OpenAiChatCompletionModel openAiModel = (OpenAiChatCompletionModel) model; - // TODO override fields from the persisted model - // var overriddenModel = OpenAiChatCompletionModel.of(model, taskSettings); - // TODO create a new OpenAiCompletionRequestManager with the appropriate unified completion input - // or look into merging the functionality but that'd require potentially a lot more fields for the old version? - var requestCreator = OpenAiCompletionRequestManager.of(openAiModel, getServiceComponents().threadPool()); - var errorMessage = constructFailedToSendRequestMessage(openAiModel.getServiceSettings().uri(), COMPLETION_ERROR_PREFIX); - var action = new SingleInputSenderExecutableAction(getSender(), requestCreator, errorMessage, COMPLETION_ERROR_PREFIX); + var overriddenModel = OpenAiChatCompletionModel.of(openAiModel, inputs.getRequest()); + var requestCreator = OpenAiCompletionRequestManager.of(overriddenModel, getServiceComponents().threadPool()); + var errorMessage = constructFailedToSendRequestMessage(overriddenModel.getServiceSettings().uri(), COMPLETION_ERROR_PREFIX); + var action = new SenderExecutableAction(getSender(), requestCreator, errorMessage); action.execute(inputs, timeout, listener); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModel.java index e721cd2955cf3..6fd94f70af166 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModel.java @@ -13,6 +13,7 @@ import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.SettingsConfiguration; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.inference.configuration.SettingsConfigurationDisplayType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; @@ -38,6 +39,11 @@ public static OpenAiChatCompletionModel of(OpenAiChatCompletionModel model, Map< return new OpenAiChatCompletionModel(model, OpenAiChatCompletionTaskSettings.of(model.getTaskSettings(), requestTaskSettings)); } + public static OpenAiChatCompletionModel of(OpenAiChatCompletionModel model, UnifiedCompletionRequest request) { + var requestTaskSettings = OpenAiChatCompletionRequestTaskSettings.fromUnifiedRequest(request); + return new OpenAiChatCompletionModel(model, OpenAiChatCompletionTaskSettings.of(model.getTaskSettings(), requestTaskSettings)); + } + public OpenAiChatCompletionModel( String inferenceEntityId, TaskType taskType, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionRequestTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionRequestTaskSettings.java index 8029d8579baba..69978b800f8b1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionRequestTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionRequestTaskSettings.java @@ -10,6 +10,7 @@ import org.elasticsearch.common.ValidationException; import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.UnifiedCompletionRequest; import java.util.Map; @@ -49,4 +50,8 @@ public static OpenAiChatCompletionRequestTaskSettings fromMap(Map) null); assertThat(overriddenModel, sameInstance(model)); } From 86d477ea5e5147969fbb7783363db027f5f88ef1 Mon Sep 17 00:00:00 2001 From: Max Hniebergall Date: Tue, 26 Nov 2024 14:50:47 -0500 Subject: [PATCH 13/53] get the build working --- .../mock/TestDenseInferenceServiceExtension.java | 11 +++++++++++ .../mock/TestRerankingServiceExtension.java | 11 +++++++++++ .../mock/TestSparseInferenceServiceExtension.java | 11 +++++++++++ .../TestStreamingCompletionServiceExtension.java | 11 +++++++++++ ... OpenAiUnifiedChatCompletionRequestTests.java} | 15 ++++++++------- .../inference/services/SenderServiceTests.java | 9 +++++++++ 6 files changed, 61 insertions(+), 7 deletions(-) rename x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/{OpenAiChatCompletionRequestTests.java => OpenAiUnifiedChatCompletionRequestTests.java} (90%) diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java index 2ddc4f6c3e2f6..31eb9f6454ed6 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java @@ -32,6 +32,7 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskSettingsConfiguration; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.inference.configuration.SettingsConfigurationDisplayType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.rest.RestStatus; @@ -133,6 +134,16 @@ public void infer( } } + @Override + public void unifiedCompletionInfer( + Model model, + UnifiedCompletionRequest request, + TimeValue timeout, + ActionListener listener + ) { + listener.onFailure(new UnsupportedOperationException("unifiedCompletionInfer not supported")); + } + @Override public void chunkedInfer( Model model, diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java index 2075c1b1924bf..8ad75d895d697 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java @@ -30,6 +30,7 @@ import org.elasticsearch.inference.SettingsConfiguration; import org.elasticsearch.inference.TaskSettingsConfiguration; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.inference.configuration.SettingsConfigurationDisplayType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.rest.RestStatus; @@ -121,6 +122,16 @@ public void infer( } } + @Override + public void unifiedCompletionInfer( + Model model, + UnifiedCompletionRequest request, + TimeValue timeout, + ActionListener listener + ) { + listener.onFailure(new UnsupportedOperationException("unifiedCompletionInfer not supported")); + } + @Override public void chunkedInfer( Model model, diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java index 3d6f0ce6eba05..bdbe73501e343 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java @@ -30,6 +30,7 @@ import org.elasticsearch.inference.SettingsConfiguration; import org.elasticsearch.inference.TaskSettingsConfiguration; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.inference.configuration.SettingsConfigurationDisplayType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.rest.RestStatus; @@ -124,6 +125,16 @@ public void infer( } } + @Override + public void unifiedCompletionInfer( + Model model, + UnifiedCompletionRequest request, + TimeValue timeout, + ActionListener listener + ) { + throw new UnsupportedOperationException("unifiedCompletionInfer not supported"); + } + @Override public void chunkedInfer( Model model, diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java index 595b92a6be66b..9d6249a7b6dcb 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java @@ -31,6 +31,7 @@ import org.elasticsearch.inference.SettingsConfiguration; import org.elasticsearch.inference.TaskSettingsConfiguration; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.inference.configuration.SettingsConfigurationDisplayType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.rest.RestStatus; @@ -122,6 +123,16 @@ public void infer( } } + @Override + public void unifiedCompletionInfer( + Model model, + UnifiedCompletionRequest request, + TimeValue timeout, + ActionListener listener + ) { + listener.onFailure(new UnsupportedOperationException("unifiedCompletionInfer not supported")); // TODO + } + private StreamingChatCompletionResults makeResults(List input) { var responseIter = input.stream().map(String::toUpperCase).iterator(); return new StreamingChatCompletionResults(subscriber -> { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestTests.java similarity index 90% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequestTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestTests.java index b6ebfd02941f3..12749dd03ca1a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestTests.java @@ -12,6 +12,7 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests; import java.io.IOException; @@ -20,13 +21,13 @@ import java.util.Map; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; -import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiChatCompletionRequest.buildDefaultUri; +import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUnifiedChatCompletionRequest.buildDefaultUri; import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.ORGANIZATION_HEADER; import static org.hamcrest.Matchers.aMapWithSize; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; -public class OpenAiChatCompletionRequestTests extends ESTestCase { +public class OpenAiUnifiedChatCompletionRequestTests extends ESTestCase { public void testCreateRequest_WithUrlOrganizationUserDefined() throws IOException { var request = createRequest("www.google.com", "org", "secret", "abc", "model", "user"); @@ -75,7 +76,7 @@ public void testCreateRequest_WithDefaultUrlAndWithoutUserOrganization() throws assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); var httpPost = (HttpPost) httpRequest.httpRequestBase(); - assertThat(httpPost.getURI().toString(), is(OpenAiChatCompletionRequest.buildDefaultUri().toString())); + assertThat(httpPost.getURI().toString(), is(OpenAiUnifiedChatCompletionRequest.buildDefaultUri().toString())); assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); assertNull(httpPost.getLastHeader(ORGANIZATION_HEADER)); @@ -101,7 +102,7 @@ public void testCreateRequest_WithStreaming() throws URISyntaxException, IOExcep public void testTruncate_DoesNotReduceInputTextSize() throws URISyntaxException, IOException { var request = createRequest(null, null, "secret", "abcd", "model", null); var truncatedRequest = request.truncate(); - assertThat(request.getURI().toString(), is(OpenAiChatCompletionRequest.buildDefaultUri().toString())); + assertThat(request.getURI().toString(), is(OpenAiUnifiedChatCompletionRequest.buildDefaultUri().toString())); var httpRequest = truncatedRequest.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); @@ -121,7 +122,7 @@ public void testTruncationInfo_ReturnsNull() { assertNull(request.getTruncationInfo()); } - public static OpenAiChatCompletionRequest createRequest( + public static OpenAiUnifiedChatCompletionRequest createRequest( @Nullable String url, @Nullable String org, String apiKey, @@ -132,7 +133,7 @@ public static OpenAiChatCompletionRequest createRequest( return createRequest(url, org, apiKey, input, model, user, false); } - public static OpenAiChatCompletionRequest createRequest( + public static OpenAiUnifiedChatCompletionRequest createRequest( @Nullable String url, @Nullable String org, String apiKey, @@ -142,7 +143,7 @@ public static OpenAiChatCompletionRequest createRequest( boolean stream ) { var chatCompletionModel = OpenAiChatCompletionModelTests.createChatCompletionModel(url, org, apiKey, model, user); - return new OpenAiChatCompletionRequest(List.of(input), chatCompletionModel, stream); + return new OpenAiUnifiedChatCompletionRequest(UnifiedChatInput.of(List.of(input), stream), chatCompletionModel); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java index d8402c28cec87..c8b688701b50b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java @@ -27,6 +27,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.junit.After; import org.junit.Before; @@ -120,6 +121,14 @@ protected void doInfer( } + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) {} + @Override protected void doChunkedInfer( Model model, From 359d3058631df182a393123af1c0591a628b80d3 Mon Sep 17 00:00:00 2001 From: Max Hniebergall <137079448+maxhniebergall@users.noreply.github.com> Date: Tue, 26 Nov 2024 15:03:31 -0500 Subject: [PATCH 14/53] Update docs/changelog/117589.yaml --- docs/changelog/117589.yaml | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 docs/changelog/117589.yaml diff --git a/docs/changelog/117589.yaml b/docs/changelog/117589.yaml new file mode 100644 index 0000000000000..2a2a483dc7bde --- /dev/null +++ b/docs/changelog/117589.yaml @@ -0,0 +1,5 @@ +pr: 117589 +summary: "[Inference API] Add unified api for chat completions" +area: Machine Learning +type: enhancement +issues: [] From 834676d6fa849e2ce97c3435b22c2ab254cd023f Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Tue, 26 Nov 2024 15:37:37 -0500 Subject: [PATCH 15/53] Fixing injection issue --- .../xpack/inference/action/BaseTransportInferenceAction.java | 3 ++- .../xpack/inference/action/TransportInferenceAction.java | 1 + .../action/TransportUnifiedCompletionInferenceAction.java | 1 + 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java index efb8098c2b455..7f4ab7d9e5447 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java @@ -53,6 +53,7 @@ public abstract class BaseTransportInferenceAction requestReader ) { - super(InferenceAction.NAME, transportService, actionFilters, requestReader, EsExecutors.DIRECT_EXECUTOR_SERVICE); + super(inferenceActionName, transportService, actionFilters, requestReader, EsExecutors.DIRECT_EXECUTOR_SERVICE); this.modelRegistry = modelRegistry; this.serviceRegistry = serviceRegistry; this.inferenceStats = inferenceStats; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java index 11e849d4be883..c4e7dfd75d218 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java @@ -34,6 +34,7 @@ public TransportInferenceAction( StreamingTaskManager streamingTaskManager ) { super( + InferenceAction.NAME, transportService, actionFilters, modelRegistry, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java index 4291bc046c919..8c79fc3e8a459 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java @@ -36,6 +36,7 @@ public TransportUnifiedCompletionInferenceAction( StreamingTaskManager streamingTaskManager ) { super( + UnifiedCompletionAction.NAME, transportService, actionFilters, modelRegistry, From 5909a7dd3927c08f99e38a52af61614f3e76403f Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Tue, 26 Nov 2024 17:07:15 -0500 Subject: [PATCH 16/53] Allowing model to be overridden but not working yet --- .../SingleInputSenderExecutableAction.java | 9 +---- .../http/sender/DocumentsOnlyInput.java | 6 ++- .../external/http/sender/InferenceInputs.java | 12 ++++-- .../http/sender/QueryAndDocsInputs.java | 5 ++- .../http/sender/UnifiedChatInput.java | 6 ++- ...nAiUnifiedChatCompletionRequestEntity.java | 2 +- .../completion/OpenAiChatCompletionModel.java | 20 +++++++++- .../http/sender/InferenceInputsTests.java | 38 +++++++++++++++++++ 8 files changed, 81 insertions(+), 17 deletions(-) create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputsTests.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableAction.java index 4e97554b56445..b43e5ab70e2f2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableAction.java @@ -12,7 +12,6 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.http.sender.RequestManager; import org.elasticsearch.xpack.inference.external.http.sender.Sender; @@ -34,13 +33,7 @@ public SingleInputSenderExecutableAction( @Override public void execute(InferenceInputs inferenceInputs, TimeValue timeout, ActionListener listener) { - if (inferenceInputs instanceof DocumentsOnlyInput == false) { - listener.onFailure(new ElasticsearchStatusException("Invalid inference input type", RestStatus.INTERNAL_SERVER_ERROR)); - return; - } - - var docsOnlyInput = (DocumentsOnlyInput) inferenceInputs; - if (docsOnlyInput.getInputs().size() > 1) { + if (inferenceInputs.inputSize() > 1) { listener.onFailure( new ElasticsearchStatusException(requestTypeForInputValidationError + " only accepts 1 input", RestStatus.BAD_REQUEST) ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/DocumentsOnlyInput.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/DocumentsOnlyInput.java index 8cf411d84c932..da5bae00c7831 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/DocumentsOnlyInput.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/DocumentsOnlyInput.java @@ -14,7 +14,7 @@ public class DocumentsOnlyInput extends InferenceInputs { public static DocumentsOnlyInput of(InferenceInputs inferenceInputs) { if (inferenceInputs instanceof DocumentsOnlyInput == false) { - throw createUnsupportedTypeException(inferenceInputs); + throw createUnsupportedTypeException(inferenceInputs, DocumentsOnlyInput.class); } return (DocumentsOnlyInput) inferenceInputs; @@ -40,4 +40,8 @@ public List getInputs() { public boolean stream() { return stream; } + + public int inputSize() { + return input.size(); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputs.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputs.java index 7b0e240a1dc40..73719a8e57bc5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputs.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputs.java @@ -10,15 +10,19 @@ import org.elasticsearch.common.Strings; public abstract class InferenceInputs { - public static IllegalArgumentException createUnsupportedTypeException(InferenceInputs inferenceInputs) { - return new IllegalArgumentException(Strings.format("Unsupported inference inputs type: [%s]", inferenceInputs.getClass())); + public static IllegalArgumentException createUnsupportedTypeException(InferenceInputs inferenceInputs, Class clazz) { + return new IllegalArgumentException( + Strings.format("Unable to convert inference inputs type: [%s] to [%s]", inferenceInputs.getClass(), clazz) + ); } public T castTo(Class clazz) { - if (this.getClass().isInstance(clazz) == false) { - throw createUnsupportedTypeException(this); + if (clazz.isInstance(this) == false) { + throw createUnsupportedTypeException(this, clazz); } return clazz.cast(this); } + + public abstract int inputSize(); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java index 50bb77b307db3..0218799ee892a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java @@ -14,7 +14,7 @@ public class QueryAndDocsInputs extends InferenceInputs { public static QueryAndDocsInputs of(InferenceInputs inferenceInputs) { if (inferenceInputs instanceof QueryAndDocsInputs == false) { - throw createUnsupportedTypeException(inferenceInputs); + throw createUnsupportedTypeException(inferenceInputs, QueryAndDocsInputs.class); } return (QueryAndDocsInputs) inferenceInputs; @@ -47,4 +47,7 @@ public boolean stream() { return stream; } + public int inputSize() { + return chunks.size(); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInput.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInput.java index 3c3577b9ec116..00fbd10a24153 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInput.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInput.java @@ -18,7 +18,7 @@ public class UnifiedChatInput extends InferenceInputs { public static UnifiedChatInput of(InferenceInputs inferenceInputs) { if (inferenceInputs instanceof UnifiedChatInput == false) { - throw createUnsupportedTypeException(inferenceInputs); + throw createUnsupportedTypeException(inferenceInputs, UnifiedChatInput.class); } return (UnifiedChatInput) inferenceInputs; @@ -63,4 +63,8 @@ public UnifiedCompletionRequest getRequest() { public boolean stream() { return stream; } + + public int inputSize() { + return request.messages().size(); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java index 6e78c023ce77c..7093671afa6a9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java @@ -108,7 +108,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.endArray(); if (unifiedRequest.model() != null) { - builder.field(MODEL_FIELD, unifiedRequest.model()); + builder.field(MODEL_FIELD, model.getServiceSettings().modelId()); } if (unifiedRequest.maxCompletionTokens() != null) { builder.field(MAX_COMPLETION_TOKENS_FIELD, unifiedRequest.maxCompletionTokens()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModel.java index 6fd94f70af166..b8d6774e4db44 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModel.java @@ -25,6 +25,7 @@ import java.util.Collections; import java.util.HashMap; import java.util.Map; +import java.util.Objects; import static org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields.USER; @@ -41,7 +42,24 @@ public static OpenAiChatCompletionModel of(OpenAiChatCompletionModel model, Map< public static OpenAiChatCompletionModel of(OpenAiChatCompletionModel model, UnifiedCompletionRequest request) { var requestTaskSettings = OpenAiChatCompletionRequestTaskSettings.fromUnifiedRequest(request); - return new OpenAiChatCompletionModel(model, OpenAiChatCompletionTaskSettings.of(model.getTaskSettings(), requestTaskSettings)); + var originalModelServiceSettings = model.getServiceSettings(); + var overriddenServiceSettings = new OpenAiChatCompletionServiceSettings( + Objects.requireNonNullElse(request.model(), originalModelServiceSettings.modelId()), + originalModelServiceSettings.uri(), + originalModelServiceSettings.organizationId(), + originalModelServiceSettings.maxInputTokens(), + originalModelServiceSettings.rateLimitSettings() + ); + + var overriddenTaskSettings = OpenAiChatCompletionTaskSettings.of(model.getTaskSettings(), requestTaskSettings); + return new OpenAiChatCompletionModel( + overriddenServiceSettings.modelId(), + model.getTaskType(), + model.getConfigurations().getService(), + overriddenServiceSettings, + overriddenTaskSettings, + model.getSecretSettings() + ); } public OpenAiChatCompletionModel( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputsTests.java new file mode 100644 index 0000000000000..b0279caa563d5 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputsTests.java @@ -0,0 +1,38 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.test.ESTestCase; +import org.hamcrest.Matchers; + +import java.util.List; + +public class InferenceInputsTests extends ESTestCase { + public void testCastToSucceeds() { + InferenceInputs inputs = new DocumentsOnlyInput(List.of(), false); + assertThat(inputs.castTo(DocumentsOnlyInput.class), Matchers.instanceOf(DocumentsOnlyInput.class)); + + assertThat(UnifiedChatInput.of(List.of(), false).castTo(UnifiedChatInput.class), Matchers.instanceOf(UnifiedChatInput.class)); + assertThat( + new QueryAndDocsInputs("hello", List.of(), false).castTo(QueryAndDocsInputs.class), + Matchers.instanceOf(QueryAndDocsInputs.class) + ); + } + + public void testCastToFails() { + InferenceInputs inputs = new DocumentsOnlyInput(List.of(), false); + var exception = expectThrows(IllegalArgumentException.class, () -> inputs.castTo(QueryAndDocsInputs.class)); + assertThat( + exception.getMessage(), + Matchers.containsString( + Strings.format("Unable to convert inference inputs type: [%s] to [%s]", DocumentsOnlyInput.class, QueryAndDocsInputs.class) + ) + ); + } +} From 315be2cb5c68fd02922714bce851fc0f1d24adb2 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Wed, 27 Nov 2024 10:02:19 -0500 Subject: [PATCH 17/53] Fixing issues --- ...nAiUnifiedChatCompletionRequestEntity.java | 19 +++++++++++-------- .../RestUnifiedCompletionInferenceAction.java | 3 +-- .../completion/OpenAiChatCompletionModel.java | 2 +- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java index 7093671afa6a9..3a90a8ccb8c2e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java @@ -40,7 +40,7 @@ public class OpenAiUnifiedChatCompletionRequestEntity implements ToXContentObjec private static final String STOP_FIELD = "stop"; private static final String TEMPERATURE_FIELD = "temperature"; private static final String TOOL_CHOICE_FIELD = "tool_choice"; - private static final String TOOL_FIELD = "tool"; + private static final String TOOL_FIELD = "tools"; private static final String TEXT_FIELD = "text"; private static final String TYPE_FIELD = "type"; @@ -107,9 +107,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } builder.endArray(); - if (unifiedRequest.model() != null) { - builder.field(MODEL_FIELD, model.getServiceSettings().modelId()); - } + builder.field(MODEL_FIELD, model.getServiceSettings().modelId()); if (unifiedRequest.maxCompletionTokens() != null) { builder.field(MAX_COMPLETION_TOKENS_FIELD, unifiedRequest.maxCompletionTokens()); } @@ -144,9 +142,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.endObject(); } } - if (unifiedRequest.tool() != null) { + if (unifiedRequest.tools() != null) { builder.startArray(TOOL_FIELD); - for (UnifiedCompletionRequest.Tool t : unifiedRequest.tool()) { + for (UnifiedCompletionRequest.Tool t : unifiedRequest.tools()) { builder.startObject(); { builder.field(TYPE_FIELD, t.type()); @@ -155,7 +153,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(DESCRIPTION_FIELD, t.function().description()); builder.field(NAME_FIELD, t.function().name()); builder.field(PARAMETERS_FIELD, t.function().parameters()); - builder.field(STRICT_FIELD, t.function().strict()); + if (t.function().strict() != null) { + builder.field(STRICT_FIELD, t.function().strict()); + } } builder.endObject(); } @@ -167,12 +167,15 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(TOP_P_FIELD, unifiedRequest.topP()); } - if (Strings.isNullOrEmpty(model.getTaskSettings().user())) { + if (Strings.isNullOrEmpty(model.getTaskSettings().user()) == false) { builder.field(USER_FIELD, model.getTaskSettings().user()); } builder.field(STREAM_FIELD, stream); builder.endObject(); + + System.out.println(Strings.toString(builder)); + return builder; } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceAction.java index 48f031d3df8cc..5c71b560a6b9d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceAction.java @@ -12,7 +12,6 @@ import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.Scope; import org.elasticsearch.rest.ServerlessScope; -import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction; import java.io.IOException; @@ -45,6 +44,6 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient request = UnifiedCompletionAction.Request.parseRequest(params.inferenceEntityId(), params.taskType(), inferTimeout, parser); } - return channel -> client.execute(InferenceAction.INSTANCE, request, new ServerSentEventsRestActionListener(channel)); + return channel -> client.execute(UnifiedCompletionAction.INSTANCE, request, new ServerSentEventsRestActionListener(channel)); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModel.java index b8d6774e4db44..ceda56b78d0e3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModel.java @@ -53,7 +53,7 @@ public static OpenAiChatCompletionModel of(OpenAiChatCompletionModel model, Unif var overriddenTaskSettings = OpenAiChatCompletionTaskSettings.of(model.getTaskSettings(), requestTaskSettings); return new OpenAiChatCompletionModel( - overriddenServiceSettings.modelId(), + model.getInferenceEntityId(), model.getTaskType(), model.getConfigurations().getService(), overriddenServiceSettings, From 657561e71e3a9ce7971851c59af068040bf33bee Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Wed, 27 Nov 2024 10:04:44 -0500 Subject: [PATCH 18/53] Switch field name for tool --- .../org/elasticsearch/inference/UnifiedCompletionRequest.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java b/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java index 5bdd0f212e68f..5b9ec50f52aba 100644 --- a/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java +++ b/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java @@ -37,7 +37,7 @@ public record UnifiedCompletionRequest( @Nullable Stop stop, @Nullable Float temperature, @Nullable ToolChoice toolChoice, - @Nullable List tool, + @Nullable List tools, @Nullable Float topP, @Nullable String user ) implements Writeable { @@ -114,7 +114,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalNamedWriteable(stop); out.writeOptionalFloat(temperature); out.writeOptionalNamedWriteable(toolChoice); - out.writeOptionalCollection(tool); + out.writeOptionalCollection(tools); out.writeOptionalFloat(topP); out.writeOptionalString(user); } From 97b330fed0bc44e2621ab73c17d36eea751a34aa Mon Sep 17 00:00:00 2001 From: Max Hniebergall Date: Wed, 27 Nov 2024 10:20:18 -0500 Subject: [PATCH 19/53] Add suport for toolCalls and refusal in streaming completion --- .../StreamingChatCompletionResults.java | 89 +++++++++++++++++- .../openai/OpenAiStreamingProcessor.java | 94 ++++++++++++++++--- ...nAiUnifiedChatCompletionRequestEntity.java | 2 - 3 files changed, 168 insertions(+), 17 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingChatCompletionResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingChatCompletionResults.java index 05a181d3fc5b6..d087582bab03d 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingChatCompletionResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingChatCompletionResults.java @@ -20,6 +20,7 @@ import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.concurrent.Flow; import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResults.COMPLETION; @@ -77,16 +78,102 @@ public Iterator toXContentChunked(ToXContent.Params params } } - public record Result(String delta) implements ChunkedToXContent { + public record Result(String delta, String refusal, List toolCalls) implements ChunkedToXContent { + private static final String RESULT = "delta"; + private static final String REFUSAL = "refusal"; + private static final String TOOL_CALLS = "tool_calls"; + + public Result(String delta) { + this(delta, "", List.of()); + } @Override public Iterator toXContentChunked(ToXContent.Params params) { return Iterators.concat( ChunkedToXContentHelper.startObject(), ChunkedToXContentHelper.field(RESULT, delta), + ChunkedToXContentHelper.field(REFUSAL, refusal), + ChunkedToXContentHelper.startArray(TOOL_CALLS), + Iterators.flatMap(toolCalls.iterator(), t -> t.toXContentChunked(params)), + ChunkedToXContentHelper.endArray(), ChunkedToXContentHelper.endObject() ); } } + + public static class ToolCall implements ChunkedToXContent { + private final int index; + private final String id; + private final String functionName; + private final String functionArguments; + + public ToolCall(int index, String id, String functionName, String functionArguments) { + this.index = index; + this.id = id; + this.functionName = functionName; + this.functionArguments = functionArguments; + } + + public int getIndex() { + return index; + } + + public String getId() { + return id; + } + + public String getFunctionName() { + return functionName; + } + + public String getFunctionArguments() { + return functionArguments; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ToolCall toolCall = (ToolCall) o; + return index == toolCall.index + && Objects.equals(id, toolCall.id) + && Objects.equals(functionName, toolCall.functionName) + && Objects.equals(functionArguments, toolCall.functionArguments); + } + + @Override + public Iterator toXContentChunked(ToXContent.Params params) { + return Iterators.concat( + ChunkedToXContentHelper.startObject(), + ChunkedToXContentHelper.field("index", index), + ChunkedToXContentHelper.field("id", id), + ChunkedToXContentHelper.field("functionName", functionName), + ChunkedToXContentHelper.field("functionArguments", functionArguments), + ChunkedToXContentHelper.endObject() + ); + } + + @Override + public int hashCode() { + return Objects.hash(index, id, functionName, functionArguments); + } + + @Override + public String toString() { + return "ToolCall{" + + "index=" + + index + + ", id='" + + id + + '\'' + + ", functionName='" + + functionName + + '\'' + + ", functionArguments='" + + functionArguments + + '\'' + + '}'; + } + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessor.java index 6e006fe255956..7105667cd7bce 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessor.java @@ -22,11 +22,12 @@ import java.io.IOException; import java.util.ArrayDeque; +import java.util.ArrayList; import java.util.Collections; import java.util.Deque; import java.util.Iterator; +import java.util.List; import java.util.Objects; -import java.util.function.Predicate; import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList; @@ -111,6 +112,8 @@ public class OpenAiStreamingProcessor extends DelegatingProcessor item) throws Exception { @@ -159,6 +162,10 @@ private Iterator parse(XContentParserConf ensureExpectedToken(XContentParser.Token.START_OBJECT, currentToken, parser); + String content = null; + String refusal = null; + List toolCalls = new ArrayList<>(); + currentToken = parser.nextToken(); // continue until the end of delta @@ -167,25 +174,84 @@ private Iterator parse(XContentParserConf parser.skipChildren(); } - if (currentToken == XContentParser.Token.FIELD_NAME && parser.currentName().equals(CONTENT_FIELD)) { - parser.nextToken(); - ensureExpectedToken(XContentParser.Token.VALUE_STRING, parser.currentToken(), parser); - var content = parser.text(); - consumeUntilObjectEnd(parser); // end delta - consumeUntilObjectEnd(parser); // end choices - return content; + if (currentToken == XContentParser.Token.FIELD_NAME) { + switch (parser.currentName()) { + case CONTENT_FIELD: + parser.nextToken(); + ensureExpectedToken(XContentParser.Token.VALUE_STRING, parser.currentToken(), parser); + content = parser.text(); + break; + case REFUSAL_FIELD: + parser.nextToken(); + ensureExpectedToken(XContentParser.Token.VALUE_STRING, parser.currentToken(), parser); + refusal = parser.text(); + break; + case TOOL_CALLS_FIELD: + parser.nextToken(); + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + toolCalls = parseToolCalls(parser); + break; + } } currentToken = parser.nextToken(); } + consumeUntilObjectEnd(parser); // end delta consumeUntilObjectEnd(parser); // end choices - return ""; // stopped - }).stream() - .filter(Objects::nonNull) - .filter(Predicate.not(String::isEmpty)) - .map(StreamingChatCompletionResults.Result::new) - .iterator(); + + return new StreamingChatCompletionResults.Result(content, refusal, toolCalls); + }).stream().filter(Objects::nonNull).iterator(); + } + } + + private List parseToolCalls(XContentParser parser) throws IOException { + List toolCalls = new ArrayList<>(); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + int index = -1; + String id = null; + String functionName = null; + String functionArguments = null; + + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + if (parser.currentToken() == XContentParser.Token.FIELD_NAME) { + switch (parser.currentName()) { + case "index": + parser.nextToken(); + ensureExpectedToken(XContentParser.Token.VALUE_NUMBER, parser.currentToken(), parser); + index = parser.intValue(); + break; + case "id": + parser.nextToken(); + ensureExpectedToken(XContentParser.Token.VALUE_STRING, parser.currentToken(), parser); + id = parser.text(); + break; + case "function": + parser.nextToken(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + if (parser.currentToken() == XContentParser.Token.FIELD_NAME) { + switch (parser.currentName()) { + case "name": + parser.nextToken(); + ensureExpectedToken(XContentParser.Token.VALUE_STRING, parser.currentToken(), parser); + functionName = parser.text(); + break; + case "arguments": + parser.nextToken(); + ensureExpectedToken(XContentParser.Token.VALUE_STRING, parser.currentToken(), parser); + functionArguments = parser.text(); + break; + } + } + } + break; + } + } + } + toolCalls.add(new StreamingChatCompletionResults.ToolCall(index, id, functionName, functionArguments)); } + return toolCalls; } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java index 3a90a8ccb8c2e..132a68c644eb9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java @@ -174,8 +174,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(STREAM_FIELD, stream); builder.endObject(); - System.out.println(Strings.toString(builder)); - return builder; } } From 2660ecb7eb2ecf7636b7def16136f343f0c69acf Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Wed, 27 Nov 2024 10:47:51 -0500 Subject: [PATCH 20/53] Working tool call response --- .../external/openai/OpenAiStreamingProcessor.java | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessor.java index 7105667cd7bce..a78b9514058e1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessor.java @@ -145,6 +145,7 @@ private Iterator parse(XContentParserConf return Collections.emptyIterator(); } + System.out.println(event.value()); try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, event.value())) { moveToFirstToken(jsonParser); @@ -178,13 +179,17 @@ private Iterator parse(XContentParserConf switch (parser.currentName()) { case CONTENT_FIELD: parser.nextToken(); - ensureExpectedToken(XContentParser.Token.VALUE_STRING, parser.currentToken(), parser); - content = parser.text(); + if (parser.currentToken() == XContentParser.Token.VALUE_STRING) { + content = parser.text(); + } + // ensureExpectedToken(XContentParser.Token.VALUE_STRING, parser.currentToken(), parser); break; case REFUSAL_FIELD: parser.nextToken(); - ensureExpectedToken(XContentParser.Token.VALUE_STRING, parser.currentToken(), parser); - refusal = parser.text(); + if (parser.currentToken() == XContentParser.Token.VALUE_STRING) { + refusal = parser.text(); + } + // ensureExpectedToken(XContentParser.Token.VALUE_STRING, parser.currentToken(), parser); break; case TOOL_CALLS_FIELD: parser.nextToken(); @@ -197,7 +202,7 @@ private Iterator parse(XContentParserConf currentToken = parser.nextToken(); } - consumeUntilObjectEnd(parser); // end delta + // consumeUntilObjectEnd(parser); // end delta consumeUntilObjectEnd(parser); // end choices return new StreamingChatCompletionResults.Result(content, refusal, toolCalls); From 03fada09839d2f9d1aa3dab72adff92faf6be5fa Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Wed, 27 Nov 2024 14:29:24 -0500 Subject: [PATCH 21/53] Separate unified and legacy code paths --- .../StreamingChatCompletionResults.java | 89 +----- ...StreamingUnifiedChatCompletionResults.java | 181 ++++++++++++ ...baCloudSearchCompletionRequestManager.java | 2 +- ...onBedrockChatCompletionRequestManager.java | 8 +- .../AnthropicCompletionRequestManager.java | 8 +- ...eAiStudioChatCompletionRequestManager.java | 8 +- .../AzureOpenAiCompletionRequestManager.java | 8 +- .../http/sender/ChatCompletionInput.java | 38 +++ .../CohereCompletionRequestManager.java | 8 +- ...oogleAiStudioCompletionRequestManager.java | 5 +- .../OpenAiCompletionRequestManager.java | 12 +- ...OpenAiUnifiedCompletionRequestManager.java | 61 ++++ .../http/sender/UnifiedChatInput.java | 36 --- .../openai/OpenAiStreamingProcessor.java | 99 +------ ...iUnifiedChatCompletionResponseHandler.java | 34 +++ .../OpenAiUnifiedStreamingProcessor.java | 262 ++++++++++++++++++ .../GoogleAiStudioCompletionRequest.java | 6 +- .../openai/OpenAiChatCompletionRequest.java | 99 +++++++ .../inference/services/SenderService.java | 4 +- .../services/openai/OpenAiService.java | 4 +- .../http/sender/InferenceInputsTests.java | 4 +- .../GoogleAiStudioCompletionRequestTests.java | 6 +- ... => OpenAiChatCompletionRequestTests.java} | 13 +- 23 files changed, 740 insertions(+), 255 deletions(-) create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ChatCompletionInput.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiUnifiedCompletionRequestManager.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedChatCompletionResponseHandler.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequest.java rename x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/{OpenAiUnifiedChatCompletionRequestTests.java => OpenAiChatCompletionRequestTests.java} (91%) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingChatCompletionResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingChatCompletionResults.java index d087582bab03d..05a181d3fc5b6 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingChatCompletionResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingChatCompletionResults.java @@ -20,7 +20,6 @@ import java.util.Iterator; import java.util.List; import java.util.Map; -import java.util.Objects; import java.util.concurrent.Flow; import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResults.COMPLETION; @@ -78,102 +77,16 @@ public Iterator toXContentChunked(ToXContent.Params params } } - public record Result(String delta, String refusal, List toolCalls) implements ChunkedToXContent { - + public record Result(String delta) implements ChunkedToXContent { private static final String RESULT = "delta"; - private static final String REFUSAL = "refusal"; - private static final String TOOL_CALLS = "tool_calls"; - - public Result(String delta) { - this(delta, "", List.of()); - } @Override public Iterator toXContentChunked(ToXContent.Params params) { return Iterators.concat( ChunkedToXContentHelper.startObject(), ChunkedToXContentHelper.field(RESULT, delta), - ChunkedToXContentHelper.field(REFUSAL, refusal), - ChunkedToXContentHelper.startArray(TOOL_CALLS), - Iterators.flatMap(toolCalls.iterator(), t -> t.toXContentChunked(params)), - ChunkedToXContentHelper.endArray(), ChunkedToXContentHelper.endObject() ); } } - - public static class ToolCall implements ChunkedToXContent { - private final int index; - private final String id; - private final String functionName; - private final String functionArguments; - - public ToolCall(int index, String id, String functionName, String functionArguments) { - this.index = index; - this.id = id; - this.functionName = functionName; - this.functionArguments = functionArguments; - } - - public int getIndex() { - return index; - } - - public String getId() { - return id; - } - - public String getFunctionName() { - return functionName; - } - - public String getFunctionArguments() { - return functionArguments; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - ToolCall toolCall = (ToolCall) o; - return index == toolCall.index - && Objects.equals(id, toolCall.id) - && Objects.equals(functionName, toolCall.functionName) - && Objects.equals(functionArguments, toolCall.functionArguments); - } - - @Override - public Iterator toXContentChunked(ToXContent.Params params) { - return Iterators.concat( - ChunkedToXContentHelper.startObject(), - ChunkedToXContentHelper.field("index", index), - ChunkedToXContentHelper.field("id", id), - ChunkedToXContentHelper.field("functionName", functionName), - ChunkedToXContentHelper.field("functionArguments", functionArguments), - ChunkedToXContentHelper.endObject() - ); - } - - @Override - public int hashCode() { - return Objects.hash(index, id, functionName, functionArguments); - } - - @Override - public String toString() { - return "ToolCall{" - + "index=" - + index - + ", id='" - + id - + '\'' - + ", functionName='" - + functionName - + '\'' - + ", functionArguments='" - + functionArguments - + '\'' - + '}'; - } - } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java new file mode 100644 index 0000000000000..de035f40711b1 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java @@ -0,0 +1,181 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.results; + +import org.elasticsearch.common.collect.Iterators; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ChunkedToXContent; +import org.elasticsearch.common.xcontent.ChunkedToXContentHelper; +import org.elasticsearch.inference.InferenceResults; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xcontent.ToXContent; + +import java.io.IOException; +import java.util.Deque; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.Flow; + +import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResults.COMPLETION; + +/** + * Chat Completion results that only contain a Flow.Publisher. + */ +public record StreamingUnifiedChatCompletionResults(Flow.Publisher publisher) + implements + InferenceServiceResults { + + @Override + public boolean isStreaming() { + return true; + } + + @Override + public List transformToCoordinationFormat() { + throw new UnsupportedOperationException("Not implemented"); + } + + @Override + public List transformToLegacyFormat() { + throw new UnsupportedOperationException("Not implemented"); + } + + @Override + public Map asMap() { + throw new UnsupportedOperationException("Not implemented"); + } + + @Override + public String getWriteableName() { + throw new UnsupportedOperationException("Not implemented"); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new UnsupportedOperationException("Not implemented"); + } + + @Override + public Iterator toXContentChunked(ToXContent.Params params) { + throw new UnsupportedOperationException("Not implemented"); + } + + public record Results(Deque results) implements ChunkedToXContent { + @Override + public Iterator toXContentChunked(ToXContent.Params params) { + return Iterators.concat( + ChunkedToXContentHelper.startObject(), + ChunkedToXContentHelper.startArray(COMPLETION), + Iterators.flatMap(results.iterator(), d -> d.toXContentChunked(params)), + ChunkedToXContentHelper.endArray(), + ChunkedToXContentHelper.endObject() + ); + } + } + + public record Result(String delta, String refusal, List toolCalls) implements ChunkedToXContent { + + private static final String RESULT = "delta"; + private static final String REFUSAL = "refusal"; + private static final String TOOL_CALLS = "tool_calls"; + + public Result(String delta) { + this(delta, "", List.of()); + } + + @Override + public Iterator toXContentChunked(ToXContent.Params params) { + return Iterators.concat( + ChunkedToXContentHelper.startObject(), + ChunkedToXContentHelper.field(RESULT, delta), + ChunkedToXContentHelper.field(REFUSAL, refusal), + ChunkedToXContentHelper.startArray(TOOL_CALLS), + Iterators.flatMap(toolCalls.iterator(), t -> t.toXContentChunked(params)), + ChunkedToXContentHelper.endArray(), + ChunkedToXContentHelper.endObject() + ); + } + } + + public static class ToolCall implements ChunkedToXContent { + private final int index; + private final String id; + private final String functionName; + private final String functionArguments; + + public ToolCall(int index, String id, String functionName, String functionArguments) { + this.index = index; + this.id = id; + this.functionName = functionName; + this.functionArguments = functionArguments; + } + + public int getIndex() { + return index; + } + + public String getId() { + return id; + } + + public String getFunctionName() { + return functionName; + } + + public String getFunctionArguments() { + return functionArguments; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ToolCall toolCall = (ToolCall) o; + return index == toolCall.index + && Objects.equals(id, toolCall.id) + && Objects.equals(functionName, toolCall.functionName) + && Objects.equals(functionArguments, toolCall.functionArguments); + } + + @Override + public Iterator toXContentChunked(ToXContent.Params params) { + return Iterators.concat( + ChunkedToXContentHelper.startObject(), + ChunkedToXContentHelper.field("index", index), + ChunkedToXContentHelper.field("id", id), + ChunkedToXContentHelper.field("functionName", functionName), + ChunkedToXContentHelper.field("functionArguments", functionArguments), + ChunkedToXContentHelper.endObject() + ); + } + + @Override + public int hashCode() { + return Objects.hash(index, id, functionName, functionArguments); + } + + @Override + public String toString() { + return "ToolCall{" + + "index=" + + index + + ", id='" + + id + + '\'' + + ", functionName='" + + functionName + + '\'' + + ", functionArguments='" + + functionArguments + + '\'' + + '}'; + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchCompletionRequestManager.java index a0a44e62f9f73..e7a960f1316f2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchCompletionRequestManager.java @@ -69,7 +69,7 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - List input = DocumentsOnlyInput.of(inferenceInputs).getInputs(); + List input = inferenceInputs.castTo(ChatCompletionInput.class).getInputs(); AlibabaCloudSearchCompletionRequest request = new AlibabaCloudSearchCompletionRequest(account, input, model); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockChatCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockChatCompletionRequestManager.java index 69a5c665feb86..3929585a0745d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockChatCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockChatCompletionRequestManager.java @@ -44,10 +44,10 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - var docsOnly = DocumentsOnlyInput.of(inferenceInputs); - var docsInput = docsOnly.getInputs(); - var stream = docsOnly.stream(); - var requestEntity = AmazonBedrockChatCompletionEntityFactory.createEntity(model, docsInput); + var chatCompletionInput = inferenceInputs.castTo(ChatCompletionInput.class); + var inputs = chatCompletionInput.getInputs(); + var stream = chatCompletionInput.stream(); + var requestEntity = AmazonBedrockChatCompletionEntityFactory.createEntity(model, inputs); var request = new AmazonBedrockChatCompletionRequest(model, requestEntity, timeout, stream); var responseHandler = new AmazonBedrockChatCompletionResponseHandler(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AnthropicCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AnthropicCompletionRequestManager.java index 5418b3dd9840b..6d4aeb9e31bac 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AnthropicCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AnthropicCompletionRequestManager.java @@ -46,10 +46,10 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - var docsOnly = DocumentsOnlyInput.of(inferenceInputs); - var docsInput = docsOnly.getInputs(); - var stream = docsOnly.stream(); - AnthropicChatCompletionRequest request = new AnthropicChatCompletionRequest(docsInput, model, stream); + var chatCompletionInput = inferenceInputs.castTo(ChatCompletionInput.class); + var inputs = chatCompletionInput.getInputs(); + var stream = chatCompletionInput.stream(); + AnthropicChatCompletionRequest request = new AnthropicChatCompletionRequest(inputs, model, stream); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioChatCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioChatCompletionRequestManager.java index 21cec68b14a49..affd2e3a7760e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioChatCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioChatCompletionRequestManager.java @@ -41,10 +41,10 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - var docsOnly = DocumentsOnlyInput.of(inferenceInputs); - var docsInput = docsOnly.getInputs(); - var stream = docsOnly.stream(); - AzureAiStudioChatCompletionRequest request = new AzureAiStudioChatCompletionRequest(model, docsInput, stream); + var chatCompletionInput = inferenceInputs.castTo(ChatCompletionInput.class); + var inputs = chatCompletionInput.getInputs(); + var stream = chatCompletionInput.stream(); + AzureAiStudioChatCompletionRequest request = new AzureAiStudioChatCompletionRequest(model, inputs, stream); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiCompletionRequestManager.java index d036559ec3dcb..c2f5f3e9db5ed 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiCompletionRequestManager.java @@ -46,10 +46,10 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - var docsOnly = DocumentsOnlyInput.of(inferenceInputs); - var docsInput = docsOnly.getInputs(); - var stream = docsOnly.stream(); - AzureOpenAiCompletionRequest request = new AzureOpenAiCompletionRequest(docsInput, model, stream); + var chatCompletionInput = inferenceInputs.castTo(ChatCompletionInput.class); + var inputs = chatCompletionInput.getInputs(); + var stream = chatCompletionInput.stream(); + AzureOpenAiCompletionRequest request = new AzureOpenAiCompletionRequest(inputs, model, stream); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ChatCompletionInput.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ChatCompletionInput.java new file mode 100644 index 0000000000000..eb869682a6c2a --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ChatCompletionInput.java @@ -0,0 +1,38 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import java.util.List; +import java.util.Objects; + +public class ChatCompletionInput extends InferenceInputs { + private final List input; + private final boolean stream; + + public ChatCompletionInput(List input) { + this(input, false); + } + + public ChatCompletionInput(List input, boolean stream) { + super(); + this.input = Objects.requireNonNull(input); + this.stream = stream; + } + + public List getInputs() { + return this.input; + } + + public boolean stream() { + return stream; + } + + public int inputSize() { + return input.size(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereCompletionRequestManager.java index ae46fbe0fef87..40cd03c87664e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereCompletionRequestManager.java @@ -50,10 +50,10 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - var docsOnly = DocumentsOnlyInput.of(inferenceInputs); - var docsInput = docsOnly.getInputs(); - var stream = docsOnly.stream(); - CohereCompletionRequest request = new CohereCompletionRequest(docsInput, model, stream); + var chatCompletionInput = inferenceInputs.castTo(ChatCompletionInput.class); + var inputs = chatCompletionInput.getInputs(); + var stream = chatCompletionInput.stream(); + CohereCompletionRequest request = new CohereCompletionRequest(inputs, model, stream); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioCompletionRequestManager.java index abe50c6fae3f9..0097f9c08ea21 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioCompletionRequestManager.java @@ -51,7 +51,10 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - GoogleAiStudioCompletionRequest request = new GoogleAiStudioCompletionRequest(DocumentsOnlyInput.of(inferenceInputs), model); + GoogleAiStudioCompletionRequest request = new GoogleAiStudioCompletionRequest( + inferenceInputs.castTo(ChatCompletionInput.class), + model + ); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java index 372a57174315a..f3de9108d0f28 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java @@ -15,6 +15,7 @@ import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; import org.elasticsearch.xpack.inference.external.openai.OpenAiChatCompletionResponseHandler; +import org.elasticsearch.xpack.inference.external.request.openai.OpenAiChatCompletionRequest; import org.elasticsearch.xpack.inference.external.request.openai.OpenAiUnifiedChatCompletionRequest; import org.elasticsearch.xpack.inference.external.response.openai.OpenAiChatCompletionResponseEntity; import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel; @@ -46,12 +47,11 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - - // TODO check and see if this works - // OpenAiUnifiedChatCompletionRequest request = new OpenAiUnifiedChatCompletionRequest(UnifiedChatInput.of(inferenceInputs), model); - OpenAiUnifiedChatCompletionRequest request = new OpenAiUnifiedChatCompletionRequest( - inferenceInputs.castTo(UnifiedChatInput.class), - model + var chatCompletionInputs = inferenceInputs.castTo(ChatCompletionInput.class); + OpenAiChatCompletionRequest request = new OpenAiChatCompletionRequest( + chatCompletionInputs.getInputs(), + model, + chatCompletionInputs.stream() ); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiUnifiedCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiUnifiedCompletionRequestManager.java new file mode 100644 index 0000000000000..3b0f770e3e061 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiUnifiedCompletionRequestManager.java @@ -0,0 +1,61 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.openai.OpenAiUnifiedChatCompletionResponseHandler; +import org.elasticsearch.xpack.inference.external.request.openai.OpenAiUnifiedChatCompletionRequest; +import org.elasticsearch.xpack.inference.external.response.openai.OpenAiChatCompletionResponseEntity; +import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel; + +import java.util.Objects; +import java.util.function.Supplier; + +public class OpenAiUnifiedCompletionRequestManager extends OpenAiRequestManager { + + private static final Logger logger = LogManager.getLogger(OpenAiUnifiedCompletionRequestManager.class); + + private static final ResponseHandler HANDLER = createCompletionHandler(); + + public static OpenAiUnifiedCompletionRequestManager of(OpenAiChatCompletionModel model, ThreadPool threadPool) { + return new OpenAiUnifiedCompletionRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool)); + } + + private final OpenAiChatCompletionModel model; + + private OpenAiUnifiedCompletionRequestManager(OpenAiChatCompletionModel model, ThreadPool threadPool) { + super(threadPool, model, OpenAiUnifiedChatCompletionRequest::buildDefaultUri); + this.model = Objects.requireNonNull(model); + } + + @Override + public void execute( + InferenceInputs inferenceInputs, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + ActionListener listener + ) { + + OpenAiUnifiedChatCompletionRequest request = new OpenAiUnifiedChatCompletionRequest( + inferenceInputs.castTo(UnifiedChatInput.class), + model + ); + + execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); + } + + private static ResponseHandler createCompletionHandler() { + return new OpenAiUnifiedChatCompletionResponseHandler("openai completion", OpenAiChatCompletionResponseEntity::fromResponse); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInput.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInput.java index 00fbd10a24153..a6f791f8be660 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInput.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInput.java @@ -9,45 +9,9 @@ import org.elasticsearch.inference.UnifiedCompletionRequest; -import java.util.List; import java.util.Objects; -import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD; - public class UnifiedChatInput extends InferenceInputs { - - public static UnifiedChatInput of(InferenceInputs inferenceInputs) { - if (inferenceInputs instanceof UnifiedChatInput == false) { - throw createUnsupportedTypeException(inferenceInputs, UnifiedChatInput.class); - } - - return (UnifiedChatInput) inferenceInputs; - } - - public static UnifiedChatInput of(List input, boolean stream) { - var unifiedRequest = new UnifiedCompletionRequest( - convertToMessages(input), - null, - null, - null, - null, - null, - null, - null, - null, - // TODO we need to get the user field from task settings if it is there - null - ); - - return new UnifiedChatInput(unifiedRequest, stream); - } - - private static List convertToMessages(List inputs) { - return inputs.stream() - .map(doc -> new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString(doc), USER_FIELD, null, null, null)) - .toList(); - } - private final UnifiedCompletionRequest request; private final boolean stream; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessor.java index a78b9514058e1..6e006fe255956 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessor.java @@ -22,12 +22,11 @@ import java.io.IOException; import java.util.ArrayDeque; -import java.util.ArrayList; import java.util.Collections; import java.util.Deque; import java.util.Iterator; -import java.util.List; import java.util.Objects; +import java.util.function.Predicate; import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList; @@ -112,8 +111,6 @@ public class OpenAiStreamingProcessor extends DelegatingProcessor item) throws Exception { @@ -145,7 +142,6 @@ private Iterator parse(XContentParserConf return Collections.emptyIterator(); } - System.out.println(event.value()); try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, event.value())) { moveToFirstToken(jsonParser); @@ -163,10 +159,6 @@ private Iterator parse(XContentParserConf ensureExpectedToken(XContentParser.Token.START_OBJECT, currentToken, parser); - String content = null; - String refusal = null; - List toolCalls = new ArrayList<>(); - currentToken = parser.nextToken(); // continue until the end of delta @@ -175,88 +167,25 @@ private Iterator parse(XContentParserConf parser.skipChildren(); } - if (currentToken == XContentParser.Token.FIELD_NAME) { - switch (parser.currentName()) { - case CONTENT_FIELD: - parser.nextToken(); - if (parser.currentToken() == XContentParser.Token.VALUE_STRING) { - content = parser.text(); - } - // ensureExpectedToken(XContentParser.Token.VALUE_STRING, parser.currentToken(), parser); - break; - case REFUSAL_FIELD: - parser.nextToken(); - if (parser.currentToken() == XContentParser.Token.VALUE_STRING) { - refusal = parser.text(); - } - // ensureExpectedToken(XContentParser.Token.VALUE_STRING, parser.currentToken(), parser); - break; - case TOOL_CALLS_FIELD: - parser.nextToken(); - ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); - toolCalls = parseToolCalls(parser); - break; - } + if (currentToken == XContentParser.Token.FIELD_NAME && parser.currentName().equals(CONTENT_FIELD)) { + parser.nextToken(); + ensureExpectedToken(XContentParser.Token.VALUE_STRING, parser.currentToken(), parser); + var content = parser.text(); + consumeUntilObjectEnd(parser); // end delta + consumeUntilObjectEnd(parser); // end choices + return content; } currentToken = parser.nextToken(); } - // consumeUntilObjectEnd(parser); // end delta consumeUntilObjectEnd(parser); // end choices - - return new StreamingChatCompletionResults.Result(content, refusal, toolCalls); - }).stream().filter(Objects::nonNull).iterator(); - } - } - - private List parseToolCalls(XContentParser parser) throws IOException { - List toolCalls = new ArrayList<>(); - while (parser.nextToken() != XContentParser.Token.END_ARRAY) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); - int index = -1; - String id = null; - String functionName = null; - String functionArguments = null; - - while (parser.nextToken() != XContentParser.Token.END_OBJECT) { - if (parser.currentToken() == XContentParser.Token.FIELD_NAME) { - switch (parser.currentName()) { - case "index": - parser.nextToken(); - ensureExpectedToken(XContentParser.Token.VALUE_NUMBER, parser.currentToken(), parser); - index = parser.intValue(); - break; - case "id": - parser.nextToken(); - ensureExpectedToken(XContentParser.Token.VALUE_STRING, parser.currentToken(), parser); - id = parser.text(); - break; - case "function": - parser.nextToken(); - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); - while (parser.nextToken() != XContentParser.Token.END_OBJECT) { - if (parser.currentToken() == XContentParser.Token.FIELD_NAME) { - switch (parser.currentName()) { - case "name": - parser.nextToken(); - ensureExpectedToken(XContentParser.Token.VALUE_STRING, parser.currentToken(), parser); - functionName = parser.text(); - break; - case "arguments": - parser.nextToken(); - ensureExpectedToken(XContentParser.Token.VALUE_STRING, parser.currentToken(), parser); - functionArguments = parser.text(); - break; - } - } - } - break; - } - } - } - toolCalls.add(new StreamingChatCompletionResults.ToolCall(index, id, functionName, functionArguments)); + return ""; // stopped + }).stream() + .filter(Objects::nonNull) + .filter(Predicate.not(String::isEmpty)) + .map(StreamingChatCompletionResults.Result::new) + .iterator(); } - return toolCalls; } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedChatCompletionResponseHandler.java new file mode 100644 index 0000000000000..fce2556efc5e0 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedChatCompletionResponseHandler.java @@ -0,0 +1,34 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.openai; + +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser; +import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor; + +import java.util.concurrent.Flow; + +public class OpenAiUnifiedChatCompletionResponseHandler extends OpenAiChatCompletionResponseHandler { + public OpenAiUnifiedChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) { + super(requestType, parseFunction); + } + + @Override + public InferenceServiceResults parseResult(Request request, Flow.Publisher flow) { + var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser()); + var openAiProcessor = new OpenAiUnifiedStreamingProcessor(); + + flow.subscribe(serverSentEventProcessor); + serverSentEventProcessor.subscribe(openAiProcessor); + return new StreamingUnifiedChatCompletionResults(openAiProcessor); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java new file mode 100644 index 0000000000000..5d8ac7d4555c5 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java @@ -0,0 +1,262 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.openai; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.common.xcontent.ChunkedToXContent; +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; +import org.elasticsearch.xpack.inference.common.DelegatingProcessor; +import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent; +import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventField; + +import java.io.IOException; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Deque; +import java.util.Iterator; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList; +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.consumeUntilObjectEnd; +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken; +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField; + +/** + * Parses the OpenAI chat completion streaming responses. + * For a request like: + * + *
+ *     
+ *         {
+ *             "inputs": ["Please summarize this text: some text", "Answer the following question: Question"]
+ *         }
+ *     
+ * 
+ * + * The response would look like: + * + *
+ *     
+ *         {
+ *              "id": "chatcmpl-123",
+ *              "object": "chat.completion",
+ *              "created": 1677652288,
+ *              "model": "gpt-3.5-turbo-0613",
+ *              "system_fingerprint": "fp_44709d6fcb",
+ *              "choices": [
+ *                  {
+ *                      "index": 0,
+ *                      "delta": {
+ *                          "content": "\n\nHello there, how ",
+ *                      },
+ *                      "finish_reason": ""
+ *                  }
+ *              ]
+ *          }
+ *
+ *         {
+ *              "id": "chatcmpl-123",
+ *              "object": "chat.completion",
+ *              "created": 1677652288,
+ *              "model": "gpt-3.5-turbo-0613",
+ *              "system_fingerprint": "fp_44709d6fcb",
+ *              "choices": [
+ *                  {
+ *                      "index": 1,
+ *                      "delta": {
+ *                          "content": "may I assist you today?",
+ *                      },
+ *                      "finish_reason": ""
+ *                  }
+ *              ]
+ *          }
+ *
+ *         {
+ *              "id": "chatcmpl-123",
+ *              "object": "chat.completion",
+ *              "created": 1677652288,
+ *              "model": "gpt-3.5-turbo-0613",
+ *              "system_fingerprint": "fp_44709d6fcb",
+ *              "choices": [
+ *                  {
+ *                      "index": 2,
+ *                      "delta": {},
+ *                      "finish_reason": "stop"
+ *                  }
+ *              ]
+ *          }
+ *
+ *          [DONE]
+ *     
+ * 
+ */ +public class OpenAiUnifiedStreamingProcessor extends DelegatingProcessor, ChunkedToXContent> { + private static final Logger log = LogManager.getLogger(OpenAiUnifiedStreamingProcessor.class); + private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in OpenAI chat completions response"; + + private static final String CHOICES_FIELD = "choices"; + private static final String DELTA_FIELD = "delta"; + private static final String CONTENT_FIELD = "content"; + private static final String DONE_MESSAGE = "[done]"; + private static final String REFUSAL_FIELD = "refusal"; + private static final String TOOL_CALLS_FIELD = "tool_calls"; + + @Override + protected void next(Deque item) throws Exception { + var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); + + var results = new ArrayDeque(item.size()); + for (ServerSentEvent event : item) { + if (ServerSentEventField.DATA == event.name() && event.hasValue()) { + try { + var delta = parse(parserConfig, event); + delta.forEachRemaining(results::offer); + } catch (Exception e) { + log.warn("Failed to parse event from inference provider: {}", event); + throw e; + } + } + } + + if (results.isEmpty()) { + upstream().request(1); + } else { + downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(results)); + } + } + + private Iterator parse(XContentParserConfiguration parserConfig, ServerSentEvent event) + throws IOException { + if (DONE_MESSAGE.equalsIgnoreCase(event.value())) { + return Collections.emptyIterator(); + } + + System.out.println(event.value()); + try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, event.value())) { + moveToFirstToken(jsonParser); + + XContentParser.Token token = jsonParser.currentToken(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser); + + positionParserAtTokenAfterField(jsonParser, CHOICES_FIELD, FAILED_TO_FIND_FIELD_TEMPLATE); + + return parseList(jsonParser, parser -> { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + + positionParserAtTokenAfterField(parser, DELTA_FIELD, FAILED_TO_FIND_FIELD_TEMPLATE); + + var currentToken = parser.currentToken(); + + ensureExpectedToken(XContentParser.Token.START_OBJECT, currentToken, parser); + + String content = null; + String refusal = null; + List toolCalls = new ArrayList<>(); + + currentToken = parser.nextToken(); + + // continue until the end of delta + while (currentToken != null && currentToken != XContentParser.Token.END_OBJECT) { + if (currentToken == XContentParser.Token.START_OBJECT || currentToken == XContentParser.Token.START_ARRAY) { + parser.skipChildren(); + } + + if (currentToken == XContentParser.Token.FIELD_NAME) { + switch (parser.currentName()) { + case CONTENT_FIELD: + parser.nextToken(); + if (parser.currentToken() == XContentParser.Token.VALUE_STRING) { + content = parser.text(); + } + // ensureExpectedToken(XContentParser.Token.VALUE_STRING, parser.currentToken(), parser); + break; + case REFUSAL_FIELD: + parser.nextToken(); + if (parser.currentToken() == XContentParser.Token.VALUE_STRING) { + refusal = parser.text(); + } + // ensureExpectedToken(XContentParser.Token.VALUE_STRING, parser.currentToken(), parser); + break; + case TOOL_CALLS_FIELD: + parser.nextToken(); + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + toolCalls = parseToolCalls(parser); + break; + } + } + + currentToken = parser.nextToken(); + } + + // consumeUntilObjectEnd(parser); // end delta + consumeUntilObjectEnd(parser); // end choices + + return new StreamingUnifiedChatCompletionResults.Result(content, refusal, toolCalls); + }).stream().filter(Objects::nonNull).iterator(); + } + } + + private List parseToolCalls(XContentParser parser) throws IOException { + List toolCalls = new ArrayList<>(); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + int index = -1; + String id = null; + String functionName = null; + String functionArguments = null; + + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + if (parser.currentToken() == XContentParser.Token.FIELD_NAME) { + switch (parser.currentName()) { + case "index": + parser.nextToken(); + ensureExpectedToken(XContentParser.Token.VALUE_NUMBER, parser.currentToken(), parser); + index = parser.intValue(); + break; + case "id": + parser.nextToken(); + ensureExpectedToken(XContentParser.Token.VALUE_STRING, parser.currentToken(), parser); + id = parser.text(); + break; + case "function": + parser.nextToken(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + if (parser.currentToken() == XContentParser.Token.FIELD_NAME) { + switch (parser.currentName()) { + case "name": + parser.nextToken(); + ensureExpectedToken(XContentParser.Token.VALUE_STRING, parser.currentToken(), parser); + functionName = parser.text(); + break; + case "arguments": + parser.nextToken(); + ensureExpectedToken(XContentParser.Token.VALUE_STRING, parser.currentToken(), parser); + functionArguments = parser.text(); + break; + } + } + } + break; + } + } + } + toolCalls.add(new StreamingUnifiedChatCompletionResults.ToolCall(index, id, functionName, functionArguments)); + } + return toolCalls; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioCompletionRequest.java index 80770d63ef139..b1af18d03dda4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioCompletionRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioCompletionRequest.java @@ -14,7 +14,7 @@ import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.request.HttpRequest; import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionModel; @@ -27,13 +27,13 @@ public class GoogleAiStudioCompletionRequest implements GoogleAiStudioRequest { private static final String ALT_PARAM = "alt"; private static final String SSE_VALUE = "sse"; - private final DocumentsOnlyInput input; + private final ChatCompletionInput input; private final LazyInitializable uri; private final GoogleAiStudioCompletionModel model; - public GoogleAiStudioCompletionRequest(DocumentsOnlyInput input, GoogleAiStudioCompletionModel model) { + public GoogleAiStudioCompletionRequest(ChatCompletionInput input, GoogleAiStudioCompletionModel model) { this.input = Objects.requireNonNull(input); this.model = Objects.requireNonNull(model); this.uri = new LazyInitializable<>(() -> model.uri(input.stream())); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequest.java new file mode 100644 index 0000000000000..99a025e70d003 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequest.java @@ -0,0 +1,99 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.openai; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.client.utils.URIBuilder; +import org.apache.http.entity.ByteArrayEntity; +import org.elasticsearch.common.Strings; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.openai.OpenAiAccount; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel; + +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; +import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.createOrgHeader; + +public class OpenAiChatCompletionRequest implements OpenAiRequest { + + private final OpenAiAccount account; + private final List input; + private final OpenAiChatCompletionModel model; + private final boolean stream; + + public OpenAiChatCompletionRequest(List input, OpenAiChatCompletionModel model, boolean stream) { + this.account = OpenAiAccount.of(model, OpenAiChatCompletionRequest::buildDefaultUri); + this.input = Objects.requireNonNull(input); + this.model = Objects.requireNonNull(model); + this.stream = stream; + } + + @Override + public HttpRequest createHttpRequest() { + HttpPost httpPost = new HttpPost(account.uri()); + + ByteArrayEntity byteEntity = new ByteArrayEntity( + Strings.toString( + new OpenAiChatCompletionRequestEntity(input, model.getServiceSettings().modelId(), model.getTaskSettings().user(), stream) + ).getBytes(StandardCharsets.UTF_8) + ); + httpPost.setEntity(byteEntity); + + httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()); + httpPost.setHeader(createAuthBearerHeader(account.apiKey())); + + var org = account.organizationId(); + if (org != null) { + httpPost.setHeader(createOrgHeader(org)); + } + + return new HttpRequest(httpPost, getInferenceEntityId()); + } + + @Override + public URI getURI() { + return account.uri(); + } + + @Override + public Request truncate() { + // No truncation for OpenAI chat completions + return this; + } + + @Override + public boolean[] getTruncationInfo() { + // No truncation for OpenAI chat completions + return null; + } + + @Override + public String getInferenceEntityId() { + return model.getInferenceEntityId(); + } + + @Override + public boolean isStreaming() { + return stream; + } + + public static URI buildDefaultUri() throws URISyntaxException { + return new URIBuilder().setScheme("https") + .setHost(OpenAiUtils.HOST) + .setPathSegments(OpenAiUtils.VERSION_1, OpenAiUtils.CHAT_PATH, OpenAiUtils.COMPLETIONS_PATH) + .build(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java index 95c84859a28ef..42a5a0a146542 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java @@ -22,6 +22,7 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; @@ -73,8 +74,7 @@ public void infer( private static InferenceInputs createInput(Model model, List input, @Nullable String query, boolean stream) { return switch (model.getTaskType()) { - // TODO implement parameters - case COMPLETION -> UnifiedChatInput.of(input, stream); + case COMPLETION -> new ChatCompletionInput(input, stream); case RERANK -> new QueryAndDocsInputs(query, input, stream); case TEXT_EMBEDDING -> new DocumentsOnlyInput(input, stream); default -> throw new ElasticsearchStatusException( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java index 8851fcc7e5144..acce6bb3ada9f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java @@ -38,7 +38,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; -import org.elasticsearch.xpack.inference.external.http.sender.OpenAiCompletionRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.OpenAiUnifiedCompletionRequestManager; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; @@ -278,7 +278,7 @@ public void doUnifiedCompletionInfer( OpenAiChatCompletionModel openAiModel = (OpenAiChatCompletionModel) model; var overriddenModel = OpenAiChatCompletionModel.of(openAiModel, inputs.getRequest()); - var requestCreator = OpenAiCompletionRequestManager.of(overriddenModel, getServiceComponents().threadPool()); + var requestCreator = OpenAiUnifiedCompletionRequestManager.of(overriddenModel, getServiceComponents().threadPool()); var errorMessage = constructFailedToSendRequestMessage(overriddenModel.getServiceSettings().uri(), COMPLETION_ERROR_PREFIX); var action = new SenderExecutableAction(getSender(), requestCreator, errorMessage); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputsTests.java index b0279caa563d5..814081ba68b49 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputsTests.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.inference.external.http.sender; import org.elasticsearch.common.Strings; +import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.test.ESTestCase; import org.hamcrest.Matchers; @@ -18,7 +19,8 @@ public void testCastToSucceeds() { InferenceInputs inputs = new DocumentsOnlyInput(List.of(), false); assertThat(inputs.castTo(DocumentsOnlyInput.class), Matchers.instanceOf(DocumentsOnlyInput.class)); - assertThat(UnifiedChatInput.of(List.of(), false).castTo(UnifiedChatInput.class), Matchers.instanceOf(UnifiedChatInput.class)); + var emptyRequest = new UnifiedCompletionRequest(List.of(), null, null, null, null, null, null, null, null, null); + assertThat(new UnifiedChatInput(emptyRequest, false).castTo(UnifiedChatInput.class), Matchers.instanceOf(UnifiedChatInput.class)); assertThat( new QueryAndDocsInputs("hello", List.of(), false).castTo(QueryAndDocsInputs.class), Matchers.instanceOf(QueryAndDocsInputs.class) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/completion/GoogleAiStudioCompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/completion/GoogleAiStudioCompletionRequestTests.java index 7ffa8940ad6be..065dfee577a82 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/completion/GoogleAiStudioCompletionRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/completion/GoogleAiStudioCompletionRequestTests.java @@ -10,7 +10,7 @@ import org.apache.http.client.methods.HttpPost; import org.elasticsearch.common.Strings; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.request.googleaistudio.GoogleAiStudioCompletionRequest; import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionModelTests; @@ -72,7 +72,7 @@ public void testTruncationInfo_ReturnsNull() { assertNull(request.getTruncationInfo()); } - private static DocumentsOnlyInput listOf(String... input) { - return new DocumentsOnlyInput(List.of(input)); + private static ChatCompletionInput listOf(String... input) { + return new ChatCompletionInput(List.of(input)); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequestTests.java similarity index 91% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequestTests.java index 12749dd03ca1a..e76b413f16943 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequestTests.java @@ -12,7 +12,6 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests; import java.io.IOException; @@ -27,7 +26,7 @@ import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; -public class OpenAiUnifiedChatCompletionRequestTests extends ESTestCase { +public class OpenAiChatCompletionRequestTests extends ESTestCase { public void testCreateRequest_WithUrlOrganizationUserDefined() throws IOException { var request = createRequest("www.google.com", "org", "secret", "abc", "model", "user"); @@ -76,7 +75,7 @@ public void testCreateRequest_WithDefaultUrlAndWithoutUserOrganization() throws assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); var httpPost = (HttpPost) httpRequest.httpRequestBase(); - assertThat(httpPost.getURI().toString(), is(OpenAiUnifiedChatCompletionRequest.buildDefaultUri().toString())); + assertThat(httpPost.getURI().toString(), is(OpenAiChatCompletionRequest.buildDefaultUri().toString())); assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); assertNull(httpPost.getLastHeader(ORGANIZATION_HEADER)); @@ -102,7 +101,7 @@ public void testCreateRequest_WithStreaming() throws URISyntaxException, IOExcep public void testTruncate_DoesNotReduceInputTextSize() throws URISyntaxException, IOException { var request = createRequest(null, null, "secret", "abcd", "model", null); var truncatedRequest = request.truncate(); - assertThat(request.getURI().toString(), is(OpenAiUnifiedChatCompletionRequest.buildDefaultUri().toString())); + assertThat(request.getURI().toString(), is(OpenAiChatCompletionRequest.buildDefaultUri().toString())); var httpRequest = truncatedRequest.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); @@ -122,7 +121,7 @@ public void testTruncationInfo_ReturnsNull() { assertNull(request.getTruncationInfo()); } - public static OpenAiUnifiedChatCompletionRequest createRequest( + public static OpenAiChatCompletionRequest createRequest( @Nullable String url, @Nullable String org, String apiKey, @@ -133,7 +132,7 @@ public static OpenAiUnifiedChatCompletionRequest createRequest( return createRequest(url, org, apiKey, input, model, user, false); } - public static OpenAiUnifiedChatCompletionRequest createRequest( + public static OpenAiChatCompletionRequest createRequest( @Nullable String url, @Nullable String org, String apiKey, @@ -143,7 +142,7 @@ public static OpenAiUnifiedChatCompletionRequest createRequest( boolean stream ) { var chatCompletionModel = OpenAiChatCompletionModelTests.createChatCompletionModel(url, org, apiKey, model, user); - return new OpenAiUnifiedChatCompletionRequest(UnifiedChatInput.of(List.of(input), stream), chatCompletionModel); + return new OpenAiChatCompletionRequest(List.of(input), chatCompletionModel, stream); } } From b76a47dfd975af592b7ae7ebd754db6d85d3b4e9 Mon Sep 17 00:00:00 2001 From: Max Hniebergall Date: Wed, 27 Nov 2024 17:26:48 -0500 Subject: [PATCH 22/53] Updated the parser, but there are some class cast exceptions to fix --- .../openai/OpenAiStreamingProcessor.java | 417 ++++++++++++++---- 1 file changed, 330 insertions(+), 87 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessor.java index 7105667cd7bce..a164e452725e7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessor.java @@ -11,6 +11,8 @@ import org.apache.logging.log4j.Logger; import org.elasticsearch.common.xcontent.ChunkedToXContent; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; @@ -23,17 +25,14 @@ import java.io.IOException; import java.util.ArrayDeque; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.Deque; import java.util.Iterator; import java.util.List; -import java.util.Objects; import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList; -import static org.elasticsearch.xpack.inference.external.response.XContentUtils.consumeUntilObjectEnd; import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken; -import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField; /** * Parses the OpenAI chat completion streaming responses. @@ -151,107 +150,351 @@ private Iterator parse(XContentParserConf XContentParser.Token token = jsonParser.currentToken(); ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser); - positionParserAtTokenAfterField(jsonParser, CHOICES_FIELD, FAILED_TO_FIND_FIELD_TEMPLATE); + ChatCompletionChunk chunk = ChatCompletionChunkParser.parse(jsonParser); - return parseList(jsonParser, parser -> { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + List results = new ArrayList<>(); + for (ChatCompletionChunk.Choice choice : chunk.getChoices()) { + String content = choice.getDelta().getContent(); + String refusal = choice.getDelta().getRefusal(); + List toolCalls = parseToolCalls(choice.getDelta().getToolCalls()); + results.add(new StreamingChatCompletionResults.Result(content, refusal, toolCalls)); + } - positionParserAtTokenAfterField(parser, DELTA_FIELD, FAILED_TO_FIND_FIELD_TEMPLATE); + return results.iterator(); + } + } - var currentToken = parser.currentToken(); + private List parseToolCalls(List toolCalls) { + List parsedToolCalls = new ArrayList<>(); + for (ChatCompletionChunk.Choice.Delta.ToolCall toolCall : toolCalls) { + int index = toolCall.getIndex(); + String id = toolCall.getId(); + String functionName = toolCall.getFunction() != null ? toolCall.getFunction().getName() : null; + String functionArguments = toolCall.getFunction() != null ? toolCall.getFunction().getArguments() : null; + parsedToolCalls.add(new StreamingChatCompletionResults.ToolCall(index, id, functionName, functionArguments)); + } + return parsedToolCalls; + } - ensureExpectedToken(XContentParser.Token.START_OBJECT, currentToken, parser); + public static class ChatCompletionChunk { + private final String id; + private List choices; + private final String model; + private final String object; + private Usage usage; - String content = null; - String refusal = null; - List toolCalls = new ArrayList<>(); + public ChatCompletionChunk(String id, List choices, String model, String object, Usage usage) { + this.id = id; + this.choices = choices; + this.model = model; + this.object = object; + this.usage = usage; + } - currentToken = parser.nextToken(); + public ChatCompletionChunk(String id, Choice[] choices, String model, String object, Usage usage) { + this.id = id; + this.choices = Arrays.stream(choices).toList(); + this.model = model; + this.object = object; + this.usage = usage; + } + + public String getId() { + return id; + } + + public List getChoices() { + return choices; + } + + public String getModel() { + return model; + } - // continue until the end of delta - while (currentToken != null && currentToken != XContentParser.Token.END_OBJECT) { - if (currentToken == XContentParser.Token.START_OBJECT || currentToken == XContentParser.Token.START_ARRAY) { - parser.skipChildren(); + public String getObject() { + return object; + } + + public Usage getUsage() { + return usage; + } + + public static class Choice { + private final Delta delta; + private final String finishReason; + private final int index; + + public Choice(Delta delta, String finishReason, int index) { + this.delta = delta; + this.finishReason = finishReason; + this.index = index; + } + + public Delta getDelta() { + return delta; + } + + public String getFinishReason() { + return finishReason; + } + + public int getIndex() { + return index; + } + + public static class Delta { + private final String content; + private final String refusal; + private final String role; + private List toolCalls; + + public Delta(String content, String refusal, String role, List toolCalls) { + this.content = content; + this.refusal = refusal; + this.role = role; + this.toolCalls = toolCalls; + } + + public String getContent() { + return content; + } + + public String getRefusal() { + return refusal; + } + + public String getRole() { + return role; + } + + public List getToolCalls() { + return toolCalls; + } + + public static class ToolCall { + private final int index; + private final String id; + private Function function; + private final String type; + + public ToolCall(int index, String id, Function function, String type) { + this.index = index; + this.id = id; + this.function = function; + this.type = type; } - if (currentToken == XContentParser.Token.FIELD_NAME) { - switch (parser.currentName()) { - case CONTENT_FIELD: - parser.nextToken(); - ensureExpectedToken(XContentParser.Token.VALUE_STRING, parser.currentToken(), parser); - content = parser.text(); - break; - case REFUSAL_FIELD: - parser.nextToken(); - ensureExpectedToken(XContentParser.Token.VALUE_STRING, parser.currentToken(), parser); - refusal = parser.text(); - break; - case TOOL_CALLS_FIELD: - parser.nextToken(); - ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); - toolCalls = parseToolCalls(parser); - break; - } + public int getIndex() { + return index; + } + + public String getId() { + return id; + } + + public Function getFunction() { + return function; } - currentToken = parser.nextToken(); + public String getType() { + return type; + } + + public static class Function { + private final String arguments; + private final String name; + + public Function(String arguments, String name) { + this.arguments = arguments; + this.name = name; + } + + public String getArguments() { + return arguments; + } + + public String getName() { + return name; + } + } } + } + } + + public static class Usage { + private final int completionTokens; + private final int promptTokens; + private final int totalTokens; - consumeUntilObjectEnd(parser); // end delta - consumeUntilObjectEnd(parser); // end choices + public Usage(int completionTokens, int promptTokens, int totalTokens) { + this.completionTokens = completionTokens; + this.promptTokens = promptTokens; + this.totalTokens = totalTokens; + } + + public int getCompletionTokens() { + return completionTokens; + } - return new StreamingChatCompletionResults.Result(content, refusal, toolCalls); - }).stream().filter(Objects::nonNull).iterator(); + public int getPromptTokens() { + return promptTokens; + } + + public int getTotalTokens() { + return totalTokens; + } } } - private List parseToolCalls(XContentParser parser) throws IOException { - List toolCalls = new ArrayList<>(); - while (parser.nextToken() != XContentParser.Token.END_ARRAY) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); - int index = -1; - String id = null; - String functionName = null; - String functionArguments = null; - - while (parser.nextToken() != XContentParser.Token.END_OBJECT) { - if (parser.currentToken() == XContentParser.Token.FIELD_NAME) { - switch (parser.currentName()) { - case "index": - parser.nextToken(); - ensureExpectedToken(XContentParser.Token.VALUE_NUMBER, parser.currentToken(), parser); - index = parser.intValue(); - break; - case "id": - parser.nextToken(); - ensureExpectedToken(XContentParser.Token.VALUE_STRING, parser.currentToken(), parser); - id = parser.text(); - break; - case "function": - parser.nextToken(); - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); - while (parser.nextToken() != XContentParser.Token.END_OBJECT) { - if (parser.currentToken() == XContentParser.Token.FIELD_NAME) { - switch (parser.currentName()) { - case "name": - parser.nextToken(); - ensureExpectedToken(XContentParser.Token.VALUE_STRING, parser.currentToken(), parser); - functionName = parser.text(); - break; - case "arguments": - parser.nextToken(); - ensureExpectedToken(XContentParser.Token.VALUE_STRING, parser.currentToken(), parser); - functionArguments = parser.text(); - break; - } - } - } - break; - } - } + public static class ChatCompletionChunkParser { + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "chat_completion_chunk", + true, + args -> new ChatCompletionChunk( + (String) args[0], + (ChatCompletionChunk.Choice[]) args[1], + (String) args[2], + (String) args[3], + (ChatCompletionChunk.Usage) args[4] + ) + + /** + * TODO + * Caused by: java.lang.ClassCastException: class java.lang.String cannot be cast to class [Lorg.elasticsearch.xpack.inference.external.openai.OpenAiStreamingProcessor$ChatCompletionChunk$Choice; (java.lang.String is in module java.base of loader 'bootstrap'; [Lorg.elasticsearch.xpack.inference.external.openai.OpenAiStreamingProcessor$ChatCompletionChunk$Choice; is in module org.elasticsearch.inference@9.0.0-SNAPSHOT of loader jdk.internal.loader.Loader @611c3eae) + * at org.elasticsearch.inference@9.0.0-SNAPSHOT/org.elasticsearch.xpack.inference.external.openai.OpenAiStreamingProcessor$ChatCompletionChunkParser.lambda$static$0(OpenAiStreamingProcessor.java:354) + * at org.elasticsearch.xcontent@9.0.0-SNAPSHOT/org.elasticsearch.xcontent.ConstructingObjectParser.lambda$new$2(ConstructingObjectParser.java:130) + * at org.elasticsearch.xcontent@9.0.0-SNAPSHOT/org.elasticsearch.xcontent.ConstructingObjectParser$Target.buildTarget(ConstructingObjectParser.java:555) + */ + ); + + static { + PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("id")); + PARSER.declareObjectArray( + (chunk, choices) -> chunk.choices = choices, + (p, c) -> ChoiceParser.parse(p), + new ParseField("choices") + ); + PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("model")); + PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("object")); + PARSER.declareObject((chunk, usage) -> chunk.usage = usage, (p, c) -> UsageParser.parse(p), new ParseField("usage")); + } + + public static ChatCompletionChunk parse(XContentParser parser) throws IOException { + return PARSER.parse(parser, null); + } + + private static class ChoiceParser { + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "choice", + true, + args -> new ChatCompletionChunk.Choice((ChatCompletionChunk.Choice.Delta) args[0], (String) args[1], (int) args[2]) + ); + + static { + PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> DeltaParser.parse(p), new ParseField("delta")); + PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("finish_reason")); + PARSER.declareInt(ConstructingObjectParser.constructorArg(), new ParseField("index")); + } + + public static ChatCompletionChunk.Choice parse(XContentParser parser) throws IOException { + return PARSER.apply(parser, null); + } + } + + private static class DeltaParser { + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "delta", + true, + args -> new ChatCompletionChunk.Choice.Delta( + (String) args[0], + (String) args[1], + (String) args[2], + (List) args[3] + ) + ); + + static { + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField("content")); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField("refusal")); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField("role")); + PARSER.declareObjectArray( + (delta, toolCalls) -> delta.toolCalls = toolCalls, + (p, c) -> ToolCallParser.parse(p), + new ParseField("tool_calls") + ); + + } + + public static ChatCompletionChunk.Choice.Delta parse(XContentParser parser) throws IOException { + return PARSER.parse(parser, null); + } + } + + private static class ToolCallParser { + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + "tool_call", + true, + args -> new ChatCompletionChunk.Choice.Delta.ToolCall( + (int) args[0], + (String) args[1], + (ChatCompletionChunk.Choice.Delta.ToolCall.Function) args[2], + (String) args[3] + ) + ); + + static { + PARSER.declareInt(ConstructingObjectParser.constructorArg(), new ParseField("index")); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField("id")); + PARSER.declareObject( + (toolCall, function) -> toolCall.function = function, + (p, c) -> FunctionParser.parse(p), + new ParseField("function") + ); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField("type")); + } + + public static ChatCompletionChunk.Choice.Delta.ToolCall parse(XContentParser parser) throws IOException { + return PARSER.parse(parser, null); + } + } + + private static class FunctionParser { + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + "function", + true, + args -> new ChatCompletionChunk.Choice.Delta.ToolCall.Function((String) args[0], (String) args[1]) + ); + + static { + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField("arguments")); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField("name")); + } + + public static ChatCompletionChunk.Choice.Delta.ToolCall.Function parse(XContentParser parser) throws IOException { + return PARSER.parse(parser, null); + } + } + + private static class UsageParser { + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "usage", + true, + args -> new ChatCompletionChunk.Usage((int) args[0], (int) args[1], (int) args[2]) + ); + + static { + PARSER.declareInt(ConstructingObjectParser.constructorArg(), new ParseField("completion_tokens")); + PARSER.declareInt(ConstructingObjectParser.constructorArg(), new ParseField("prompt_tokens")); + PARSER.declareInt(ConstructingObjectParser.constructorArg(), new ParseField("total_tokens")); + } + + public static ChatCompletionChunk.Usage parse(XContentParser parser) throws IOException { + return PARSER.parse(parser, null); } - toolCalls.add(new StreamingChatCompletionResults.ToolCall(index, id, functionName, functionArguments)); } - return toolCalls; } } From 0dfd0815b55cc676bdddf78762db67954b758f87 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Mon, 2 Dec 2024 15:32:07 -0500 Subject: [PATCH 23/53] Refactoring tests and request entities --- .../inference/UnifiedCompletionRequest.java | 34 ++--- .../org/elasticsearch/test/ESTestCase.java | 138 +++++++++++------- .../action/InferenceActionRequestTests.java | 2 +- .../action/UnifiedCompletionRequestTests.java | 49 +++---- .../OpenAiCompletionRequestManager.java | 9 +- .../http/sender/UnifiedChatInput.java | 17 +++ ...nAiUnifiedChatCompletionRequestEntity.java | 6 +- .../completion/OpenAiChatCompletionModel.java | 4 +- ...enAiChatCompletionRequestTaskSettings.java | 5 - .../http/sender/InferenceInputsTests.java | 2 +- ...penAiChatCompletionRequestEntityTests.java | 53 ------- ...nAiUnifiedChatCompletionRequestTests.java} | 21 ++- 12 files changed, 160 insertions(+), 180 deletions(-) delete mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequestEntityTests.java rename x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/{OpenAiChatCompletionRequestTests.java => OpenAiUnifiedChatCompletionRequestTests.java} (90%) diff --git a/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java b/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java index 5b9ec50f52aba..bce5a5601a936 100644 --- a/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java +++ b/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java @@ -33,13 +33,11 @@ public record UnifiedCompletionRequest( List messages, @Nullable String model, @Nullable Long maxCompletionTokens, - @Nullable Integer n, @Nullable Stop stop, @Nullable Float temperature, @Nullable ToolChoice toolChoice, @Nullable List tools, - @Nullable Float topP, - @Nullable String user + @Nullable Float topP ) implements Writeable { public sealed interface Content extends NamedWriteable permits ContentObjects, ContentString {} @@ -51,13 +49,11 @@ public sealed interface Content extends NamedWriteable permits ContentObjects, C (List) args[0], (String) args[1], (Long) args[2], - (Integer) args[3], - (Stop) args[4], - (Float) args[5], - (ToolChoice) args[6], - (List) args[7], - (Float) args[8], - (String) args[9] + (Stop) args[3], + (Float) args[4], + (ToolChoice) args[5], + (List) args[6], + (Float) args[7] ) ); @@ -65,7 +61,6 @@ public sealed interface Content extends NamedWriteable permits ContentObjects, C PARSER.declareObjectArray(constructorArg(), Message.PARSER::apply, new ParseField("messages")); PARSER.declareString(optionalConstructorArg(), new ParseField("model")); PARSER.declareLong(optionalConstructorArg(), new ParseField("max_completion_tokens")); - PARSER.declareInt(optionalConstructorArg(), new ParseField("n")); PARSER.declareField(optionalConstructorArg(), (p, c) -> parseStop(p), new ParseField("stop"), ObjectParser.ValueType.VALUE_ARRAY); PARSER.declareFloat(optionalConstructorArg(), new ParseField("temperature")); PARSER.declareField( @@ -76,7 +71,6 @@ public sealed interface Content extends NamedWriteable permits ContentObjects, C ); PARSER.declareObjectArray(optionalConstructorArg(), Tool.PARSER::apply, new ParseField("tools")); PARSER.declareFloat(optionalConstructorArg(), new ParseField("top_p")); - PARSER.declareString(optionalConstructorArg(), new ParseField("user")); } public static List getNamedWriteables() { @@ -90,18 +84,20 @@ public static List getNamedWriteables() { ); } + public static UnifiedCompletionRequest of(List messages) { + return new UnifiedCompletionRequest(messages, null, null, null, null, null, null, null); + } + public UnifiedCompletionRequest(StreamInput in) throws IOException { this( in.readCollectionAsImmutableList(Message::new), in.readOptionalString(), in.readOptionalVLong(), - in.readOptionalVInt(), in.readOptionalNamedWriteable(Stop.class), in.readOptionalFloat(), in.readOptionalNamedWriteable(ToolChoice.class), - in.readCollectionAsImmutableList(Tool::new), - in.readOptionalFloat(), - in.readOptionalString() + in.readOptionalCollectionAsList(Tool::new), + in.readOptionalFloat() ); } @@ -110,13 +106,11 @@ public void writeTo(StreamOutput out) throws IOException { out.writeCollection(messages); out.writeOptionalString(model); out.writeOptionalVLong(maxCompletionTokens); - out.writeOptionalVInt(n); out.writeOptionalNamedWriteable(stop); out.writeOptionalFloat(temperature); out.writeOptionalNamedWriteable(toolChoice); out.writeOptionalCollection(tools); out.writeOptionalFloat(topP); - out.writeOptionalString(user); } public record Message(Content content, String role, @Nullable String name, @Nullable String toolCallId, List toolCalls) @@ -155,7 +149,7 @@ public Message(StreamInput in) throws IOException { in.readString(), in.readOptionalString(), in.readOptionalString(), - in.readCollectionAsImmutableList(ToolCall::new) + in.readOptionalCollectionAsList(ToolCall::new) ); } @@ -165,7 +159,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(role); out.writeOptionalString(name); out.writeOptionalString(toolCallId); - out.writeCollection(toolCalls); + out.writeOptionalCollection(toolCalls); } } diff --git a/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java index 120e8a8494424..7ac1d49a817b4 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java @@ -254,11 +254,8 @@ public abstract class ESTestCase extends LuceneTestCase { protected static final List JAVA_TIMEZONE_IDS; protected static final List JAVA_ZONE_IDS; - private static final AtomicInteger portGenerator = new AtomicInteger(); - private static final Collection loggedLeaks = new ArrayList<>(); - private HeaderWarningAppender headerWarningAppender; @AfterClass @@ -268,13 +265,9 @@ public static void resetPortCounter() { // Allows distinguishing between parallel test processes public static final String TEST_WORKER_VM_ID; - public static final String TEST_WORKER_SYS_PROPERTY = "org.gradle.test.worker"; - public static final String DEFAULT_TEST_WORKER_ID = "--not-gradle--"; - public static final String FIPS_SYSPROP = "tests.fips.enabled"; - private static final SetOnce WARN_SECURE_RANDOM_FIPS_NOT_DETERMINISTIC = new SetOnce<>(); static { @@ -443,7 +436,6 @@ private static void setTestSysProps(Random random) { // ----------------------------------------------------------------- // Suite and test case setup/cleanup. // ----------------------------------------------------------------- - @Rule public RuleChain failureAndSuccessEvents = RuleChain.outerRule(new TestRuleAdapter() { @Override @@ -482,7 +474,9 @@ public static TransportAddress buildNewFakeTransportAddress() { */ protected void afterIfFailed(List errors) {} - /** called after a test is finished, but only if successful */ + /** + * called after a test is finished, but only if successful + */ protected void afterIfSuccessful() throws Exception {} /** @@ -492,6 +486,7 @@ protected void afterIfSuccessful() throws Exception {} @Target({ ElementType.TYPE }) @Inherited public @interface WithoutSecurityManager { + } private static Closeable securityManagerRestorer; @@ -668,6 +663,7 @@ protected List filteredWarnings() { /** * Convenience method to assert warnings for settings deprecations and general deprecation warnings. + * * @param settings the settings that are expected to be deprecated * @param warnings other expected general deprecation warnings */ @@ -688,6 +684,7 @@ protected final void assertSettingDeprecationsAndWarnings(final Setting[] set /** * Convenience method to assert warnings for settings deprecations and general deprecation warnings. All warnings passed to this method * are assumed to be at WARNING level. + * * @param expectedWarnings expected general deprecation warning messages. */ protected final void assertWarnings(String... expectedWarnings) { @@ -702,6 +699,7 @@ protected final void assertWarnings(String... expectedWarnings) { /** * Convenience method to assert warnings for settings deprecations and general deprecation warnings. All warnings passed to this method * are assumed to be at CRITICAL level. + * * @param expectedWarnings expected general deprecation warning messages. */ protected final void assertCriticalWarnings(String... expectedWarnings) { @@ -769,20 +767,19 @@ private void resetDeprecationLogger() { } private static final List statusData = new ArrayList<>(); + static { // ensure that the status logger is set to the warn level so we do not miss any warnings with our Log4j usage StatusLogger.getLogger().setLevel(Level.WARN); // Log4j will write out status messages indicating problems with the Log4j usage to the status logger; we hook into this logger and // assert that no such messages were written out as these would indicate a problem with our logging configuration StatusLogger.getLogger().registerListener(new StatusConsoleListener(Level.WARN) { - @Override public void log(StatusData data) { synchronized (statusData) { statusData.add(data); } } - }); } @@ -843,8 +840,9 @@ public final void ensureAllSearchContextsReleased() throws Exception { // mockdirectorywrappers currently set this boolean if checkindex fails // TODO: can we do this cleaner??? - - /** MockFSDirectoryService sets this: */ + /** + * MockFSDirectoryService sets this: + */ public static final List checkIndexFailures = new CopyOnWriteArrayList<>(); @Before @@ -1139,37 +1137,48 @@ public static LongStream randomLongs(long streamSize) { * Returns a random BigInteger uniformly distributed over the range 0 to (2^64 - 1) inclusive * Currently BigIntegers are only used for unsigned_long field type, where the max value is 2^64 - 1. * Modify this random generator if a wider range for BigIntegers is necessary. + * * @return a random bigInteger in the range [0 ; 2^64 - 1] */ public static BigInteger randomBigInteger() { return new BigInteger(64, random()); } - /** A random integer from 0..max (inclusive). */ + /** + * A random integer from 0..max (inclusive). + */ public static int randomInt(int max) { return RandomizedTest.randomInt(max); } - /** A random byte size value. */ + /** + * A random byte size value. + */ public static ByteSizeValue randomByteSizeValue() { return ByteSizeValue.ofBytes(randomLongBetween(0L, Long.MAX_VALUE >> 16)); } - /** Pick a random object from the given array. The array must not be empty. */ + /** + * Pick a random object from the given array. The array must not be empty. + */ @SafeVarargs @SuppressWarnings("varargs") public static T randomFrom(T... array) { return randomFrom(random(), array); } - /** Pick a random object from the given array. The array must not be empty. */ + /** + * Pick a random object from the given array. The array must not be empty. + */ @SafeVarargs @SuppressWarnings("varargs") public static T randomFrom(Random random, T... array) { return RandomPicks.randomFrom(random, array); } - /** Pick a random object from the given array of suppliers. The array must not be empty. */ + /** + * Pick a random object from the given array of suppliers. The array must not be empty. + */ @SafeVarargs @SuppressWarnings("varargs") public static T randomFrom(Random random, Supplier... array) { @@ -1177,17 +1186,23 @@ public static T randomFrom(Random random, Supplier... array) { return supplier.get(); } - /** Pick a random object from the given list. */ + /** + * Pick a random object from the given list. + */ public static T randomFrom(List list) { return RandomPicks.randomFrom(random(), list); } - /** Pick a random object from the given collection. */ + /** + * Pick a random object from the given collection. + */ public static T randomFrom(Collection collection) { return randomFrom(random(), collection); } - /** Pick a random object from the given collection. */ + /** + * Pick a random object from the given collection. + */ public static T randomFrom(Random random, Collection collection) { return RandomPicks.randomFrom(random, collection); } @@ -1205,27 +1220,27 @@ public static SecureString randomSecureStringOfLength(int codeUnits) { return new SecureString(randomAlpha.toCharArray()); } - public static String randomNullOrAlphaOfLength(int codeUnits) { + public static String randomAlphaOfLengthOrNull(int codeUnits) { return randomBoolean() ? null : randomAlphaOfLength(codeUnits); } - public static Long randomNullOrLong() { + public static Long randomLongOrNull() { return randomBoolean() ? null : randomLong(); } - public static Long randomNullOrPositiveLong() { - return randomBoolean() ? null : randomLongBetween(0L, Long.MAX_VALUE); + public static Long randomPositiveLongOrNull() { + return randomBoolean() ? null : randomNonNegativeLong(); } - public static Integer randomNullOrInt() { + public static Integer randomIntOrNull() { return randomBoolean() ? null : randomInt(); } - public static Integer randomNullOrPositiveInt() { - return randomBoolean() ? null : randomIntBetween(0, Integer.MAX_VALUE); + public static Integer randomPositiveIntOrNull() { + return randomBoolean() ? null : randomNonNegativeInt(); } - public static Float randomNullOrFloat() { + public static Float randomFloatOrNull() { return randomBoolean() ? null : randomFloat(); } @@ -1274,9 +1289,9 @@ public static String randomRealisticUnicodeOfCodepointLength(int codePoints) { /** * @param maxArraySize The maximum number of elements in the random array - * @param stringSize The length of each String in the array - * @param allowNull Whether the returned array may be null - * @param allowEmpty Whether the returned array may be empty (have zero elements) + * @param stringSize The length of each String in the array + * @param allowNull Whether the returned array may be null + * @param allowEmpty Whether the returned array may be empty (have zero elements) */ public static String[] generateRandomStringArray(int maxArraySize, int stringSize, boolean allowNull, boolean allowEmpty) { if (allowNull && random().nextBoolean()) { @@ -1481,8 +1496,8 @@ public static boolean waitUntil(BooleanSupplier breakSupplier) { * {@link ESTestCase#assertBusy(CheckedRunnable)} instead. * * @param breakSupplier determines whether to return immediately or continue waiting. - * @param maxWaitTime the maximum amount of time to wait - * @param unit the unit of tie for maxWaitTime + * @param maxWaitTime the maximum amount of time to wait + * @param unit the unit of tie for maxWaitTime * @return the last value returned by breakSupplier */ public static boolean waitUntil(BooleanSupplier breakSupplier, long maxWaitTime, TimeUnit unit) { @@ -1553,7 +1568,9 @@ public static Path getResourceDataPath(Class clazz, String relativePath) { } } - /** Returns a random number of temporary paths. */ + /** + * Returns a random number of temporary paths. + */ public String[] tmpPaths() { final int numPaths = TestUtil.nextInt(random(), 1, 3); final String[] absPaths = new String[numPaths]; @@ -1590,18 +1607,24 @@ public Environment newEnvironment(Settings settings) { return TestEnvironment.newEnvironment(build); } - /** Return consistent index settings for the provided index version. */ + /** + * Return consistent index settings for the provided index version. + */ public static Settings.Builder settings(IndexVersion version) { return Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, version); } - /** Return consistent index settings for the provided index version, shard- and replica-count. */ + /** + * Return consistent index settings for the provided index version, shard- and replica-count. + */ public static Settings.Builder indexSettings(IndexVersion indexVersionCreated, int shards, int replicas) { return settings(indexVersionCreated).put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, shards) .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, replicas); } - /** Return consistent index settings for the provided shard- and replica-count. */ + /** + * Return consistent index settings for the provided shard- and replica-count. + */ public static Settings.Builder indexSettings(int shards, int replicas) { return Settings.builder() .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, shards) @@ -1705,6 +1728,7 @@ public XContentType randomVendorType() { } public static class GeohashGenerator extends CodepointSetGenerator { + private static final char[] ASCII_SET = "0123456789bcdefghjkmnpqrstuvwxyz".toCharArray(); public GeohashGenerator() { @@ -1863,6 +1887,7 @@ public static C copyNamedWriteable( /** * Same as {@link #copyNamedWriteable(NamedWriteable, NamedWriteableRegistry, Class)} but also allows to provide * a {@link TransportVersion} argument which will be used to write and read back the object. + * * @return */ @SuppressWarnings("unchecked") @@ -1989,12 +2014,16 @@ public static Script mockScript(String id) { return new Script(ScriptType.INLINE, MockScriptEngine.NAME, id, emptyMap()); } - /** Returns the suite failure marker: internal use only! */ + /** + * Returns the suite failure marker: internal use only! + */ public static TestRuleMarkFailure getSuiteFailureMarker() { return suiteFailureMarker; } - /** Compares two stack traces, ignoring module (which is not yet serialized) */ + /** + * Compares two stack traces, ignoring module (which is not yet serialized) + */ public static void assertArrayEquals(StackTraceElement expected[], StackTraceElement actual[]) { assertEquals(expected.length, actual.length); for (int i = 0; i < expected.length; i++) { @@ -2002,7 +2031,9 @@ public static void assertArrayEquals(StackTraceElement expected[], StackTraceEle } } - /** Compares two stack trace elements, ignoring module (which is not yet serialized) */ + /** + * Compares two stack trace elements, ignoring module (which is not yet serialized) + */ public static void assertEquals(StackTraceElement expected, StackTraceElement actual) { assertEquals(expected.getClassName(), actual.getClassName()); assertEquals(expected.getMethodName(), actual.getMethodName()); @@ -2123,23 +2154,19 @@ public static boolean inFipsJvm() { * worker, avoiding any unexpected interactions, although if we spawn enough test workers then we will wrap around to the beginning * again. */ - /** * Defines the size of the port range assigned to each worker, which must be large enough to supply enough ports to run the tests, but * not so large that we run out of ports. See also [NOTE: Port ranges for tests]. */ private static final int PORTS_PER_WORKER = 30; - /** * Defines the minimum port that test workers should use. See also [NOTE: Port ranges for tests]. */ protected static final int MIN_PRIVATE_PORT = 13301; - /** * Defines the maximum port that test workers should use. See also [NOTE: Port ranges for tests]. */ private static final int MAX_PRIVATE_PORT = 32767; - /** * Wrap around after reaching this worker ID. */ @@ -2197,6 +2224,7 @@ public static InetAddress randomIp(boolean v4) { } public static final class DeprecationWarning { + private final Level level; // Intentionally ignoring level for the sake of equality for now private final String message; @@ -2227,8 +2255,9 @@ public String toString() { /** * Call method at the beginning of a test to disable its execution * until a given Lucene version is released and integrated into Elasticsearch + * * @param luceneVersionWithFix the lucene release to wait for - * @param message an additional message or link with information on the fix + * @param message an additional message or link with information on the fix */ protected void skipTestWaitingForLuceneFix(org.apache.lucene.util.Version luceneVersionWithFix, String message) { final boolean currentVersionHasFix = IndexVersion.current().luceneVersion().onOrAfter(luceneVersionWithFix); @@ -2239,9 +2268,10 @@ protected void skipTestWaitingForLuceneFix(org.apache.lucene.util.Version lucene /** * In non-FIPS mode, get a deterministic SecureRandom SHA1PRNG/SUN instance seeded by deterministic LuceneTestCase.random(). * In FIPS mode, get a non-deterministic SecureRandom DEFAULT/BCFIPS instance seeded by deterministic LuceneTestCase.random(). + * * @return SecureRandom SHA1PRNG instance. * @throws NoSuchAlgorithmException SHA1PRNG or DEFAULT algorithm not found. - * @throws NoSuchProviderException BCFIPS algorithm not found. + * @throws NoSuchProviderException BCFIPS algorithm not found. */ public static SecureRandom secureRandom() throws NoSuchAlgorithmException, NoSuchProviderException { return secureRandom(randomByteArrayOfLength(32)); @@ -2250,10 +2280,11 @@ public static SecureRandom secureRandom() throws NoSuchAlgorithmException, NoSuc /** * In non-FIPS mode, get a deterministic SecureRandom SHA1PRNG/SUN instance seeded by the input value. * In FIPS mode, get a non-deterministic SecureRandom DEFAULT/BCFIPS instance seeded by the input value. + * * @param seed Byte array to use for seeding the SecureRandom instance. * @return SecureRandom SHA1PRNG or DEFAULT/BCFIPS instance, depending on FIPS mode. * @throws NoSuchAlgorithmException SHA1PRNG or DEFAULT algorithm not found. - * @throws NoSuchProviderException BCFIPS algorithm not found. + * @throws NoSuchProviderException BCFIPS algorithm not found. */ public static SecureRandom secureRandom(final byte[] seed) throws NoSuchAlgorithmException, NoSuchProviderException { return inFipsJvm() ? secureRandomFips(seed) : secureRandomNonFips(seed); @@ -2261,6 +2292,7 @@ public static SecureRandom secureRandom(final byte[] seed) throws NoSuchAlgorith /** * Returns deterministic non-FIPS SecureRandom SHA1PRNG/SUN instance seeded by deterministic LuceneTestCase.random(). + * * @return Deterministic non-FIPS SecureRandom SHA1PRNG/SUN instance seeded by deterministic LuceneTestCase.random(). * @throws NoSuchAlgorithmException Exception if SHA1PRNG algorithm not found, such as missing SUN provider (unlikely). */ @@ -2270,6 +2302,7 @@ protected static SecureRandom secureRandomNonFips() throws NoSuchAlgorithmExcept /** * Returns non-deterministic FIPS SecureRandom DEFAULT/BCFIPS instance. Seeded. + * * @return Non-deterministic FIPS SecureRandom DEFAULT/BCFIPS instance. Seeded. * @throws NoSuchAlgorithmException Exception if DEFAULT algorithm not found, such as missing BCFIPS provider. */ @@ -2279,6 +2312,7 @@ protected static SecureRandom secureRandomFips() throws NoSuchAlgorithmException /** * Returns deterministic non-FIPS SecureRandom SHA1PRNG/SUN instance seeded by deterministic LuceneTestCase.random(). + * * @return Deterministic non-FIPS SecureRandom SHA1PRNG/SUN instance seeded by deterministic LuceneTestCase.random(). * @throws NoSuchAlgorithmException Exception if SHA1PRNG algorithm not found, such as missing SUN provider (unlikely). */ @@ -2290,6 +2324,7 @@ protected static SecureRandom secureRandomNonFips(final byte[] seed) throws NoSu /** * Returns non-deterministic FIPS SecureRandom DEFAULT/BCFIPS instance. Seeded. + * * @return Non-deterministic FIPS SecureRandom DEFAULT/BCFIPS instance. Seeded. * @throws NoSuchAlgorithmException Exception if DEFAULT algorithm not found, such as missing BCFIPS provider. */ @@ -2315,7 +2350,6 @@ protected static SecureRandom secureRandomFips(final byte[] seed) throws NoSuchA * in these requests. This constant can be used as a slightly more meaningful way to refer to the 30s default value in tests. */ public static final TimeValue TEST_REQUEST_TIMEOUT = TimeValue.THIRTY_SECONDS; - /** * The timeout used for the various "safe" wait methods such as {@link #safeAwait} and {@link #safeAcquire}. In tests we generally want * these things to complete almost immediately, but sometimes the CI runner executes things rather slowly so we use {@code 10s} as a @@ -2487,7 +2521,6 @@ public static Exception safeAwaitFailure(Consumer> consume * AssertionError} to trigger a test failure. * * @param responseType Class of listener response type, to aid type inference but otherwise ignored. - * * @return The exception with which the {@code listener} was completed exceptionally. */ public static Exception safeAwaitFailure(@SuppressWarnings("unused") Class responseType, Consumer> consumer) { @@ -2501,7 +2534,6 @@ public static Exception safeAwaitFailure(@SuppressWarnings("unused") Class ExpectedException safeAwaitFailure( @@ -2521,7 +2553,6 @@ public static ExpectedException * @param responseType Class of listener response type, to aid type inference but otherwise ignored. * @param exceptionType Expected unwrapped exception type. This method throws an {@link AssertionError} if a different type of exception * is seen. - * * @return The unwrapped exception with which the {@code listener} was completed exceptionally. */ public static ExpectedException safeAwaitAndUnwrapFailure( @@ -2643,8 +2674,9 @@ public static void startInParallel(int numberOfTasks, IntConsumer taskFactory) { /** * Run {@code numberOfTasks} parallel tasks that were created by the given {@code taskFactory}. On of the tasks will be run on the * calling thread, the rest will be run on a new thread. + * * @param numberOfTasks number of tasks to run in parallel - * @param taskFactory task factory + * @param taskFactory task factory */ public static void runInParallel(int numberOfTasks, IntConsumer taskFactory) { final ArrayList> futures = new ArrayList<>(numberOfTasks); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java index a9ca5e6da8720..0b9d32d0668c2 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java @@ -42,7 +42,7 @@ protected InferenceAction.Request createTestInstance() { randomFrom(TaskType.values()), randomAlphaOfLength(6), // null, - randomNullOrAlphaOfLength(10), + randomAlphaOfLengthOrNull(10), randomList(1, 5, () -> randomAlphaOfLength(8)), randomMap(0, 3, () -> new Tuple<>(randomAlphaOfLength(4), randomAlphaOfLength(4))), randomFrom(InputType.values()), diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequestTests.java index 612fadfab308f..7d839df00dc4c 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequestTests.java @@ -50,7 +50,6 @@ public void testParseAllFields() throws IOException { } ], "max_completion_tokens": 100, - "n": 1, "stop": ["stop"], "temperature": 0.1, "tools": [ @@ -71,8 +70,7 @@ public void testParseAllFields() throws IOException { "name": "some function" } }, - "top_p": 0.2, - "user": "user" + "top_p": 0.2 } """; @@ -98,7 +96,6 @@ public void testParseAllFields() throws IOException { ), "gpt-4o", 100L, - 1, new UnifiedCompletionRequest.StopValues(List.of("stop")), 0.1F, new UnifiedCompletionRequest.ToolChoiceObject( @@ -116,8 +113,7 @@ public void testParseAllFields() throws IOException { ) ) ), - 0.2F, - "user" + 0.2F ); assertThat(request, is(expected)); @@ -165,7 +161,6 @@ public void testParsing() throws IOException { ), "gpt-4o", null, - null, new UnifiedCompletionRequest.StopString("none"), null, new UnifiedCompletionRequest.ToolChoiceString("auto"), @@ -180,7 +175,6 @@ public void testParsing() throws IOException { ) ) ), - null, null ); @@ -191,15 +185,13 @@ public void testParsing() throws IOException { public static UnifiedCompletionRequest randomUnifiedCompletionRequest() { return new UnifiedCompletionRequest( randomList(5, UnifiedCompletionRequestTests::randomMessage), - randomNullOrAlphaOfLength(10), - randomNullOrPositiveLong(), - randomNullOrPositiveInt(), - randomNullOrStop(), - randomNullOrFloat(), - randomNullOrToolChoice(), - randomList(5, UnifiedCompletionRequestTests::randomTool), - randomNullOrFloat(), - randomNullOrAlphaOfLength(10) + randomAlphaOfLengthOrNull(10), + randomPositiveLongOrNull(), + randomStopOrNull(), + randomFloatOrNull(), + randomToolChoiceOrNull(), + randomToolListOrNull(), + randomFloatOrNull() ); } @@ -207,9 +199,9 @@ public static UnifiedCompletionRequest.Message randomMessage() { return new UnifiedCompletionRequest.Message( randomContent(), randomAlphaOfLength(10), - randomNullOrAlphaOfLength(10), - randomNullOrAlphaOfLength(10), - randomList(10, UnifiedCompletionRequestTests::randomToolCall) + randomAlphaOfLengthOrNull(10), + randomAlphaOfLengthOrNull(10), + randomToolCallListOrNull() ); } @@ -223,6 +215,10 @@ public static UnifiedCompletionRequest.ContentObject randomContentObject() { return new UnifiedCompletionRequest.ContentObject(randomAlphaOfLength(10), randomAlphaOfLength(10)); } + public static List randomToolCallListOrNull() { + return randomBoolean() ? randomList(10, UnifiedCompletionRequestTests::randomToolCall) : null; + } + public static UnifiedCompletionRequest.ToolCall randomToolCall() { return new UnifiedCompletionRequest.ToolCall(randomAlphaOfLength(10), randomToolCallFunctionField(), randomAlphaOfLength(10)); } @@ -231,7 +227,7 @@ public static UnifiedCompletionRequest.ToolCall.FunctionField randomToolCallFunc return new UnifiedCompletionRequest.ToolCall.FunctionField(randomAlphaOfLength(10), randomAlphaOfLength(10)); } - public static UnifiedCompletionRequest.Stop randomNullOrStop() { + public static UnifiedCompletionRequest.Stop randomStopOrNull() { return randomBoolean() ? randomStop() : null; } @@ -241,7 +237,7 @@ public static UnifiedCompletionRequest.Stop randomStop() { : new UnifiedCompletionRequest.StopValues(randomList(5, () -> randomAlphaOfLength(10))); } - public static UnifiedCompletionRequest.ToolChoice randomNullOrToolChoice() { + public static UnifiedCompletionRequest.ToolChoice randomToolChoiceOrNull() { return randomBoolean() ? randomToolChoice() : null; } @@ -255,13 +251,17 @@ public static UnifiedCompletionRequest.ToolChoiceObject.FunctionField randomTool return new UnifiedCompletionRequest.ToolChoiceObject.FunctionField(randomAlphaOfLength(10)); } + public static List randomToolListOrNull() { + return randomBoolean() ? randomList(10, UnifiedCompletionRequestTests::randomTool) : null; + } + public static UnifiedCompletionRequest.Tool randomTool() { return new UnifiedCompletionRequest.Tool(randomAlphaOfLength(10), randomToolFunctionField()); } public static UnifiedCompletionRequest.Tool.FunctionField randomToolFunctionField() { return new UnifiedCompletionRequest.Tool.FunctionField( - randomNullOrAlphaOfLength(10), + randomAlphaOfLengthOrNull(10), randomAlphaOfLength(10), null, randomOptionalBoolean() @@ -290,9 +290,6 @@ protected UnifiedCompletionRequest mutateInstance(UnifiedCompletionRequest insta @Override protected NamedWriteableRegistry getNamedWriteableRegistry() { - // List entries = new ArrayList<>(); - // entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables()); - // entries.addAll(InferenceNamedWriteablesProvider.getNamedWriteables()); return new NamedWriteableRegistry(UnifiedCompletionRequest.getNamedWriteables()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java index f3de9108d0f28..4d730be6aa6bd 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java @@ -15,7 +15,6 @@ import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; import org.elasticsearch.xpack.inference.external.openai.OpenAiChatCompletionResponseHandler; -import org.elasticsearch.xpack.inference.external.request.openai.OpenAiChatCompletionRequest; import org.elasticsearch.xpack.inference.external.request.openai.OpenAiUnifiedChatCompletionRequest; import org.elasticsearch.xpack.inference.external.response.openai.OpenAiChatCompletionResponseEntity; import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel; @@ -26,8 +25,8 @@ public class OpenAiCompletionRequestManager extends OpenAiRequestManager { private static final Logger logger = LogManager.getLogger(OpenAiCompletionRequestManager.class); - private static final ResponseHandler HANDLER = createCompletionHandler(); + static final String USER_ROLE = "user"; public static OpenAiCompletionRequestManager of(OpenAiChatCompletionModel model, ThreadPool threadPool) { return new OpenAiCompletionRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool)); @@ -48,11 +47,7 @@ public void execute( ActionListener listener ) { var chatCompletionInputs = inferenceInputs.castTo(ChatCompletionInput.class); - OpenAiChatCompletionRequest request = new OpenAiChatCompletionRequest( - chatCompletionInputs.getInputs(), - model, - chatCompletionInputs.stream() - ); + var request = new OpenAiUnifiedChatCompletionRequest(new UnifiedChatInput(chatCompletionInputs, USER_ROLE), model); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInput.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInput.java index a6f791f8be660..38c2a03548495 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInput.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInput.java @@ -9,6 +9,7 @@ import org.elasticsearch.inference.UnifiedCompletionRequest; +import java.util.List; import java.util.Objects; public class UnifiedChatInput extends InferenceInputs { @@ -20,6 +21,22 @@ public UnifiedChatInput(UnifiedCompletionRequest request, boolean stream) { this.stream = stream; } + public UnifiedChatInput(ChatCompletionInput completionInput, String roleValue) { + this( + completionInput.getInputs(), roleValue, completionInput.stream() + ); + } + + public UnifiedChatInput(List inputs, String roleValue, boolean stream) { + this(UnifiedCompletionRequest.of(convertToMessages(inputs, roleValue)), stream); + } + + private static List convertToMessages(List inputs, String roleValue) { + return inputs.stream() + .map(value -> new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString(value), roleValue, null, null, null)) + .toList(); + } + public UnifiedCompletionRequest getRequest() { return request; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java index 132a68c644eb9..3225ecd7941f9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java @@ -111,9 +111,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (unifiedRequest.maxCompletionTokens() != null) { builder.field(MAX_COMPLETION_TOKENS_FIELD, unifiedRequest.maxCompletionTokens()); } - if (unifiedRequest.n() != null) { - builder.field(NUMBER_OF_RETURNED_CHOICES_FIELD, unifiedRequest.n()); - } + + builder.field(NUMBER_OF_RETURNED_CHOICES_FIELD, 1); + if (unifiedRequest.stop() != null) { switch (unifiedRequest.stop()) { case UnifiedCompletionRequest.StopString stopString -> builder.field(STOP_FIELD, stopString.value()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModel.java index ceda56b78d0e3..7d79d64b3a771 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModel.java @@ -41,7 +41,6 @@ public static OpenAiChatCompletionModel of(OpenAiChatCompletionModel model, Map< } public static OpenAiChatCompletionModel of(OpenAiChatCompletionModel model, UnifiedCompletionRequest request) { - var requestTaskSettings = OpenAiChatCompletionRequestTaskSettings.fromUnifiedRequest(request); var originalModelServiceSettings = model.getServiceSettings(); var overriddenServiceSettings = new OpenAiChatCompletionServiceSettings( Objects.requireNonNullElse(request.model(), originalModelServiceSettings.modelId()), @@ -51,13 +50,12 @@ public static OpenAiChatCompletionModel of(OpenAiChatCompletionModel model, Unif originalModelServiceSettings.rateLimitSettings() ); - var overriddenTaskSettings = OpenAiChatCompletionTaskSettings.of(model.getTaskSettings(), requestTaskSettings); return new OpenAiChatCompletionModel( model.getInferenceEntityId(), model.getTaskType(), model.getConfigurations().getService(), overriddenServiceSettings, - overriddenTaskSettings, + model.getTaskSettings(), model.getSecretSettings() ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionRequestTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionRequestTaskSettings.java index 69978b800f8b1..f23956bc21a04 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionRequestTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionRequestTaskSettings.java @@ -49,9 +49,4 @@ public static OpenAiChatCompletionRequestTaskSettings fromMap(Map new OpenAiChatCompletionRequestEntity(List.of("abc"), null, "user", false)); - } - - public void testXContent_ThrowsIfMessagesAreNull() { - assertThrows(NullPointerException.class, () -> new OpenAiChatCompletionRequestEntity(null, "model", "user", false)); - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestTests.java similarity index 90% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequestTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestTests.java index e76b413f16943..6bfbc96b9bfe1 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestTests.java @@ -12,6 +12,7 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests; import java.io.IOException; @@ -26,7 +27,7 @@ import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; -public class OpenAiChatCompletionRequestTests extends ESTestCase { +public class OpenAiUnifiedChatCompletionRequestTests extends ESTestCase { public void testCreateRequest_WithUrlOrganizationUserDefined() throws IOException { var request = createRequest("www.google.com", "org", "secret", "abc", "model", "user"); @@ -41,11 +42,12 @@ public void testCreateRequest_WithUrlOrganizationUserDefined() throws IOExceptio assertThat(httpPost.getLastHeader(ORGANIZATION_HEADER).getValue(), is("org")); var requestMap = entityAsMap(httpPost.getEntity().getContent()); - assertThat(requestMap, aMapWithSize(4)); + assertThat(requestMap, aMapWithSize(5)); assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc")))); assertThat(requestMap.get("model"), is("model")); assertThat(requestMap.get("user"), is("user")); assertThat(requestMap.get("n"), is(1)); + assertTrue((Boolean) requestMap.get("stream")); } public void testCreateRequest_WithDefaultUrl() throws URISyntaxException, IOException { @@ -61,11 +63,12 @@ public void testCreateRequest_WithDefaultUrl() throws URISyntaxException, IOExce assertThat(httpPost.getLastHeader(ORGANIZATION_HEADER).getValue(), is("org")); var requestMap = entityAsMap(httpPost.getEntity().getContent()); - assertThat(requestMap, aMapWithSize(4)); + assertThat(requestMap, aMapWithSize(5)); assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc")))); assertThat(requestMap.get("model"), is("model")); assertThat(requestMap.get("user"), is("user")); assertThat(requestMap.get("n"), is(1)); + assertTrue((Boolean) requestMap.get("stream")); } public void testCreateRequest_WithDefaultUrlAndWithoutUserOrganization() throws URISyntaxException, IOException { @@ -81,10 +84,11 @@ public void testCreateRequest_WithDefaultUrlAndWithoutUserOrganization() throws assertNull(httpPost.getLastHeader(ORGANIZATION_HEADER)); var requestMap = entityAsMap(httpPost.getEntity().getContent()); - assertThat(requestMap, aMapWithSize(3)); + assertThat(requestMap, aMapWithSize(4)); assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc")))); assertThat(requestMap.get("model"), is("model")); assertThat(requestMap.get("n"), is(1)); + assertTrue((Boolean) requestMap.get("stream")); } public void testCreateRequest_WithStreaming() throws URISyntaxException, IOException { @@ -108,12 +112,13 @@ public void testTruncate_DoesNotReduceInputTextSize() throws URISyntaxException, var httpPost = (HttpPost) httpRequest.httpRequestBase(); var requestMap = entityAsMap(httpPost.getEntity().getContent()); - assertThat(requestMap, aMapWithSize(3)); + assertThat(requestMap, aMapWithSize(4)); // We do not truncate for OpenAi chat completions assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abcd")))); assertThat(requestMap.get("model"), is("model")); assertThat(requestMap.get("n"), is(1)); + assertTrue((Boolean) requestMap.get("stream")); } public void testTruncationInfo_ReturnsNull() { @@ -121,7 +126,7 @@ public void testTruncationInfo_ReturnsNull() { assertNull(request.getTruncationInfo()); } - public static OpenAiChatCompletionRequest createRequest( + public static OpenAiUnifiedChatCompletionRequest createRequest( @Nullable String url, @Nullable String org, String apiKey, @@ -132,7 +137,7 @@ public static OpenAiChatCompletionRequest createRequest( return createRequest(url, org, apiKey, input, model, user, false); } - public static OpenAiChatCompletionRequest createRequest( + public static OpenAiUnifiedChatCompletionRequest createRequest( @Nullable String url, @Nullable String org, String apiKey, @@ -142,7 +147,7 @@ public static OpenAiChatCompletionRequest createRequest( boolean stream ) { var chatCompletionModel = OpenAiChatCompletionModelTests.createChatCompletionModel(url, org, apiKey, model, user); - return new OpenAiChatCompletionRequest(List.of(input), chatCompletionModel, stream); + return new OpenAiUnifiedChatCompletionRequest(new UnifiedChatInput(List.of(input), "user", true), chatCompletionModel); } } From 3735bf1fa1d64cdfd0720c688b961893cf36645a Mon Sep 17 00:00:00 2001 From: Max Hniebergall Date: Mon, 2 Dec 2024 15:46:28 -0500 Subject: [PATCH 24/53] Parse response from OpenAI --- ...StreamingUnifiedChatCompletionResults.java | 211 ++++++++- .../openai/OpenAiStreamingProcessor.java | 372 ++-------------- .../OpenAiUnifiedStreamingProcessor.java | 402 ++++++++++-------- 3 files changed, 454 insertions(+), 531 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java index de035f40711b1..4712a5964b483 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java @@ -16,6 +16,8 @@ import org.elasticsearch.xcontent.ToXContent; import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; import java.util.Deque; import java.util.Iterator; import java.util.List; @@ -24,6 +26,7 @@ import java.util.concurrent.Flow; import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResults.COMPLETION; +import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResults.Result.RESULT; /** * Chat Completion results that only contain a Flow.Publisher. @@ -32,6 +35,10 @@ public record StreamingUnifiedChatCompletionResults(Flow.Publisher toXContentChunked(ToXContent.Params params } } - public record Result(String delta, String refusal, List toolCalls) implements ChunkedToXContent { - - private static final String RESULT = "delta"; - private static final String REFUSAL = "refusal"; - private static final String TOOL_CALLS = "tool_calls"; + private static final String REFUSAL_FIELD = "refusal"; + private static final String TOOL_CALLS_FIELD = "tool_calls"; + public static final String FINISH_REASON_FIELD = "finish_reason"; - public Result(String delta) { - this(delta, "", List.of()); - } + public record Result( + String delta, + String refusal, + List toolCalls, + String finishReason, + String model, + String object, + ChatCompletionChunk.Usage usage + ) implements ChunkedToXContent { @Override public Iterator toXContentChunked(ToXContent.Params params) { + Iterator toolCallsIterator = Collections.emptyIterator(); + if (toolCalls != null && toolCalls.isEmpty() == false) { + toolCallsIterator = Iterators.concat( + ChunkedToXContentHelper.startArray(TOOL_CALLS_FIELD), + Iterators.flatMap(toolCalls.iterator(), d -> d.toXContentChunked(params)), + ChunkedToXContentHelper.endArray() + ); + } + + Iterator usageIterator = Collections.emptyIterator(); + if (usage != null) { + usageIterator = Iterators.concat( + ChunkedToXContentHelper.startObject(USAGE_FIELD), + ChunkedToXContentHelper.field("completion_tokens", usage.completionTokens()), + ChunkedToXContentHelper.field("prompt_tokens", usage.promptTokens()), + ChunkedToXContentHelper.field("total_tokens", usage.totalTokens()), + ChunkedToXContentHelper.endObject() + ); + } + return Iterators.concat( ChunkedToXContentHelper.startObject(), ChunkedToXContentHelper.field(RESULT, delta), - ChunkedToXContentHelper.field(REFUSAL, refusal), - ChunkedToXContentHelper.startArray(TOOL_CALLS), - Iterators.flatMap(toolCalls.iterator(), t -> t.toXContentChunked(params)), - ChunkedToXContentHelper.endArray(), + ChunkedToXContentHelper.field(REFUSAL_FIELD, refusal), + toolCallsIterator, + ChunkedToXContentHelper.field(FINISH_REASON_FIELD, finishReason), + ChunkedToXContentHelper.field(MODEL_FIELD, model), + ChunkedToXContentHelper.field(OBJECT_FIELD, object), + usageIterator, ChunkedToXContentHelper.endObject() ); } @@ -178,4 +211,158 @@ public String toString() { + '}'; } } + + public static class ChatCompletionChunk { + private final String id; + private List choices; + private final String model; + private final String object; + private ChatCompletionChunk.Usage usage; + + public ChatCompletionChunk(String id, List choices, String model, String object, ChatCompletionChunk.Usage usage) { + this.id = id; + this.choices = choices; + this.model = model; + this.object = object; + this.usage = usage; + } + + public ChatCompletionChunk( + String id, + ChatCompletionChunk.Choice[] choices, + String model, + String object, + ChatCompletionChunk.Usage usage + ) { + this.id = id; + this.choices = Arrays.stream(choices).toList(); + this.model = model; + this.object = object; + this.usage = usage; + } + + public String getId() { + return id; + } + + public List getChoices() { + return choices; + } + + public String getModel() { + return model; + } + + public String getObject() { + return object; + } + + public ChatCompletionChunk.Usage getUsage() { + return usage; + } + + public static class Choice { + private final ChatCompletionChunk.Choice.Delta delta; + private final String finishReason; + private final int index; + + public Choice(ChatCompletionChunk.Choice.Delta delta, String finishReason, int index) { + this.delta = delta; + this.finishReason = finishReason; + this.index = index; + } + + public ChatCompletionChunk.Choice.Delta getDelta() { + return delta; + } + + public String getFinishReason() { + return finishReason; + } + + public int getIndex() { + return index; + } + + public static class Delta { + private final String content; + private final String refusal; + private final String role; + private List toolCalls; + + public Delta(String content, String refusal, String role, List toolCalls) { + this.content = content; + this.refusal = refusal; + this.role = role; + this.toolCalls = toolCalls; + } + + public String getContent() { + return content; + } + + public String getRefusal() { + return refusal; + } + + public String getRole() { + return role; + } + + public List getToolCalls() { + return toolCalls; + } + + public static class ToolCall { + private final int index; + private final String id; + public ChatCompletionChunk.Choice.Delta.ToolCall.Function function; + private final String type; + + public ToolCall(int index, String id, ChatCompletionChunk.Choice.Delta.ToolCall.Function function, String type) { + this.index = index; + this.id = id; + this.function = function; + this.type = type; + } + + public int getIndex() { + return index; + } + + public String getId() { + return id; + } + + public ChatCompletionChunk.Choice.Delta.ToolCall.Function getFunction() { + return function; + } + + public String getType() { + return type; + } + + public static class Function { + private final String arguments; + private final String name; + + public Function(String arguments, String name) { + this.arguments = arguments; + this.name = name; + } + + public String getArguments() { + return arguments; + } + + public String getName() { + return name; + } + } + } + } + } + + public record Usage(int completionTokens, int promptTokens, int totalTokens) {} + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessor.java index 5f0339d402231..6e006fe255956 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessor.java @@ -11,8 +11,6 @@ import org.apache.logging.log4j.Logger; import org.elasticsearch.common.xcontent.ChunkedToXContent; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; -import org.elasticsearch.xcontent.ConstructingObjectParser; -import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; @@ -29,15 +27,12 @@ import java.util.Iterator; import java.util.Objects; import java.util.function.Predicate; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.Deque; -import java.util.Iterator; -import java.util.List; import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList; +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.consumeUntilObjectEnd; import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken; +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField; /** * Parses the OpenAI chat completion streaming responses. @@ -153,353 +148,44 @@ private Iterator parse(XContentParserConf XContentParser.Token token = jsonParser.currentToken(); ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser); - ChatCompletionChunk chunk = ChatCompletionChunkParser.parse(jsonParser); - - List results = new ArrayList<>(); - for (ChatCompletionChunk.Choice choice : chunk.getChoices()) { - String content = choice.getDelta().getContent(); - String refusal = choice.getDelta().getRefusal(); - List toolCalls = parseToolCalls(choice.getDelta().getToolCalls()); - results.add(new StreamingChatCompletionResults.Result(content, refusal, toolCalls)); - } - - return results.iterator(); - } - } - - private List parseToolCalls(List toolCalls) { - List parsedToolCalls = new ArrayList<>(); - for (ChatCompletionChunk.Choice.Delta.ToolCall toolCall : toolCalls) { - int index = toolCall.getIndex(); - String id = toolCall.getId(); - String functionName = toolCall.getFunction() != null ? toolCall.getFunction().getName() : null; - String functionArguments = toolCall.getFunction() != null ? toolCall.getFunction().getArguments() : null; - parsedToolCalls.add(new StreamingChatCompletionResults.ToolCall(index, id, functionName, functionArguments)); - } - return parsedToolCalls; - } - - public static class ChatCompletionChunk { - private final String id; - private List choices; - private final String model; - private final String object; - private Usage usage; - - public ChatCompletionChunk(String id, List choices, String model, String object, Usage usage) { - this.id = id; - this.choices = choices; - this.model = model; - this.object = object; - this.usage = usage; - } - - public ChatCompletionChunk(String id, Choice[] choices, String model, String object, Usage usage) { - this.id = id; - this.choices = Arrays.stream(choices).toList(); - this.model = model; - this.object = object; - this.usage = usage; - } - - public String getId() { - return id; - } - - public List getChoices() { - return choices; - } - - public String getModel() { - return model; - } - - public String getObject() { - return object; - } - - public Usage getUsage() { - return usage; - } - - public static class Choice { - private final Delta delta; - private final String finishReason; - private final int index; - - public Choice(Delta delta, String finishReason, int index) { - this.delta = delta; - this.finishReason = finishReason; - this.index = index; - } - - public Delta getDelta() { - return delta; - } - - public String getFinishReason() { - return finishReason; - } - - public int getIndex() { - return index; - } - - public static class Delta { - private final String content; - private final String refusal; - private final String role; - private List toolCalls; - - public Delta(String content, String refusal, String role, List toolCalls) { - this.content = content; - this.refusal = refusal; - this.role = role; - this.toolCalls = toolCalls; - } - - public String getContent() { - return content; - } - - public String getRefusal() { - return refusal; - } + positionParserAtTokenAfterField(jsonParser, CHOICES_FIELD, FAILED_TO_FIND_FIELD_TEMPLATE); - public String getRole() { - return role; - } - - public List getToolCalls() { - return toolCalls; - } + return parseList(jsonParser, parser -> { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); - public static class ToolCall { - private final int index; - private final String id; - private Function function; - private final String type; + positionParserAtTokenAfterField(parser, DELTA_FIELD, FAILED_TO_FIND_FIELD_TEMPLATE); - public ToolCall(int index, String id, Function function, String type) { - this.index = index; - this.id = id; - this.function = function; - this.type = type; - } + var currentToken = parser.currentToken(); - public int getIndex() { - return index; - } + ensureExpectedToken(XContentParser.Token.START_OBJECT, currentToken, parser); - public String getId() { - return id; - } + currentToken = parser.nextToken(); - public Function getFunction() { - return function; + // continue until the end of delta + while (currentToken != null && currentToken != XContentParser.Token.END_OBJECT) { + if (currentToken == XContentParser.Token.START_OBJECT || currentToken == XContentParser.Token.START_ARRAY) { + parser.skipChildren(); } - public String getType() { - return type; + if (currentToken == XContentParser.Token.FIELD_NAME && parser.currentName().equals(CONTENT_FIELD)) { + parser.nextToken(); + ensureExpectedToken(XContentParser.Token.VALUE_STRING, parser.currentToken(), parser); + var content = parser.text(); + consumeUntilObjectEnd(parser); // end delta + consumeUntilObjectEnd(parser); // end choices + return content; } - public static class Function { - private final String arguments; - private final String name; - - public Function(String arguments, String name) { - this.arguments = arguments; - this.name = name; - } - - public String getArguments() { - return arguments; - } - - public String getName() { - return name; - } - } + currentToken = parser.nextToken(); } - } - } - - public static class Usage { - private final int completionTokens; - private final int promptTokens; - private final int totalTokens; - - public Usage(int completionTokens, int promptTokens, int totalTokens) { - this.completionTokens = completionTokens; - this.promptTokens = promptTokens; - this.totalTokens = totalTokens; - } - - public int getCompletionTokens() { - return completionTokens; - } - - public int getPromptTokens() { - return promptTokens; - } - - public int getTotalTokens() { - return totalTokens; - } - } - } - - public static class ChatCompletionChunkParser { - @SuppressWarnings("unchecked") - private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( - "chat_completion_chunk", - true, - args -> new ChatCompletionChunk( - (String) args[0], - (ChatCompletionChunk.Choice[]) args[1], - (String) args[2], - (String) args[3], - (ChatCompletionChunk.Usage) args[4] - ) - - /** - * TODO - * Caused by: java.lang.ClassCastException: class java.lang.String cannot be cast to class [Lorg.elasticsearch.xpack.inference.external.openai.OpenAiStreamingProcessor$ChatCompletionChunk$Choice; (java.lang.String is in module java.base of loader 'bootstrap'; [Lorg.elasticsearch.xpack.inference.external.openai.OpenAiStreamingProcessor$ChatCompletionChunk$Choice; is in module org.elasticsearch.inference@9.0.0-SNAPSHOT of loader jdk.internal.loader.Loader @611c3eae) - * at org.elasticsearch.inference@9.0.0-SNAPSHOT/org.elasticsearch.xpack.inference.external.openai.OpenAiStreamingProcessor$ChatCompletionChunkParser.lambda$static$0(OpenAiStreamingProcessor.java:354) - * at org.elasticsearch.xcontent@9.0.0-SNAPSHOT/org.elasticsearch.xcontent.ConstructingObjectParser.lambda$new$2(ConstructingObjectParser.java:130) - * at org.elasticsearch.xcontent@9.0.0-SNAPSHOT/org.elasticsearch.xcontent.ConstructingObjectParser$Target.buildTarget(ConstructingObjectParser.java:555) - */ - ); - - static { - PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("id")); - PARSER.declareObjectArray( - (chunk, choices) -> chunk.choices = choices, - (p, c) -> ChoiceParser.parse(p), - new ParseField("choices") - ); - PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("model")); - PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("object")); - PARSER.declareObject((chunk, usage) -> chunk.usage = usage, (p, c) -> UsageParser.parse(p), new ParseField("usage")); - } - - public static ChatCompletionChunk parse(XContentParser parser) throws IOException { - return PARSER.parse(parser, null); - } - - private static class ChoiceParser { - private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( - "choice", - true, - args -> new ChatCompletionChunk.Choice((ChatCompletionChunk.Choice.Delta) args[0], (String) args[1], (int) args[2]) - ); - - static { - PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> DeltaParser.parse(p), new ParseField("delta")); - PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("finish_reason")); - PARSER.declareInt(ConstructingObjectParser.constructorArg(), new ParseField("index")); - } - - public static ChatCompletionChunk.Choice parse(XContentParser parser) throws IOException { - return PARSER.apply(parser, null); - } - } - - private static class DeltaParser { - @SuppressWarnings("unchecked") - private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( - "delta", - true, - args -> new ChatCompletionChunk.Choice.Delta( - (String) args[0], - (String) args[1], - (String) args[2], - (List) args[3] - ) - ); - - static { - PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField("content")); - PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField("refusal")); - PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField("role")); - PARSER.declareObjectArray( - (delta, toolCalls) -> delta.toolCalls = toolCalls, - (p, c) -> ToolCallParser.parse(p), - new ParseField("tool_calls") - ); - - } - - public static ChatCompletionChunk.Choice.Delta parse(XContentParser parser) throws IOException { - return PARSER.parse(parser, null); - } - } - - private static class ToolCallParser { - private static final ConstructingObjectParser PARSER = - new ConstructingObjectParser<>( - "tool_call", - true, - args -> new ChatCompletionChunk.Choice.Delta.ToolCall( - (int) args[0], - (String) args[1], - (ChatCompletionChunk.Choice.Delta.ToolCall.Function) args[2], - (String) args[3] - ) - ); - - static { - PARSER.declareInt(ConstructingObjectParser.constructorArg(), new ParseField("index")); - PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField("id")); - PARSER.declareObject( - (toolCall, function) -> toolCall.function = function, - (p, c) -> FunctionParser.parse(p), - new ParseField("function") - ); - PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField("type")); - } - - return parseList(jsonParser, parser -> { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); - public static ChatCompletionChunk.Choice.Delta.ToolCall parse(XContentParser parser) throws IOException { - return PARSER.parse(parser, null); - } - } - - private static class FunctionParser { - private static final ConstructingObjectParser PARSER = - new ConstructingObjectParser<>( - "function", - true, - args -> new ChatCompletionChunk.Choice.Delta.ToolCall.Function((String) args[0], (String) args[1]) - ); - static { - PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField("arguments")); - PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField("name")); - } - - public static ChatCompletionChunk.Choice.Delta.ToolCall.Function parse(XContentParser parser) throws IOException { - return PARSER.parse(parser, null); - } - } - - private static class UsageParser { - private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( - "usage", - true, - args -> new ChatCompletionChunk.Usage((int) args[0], (int) args[1], (int) args[2]) - ); - - static { - PARSER.declareInt(ConstructingObjectParser.constructorArg(), new ParseField("completion_tokens")); - PARSER.declareInt(ConstructingObjectParser.constructorArg(), new ParseField("prompt_tokens")); - PARSER.declareInt(ConstructingObjectParser.constructorArg(), new ParseField("total_tokens")); - } - - public static ChatCompletionChunk.Usage parse(XContentParser parser) throws IOException { - return PARSER.parse(parser, null); - } + consumeUntilObjectEnd(parser); // end choices + return ""; // stopped + }).stream() + .filter(Objects::nonNull) + .filter(Predicate.not(String::isEmpty)) + .map(StreamingChatCompletionResults.Result::new) + .iterator(); } } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java index 5d8ac7d4555c5..2baf0d6980054 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java @@ -11,6 +11,8 @@ import org.apache.logging.log4j.Logger; import org.elasticsearch.common.xcontent.ChunkedToXContent; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; @@ -27,85 +29,12 @@ import java.util.Deque; import java.util.Iterator; import java.util.List; -import java.util.Objects; import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList; -import static org.elasticsearch.xpack.inference.external.response.XContentUtils.consumeUntilObjectEnd; import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken; -import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField; - -/** - * Parses the OpenAI chat completion streaming responses. - * For a request like: - * - *
- *     
- *         {
- *             "inputs": ["Please summarize this text: some text", "Answer the following question: Question"]
- *         }
- *     
- * 
- * - * The response would look like: - * - *
- *     
- *         {
- *              "id": "chatcmpl-123",
- *              "object": "chat.completion",
- *              "created": 1677652288,
- *              "model": "gpt-3.5-turbo-0613",
- *              "system_fingerprint": "fp_44709d6fcb",
- *              "choices": [
- *                  {
- *                      "index": 0,
- *                      "delta": {
- *                          "content": "\n\nHello there, how ",
- *                      },
- *                      "finish_reason": ""
- *                  }
- *              ]
- *          }
- *
- *         {
- *              "id": "chatcmpl-123",
- *              "object": "chat.completion",
- *              "created": 1677652288,
- *              "model": "gpt-3.5-turbo-0613",
- *              "system_fingerprint": "fp_44709d6fcb",
- *              "choices": [
- *                  {
- *                      "index": 1,
- *                      "delta": {
- *                          "content": "may I assist you today?",
- *                      },
- *                      "finish_reason": ""
- *                  }
- *              ]
- *          }
- *
- *         {
- *              "id": "chatcmpl-123",
- *              "object": "chat.completion",
- *              "created": 1677652288,
- *              "model": "gpt-3.5-turbo-0613",
- *              "system_fingerprint": "fp_44709d6fcb",
- *              "choices": [
- *                  {
- *                      "index": 2,
- *                      "delta": {},
- *                      "finish_reason": "stop"
- *                  }
- *              ]
- *          }
- *
- *          [DONE]
- *     
- * 
- */ + public class OpenAiUnifiedStreamingProcessor extends DelegatingProcessor, ChunkedToXContent> { - private static final Logger log = LogManager.getLogger(OpenAiUnifiedStreamingProcessor.class); + private static final Logger logger = LogManager.getLogger(OpenAiUnifiedStreamingProcessor.class); private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in OpenAI chat completions response"; private static final String CHOICES_FIELD = "choices"; @@ -114,6 +43,9 @@ public class OpenAiUnifiedStreamingProcessor extends DelegatingProcessor item) throws Exception { @@ -126,7 +58,7 @@ protected void next(Deque item) throws Exception { var delta = parse(parserConfig, event); delta.forEachRemaining(results::offer); } catch (Exception e) { - log.warn("Failed to parse event from inference provider: {}", event); + logger.warn("Failed to parse event from inference provider: {}", event); throw e; } } @@ -145,118 +77,236 @@ private Iterator parse(XContentPar return Collections.emptyIterator(); } - System.out.println(event.value()); try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, event.value())) { moveToFirstToken(jsonParser); XContentParser.Token token = jsonParser.currentToken(); ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser); - positionParserAtTokenAfterField(jsonParser, CHOICES_FIELD, FAILED_TO_FIND_FIELD_TEMPLATE); - - return parseList(jsonParser, parser -> { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); - - positionParserAtTokenAfterField(parser, DELTA_FIELD, FAILED_TO_FIND_FIELD_TEMPLATE); - - var currentToken = parser.currentToken(); - - ensureExpectedToken(XContentParser.Token.START_OBJECT, currentToken, parser); - - String content = null; - String refusal = null; - List toolCalls = new ArrayList<>(); - - currentToken = parser.nextToken(); - - // continue until the end of delta - while (currentToken != null && currentToken != XContentParser.Token.END_OBJECT) { - if (currentToken == XContentParser.Token.START_OBJECT || currentToken == XContentParser.Token.START_ARRAY) { - parser.skipChildren(); - } - - if (currentToken == XContentParser.Token.FIELD_NAME) { - switch (parser.currentName()) { - case CONTENT_FIELD: - parser.nextToken(); - if (parser.currentToken() == XContentParser.Token.VALUE_STRING) { - content = parser.text(); - } - // ensureExpectedToken(XContentParser.Token.VALUE_STRING, parser.currentToken(), parser); - break; - case REFUSAL_FIELD: - parser.nextToken(); - if (parser.currentToken() == XContentParser.Token.VALUE_STRING) { - refusal = parser.text(); - } - // ensureExpectedToken(XContentParser.Token.VALUE_STRING, parser.currentToken(), parser); - break; - case TOOL_CALLS_FIELD: - parser.nextToken(); - ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); - toolCalls = parseToolCalls(parser); - break; - } - } - - currentToken = parser.nextToken(); - } + StreamingUnifiedChatCompletionResults.ChatCompletionChunk chunk = ChatCompletionChunkParser.parse(jsonParser); - // consumeUntilObjectEnd(parser); // end delta - consumeUntilObjectEnd(parser); // end choices + List results = new ArrayList<>(); + for (StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice choice : chunk.getChoices()) { + String content = choice.getDelta().getContent(); + String refusal = choice.getDelta().getRefusal(); + List toolCalls = parseToolCalls(choice.getDelta().getToolCalls()); + results.add( + new StreamingUnifiedChatCompletionResults.Result( + content, + refusal, + toolCalls, + choice.getFinishReason(), + chunk.getModel(), + chunk.getObject(), + chunk.getUsage() + ) + ); + } - return new StreamingUnifiedChatCompletionResults.Result(content, refusal, toolCalls); - }).stream().filter(Objects::nonNull).iterator(); + return results.iterator(); } } - private List parseToolCalls(XContentParser parser) throws IOException { - List toolCalls = new ArrayList<>(); - while (parser.nextToken() != XContentParser.Token.END_ARRAY) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); - int index = -1; - String id = null; - String functionName = null; - String functionArguments = null; - - while (parser.nextToken() != XContentParser.Token.END_OBJECT) { - if (parser.currentToken() == XContentParser.Token.FIELD_NAME) { - switch (parser.currentName()) { - case "index": - parser.nextToken(); - ensureExpectedToken(XContentParser.Token.VALUE_NUMBER, parser.currentToken(), parser); - index = parser.intValue(); - break; - case "id": - parser.nextToken(); - ensureExpectedToken(XContentParser.Token.VALUE_STRING, parser.currentToken(), parser); - id = parser.text(); - break; - case "function": - parser.nextToken(); - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); - while (parser.nextToken() != XContentParser.Token.END_OBJECT) { - if (parser.currentToken() == XContentParser.Token.FIELD_NAME) { - switch (parser.currentName()) { - case "name": - parser.nextToken(); - ensureExpectedToken(XContentParser.Token.VALUE_STRING, parser.currentToken(), parser); - functionName = parser.text(); - break; - case "arguments": - parser.nextToken(); - ensureExpectedToken(XContentParser.Token.VALUE_STRING, parser.currentToken(), parser); - functionArguments = parser.text(); - break; - } - } - } - break; - } - } + private List parseToolCalls( + List toolCalls + ) { + List parsedToolCalls = new ArrayList<>(); + + if (toolCalls == null || toolCalls.isEmpty()) { + return null; + } + + for (StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall toolCall : toolCalls) { + int index = toolCall.getIndex(); + String id = toolCall.getId(); + String functionName = toolCall.getFunction() != null ? toolCall.getFunction().getName() : null; + String functionArguments = toolCall.getFunction() != null ? toolCall.getFunction().getArguments() : null; + parsedToolCalls.add(new StreamingUnifiedChatCompletionResults.ToolCall(index, id, functionName, functionArguments)); + } + return parsedToolCalls; + } + + public static class ChatCompletionChunkParser { + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + "chat_completion_chunk", + true, + args -> new StreamingUnifiedChatCompletionResults.ChatCompletionChunk( + (String) args[0], + (List) args[1], + (String) args[2], + (String) args[3], + (StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage) args[4] + ) + + /** + * TODO + * Caused by: java.lang.ClassCastException: class java.lang.String cannot be cast to class + * [Lorg.elasticsearch.xpack.inference.external.openai.OpenAiStreamingProcessor$ChatCompletionChunk$Choice; + * (java.lang.String is in module java.base of loader 'bootstrap'; [Lorg.elasticsearch.xpack.inference.external.openai. + * OpenAiStreamingProcessor$ChatCompletionChunk$Choice; is in module org.elasticsearch.inference@9.0.0-SNAPSHOT of loader + * jdk.internal.loader.Loader @611c3eae) + * at org.elasticsearch.inference@9.0.0-SNAPSHOT/org.elasticsearch.xpack.inference.external.openai. + * OpenAiStreamingProcessor$ChatCompletionChunkParser.lambda$static$0(OpenAiStreamingProcessor.java:354) + * at org.elasticsearch.xcontent@9.0.0-SNAPSHOT/org.elasticsearch.xcontent.ConstructingObjectParser. + * lambda$new$2(ConstructingObjectParser.java:130) + * at org.elasticsearch.xcontent@9.0.0-SNAPSHOT/org.elasticsearch.xcontent.ConstructingObjectParser$Target. + * buildTarget(ConstructingObjectParser.java:555) + */ + ); + + static { + PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("id")); + PARSER.declareObjectArray( + ConstructingObjectParser.constructorArg(), + (p, c) -> ChatCompletionChunkParser.ChoiceParser.parse(p), + new ParseField(CHOICES_FIELD) + ); + PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("model")); + PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("object")); + PARSER.declareObject( + ConstructingObjectParser.optionalConstructorArg(), + (p, c) -> ChatCompletionChunkParser.UsageParser.parse(p), + new ParseField("usage") + ); + } + + public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk parse(XContentParser parser) throws IOException { + return PARSER.parse(parser, null); + } + + private static class ChoiceParser { + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + "choice", + true, + args -> new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice( + (StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta) args[0], + (String) args[1], + (int) args[2] + ) + ); + + static { + PARSER.declareObject( + ConstructingObjectParser.constructorArg(), + (p, c) -> ChatCompletionChunkParser.DeltaParser.parse(p), + new ParseField(DELTA_FIELD) + ); + PARSER.declareStringOrNull(ConstructingObjectParser.optionalConstructorArg(), new ParseField(FINISH_REASON_FIELD)); + PARSER.declareInt(ConstructingObjectParser.constructorArg(), new ParseField(INDEX_FIELD)); + } + + public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice parse(XContentParser parser) { + return PARSER.apply(parser, null); + } + } + + private static class DeltaParser { + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser< + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta, + Void> PARSER = new ConstructingObjectParser<>( + DELTA_FIELD, + true, + args -> new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta( + (String) args[0], + (String) args[1], + (String) args[2], + (List) args[3] + ) + ); + + static { + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(CONTENT_FIELD)); + PARSER.declareStringOrNull(ConstructingObjectParser.optionalConstructorArg(), new ParseField(REFUSAL_FIELD)); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(ROLE_FIELD)); + PARSER.declareObjectArray( + ConstructingObjectParser.optionalConstructorArg(), + (p, c) -> ChatCompletionChunkParser.ToolCallParser.parse(p), + new ParseField(TOOL_CALLS_FIELD) + ); + } + + public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta parse(XContentParser parser) + throws IOException { + return PARSER.parse(parser, null); + } + } + + private static class ToolCallParser { + private static final ConstructingObjectParser< + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall, + Void> PARSER = new ConstructingObjectParser<>( + "tool_call", + true, + args -> new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall( + (int) args[0], + (String) args[1], + (StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall.Function) args[2], + (String) args[3] + ) + ); + + static { + PARSER.declareInt(ConstructingObjectParser.constructorArg(), new ParseField(INDEX_FIELD)); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField("id")); + PARSER.declareObject( + (toolCall, function) -> toolCall.function = function, + (p, c) -> ChatCompletionChunkParser.FunctionParser.parse(p), + new ParseField("function") + ); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField("type")); + } + + public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall parse(XContentParser parser) + throws IOException { + return PARSER.parse(parser, null); + } + } + + private static class FunctionParser { + private static final ConstructingObjectParser< + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall.Function, + Void> PARSER = new ConstructingObjectParser<>( + "function", + true, + args -> new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall.Function( + (String) args[0], + (String) args[1] + ) + ); + + static { + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField("arguments")); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField("name")); + } + + public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall.Function parse( + XContentParser parser + ) throws IOException { + return PARSER.parse(parser, null); + } + } + + private static class UsageParser { + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + "usage", + true, + args -> new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage((int) args[0], (int) args[1], (int) args[2]) + ); + + static { + PARSER.declareInt(ConstructingObjectParser.constructorArg(), new ParseField("completion_tokens")); + PARSER.declareInt(ConstructingObjectParser.constructorArg(), new ParseField("prompt_tokens")); + PARSER.declareInt(ConstructingObjectParser.constructorArg(), new ParseField("total_tokens")); + } + + public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage parse(XContentParser parser) throws IOException { + return PARSER.parse(parser, null); } - toolCalls.add(new StreamingUnifiedChatCompletionResults.ToolCall(index, id, functionName, functionArguments)); } - return toolCalls; } } From ecdf5c35561f2f67d63687050897a19a0a0eff04 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Mon, 2 Dec 2024 15:49:24 -0500 Subject: [PATCH 25/53] Removing unused request classes --- .../inference/UnifiedCompletionRequest.java | 2 +- .../openai/OpenAiChatCompletionRequest.java | 99 ------------------- .../OpenAiChatCompletionRequestEntity.java | 80 --------------- ...enAiUnifiedChatCompletionRequestTests.java | 4 +- 4 files changed, 3 insertions(+), 182 deletions(-) delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequest.java delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequestEntity.java diff --git a/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java b/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java index bce5a5601a936..b1925e74e897b 100644 --- a/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java +++ b/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java @@ -140,7 +140,7 @@ private static Content parseContent(XContentParser parser) throws IOException { return ContentString.of(parser); } - throw new XContentParseException("Unsupported token [" + token + "]"); + throw new XContentParseException("Expected an array start token or a value string token but found token [" + token + "]"); } public Message(StreamInput in) throws IOException { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequest.java deleted file mode 100644 index 99a025e70d003..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequest.java +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.external.request.openai; - -import org.apache.http.HttpHeaders; -import org.apache.http.client.methods.HttpPost; -import org.apache.http.client.utils.URIBuilder; -import org.apache.http.entity.ByteArrayEntity; -import org.elasticsearch.common.Strings; -import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.inference.external.openai.OpenAiAccount; -import org.elasticsearch.xpack.inference.external.request.HttpRequest; -import org.elasticsearch.xpack.inference.external.request.Request; -import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel; - -import java.net.URI; -import java.net.URISyntaxException; -import java.nio.charset.StandardCharsets; -import java.util.List; -import java.util.Objects; - -import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; -import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.createOrgHeader; - -public class OpenAiChatCompletionRequest implements OpenAiRequest { - - private final OpenAiAccount account; - private final List input; - private final OpenAiChatCompletionModel model; - private final boolean stream; - - public OpenAiChatCompletionRequest(List input, OpenAiChatCompletionModel model, boolean stream) { - this.account = OpenAiAccount.of(model, OpenAiChatCompletionRequest::buildDefaultUri); - this.input = Objects.requireNonNull(input); - this.model = Objects.requireNonNull(model); - this.stream = stream; - } - - @Override - public HttpRequest createHttpRequest() { - HttpPost httpPost = new HttpPost(account.uri()); - - ByteArrayEntity byteEntity = new ByteArrayEntity( - Strings.toString( - new OpenAiChatCompletionRequestEntity(input, model.getServiceSettings().modelId(), model.getTaskSettings().user(), stream) - ).getBytes(StandardCharsets.UTF_8) - ); - httpPost.setEntity(byteEntity); - - httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()); - httpPost.setHeader(createAuthBearerHeader(account.apiKey())); - - var org = account.organizationId(); - if (org != null) { - httpPost.setHeader(createOrgHeader(org)); - } - - return new HttpRequest(httpPost, getInferenceEntityId()); - } - - @Override - public URI getURI() { - return account.uri(); - } - - @Override - public Request truncate() { - // No truncation for OpenAI chat completions - return this; - } - - @Override - public boolean[] getTruncationInfo() { - // No truncation for OpenAI chat completions - return null; - } - - @Override - public String getInferenceEntityId() { - return model.getInferenceEntityId(); - } - - @Override - public boolean isStreaming() { - return stream; - } - - public static URI buildDefaultUri() throws URISyntaxException { - return new URIBuilder().setScheme("https") - .setHost(OpenAiUtils.HOST) - .setPathSegments(OpenAiUtils.VERSION_1, OpenAiUtils.CHAT_PATH, OpenAiUtils.COMPLETIONS_PATH) - .build(); - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequestEntity.java deleted file mode 100644 index 2332e70589104..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequestEntity.java +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.external.request.openai; - -import org.elasticsearch.common.Strings; -import org.elasticsearch.xcontent.ToXContentObject; -import org.elasticsearch.xcontent.XContentBuilder; - -import java.io.IOException; -import java.util.List; -import java.util.Objects; - -// TODO remove this -public class OpenAiChatCompletionRequestEntity implements ToXContentObject { - - private static final String MESSAGES_FIELD = "messages"; - private static final String MODEL_FIELD = "model"; - - private static final String NUMBER_OF_RETURNED_CHOICES_FIELD = "n"; - - private static final String ROLE_FIELD = "role"; - private static final String USER_FIELD = "user"; - private static final String CONTENT_FIELD = "content"; - private static final String STREAM_FIELD = "stream"; - - private final List messages; - private final String model; - - private final String user; - private final boolean stream; - - public OpenAiChatCompletionRequestEntity(List messages, String model, String user, boolean stream) { - Objects.requireNonNull(messages); - Objects.requireNonNull(model); - - this.messages = messages; - this.model = model; - this.user = user; - this.stream = stream; - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.startArray(MESSAGES_FIELD); - { - for (String message : messages) { - builder.startObject(); - - { - builder.field(ROLE_FIELD, USER_FIELD); - builder.field(CONTENT_FIELD, message); - } - - builder.endObject(); - } - } - builder.endArray(); - - builder.field(MODEL_FIELD, model); - builder.field(NUMBER_OF_RETURNED_CHOICES_FIELD, 1); - - if (Strings.isNullOrEmpty(user) == false) { - builder.field(USER_FIELD, user); - } - - if (stream) { - builder.field(STREAM_FIELD, true); - } - - builder.endObject(); - - return builder; - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestTests.java index 6bfbc96b9bfe1..636c28126de0f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestTests.java @@ -78,7 +78,7 @@ public void testCreateRequest_WithDefaultUrlAndWithoutUserOrganization() throws assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); var httpPost = (HttpPost) httpRequest.httpRequestBase(); - assertThat(httpPost.getURI().toString(), is(OpenAiChatCompletionRequest.buildDefaultUri().toString())); + assertThat(httpPost.getURI().toString(), is(OpenAiUnifiedChatCompletionRequest.buildDefaultUri().toString())); assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); assertNull(httpPost.getLastHeader(ORGANIZATION_HEADER)); @@ -105,7 +105,7 @@ public void testCreateRequest_WithStreaming() throws URISyntaxException, IOExcep public void testTruncate_DoesNotReduceInputTextSize() throws URISyntaxException, IOException { var request = createRequest(null, null, "secret", "abcd", "model", null); var truncatedRequest = request.truncate(); - assertThat(request.getURI().toString(), is(OpenAiChatCompletionRequest.buildDefaultUri().toString())); + assertThat(request.getURI().toString(), is(OpenAiUnifiedChatCompletionRequest.buildDefaultUri().toString())); var httpRequest = truncatedRequest.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); From f7f8a2ebe261fd185fcbf1571b77331b7a7d37c9 Mon Sep 17 00:00:00 2001 From: Max Hniebergall Date: Mon, 2 Dec 2024 15:51:48 -0500 Subject: [PATCH 26/53] precommit --- .../action/UnifiedCompletionRequestTests.java | 2 +- .../external/http/sender/UnifiedChatInput.java | 14 ++++++++++---- .../OpenAiChatCompletionRequestTaskSettings.java | 1 - 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequestTests.java index 7d839df00dc4c..2a6c28877d6b9 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequestTests.java @@ -187,7 +187,7 @@ public static UnifiedCompletionRequest randomUnifiedCompletionRequest() { randomList(5, UnifiedCompletionRequestTests::randomMessage), randomAlphaOfLengthOrNull(10), randomPositiveLongOrNull(), - randomStopOrNull(), + randomStopOrNull(), randomFloatOrNull(), randomToolChoiceOrNull(), randomToolListOrNull(), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInput.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInput.java index 38c2a03548495..95f4037276363 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInput.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInput.java @@ -22,9 +22,7 @@ public UnifiedChatInput(UnifiedCompletionRequest request, boolean stream) { } public UnifiedChatInput(ChatCompletionInput completionInput, String roleValue) { - this( - completionInput.getInputs(), roleValue, completionInput.stream() - ); + this(completionInput.getInputs(), roleValue, completionInput.stream()); } public UnifiedChatInput(List inputs, String roleValue, boolean stream) { @@ -33,7 +31,15 @@ public UnifiedChatInput(List inputs, String roleValue, boolean stream) { private static List convertToMessages(List inputs, String roleValue) { return inputs.stream() - .map(value -> new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString(value), roleValue, null, null, null)) + .map( + value -> new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString(value), + roleValue, + null, + null, + null + ) + ) .toList(); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionRequestTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionRequestTaskSettings.java index f23956bc21a04..7ef7f85d71a6a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionRequestTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionRequestTaskSettings.java @@ -10,7 +10,6 @@ import org.elasticsearch.common.ValidationException; import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.ModelConfigurations; -import org.elasticsearch.inference.UnifiedCompletionRequest; import java.util.Map; From 10ac1ae28880c72ffb306caf69f8b02778a6ae25 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Mon, 2 Dec 2024 16:30:13 -0500 Subject: [PATCH 27/53] Adding tests for UnifiedCompletionAction Request --- .../action/InferenceActionRequestTests.java | 1 - .../UnifiedCompletionActionRequestTests.java | 87 +++++++++++++++++++ 2 files changed, 87 insertions(+), 1 deletion(-) create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionActionRequestTests.java diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java index 0b9d32d0668c2..01c0ff88be222 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java @@ -41,7 +41,6 @@ protected InferenceAction.Request createTestInstance() { return new InferenceAction.Request( randomFrom(TaskType.values()), randomAlphaOfLength(6), - // null, randomAlphaOfLengthOrNull(10), randomList(1, 5, () -> randomAlphaOfLength(8)), randomMap(0, 3, () -> new Tuple<>(randomAlphaOfLength(4), randomAlphaOfLength(4))), diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionActionRequestTests.java new file mode 100644 index 0000000000000..aad4df0a2ea5e --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionActionRequestTests.java @@ -0,0 +1,87 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.action; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; + +import java.io.IOException; +import java.util.List; + +import static org.hamcrest.Matchers.is; + +public class UnifiedCompletionActionRequestTests extends AbstractBWCWireSerializationTestCase { + + public void testValidation_ReturnsException_When_UnifiedCompletionRequestMessage_Is_Null() { + var request = new UnifiedCompletionAction.Request( + "inference_id", + TaskType.COMPLETION, + UnifiedCompletionRequest.of(null), + TimeValue.timeValueSeconds(10) + ); + var exception = request.validate(); + assertThat(exception.getMessage(), is("Validation Failed: 1: Field [messages] cannot be null;")); + } + + public void testValidation_ReturnsException_When_UnifiedCompletionRequest_Is_EmptyArray() { + var request = new UnifiedCompletionAction.Request( + "inference_id", + TaskType.COMPLETION, + UnifiedCompletionRequest.of(List.of()), + TimeValue.timeValueSeconds(10) + ); + var exception = request.validate(); + assertThat(exception.getMessage(), is("Validation Failed: 1: Field [messages] cannot be an empty array;")); + } + + public void testValidation_ReturnsException_When_TaskType_IsNot_Completion() { + var request = new UnifiedCompletionAction.Request( + "inference_id", + TaskType.SPARSE_EMBEDDING, + UnifiedCompletionRequest.of(List.of(UnifiedCompletionRequestTests.randomMessage())), + TimeValue.timeValueSeconds(10) + ); + var exception = request.validate(); + assertThat(exception.getMessage(), is("Validation Failed: 1: Field [taskType] must be [completion];")); + } + + @Override + protected UnifiedCompletionAction.Request mutateInstanceForVersion(UnifiedCompletionAction.Request instance, TransportVersion version) { + return instance; + } + + @Override + protected Writeable.Reader instanceReader() { + return UnifiedCompletionAction.Request::new; + } + + @Override + protected UnifiedCompletionAction.Request createTestInstance() { + return new UnifiedCompletionAction.Request( + randomAlphaOfLength(10), + randomFrom(TaskType.values()), + UnifiedCompletionRequestTests.randomUnifiedCompletionRequest(), + TimeValue.timeValueMillis(randomLongBetween(1, 2048)) + ); + } + + @Override + protected UnifiedCompletionAction.Request mutateInstance(UnifiedCompletionAction.Request instance) throws IOException { + return randomValueOtherThan(instance, this::createTestInstance); + } + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + return new NamedWriteableRegistry(UnifiedCompletionRequest.getNamedWriteables()); + } +} From 99d202f824ecd723161fc2326a71985800017992 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Tue, 3 Dec 2024 09:31:31 -0500 Subject: [PATCH 28/53] Refactoring stop to be a list of strings --- .../inference/UnifiedCompletionRequest.java | 126 +++++++++--------- .../action/UnifiedCompletionRequestTests.java | 12 +- ...nAiUnifiedChatCompletionRequestEntity.java | 5 +- 3 files changed, 70 insertions(+), 73 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java b/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java index b1925e74e897b..a0b86aa2b19c0 100644 --- a/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java +++ b/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java @@ -33,7 +33,7 @@ public record UnifiedCompletionRequest( List messages, @Nullable String model, @Nullable Long maxCompletionTokens, - @Nullable Stop stop, + @Nullable List stop, @Nullable Float temperature, @Nullable ToolChoice toolChoice, @Nullable List tools, @@ -49,7 +49,7 @@ public sealed interface Content extends NamedWriteable permits ContentObjects, C (List) args[0], (String) args[1], (Long) args[2], - (Stop) args[3], + (List) args[3], (Float) args[4], (ToolChoice) args[5], (List) args[6], @@ -61,7 +61,9 @@ public sealed interface Content extends NamedWriteable permits ContentObjects, C PARSER.declareObjectArray(constructorArg(), Message.PARSER::apply, new ParseField("messages")); PARSER.declareString(optionalConstructorArg(), new ParseField("model")); PARSER.declareLong(optionalConstructorArg(), new ParseField("max_completion_tokens")); - PARSER.declareField(optionalConstructorArg(), (p, c) -> parseStop(p), new ParseField("stop"), ObjectParser.ValueType.VALUE_ARRAY); + // PARSER.declareField(optionalConstructorArg(), (p, c) -> parseStop(p), new ParseField("stop"), + // ObjectParser.ValueType.VALUE_ARRAY); + PARSER.declareStringArray(optionalConstructorArg(), new ParseField("stop")); PARSER.declareFloat(optionalConstructorArg(), new ParseField("temperature")); PARSER.declareField( optionalConstructorArg(), @@ -78,9 +80,9 @@ public static List getNamedWriteables() { new NamedWriteableRegistry.Entry(Content.class, ContentObjects.NAME, ContentObjects::new), new NamedWriteableRegistry.Entry(Content.class, ContentString.NAME, ContentString::new), new NamedWriteableRegistry.Entry(ToolChoice.class, ToolChoiceObject.NAME, ToolChoiceObject::new), - new NamedWriteableRegistry.Entry(ToolChoice.class, ToolChoiceString.NAME, ToolChoiceString::new), - new NamedWriteableRegistry.Entry(Stop.class, StopValues.NAME, StopValues::new), - new NamedWriteableRegistry.Entry(Stop.class, StopString.NAME, StopString::new) + new NamedWriteableRegistry.Entry(ToolChoice.class, ToolChoiceString.NAME, ToolChoiceString::new) + // new NamedWriteableRegistry.Entry(Stop.class, StopValues.NAME, StopValues::new), + // new NamedWriteableRegistry.Entry(Stop.class, StopString.NAME, StopString::new) ); } @@ -93,7 +95,7 @@ public UnifiedCompletionRequest(StreamInput in) throws IOException { in.readCollectionAsImmutableList(Message::new), in.readOptionalString(), in.readOptionalVLong(), - in.readOptionalNamedWriteable(Stop.class), + in.readOptionalStringCollectionAsList(), in.readOptionalFloat(), in.readOptionalNamedWriteable(ToolChoice.class), in.readOptionalCollectionAsList(Tool::new), @@ -106,7 +108,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeCollection(messages); out.writeOptionalString(model); out.writeOptionalVLong(maxCompletionTokens); - out.writeOptionalNamedWriteable(stop); + out.writeOptionalStringCollection(stop); out.writeOptionalFloat(temperature); out.writeOptionalNamedWriteable(toolChoice); out.writeOptionalCollection(tools); @@ -279,60 +281,60 @@ public void writeTo(StreamOutput out) throws IOException { } } - private static Stop parseStop(XContentParser parser) throws IOException { - var token = parser.currentToken(); - if (token == XContentParser.Token.START_ARRAY) { - var parsedStopValues = XContentParserUtils.parseList(parser, XContentParser::text); - return new StopValues(parsedStopValues); - } else if (token == XContentParser.Token.VALUE_STRING) { - return StopString.of(parser); - } - - throw new XContentParseException("Unsupported token [" + token + "]"); - } - - public sealed interface Stop extends NamedWriteable permits StopString, StopValues {} - - public record StopString(String value) implements Stop, NamedWriteable { - public static final String NAME = "stop_string"; - - public static StopString of(XContentParser parser) throws IOException { - var content = parser.text(); - return new StopString(content); - } - - public StopString(StreamInput in) throws IOException { - this(in.readString()); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeString(value); - } - - @Override - public String getWriteableName() { - return NAME; - } - } - - public record StopValues(List values) implements Stop, NamedWriteable { - public static final String NAME = "stop_values"; - - public StopValues(StreamInput in) throws IOException { - this(in.readStringCollectionAsImmutableList()); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeStringCollection(values); - } - - @Override - public String getWriteableName() { - return NAME; - } - } + // private static Stop parseStop(XContentParser parser) throws IOException { + // var token = parser.currentToken(); + // if (token == XContentParser.Token.START_ARRAY) { + // var parsedStopValues = XContentParserUtils.parseList(parser, XContentParser::text); + // return new StopValues(parsedStopValues); + // } else if (token == XContentParser.Token.VALUE_STRING) { + // return StopString.of(parser); + // } + // + // throw new XContentParseException("Unsupported token [" + token + "]"); + // } + + // public sealed interface Stop extends NamedWriteable permits StopString, StopValues {} + // + // public record StopString(String value) implements Stop, NamedWriteable { + // public static final String NAME = "stop_string"; + // + // public static StopString of(XContentParser parser) throws IOException { + // var content = parser.text(); + // return new StopString(content); + // } + // + // public StopString(StreamInput in) throws IOException { + // this(in.readString()); + // } + // + // @Override + // public void writeTo(StreamOutput out) throws IOException { + // out.writeString(value); + // } + // + // @Override + // public String getWriteableName() { + // return NAME; + // } + // } + // + // public record StopValues(List values) implements Stop, NamedWriteable { + // public static final String NAME = "stop_values"; + // + // public StopValues(StreamInput in) throws IOException { + // this(in.readStringCollectionAsImmutableList()); + // } + // + // @Override + // public void writeTo(StreamOutput out) throws IOException { + // out.writeStringCollection(values); + // } + // + // @Override + // public String getWriteableName() { + // return NAME; + // } + // } private static ToolChoice parseToolChoice(XContentParser parser) throws IOException { var token = parser.currentToken(); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequestTests.java index 2a6c28877d6b9..47a0814a584b7 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequestTests.java @@ -96,7 +96,7 @@ public void testParseAllFields() throws IOException { ), "gpt-4o", 100L, - new UnifiedCompletionRequest.StopValues(List.of("stop")), + List.of("stop"), 0.1F, new UnifiedCompletionRequest.ToolChoiceObject( "function", @@ -161,7 +161,7 @@ public void testParsing() throws IOException { ), "gpt-4o", null, - new UnifiedCompletionRequest.StopString("none"), + List.of("none"), null, new UnifiedCompletionRequest.ToolChoiceString("auto"), List.of( @@ -227,14 +227,12 @@ public static UnifiedCompletionRequest.ToolCall.FunctionField randomToolCallFunc return new UnifiedCompletionRequest.ToolCall.FunctionField(randomAlphaOfLength(10), randomAlphaOfLength(10)); } - public static UnifiedCompletionRequest.Stop randomStopOrNull() { + public static List randomStopOrNull() { return randomBoolean() ? randomStop() : null; } - public static UnifiedCompletionRequest.Stop randomStop() { - return randomBoolean() - ? new UnifiedCompletionRequest.StopString(randomAlphaOfLength(10)) - : new UnifiedCompletionRequest.StopValues(randomList(5, () -> randomAlphaOfLength(10))); + public static List randomStop() { + return randomList(5, () -> randomAlphaOfLength(10)); } public static UnifiedCompletionRequest.ToolChoice randomToolChoiceOrNull() { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java index 3225ecd7941f9..c57f09d10fe53 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java @@ -115,10 +115,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(NUMBER_OF_RETURNED_CHOICES_FIELD, 1); if (unifiedRequest.stop() != null) { - switch (unifiedRequest.stop()) { - case UnifiedCompletionRequest.StopString stopString -> builder.field(STOP_FIELD, stopString.value()); - case UnifiedCompletionRequest.StopValues stopValues -> builder.field(STOP_FIELD, stopValues.values()); - } + builder.field(STOP_FIELD, unifiedRequest.stop()); } if (unifiedRequest.temperature() != null) { builder.field(TEMPERATURE_FIELD, unifiedRequest.temperature()); From 6831aaae1084e1eb1898f8abc1ac239be78f203c Mon Sep 17 00:00:00 2001 From: Max Hniebergall Date: Tue, 3 Dec 2024 13:11:50 -0500 Subject: [PATCH 29/53] Testing for OpenAI response parsing --- ...mingUnifiedChatCompletionResultsTests.java | 14 + .../OpenAiUnifiedStreamingProcessor.java | 61 ++- .../OpenAiUnifiedStreamingProcessorTests.java | 392 ++++++++++++++++++ 3 files changed, 435 insertions(+), 32 deletions(-) create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResultsTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessorTests.java diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResultsTests.java new file mode 100644 index 0000000000000..351e1d97fee9a --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResultsTests.java @@ -0,0 +1,14 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.results; + +import org.elasticsearch.test.ESTestCase; + +public class StreamingUnifiedChatCompletionResultsTests extends ESTestCase { + // See OpenAiUnifiedStreamingProcessorTests.java +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java index 2baf0d6980054..168c6778d7a25 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java @@ -34,6 +34,7 @@ import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken; public class OpenAiUnifiedStreamingProcessor extends DelegatingProcessor, ChunkedToXContent> { + public static final String FUNCTION_FIELD = "function"; private static final Logger logger = LogManager.getLogger(OpenAiUnifiedStreamingProcessor.class); private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in OpenAI chat completions response"; @@ -46,6 +47,17 @@ public class OpenAiUnifiedStreamingProcessor extends DelegatingProcessor item) throws Exception { @@ -139,36 +151,21 @@ public static class ChatCompletionChunkParser { (String) args[3], (StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage) args[4] ) - - /** - * TODO - * Caused by: java.lang.ClassCastException: class java.lang.String cannot be cast to class - * [Lorg.elasticsearch.xpack.inference.external.openai.OpenAiStreamingProcessor$ChatCompletionChunk$Choice; - * (java.lang.String is in module java.base of loader 'bootstrap'; [Lorg.elasticsearch.xpack.inference.external.openai. - * OpenAiStreamingProcessor$ChatCompletionChunk$Choice; is in module org.elasticsearch.inference@9.0.0-SNAPSHOT of loader - * jdk.internal.loader.Loader @611c3eae) - * at org.elasticsearch.inference@9.0.0-SNAPSHOT/org.elasticsearch.xpack.inference.external.openai. - * OpenAiStreamingProcessor$ChatCompletionChunkParser.lambda$static$0(OpenAiStreamingProcessor.java:354) - * at org.elasticsearch.xcontent@9.0.0-SNAPSHOT/org.elasticsearch.xcontent.ConstructingObjectParser. - * lambda$new$2(ConstructingObjectParser.java:130) - * at org.elasticsearch.xcontent@9.0.0-SNAPSHOT/org.elasticsearch.xcontent.ConstructingObjectParser$Target. - * buildTarget(ConstructingObjectParser.java:555) - */ ); static { - PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("id")); + PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField(ID_FIELD)); PARSER.declareObjectArray( ConstructingObjectParser.constructorArg(), (p, c) -> ChatCompletionChunkParser.ChoiceParser.parse(p), new ParseField(CHOICES_FIELD) ); - PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("model")); - PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("object")); + PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField(MODEL_FIELD)); + PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField(OBJECT_FIELD)); PARSER.declareObject( ConstructingObjectParser.optionalConstructorArg(), (p, c) -> ChatCompletionChunkParser.UsageParser.parse(p), - new ParseField("usage") + new ParseField(USAGE_FIELD) ); } @@ -179,7 +176,7 @@ public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk parse(XC private static class ChoiceParser { private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( - "choice", + CHOICE_FIELD, true, args -> new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice( (StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta) args[0], @@ -219,7 +216,7 @@ private static class DeltaParser { ); static { - PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(CONTENT_FIELD)); + PARSER.declareStringOrNull(ConstructingObjectParser.optionalConstructorArg(), new ParseField(CONTENT_FIELD)); PARSER.declareStringOrNull(ConstructingObjectParser.optionalConstructorArg(), new ParseField(REFUSAL_FIELD)); PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(ROLE_FIELD)); PARSER.declareObjectArray( @@ -251,13 +248,13 @@ private static class ToolCallParser { static { PARSER.declareInt(ConstructingObjectParser.constructorArg(), new ParseField(INDEX_FIELD)); - PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField("id")); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(ID_FIELD)); PARSER.declareObject( - (toolCall, function) -> toolCall.function = function, + ConstructingObjectParser.optionalConstructorArg(), (p, c) -> ChatCompletionChunkParser.FunctionParser.parse(p), - new ParseField("function") + new ParseField(FUNCTION_FIELD) ); - PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField("type")); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(TYPE_FIELD)); } public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall parse(XContentParser parser) @@ -270,7 +267,7 @@ private static class FunctionParser { private static final ConstructingObjectParser< StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall.Function, Void> PARSER = new ConstructingObjectParser<>( - "function", + FUNCTION_FIELD, true, args -> new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall.Function( (String) args[0], @@ -279,8 +276,8 @@ private static class FunctionParser { ); static { - PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField("arguments")); - PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField("name")); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(ARGUMENTS_FIELD)); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(NAME_FIELD)); } public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall.Function parse( @@ -293,15 +290,15 @@ public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.D private static class UsageParser { private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( - "usage", + USAGE_FIELD, true, args -> new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage((int) args[0], (int) args[1], (int) args[2]) ); static { - PARSER.declareInt(ConstructingObjectParser.constructorArg(), new ParseField("completion_tokens")); - PARSER.declareInt(ConstructingObjectParser.constructorArg(), new ParseField("prompt_tokens")); - PARSER.declareInt(ConstructingObjectParser.constructorArg(), new ParseField("total_tokens")); + PARSER.declareInt(ConstructingObjectParser.constructorArg(), new ParseField(COMPLETION_TOKENS_FIELD)); + PARSER.declareInt(ConstructingObjectParser.constructorArg(), new ParseField(PROMPT_TOKENS_FIELD)); + PARSER.declareInt(ConstructingObjectParser.constructorArg(), new ParseField(TOTAL_TOKENS_FIELD)); } public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage parse(XContentParser parser) throws IOException { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessorTests.java new file mode 100644 index 0000000000000..8f1e720a128b2 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessorTests.java @@ -0,0 +1,392 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.openai; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; + +import java.io.IOException; +import java.util.List; + +public class OpenAiUnifiedStreamingProcessorTests extends ESTestCase { + + public void testJsonLiteral() { + String json = """ + { + "id": "example_id", + "choices": [ + { + "delta": { + "content": "example_content", + "refusal": null, + "role": "assistant", + "tool_calls": [ + { + "index": 1, + "id": "tool_call_id", + "function": { + "arguments": "example_arguments", + "name": "example_function_name" + }, + "type": "function" + } + ] + }, + "finish_reason": "stop", + "index": 0 + } + ], + "model": "example_model", + "object": "chat.completion.chunk", + "usage": { + "completion_tokens": 50, + "prompt_tokens": 20, + "total_tokens": 70 + } + } + """; + // Parse the JSON + XContentParserConfiguration parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler( + LoggingDeprecationHandler.INSTANCE + ); + try (XContentParser parser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, json)) { + StreamingUnifiedChatCompletionResults.ChatCompletionChunk chunk = OpenAiUnifiedStreamingProcessor.ChatCompletionChunkParser + .parse(parser); + + // Assertions to verify the parsed object + assertEquals("example_id", chunk.getId()); + assertEquals("example_model", chunk.getModel()); + assertEquals("chat.completion.chunk", chunk.getObject()); + assertNotNull(chunk.getUsage()); + assertEquals(50, chunk.getUsage().completionTokens()); + assertEquals(20, chunk.getUsage().promptTokens()); + assertEquals(70, chunk.getUsage().totalTokens()); + + List choices = chunk.getChoices(); + assertEquals(1, choices.size()); + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice choice = choices.get(0); + assertEquals("example_content", choice.getDelta().getContent()); + assertNull(choice.getDelta().getRefusal()); + assertEquals("assistant", choice.getDelta().getRole()); + assertEquals("stop", choice.getFinishReason()); + assertEquals(0, choice.getIndex()); + + List toolCalls = choice.getDelta() + .getToolCalls(); + assertEquals(1, toolCalls.size()); + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall toolCall = toolCalls.get(0); + assertEquals(1, toolCall.getIndex()); + assertEquals("tool_call_id", toolCall.getId()); + assertEquals("example_function_name", toolCall.getFunction().getName()); + assertEquals("example_arguments", toolCall.getFunction().getArguments()); + assertEquals("function", toolCall.getType()); + } catch (IOException e) { + fail(); + } + } + + public void testJsonLiteralCornerCases() { + String json = """ + { + "id": "example_id", + "choices": [ + { + "delta": { + "content": null, + "refusal": null, + "role": "assistant", + "tool_calls": [] + }, + "finish_reason": null, + "index": 0 + }, + { + "delta": { + "content": "example_content", + "refusal": "example_refusal", + "role": "user", + "tool_calls": [ + { + "index": 1, + "function": { + "name": "example_function_name" + }, + "type": "function" + } + ] + }, + "finish_reason": "stop", + "index": 1 + } + ], + "model": "example_model", + "object": "chat.completion.chunk", + "usage": { + "completion_tokens": 50, + "prompt_tokens": 20, + "total_tokens": 70 + } + } + """; + // Parse the JSON + XContentParserConfiguration parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler( + LoggingDeprecationHandler.INSTANCE + ); + try (XContentParser parser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, json)) { + StreamingUnifiedChatCompletionResults.ChatCompletionChunk chunk = OpenAiUnifiedStreamingProcessor.ChatCompletionChunkParser + .parse(parser); + + // Assertions to verify the parsed object + assertEquals("example_id", chunk.getId()); + assertEquals("example_model", chunk.getModel()); + assertEquals("chat.completion.chunk", chunk.getObject()); + assertNotNull(chunk.getUsage()); + assertEquals(50, chunk.getUsage().completionTokens()); + assertEquals(20, chunk.getUsage().promptTokens()); + assertEquals(70, chunk.getUsage().totalTokens()); + + List choices = chunk.getChoices(); + assertEquals(2, choices.size()); + + // First choice assertions + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice firstChoice = choices.get(0); + assertNull(firstChoice.getDelta().getContent()); + assertNull(firstChoice.getDelta().getRefusal()); + assertEquals("assistant", firstChoice.getDelta().getRole()); + assertTrue(firstChoice.getDelta().getToolCalls().isEmpty()); + assertNull(firstChoice.getFinishReason()); + assertEquals(0, firstChoice.getIndex()); + + // Second choice assertions + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice secondChoice = choices.get(1); + assertEquals("example_content", secondChoice.getDelta().getContent()); + assertEquals("example_refusal", secondChoice.getDelta().getRefusal()); + assertEquals("user", secondChoice.getDelta().getRole()); + assertEquals("stop", secondChoice.getFinishReason()); + assertEquals(1, secondChoice.getIndex()); + + List toolCalls = secondChoice.getDelta() + .getToolCalls(); + assertEquals(1, toolCalls.size()); + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall toolCall = toolCalls.get(0); + assertEquals(1, toolCall.getIndex()); + assertNull(toolCall.getId()); + assertEquals("example_function_name", toolCall.getFunction().getName()); + assertNull(toolCall.getFunction().getArguments()); + assertEquals("function", toolCall.getType()); + } catch (IOException e) { + fail(); + } + } + + public void testOpenAiUnifiedStreamingProcessorParsing() throws IOException { + // Generate random values for the JSON fields + int toolCallIndex = randomIntBetween(0, 10); + String toolCallId = randomAlphaOfLength(5); + String toolCallFunctionName = randomAlphaOfLength(8); + String toolCallFunctionArguments = randomAlphaOfLength(10); + String toolCallType = "function"; + String toolCallJson = createToolCallJson(toolCallIndex, toolCallId, toolCallFunctionName, toolCallFunctionArguments, toolCallType); + + String choiceContent = randomAlphaOfLength(10); + String choiceRole = randomFrom("system", "user", "assistant", "tool"); + String choiceFinishReason = randomFrom("stop", "length", "tool_calls", "content_filter", "function_call", null); + int choiceIndex = randomIntBetween(0, 10); + String choiceJson = createChoiceJson(choiceContent, null, choiceRole, toolCallJson, choiceFinishReason, choiceIndex); + + int usageCompletionTokens = randomIntBetween(1, 100); + int usagePromptTokens = randomIntBetween(1, 100); + int usageTotalTokens = randomIntBetween(1, 200); + String usageJson = createUsageJson(usageCompletionTokens, usagePromptTokens, usageTotalTokens); + + String chatCompletionChunkId = randomAlphaOfLength(10); + String chatCompletionChunkModel = randomAlphaOfLength(5); + String chatCompletionChunkJson = createChatCompletionChunkJson( + chatCompletionChunkId, + choiceJson, + chatCompletionChunkModel, + "chat.completion.chunk", + usageJson + ); + + // Parse the JSON + XContentParserConfiguration parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler( + LoggingDeprecationHandler.INSTANCE + ); + try (XContentParser parser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, chatCompletionChunkJson)) { + StreamingUnifiedChatCompletionResults.ChatCompletionChunk chunk = OpenAiUnifiedStreamingProcessor.ChatCompletionChunkParser + .parse(parser); + + // Assertions to verify the parsed object + assertEquals(chatCompletionChunkId, chunk.getId()); + assertEquals(chatCompletionChunkModel, chunk.getModel()); + assertEquals("chat.completion.chunk", chunk.getObject()); + assertNotNull(chunk.getUsage()); + assertEquals(usageCompletionTokens, chunk.getUsage().completionTokens()); + assertEquals(usagePromptTokens, chunk.getUsage().promptTokens()); + assertEquals(usageTotalTokens, chunk.getUsage().totalTokens()); + + List choices = chunk.getChoices(); + assertEquals(1, choices.size()); + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice choice = choices.get(0); + assertEquals(choiceContent, choice.getDelta().getContent()); + assertNull(choice.getDelta().getRefusal()); + assertEquals(choiceRole, choice.getDelta().getRole()); + assertEquals(choiceFinishReason, choice.getFinishReason()); + assertEquals(choiceIndex, choice.getIndex()); + + List toolCalls = choice.getDelta() + .getToolCalls(); + assertEquals(1, toolCalls.size()); + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall toolCall = toolCalls.get(0); + assertEquals(toolCallIndex, toolCall.getIndex()); + assertEquals(toolCallId, toolCall.getId()); + assertEquals(toolCallFunctionName, toolCall.getFunction().getName()); + assertEquals(toolCallFunctionArguments, toolCall.getFunction().getArguments()); + assertEquals(toolCallType, toolCall.getType()); + } + } + + public void testOpenAiUnifiedStreamingProcessorParsingWithNullFields() throws IOException { + // JSON with null fields + int choiceIndex = randomIntBetween(0, 10); + String choiceJson = createChoiceJson(null, null, null, "", null, choiceIndex); + + String chatCompletionChunkId = randomAlphaOfLength(10); + String chatCompletionChunkModel = randomAlphaOfLength(5); + String chatCompletionChunkJson = createChatCompletionChunkJson( + chatCompletionChunkId, + choiceJson, + chatCompletionChunkModel, + "chat.completion.chunk", + null + ); + + // Parse the JSON + XContentParserConfiguration parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler( + LoggingDeprecationHandler.INSTANCE + ); + try (XContentParser parser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, chatCompletionChunkJson)) { + StreamingUnifiedChatCompletionResults.ChatCompletionChunk chunk = OpenAiUnifiedStreamingProcessor.ChatCompletionChunkParser + .parse(parser); + + // Assertions to verify the parsed object + assertEquals(chatCompletionChunkId, chunk.getId()); + assertEquals(chatCompletionChunkModel, chunk.getModel()); + assertEquals("chat.completion.chunk", chunk.getObject()); + assertNull(chunk.getUsage()); + + List choices = chunk.getChoices(); + assertEquals(1, choices.size()); + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice choice = choices.get(0); + assertNull(choice.getDelta().getContent()); + assertNull(choice.getDelta().getRefusal()); + assertNull(choice.getDelta().getRole()); + assertNull(choice.getFinishReason()); + assertEquals(choiceIndex, choice.getIndex()); + assertTrue(choice.getDelta().getToolCalls().isEmpty()); + } + } + + private String createToolCallJson(int index, String id, String functionName, String functionArguments, String type) { + return Strings.format(""" + { + "index": %d, + "id": "%s", + "function": { + "name": "%s", + "arguments": "%s" + }, + "type": "%s" + } + """, index, id, functionName, functionArguments, type); + } + + private String createChoiceJson(String content, String refusal, String role, String toolCallsJson, String finishReason, int index) { + if (role == null) { + return Strings.format( + """ + { + "delta": { + "content": %s, + "refusal": %s, + "tool_calls": [%s] + }, + "finish_reason": %s, + "index": %d + } + """, + content != null ? "\"" + content + "\"" : "null", + refusal != null ? "\"" + refusal + "\"" : "null", + toolCallsJson, + finishReason != null ? "\"" + finishReason + "\"" : "null", + index + ); + } else { + return Strings.format( + """ + { + "delta": { + "content": %s, + "refusal": %s, + "role": %s, + "tool_calls": [%s] + }, + "finish_reason": %s, + "index": %d + } + """, + content != null ? "\"" + content + "\"" : "null", + refusal != null ? "\"" + refusal + "\"" : "null", + role != null ? "\"" + role + "\"" : "null", + toolCallsJson, + finishReason != null ? "\"" + finishReason + "\"" : "null", + index + ); + } + } + + private String createChatCompletionChunkJson(String id, String choicesJson, String model, String object, String usageJson) { + if (usageJson != null) { + return Strings.format(""" + { + "id": "%s", + "choices": [%s], + "model": "%s", + "object": "%s", + "usage": %s + } + """, id, choicesJson, model, object, usageJson); + } else { + return Strings.format(""" + { + "id": "%s", + "choices": [%s], + "model": "%s", + "object": "%s" + } + """, id, choicesJson, model, object); + } + } + + private String createUsageJson(int completionTokens, int promptTokens, int totalTokens) { + return Strings.format(""" + { + "completion_tokens": %d, + "prompt_tokens": %d, + "total_tokens": %d + } + """, completionTokens, promptTokens, totalTokens); + } +} From 41f9bce5293b2ba0f06eb3d566fa40f3192713b4 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Tue, 3 Dec 2024 13:16:05 -0500 Subject: [PATCH 30/53] Refactoring transport action tests to test unified validation code --- .../action/BaseTransportInferenceAction.java | 59 ++- .../action/TransportInferenceAction.java | 5 +- ...sportUnifiedCompletionInferenceAction.java | 2 +- .../BaseTransportInferenceActionTestCase.java | 364 ++++++++++++++++++ .../action/TransportInferenceActionTests.java | 337 +--------------- ...TransportUnifiedCompletionActionTests.java | 83 ++++ 6 files changed, 504 insertions(+), 346 deletions(-) create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java index 7f4ab7d9e5447..87c2e8befbd56 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java @@ -34,6 +34,7 @@ import org.elasticsearch.xpack.inference.telemetry.InferenceStats; import org.elasticsearch.xpack.inference.telemetry.InferenceTimer; +import java.util.function.Supplier; import java.util.stream.Collectors; import static org.elasticsearch.core.Strings.format; @@ -75,27 +76,43 @@ protected void doExecute(Task task, Request request, ActionListener { var service = serviceRegistry.getService(unparsedModel.service()); - if (service.isEmpty()) { - var e = unknownServiceException(unparsedModel.service(), request.getInferenceEntityId()); + try { + validationHelper(service::isEmpty, () -> unknownServiceException(unparsedModel.service(), request.getInferenceEntityId())); + validationHelper( + () -> request.getTaskType().isAnyOrSame(unparsedModel.taskType()) == false, + () -> requestModelTaskTypeMismatchException(request.getTaskType(), unparsedModel.taskType()) + ); + validationHelper( + () -> isInvalidTaskTypeForInferenceEndpoint(request, unparsedModel), + () -> createInvalidTaskTypeException(request, unparsedModel) + ); + } catch (Exception e) { recordMetrics(unparsedModel, timer, e); listener.onFailure(e); return; } - if (request.getTaskType().isAnyOrSame(unparsedModel.taskType()) == false) { - // not the wildcard task type and not the model task type - var e = incompatibleTaskTypeException(request.getTaskType(), unparsedModel.taskType()); - recordMetrics(unparsedModel, timer, e); - listener.onFailure(e); - return; - } - - if (isInvalidTaskTypeForInferenceEndpoint(request, unparsedModel)) { - var e = createIncompatibleTaskTypeException(request, unparsedModel); - recordMetrics(unparsedModel, timer, e); - listener.onFailure(e); - return; - } + // if (service.isEmpty()) { + // var e = unknownServiceException(unparsedModel.service(), request.getInferenceEntityId()); + // recordMetrics(unparsedModel, timer, e); + // listener.onFailure(e); + // return; + // } + + // if (request.getTaskType().isAnyOrSame(unparsedModel.taskType()) == false) { + // // not the wildcard task type and not the model task type + // var e = incompatibleTaskTypeException(request.getTaskType(), unparsedModel.taskType()); + // recordMetrics(unparsedModel, timer, e); + // listener.onFailure(e); + // return; + // } + + // if (isInvalidTaskTypeForInferenceEndpoint(request, unparsedModel)) { + // var e = createInvalidTaskTypeException(request, unparsedModel); + // recordMetrics(unparsedModel, timer, e); + // listener.onFailure(e); + // return; + // } var model = service.get() .parsePersistedConfigWithSecrets( @@ -117,9 +134,15 @@ protected void doExecute(Task task, Request request, ActionListener validationFailure, Supplier exceptionCreator) { + if (validationFailure.get()) { + throw exceptionCreator.get(); + } + } + protected abstract boolean isInvalidTaskTypeForInferenceEndpoint(Request request, UnparsedModel unparsedModel); - protected abstract ElasticsearchStatusException createIncompatibleTaskTypeException(Request request, UnparsedModel unparsedModel); + protected abstract ElasticsearchStatusException createInvalidTaskTypeException(Request request, UnparsedModel unparsedModel); private void recordMetrics(UnparsedModel model, InferenceTimer timer, @Nullable Throwable t) { try { @@ -225,7 +248,7 @@ private static ElasticsearchStatusException unknownServiceException(String servi return new ElasticsearchStatusException("Unknown service [{}] for model [{}]. ", RestStatus.BAD_REQUEST, service, inferenceId); } - private static ElasticsearchStatusException incompatibleTaskTypeException(TaskType requested, TaskType expected) { + private static ElasticsearchStatusException requestModelTaskTypeMismatchException(TaskType requested, TaskType expected) { return new ElasticsearchStatusException( "Incompatible task_type, the requested type [{}] does not match the model type [{}]", RestStatus.BAD_REQUEST, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java index c4e7dfd75d218..08e6d869a553d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java @@ -51,10 +51,7 @@ protected boolean isInvalidTaskTypeForInferenceEndpoint(InferenceAction.Request } @Override - protected ElasticsearchStatusException createIncompatibleTaskTypeException( - InferenceAction.Request request, - UnparsedModel unparsedModel - ) { + protected ElasticsearchStatusException createInvalidTaskTypeException(InferenceAction.Request request, UnparsedModel unparsedModel) { return null; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java index 8c79fc3e8a459..fd6b234fce7f8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java @@ -53,7 +53,7 @@ protected boolean isInvalidTaskTypeForInferenceEndpoint(UnifiedCompletionAction. } @Override - protected ElasticsearchStatusException createIncompatibleTaskTypeException( + protected ElasticsearchStatusException createInvalidTaskTypeException( UnifiedCompletionAction.Request request, UnparsedModel unparsedModel ) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java new file mode 100644 index 0000000000000..47f3a0e0b57aa --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java @@ -0,0 +1,364 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.action; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.common.xcontent.ChunkedToXContent; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnparsedModel; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.inference.action.BaseInferenceActionRequest; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.telemetry.InferenceStats; +import org.junit.Before; +import org.mockito.ArgumentCaptor; + +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.Flow; +import java.util.function.Consumer; + +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.isA; +import static org.hamcrest.Matchers.nullValue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.assertArg; +import static org.mockito.ArgumentMatchers.same; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public abstract class BaseTransportInferenceActionTestCase extends ESTestCase { + private ModelRegistry modelRegistry; + private StreamingTaskManager streamingTaskManager; + private BaseTransportInferenceAction action; + + protected static final String serviceId = "serviceId"; + protected static final TaskType taskType = TaskType.COMPLETION; + protected static final String inferenceId = "inferenceEntityId"; + protected InferenceServiceRegistry serviceRegistry; + protected InferenceStats inferenceStats; + + @Before + public void setUp() throws Exception { + super.setUp(); + TransportService transportService = mock(); + ActionFilters actionFilters = mock(); + modelRegistry = mock(); + serviceRegistry = mock(); + inferenceStats = new InferenceStats(mock(), mock()); + streamingTaskManager = mock(); + action = createAction(transportService, actionFilters, modelRegistry, serviceRegistry, inferenceStats, streamingTaskManager); + } + + protected abstract BaseTransportInferenceAction createAction( + TransportService transportService, + ActionFilters actionFilters, + ModelRegistry modelRegistry, + InferenceServiceRegistry serviceRegistry, + InferenceStats inferenceStats, + StreamingTaskManager streamingTaskManager + ); + + protected abstract Request createRequest(); + + public void testMetricsAfterModelRegistryError() { + var expectedException = new IllegalStateException("hello"); + var expectedError = expectedException.getClass().getSimpleName(); + + doAnswer(ans -> { + ActionListener listener = ans.getArgument(1); + listener.onFailure(expectedException); + return null; + }).when(modelRegistry).getModelWithSecrets(any(), any()); + + var listener = doExecute(taskType); + verify(listener).onFailure(same(expectedException)); + + verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { + assertThat(attributes.get("service"), nullValue()); + assertThat(attributes.get("task_type"), nullValue()); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), nullValue()); + assertThat(attributes.get("error.type"), is(expectedError)); + })); + } + + protected ActionListener doExecute(TaskType taskType) { + return doExecute(taskType, false); + } + + protected ActionListener doExecute(TaskType taskType, boolean stream) { + Request request = createRequest(); + when(request.getInferenceEntityId()).thenReturn(inferenceId); + when(request.getTaskType()).thenReturn(taskType); + when(request.isStreaming()).thenReturn(stream); + ActionListener listener = mock(); + action.doExecute(mock(), request, listener); + return listener; + } + + public void testMetricsAfterMissingService() { + mockModelRegistry(taskType); + + when(serviceRegistry.getService(any())).thenReturn(Optional.empty()); + + var listener = doExecute(taskType); + + verify(listener).onFailure(assertArg(e -> { + assertThat(e, isA(ElasticsearchStatusException.class)); + assertThat(e.getMessage(), is("Unknown service [" + serviceId + "] for model [" + inferenceId + "]. ")); + assertThat(((ElasticsearchStatusException) e).status(), is(RestStatus.BAD_REQUEST)); + })); + verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { + assertThat(attributes.get("service"), is(serviceId)); + assertThat(attributes.get("task_type"), is(taskType.toString())); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), is(RestStatus.BAD_REQUEST.getStatus())); + assertThat(attributes.get("error.type"), is(String.valueOf(RestStatus.BAD_REQUEST.getStatus()))); + })); + } + + protected void mockModelRegistry(TaskType expectedTaskType) { + var unparsedModel = new UnparsedModel(inferenceId, expectedTaskType, serviceId, Map.of(), Map.of()); + doAnswer(ans -> { + ActionListener listener = ans.getArgument(1); + listener.onResponse(unparsedModel); + return null; + }).when(modelRegistry).getModelWithSecrets(any(), any()); + } + + public void testMetricsAfterUnknownTaskType() { + var modelTaskType = TaskType.RERANK; + var requestTaskType = TaskType.SPARSE_EMBEDDING; + mockModelRegistry(modelTaskType); + when(serviceRegistry.getService(any())).thenReturn(Optional.of(mock())); + + var listener = doExecute(requestTaskType); + + verify(listener).onFailure(assertArg(e -> { + assertThat(e, isA(ElasticsearchStatusException.class)); + assertThat( + e.getMessage(), + is( + "Incompatible task_type, the requested type [" + + requestTaskType + + "] does not match the model type [" + + modelTaskType + + "]" + ) + ); + assertThat(((ElasticsearchStatusException) e).status(), is(RestStatus.BAD_REQUEST)); + })); + verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { + assertThat(attributes.get("service"), is(serviceId)); + assertThat(attributes.get("task_type"), is(modelTaskType.toString())); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), is(RestStatus.BAD_REQUEST.getStatus())); + assertThat(attributes.get("error.type"), is(String.valueOf(RestStatus.BAD_REQUEST.getStatus()))); + })); + } + + public void testMetricsAfterInferError() { + var expectedException = new IllegalStateException("hello"); + var expectedError = expectedException.getClass().getSimpleName(); + mockService(listener -> listener.onFailure(expectedException)); + + var listener = doExecute(taskType); + + verify(listener).onFailure(same(expectedException)); + verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { + assertThat(attributes.get("service"), is(serviceId)); + assertThat(attributes.get("task_type"), is(taskType.toString())); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), nullValue()); + assertThat(attributes.get("error.type"), is(expectedError)); + })); + } + + public void testMetricsAfterStreamUnsupported() { + var expectedStatus = RestStatus.METHOD_NOT_ALLOWED; + var expectedError = String.valueOf(expectedStatus.getStatus()); + mockService(l -> {}); + + var listener = doExecute(taskType, true); + + verify(listener).onFailure(assertArg(e -> { + assertThat(e, isA(ElasticsearchStatusException.class)); + var ese = (ElasticsearchStatusException) e; + assertThat(ese.getMessage(), is("Streaming is not allowed for service [" + serviceId + "].")); + assertThat(ese.status(), is(expectedStatus)); + })); + verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { + assertThat(attributes.get("service"), is(serviceId)); + assertThat(attributes.get("task_type"), is(taskType.toString())); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), is(expectedStatus.getStatus())); + assertThat(attributes.get("error.type"), is(expectedError)); + })); + } + + public void testMetricsAfterInferSuccess() { + mockService(listener -> listener.onResponse(mock())); + + var listener = doExecute(taskType); + + verify(listener).onResponse(any()); + verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { + assertThat(attributes.get("service"), is(serviceId)); + assertThat(attributes.get("task_type"), is(taskType.toString())); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), is(200)); + assertThat(attributes.get("error.type"), nullValue()); + })); + } + + public void testMetricsAfterStreamInferSuccess() { + mockStreamResponse(Flow.Subscriber::onComplete); + verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { + assertThat(attributes.get("service"), is(serviceId)); + assertThat(attributes.get("task_type"), is(taskType.toString())); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), is(200)); + assertThat(attributes.get("error.type"), nullValue()); + })); + } + + public void testMetricsAfterStreamInferFailure() { + var expectedException = new IllegalStateException("hello"); + var expectedError = expectedException.getClass().getSimpleName(); + mockStreamResponse(subscriber -> { + subscriber.subscribe(mock()); + subscriber.onError(expectedException); + }); + verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { + assertThat(attributes.get("service"), is(serviceId)); + assertThat(attributes.get("task_type"), is(taskType.toString())); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), nullValue()); + assertThat(attributes.get("error.type"), is(expectedError)); + })); + } + + public void testMetricsAfterStreamCancel() { + var response = mockStreamResponse(s -> s.onSubscribe(mock())); + response.subscribe(new Flow.Subscriber<>() { + @Override + public void onSubscribe(Flow.Subscription subscription) { + subscription.cancel(); + } + + @Override + public void onNext(ChunkedToXContent item) { + + } + + @Override + public void onError(Throwable throwable) { + + } + + @Override + public void onComplete() { + + } + }); + + verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { + assertThat(attributes.get("service"), is(serviceId)); + assertThat(attributes.get("task_type"), is(taskType.toString())); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), is(200)); + assertThat(attributes.get("error.type"), nullValue()); + })); + } + + protected Flow.Publisher mockStreamResponse(Consumer> action) { + mockService(true, Set.of(), listener -> { + Flow.Processor taskProcessor = mock(); + doAnswer(innerAns -> { + action.accept(innerAns.getArgument(0)); + return null; + }).when(taskProcessor).subscribe(any()); + when(streamingTaskManager.create(any(), any())).thenReturn(taskProcessor); + var inferenceServiceResults = mock(InferenceServiceResults.class); + when(inferenceServiceResults.publisher()).thenReturn(mock()); + listener.onResponse(inferenceServiceResults); + }); + + var listener = doExecute(taskType, true); + var captor = ArgumentCaptor.forClass(InferenceAction.Response.class); + verify(listener).onResponse(captor.capture()); + assertTrue(captor.getValue().isStreaming()); + assertNotNull(captor.getValue().publisher()); + return captor.getValue().publisher(); + } + + protected void mockService(Consumer> listenerAction) { + mockService(false, Set.of(), listenerAction); + } + + protected void mockService( + boolean stream, + Set supportedStreamingTasks, + Consumer> listenerAction + ) { + InferenceService service = mock(); + Model model = mockModel(); + when(service.parsePersistedConfigWithSecrets(any(), any(), any(), any())).thenReturn(model); + when(service.name()).thenReturn(serviceId); + + when(service.canStream(any())).thenReturn(stream); + when(service.supportedStreamingTasks()).thenReturn(supportedStreamingTasks); + doAnswer(ans -> { + listenerAction.accept(ans.getArgument(7)); + return null; + }).when(service).infer(any(), any(), any(), anyBoolean(), any(), any(), any(), any()); + doAnswer(ans -> { + listenerAction.accept(ans.getArgument(3)); + return null; + }).when(service).unifiedCompletionInfer(any(), any(), any(), any()); + mockModelAndServiceRegistry(service); + } + + protected Model mockModel() { + Model model = mock(); + ModelConfigurations modelConfigurations = mock(); + when(modelConfigurations.getService()).thenReturn(serviceId); + when(model.getConfigurations()).thenReturn(modelConfigurations); + when(model.getTaskType()).thenReturn(taskType); + when(model.getServiceSettings()).thenReturn(mock()); + return model; + } + + protected void mockModelAndServiceRegistry(InferenceService service) { + var unparsedModel = new UnparsedModel(inferenceId, taskType, serviceId, Map.of(), Map.of()); + doAnswer(ans -> { + ActionListener listener = ans.getArgument(1); + listener.onResponse(unparsedModel); + return null; + }).when(modelRegistry).getModelWithSecrets(any(), any()); + + when(serviceRegistry.getService(any())).thenReturn(Optional.of(service)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java index 0ed9cbf56b3fa..e54175cb27009 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java @@ -7,66 +7,28 @@ package org.elasticsearch.xpack.inference.action; -import org.elasticsearch.ElasticsearchStatusException; -import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.ActionFilters; -import org.elasticsearch.common.xcontent.ChunkedToXContent; -import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceRegistry; -import org.elasticsearch.inference.InferenceServiceResults; -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.ModelConfigurations; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.inference.UnparsedModel; -import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.test.ESTestCase; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.telemetry.InferenceStats; -import org.junit.Before; -import org.mockito.ArgumentCaptor; -import java.util.Map; -import java.util.Optional; -import java.util.Set; -import java.util.concurrent.Flow; -import java.util.function.Consumer; - -import static org.hamcrest.Matchers.is; -import static org.hamcrest.Matchers.isA; -import static org.hamcrest.Matchers.nullValue; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyBoolean; -import static org.mockito.ArgumentMatchers.anyLong; -import static org.mockito.ArgumentMatchers.assertArg; -import static org.mockito.ArgumentMatchers.same; -import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; -public class TransportInferenceActionTests extends ESTestCase { - private static final String serviceId = "serviceId"; - private static final TaskType taskType = TaskType.COMPLETION; - private static final String inferenceId = "inferenceEntityId"; - private ModelRegistry modelRegistry; - private InferenceServiceRegistry serviceRegistry; - private InferenceStats inferenceStats; - private StreamingTaskManager streamingTaskManager; - private TransportInferenceAction action; +public class TransportInferenceActionTests extends BaseTransportInferenceActionTestCase { - @Before - public void setUp() throws Exception { - super.setUp(); - TransportService transportService = mock(); - ActionFilters actionFilters = mock(); - modelRegistry = mock(); - serviceRegistry = mock(); - inferenceStats = new InferenceStats(mock(), mock()); - streamingTaskManager = mock(); - action = new TransportInferenceAction( + @Override + protected BaseTransportInferenceAction createAction( + TransportService transportService, + ActionFilters actionFilters, + ModelRegistry modelRegistry, + InferenceServiceRegistry serviceRegistry, + InferenceStats inferenceStats, + StreamingTaskManager streamingTaskManager + ) { + return new TransportInferenceAction( transportService, actionFilters, modelRegistry, @@ -76,279 +38,8 @@ public void setUp() throws Exception { ); } - public void testMetricsAfterModelRegistryError() { - var expectedException = new IllegalStateException("hello"); - var expectedError = expectedException.getClass().getSimpleName(); - - doAnswer(ans -> { - ActionListener listener = ans.getArgument(1); - listener.onFailure(expectedException); - return null; - }).when(modelRegistry).getModelWithSecrets(any(), any()); - - var listener = doExecute(taskType); - verify(listener).onFailure(same(expectedException)); - - verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { - assertThat(attributes.get("service"), nullValue()); - assertThat(attributes.get("task_type"), nullValue()); - assertThat(attributes.get("model_id"), nullValue()); - assertThat(attributes.get("status_code"), nullValue()); - assertThat(attributes.get("error.type"), is(expectedError)); - })); - } - - private ActionListener doExecute(TaskType taskType) { - return doExecute(taskType, false); - } - - private ActionListener doExecute(TaskType taskType, boolean stream) { - InferenceAction.Request request = mock(); - when(request.getInferenceEntityId()).thenReturn(inferenceId); - when(request.getTaskType()).thenReturn(taskType); - when(request.isStreaming()).thenReturn(stream); - ActionListener listener = mock(); - action.doExecute(mock(), request, listener); - return listener; - } - - public void testMetricsAfterMissingService() { - mockModelRegistry(taskType); - - when(serviceRegistry.getService(any())).thenReturn(Optional.empty()); - - var listener = doExecute(taskType); - - verify(listener).onFailure(assertArg(e -> { - assertThat(e, isA(ElasticsearchStatusException.class)); - assertThat(e.getMessage(), is("Unknown service [" + serviceId + "] for model [" + inferenceId + "]. ")); - assertThat(((ElasticsearchStatusException) e).status(), is(RestStatus.BAD_REQUEST)); - })); - verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { - assertThat(attributes.get("service"), is(serviceId)); - assertThat(attributes.get("task_type"), is(taskType.toString())); - assertThat(attributes.get("model_id"), nullValue()); - assertThat(attributes.get("status_code"), is(RestStatus.BAD_REQUEST.getStatus())); - assertThat(attributes.get("error.type"), is(String.valueOf(RestStatus.BAD_REQUEST.getStatus()))); - })); - } - - private void mockModelRegistry(TaskType expectedTaskType) { - var unparsedModel = new UnparsedModel(inferenceId, expectedTaskType, serviceId, Map.of(), Map.of()); - doAnswer(ans -> { - ActionListener listener = ans.getArgument(1); - listener.onResponse(unparsedModel); - return null; - }).when(modelRegistry).getModelWithSecrets(any(), any()); - } - - public void testMetricsAfterUnknownTaskType() { - var modelTaskType = TaskType.RERANK; - var requestTaskType = TaskType.SPARSE_EMBEDDING; - mockModelRegistry(modelTaskType); - when(serviceRegistry.getService(any())).thenReturn(Optional.of(mock())); - - var listener = doExecute(requestTaskType); - - verify(listener).onFailure(assertArg(e -> { - assertThat(e, isA(ElasticsearchStatusException.class)); - assertThat( - e.getMessage(), - is( - "Incompatible task_type, the requested type [" - + requestTaskType - + "] does not match the model type [" - + modelTaskType - + "]" - ) - ); - assertThat(((ElasticsearchStatusException) e).status(), is(RestStatus.BAD_REQUEST)); - })); - verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { - assertThat(attributes.get("service"), is(serviceId)); - assertThat(attributes.get("task_type"), is(modelTaskType.toString())); - assertThat(attributes.get("model_id"), nullValue()); - assertThat(attributes.get("status_code"), is(RestStatus.BAD_REQUEST.getStatus())); - assertThat(attributes.get("error.type"), is(String.valueOf(RestStatus.BAD_REQUEST.getStatus()))); - })); - } - - public void testMetricsAfterInferError() { - var expectedException = new IllegalStateException("hello"); - var expectedError = expectedException.getClass().getSimpleName(); - mockService(listener -> listener.onFailure(expectedException)); - - var listener = doExecute(taskType); - - verify(listener).onFailure(same(expectedException)); - verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { - assertThat(attributes.get("service"), is(serviceId)); - assertThat(attributes.get("task_type"), is(taskType.toString())); - assertThat(attributes.get("model_id"), nullValue()); - assertThat(attributes.get("status_code"), nullValue()); - assertThat(attributes.get("error.type"), is(expectedError)); - })); - } - - public void testMetricsAfterStreamUnsupported() { - var expectedStatus = RestStatus.METHOD_NOT_ALLOWED; - var expectedError = String.valueOf(expectedStatus.getStatus()); - mockService(l -> {}); - - var listener = doExecute(taskType, true); - - verify(listener).onFailure(assertArg(e -> { - assertThat(e, isA(ElasticsearchStatusException.class)); - var ese = (ElasticsearchStatusException) e; - assertThat(ese.getMessage(), is("Streaming is not allowed for service [" + serviceId + "].")); - assertThat(ese.status(), is(expectedStatus)); - })); - verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { - assertThat(attributes.get("service"), is(serviceId)); - assertThat(attributes.get("task_type"), is(taskType.toString())); - assertThat(attributes.get("model_id"), nullValue()); - assertThat(attributes.get("status_code"), is(expectedStatus.getStatus())); - assertThat(attributes.get("error.type"), is(expectedError)); - })); - } - - public void testMetricsAfterInferSuccess() { - mockService(listener -> listener.onResponse(mock())); - - var listener = doExecute(taskType); - - verify(listener).onResponse(any()); - verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { - assertThat(attributes.get("service"), is(serviceId)); - assertThat(attributes.get("task_type"), is(taskType.toString())); - assertThat(attributes.get("model_id"), nullValue()); - assertThat(attributes.get("status_code"), is(200)); - assertThat(attributes.get("error.type"), nullValue()); - })); - } - - public void testMetricsAfterStreamInferSuccess() { - mockStreamResponse(Flow.Subscriber::onComplete); - verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { - assertThat(attributes.get("service"), is(serviceId)); - assertThat(attributes.get("task_type"), is(taskType.toString())); - assertThat(attributes.get("model_id"), nullValue()); - assertThat(attributes.get("status_code"), is(200)); - assertThat(attributes.get("error.type"), nullValue()); - })); - } - - public void testMetricsAfterStreamInferFailure() { - var expectedException = new IllegalStateException("hello"); - var expectedError = expectedException.getClass().getSimpleName(); - mockStreamResponse(subscriber -> { - subscriber.subscribe(mock()); - subscriber.onError(expectedException); - }); - verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { - assertThat(attributes.get("service"), is(serviceId)); - assertThat(attributes.get("task_type"), is(taskType.toString())); - assertThat(attributes.get("model_id"), nullValue()); - assertThat(attributes.get("status_code"), nullValue()); - assertThat(attributes.get("error.type"), is(expectedError)); - })); - } - - public void testMetricsAfterStreamCancel() { - var response = mockStreamResponse(s -> s.onSubscribe(mock())); - response.subscribe(new Flow.Subscriber<>() { - @Override - public void onSubscribe(Flow.Subscription subscription) { - subscription.cancel(); - } - - @Override - public void onNext(ChunkedToXContent item) { - - } - - @Override - public void onError(Throwable throwable) { - - } - - @Override - public void onComplete() { - - } - }); - - verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { - assertThat(attributes.get("service"), is(serviceId)); - assertThat(attributes.get("task_type"), is(taskType.toString())); - assertThat(attributes.get("model_id"), nullValue()); - assertThat(attributes.get("status_code"), is(200)); - assertThat(attributes.get("error.type"), nullValue()); - })); - } - - private Flow.Publisher mockStreamResponse(Consumer> action) { - mockService(true, Set.of(), listener -> { - Flow.Processor taskProcessor = mock(); - doAnswer(innerAns -> { - action.accept(innerAns.getArgument(0)); - return null; - }).when(taskProcessor).subscribe(any()); - when(streamingTaskManager.create(any(), any())).thenReturn(taskProcessor); - var inferenceServiceResults = mock(InferenceServiceResults.class); - when(inferenceServiceResults.publisher()).thenReturn(mock()); - listener.onResponse(inferenceServiceResults); - }); - - var listener = doExecute(taskType, true); - var captor = ArgumentCaptor.forClass(InferenceAction.Response.class); - verify(listener).onResponse(captor.capture()); - assertTrue(captor.getValue().isStreaming()); - assertNotNull(captor.getValue().publisher()); - return captor.getValue().publisher(); - } - - private void mockService(Consumer> listenerAction) { - mockService(false, Set.of(), listenerAction); - } - - private void mockService( - boolean stream, - Set supportedStreamingTasks, - Consumer> listenerAction - ) { - InferenceService service = mock(); - Model model = mockModel(); - when(service.parsePersistedConfigWithSecrets(any(), any(), any(), any())).thenReturn(model); - when(service.name()).thenReturn(serviceId); - - when(service.canStream(any())).thenReturn(stream); - when(service.supportedStreamingTasks()).thenReturn(supportedStreamingTasks); - doAnswer(ans -> { - listenerAction.accept(ans.getArgument(7)); - return null; - }).when(service).infer(any(), any(), any(), anyBoolean(), any(), any(), any(), any()); - mockModelAndServiceRegistry(service); - } - - private Model mockModel() { - Model model = mock(); - ModelConfigurations modelConfigurations = mock(); - when(modelConfigurations.getService()).thenReturn(serviceId); - when(model.getConfigurations()).thenReturn(modelConfigurations); - when(model.getTaskType()).thenReturn(taskType); - when(model.getServiceSettings()).thenReturn(mock()); - return model; - } - - private void mockModelAndServiceRegistry(InferenceService service) { - var unparsedModel = new UnparsedModel(inferenceId, taskType, serviceId, Map.of(), Map.of()); - doAnswer(ans -> { - ActionListener listener = ans.getArgument(1); - listener.onResponse(unparsedModel); - return null; - }).when(modelRegistry).getModelWithSecrets(any(), any()); - - when(serviceRegistry.getService(any())).thenReturn(Optional.of(service)); + @Override + protected InferenceAction.Request createRequest() { + return mock(); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java new file mode 100644 index 0000000000000..bb702c6b1e538 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java @@ -0,0 +1,83 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.action; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction; +import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.telemetry.InferenceStats; + +import java.util.Optional; + +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.isA; +import static org.hamcrest.Matchers.nullValue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.assertArg; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class TransportUnifiedCompletionActionTests extends BaseTransportInferenceActionTestCase { + + @Override + protected BaseTransportInferenceAction createAction( + TransportService transportService, + ActionFilters actionFilters, + ModelRegistry modelRegistry, + InferenceServiceRegistry serviceRegistry, + InferenceStats inferenceStats, + StreamingTaskManager streamingTaskManager + ) { + return new TransportUnifiedCompletionInferenceAction( + transportService, + actionFilters, + modelRegistry, + serviceRegistry, + inferenceStats, + streamingTaskManager + ); + } + + @Override + protected UnifiedCompletionAction.Request createRequest() { + return mock(); + } + + public void testThrows_IncompatibleTaskTypeException_WhenUsingATextEmbeddingInferenceEndpoint() { + var modelTaskType = TaskType.TEXT_EMBEDDING; + var requestTaskType = TaskType.TEXT_EMBEDDING; + mockModelRegistry(modelTaskType); + when(serviceRegistry.getService(any())).thenReturn(Optional.of(mock())); + + var listener = doExecute(requestTaskType); + + verify(listener).onFailure(assertArg(e -> { + assertThat(e, isA(ElasticsearchStatusException.class)); + assertThat( + e.getMessage(), + is("Incompatible task_type for unified API, the requested type [" + requestTaskType + "] must be one of [completion]") + ); + assertThat(((ElasticsearchStatusException) e).status(), is(RestStatus.BAD_REQUEST)); + })); + verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { + assertThat(attributes.get("service"), is(serviceId)); + assertThat(attributes.get("task_type"), is(modelTaskType.toString())); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), is(RestStatus.BAD_REQUEST.getStatus())); + assertThat(attributes.get("error.type"), is(String.valueOf(RestStatus.BAD_REQUEST.getStatus()))); + })); + } +} From f3822465bf1a1bed93f2129b0063c99a8af116ca Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Tue, 3 Dec 2024 16:28:14 -0500 Subject: [PATCH 31/53] Fixing various tests --- .../inference/UnifiedCompletionRequest.java | 2 - .../http/sender/ChatCompletionInput.java | 8 +--- .../http/sender/DocumentsOnlyInput.java | 8 +--- .../external/http/sender/InferenceInputs.java | 10 ++++ .../http/sender/QueryAndDocsInputs.java | 8 +--- .../http/sender/UnifiedChatInput.java | 7 +-- .../inference/services/SenderService.java | 2 +- .../googleaistudio/GoogleAiStudioService.java | 3 +- .../elasticsearch/xpack/inference/Utils.java | 3 ++ ...ingleInputSenderExecutableActionTests.java | 20 ++------ .../AmazonBedrockActionCreatorTests.java | 5 +- .../AnthropicActionCreatorTests.java | 7 +-- .../AnthropicChatCompletionActionTests.java | 12 ++--- .../AzureAiStudioActionAndCreatorTests.java | 3 +- .../AzureOpenAiActionCreatorTests.java | 7 +-- .../AzureOpenAiCompletionActionTests.java | 10 ++-- .../cohere/CohereActionCreatorTests.java | 5 +- .../cohere/CohereCompletionActionTests.java | 18 ++++---- .../GoogleAiStudioCompletionActionTests.java | 12 ++--- .../openai/OpenAiActionCreatorTests.java | 21 +++++---- .../OpenAiChatCompletionActionTests.java | 19 ++++---- .../AmazonBedrockMockRequestSender.java | 12 ++++- .../AmazonBedrockRequestSenderTests.java | 3 +- .../http/sender/UnifiedChatInputTests.java | 46 +++++++++++++++++++ 24 files changed, 145 insertions(+), 106 deletions(-) create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInputTests.java diff --git a/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java b/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java index a0b86aa2b19c0..5a566d1729933 100644 --- a/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java +++ b/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java @@ -81,8 +81,6 @@ public static List getNamedWriteables() { new NamedWriteableRegistry.Entry(Content.class, ContentString.NAME, ContentString::new), new NamedWriteableRegistry.Entry(ToolChoice.class, ToolChoiceObject.NAME, ToolChoiceObject::new), new NamedWriteableRegistry.Entry(ToolChoice.class, ToolChoiceString.NAME, ToolChoiceString::new) - // new NamedWriteableRegistry.Entry(Stop.class, StopValues.NAME, StopValues::new), - // new NamedWriteableRegistry.Entry(Stop.class, StopString.NAME, StopString::new) ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ChatCompletionInput.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ChatCompletionInput.java index eb869682a6c2a..e7f3eb7dfea67 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ChatCompletionInput.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ChatCompletionInput.java @@ -12,26 +12,20 @@ public class ChatCompletionInput extends InferenceInputs { private final List input; - private final boolean stream; public ChatCompletionInput(List input) { this(input, false); } public ChatCompletionInput(List input, boolean stream) { - super(); + super(stream); this.input = Objects.requireNonNull(input); - this.stream = stream; } public List getInputs() { return this.input; } - public boolean stream() { - return stream; - } - public int inputSize() { return input.size(); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/DocumentsOnlyInput.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/DocumentsOnlyInput.java index da5bae00c7831..3feb79d3de6cc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/DocumentsOnlyInput.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/DocumentsOnlyInput.java @@ -21,26 +21,20 @@ public static DocumentsOnlyInput of(InferenceInputs inferenceInputs) { } private final List input; - private final boolean stream; public DocumentsOnlyInput(List input) { this(input, false); } public DocumentsOnlyInput(List input, boolean stream) { - super(); + super(stream); this.input = Objects.requireNonNull(input); - this.stream = stream; } public List getInputs() { return this.input; } - public boolean stream() { - return stream; - } - public int inputSize() { return input.size(); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputs.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputs.java index 73719a8e57bc5..e85ea6f1d9b35 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputs.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputs.java @@ -10,6 +10,12 @@ import org.elasticsearch.common.Strings; public abstract class InferenceInputs { + private final boolean stream; + + public InferenceInputs(boolean stream) { + this.stream = stream; + } + public static IllegalArgumentException createUnsupportedTypeException(InferenceInputs inferenceInputs, Class clazz) { return new IllegalArgumentException( Strings.format("Unable to convert inference inputs type: [%s] to [%s]", inferenceInputs.getClass(), clazz) @@ -24,5 +30,9 @@ public T castTo(Class clazz) { return clazz.cast(this); } + public boolean stream() { + return stream; + } + public abstract int inputSize(); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java index 0218799ee892a..5af5245ac5b40 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java @@ -22,17 +22,15 @@ public static QueryAndDocsInputs of(InferenceInputs inferenceInputs) { private final String query; private final List chunks; - private final boolean stream; public QueryAndDocsInputs(String query, List chunks) { this(query, chunks, false); } public QueryAndDocsInputs(String query, List chunks, boolean stream) { - super(); + super(stream); this.query = Objects.requireNonNull(query); this.chunks = Objects.requireNonNull(chunks); - this.stream = stream; } public String getQuery() { @@ -43,10 +41,6 @@ public List getChunks() { return chunks; } - public boolean stream() { - return stream; - } - public int inputSize() { return chunks.size(); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInput.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInput.java index 95f4037276363..be647ef85e869 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInput.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInput.java @@ -14,11 +14,10 @@ public class UnifiedChatInput extends InferenceInputs { private final UnifiedCompletionRequest request; - private final boolean stream; public UnifiedChatInput(UnifiedCompletionRequest request, boolean stream) { + super(stream); this.request = Objects.requireNonNull(request); - this.stream = stream; } public UnifiedChatInput(ChatCompletionInput completionInput, String roleValue) { @@ -47,10 +46,6 @@ public UnifiedCompletionRequest getRequest() { return request; } - public boolean stream() { - return stream; - } - public int inputSize() { return request.messages().size(); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java index 3ba9c36106895..e9b75e9ec7796 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java @@ -75,7 +75,7 @@ private static InferenceInputs createInput(Model model, List input, @Nul return switch (model.getTaskType()) { case COMPLETION -> new ChatCompletionInput(input, stream); case RERANK -> new QueryAndDocsInputs(query, input, stream); - case TEXT_EMBEDDING -> new DocumentsOnlyInput(input, stream); + case TEXT_EMBEDDING, SPARSE_EMBEDDING -> new DocumentsOnlyInput(input, stream); default -> throw new ElasticsearchStatusException( Strings.format("Invalid task type received when determining input type: [%s]", model.getTaskType().toString()), RestStatus.BAD_REQUEST diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java index f10c900c3a30e..b681722a82136 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java @@ -284,9 +284,8 @@ protected void doInfer( ) { if (model instanceof GoogleAiStudioCompletionModel completionModel) { var requestManager = new GoogleAiStudioCompletionRequestManager(completionModel, getServiceComponents().threadPool()); - var docsOnly = DocumentsOnlyInput.of(inputs); var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage( - completionModel.uri(docsOnly.stream()), + completionModel.uri(inputs.stream()), "Google AI Studio completion" ); var action = new SingleInputSenderExecutableAction( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java index 5abb9000f4d04..9395ae222e9ba 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java @@ -19,6 +19,7 @@ import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.threadpool.ScalingExecutorBuilder; import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; import org.elasticsearch.xpack.inference.common.Truncator; @@ -160,9 +161,11 @@ public static Model getInvalidModel(String inferenceEntityId, String serviceName var mockConfigs = mock(ModelConfigurations.class); when(mockConfigs.getInferenceEntityId()).thenReturn(inferenceEntityId); when(mockConfigs.getService()).thenReturn(serviceName); + when(mockConfigs.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING); var mockModel = mock(Model.class); when(mockModel.getConfigurations()).thenReturn(mockConfigs); + when(mockModel.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING); return mockModel; } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableActionTests.java index d4ab9b1f1e19a..9e7c58b0ca79e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableActionTests.java @@ -61,25 +61,11 @@ public void testOneInputIsValid() { assertTrue("Test failed to call listener.", testRan.get()); } - public void testInvalidInputType() { - var badInput = mock(InferenceInputs.class); - var actualException = new AtomicReference(); - - executableAction.execute( - badInput, - mock(TimeValue.class), - ActionListener.wrap(shouldNotSucceed -> fail("Test failed."), actualException::set) - ); - - assertThat(actualException.get(), notNullValue()); - assertThat(actualException.get().getMessage(), is("Invalid inference input type")); - assertThat(actualException.get(), instanceOf(ElasticsearchStatusException.class)); - assertThat(((ElasticsearchStatusException) actualException.get()).status(), is(RestStatus.INTERNAL_SERVER_ERROR)); - } - public void testMoreThanOneInput() { var badInput = mock(DocumentsOnlyInput.class); - when(badInput.getInputs()).thenReturn(List.of("one", "two")); + var input = List.of("one", "two"); + when(badInput.getInputs()).thenReturn(input); + when(badInput.inputSize()).thenReturn(input.size()); var actualException = new AtomicReference(); executableAction.execute( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockActionCreatorTests.java index 87d3a82b4aae6..e7543aa6ba9e5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockActionCreatorTests.java @@ -17,6 +17,7 @@ import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockMockRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.services.ServiceComponentsTests; import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProvider; @@ -130,7 +131,7 @@ public void testCompletionRequestAction() throws IOException { ); var action = creator.create(model, Map.of()); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); assertThat(result.asMap(), is(buildExpectationCompletion(List.of("test input string")))); @@ -163,7 +164,7 @@ public void testChatCompletionRequestAction_HandlesException() throws IOExceptio ); var action = creator.create(model, Map.of()); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); assertThat(sender.sendCount(), is(1)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/anthropic/AnthropicActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/anthropic/AnthropicActionCreatorTests.java index a3114300c5ddc..f0de37ceaaf98 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/anthropic/AnthropicActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/anthropic/AnthropicActionCreatorTests.java @@ -20,7 +20,7 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; -import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.external.request.anthropic.AnthropicRequestUtils; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; @@ -49,6 +49,7 @@ import static org.mockito.Mockito.mock; public class AnthropicActionCreatorTests extends ESTestCase { + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); private final MockWebServer webServer = new MockWebServer(); private ThreadPool threadPool; @@ -103,7 +104,7 @@ public void testCreate_ChatCompletionModel() throws IOException { var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -168,7 +169,7 @@ public void testCreate_ChatCompletionModel_FailsFromInvalidResponseFormat() thro var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); assertThat( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/anthropic/AnthropicChatCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/anthropic/AnthropicChatCompletionActionTests.java index fca2e316af17f..2065a726b7589 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/anthropic/AnthropicChatCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/anthropic/AnthropicChatCompletionActionTests.java @@ -27,7 +27,7 @@ import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.AnthropicCompletionRequestManager; -import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.external.http.sender.Sender; @@ -113,7 +113,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { var action = createAction(getUrl(webServer), "secret", "model", 1, sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -149,7 +149,7 @@ public void testExecute_ThrowsElasticsearchException() { var action = createAction(getUrl(webServer), "secret", "model", 1, sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -170,7 +170,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled var action = createAction(getUrl(webServer), "secret", "model", 1, sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -187,7 +187,7 @@ public void testExecute_ThrowsException() { var action = createAction(getUrl(webServer), "secret", "model", 1, sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -229,7 +229,7 @@ public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOExc var action = createAction(getUrl(webServer), "secret", "model", 1, sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc", "def")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc", "def")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureaistudio/AzureAiStudioActionAndCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureaistudio/AzureAiStudioActionAndCreatorTests.java index 8792234102a94..210fab457de10 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureaistudio/AzureAiStudioActionAndCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureaistudio/AzureAiStudioActionAndCreatorTests.java @@ -21,6 +21,7 @@ import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.common.TruncatorTests; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; @@ -160,7 +161,7 @@ public void testChatCompletionRequestAction() throws IOException { var action = creator.create(model, Map.of()); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreatorTests.java index 45a2fb0954c79..7e1e3e55caed8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreatorTests.java @@ -24,6 +24,7 @@ import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiUtils; @@ -475,7 +476,7 @@ public void testInfer_AzureOpenAiCompletion_WithOverriddenUser() throws IOExcept var action = actionCreator.create(model, taskSettingsWithUserOverride); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of(completionInput)), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of(completionInput)), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -531,7 +532,7 @@ public void testInfer_AzureOpenAiCompletionModel_WithoutUser() throws IOExceptio var action = actionCreator.create(model, requestTaskSettingsWithoutUser); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of(completionInput)), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of(completionInput)), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -589,7 +590,7 @@ public void testInfer_AzureOpenAiCompletionModel_FailsFromInvalidResponseFormat( var action = actionCreator.create(model, requestTaskSettingsWithoutUser); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of(completionInput)), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of(completionInput)), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); assertThat( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiCompletionActionTests.java index 4c7683c882816..dca12dfda9c98 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiCompletionActionTests.java @@ -26,7 +26,7 @@ import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.AzureOpenAiCompletionRequestManager; -import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiUtils; @@ -111,7 +111,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { var action = createAction("resource", "deployment", "apiversion", user, apiKey, sender, "id"); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of(completionInput)), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of(completionInput)), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -142,7 +142,7 @@ public void testExecute_ThrowsElasticsearchException() { var action = createAction("resource", "deployment", "apiVersion", "user", "apikey", sender, "id"); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -163,7 +163,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled var action = createAction("resource", "deployment", "apiVersion", "user", "apikey", sender, "id"); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -177,7 +177,7 @@ public void testExecute_ThrowsException() { var action = createAction("resource", "deployment", "apiVersion", "user", "apikey", sender, "id"); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java index 9ec34e7d8e5c5..3a512de25a39c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java @@ -20,6 +20,7 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; @@ -197,7 +198,7 @@ public void testCreate_CohereCompletionModel_WithModelSpecified() throws IOExcep var action = actionCreator.create(model, Map.of()); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -257,7 +258,7 @@ public void testCreate_CohereCompletionModel_WithoutModelSpecified() throws IOEx var action = actionCreator.create(model, Map.of()); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereCompletionActionTests.java index ba839e0d7c5e9..c5871adb34864 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereCompletionActionTests.java @@ -26,8 +26,8 @@ import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.CohereCompletionRequestManager; -import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.external.request.cohere.CohereUtils; @@ -120,7 +120,7 @@ public void testExecute_ReturnsSuccessfulResponse_WithModelSpecified() throws IO var action = createAction(getUrl(webServer), "secret", "model", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -181,7 +181,7 @@ public void testExecute_ReturnsSuccessfulResponse_WithoutModelSpecified() throws var action = createAction(getUrl(webServer), "secret", null, sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -214,7 +214,7 @@ public void testExecute_ThrowsElasticsearchException() { var action = createAction(getUrl(webServer), "secret", "model", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -235,7 +235,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled var action = createAction(getUrl(webServer), "secret", "model", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -256,7 +256,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled var action = createAction(null, "secret", "model", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -270,7 +270,7 @@ public void testExecute_ThrowsException() { var action = createAction(getUrl(webServer), "secret", "model", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -284,7 +284,7 @@ public void testExecute_ThrowsExceptionWithNullUrl() { var action = createAction(null, "secret", "model", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -334,7 +334,7 @@ public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOExc var action = createAction(getUrl(webServer), "secret", "model", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc", "def")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc", "def")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioCompletionActionTests.java index 72b5ffa45a0dd..ff17bbf66e02a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioCompletionActionTests.java @@ -25,7 +25,7 @@ import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; -import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.GoogleAiStudioCompletionRequestManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.external.http.sender.Sender; @@ -128,7 +128,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { var action = createAction(getUrl(webServer), "secret", "model", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("input")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("input")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -159,7 +159,7 @@ public void testExecute_ThrowsElasticsearchException() { var action = createAction(getUrl(webServer), "secret", "model", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -180,7 +180,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled var action = createAction(getUrl(webServer), "secret", "model", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -197,7 +197,7 @@ public void testExecute_ThrowsException() { var action = createAction(getUrl(webServer), "secret", "model", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -260,7 +260,7 @@ public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOExc var action = createAction(getUrl(webServer), "secret", "model", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc", "def")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc", "def")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java index b6d7eb673b7f0..fe076eb721ea2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java @@ -21,6 +21,7 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; @@ -330,7 +331,7 @@ public void testCreate_OpenAiChatCompletionModel() throws IOException { var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -345,11 +346,12 @@ public void testCreate_OpenAiChatCompletionModel() throws IOException { assertThat(request.getHeader(ORGANIZATION_HEADER), equalTo("org")); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - assertThat(requestMap.size(), is(4)); + assertThat(requestMap.size(), is(5)); assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc")))); assertThat(requestMap.get("model"), is("model")); assertThat(requestMap.get("user"), is("overridden_user")); assertThat(requestMap.get("n"), is(1)); + assertThat(requestMap.get("stream"), is(false)); } } @@ -393,7 +395,7 @@ public void testCreate_OpenAiChatCompletionModel_WithoutUser() throws IOExceptio var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -408,10 +410,11 @@ public void testCreate_OpenAiChatCompletionModel_WithoutUser() throws IOExceptio assertThat(request.getHeader(ORGANIZATION_HEADER), equalTo("org")); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - assertThat(requestMap.size(), is(3)); + assertThat(requestMap.size(), is(4)); assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc")))); assertThat(requestMap.get("model"), is("model")); assertThat(requestMap.get("n"), is(1)); + assertThat(requestMap.get("stream"), is(false)); } } @@ -455,7 +458,7 @@ public void testCreate_OpenAiChatCompletionModel_WithoutOrganization() throws IO var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -470,11 +473,12 @@ public void testCreate_OpenAiChatCompletionModel_WithoutOrganization() throws IO assertNull(request.getHeader(ORGANIZATION_HEADER)); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - assertThat(requestMap.size(), is(4)); + assertThat(requestMap.size(), is(5)); assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc")))); assertThat(requestMap.get("model"), is("model")); assertThat(requestMap.get("user"), is("overridden_user")); assertThat(requestMap.get("n"), is(1)); + assertThat(requestMap.get("stream"), is(false)); } } @@ -523,7 +527,7 @@ public void testCreate_OpenAiChatCompletionModel_FailsFromInvalidResponseFormat( var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); assertThat( @@ -542,11 +546,12 @@ public void testCreate_OpenAiChatCompletionModel_FailsFromInvalidResponseFormat( assertNull(webServer.requests().get(0).getHeader(ORGANIZATION_HEADER)); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - assertThat(requestMap.size(), is(4)); + assertThat(requestMap.size(), is(5)); assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc")))); assertThat(requestMap.get("model"), is("model")); assertThat(requestMap.get("user"), is("overridden_user")); assertThat(requestMap.get("n"), is(1)); + assertThat(requestMap.get("stream"), is(false)); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiChatCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiChatCompletionActionTests.java index d84b2b5bb324a..ba74d2ab42c21 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiChatCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiChatCompletionActionTests.java @@ -27,7 +27,7 @@ import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; -import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.external.http.sender.OpenAiCompletionRequestManager; @@ -119,7 +119,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { var action = createAction(getUrl(webServer), "org", "secret", "model", "user", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -134,11 +134,12 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { assertThat(request.getHeader(ORGANIZATION_HEADER), equalTo("org")); var requestMap = entityAsMap(request.getBody()); - assertThat(requestMap.size(), is(4)); + assertThat(requestMap.size(), is(5)); assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc")))); assertThat(requestMap.get("model"), is("model")); assertThat(requestMap.get("user"), is("user")); assertThat(requestMap.get("n"), is(1)); + assertThat(requestMap.get("stream"), is(false)); } } @@ -159,7 +160,7 @@ public void testExecute_ThrowsElasticsearchException() { var action = createAction(getUrl(webServer), "org", "secret", "model", "user", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -180,7 +181,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled var action = createAction(getUrl(webServer), "org", "secret", "model", "user", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -201,7 +202,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled var action = createAction(null, "org", "secret", "model", "user", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -215,7 +216,7 @@ public void testExecute_ThrowsException() { var action = createAction(getUrl(webServer), "org", "secret", "model", "user", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -229,7 +230,7 @@ public void testExecute_ThrowsExceptionWithNullUrl() { var action = createAction(null, "org", "secret", "model", "user", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -273,7 +274,7 @@ public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOExc var action = createAction(getUrl(webServer), "org", "secret", "model", "user", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc", "def")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc", "def")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockRequestSender.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockRequestSender.java index e68beaf4c1eb5..929aefeeef6b9 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockRequestSender.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockRequestSender.java @@ -12,6 +12,7 @@ import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.http.sender.RequestManager; @@ -67,8 +68,15 @@ public void send( ActionListener listener ) { sendCounter++; - var docsInput = (DocumentsOnlyInput) inferenceInputs; - inputs.add(docsInput.getInputs()); + if (inferenceInputs instanceof DocumentsOnlyInput docsInput) { + inputs.add(docsInput.getInputs()); + } else if (inferenceInputs instanceof ChatCompletionInput chatCompletionInput) { + inputs.add(chatCompletionInput.getInputs()); + } else { + throw new IllegalArgumentException( + "Invalid inference inputs received in mock sender: " + inferenceInputs.getClass().getSimpleName() + ); + } if (results.isEmpty()) { listener.onFailure(new ElasticsearchException("No results found")); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSenderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSenderTests.java index 7fa8a09d5bf12..a8f37aedcece3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSenderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSenderTests.java @@ -15,6 +15,7 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.inference.external.http.sender.AmazonBedrockChatCompletionRequestManager; import org.elasticsearch.xpack.inference.external.http.sender.AmazonBedrockEmbeddingsRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; @@ -107,7 +108,7 @@ public void testCreateSender_SendsCompletionRequestAndReceivesResponse() throws PlainActionFuture listener = new PlainActionFuture<>(); var requestManager = new AmazonBedrockChatCompletionRequestManager(model, threadPool, new TimeValue(30, TimeUnit.SECONDS)); - sender.send(requestManager, new DocumentsOnlyInput(List.of("abc")), null, listener); + sender.send(requestManager, new ChatCompletionInput(List.of("abc")), null, listener); var result = listener.actionGet(TIMEOUT); assertThat(result.asMap(), is(buildExpectationCompletion(List.of("test response text")))); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInputTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInputTests.java new file mode 100644 index 0000000000000..42e1b18168aec --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInputTests.java @@ -0,0 +1,46 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.test.ESTestCase; +import org.hamcrest.Matchers; + +import java.util.List; + +public class UnifiedChatInputTests extends ESTestCase { + + public void testConvertsStringInputToMessages() { + var a = new UnifiedChatInput(List.of("hello", "awesome"), "a role", true); + + assertThat(a.inputSize(), Matchers.is(2)); + assertThat( + a.getRequest(), + Matchers.is( + UnifiedCompletionRequest.of( + List.of( + new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("hello"), + "a role", + null, + null, + null + ), + new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("awesome"), + "a role", + null, + null, + null + ) + ) + ) + ) + ); + } +} From b7d1c86f57ac83fc6fb420f15d03ffda04bfeda1 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Tue, 3 Dec 2024 17:21:39 -0500 Subject: [PATCH 32/53] Fixing license header --- .../elasticsearch/inference/UnifiedCompletionRequest.java | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java b/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java index 5a566d1729933..8827dbd6ea919 100644 --- a/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java +++ b/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java @@ -1,8 +1,10 @@ /* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". */ package org.elasticsearch.inference; From 93a671afcbaaf2410391d0d87669d51d5fa5c20a Mon Sep 17 00:00:00 2001 From: Max Hniebergall Date: Wed, 4 Dec 2024 10:00:29 -0500 Subject: [PATCH 33/53] Reformat streaming results --- .../xcontent/ChunkedToXContentHelper.java | 9 + .../rest/ChunkedRestResponseBodyPart.java | 4 +- .../rest/ChunkedZipResponse.java | 4 +- .../rest/StreamingXContentResponse.java | 4 +- ...StreamingUnifiedChatCompletionResults.java | 311 ++++++++---------- .../OpenAiUnifiedStreamingProcessor.java | 48 +-- .../ServerSentEventsRestActionListener.java | 17 +- .../OpenAiUnifiedStreamingProcessorTests.java | 62 ++-- 8 files changed, 203 insertions(+), 256 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/common/xcontent/ChunkedToXContentHelper.java b/server/src/main/java/org/elasticsearch/common/xcontent/ChunkedToXContentHelper.java index 2e78cc6f516b1..6a5aa2943de92 100644 --- a/server/src/main/java/org/elasticsearch/common/xcontent/ChunkedToXContentHelper.java +++ b/server/src/main/java/org/elasticsearch/common/xcontent/ChunkedToXContentHelper.java @@ -12,6 +12,7 @@ import org.elasticsearch.common.collect.Iterators; import org.elasticsearch.xcontent.ToXContent; +import java.util.Collections; import java.util.Iterator; public enum ChunkedToXContentHelper { @@ -53,6 +54,14 @@ public static Iterator field(String name, String value) { return Iterators.single(((builder, params) -> builder.field(name, value))); } + public static Iterator optionalField(String name, String value) { + if (value == null) { + return Collections.emptyIterator(); + } else { + return field(name, value); + } + } + /** * Creates an Iterator of a single ToXContent object that serializes the given object as a single chunk. Just wraps {@link * Iterators#single}, but still useful because it avoids any type ambiguity. diff --git a/server/src/main/java/org/elasticsearch/rest/ChunkedRestResponseBodyPart.java b/server/src/main/java/org/elasticsearch/rest/ChunkedRestResponseBodyPart.java index 694af7e1606cb..75bffdafd8db7 100644 --- a/server/src/main/java/org/elasticsearch/rest/ChunkedRestResponseBodyPart.java +++ b/server/src/main/java/org/elasticsearch/rest/ChunkedRestResponseBodyPart.java @@ -173,11 +173,11 @@ public ReleasableBytesReference encodeChunk(int sizeHint, Recycler rec target = null; return result; } catch (Exception e) { - logger.error("failure encoding chunk", e); + logger.error("failure encoding chunk 5", e); throw e; } finally { if (target != null) { - assert false : "failure encoding chunk"; + assert false : "failure encoding chunk 6"; IOUtils.closeWhileHandlingException(target); target = null; } diff --git a/server/src/main/java/org/elasticsearch/rest/ChunkedZipResponse.java b/server/src/main/java/org/elasticsearch/rest/ChunkedZipResponse.java index 585f5f3b1b4d1..72d1a835102bb 100644 --- a/server/src/main/java/org/elasticsearch/rest/ChunkedZipResponse.java +++ b/server/src/main/java/org/elasticsearch/rest/ChunkedZipResponse.java @@ -412,11 +412,11 @@ public ReleasableBytesReference encodeChunk(int sizeHint, Recycler rec return new ReleasableBytesReference(BytesArray.EMPTY, () -> {}); } } catch (Exception e) { - logger.error("failure encoding chunk", e); + logger.error("failure encoding chunk 7", e); throw e; } finally { if (targetStream != null) { - assert false : "failure encoding chunk"; + assert false : "failure encoding chunk 8"; IOUtils.closeWhileHandlingException(targetStream, Releasables.wrap(releasables)); targetStream = null; } diff --git a/server/src/main/java/org/elasticsearch/rest/StreamingXContentResponse.java b/server/src/main/java/org/elasticsearch/rest/StreamingXContentResponse.java index db33673939ae9..19da57ba56063 100644 --- a/server/src/main/java/org/elasticsearch/rest/StreamingXContentResponse.java +++ b/server/src/main/java/org/elasticsearch/rest/StreamingXContentResponse.java @@ -364,11 +364,11 @@ public ReleasableBytesReference encodeChunk(int sizeHint, Recycler rec return new ReleasableBytesReference(BytesArray.EMPTY, () -> {}); } } catch (Exception e) { - logger.error("failure encoding chunk", e); + logger.error("failure encoding chunk 9", e); throw e; } finally { if (targetStream != null) { - assert false : "failure encoding chunk"; + assert false : "failure encoding chunk 10"; IOUtils.closeWhileHandlingException(targetStream, Releasables.wrap(releasables)); targetStream = null; } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java index 4712a5964b483..23af412ee3692 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java @@ -16,18 +16,13 @@ import org.elasticsearch.xcontent.ToXContent; import java.io.IOException; -import java.util.Arrays; import java.util.Collections; import java.util.Deque; import java.util.Iterator; import java.util.List; import java.util.Map; -import java.util.Objects; import java.util.concurrent.Flow; -import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResults.COMPLETION; -import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResults.Result.RESULT; - /** * Chat Completion results that only contain a Flow.Publisher. */ @@ -35,9 +30,26 @@ public record StreamingUnifiedChatCompletionResults(Flow.Publisher toXContentChunked(ToXContent.Params params throw new UnsupportedOperationException("Not implemented"); } - public record Results(Deque results) implements ChunkedToXContent { + public record Results(Deque chunks) implements ChunkedToXContent { @Override public Iterator toXContentChunked(ToXContent.Params params) { return Iterators.concat( ChunkedToXContentHelper.startObject(), - ChunkedToXContentHelper.startArray(COMPLETION), - Iterators.flatMap(results.iterator(), d -> d.toXContentChunked(params)), + ChunkedToXContentHelper.startArray(NAME), + Iterators.flatMap(chunks.iterator(), c -> c.toXContentChunked(params)), ChunkedToXContentHelper.endArray(), ChunkedToXContentHelper.endObject() ); } } - private static final String REFUSAL_FIELD = "refusal"; - private static final String TOOL_CALLS_FIELD = "tool_calls"; - public static final String FINISH_REASON_FIELD = "finish_reason"; - - public record Result( - String delta, - String refusal, - List toolCalls, - String finishReason, - String model, - String object, - ChatCompletionChunk.Usage usage - ) implements ChunkedToXContent { - - @Override - public Iterator toXContentChunked(ToXContent.Params params) { - Iterator toolCallsIterator = Collections.emptyIterator(); - if (toolCalls != null && toolCalls.isEmpty() == false) { - toolCallsIterator = Iterators.concat( - ChunkedToXContentHelper.startArray(TOOL_CALLS_FIELD), - Iterators.flatMap(toolCalls.iterator(), d -> d.toXContentChunked(params)), - ChunkedToXContentHelper.endArray() - ); - } - - Iterator usageIterator = Collections.emptyIterator(); - if (usage != null) { - usageIterator = Iterators.concat( - ChunkedToXContentHelper.startObject(USAGE_FIELD), - ChunkedToXContentHelper.field("completion_tokens", usage.completionTokens()), - ChunkedToXContentHelper.field("prompt_tokens", usage.promptTokens()), - ChunkedToXContentHelper.field("total_tokens", usage.totalTokens()), - ChunkedToXContentHelper.endObject() - ); - } - - return Iterators.concat( - ChunkedToXContentHelper.startObject(), - ChunkedToXContentHelper.field(RESULT, delta), - ChunkedToXContentHelper.field(REFUSAL_FIELD, refusal), - toolCallsIterator, - ChunkedToXContentHelper.field(FINISH_REASON_FIELD, finishReason), - ChunkedToXContentHelper.field(MODEL_FIELD, model), - ChunkedToXContentHelper.field(OBJECT_FIELD, object), - usageIterator, - ChunkedToXContentHelper.endObject() - ); - } - } - - public static class ToolCall implements ChunkedToXContent { - private final int index; + public static class ChatCompletionChunk implements ChunkedToXContent { private final String id; - private final String functionName; - private final String functionArguments; - - public ToolCall(int index, String id, String functionName, String functionArguments) { - this.index = index; - this.id = id; - this.functionName = functionName; - this.functionArguments = functionArguments; - } - - public int getIndex() { - return index; - } public String getId() { return id; } - public String getFunctionName() { - return functionName; - } - - public String getFunctionArguments() { - return functionArguments; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - ToolCall toolCall = (ToolCall) o; - return index == toolCall.index - && Objects.equals(id, toolCall.id) - && Objects.equals(functionName, toolCall.functionName) - && Objects.equals(functionArguments, toolCall.functionArguments); + public List getChoices() { + return choices; } - @Override - public Iterator toXContentChunked(ToXContent.Params params) { - return Iterators.concat( - ChunkedToXContentHelper.startObject(), - ChunkedToXContentHelper.field("index", index), - ChunkedToXContentHelper.field("id", id), - ChunkedToXContentHelper.field("functionName", functionName), - ChunkedToXContentHelper.field("functionArguments", functionArguments), - ChunkedToXContentHelper.endObject() - ); + public String getModel() { + return model; } - @Override - public int hashCode() { - return Objects.hash(index, id, functionName, functionArguments); + public String getObject() { + return object; } - @Override - public String toString() { - return "ToolCall{" - + "index=" - + index - + ", id='" - + id - + '\'' - + ", functionName='" - + functionName - + '\'' - + ", functionArguments='" - + functionArguments - + '\'' - + '}'; + public Usage getUsage() { + return usage; } - } - public static class ChatCompletionChunk { - private final String id; - private List choices; + private final List choices; private final String model; private final String object; - private ChatCompletionChunk.Usage usage; + private final ChatCompletionChunk.Usage usage; public ChatCompletionChunk(String id, List choices, String model, String object, ChatCompletionChunk.Usage usage) { this.id = id; @@ -227,61 +135,57 @@ public ChatCompletionChunk(String id, List choices, String model, String this.usage = usage; } - public ChatCompletionChunk( - String id, - ChatCompletionChunk.Choice[] choices, - String model, - String object, - ChatCompletionChunk.Usage usage - ) { - this.id = id; - this.choices = Arrays.stream(choices).toList(); - this.model = model; - this.object = object; - this.usage = usage; - } - - public String getId() { - return id; - } - - public List getChoices() { - return choices; - } - - public String getModel() { - return model; - } - - public String getObject() { - return object; - } - - public ChatCompletionChunk.Usage getUsage() { - return usage; - } - - public static class Choice { - private final ChatCompletionChunk.Choice.Delta delta; - private final String finishReason; - private final int index; + @Override + public Iterator toXContentChunked(ToXContent.Params params) { - public Choice(ChatCompletionChunk.Choice.Delta delta, String finishReason, int index) { - this.delta = delta; - this.finishReason = finishReason; - this.index = index; + Iterator choicesIterator = Collections.emptyIterator(); + if (choices != null) { + choicesIterator = Iterators.concat( + ChunkedToXContentHelper.startArray(CHOICES_FIELD), + Iterators.flatMap(choices.iterator(), c -> c.toXContentChunked(params)), + ChunkedToXContentHelper.endArray() + ); } - public ChatCompletionChunk.Choice.Delta getDelta() { - return delta; + Iterator usageIterator = Collections.emptyIterator(); + if (usage != null) { + usageIterator = Iterators.concat( + ChunkedToXContentHelper.startObject(USAGE_FIELD), + ChunkedToXContentHelper.field("completion_tokens", usage.completionTokens()), + ChunkedToXContentHelper.field("prompt_tokens", usage.promptTokens()), + ChunkedToXContentHelper.field("total_tokens", usage.totalTokens()), + ChunkedToXContentHelper.endObject() + ); } - public String getFinishReason() { - return finishReason; - } + return Iterators.concat( + ChunkedToXContentHelper.startObject(), + ChunkedToXContentHelper.field(ID_FIELD, id), + choicesIterator, + ChunkedToXContentHelper.field(MODEL_FIELD, model), + ChunkedToXContentHelper.field(OBJECT_FIELD, object), + usageIterator, + ChunkedToXContentHelper.endObject() + ); + } - public int getIndex() { - return index; + public record Choice(ChatCompletionChunk.Choice.Delta delta, String finishReason, int index) { + + /* + choices: Array<{ + delta: { ... }; + finish_reason: string | null; + index: number; + }>; + */ + public Iterator toXContentChunked(ToXContent.Params params) { + return Iterators.concat( + ChunkedToXContentHelper.startObject(), + delta.toXContentChunked(params), + ChunkedToXContentHelper.optionalField(FINISH_REASON_FIELD, finishReason), + ChunkedToXContentHelper.field(INDEX_FIELD, index), + ChunkedToXContentHelper.endObject() + ); } public static class Delta { @@ -297,6 +201,35 @@ public Delta(String content, String refusal, String role, List toolCal this.toolCalls = toolCalls; } + /* + delta: { + content?: string | null; + refusal?: string | null; + role?: 'system' | 'user' | 'assistant' | 'tool'; + tool_calls?: Array<{ ... }>; + }; + */ + public Iterator toXContentChunked(ToXContent.Params params) { + var xContent = Iterators.concat( + ChunkedToXContentHelper.startObject(DELTA_FIELD), + ChunkedToXContentHelper.optionalField(CONTENT_FIELD, content), + ChunkedToXContentHelper.optionalField(REFUSAL_FIELD, refusal), + ChunkedToXContentHelper.optionalField(ROLE_FIELD, role) + ); + + if (toolCalls != null && toolCalls.isEmpty() == false) { + xContent = Iterators.concat( + xContent, + ChunkedToXContentHelper.startArray(TOOL_CALLS_FIELD), + Iterators.flatMap(toolCalls.iterator(), t -> t.toXContentChunked(params)), + ChunkedToXContentHelper.endArray() + ); + } + xContent = Iterators.concat(xContent, ChunkedToXContentHelper.endObject()); + return xContent; + + } + public String getContent() { return content; } @@ -342,6 +275,40 @@ public String getType() { return type; } + /* + index: number; + id?: string; + function?: { + arguments?: string; + name?: string; + }; + type?: 'function'; + */ + public Iterator toXContentChunked(ToXContent.Params params) { + var content = Iterators.concat( + ChunkedToXContentHelper.startObject(), + ChunkedToXContentHelper.field(INDEX_FIELD, index), + ChunkedToXContentHelper.optionalField(ID_FIELD, id) + ); + + if (function != null) { + content = Iterators.concat( + content, + ChunkedToXContentHelper.startObject(FUNCTION_FIELD), + ChunkedToXContentHelper.optionalField(FUNCTION_ARGUMENTS_FIELD, function.getArguments()), + ChunkedToXContentHelper.optionalField(FUNCTION_NAME_FIELD, function.getName()), + ChunkedToXContentHelper.endObject() + ); + } + + content = Iterators.concat( + content, + ChunkedToXContentHelper.field(TYPE_FIELD, type), + ChunkedToXContentHelper.endObject() + ); + return content; + } + public static class Function { private final String arguments; private final String name; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java index 168c6778d7a25..50cb5275c4b36 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java @@ -24,7 +24,6 @@ import java.io.IOException; import java.util.ArrayDeque; -import java.util.ArrayList; import java.util.Collections; import java.util.Deque; import java.util.Iterator; @@ -63,7 +62,7 @@ public class OpenAiUnifiedStreamingProcessor extends DelegatingProcessor item) throws Exception { var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); - var results = new ArrayDeque(item.size()); + var results = new ArrayDeque(item.size()); for (ServerSentEvent event : item) { if (ServerSentEventField.DATA == event.name() && event.hasValue()) { try { @@ -83,8 +82,10 @@ protected void next(Deque item) throws Exception { } } - private Iterator parse(XContentParserConfiguration parserConfig, ServerSentEvent event) - throws IOException { + private Iterator parse( + XContentParserConfiguration parserConfig, + ServerSentEvent event + ) throws IOException { if (DONE_MESSAGE.equalsIgnoreCase(event.value())) { return Collections.emptyIterator(); } @@ -97,45 +98,8 @@ private Iterator parse(XContentPar StreamingUnifiedChatCompletionResults.ChatCompletionChunk chunk = ChatCompletionChunkParser.parse(jsonParser); - List results = new ArrayList<>(); - for (StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice choice : chunk.getChoices()) { - String content = choice.getDelta().getContent(); - String refusal = choice.getDelta().getRefusal(); - List toolCalls = parseToolCalls(choice.getDelta().getToolCalls()); - results.add( - new StreamingUnifiedChatCompletionResults.Result( - content, - refusal, - toolCalls, - choice.getFinishReason(), - chunk.getModel(), - chunk.getObject(), - chunk.getUsage() - ) - ); - } - - return results.iterator(); - } - } - - private List parseToolCalls( - List toolCalls - ) { - List parsedToolCalls = new ArrayList<>(); - - if (toolCalls == null || toolCalls.isEmpty()) { - return null; - } - - for (StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall toolCall : toolCalls) { - int index = toolCall.getIndex(); - String id = toolCall.getId(); - String functionName = toolCall.getFunction() != null ? toolCall.getFunction().getName() : null; - String functionArguments = toolCall.getFunction() != null ? toolCall.getFunction().getArguments() : null; - parsedToolCalls.add(new StreamingUnifiedChatCompletionResults.ToolCall(index, id, functionName, functionArguments)); + return Collections.singleton(chunk).iterator(); } - return parsedToolCalls; } public static class ChatCompletionChunkParser { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListener.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListener.java index 3177474ea8ca6..73732216bfef5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListener.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListener.java @@ -336,38 +336,47 @@ public ReleasableBytesReference encodeChunk(int sizeHint, Recycler rec assert target == null; target = chunkStream; + logger.warn("encodeChunk1"); // if this is the first time we are encoding this chunk, write the SSE leading bytes if (isStartOfData.compareAndSet(true, false)) { target.write(ServerSentEventSpec.BOM); target.write(event.eventType); target.write(ServerSentEventSpec.EOL); target.write(ServerSentEventSpec.DATA); + logger.warn("encodeChunk2"); } // start or continue writing this chunk while (serialization.hasNext()) { + logger.warn("encodeChunk3"); serialization.next().toXContent(builder, params); + logger.warn("encodeChunk4"); if (chunkStream.size() >= sizeHint) { break; } } + logger.warn("encodeChunk5"); if (serialization.hasNext() == false) { + logger.warn("encodeChunk6"); // SSE wants two newlines between messages builder.close(); target.write(ServerSentEventSpec.EOL); target.write(ServerSentEventSpec.EOL); target.flush(); + logger.warn("encodeChunk7"); + } final var result = new ReleasableBytesReference(chunkStream.bytes(), () -> Releasables.closeExpectNoException(chunkStream)); + logger.warn("encodeChunk8"); target = null; return result; } catch (Exception e) { - logger.error("failure encoding chunk", e); + logger.error("failure encoding chunk 1", e); throw e; } finally { if (target != null) { - assert false : "failure encoding chunk"; + assert false : "failure encoding chunk 2"; IOUtils.closeWhileHandlingException(target); target = null; } @@ -427,11 +436,11 @@ public ReleasableBytesReference encodeChunk(int sizeHint, Recycler rec isPartComplete = true; return new ReleasableBytesReference(chunkStream.bytes(), () -> Releasables.closeExpectNoException(chunkStream)); } catch (Exception e) { - logger.error("failure encoding chunk", e); + logger.error("failure encoding chunk 3", e); throw e; } finally { if (isPartComplete == false) { - assert false : "failure encoding chunk"; + assert false : "failure encoding chunk 4"; IOUtils.closeWhileHandlingException(chunkStream); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessorTests.java index 8f1e720a128b2..ea7754279d4ee 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessorTests.java @@ -76,14 +76,13 @@ public void testJsonLiteral() { List choices = chunk.getChoices(); assertEquals(1, choices.size()); StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice choice = choices.get(0); - assertEquals("example_content", choice.getDelta().getContent()); - assertNull(choice.getDelta().getRefusal()); - assertEquals("assistant", choice.getDelta().getRole()); - assertEquals("stop", choice.getFinishReason()); - assertEquals(0, choice.getIndex()); + assertEquals("example_content", choice.delta().getContent()); + assertNull(choice.delta().getRefusal()); + assertEquals("assistant", choice.delta().getRole()); + assertEquals("stop", choice.finishReason()); + assertEquals(0, choice.index()); - List toolCalls = choice.getDelta() - .getToolCalls(); + List toolCalls = choice.delta().getToolCalls(); assertEquals(1, toolCalls.size()); StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall toolCall = toolCalls.get(0); assertEquals(1, toolCall.getIndex()); @@ -161,22 +160,22 @@ public void testJsonLiteralCornerCases() { // First choice assertions StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice firstChoice = choices.get(0); - assertNull(firstChoice.getDelta().getContent()); - assertNull(firstChoice.getDelta().getRefusal()); - assertEquals("assistant", firstChoice.getDelta().getRole()); - assertTrue(firstChoice.getDelta().getToolCalls().isEmpty()); - assertNull(firstChoice.getFinishReason()); - assertEquals(0, firstChoice.getIndex()); + assertNull(firstChoice.delta().getContent()); + assertNull(firstChoice.delta().getRefusal()); + assertEquals("assistant", firstChoice.delta().getRole()); + assertTrue(firstChoice.delta().getToolCalls().isEmpty()); + assertNull(firstChoice.finishReason()); + assertEquals(0, firstChoice.index()); // Second choice assertions StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice secondChoice = choices.get(1); - assertEquals("example_content", secondChoice.getDelta().getContent()); - assertEquals("example_refusal", secondChoice.getDelta().getRefusal()); - assertEquals("user", secondChoice.getDelta().getRole()); - assertEquals("stop", secondChoice.getFinishReason()); - assertEquals(1, secondChoice.getIndex()); + assertEquals("example_content", secondChoice.delta().getContent()); + assertEquals("example_refusal", secondChoice.delta().getRefusal()); + assertEquals("user", secondChoice.delta().getRole()); + assertEquals("stop", secondChoice.finishReason()); + assertEquals(1, secondChoice.index()); - List toolCalls = secondChoice.getDelta() + List toolCalls = secondChoice.delta() .getToolCalls(); assertEquals(1, toolCalls.size()); StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall toolCall = toolCalls.get(0); @@ -240,14 +239,13 @@ public void testOpenAiUnifiedStreamingProcessorParsing() throws IOException { List choices = chunk.getChoices(); assertEquals(1, choices.size()); StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice choice = choices.get(0); - assertEquals(choiceContent, choice.getDelta().getContent()); - assertNull(choice.getDelta().getRefusal()); - assertEquals(choiceRole, choice.getDelta().getRole()); - assertEquals(choiceFinishReason, choice.getFinishReason()); - assertEquals(choiceIndex, choice.getIndex()); + assertEquals(choiceContent, choice.delta().getContent()); + assertNull(choice.delta().getRefusal()); + assertEquals(choiceRole, choice.delta().getRole()); + assertEquals(choiceFinishReason, choice.finishReason()); + assertEquals(choiceIndex, choice.index()); - List toolCalls = choice.getDelta() - .getToolCalls(); + List toolCalls = choice.delta().getToolCalls(); assertEquals(1, toolCalls.size()); StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall toolCall = toolCalls.get(0); assertEquals(toolCallIndex, toolCall.getIndex()); @@ -290,12 +288,12 @@ public void testOpenAiUnifiedStreamingProcessorParsingWithNullFields() throws IO List choices = chunk.getChoices(); assertEquals(1, choices.size()); StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice choice = choices.get(0); - assertNull(choice.getDelta().getContent()); - assertNull(choice.getDelta().getRefusal()); - assertNull(choice.getDelta().getRole()); - assertNull(choice.getFinishReason()); - assertEquals(choiceIndex, choice.getIndex()); - assertTrue(choice.getDelta().getToolCalls().isEmpty()); + assertNull(choice.delta().getContent()); + assertNull(choice.delta().getRefusal()); + assertNull(choice.delta().getRole()); + assertNull(choice.finishReason()); + assertEquals(choiceIndex, choice.index()); + assertTrue(choice.delta().getToolCalls().isEmpty()); } } From 6e3db613ee890388dec043ecf82ec290b5e07d82 Mon Sep 17 00:00:00 2001 From: Max Hniebergall Date: Wed, 4 Dec 2024 11:35:23 -0500 Subject: [PATCH 34/53] Finalize response format --- ...StreamingUnifiedChatCompletionResults.java | 8 +---- .../inference/common/DelegatingProcessor.java | 9 +++++- .../OpenAiUnifiedStreamingProcessor.java | 30 ++++++++++++++++++- 3 files changed, 38 insertions(+), 9 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java index 23af412ee3692..7008309feb809 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java @@ -89,13 +89,7 @@ public Iterator toXContentChunked(ToXContent.Params params public record Results(Deque chunks) implements ChunkedToXContent { @Override public Iterator toXContentChunked(ToXContent.Params params) { - return Iterators.concat( - ChunkedToXContentHelper.startObject(), - ChunkedToXContentHelper.startArray(NAME), - Iterators.flatMap(chunks.iterator(), c -> c.toXContentChunked(params)), - ChunkedToXContentHelper.endArray(), - ChunkedToXContentHelper.endObject() - ); + return Iterators.concat(Iterators.flatMap(chunks.iterator(), c -> c.toXContentChunked(params))); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/DelegatingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/DelegatingProcessor.java index 03e794e42c3a2..3045246344921 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/DelegatingProcessor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/DelegatingProcessor.java @@ -51,7 +51,7 @@ public void request(long n) { if (isClosed.get()) { downstream.onComplete(); } else if (upstream != null) { - upstream.request(n); + onRequest(n); } else { pendingRequests.accumulateAndGet(n, Long::sum); } @@ -67,6 +67,13 @@ public void cancel() { }; } + /** + * Guaranteed to be called when the upstream is set and this processor had not been closed. + */ + protected void onRequest(long n) { + upstream.request(n); + } + protected void onCancel() {} @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java index 50cb5275c4b36..4023d0d0936b7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java @@ -28,6 +28,7 @@ import java.util.Deque; import java.util.Iterator; import java.util.List; +import java.util.concurrent.LinkedBlockingDeque; import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken; @@ -58,6 +59,17 @@ public class OpenAiUnifiedStreamingProcessor extends DelegatingProcessor buffer = new LinkedBlockingDeque<>(); + + @Override + protected void onRequest(long n) { + if (buffer.isEmpty()) { + super.onRequest(n); + } else { + downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(singleItem(buffer.poll()))); + } + } + @Override protected void next(Deque item) throws Exception { var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); @@ -77,8 +89,16 @@ protected void next(Deque item) throws Exception { if (results.isEmpty()) { upstream().request(1); - } else { + } + if (results.size() == 1) { downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(results)); + } else { + // results > 1, but openai spec only wants 1 chunk per SSE event + var firstItem = singleItem(results.poll()); + while (results.isEmpty() == false) { + buffer.offer(results.poll()); + } + downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(firstItem)); } } @@ -270,4 +290,12 @@ public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage pa } } } + + private Deque singleItem( + StreamingUnifiedChatCompletionResults.ChatCompletionChunk result + ) { + var deque = new ArrayDeque(2); + deque.offer(result); + return deque; + } } From 56735c663e197ebcd7587411825b5a131a070e15 Mon Sep 17 00:00:00 2001 From: Max Hniebergall Date: Wed, 4 Dec 2024 11:36:12 -0500 Subject: [PATCH 35/53] remove debug logs --- .../rest/ServerSentEventsRestActionListener.java | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListener.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListener.java index 73732216bfef5..7fcff1ede2d7d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListener.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListener.java @@ -336,47 +336,39 @@ public ReleasableBytesReference encodeChunk(int sizeHint, Recycler rec assert target == null; target = chunkStream; - logger.warn("encodeChunk1"); // if this is the first time we are encoding this chunk, write the SSE leading bytes if (isStartOfData.compareAndSet(true, false)) { target.write(ServerSentEventSpec.BOM); target.write(event.eventType); target.write(ServerSentEventSpec.EOL); target.write(ServerSentEventSpec.DATA); - logger.warn("encodeChunk2"); } // start or continue writing this chunk while (serialization.hasNext()) { - logger.warn("encodeChunk3"); serialization.next().toXContent(builder, params); - logger.warn("encodeChunk4"); if (chunkStream.size() >= sizeHint) { break; } } - logger.warn("encodeChunk5"); if (serialization.hasNext() == false) { - logger.warn("encodeChunk6"); // SSE wants two newlines between messages builder.close(); target.write(ServerSentEventSpec.EOL); target.write(ServerSentEventSpec.EOL); target.flush(); - logger.warn("encodeChunk7"); } final var result = new ReleasableBytesReference(chunkStream.bytes(), () -> Releasables.closeExpectNoException(chunkStream)); - logger.warn("encodeChunk8"); target = null; return result; } catch (Exception e) { - logger.error("failure encoding chunk 1", e); + logger.error("failure encoding chunk", e); throw e; } finally { if (target != null) { - assert false : "failure encoding chunk 2"; + assert false : "failure encoding chunk"; IOUtils.closeWhileHandlingException(target); target = null; } @@ -436,11 +428,11 @@ public ReleasableBytesReference encodeChunk(int sizeHint, Recycler rec isPartComplete = true; return new ReleasableBytesReference(chunkStream.bytes(), () -> Releasables.closeExpectNoException(chunkStream)); } catch (Exception e) { - logger.error("failure encoding chunk 3", e); + logger.error("failure encoding chunk", e); throw e; } finally { if (isPartComplete == false) { - assert false : "failure encoding chunk 4"; + assert false : "failure encoding chunk"; IOUtils.closeWhileHandlingException(chunkStream); } } From 0fb9a178d6426830d7bebe416b813900d6897c15 Mon Sep 17 00:00:00 2001 From: Max Hniebergall Date: Wed, 4 Dec 2024 11:40:27 -0500 Subject: [PATCH 36/53] remove changes for debugging --- .../org/elasticsearch/rest/ChunkedRestResponseBodyPart.java | 4 ++-- .../main/java/org/elasticsearch/rest/ChunkedZipResponse.java | 4 ++-- .../org/elasticsearch/rest/StreamingXContentResponse.java | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/rest/ChunkedRestResponseBodyPart.java b/server/src/main/java/org/elasticsearch/rest/ChunkedRestResponseBodyPart.java index 75bffdafd8db7..694af7e1606cb 100644 --- a/server/src/main/java/org/elasticsearch/rest/ChunkedRestResponseBodyPart.java +++ b/server/src/main/java/org/elasticsearch/rest/ChunkedRestResponseBodyPart.java @@ -173,11 +173,11 @@ public ReleasableBytesReference encodeChunk(int sizeHint, Recycler rec target = null; return result; } catch (Exception e) { - logger.error("failure encoding chunk 5", e); + logger.error("failure encoding chunk", e); throw e; } finally { if (target != null) { - assert false : "failure encoding chunk 6"; + assert false : "failure encoding chunk"; IOUtils.closeWhileHandlingException(target); target = null; } diff --git a/server/src/main/java/org/elasticsearch/rest/ChunkedZipResponse.java b/server/src/main/java/org/elasticsearch/rest/ChunkedZipResponse.java index 72d1a835102bb..585f5f3b1b4d1 100644 --- a/server/src/main/java/org/elasticsearch/rest/ChunkedZipResponse.java +++ b/server/src/main/java/org/elasticsearch/rest/ChunkedZipResponse.java @@ -412,11 +412,11 @@ public ReleasableBytesReference encodeChunk(int sizeHint, Recycler rec return new ReleasableBytesReference(BytesArray.EMPTY, () -> {}); } } catch (Exception e) { - logger.error("failure encoding chunk 7", e); + logger.error("failure encoding chunk", e); throw e; } finally { if (targetStream != null) { - assert false : "failure encoding chunk 8"; + assert false : "failure encoding chunk"; IOUtils.closeWhileHandlingException(targetStream, Releasables.wrap(releasables)); targetStream = null; } diff --git a/server/src/main/java/org/elasticsearch/rest/StreamingXContentResponse.java b/server/src/main/java/org/elasticsearch/rest/StreamingXContentResponse.java index 19da57ba56063..db33673939ae9 100644 --- a/server/src/main/java/org/elasticsearch/rest/StreamingXContentResponse.java +++ b/server/src/main/java/org/elasticsearch/rest/StreamingXContentResponse.java @@ -364,11 +364,11 @@ public ReleasableBytesReference encodeChunk(int sizeHint, Recycler rec return new ReleasableBytesReference(BytesArray.EMPTY, () -> {}); } } catch (Exception e) { - logger.error("failure encoding chunk 9", e); + logger.error("failure encoding chunk", e); throw e; } finally { if (targetStream != null) { - assert false : "failure encoding chunk 10"; + assert false : "failure encoding chunk"; IOUtils.closeWhileHandlingException(targetStream, Releasables.wrap(releasables)); targetStream = null; } From a530f02779383bc66d3156c462ba439e60a38ff6 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Wed, 4 Dec 2024 12:56:53 -0500 Subject: [PATCH 37/53] Task type and base inference action tests --- .../inference/InferenceService.java | 4 +- .../org/elasticsearch/inference/TaskType.java | 4 + .../inference/UnifiedCompletionRequest.java | 2 - .../xpack/inference/TaskTypeTests.java | 27 ++++++ .../rest/BaseInferenceActionTests.java | 43 ++++++++++ ...UnifiedCompletionInferenceActionTests.java | 83 +++++++++++++++++++ 6 files changed, 158 insertions(+), 5 deletions(-) create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/TaskTypeTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceActionTests.java diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceService.java b/server/src/main/java/org/elasticsearch/inference/InferenceService.java index 333ce76a69594..c2d690d8160ac 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceService.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceService.java @@ -127,9 +127,7 @@ void unifiedCompletionInfer( ); /** - * Chunk long text according to {@code chunkingOptions} or the - * model defaults if {@code chunkingOptions} contains unset - * values. + * Chunk long text. * * @param model The model * @param query Inference query, mainly for re-ranking diff --git a/server/src/main/java/org/elasticsearch/inference/TaskType.java b/server/src/main/java/org/elasticsearch/inference/TaskType.java index b0e5bababbbc0..fcb8ea7213795 100644 --- a/server/src/main/java/org/elasticsearch/inference/TaskType.java +++ b/server/src/main/java/org/elasticsearch/inference/TaskType.java @@ -38,6 +38,10 @@ public static TaskType fromString(String name) { } public static TaskType fromStringOrStatusException(String name) { + if (name == null) { + throw new ElasticsearchStatusException("Task type must not be null", RestStatus.BAD_REQUEST); + } + try { TaskType taskType = TaskType.fromString(name); return Objects.requireNonNull(taskType); diff --git a/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java b/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java index 8827dbd6ea919..49caa86e113fd 100644 --- a/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java +++ b/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java @@ -63,8 +63,6 @@ public sealed interface Content extends NamedWriteable permits ContentObjects, C PARSER.declareObjectArray(constructorArg(), Message.PARSER::apply, new ParseField("messages")); PARSER.declareString(optionalConstructorArg(), new ParseField("model")); PARSER.declareLong(optionalConstructorArg(), new ParseField("max_completion_tokens")); - // PARSER.declareField(optionalConstructorArg(), (p, c) -> parseStop(p), new ParseField("stop"), - // ObjectParser.ValueType.VALUE_ARRAY); PARSER.declareStringArray(optionalConstructorArg(), new ParseField("stop")); PARSER.declareFloat(optionalConstructorArg(), new ParseField("temperature")); PARSER.declareField( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/TaskTypeTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/TaskTypeTests.java new file mode 100644 index 0000000000000..f6c058bdbb79f --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/TaskTypeTests.java @@ -0,0 +1,27 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.hamcrest.Matchers; + +public class TaskTypeTests extends ESTestCase { + + public void testFromStringOrStatusException() { + var exception = expectThrows(ElasticsearchStatusException.class, () -> TaskType.fromStringOrStatusException(null)); + assertThat(exception.getMessage(), Matchers.is("Task type must not be null")); + + exception = expectThrows(ElasticsearchStatusException.class, () -> TaskType.fromStringOrStatusException("blah")); + assertThat(exception.getMessage(), Matchers.is("Unknown task_type [blah]")); + + assertThat(TaskType.fromStringOrStatusException("any"), Matchers.is(TaskType.ANY)); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java index 05a8d52be5df4..9a0213d1e18b0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java @@ -8,11 +8,14 @@ package org.elasticsearch.xpack.inference.rest; import org.apache.lucene.util.SetOnce; +import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestChannel; import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.RestRequestTests; import org.elasticsearch.rest.action.RestChunkedToXContentListener; import org.elasticsearch.test.rest.FakeRestRequest; import org.elasticsearch.test.rest.RestActionTestCase; @@ -26,6 +29,10 @@ import java.util.Map; import static org.elasticsearch.rest.RestRequest.Method.POST; +import static org.elasticsearch.xpack.inference.rest.BaseInferenceAction.parseParams; +import static org.elasticsearch.xpack.inference.rest.BaseInferenceAction.parseTimeout; +import static org.elasticsearch.xpack.inference.rest.Paths.INFERENCE_ID; +import static org.elasticsearch.xpack.inference.rest.Paths.TASK_TYPE_OR_INFERENCE_ID; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; @@ -56,6 +63,42 @@ private static String route(String param) { return "_route/" + param; } + public void testParseParams_ExtractsInferenceIdAndTaskType() { + var params = parseParams( + RestRequestTests.contentRestRequest("{}", Map.of(INFERENCE_ID, "id", TASK_TYPE_OR_INFERENCE_ID, TaskType.COMPLETION.toString())) + ); + assertThat(params, is(new BaseInferenceAction.Params("id", TaskType.COMPLETION))); + } + + public void testParseParams_DefaultsToTaskTypeAny_WhenInferenceId_IsMissing() { + var params = parseParams( + RestRequestTests.contentRestRequest("{}", Map.of(TASK_TYPE_OR_INFERENCE_ID, TaskType.COMPLETION.toString())) + ); + assertThat(params, is(new BaseInferenceAction.Params(TASK_TYPE_OR_INFERENCE_ID, TaskType.ANY))); + } + + public void testParseParams_ThrowsStatusException_WhenTaskTypeIsMissing() { + var e = expectThrows( + ElasticsearchStatusException.class, + () -> parseParams(RestRequestTests.contentRestRequest("{}", Map.of(INFERENCE_ID, "id"))) + ); + assertThat(e.getMessage(), is("Task type must not be null")); + } + + public void testParseTimeout_ReturnsTimeout() { + var timeout = parseTimeout( + RestRequestTests.contentRestRequest("{}", Map.of(InferenceAction.Request.TIMEOUT.getPreferredName(), "4s")) + ); + + assertThat(timeout, is(TimeValue.timeValueSeconds(4))); + } + + public void testParseTimeout_ReturnsDefaultTimeout() { + var timeout = parseTimeout(RestRequestTests.contentRestRequest("{}", Map.of())); + + assertThat(timeout, is(TimeValue.timeValueSeconds(30))); + } + public void testUsesDefaultTimeout() { SetOnce executeCalled = new SetOnce<>(); verifyingClient.setExecuteVerifier(((actionType, actionRequest) -> { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceActionTests.java new file mode 100644 index 0000000000000..264a658c8fd3b --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceActionTests.java @@ -0,0 +1,83 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.rest; + +import org.apache.lucene.util.SetOnce; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.rest.AbstractRestChannel; +import org.elasticsearch.rest.RestChannel; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.RestResponse; +import org.elasticsearch.test.rest.FakeRestRequest; +import org.elasticsearch.test.rest.RestActionTestCase; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction; +import org.junit.Before; + +import static org.elasticsearch.xpack.inference.rest.BaseInferenceActionTests.createResponse; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; + +public class RestUnifiedCompletionInferenceActionTests extends RestActionTestCase { + + @Before + public void setUpAction() { + controller().registerHandler(new RestUnifiedCompletionInferenceAction()); + } + + public void testStreamIsTrue() { + SetOnce executeCalled = new SetOnce<>(); + verifyingClient.setExecuteVerifier(((actionType, actionRequest) -> { + assertThat(actionRequest, instanceOf(UnifiedCompletionAction.Request.class)); + + var request = (UnifiedCompletionAction.Request) actionRequest; + assertThat(request.isStreaming(), is(true)); + + executeCalled.set(true); + return createResponse(); + })); + + var requestBody = """ + { + "messages": [ + { + "content": "abc", + "role": "user" + } + ] + } + """; + + RestRequest inferenceRequest = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath("_inference/completion/test/_unified") + .withContent(new BytesArray(requestBody), XContentType.JSON) + .build(); + + final SetOnce responseSetOnce = new SetOnce<>(); + dispatchRequest(inferenceRequest, new AbstractRestChannel(inferenceRequest, true) { + @Override + public void sendResponse(RestResponse response) { + responseSetOnce.set(response); + } + }); + + // the response content will be null when there is no error + assertNull(responseSetOnce.get().content()); + // var responseBody = responseSetOnce.get().content().utf8ToString(); + // assertThat(Objects.requireNonNull(responseSetOnce.get().content()).utf8ToString(), equalTo(createResponse())); + assertThat(executeCalled.get(), equalTo(true)); + } + + private void dispatchRequest(final RestRequest request, final RestChannel channel) { + ThreadContext threadContext = new ThreadContext(Settings.EMPTY); + controller().dispatchRequest(request, channel, threadContext); + } +} From df1b006ed14c4c2d04f351d39a269f5e8e5eab81 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Wed, 4 Dec 2024 14:36:41 -0500 Subject: [PATCH 38/53] Adding openai service tests --- ...nAiUnifiedChatCompletionRequestEntity.java | 7 +++ ...enAiUnifiedChatCompletionRequestTests.java | 46 +++++++------- .../rest/BaseInferenceActionTests.java | 2 +- ...UnifiedCompletionInferenceActionTests.java | 2 - .../services/openai/OpenAiServiceTests.java | 63 +++++++++++++++++++ 5 files changed, 96 insertions(+), 24 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java index c57f09d10fe53..74339e441b305 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java @@ -43,6 +43,8 @@ public class OpenAiUnifiedChatCompletionRequestEntity implements ToXContentObjec private static final String TOOL_FIELD = "tools"; private static final String TEXT_FIELD = "text"; private static final String TYPE_FIELD = "type"; + private static final String STREAM_OPTIONS_FIELD = "stream_options"; + private static final String INCLUDE_USAGE_FIELD = "include_usage"; private final UnifiedCompletionRequest unifiedRequest; private final boolean stream; @@ -169,6 +171,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } builder.field(STREAM_FIELD, stream); + if (stream) { + builder.startObject(STREAM_OPTIONS_FIELD); + builder.field(INCLUDE_USAGE_FIELD, true); + builder.endObject(); + } builder.endObject(); return builder; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestTests.java index 636c28126de0f..2be12c9b12e0b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestTests.java @@ -30,7 +30,7 @@ public class OpenAiUnifiedChatCompletionRequestTests extends ESTestCase { public void testCreateRequest_WithUrlOrganizationUserDefined() throws IOException { - var request = createRequest("www.google.com", "org", "secret", "abc", "model", "user"); + var request = createRequest("www.google.com", "org", "secret", "abc", "model", "user", true); var httpRequest = request.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); @@ -42,16 +42,27 @@ public void testCreateRequest_WithUrlOrganizationUserDefined() throws IOExceptio assertThat(httpPost.getLastHeader(ORGANIZATION_HEADER).getValue(), is("org")); var requestMap = entityAsMap(httpPost.getEntity().getContent()); - assertThat(requestMap, aMapWithSize(5)); + assertRequestMapWithUser(requestMap, "user"); + } + + private void assertRequestMapWithoutUser(Map requestMap) { + assertRequestMapWithUser(requestMap, null); + } + + private void assertRequestMapWithUser(Map requestMap, @Nullable String user) { + assertThat(requestMap, aMapWithSize(user != null ? 6 : 5)); assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc")))); assertThat(requestMap.get("model"), is("model")); - assertThat(requestMap.get("user"), is("user")); + if (user != null) { + assertThat(requestMap.get("user"), is(user)); + } assertThat(requestMap.get("n"), is(1)); assertTrue((Boolean) requestMap.get("stream")); + assertThat(requestMap.get("stream_options"), is(Map.of("include_usage", true))); } public void testCreateRequest_WithDefaultUrl() throws URISyntaxException, IOException { - var request = createRequest(null, "org", "secret", "abc", "model", "user"); + var request = createRequest(null, "org", "secret", "abc", "model", "user", true); var httpRequest = request.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); @@ -63,16 +74,12 @@ public void testCreateRequest_WithDefaultUrl() throws URISyntaxException, IOExce assertThat(httpPost.getLastHeader(ORGANIZATION_HEADER).getValue(), is("org")); var requestMap = entityAsMap(httpPost.getEntity().getContent()); - assertThat(requestMap, aMapWithSize(5)); - assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc")))); - assertThat(requestMap.get("model"), is("model")); - assertThat(requestMap.get("user"), is("user")); - assertThat(requestMap.get("n"), is(1)); - assertTrue((Boolean) requestMap.get("stream")); + assertRequestMapWithUser(requestMap, "user"); + } public void testCreateRequest_WithDefaultUrlAndWithoutUserOrganization() throws URISyntaxException, IOException { - var request = createRequest(null, null, "secret", "abc", "model", null); + var request = createRequest(null, null, "secret", "abc", "model", null, true); var httpRequest = request.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); @@ -84,14 +91,10 @@ public void testCreateRequest_WithDefaultUrlAndWithoutUserOrganization() throws assertNull(httpPost.getLastHeader(ORGANIZATION_HEADER)); var requestMap = entityAsMap(httpPost.getEntity().getContent()); - assertThat(requestMap, aMapWithSize(4)); - assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc")))); - assertThat(requestMap.get("model"), is("model")); - assertThat(requestMap.get("n"), is(1)); - assertTrue((Boolean) requestMap.get("stream")); + assertRequestMapWithoutUser(requestMap); } - public void testCreateRequest_WithStreaming() throws URISyntaxException, IOException { + public void testCreateRequest_WithStreaming() throws IOException { var request = createRequest(null, null, "secret", "abc", "model", null, true); var httpRequest = request.createHttpRequest(); @@ -103,7 +106,7 @@ public void testCreateRequest_WithStreaming() throws URISyntaxException, IOExcep } public void testTruncate_DoesNotReduceInputTextSize() throws URISyntaxException, IOException { - var request = createRequest(null, null, "secret", "abcd", "model", null); + var request = createRequest(null, null, "secret", "abcd", "model", null, true); var truncatedRequest = request.truncate(); assertThat(request.getURI().toString(), is(OpenAiUnifiedChatCompletionRequest.buildDefaultUri().toString())); @@ -112,17 +115,18 @@ public void testTruncate_DoesNotReduceInputTextSize() throws URISyntaxException, var httpPost = (HttpPost) httpRequest.httpRequestBase(); var requestMap = entityAsMap(httpPost.getEntity().getContent()); - assertThat(requestMap, aMapWithSize(4)); + assertThat(requestMap, aMapWithSize(5)); // We do not truncate for OpenAi chat completions assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abcd")))); assertThat(requestMap.get("model"), is("model")); assertThat(requestMap.get("n"), is(1)); assertTrue((Boolean) requestMap.get("stream")); + assertThat(requestMap.get("stream_options"), is(Map.of("include_usage", true))); } public void testTruncationInfo_ReturnsNull() { - var request = createRequest(null, null, "secret", "abcd", "model", null); + var request = createRequest(null, null, "secret", "abcd", "model", null, true); assertNull(request.getTruncationInfo()); } @@ -147,7 +151,7 @@ public static OpenAiUnifiedChatCompletionRequest createRequest( boolean stream ) { var chatCompletionModel = OpenAiChatCompletionModelTests.createChatCompletionModel(url, org, apiKey, model, user); - return new OpenAiUnifiedChatCompletionRequest(new UnifiedChatInput(List.of(input), "user", true), chatCompletionModel); + return new OpenAiUnifiedChatCompletionRequest(new UnifiedChatInput(List.of(input), "user", stream), chatCompletionModel); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java index 9a0213d1e18b0..5528c80066b0a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java @@ -74,7 +74,7 @@ public void testParseParams_DefaultsToTaskTypeAny_WhenInferenceId_IsMissing() { var params = parseParams( RestRequestTests.contentRestRequest("{}", Map.of(TASK_TYPE_OR_INFERENCE_ID, TaskType.COMPLETION.toString())) ); - assertThat(params, is(new BaseInferenceAction.Params(TASK_TYPE_OR_INFERENCE_ID, TaskType.ANY))); + assertThat(params, is(new BaseInferenceAction.Params("completion", TaskType.ANY))); } public void testParseParams_ThrowsStatusException_WhenTaskTypeIsMissing() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceActionTests.java index 264a658c8fd3b..5acfe67b175df 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceActionTests.java @@ -71,8 +71,6 @@ public void sendResponse(RestResponse response) { // the response content will be null when there is no error assertNull(responseSetOnce.get().content()); - // var responseBody = responseSetOnce.get().content().utf8ToString(); - // assertThat(Objects.requireNonNull(responseSetOnce.get().content()).utf8ToString(), equalTo(createResponse())); assertThat(executeCalled.get(), equalTo(true)); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index 76b5d6fee2c59..6792298d6c1ab 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -27,6 +27,7 @@ import org.elasticsearch.inference.Model; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.http.MockResponse; import org.elasticsearch.test.http.MockWebServer; @@ -920,6 +921,68 @@ public void testInfer_SendsRequest() throws IOException { } } + public void testUnifiedCompletionInfer() throws Exception { + // streaming response must be on a single line + String responseJson = """ + data: {\ + "id":"12345",\ + "object":"chat.completion.chunk",\ + "created":123456789,\ + "model":"gpt-4o-mini",\ + "system_fingerprint": "123456789",\ + "choices":[\ + {\ + "index":0,\ + "delta":{\ + "content":"hello, world"\ + },\ + "logprobs":null,\ + "finish_reason":"stop"\ + }\ + ],\ + "usage":{\ + "prompt_tokens": 16,\ + "completion_tokens": 28,\ + "total_tokens": 44,\ + "prompt_tokens_details": {\ + "cached_tokens": 0,\ + "audio_tokens": 0\ + },\ + "completion_tokens_details": {\ + "reasoning_tokens": 0,\ + "audio_tokens": 0,\ + "accepted_prediction_tokens": 0,\ + "rejected_prediction_tokens": 0\ + }\ + }\ + } + + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + var model = OpenAiChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "org", "secret", "model", "user"); + PlainActionFuture listener = new PlainActionFuture<>(); + service.unifiedCompletionInfer( + model, + UnifiedCompletionRequest.of( + List.of( + new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null, null) + ) + ), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoErrors().hasEvent(""" + {"id":"12345","choices":[{"delta":{"content":"hello, world"},"finish_reason":"stop","index":0}],""" + """ + "model":"gpt-4o-mini","object":"chat.completion.chunk",""" + """ + "usage":{"completion_tokens":28,"prompt_tokens":16,"total_tokens":44}}"""); + } + } + public void testInfer_StreamRequest() throws Exception { String responseJson = """ data: {\ From 0166d986792d741b6257f9ee78c7004a3ded8306 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Wed, 4 Dec 2024 14:48:38 -0500 Subject: [PATCH 39/53] Adding model tests --- .../OpenAiChatCompletionModelTests.java | 40 +++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModelTests.java index 61edb05f9bac0..e7ac4cf879e92 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModelTests.java @@ -10,9 +10,11 @@ import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import java.util.List; import java.util.Map; import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionRequestTaskSettingsTests.getChatCompletionRequestTaskSettingsMap; @@ -46,6 +48,44 @@ public void testOverrideWith_NullMap() { assertThat(overriddenModel, sameInstance(model)); } + public void testOverrideWith_UnifiedCompletionRequest_OverridesModelId() { + var model = createChatCompletionModel("url", "org", "api_key", "model_name", "user"); + var request = new UnifiedCompletionRequest( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null, null)), + "different_model", + null, + null, + null, + null, + null, + null + ); + + assertThat( + OpenAiChatCompletionModel.of(model, request), + is(createChatCompletionModel("url", "org", "api_key", "different_model", "user")) + ); + } + + public void testOverrideWith_UnifiedCompletionRequest_UsesModelFields_WhenRequestDoesNotOverride() { + var model = createChatCompletionModel("url", "org", "api_key", "model_name", "user"); + var request = new UnifiedCompletionRequest( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null, null)), + null, // not overriding model + null, + null, + null, + null, + null, + null + ); + + assertThat( + OpenAiChatCompletionModel.of(model, request), + is(createChatCompletionModel("url", "org", "api_key", "model_name", "user")) + ); + } + public static OpenAiChatCompletionModel createChatCompletionModel( String url, @Nullable String org, From 3dfb8f5fa1ac93a5fd7d78543df9458559c2a00b Mon Sep 17 00:00:00 2001 From: Max Hniebergall Date: Wed, 4 Dec 2024 14:57:25 -0500 Subject: [PATCH 40/53] tests for StreamingUnifiedChatCompletionResultsTests toXContentChunked --- ...StreamingUnifiedChatCompletionResults.java | 6 +- ...mingUnifiedChatCompletionResultsTests.java | 263 +++++++++++++++++- 2 files changed, 265 insertions(+), 4 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java index 7008309feb809..90038c67036c4 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java @@ -145,9 +145,9 @@ public Iterator toXContentChunked(ToXContent.Params params if (usage != null) { usageIterator = Iterators.concat( ChunkedToXContentHelper.startObject(USAGE_FIELD), - ChunkedToXContentHelper.field("completion_tokens", usage.completionTokens()), - ChunkedToXContentHelper.field("prompt_tokens", usage.promptTokens()), - ChunkedToXContentHelper.field("total_tokens", usage.totalTokens()), + ChunkedToXContentHelper.field(COMPLETION_TOKENS_FIELD, usage.completionTokens()), + ChunkedToXContentHelper.field(PROMPT_TOKENS_FIELD, usage.promptTokens()), + ChunkedToXContentHelper.field(TOTAL_TOKENS_FIELD, usage.totalTokens()), ChunkedToXContentHelper.endObject() ); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResultsTests.java index 351e1d97fee9a..61a4aa949d4ca 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResultsTests.java @@ -3,12 +3,273 @@ * or more contributor license agreements. Licensed under the Elastic License * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. + * + * this file was contributed to by a generative AI */ package org.elasticsearch.xpack.core.inference.results; +import org.elasticsearch.common.Strings; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.json.JsonXContent; + +import java.io.IOException; +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.List; public class StreamingUnifiedChatCompletionResultsTests extends ESTestCase { - // See OpenAiUnifiedStreamingProcessorTests.java + + public void testResultstoXContentChunked() throws IOException { + String expected = """ + { + "id": "chunk1", + "choices": [ + { + "delta": { + "content": "example_content", + "refusal": "example_refusal", + "role": "assistant", + "tool_calls": [ + { + "index": 1, + "id": "tool1", + "function": { + "arguments": "example_arguments", + "name": "example_function" + }, + "type": "function" + } + ] + }, + "finish_reason": "example_reason", + "index": 0 + } + ], + "model": "example_model", + "object": "example_object", + "usage": { + "completion_tokens": 10, + "prompt_tokens": 5, + "total_tokens": 15 + } + } + """; + + StreamingUnifiedChatCompletionResults.ChatCompletionChunk chunk = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk( + "chunk1", + List.of( + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice( + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta( + "example_content", + "example_refusal", + "assistant", + List.of( + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall( + 1, + "tool1", + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall.Function( + "example_arguments", + "example_function" + ), + "function" + ) + ) + ), + "example_reason", + 0 + ) + ), + "example_model", + "example_object", + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage(10, 5, 15) + ); + + Deque deque = new ArrayDeque<>(); + deque.add(chunk); + StreamingUnifiedChatCompletionResults.Results results = new StreamingUnifiedChatCompletionResults.Results(deque); + XContentBuilder builder = JsonXContent.contentBuilder(); + results.toXContentChunked(null).forEachRemaining(xContent -> { + try { + xContent.toXContent(builder, null); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + + assertEquals(expected.replaceAll("\\s+", ""), Strings.toString(builder.prettyPrint()).trim()); + } + + public void testChatCompletionChunkToXContentChunked() throws IOException { + String expected = """ + { + "id": "chunk1", + "choices": [ + { + "delta": { + "content": "example_content", + "refusal": "example_refusal", + "role": "assistant", + "tool_calls": [ + { + "index": 1, + "id": "tool1", + "function": { + "arguments": "example_arguments", + "name": "example_function" + }, + "type": "function" + } + ] + }, + "finish_reason": "example_reason", + "index": 0 + } + ], + "model": "example_model", + "object": "example_object", + "usage": { + "completion_tokens": 10, + "prompt_tokens": 5, + "total_tokens": 15 + } + } + """; + + StreamingUnifiedChatCompletionResults.ChatCompletionChunk chunk = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk( + "chunk1", + List.of( + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice( + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta( + "example_content", + "example_refusal", + "assistant", + List.of( + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall( + 1, + "tool1", + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall.Function( + "example_arguments", + "example_function" + ), + "function" + ) + ) + ), + "example_reason", + 0 + ) + ), + "example_model", + "example_object", + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage(10, 5, 15) + ); + + XContentBuilder builder = JsonXContent.contentBuilder(); + chunk.toXContentChunked(null).forEachRemaining(xContent -> { + try { + xContent.toXContent(builder, null); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + + assertEquals(expected.replaceAll("\\s+", ""), Strings.toString(builder.prettyPrint()).trim()); + } + + public void testChoiceToXContentChunked() throws IOException { + String expected = """ + { + "delta": { + "content": "example_content", + "refusal": "example_refusal", + "role": "assistant", + "tool_calls": [ + { + "index": 1, + "id": "tool1", + "function": { + "arguments": "example_arguments", + "name": "example_function" + }, + "type": "function" + } + ] + }, + "finish_reason": "example_reason", + "index": 0 + } + """; + + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice choice = + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice( + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta( + "example_content", + "example_refusal", + "assistant", + List.of( + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall( + 1, + "tool1", + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall.Function( + "example_arguments", + "example_function" + ), + "function" + ) + ) + ), + "example_reason", + 0 + ); + + XContentBuilder builder = JsonXContent.contentBuilder(); + choice.toXContentChunked(null).forEachRemaining(xContent -> { + try { + xContent.toXContent(builder, null); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + + assertEquals(expected.replaceAll("\\s+", ""), Strings.toString(builder.prettyPrint()).trim()); + } + + public void testToolCallToXContentChunked() throws IOException { + String expected = """ + { + "index": 1, + "id": "tool1", + "function": { + "arguments": "example_arguments", + "name": "example_function" + }, + "type": "function" + } + """; + + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall toolCall = + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall( + 1, + "tool1", + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall.Function( + "example_arguments", + "example_function" + ), + "function" + ); + + XContentBuilder builder = JsonXContent.contentBuilder(); + toolCall.toXContentChunked(null).forEachRemaining(xContent -> { + try { + xContent.toXContent(builder, null); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + + assertEquals(expected.replaceAll("\\s+", ""), Strings.toString(builder.prettyPrint()).trim()); + } + } From 9d81db0295762cbeef57257f2e17106b87e65961 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Wed, 4 Dec 2024 15:06:27 -0500 Subject: [PATCH 41/53] Fixing change log and removing commented out code --- docs/changelog/117589.yaml | 2 +- .../inference/UnifiedCompletionRequest.java | 55 -------- .../org/elasticsearch/test/ESTestCase.java | 122 +++++++----------- 3 files changed, 46 insertions(+), 133 deletions(-) diff --git a/docs/changelog/117589.yaml b/docs/changelog/117589.yaml index 2a2a483dc7bde..e6880fd9477b5 100644 --- a/docs/changelog/117589.yaml +++ b/docs/changelog/117589.yaml @@ -1,5 +1,5 @@ pr: 117589 -summary: "[Inference API] Add unified api for chat completions" +summary: "Add Inference Unified API for chat completions for OpenAI" area: Machine Learning type: enhancement issues: [] diff --git a/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java b/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java index 49caa86e113fd..9b632776646f8 100644 --- a/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java +++ b/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java @@ -279,61 +279,6 @@ public void writeTo(StreamOutput out) throws IOException { } } - // private static Stop parseStop(XContentParser parser) throws IOException { - // var token = parser.currentToken(); - // if (token == XContentParser.Token.START_ARRAY) { - // var parsedStopValues = XContentParserUtils.parseList(parser, XContentParser::text); - // return new StopValues(parsedStopValues); - // } else if (token == XContentParser.Token.VALUE_STRING) { - // return StopString.of(parser); - // } - // - // throw new XContentParseException("Unsupported token [" + token + "]"); - // } - - // public sealed interface Stop extends NamedWriteable permits StopString, StopValues {} - // - // public record StopString(String value) implements Stop, NamedWriteable { - // public static final String NAME = "stop_string"; - // - // public static StopString of(XContentParser parser) throws IOException { - // var content = parser.text(); - // return new StopString(content); - // } - // - // public StopString(StreamInput in) throws IOException { - // this(in.readString()); - // } - // - // @Override - // public void writeTo(StreamOutput out) throws IOException { - // out.writeString(value); - // } - // - // @Override - // public String getWriteableName() { - // return NAME; - // } - // } - // - // public record StopValues(List values) implements Stop, NamedWriteable { - // public static final String NAME = "stop_values"; - // - // public StopValues(StreamInput in) throws IOException { - // this(in.readStringCollectionAsImmutableList()); - // } - // - // @Override - // public void writeTo(StreamOutput out) throws IOException { - // out.writeStringCollection(values); - // } - // - // @Override - // public String getWriteableName() { - // return NAME; - // } - // } - private static ToolChoice parseToolChoice(XContentParser parser) throws IOException { var token = parser.currentToken(); if (token == XContentParser.Token.START_OBJECT) { diff --git a/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java index 19bcb93c289d3..a71f61740e17b 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java @@ -254,8 +254,11 @@ public abstract class ESTestCase extends LuceneTestCase { protected static final List JAVA_TIMEZONE_IDS; protected static final List JAVA_ZONE_IDS; + private static final AtomicInteger portGenerator = new AtomicInteger(); + private static final Collection loggedLeaks = new ArrayList<>(); + private HeaderWarningAppender headerWarningAppender; @AfterClass @@ -265,9 +268,13 @@ public static void resetPortCounter() { // Allows distinguishing between parallel test processes public static final String TEST_WORKER_VM_ID; + public static final String TEST_WORKER_SYS_PROPERTY = "org.gradle.test.worker"; + public static final String DEFAULT_TEST_WORKER_ID = "--not-gradle--"; + public static final String FIPS_SYSPROP = "tests.fips.enabled"; + private static final SetOnce WARN_SECURE_RANDOM_FIPS_NOT_DETERMINISTIC = new SetOnce<>(); static { @@ -436,6 +443,7 @@ private static void setTestSysProps(Random random) { // ----------------------------------------------------------------- // Suite and test case setup/cleanup. // ----------------------------------------------------------------- + @Rule public RuleChain failureAndSuccessEvents = RuleChain.outerRule(new TestRuleAdapter() { @Override @@ -474,9 +482,7 @@ public static TransportAddress buildNewFakeTransportAddress() { */ protected void afterIfFailed(List errors) {} - /** - * called after a test is finished, but only if successful - */ + /** called after a test is finished, but only if successful */ protected void afterIfSuccessful() throws Exception {} /** @@ -486,7 +492,6 @@ protected void afterIfSuccessful() throws Exception {} @Target({ ElementType.TYPE }) @Inherited public @interface WithoutSecurityManager { - } private static Closeable securityManagerRestorer; @@ -663,7 +668,6 @@ protected List filteredWarnings() { /** * Convenience method to assert warnings for settings deprecations and general deprecation warnings. - * * @param settings the settings that are expected to be deprecated * @param warnings other expected general deprecation warnings */ @@ -684,7 +688,6 @@ protected final void assertSettingDeprecationsAndWarnings(final Setting[] set /** * Convenience method to assert warnings for settings deprecations and general deprecation warnings. All warnings passed to this method * are assumed to be at WARNING level. - * * @param expectedWarnings expected general deprecation warning messages. */ protected final void assertWarnings(String... expectedWarnings) { @@ -699,7 +702,6 @@ protected final void assertWarnings(String... expectedWarnings) { /** * Convenience method to assert warnings for settings deprecations and general deprecation warnings. All warnings passed to this method * are assumed to be at CRITICAL level. - * * @param expectedWarnings expected general deprecation warning messages. */ protected final void assertCriticalWarnings(String... expectedWarnings) { @@ -767,19 +769,20 @@ private void resetDeprecationLogger() { } private static final List statusData = new ArrayList<>(); - static { // ensure that the status logger is set to the warn level so we do not miss any warnings with our Log4j usage StatusLogger.getLogger().setLevel(Level.WARN); // Log4j will write out status messages indicating problems with the Log4j usage to the status logger; we hook into this logger and // assert that no such messages were written out as these would indicate a problem with our logging configuration StatusLogger.getLogger().registerListener(new StatusConsoleListener(Level.WARN) { + @Override public void log(StatusData data) { synchronized (statusData) { statusData.add(data); } } + }); } @@ -840,9 +843,8 @@ public final void ensureAllSearchContextsReleased() throws Exception { // mockdirectorywrappers currently set this boolean if checkindex fails // TODO: can we do this cleaner??? - /** - * MockFSDirectoryService sets this: - */ + + /** MockFSDirectoryService sets this: */ public static final List checkIndexFailures = new CopyOnWriteArrayList<>(); @Before @@ -1137,48 +1139,37 @@ public static LongStream randomLongs(long streamSize) { * Returns a random BigInteger uniformly distributed over the range 0 to (2^64 - 1) inclusive * Currently BigIntegers are only used for unsigned_long field type, where the max value is 2^64 - 1. * Modify this random generator if a wider range for BigIntegers is necessary. - * * @return a random bigInteger in the range [0 ; 2^64 - 1] */ public static BigInteger randomBigInteger() { return new BigInteger(64, random()); } - /** - * A random integer from 0..max (inclusive). - */ + /** A random integer from 0..max (inclusive). */ public static int randomInt(int max) { return RandomizedTest.randomInt(max); } - /** - * A random byte size value. - */ + /** A random byte size value. */ public static ByteSizeValue randomByteSizeValue() { return ByteSizeValue.ofBytes(randomLongBetween(0L, Long.MAX_VALUE >> 16)); } - /** - * Pick a random object from the given array. The array must not be empty. - */ + /** Pick a random object from the given array. The array must not be empty. */ @SafeVarargs @SuppressWarnings("varargs") public static T randomFrom(T... array) { return randomFrom(random(), array); } - /** - * Pick a random object from the given array. The array must not be empty. - */ + /** Pick a random object from the given array. The array must not be empty. */ @SafeVarargs @SuppressWarnings("varargs") public static T randomFrom(Random random, T... array) { return RandomPicks.randomFrom(random, array); } - /** - * Pick a random object from the given array of suppliers. The array must not be empty. - */ + /** Pick a random object from the given array of suppliers. The array must not be empty. */ @SafeVarargs @SuppressWarnings("varargs") public static T randomFrom(Random random, Supplier... array) { @@ -1186,23 +1177,17 @@ public static T randomFrom(Random random, Supplier... array) { return supplier.get(); } - /** - * Pick a random object from the given list. - */ + /** Pick a random object from the given list. */ public static T randomFrom(List list) { return RandomPicks.randomFrom(random(), list); } - /** - * Pick a random object from the given collection. - */ + /** Pick a random object from the given collection. */ public static T randomFrom(Collection collection) { return randomFrom(random(), collection); } - /** - * Pick a random object from the given collection. - */ + /** Pick a random object from the given collection. */ public static T randomFrom(Random random, Collection collection) { return RandomPicks.randomFrom(random, collection); } @@ -1289,9 +1274,9 @@ public static String randomRealisticUnicodeOfCodepointLength(int codePoints) { /** * @param maxArraySize The maximum number of elements in the random array - * @param stringSize The length of each String in the array - * @param allowNull Whether the returned array may be null - * @param allowEmpty Whether the returned array may be empty (have zero elements) + * @param stringSize The length of each String in the array + * @param allowNull Whether the returned array may be null + * @param allowEmpty Whether the returned array may be empty (have zero elements) */ public static String[] generateRandomStringArray(int maxArraySize, int stringSize, boolean allowNull, boolean allowEmpty) { if (allowNull && random().nextBoolean()) { @@ -1503,8 +1488,8 @@ public static boolean waitUntil(BooleanSupplier breakSupplier) { * {@link ESTestCase#assertBusy(CheckedRunnable)} instead. * * @param breakSupplier determines whether to return immediately or continue waiting. - * @param maxWaitTime the maximum amount of time to wait - * @param unit the unit of tie for maxWaitTime + * @param maxWaitTime the maximum amount of time to wait + * @param unit the unit of tie for maxWaitTime * @return the last value returned by breakSupplier */ public static boolean waitUntil(BooleanSupplier breakSupplier, long maxWaitTime, TimeUnit unit) { @@ -1575,9 +1560,7 @@ public static Path getResourceDataPath(Class clazz, String relativePath) { } } - /** - * Returns a random number of temporary paths. - */ + /** Returns a random number of temporary paths. */ public String[] tmpPaths() { final int numPaths = TestUtil.nextInt(random(), 1, 3); final String[] absPaths = new String[numPaths]; @@ -1614,24 +1597,18 @@ public Environment newEnvironment(Settings settings) { return TestEnvironment.newEnvironment(build); } - /** - * Return consistent index settings for the provided index version. - */ + /** Return consistent index settings for the provided index version. */ public static Settings.Builder settings(IndexVersion version) { return Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, version); } - /** - * Return consistent index settings for the provided index version, shard- and replica-count. - */ + /** Return consistent index settings for the provided index version, shard- and replica-count. */ public static Settings.Builder indexSettings(IndexVersion indexVersionCreated, int shards, int replicas) { return settings(indexVersionCreated).put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, shards) .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, replicas); } - /** - * Return consistent index settings for the provided shard- and replica-count. - */ + /** Return consistent index settings for the provided shard- and replica-count. */ public static Settings.Builder indexSettings(int shards, int replicas) { return Settings.builder() .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, shards) @@ -1735,7 +1712,6 @@ public XContentType randomVendorType() { } public static class GeohashGenerator extends CodepointSetGenerator { - private static final char[] ASCII_SET = "0123456789bcdefghjkmnpqrstuvwxyz".toCharArray(); public GeohashGenerator() { @@ -1894,7 +1870,6 @@ public static C copyNamedWriteable( /** * Same as {@link #copyNamedWriteable(NamedWriteable, NamedWriteableRegistry, Class)} but also allows to provide * a {@link TransportVersion} argument which will be used to write and read back the object. - * * @return */ @SuppressWarnings("unchecked") @@ -2021,16 +1996,12 @@ public static Script mockScript(String id) { return new Script(ScriptType.INLINE, MockScriptEngine.NAME, id, emptyMap()); } - /** - * Returns the suite failure marker: internal use only! - */ + /** Returns the suite failure marker: internal use only! */ public static TestRuleMarkFailure getSuiteFailureMarker() { return suiteFailureMarker; } - /** - * Compares two stack traces, ignoring module (which is not yet serialized) - */ + /** Compares two stack traces, ignoring module (which is not yet serialized) */ public static void assertArrayEquals(StackTraceElement expected[], StackTraceElement actual[]) { assertEquals(expected.length, actual.length); for (int i = 0; i < expected.length; i++) { @@ -2038,9 +2009,7 @@ public static void assertArrayEquals(StackTraceElement expected[], StackTraceEle } } - /** - * Compares two stack trace elements, ignoring module (which is not yet serialized) - */ + /** Compares two stack trace elements, ignoring module (which is not yet serialized) */ public static void assertEquals(StackTraceElement expected, StackTraceElement actual) { assertEquals(expected.getClassName(), actual.getClassName()); assertEquals(expected.getMethodName(), actual.getMethodName()); @@ -2161,19 +2130,23 @@ public static boolean inFipsJvm() { * worker, avoiding any unexpected interactions, although if we spawn enough test workers then we will wrap around to the beginning * again. */ + /** * Defines the size of the port range assigned to each worker, which must be large enough to supply enough ports to run the tests, but * not so large that we run out of ports. See also [NOTE: Port ranges for tests]. */ private static final int PORTS_PER_WORKER = 30; + /** * Defines the minimum port that test workers should use. See also [NOTE: Port ranges for tests]. */ protected static final int MIN_PRIVATE_PORT = 13301; + /** * Defines the maximum port that test workers should use. See also [NOTE: Port ranges for tests]. */ private static final int MAX_PRIVATE_PORT = 32767; + /** * Wrap around after reaching this worker ID. */ @@ -2231,7 +2204,6 @@ public static InetAddress randomIp(boolean v4) { } public static final class DeprecationWarning { - private final Level level; // Intentionally ignoring level for the sake of equality for now private final String message; @@ -2262,9 +2234,8 @@ public String toString() { /** * Call method at the beginning of a test to disable its execution * until a given Lucene version is released and integrated into Elasticsearch - * * @param luceneVersionWithFix the lucene release to wait for - * @param message an additional message or link with information on the fix + * @param message an additional message or link with information on the fix */ protected void skipTestWaitingForLuceneFix(org.apache.lucene.util.Version luceneVersionWithFix, String message) { final boolean currentVersionHasFix = IndexVersion.current().luceneVersion().onOrAfter(luceneVersionWithFix); @@ -2275,10 +2246,9 @@ protected void skipTestWaitingForLuceneFix(org.apache.lucene.util.Version lucene /** * In non-FIPS mode, get a deterministic SecureRandom SHA1PRNG/SUN instance seeded by deterministic LuceneTestCase.random(). * In FIPS mode, get a non-deterministic SecureRandom DEFAULT/BCFIPS instance seeded by deterministic LuceneTestCase.random(). - * * @return SecureRandom SHA1PRNG instance. * @throws NoSuchAlgorithmException SHA1PRNG or DEFAULT algorithm not found. - * @throws NoSuchProviderException BCFIPS algorithm not found. + * @throws NoSuchProviderException BCFIPS algorithm not found. */ public static SecureRandom secureRandom() throws NoSuchAlgorithmException, NoSuchProviderException { return secureRandom(randomByteArrayOfLength(32)); @@ -2287,11 +2257,10 @@ public static SecureRandom secureRandom() throws NoSuchAlgorithmException, NoSuc /** * In non-FIPS mode, get a deterministic SecureRandom SHA1PRNG/SUN instance seeded by the input value. * In FIPS mode, get a non-deterministic SecureRandom DEFAULT/BCFIPS instance seeded by the input value. - * * @param seed Byte array to use for seeding the SecureRandom instance. * @return SecureRandom SHA1PRNG or DEFAULT/BCFIPS instance, depending on FIPS mode. * @throws NoSuchAlgorithmException SHA1PRNG or DEFAULT algorithm not found. - * @throws NoSuchProviderException BCFIPS algorithm not found. + * @throws NoSuchProviderException BCFIPS algorithm not found. */ public static SecureRandom secureRandom(final byte[] seed) throws NoSuchAlgorithmException, NoSuchProviderException { return inFipsJvm() ? secureRandomFips(seed) : secureRandomNonFips(seed); @@ -2299,7 +2268,6 @@ public static SecureRandom secureRandom(final byte[] seed) throws NoSuchAlgorith /** * Returns deterministic non-FIPS SecureRandom SHA1PRNG/SUN instance seeded by deterministic LuceneTestCase.random(). - * * @return Deterministic non-FIPS SecureRandom SHA1PRNG/SUN instance seeded by deterministic LuceneTestCase.random(). * @throws NoSuchAlgorithmException Exception if SHA1PRNG algorithm not found, such as missing SUN provider (unlikely). */ @@ -2309,7 +2277,6 @@ protected static SecureRandom secureRandomNonFips() throws NoSuchAlgorithmExcept /** * Returns non-deterministic FIPS SecureRandom DEFAULT/BCFIPS instance. Seeded. - * * @return Non-deterministic FIPS SecureRandom DEFAULT/BCFIPS instance. Seeded. * @throws NoSuchAlgorithmException Exception if DEFAULT algorithm not found, such as missing BCFIPS provider. */ @@ -2319,7 +2286,6 @@ protected static SecureRandom secureRandomFips() throws NoSuchAlgorithmException /** * Returns deterministic non-FIPS SecureRandom SHA1PRNG/SUN instance seeded by deterministic LuceneTestCase.random(). - * * @return Deterministic non-FIPS SecureRandom SHA1PRNG/SUN instance seeded by deterministic LuceneTestCase.random(). * @throws NoSuchAlgorithmException Exception if SHA1PRNG algorithm not found, such as missing SUN provider (unlikely). */ @@ -2331,7 +2297,6 @@ protected static SecureRandom secureRandomNonFips(final byte[] seed) throws NoSu /** * Returns non-deterministic FIPS SecureRandom DEFAULT/BCFIPS instance. Seeded. - * * @return Non-deterministic FIPS SecureRandom DEFAULT/BCFIPS instance. Seeded. * @throws NoSuchAlgorithmException Exception if DEFAULT algorithm not found, such as missing BCFIPS provider. */ @@ -2357,6 +2322,7 @@ protected static SecureRandom secureRandomFips(final byte[] seed) throws NoSuchA * in these requests. This constant can be used as a slightly more meaningful way to refer to the 30s default value in tests. */ public static final TimeValue TEST_REQUEST_TIMEOUT = TimeValue.THIRTY_SECONDS; + /** * The timeout used for the various "safe" wait methods such as {@link #safeAwait} and {@link #safeAcquire}. In tests we generally want * these things to complete almost immediately, but sometimes the CI runner executes things rather slowly so we use {@code 10s} as a @@ -2528,6 +2494,7 @@ public static Exception safeAwaitFailure(Consumer> consume * AssertionError} to trigger a test failure. * * @param responseType Class of listener response type, to aid type inference but otherwise ignored. + * * @return The exception with which the {@code listener} was completed exceptionally. */ public static Exception safeAwaitFailure(@SuppressWarnings("unused") Class responseType, Consumer> consumer) { @@ -2541,6 +2508,7 @@ public static Exception safeAwaitFailure(@SuppressWarnings("unused") Class ExpectedException safeAwaitFailure( @@ -2560,6 +2528,7 @@ public static ExpectedException * @param responseType Class of listener response type, to aid type inference but otherwise ignored. * @param exceptionType Expected unwrapped exception type. This method throws an {@link AssertionError} if a different type of exception * is seen. + * * @return The unwrapped exception with which the {@code listener} was completed exceptionally. */ public static ExpectedException safeAwaitAndUnwrapFailure( @@ -2681,9 +2650,8 @@ public static void startInParallel(int numberOfTasks, IntConsumer taskFactory) { /** * Run {@code numberOfTasks} parallel tasks that were created by the given {@code taskFactory}. On of the tasks will be run on the * calling thread, the rest will be run on a new thread. - * * @param numberOfTasks number of tasks to run in parallel - * @param taskFactory task factory + * @param taskFactory task factory */ public static void runInParallel(int numberOfTasks, IntConsumer taskFactory) { final ArrayList> futures = new ArrayList<>(numberOfTasks); From 00ae5ab054964643b63dc448e8219711e8cf4ed2 Mon Sep 17 00:00:00 2001 From: Max Hniebergall Date: Wed, 4 Dec 2024 15:22:02 -0500 Subject: [PATCH 42/53] Switch usage to accept null --- .../openai/OpenAiUnifiedStreamingProcessor.java | 3 ++- .../openai/OpenAiUnifiedStreamingProcessorTests.java | 11 ++--------- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java index 4023d0d0936b7..01b32d12fdbee 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java @@ -146,9 +146,10 @@ public static class ChatCompletionChunkParser { ); PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField(MODEL_FIELD)); PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField(OBJECT_FIELD)); - PARSER.declareObject( + PARSER.declareObjectOrNull( ConstructingObjectParser.optionalConstructorArg(), (p, c) -> ChatCompletionChunkParser.UsageParser.parse(p), + null, new ParseField(USAGE_FIELD) ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessorTests.java index ea7754279d4ee..0f127998f9c54 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessorTests.java @@ -131,11 +131,7 @@ public void testJsonLiteralCornerCases() { ], "model": "example_model", "object": "chat.completion.chunk", - "usage": { - "completion_tokens": 50, - "prompt_tokens": 20, - "total_tokens": 70 - } + "usage": null } """; // Parse the JSON @@ -150,10 +146,7 @@ public void testJsonLiteralCornerCases() { assertEquals("example_id", chunk.getId()); assertEquals("example_model", chunk.getModel()); assertEquals("chat.completion.chunk", chunk.getObject()); - assertNotNull(chunk.getUsage()); - assertEquals(50, chunk.getUsage().completionTokens()); - assertEquals(20, chunk.getUsage().promptTokens()); - assertEquals(70, chunk.getUsage().totalTokens()); + assertNull(chunk.getUsage()); List choices = chunk.getChoices(); assertEquals(2, choices.size()); From 88a7eb0d139f543a05222b0f18b03291b256eb52 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Thu, 5 Dec 2024 10:08:58 -0500 Subject: [PATCH 43/53] Adding test for TestStreamingCompletionServiceExtension --- .../inference/UnifiedCompletionRequest.java | 10 ++- ...mingUnifiedChatCompletionResultsTests.java | 2 +- .../inference/InferenceBaseRestTest.java | 32 +++++++++- .../xpack/inference/InferenceCrudIT.java | 55 ++++++++++++++++ ...stStreamingCompletionServiceExtension.java | 64 ++++++++++++++++++- 5 files changed, 159 insertions(+), 4 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java b/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java index 9b632776646f8..90c31719028de 100644 --- a/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java +++ b/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java @@ -203,6 +203,10 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(type); } + public String toString() { + return text + ":" + type; + } + } public record ContentString(String content) implements Content, NamedWriteable { @@ -230,6 +234,10 @@ public String getWriteableName() { public void toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { builder.value(content); } + + public String toString() { + return content; + } } public record ToolCall(String id, FunctionField function, String type) implements Writeable { @@ -390,7 +398,7 @@ public void writeTo(StreamOutput out) throws IOException { public record FunctionField( @Nullable String description, String name, - @Nullable Map parameters, // TODO can we parse this as a string? + @Nullable Map parameters, @Nullable Boolean strict ) implements Writeable { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResultsTests.java index 61a4aa949d4ca..447cf89c3045a 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResultsTests.java @@ -21,7 +21,7 @@ public class StreamingUnifiedChatCompletionResultsTests extends ESTestCase { - public void testResultstoXContentChunked() throws IOException { + public void testResults_toXContentChunked() throws IOException { String expected = """ { "id": "chunk1", diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java index 4e32ef99d06dd..07ce2fe00642b 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java @@ -21,6 +21,9 @@ import org.elasticsearch.test.cluster.ElasticsearchCluster; import org.elasticsearch.test.cluster.local.distribution.DistributionType; import org.elasticsearch.test.rest.ESRestTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent; import org.junit.ClassRule; @@ -341,10 +344,21 @@ protected Deque streamInferOnMockService(String modelId, TaskTy return callAsync(endpoint, input); } + protected Deque unifiedCompletionInferOnMockService(String modelId, TaskType taskType, List input) + throws Exception { + var endpoint = Strings.format("_inference/%s/%s/_unified", taskType, modelId); + return callAsyncUnified(endpoint, input, "user"); + } + private Deque callAsync(String endpoint, List input) throws Exception { - var responseConsumer = new AsyncInferenceResponseConsumer(); var request = new Request("POST", endpoint); request.setJsonEntity(jsonBody(input)); + + return execAsyncCall(request); + } + + private Deque execAsyncCall(Request request) throws Exception { + var responseConsumer = new AsyncInferenceResponseConsumer(); request.setOptions(RequestOptions.DEFAULT.toBuilder().setHttpAsyncResponseConsumerFactory(() -> responseConsumer).build()); var latch = new CountDownLatch(1); client().performRequestAsync(request, new ResponseListener() { @@ -362,6 +376,22 @@ public void onFailure(Exception exception) { return responseConsumer.events(); } + private Deque callAsyncUnified(String endpoint, List input, String role) throws Exception { + var request = new Request("POST", endpoint); + + request.setJsonEntity(createUnifiedJsonBody(input, role)); + return execAsyncCall(request); + } + + private String createUnifiedJsonBody(List input, String role) throws IOException { + var messages = input.stream().map(i -> Map.of("content", i, "role", role)).toList(); + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + builder.startObject(); + builder.field("messages", messages); + builder.endObject(); + return org.elasticsearch.common.Strings.toString(builder); + } + protected Map infer(String modelId, TaskType taskType, List input) throws IOException { var endpoint = Strings.format("_inference/%s/%s", taskType, modelId); return inferInternal(endpoint, input, Map.of()); diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java index f5773e73f2b22..1e19491aeaa60 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java @@ -11,13 +11,18 @@ import org.apache.http.util.EntityUtils; import org.elasticsearch.client.ResponseException; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceFeature; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; +import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Objects; @@ -481,6 +486,56 @@ public void testSupportedStream() throws Exception { } } + public void testUnifiedCompletionInference() throws Exception { + String modelId = "streaming"; + putModel(modelId, mockCompletionServiceModelConfig(TaskType.COMPLETION)); + var singleModel = getModel(modelId); + assertEquals(modelId, singleModel.get("inference_id")); + assertEquals(TaskType.COMPLETION.toString(), singleModel.get("task_type")); + + var input = IntStream.range(1, 2 + randomInt(8)).mapToObj(i -> randomUUID()).toList(); + try { + var events = unifiedCompletionInferOnMockService(modelId, TaskType.COMPLETION, input); + var expectedResponses = expectedResultsIterator(input); + assertThat(events.size(), equalTo((input.size() + 1) * 2)); + events.forEach(event -> { + switch (event.name()) { + case EVENT -> assertThat(event.value(), equalToIgnoringCase("message")); + case DATA -> assertThat(event.value(), equalTo(expectedResponses.next())); + } + }); + } finally { + deleteModel(modelId); + } + } + + private static Iterator expectedResultsIterator(List input) { + return Stream.concat(input.stream().map(String::toUpperCase).map(InferenceCrudIT::expectedResult), Stream.of("[DONE]")).iterator(); + } + + private static String expectedResult(String input) { + try { + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + builder.startObject(); + builder.field("id", "id"); + builder.startArray("choices"); + builder.startObject(); + builder.startObject("delta"); + builder.field("content", input); + builder.endObject(); + builder.field("index", 0); + builder.endObject(); + builder.endArray(); + builder.field("model", "gpt-4o-2024-08-06"); + builder.field("object", "chat.completion.chunk"); + builder.endObject(); + + return Strings.toString(builder); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + public void testGetZeroModels() throws IOException { var models = getModels("_all", TaskType.RERANK); assertThat(models, empty()); diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java index ee4b7d6ff82b2..f7a05a27354ef 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java @@ -37,6 +37,7 @@ import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults; +import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; import java.io.IOException; import java.util.EnumSet; @@ -129,7 +130,15 @@ public void unifiedCompletionInfer( TimeValue timeout, ActionListener listener ) { - listener.onFailure(new UnsupportedOperationException("unifiedCompletionInfer not supported")); // TODO + switch (model.getConfigurations().getTaskType()) { + case COMPLETION -> listener.onResponse(makeUnifiedResults(request)); + default -> listener.onFailure( + new ElasticsearchStatusException( + TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()), + RestStatus.BAD_REQUEST + ) + ); + } } private StreamingChatCompletionResults makeResults(List input) { @@ -163,6 +172,59 @@ private ChunkedToXContent completionChunk(String delta) { ); } + private StreamingUnifiedChatCompletionResults makeUnifiedResults(UnifiedCompletionRequest request) { + var responseIter = request.messages().stream().map(message -> message.content().toString().toUpperCase()).iterator(); + return new StreamingUnifiedChatCompletionResults(subscriber -> { + subscriber.onSubscribe(new Flow.Subscription() { + @Override + public void request(long n) { + if (responseIter.hasNext()) { + subscriber.onNext(unifiedCompletionChunk(responseIter.next())); + } else { + subscriber.onComplete(); + } + } + + @Override + public void cancel() {} + }); + }); + } + + /* + The response format looks like this + { + "id": "chatcmpl-AarrzyuRflye7yzDF4lmVnenGmQCF", + "choices": [ + { + "delta": { + "content": " information" + }, + "index": 0 + } + ], + "model": "gpt-4o-2024-08-06", + "object": "chat.completion.chunk" + } + */ + private ChunkedToXContent unifiedCompletionChunk(String delta) { + return params -> Iterators.concat( + ChunkedToXContentHelper.startObject(), + ChunkedToXContentHelper.field("id", "id"), + ChunkedToXContentHelper.startArray("choices"), + ChunkedToXContentHelper.startObject(), + ChunkedToXContentHelper.startObject("delta"), + ChunkedToXContentHelper.field("content", delta), + ChunkedToXContentHelper.endObject(), + ChunkedToXContentHelper.field("index", 0), + ChunkedToXContentHelper.endObject(), + ChunkedToXContentHelper.endArray(), + ChunkedToXContentHelper.field("model", "gpt-4o-2024-08-06"), + ChunkedToXContentHelper.field("object", "chat.completion.chunk"), + ChunkedToXContentHelper.endObject() + ); + } + @Override public void chunkedInfer( Model model, From be3a459395778b1f9ff4905ff2742ccb2e8b8845 Mon Sep 17 00:00:00 2001 From: Max Hniebergall Date: Thu, 5 Dec 2024 12:12:15 -0500 Subject: [PATCH 44/53] Avoid serializing empty lists + request entity tests --- ...nAiUnifiedChatCompletionRequestEntity.java | 8 +- ...ifiedChatCompletionRequestEntityTests.java | 856 ++++++++++++++++++ 2 files changed, 861 insertions(+), 3 deletions(-) create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntityTests.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java index 74339e441b305..50339bf851f7d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java @@ -69,12 +69,14 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws switch (message.content()) { case UnifiedCompletionRequest.ContentString contentString -> builder.field(CONTENT_FIELD, contentString.content()); case UnifiedCompletionRequest.ContentObjects contentObjects -> { + builder.startArray(CONTENT_FIELD); for (UnifiedCompletionRequest.ContentObject contentObject : contentObjects.contentObjects()) { - builder.startObject(CONTENT_FIELD); + builder.startObject(); builder.field(TEXT_FIELD, contentObject.text()); builder.field(TYPE_FIELD, contentObject.type()); builder.endObject(); } + builder.endArray(); } } @@ -116,7 +118,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(NUMBER_OF_RETURNED_CHOICES_FIELD, 1); - if (unifiedRequest.stop() != null) { + if (unifiedRequest.stop() != null && unifiedRequest.stop().isEmpty() == false) { builder.field(STOP_FIELD, unifiedRequest.stop()); } if (unifiedRequest.temperature() != null) { @@ -141,7 +143,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.endObject(); } } - if (unifiedRequest.tools() != null) { + if (unifiedRequest.tools() != null && unifiedRequest.tools().isEmpty() == false) { builder.startArray(TOOL_FIELD); for (UnifiedCompletionRequest.Tool t : unifiedRequest.tools()) { builder.startObject(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntityTests.java new file mode 100644 index 0000000000000..e0ad98996a216 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntityTests.java @@ -0,0 +1,856 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.openai; + +import org.elasticsearch.common.Randomness; +import org.elasticsearch.common.Strings; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.json.JsonXContent; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Locale; +import java.util.Map; +import java.util.Random; + +import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests.createChatCompletionModel; +import static org.hamcrest.Matchers.equalTo; + +public class OpenAiUnifiedChatCompletionRequestEntityTests extends ESTestCase { + + // 1. Basic Serialization + // Test with minimal required fields to ensure basic serialization works. + public void testBasicSerialization() throws IOException { + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Hello, world!"), + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + null, + null, + null + ); + var messageList = new ArrayList(); + messageList.add(message); + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest(messageList, null, null, null, null, null, null, null); + + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + OpenAiChatCompletionModel model = createChatCompletionModel("test-url", "organizationId", "api-key", "test-endpoint", null); + + OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "messages": [ + { + "content": "Hello, world!", + "role": "user" + } + ], + "model": "test-endpoint", + "n": 1, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """; + assertJsonEquals(jsonString, expectedJson); + } + + // 2. Serialization with All Fields + // Test with all possible fields populated to ensure complete serialization. + public void testSerializationWithAllFields() throws IOException { + // Create a message with all fields populated + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Hello, world!"), + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + "name", + "tool_call_id", + Collections.singletonList( + new UnifiedCompletionRequest.ToolCall( + "id", + new UnifiedCompletionRequest.ToolCall.FunctionField("arguments", "function_name"), + "type" + ) + ) + ); + + // Create a tool with all fields populated + UnifiedCompletionRequest.Tool tool = new UnifiedCompletionRequest.Tool( + "type", + new UnifiedCompletionRequest.Tool.FunctionField( + "Fetches the weather in the given location", + "get_weather", + createParameters(), + true + ) + ); + var messageList = new ArrayList(); + messageList.add(message); + // Create the unified request with all fields populated + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( + messageList, + "model", + 100L, // maxCompletionTokens + Collections.singletonList("stop"), + 0.9f, // temperature + new UnifiedCompletionRequest.ToolChoiceString("tool_choice"), + Collections.singletonList(tool), + 0.8f // topP + ); + + // Create the unified chat input + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + + OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); + + // Create the entity + OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); + + // Serialize to XContent + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + // Convert to string and verify + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "messages": [ + { + "content": "Hello, world!", + "role": "user", + "name": "name", + "tool_call_id": "tool_call_id", + "tool_calls": [ + { + "id": "id", + "function": { + "arguments": "arguments", + "name": "function_name" + }, + "type": "type" + } + ] + } + ], + "model": "model-name", + "max_completion_tokens": 100, + "n": 1, + "stop": ["stop"], + "temperature": 0.9, + "tool_choice": "tool_choice", + "tools": [ + { + "type": "type", + "function": { + "description": "Fetches the weather in the given location", + "name": "get_weather", + "parameters": { + "type": "object", + "properties": { + "location": { + "description": "The location to get the weather for", + "type": "string" + }, + "unit": { + "description": "The unit to return the temperature in", + "type": "string", + "enum": ["F", "C"] + } + }, + "additionalProperties": false, + "required": ["location", "unit"] + }, + "strict": true + } + } + ], + "top_p": 0.8, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """; + assertJsonEquals(jsonString, expectedJson); + + } + + // 3. Serialization with Null Optional Fields + // Test with optional fields set to null to ensure they are correctly omitted from the output. + public void testSerializationWithNullOptionalFields() throws IOException { + // Create a message with minimal required fields + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Hello, world!"), + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + null, + null, + null + ); + var messageList = new ArrayList(); + messageList.add(message); + + // Create the unified request with optional fields set to null + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( + messageList, + null, // model + null, // maxCompletionTokens + null, // stop + null, // temperature + null, // toolChoice + null, // tools + null // topP + ); + + // Create the unified chat input + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + + OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); + + // Create the entity + OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); + + // Serialize to XContent + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + // Convert to string and verify + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "messages": [ + { + "content": "Hello, world!", + "role": "user" + } + ], + "model": "model-name", + "n": 1, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """; + assertJsonEquals(jsonString, expectedJson); + } + + // 4. Serialization with Empty Lists + // Test with fields that are lists set to empty lists to ensure they are correctly serialized. + public void testSerializationWithEmptyLists() throws IOException { + // Create a message with minimal required fields + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Hello, world!"), + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + null, + null, + Collections.emptyList() // empty toolCalls list + ); + var messageList = new ArrayList(); + messageList.add(message); + // Create the unified request with empty lists + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( + messageList, + null, // model + null, // maxCompletionTokens + Collections.emptyList(), // empty stop list TODO when passing an empty stop-list, should it be converted to null? + null, // temperature + null, // toolChoice + Collections.emptyList(), // empty tools list TODO when passing an empty tools-list, should it be converted to null? + null // topP + ); + + // Create the unified chat input + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + + OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); + + // Create the entity + OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); + + // Serialize to XContent + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + // Convert to string and verify + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "messages": [ + { + "content": "Hello, world!", + "role": "user", + "tool_calls": [] + } + ], + "model": "model-name", + "n": 1, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """; + assertJsonEquals(jsonString, expectedJson); + } + + // 5. Serialization with Nested Objects + // Test with nested objects (e.g., toolCalls, toolChoice, tool) to ensure they are correctly serialized. + public void testSerializationWithNestedObjects() throws IOException { + Random random = Randomness.get(); + + // Generate random values + String randomContent = "Hello, world! " + random.nextInt(1000); + String randomName = "name" + random.nextInt(1000); + String randomToolCallId = "tool_call_id" + random.nextInt(1000); + String randomArguments = "arguments" + random.nextInt(1000); + String randomFunctionName = "function_name" + random.nextInt(1000); + String randomType = "type" + random.nextInt(1000); + String randomModel = "model" + random.nextInt(1000); + String randomStop = "stop" + random.nextInt(1000); + float randomTemperature = (float) ((float) Math.round(0.5d + (double) random.nextFloat() * 0.5d * 100000d) / 100000d); + float randomTopP = (float) ((float) Math.round(0.5d + (double) random.nextFloat() * 0.5d * 100000d) / 100000d); + + // Create a message with nested toolCalls + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString(randomContent), + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + randomName, + randomToolCallId, + Collections.singletonList( + new UnifiedCompletionRequest.ToolCall( + "id", + new UnifiedCompletionRequest.ToolCall.FunctionField(randomArguments, randomFunctionName), + randomType + ) + ) + ); + + // Create a tool with nested function fields + UnifiedCompletionRequest.Tool tool = new UnifiedCompletionRequest.Tool( + randomType, + new UnifiedCompletionRequest.Tool.FunctionField( + "Fetches the weather in the given location", + "get_weather", + createParameters(), + true + ) + ); + var messageList = new ArrayList(); + messageList.add(message); + // Create the unified request with nested objects + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( + messageList, + randomModel, + 100L, // maxCompletionTokens + Collections.singletonList(randomStop), + randomTemperature, // temperature + new UnifiedCompletionRequest.ToolChoiceObject( + randomType, + new UnifiedCompletionRequest.ToolChoiceObject.FunctionField(randomFunctionName) + ), + Collections.singletonList(tool), + randomTopP // topP + ); + + // Create the unified chat input + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + + OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", randomModel, null); + + // Create the entity + OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); + + // Serialize to XContent + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + // Convert to string and verify + String jsonString = Strings.toString(builder); + // Expected JSON should be dynamically generated based on random values + String expectedJson = String.format( + Locale.US, + """ + { + "messages": [ + { + "content": "%s", + "role": "user", + "name": "%s", + "tool_call_id": "%s", + "tool_calls": [ + { + "id": "id", + "function": { + "arguments": "%s", + "name": "%s" + }, + "type": "%s" + } + ] + } + ], + "model": "%s", + "max_completion_tokens": 100, + "n": 1, + "stop": ["%s"], + "temperature": %.5f, + "tool_choice": { + "type": "%s", + "function": { + "name": "%s" + } + }, + "tools": [ + { + "type": "%s", + "function": { + "description": "Fetches the weather in the given location", + "name": "get_weather", + "parameters": { + "type": "object", + "properties": { + "unit": { + "description": "The unit to return the temperature in", + "type": "string", + "enum": ["F", "C"] + }, + "location": { + "description": "The location to get the weather for", + "type": "string" + } + }, + "additionalProperties": false, + "required": ["location", "unit"] + }, + "strict": true + } + } + ], + "top_p": %.5f, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """, + randomContent, + randomName, + randomToolCallId, + randomArguments, + randomFunctionName, + randomType, + randomModel, + randomStop, + randomTemperature, + randomType, + randomFunctionName, + randomType, + randomTopP + ); + assertJsonEquals(jsonString, expectedJson); + } + + // 6. Serialization with Different Content Types + // Test with different content types in messages (e.g., ContentString, ContentObjects) to ensure they are correctly serialized. + public void testSerializationWithDifferentContentTypes() throws IOException { + Random random = Randomness.get(); + + // Generate random values for ContentString + String randomContentString = "Hello, world! " + random.nextInt(1000); + + // Generate random values for ContentObjects + String randomText = "Random text " + random.nextInt(1000); + String randomType = "type" + random.nextInt(1000); + UnifiedCompletionRequest.ContentObject contentObject = new UnifiedCompletionRequest.ContentObject(randomText, randomType); + + var contentObjectsList = new ArrayList(); + contentObjectsList.add(contentObject); + UnifiedCompletionRequest.ContentObjects contentObjects = new UnifiedCompletionRequest.ContentObjects(contentObjectsList); + + // Create messages with different content types + UnifiedCompletionRequest.Message messageWithString = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString(randomContentString), + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + null, + null, + null + ); + + UnifiedCompletionRequest.Message messageWithObjects = new UnifiedCompletionRequest.Message( + contentObjects, + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + null, + null, + null + ); + var messageList = new ArrayList(); + messageList.add(messageWithString); + messageList.add(messageWithObjects); + + // Create the unified request with both types of messages + UnifiedCompletionRequest unifiedRequest = UnifiedCompletionRequest.of(messageList); + + // Create the unified chat input + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + + OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); + + // Create the entity + OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); + + // Serialize to XContent + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + // Convert to string and verify + String jsonString = Strings.toString(builder); + String expectedJson = String.format(Locale.US, """ + { + "messages": [ + { + "content": "%s", + "role": "user" + }, + { + "content": [ + { + "text": "%s", + "type": "%s" + } + ], + "role": "user" + } + ], + "model": "model-name", + "n": 1, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """, randomContentString, randomText, randomType); + assertJsonEquals(jsonString, expectedJson); + } + + // 7. Serialization with Special Characters + // Test with special characters in string fields to ensure they are correctly escaped and serialized. + public void testSerializationWithSpecialCharacters() throws IOException { + // Create a message with special characters + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Hello, world! \n \"Special\" characters: \t \\ /"), + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + "name\nwith\nnewlines", + "tool_call_id\twith\ttabs", + Collections.singletonList( + new UnifiedCompletionRequest.ToolCall( + "id\\with\\backslashes", + new UnifiedCompletionRequest.ToolCall.FunctionField("arguments\"with\"quotes", "function_name/with/slashes"), + "type" + ) + ) + ); + var messageList = new ArrayList(); + messageList.add(message); + // Create the unified request + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( + messageList, + null, // model + null, // maxCompletionTokens + null, // stop + null, // temperature + null, // toolChoice + null, // tools + null // topP + ); + + // Create the unified chat input + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + + OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); + + // Create the entity + OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); + + // Serialize to XContent + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + // Convert to string and verify + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "messages": [ + { + "content": "Hello, world! \\n \\"Special\\" characters: \\t \\\\ /", + "role": "user", + "name": "name\\nwith\\nnewlines", + "tool_call_id": "tool_call_id\\twith\\ttabs", + "tool_calls": [ + { + "id": "id\\\\with\\\\backslashes", + "function": { + "arguments": "arguments\\"with\\"quotes", + "name": "function_name/with/slashes" + }, + "type": "type" + } + ] + } + ], + "model": "model-name", + "n": 1, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """; + assertJsonEquals(jsonString, expectedJson); + } + + // 8. Serialization with Boolean Fields + // Test with boolean fields (stream) set to both true and false to ensure they are correctly serialized. + public void testSerializationWithBooleanFields() throws IOException { + // Create a message with minimal required fields + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Hello, world!"), + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + null, + null, + null + ); + var messageList = new ArrayList(); + messageList.add(message); + // Create the unified request + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( + messageList, + null, // model + null, // maxCompletionTokens + null, // stop + null, // temperature + null, // toolChoice + null, // tools + null // topP + ); + + OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); + + // Test with stream set to true + UnifiedChatInput unifiedChatInputTrue = new UnifiedChatInput(unifiedRequest, true); + OpenAiUnifiedChatCompletionRequestEntity entityTrue = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInputTrue, model); + + XContentBuilder builderTrue = JsonXContent.contentBuilder(); + entityTrue.toXContent(builderTrue, ToXContent.EMPTY_PARAMS); + + String jsonStringTrue = Strings.toString(builderTrue); + String expectedJsonTrue = """ + { + "messages": [ + { + "content": "Hello, world!", + "role": "user" + } + ], + "model": "model-name", + "n": 1, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """; + assertJsonEquals(expectedJsonTrue, jsonStringTrue); + + // Test with stream set to false + UnifiedChatInput unifiedChatInputFalse = new UnifiedChatInput(unifiedRequest, false); + OpenAiUnifiedChatCompletionRequestEntity entityFalse = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInputFalse, model); + + XContentBuilder builderFalse = JsonXContent.contentBuilder(); + entityFalse.toXContent(builderFalse, ToXContent.EMPTY_PARAMS); + + String jsonStringFalse = Strings.toString(builderFalse); + String expectedJsonFalse = """ + { + "messages": [ + { + "content": "Hello, world!", + "role": "user" + } + ], + "model": "model-name", + "n": 1, + "stream": false + } + """; + assertJsonEquals(expectedJsonFalse, jsonStringFalse); + } + + // 9. Serialization with Missing Required Fields + // Test with missing required fields to ensure appropriate exceptions are thrown. + public void testSerializationWithMissingRequiredFields() { + // Create a message with missing content (required field) + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + null, // missing content + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + null, + null, + null + ); + var messageList = new ArrayList(); + messageList.add(message); + // Create the unified request + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( + messageList, + null, // model + null, // maxCompletionTokens + null, // stop + null, // temperature + null, // toolChoice + null, // tools + null // topP + ); + + // Create the unified chat input + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + + OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); + + // Create the entity + OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); + + // Attempt to serialize to XContent and expect an exception + try { + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + fail("Expected an exception due to missing required fields"); + } catch (NullPointerException | IOException e) { + // Expected exception + } + } + + // 10. Serialization with Mixed Valid and Invalid Data + // Test with a mix of valid and invalid data to ensure the serializer handles it gracefully. + public void testSerializationWithMixedValidAndInvalidData() throws IOException { + // Create a valid message + UnifiedCompletionRequest.Message validMessage = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Valid content"), + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + "validName", + "validToolCallId", + Collections.singletonList( + new UnifiedCompletionRequest.ToolCall( + "validId", + new UnifiedCompletionRequest.ToolCall.FunctionField("validArguments", "validFunctionName"), + "validType" + ) + ) + ); + + // Create an invalid message with null content + UnifiedCompletionRequest.Message invalidMessage = new UnifiedCompletionRequest.Message( + null, // invalid content + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + "invalidName", + "invalidToolCallId", + Collections.singletonList( + new UnifiedCompletionRequest.ToolCall( + "invalidId", + new UnifiedCompletionRequest.ToolCall.FunctionField("invalidArguments", "invalidFunctionName"), + "invalidType" + ) + ) + ); + var messageList = new ArrayList(); + messageList.add(validMessage); + messageList.add(invalidMessage); + // Create the unified request with both valid and invalid messages + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( + messageList, + "model-name", + 100L, // maxCompletionTokens + Collections.singletonList("stop"), + 0.9f, // temperature + new UnifiedCompletionRequest.ToolChoiceString("tool_choice"), + Collections.singletonList( + new UnifiedCompletionRequest.Tool( + "type", + new UnifiedCompletionRequest.Tool.FunctionField( + "Fetches the weather in the given location", + "get_weather", + createParameters(), + true + ) + ) + ), + 0.8f // topP + ); + + // Create the unified chat input + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + + OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); + + // Create the entity + OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); + + // Serialize to XContent and verify + try { + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + fail("Expected an exception due to invalid data"); + } catch (NullPointerException | IOException e) { + // Expected exception + } + } + + public static Map createParameters() { + Map parameters = new LinkedHashMap<>(); + parameters.put("type", "object"); + + Map properties = new HashMap<>(); + + Map location = new HashMap<>(); + location.put("type", "string"); + location.put("description", "The location to get the weather for"); + properties.put("location", location); + + Map unit = new HashMap<>(); + unit.put("type", "string"); + unit.put("description", "The unit to return the temperature in"); + unit.put("enum", new String[] { "F", "C" }); + properties.put("unit", unit); + + parameters.put("properties", properties); + parameters.put("additionalProperties", false); + parameters.put("required", new String[] { "location", "unit" }); + + return parameters; + } + + private void assertJsonEquals(String actual, String expected) throws IOException { + try ( + var actualParser = createParser(JsonXContent.jsonXContent, actual); + var expectedParser = createParser(JsonXContent.jsonXContent, expected) + ) { + assertThat(actualParser.mapOrdered(), equalTo(expectedParser.mapOrdered())); + } + } + +} From e40e41e4b98d3b4318b859fbadc92222dcba39fd Mon Sep 17 00:00:00 2001 From: Max Hniebergall Date: Thu, 5 Dec 2024 13:34:01 -0500 Subject: [PATCH 45/53] Register named writeables from UnifiedCompletionRequest --- .../xpack/inference/InferenceNamedWriteablesProvider.java | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index 2320cca8295d1..b83c098ca808c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -16,6 +16,7 @@ import org.elasticsearch.inference.SecretSettings; import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; @@ -137,11 +138,18 @@ public static List getNamedWriteables() { addEisNamedWriteables(namedWriteables); addAlibabaCloudSearchNamedWriteables(namedWriteables); + addUnifiedNamedWriteables(namedWriteables); + namedWriteables.addAll(StreamingTaskManager.namedWriteables()); return namedWriteables; } + private static void addUnifiedNamedWriteables(List namedWriteables) { + var writeables = UnifiedCompletionRequest.getNamedWriteables(); + namedWriteables.addAll(writeables); + } + private static void addAmazonBedrockNamedWriteables(List namedWriteables) { namedWriteables.add( new NamedWriteableRegistry.Entry( From c1e4ba135cb6b62cab2d33eeb32819705e5590a7 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Thu, 5 Dec 2024 13:45:23 -0500 Subject: [PATCH 46/53] Removing commented code --- .../action/BaseTransportInferenceAction.java | 43 ------------------- .../OpenAiUnifiedStreamingProcessor.java | 1 - 2 files changed, 44 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java index 87c2e8befbd56..2a0e8e1775279 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java @@ -92,28 +92,6 @@ protected void doExecute(Task task, Request request, ActionListener listener - // ) { - // return request.isUnifiedCompletionMode() - // // TODO add parameters - // ? () -> service.completionInfer(model, null, request.getInferenceTimeout(), listener) - // : () -> service.infer( - // model, - // request.getQuery(), - // request.getInput(), - // request.isStreaming(), - // request.getTaskSettings(), - // request.getInputType(), - // request.getInferenceTimeout(), - // listener - // ); - // } - protected abstract void doInference( Model model, Request request, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java index 01b32d12fdbee..9b0586fa35eca 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java @@ -36,7 +36,6 @@ public class OpenAiUnifiedStreamingProcessor extends DelegatingProcessor, ChunkedToXContent> { public static final String FUNCTION_FIELD = "function"; private static final Logger logger = LogManager.getLogger(OpenAiUnifiedStreamingProcessor.class); - private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in OpenAI chat completions response"; private static final String CHOICES_FIELD = "choices"; private static final String DELTA_FIELD = "delta"; From 10a5b12b3341f1f23aac479ae7cdc5834401ed1e Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Thu, 5 Dec 2024 14:16:55 -0500 Subject: [PATCH 47/53] Clean up and add more of an explination --- .../inference/rest/ServerSentEventsRestActionListener.java | 1 - .../xpack/inference/services/openai/OpenAiServiceTests.java | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListener.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListener.java index 7fcff1ede2d7d..3177474ea8ca6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListener.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListener.java @@ -358,7 +358,6 @@ public ReleasableBytesReference encodeChunk(int sizeHint, Recycler rec target.write(ServerSentEventSpec.EOL); target.write(ServerSentEventSpec.EOL); target.flush(); - } final var result = new ReleasableBytesReference(chunkStream.bytes(), () -> Releasables.closeExpectNoException(chunkStream)); target = null; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index 6792298d6c1ab..159b77789482d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -922,7 +922,7 @@ public void testInfer_SendsRequest() throws IOException { } public void testUnifiedCompletionInfer() throws Exception { - // streaming response must be on a single line + // The escapes are because the streaming response must be on a single line String responseJson = """ data: {\ "id":"12345",\ From bc7dbb8179efdbb1886d68071a01d63bf52ab225 Mon Sep 17 00:00:00 2001 From: Max Hniebergall Date: Thu, 5 Dec 2024 14:52:31 -0500 Subject: [PATCH 48/53] remove duplicate test --- ...mingUnifiedChatCompletionResultsTests.java | 77 ------------------- 1 file changed, 77 deletions(-) diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResultsTests.java index 61a4aa949d4ca..ccec8be6db76b 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResultsTests.java @@ -101,83 +101,6 @@ public void testResultstoXContentChunked() throws IOException { assertEquals(expected.replaceAll("\\s+", ""), Strings.toString(builder.prettyPrint()).trim()); } - public void testChatCompletionChunkToXContentChunked() throws IOException { - String expected = """ - { - "id": "chunk1", - "choices": [ - { - "delta": { - "content": "example_content", - "refusal": "example_refusal", - "role": "assistant", - "tool_calls": [ - { - "index": 1, - "id": "tool1", - "function": { - "arguments": "example_arguments", - "name": "example_function" - }, - "type": "function" - } - ] - }, - "finish_reason": "example_reason", - "index": 0 - } - ], - "model": "example_model", - "object": "example_object", - "usage": { - "completion_tokens": 10, - "prompt_tokens": 5, - "total_tokens": 15 - } - } - """; - - StreamingUnifiedChatCompletionResults.ChatCompletionChunk chunk = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk( - "chunk1", - List.of( - new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice( - new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta( - "example_content", - "example_refusal", - "assistant", - List.of( - new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall( - 1, - "tool1", - new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall.Function( - "example_arguments", - "example_function" - ), - "function" - ) - ) - ), - "example_reason", - 0 - ) - ), - "example_model", - "example_object", - new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage(10, 5, 15) - ); - - XContentBuilder builder = JsonXContent.contentBuilder(); - chunk.toXContentChunked(null).forEachRemaining(xContent -> { - try { - xContent.toXContent(builder, null); - } catch (IOException e) { - throw new RuntimeException(e); - } - }); - - assertEquals(expected.replaceAll("\\s+", ""), Strings.toString(builder.prettyPrint()).trim()); - } - public void testChoiceToXContentChunked() throws IOException { String expected = """ { From 357277e79cc28a94c41ff1129475b3a583d92e75 Mon Sep 17 00:00:00 2001 From: Max Hniebergall Date: Thu, 5 Dec 2024 15:13:21 -0500 Subject: [PATCH 49/53] remove old todos --- .../openai/OpenAiUnifiedChatCompletionRequestEntityTests.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntityTests.java index e0ad98996a216..f945c154ea234 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntityTests.java @@ -269,10 +269,10 @@ public void testSerializationWithEmptyLists() throws IOException { messageList, null, // model null, // maxCompletionTokens - Collections.emptyList(), // empty stop list TODO when passing an empty stop-list, should it be converted to null? + Collections.emptyList(), // empty stop list null, // temperature null, // toolChoice - Collections.emptyList(), // empty tools list TODO when passing an empty tools-list, should it be converted to null? + Collections.emptyList(), // empty tools list null // topP ); From 8f22f5664993db85ea9481c1514df03d5176412e Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Thu, 5 Dec 2024 15:22:05 -0500 Subject: [PATCH 50/53] Refactoring some duplication --- .../inference/common/DelegatingProcessor.java | 36 ++++++++++++++++++- .../openai/OpenAiStreamingProcessor.java | 18 ++-------- .../OpenAiUnifiedStreamingProcessor.java | 17 ++------- 3 files changed, 39 insertions(+), 32 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/DelegatingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/DelegatingProcessor.java index 3045246344921..65870949fcd30 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/DelegatingProcessor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/DelegatingProcessor.java @@ -9,7 +9,14 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; - +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent; +import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventField; + +import java.io.IOException; +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.Iterator; import java.util.concurrent.Flow; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; @@ -25,6 +32,33 @@ public abstract class DelegatingProcessor implements Flow.Processor private Flow.Subscriber downstream; private Flow.Subscription upstream; + public static Deque parseEvent( + Deque item, + ParseChunkFunction parseFunction, + XContentParserConfiguration parserConfig, + Logger logger + ) throws Exception { + var results = new ArrayDeque(item.size()); + for (ServerSentEvent event : item) { + if (ServerSentEventField.DATA == event.name() && event.hasValue()) { + try { + var delta = parseFunction.apply(parserConfig, event); + delta.forEachRemaining(results::offer); + } catch (Exception e) { + logger.warn("Failed to parse event from inference provider: {}", event); + throw e; + } + } + } + + return results; + } + + @FunctionalInterface + public interface ParseChunkFunction { + Iterator apply(XContentParserConfiguration parserConfig, ServerSentEvent event) throws IOException; + } + @Override public void subscribe(Flow.Subscriber subscriber) { if (downstream != null) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessor.java index 6e006fe255956..48c8132035b50 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessor.java @@ -18,10 +18,8 @@ import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults; import org.elasticsearch.xpack.inference.common.DelegatingProcessor; import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent; -import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventField; import java.io.IOException; -import java.util.ArrayDeque; import java.util.Collections; import java.util.Deque; import java.util.Iterator; @@ -115,19 +113,7 @@ public class OpenAiStreamingProcessor extends DelegatingProcessor item) throws Exception { var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); - - var results = new ArrayDeque(item.size()); - for (ServerSentEvent event : item) { - if (ServerSentEventField.DATA == event.name() && event.hasValue()) { - try { - var delta = parse(parserConfig, event); - delta.forEachRemaining(results::offer); - } catch (Exception e) { - log.warn("Failed to parse event from inference provider: {}", event); - throw e; - } - } - } + var results = parseEvent(item, OpenAiStreamingProcessor::parse, parserConfig, log); if (results.isEmpty()) { upstream().request(1); @@ -136,7 +122,7 @@ protected void next(Deque item) throws Exception { } } - private Iterator parse(XContentParserConfiguration parserConfig, ServerSentEvent event) + private static Iterator parse(XContentParserConfiguration parserConfig, ServerSentEvent event) throws IOException { if (DONE_MESSAGE.equalsIgnoreCase(event.value())) { return Collections.emptyIterator(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java index 9b0586fa35eca..9dfe4bf00ad08 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java @@ -20,7 +20,6 @@ import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; import org.elasticsearch.xpack.inference.common.DelegatingProcessor; import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent; -import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventField; import java.io.IOException; import java.util.ArrayDeque; @@ -72,19 +71,7 @@ protected void onRequest(long n) { @Override protected void next(Deque item) throws Exception { var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); - - var results = new ArrayDeque(item.size()); - for (ServerSentEvent event : item) { - if (ServerSentEventField.DATA == event.name() && event.hasValue()) { - try { - var delta = parse(parserConfig, event); - delta.forEachRemaining(results::offer); - } catch (Exception e) { - logger.warn("Failed to parse event from inference provider: {}", event); - throw e; - } - } - } + var results = parseEvent(item, OpenAiUnifiedStreamingProcessor::parse, parserConfig, logger); if (results.isEmpty()) { upstream().request(1); @@ -101,7 +88,7 @@ protected void next(Deque item) throws Exception { } } - private Iterator parse( + private static Iterator parse( XContentParserConfiguration parserConfig, ServerSentEvent event ) throws IOException { From a9b44b5ca0c837e7dbae0dbe7deea7fd291fc285 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Thu, 5 Dec 2024 15:37:28 -0500 Subject: [PATCH 51/53] Adding javadoc --- .../external/http/sender/ChatCompletionInput.java | 7 +++++++ .../external/http/sender/UnifiedChatInput.java | 10 ++++++++++ 2 files changed, 17 insertions(+) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ChatCompletionInput.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ChatCompletionInput.java index e7f3eb7dfea67..928da95d9c2f0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ChatCompletionInput.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ChatCompletionInput.java @@ -10,6 +10,13 @@ import java.util.List; import java.util.Objects; +/** + * This class encapsulates the input text passed by the request and indicates whether the response should be streamed. + * The main difference between this class and {@link UnifiedChatInput} is this should only be used for + * {@link org.elasticsearch.inference.TaskType#COMPLETION} originating through the + * {@link org.elasticsearch.inference.InferenceService#infer} code path. These are requests sent to the + * API without using the _unified route. + */ public class ChatCompletionInput extends InferenceInputs { private final List input; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInput.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInput.java index be647ef85e869..f89fa1ee37a6f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInput.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInput.java @@ -7,11 +7,21 @@ package org.elasticsearch.xpack.inference.external.http.sender; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.Model; import org.elasticsearch.inference.UnifiedCompletionRequest; import java.util.List; import java.util.Objects; +/** + * This class encapsulates the unified request. + * The main difference between this class and {@link ChatCompletionInput} is this should only be used for + * {@link org.elasticsearch.inference.TaskType#COMPLETION} originating through the + * {@link org.elasticsearch.inference.InferenceService#unifiedCompletionInfer(Model, UnifiedCompletionRequest, TimeValue, ActionListener)} + * code path. These are requests sent to the API with the _unified route. + */ public class UnifiedChatInput extends InferenceInputs { private final UnifiedCompletionRequest request; From 3c4428f9af87443e487c22a94daca3b8e8c81c94 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Fri, 6 Dec 2024 14:09:38 -0500 Subject: [PATCH 52/53] Addressing feedback --- .../inference/UnifiedCompletionRequest.java | 4 -- .../action/UnifiedCompletionAction.java | 6 ++- .../UnifiedCompletionActionRequestTests.java | 10 +++++ ...sportUnifiedCompletionInferenceAction.java | 2 +- .../inference/common/DelegatingProcessor.java | 4 +- .../OpenAiUnifiedStreamingProcessor.java | 9 ++-- ...TransportUnifiedCompletionActionTests.java | 41 +++++++++++++++++++ 7 files changed, 63 insertions(+), 13 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java b/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java index 90c31719028de..4128eed8a3854 100644 --- a/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java +++ b/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java @@ -231,10 +231,6 @@ public String getWriteableName() { return NAME; } - public void toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { - builder.value(content); - } - public String toString() { return content; } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionAction.java index 39188540cc7eb..8d121463fb465 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionAction.java @@ -66,6 +66,10 @@ public UnifiedCompletionRequest getUnifiedCompletionRequest() { return unifiedCompletionRequest; } + /** + * The Unified API only supports streaming so we always return true here. + * @return true + */ public boolean isStreaming() { return true; } @@ -88,7 +92,7 @@ public ActionRequestValidationException validate() { return e; } - if (taskType != TaskType.COMPLETION) { + if (taskType.isAnyOrSame(TaskType.COMPLETION) == false) { var e = new ActionRequestValidationException(); e.addValidationError("Field [taskType] must be [completion]"); return e; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionActionRequestTests.java index aad4df0a2ea5e..1872ac3caa230 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionActionRequestTests.java @@ -55,6 +55,16 @@ public void testValidation_ReturnsException_When_TaskType_IsNot_Completion() { assertThat(exception.getMessage(), is("Validation Failed: 1: Field [taskType] must be [completion];")); } + public void testValidation_ReturnsNull_When_TaskType_IsAny() { + var request = new UnifiedCompletionAction.Request( + "inference_id", + TaskType.ANY, + UnifiedCompletionRequest.of(List.of(UnifiedCompletionRequestTests.randomMessage())), + TimeValue.timeValueSeconds(10) + ); + assertNull(request.validate()); + } + @Override protected UnifiedCompletionAction.Request mutateInstanceForVersion(UnifiedCompletionAction.Request instance, TransportVersion version) { return instance; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java index fd6b234fce7f8..f0906231d8f42 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java @@ -49,7 +49,7 @@ public TransportUnifiedCompletionInferenceAction( @Override protected boolean isInvalidTaskTypeForInferenceEndpoint(UnifiedCompletionAction.Request request, UnparsedModel unparsedModel) { - return request.getTaskType() != TaskType.COMPLETION; + return request.getTaskType().isAnyOrSame(TaskType.COMPLETION) == false || unparsedModel.taskType() != TaskType.COMPLETION; } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/DelegatingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/DelegatingProcessor.java index 65870949fcd30..eda3fc0f3bfdb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/DelegatingProcessor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/DelegatingProcessor.java @@ -85,7 +85,7 @@ public void request(long n) { if (isClosed.get()) { downstream.onComplete(); } else if (upstream != null) { - onRequest(n); + upstreamRequest(n); } else { pendingRequests.accumulateAndGet(n, Long::sum); } @@ -104,7 +104,7 @@ public void cancel() { /** * Guaranteed to be called when the upstream is set and this processor had not been closed. */ - protected void onRequest(long n) { + protected void upstreamRequest(long n) { upstream.request(n); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java index 9dfe4bf00ad08..599d71df3dcfa 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java @@ -60,9 +60,9 @@ public class OpenAiUnifiedStreamingProcessor extends DelegatingProcessor buffer = new LinkedBlockingDeque<>(); @Override - protected void onRequest(long n) { + protected void upstreamRequest(long n) { if (buffer.isEmpty()) { - super.onRequest(n); + super.upstreamRequest(n); } else { downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(singleItem(buffer.poll()))); } @@ -75,8 +75,7 @@ protected void next(Deque item) throws Exception { if (results.isEmpty()) { upstream().request(1); - } - if (results.size() == 1) { + } else if (results.size() == 1) { downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(results)); } else { // results > 1, but openai spec only wants 1 chunk per SSE event @@ -281,7 +280,7 @@ public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage pa private Deque singleItem( StreamingUnifiedChatCompletionResults.ChatCompletionChunk result ) { - var deque = new ArrayDeque(2); + var deque = new ArrayDeque(1); deque.offer(result); return deque; } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java index bb702c6b1e538..4c943599ce523 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java @@ -80,4 +80,45 @@ public void testThrows_IncompatibleTaskTypeException_WhenUsingATextEmbeddingInfe assertThat(attributes.get("error.type"), is(String.valueOf(RestStatus.BAD_REQUEST.getStatus()))); })); } + + public void testThrows_IncompatibleTaskTypeException_WhenUsingRequestIsAny_ModelIsTextEmbedding() { + var modelTaskType = TaskType.ANY; + var requestTaskType = TaskType.TEXT_EMBEDDING; + mockModelRegistry(modelTaskType); + when(serviceRegistry.getService(any())).thenReturn(Optional.of(mock())); + + var listener = doExecute(requestTaskType); + + verify(listener).onFailure(assertArg(e -> { + assertThat(e, isA(ElasticsearchStatusException.class)); + assertThat( + e.getMessage(), + is("Incompatible task_type for unified API, the requested type [" + requestTaskType + "] must be one of [completion]") + ); + assertThat(((ElasticsearchStatusException) e).status(), is(RestStatus.BAD_REQUEST)); + })); + verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { + assertThat(attributes.get("service"), is(serviceId)); + assertThat(attributes.get("task_type"), is(modelTaskType.toString())); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), is(RestStatus.BAD_REQUEST.getStatus())); + assertThat(attributes.get("error.type"), is(String.valueOf(RestStatus.BAD_REQUEST.getStatus()))); + })); + } + + public void testMetricsAfterUnifiedInferSuccess_WithRequestTaskTypeAny() { + mockModelRegistry(TaskType.COMPLETION); + mockService(listener -> listener.onResponse(mock())); + + var listener = doExecute(TaskType.ANY); + + verify(listener).onResponse(any()); + verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { + assertThat(attributes.get("service"), is(serviceId)); + assertThat(attributes.get("task_type"), is(taskType.toString())); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), is(200)); + assertThat(attributes.get("error.type"), nullValue()); + })); + } } From 7fc36ce400028f641d6a2f8de6cc1c88b97284e9 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Fri, 6 Dec 2024 14:53:12 -0500 Subject: [PATCH 53/53] Removing unused import --- .../org/elasticsearch/inference/UnifiedCompletionRequest.java | 2 -- 1 file changed, 2 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java b/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java index 4128eed8a3854..e596be626b518 100644 --- a/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java +++ b/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java @@ -19,8 +19,6 @@ import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ObjectParser; import org.elasticsearch.xcontent.ParseField; -import org.elasticsearch.xcontent.ToXContent; -import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParseException; import org.elasticsearch.xcontent.XContentParser;