diff --git a/docs/changelog/127767.yaml b/docs/changelog/127767.yaml new file mode 100644 index 0000000000000..659fc31fbaf83 --- /dev/null +++ b/docs/changelog/127767.yaml @@ -0,0 +1,5 @@ +pr: 127767 +summary: Integrate `OpenAi` Chat Completion in `SageMaker` +area: Machine Learning +type: enhancement +issues: [] diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index fd6a052b527bd..5ad533c78ab6d 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -180,6 +180,7 @@ static TransportVersion def(int id) { public static final TransportVersion ESQL_HASH_OPERATOR_STATUS_OUTPUT_TIME_8_19 = def(8_841_0_34); public static final TransportVersion RERANKER_FAILURES_ALLOWED_8_19 = def(8_841_0_35); public static final TransportVersion ML_INFERENCE_HUGGING_FACE_RERANK_ADDED_8_19 = def(8_841_0_36); + public static final TransportVersion ML_INFERENCE_SAGEMAKER_CHAT_COMPLETION_8_19 = def(8_841_0_37); public static final TransportVersion V_9_0_0 = def(9_000_0_09); public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10); public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11); @@ -264,6 +265,7 @@ static TransportVersion def(int id) { public static final TransportVersion NODES_STATS_SUPPORTS_MULTI_PROJECT = def(9_079_0_00); public static final TransportVersion ML_INFERENCE_HUGGING_FACE_RERANK_ADDED = def(9_080_0_00); public static final TransportVersion SETTINGS_IN_DATA_STREAMS_DRY_RUN = def(9_081_0_00); + public static final TransportVersion ML_INFERENCE_SAGEMAKER_CHAT_COMPLETION = def(9_082_0_00); /* * STOP! READ THIS FIRST! No, really, * ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _ diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java index 52bd95e9d2619..1160002a60ac0 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java @@ -124,7 +124,7 @@ public void testGetServicesWithRerankTaskType() throws IOException { public void testGetServicesWithCompletionTaskType() throws IOException { List services = getServices(TaskType.COMPLETION); - assertThat(services.size(), equalTo(11)); + assertThat(services.size(), equalTo(12)); var providers = providers(services); @@ -142,7 +142,8 @@ public void testGetServicesWithCompletionTaskType() throws IOException { "googleaistudio", "openai", "streaming_completion_test_service", - "hugging_face" + "hugging_face", + "sagemaker" ).toArray() ) ); @@ -150,13 +151,15 @@ public void testGetServicesWithCompletionTaskType() throws IOException { public void testGetServicesWithChatCompletionTaskType() throws IOException { List services = getServices(TaskType.CHAT_COMPLETION); - assertThat(services.size(), equalTo(5)); + assertThat(services.size(), equalTo(6)); var providers = providers(services); assertThat( providers, - containsInAnyOrder(List.of("deepseek", "elastic", "openai", "streaming_completion_test_service", "hugging_face").toArray()) + containsInAnyOrder( + List.of("deepseek", "elastic", "openai", "streaming_completion_test_service", "hugging_face", "sagemaker").toArray() + ) ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/DelegatingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/DelegatingProcessor.java index 6160f51709299..93d120d63c65c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/DelegatingProcessor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/DelegatingProcessor.java @@ -12,13 +12,12 @@ import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent; -import java.io.IOException; import java.util.ArrayDeque; import java.util.Deque; -import java.util.Iterator; import java.util.concurrent.Flow; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; +import java.util.stream.Stream; /** * Processor that delegates the {@link java.util.concurrent.Flow.Subscription} to the upstream {@link java.util.concurrent.Flow.Publisher} @@ -34,19 +33,13 @@ public abstract class DelegatingProcessor implements Flow.Processor public static Deque parseEvent( Deque item, ParseChunkFunction parseFunction, - XContentParserConfiguration parserConfig, - Logger logger - ) throws Exception { + XContentParserConfiguration parserConfig + ) { var results = new ArrayDeque(item.size()); for (ServerSentEvent event : item) { if (event.hasData()) { - try { - var delta = parseFunction.apply(parserConfig, event); - delta.forEachRemaining(results::offer); - } catch (Exception e) { - logger.warn("Failed to parse event from inference provider: {}", event); - throw e; - } + var delta = parseFunction.apply(parserConfig, event); + delta.forEach(results::offer); } } @@ -55,7 +48,7 @@ public static Deque parseEvent( @FunctionalInterface public interface ParseChunkFunction { - Iterator apply(XContentParserConfiguration parserConfig, ServerSentEvent event) throws IOException; + Stream apply(XContentParserConfiguration parserConfig, ServerSentEvent event); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/unified/UnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/unified/UnifiedChatCompletionRequestEntity.java index 6a6f8d92c74ca..623e08cb58a85 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/unified/UnifiedChatCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/unified/UnifiedChatCompletionRequestEntity.java @@ -45,10 +45,12 @@ public class UnifiedChatCompletionRequestEntity implements ToXContentFragment { private final boolean stream; public UnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput) { - Objects.requireNonNull(unifiedChatInput); + this(Objects.requireNonNull(unifiedChatInput).getRequest(), Objects.requireNonNull(unifiedChatInput).stream()); + } - this.unifiedRequest = unifiedChatInput.getRequest(); - this.stream = unifiedChatInput.stream(); + public UnifiedChatCompletionRequestEntity(UnifiedCompletionRequest unifiedRequest, boolean stream) { + this.unifiedRequest = Objects.requireNonNull(unifiedRequest); + this.stream = stream; } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiStreamingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiStreamingProcessor.java index fc71e656322b2..57cdcdf4ba046 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiStreamingProcessor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiStreamingProcessor.java @@ -9,8 +9,10 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; @@ -20,11 +22,10 @@ import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent; import java.io.IOException; -import java.util.Collections; import java.util.Deque; -import java.util.Iterator; import java.util.Objects; import java.util.function.Predicate; +import java.util.stream.Stream; import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList; @@ -113,7 +114,7 @@ public class OpenAiStreamingProcessor extends DelegatingProcessor item) throws Exception { var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); - var results = parseEvent(item, OpenAiStreamingProcessor::parse, parserConfig, log); + var results = parseEvent(item, OpenAiStreamingProcessor::parse, parserConfig); if (results.isEmpty()) { upstream().request(1); @@ -122,10 +123,9 @@ protected void next(Deque item) throws Exception { } } - private static Iterator parse(XContentParserConfiguration parserConfig, ServerSentEvent event) - throws IOException { + public static Stream parse(XContentParserConfiguration parserConfig, ServerSentEvent event) { if (DONE_MESSAGE.equalsIgnoreCase(event.data())) { - return Collections.emptyIterator(); + return Stream.empty(); } try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, event.data())) { @@ -167,11 +167,14 @@ private static Iterator parse(XContentPar consumeUntilObjectEnd(parser); // end choices return ""; // stopped - }).stream() - .filter(Objects::nonNull) - .filter(Predicate.not(String::isEmpty)) - .map(StreamingChatCompletionResults.Result::new) - .iterator(); + }).stream().filter(Objects::nonNull).filter(Predicate.not(String::isEmpty)).map(StreamingChatCompletionResults.Result::new); + } catch (IOException e) { + throw new ElasticsearchStatusException( + "Failed to parse event from inference provider: {}", + RestStatus.INTERNAL_SERVER_ERROR, + e, + event + ); } } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandler.java index e60ba31823107..bbc83667c13e7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandler.java @@ -50,7 +50,6 @@ public OpenAiUnifiedChatCompletionResponseHandler( public InferenceServiceResults parseResult(Request request, Flow.Publisher flow) { var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser()); var openAiProcessor = new OpenAiUnifiedStreamingProcessor((m, e) -> buildMidStreamError(request, m, e)); - flow.subscribe(serverSentEventProcessor); serverSentEventProcessor.subscribe(openAiProcessor); return new StreamingUnifiedChatCompletionResults(openAiProcessor); @@ -81,6 +80,10 @@ protected static String createErrorType(ErrorResponse errorResponse) { } protected Exception buildMidStreamError(Request request, String message, Exception e) { + return buildMidStreamError(request.getInferenceEntityId(), message, e); + } + + public static UnifiedChatCompletionException buildMidStreamError(String inferenceEntityId, String message, Exception e) { var errorResponse = OpenAiErrorResponse.fromString(message); if (errorResponse instanceof OpenAiErrorResponse oer) { return new UnifiedChatCompletionException( @@ -88,7 +91,7 @@ protected Exception buildMidStreamError(Request request, String message, Excepti format( "%s for request from inference entity id [%s]. Error message: [%s]", SERVER_ERROR_OBJECT, - request.getInferenceEntityId(), + inferenceEntityId, errorResponse.getErrorMessage() ), oer.type(), @@ -100,7 +103,7 @@ protected Exception buildMidStreamError(Request request, String message, Excepti } else { return new UnifiedChatCompletionException( RestStatus.INTERNAL_SERVER_ERROR, - format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, request.getInferenceEntityId()), + format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, inferenceEntityId), createErrorType(errorResponse), "stream_error" ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedStreamingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedStreamingProcessor.java index 983bb5efbf3fa..86b4a0a65ef2c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedStreamingProcessor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedStreamingProcessor.java @@ -22,11 +22,10 @@ import java.io.IOException; import java.util.ArrayDeque; -import java.util.Collections; import java.util.Deque; -import java.util.Iterator; import java.util.List; import java.util.function.BiFunction; +import java.util.stream.Stream; import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken; @@ -75,7 +74,7 @@ protected void next(Deque item) throws Exception { } else if (event.hasData()) { try { var delta = parse(parserConfig, event); - delta.forEachRemaining(results::offer); + delta.forEach(results::offer); } catch (Exception e) { logger.warn("Failed to parse event from inference provider: {}", event); throw errorParser.apply(event.data(), e); @@ -90,12 +89,12 @@ protected void next(Deque item) throws Exception { } } - private static Iterator parse( + public static Stream parse( XContentParserConfiguration parserConfig, ServerSentEvent event ) throws IOException { if (DONE_MESSAGE.equalsIgnoreCase(event.data())) { - return Collections.emptyIterator(); + return Stream.empty(); } try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, event.data())) { @@ -106,7 +105,7 @@ private static Iterator taskSettings, InputType inputType, - TimeValue timeout, + @Nullable TimeValue timeout, ActionListener listener ) { if (model instanceof SageMakerModel == false) { @@ -148,7 +149,7 @@ public void infer( client.invokeStream( regionAndSecrets, request, - timeout, + timeout != null ? timeout : DEFAULT_TIMEOUT, ActionListener.wrap( response -> listener.onResponse(schema.streamResponse(sageMakerModel, response)), e -> listener.onFailure(schema.error(sageMakerModel, e)) @@ -160,7 +161,7 @@ public void infer( client.invoke( regionAndSecrets, request, - timeout, + timeout != null ? timeout : DEFAULT_TIMEOUT, ActionListener.wrap( response -> listener.onResponse(schema.response(sageMakerModel, response, threadPool.getThreadContext())), e -> listener.onFailure(schema.error(sageMakerModel, e)) @@ -201,7 +202,7 @@ private static ElasticsearchStatusException internalFailure(Model model, Excepti public void unifiedCompletionInfer( Model model, UnifiedCompletionRequest request, - TimeValue timeout, + @Nullable TimeValue timeout, ActionListener listener ) { if (model instanceof SageMakerModel == false) { @@ -217,7 +218,7 @@ public void unifiedCompletionInfer( client.invokeStream( regionAndSecrets, sagemakerRequest, - timeout, + timeout != null ? timeout : DEFAULT_TIMEOUT, ActionListener.wrap( response -> listener.onResponse(schema.chatCompletionStreamResponse(sageMakerModel, response)), e -> listener.onFailure(schema.chatCompletionError(sageMakerModel, e)) @@ -235,7 +236,7 @@ public void chunkedInfer( List input, Map taskSettings, InputType inputType, - TimeValue timeout, + @Nullable TimeValue timeout, ActionListener> listener ) { if (model instanceof SageMakerModel == false) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchemas.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchemas.java index cf3a17a7ae70f..3ecd0388796c4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchemas.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchemas.java @@ -12,10 +12,12 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel; +import org.elasticsearch.xpack.inference.services.sagemaker.schema.openai.OpenAiCompletionPayload; import org.elasticsearch.xpack.inference.services.sagemaker.schema.openai.OpenAiTextEmbeddingPayload; import java.util.Arrays; import java.util.EnumSet; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; @@ -39,7 +41,7 @@ public class SageMakerSchemas { /* * Add new model API to the register call. */ - schemas = register(new OpenAiTextEmbeddingPayload()); + schemas = register(new OpenAiTextEmbeddingPayload(), new OpenAiCompletionPayload()); streamSchemas = schemas.entrySet() .stream() @@ -88,7 +90,16 @@ public static List namedWriteables() { ) ), schemas.values().stream().flatMap(SageMakerSchema::namedWriteables) - ).toList(); + ) + // Dedupe based on Entry name, we allow Payloads to declare the same Entry but the Registry does not handle duplicates + .collect( + () -> new HashMap(), + (map, entry) -> map.putIfAbsent(entry.name, entry), + Map::putAll + ) + .values() + .stream() + .toList(); } public SageMakerSchema schemaFor(SageMakerModel model) throws ElasticsearchStatusException { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerStreamSchema.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerStreamSchema.java index 1eb84ecede37e..10175aa8ecb3b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerStreamSchema.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerStreamSchema.java @@ -20,6 +20,7 @@ import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults; +import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerClient; import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest; @@ -66,16 +67,16 @@ private InvokeEndpointWithResponseStreamRequest streamRequest(SageMakerModel mod } public InferenceServiceResults streamResponse(SageMakerModel model, SageMakerClient.SageMakerStream response) { - return streamResponse(model, response, payload::streamResponseBody, this::error); + return new StreamingChatCompletionResults(streamResponse(model, response, payload::streamResponseBody, this::error)); } - private InferenceServiceResults streamResponse( + private Flow.Publisher streamResponse( SageMakerModel model, SageMakerClient.SageMakerStream response, - CheckedBiFunction parseFunction, + CheckedBiFunction parseFunction, BiFunction errorFunction ) { - return new StreamingChatCompletionResults(downstream -> { + return downstream -> { response.responseStream().subscribe(new Flow.Subscriber<>() { private volatile Flow.Subscription upstream; @@ -118,7 +119,7 @@ public void onComplete() { downstream.onComplete(); } }); - }); + }; } public InvokeEndpointWithResponseStreamRequest chatCompletionStreamRequest(SageMakerModel model, UnifiedCompletionRequest request) { @@ -126,7 +127,9 @@ public InvokeEndpointWithResponseStreamRequest chatCompletionStreamRequest(SageM } public InferenceServiceResults chatCompletionStreamResponse(SageMakerModel model, SageMakerClient.SageMakerStream response) { - return streamResponse(model, response, payload::chatCompletionResponseBody, this::chatCompletionError); + return new StreamingUnifiedChatCompletionResults( + streamResponse(model, response, payload::chatCompletionResponseBody, this::chatCompletionError) + ); } public UnifiedChatCompletionException chatCompletionError(SageMakerModel model, Exception e) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerStreamSchemaPayload.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerStreamSchemaPayload.java index 7867e16b87733..0da33cf3c3628 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerStreamSchemaPayload.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerStreamSchemaPayload.java @@ -9,9 +9,10 @@ import software.amazon.awssdk.core.SdkBytes; -import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults; +import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest; import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel; @@ -38,9 +39,9 @@ default EnumSet supportedTasks() { * This API would only be called for Completion task types. {@link #requestBytes(SageMakerModel, SageMakerInferenceRequest)} would * handle the request translation for both streaming and non-streaming. */ - InferenceServiceResults.Result streamResponseBody(SageMakerModel model, SdkBytes response) throws Exception; + StreamingChatCompletionResults.Results streamResponseBody(SageMakerModel model, SdkBytes response) throws Exception; SdkBytes chatCompletionRequestBytes(SageMakerModel model, UnifiedCompletionRequest request) throws Exception; - InferenceServiceResults.Result chatCompletionResponseBody(SageMakerModel model, SdkBytes response) throws Exception; + StreamingUnifiedChatCompletionResults.Results chatCompletionResponseBody(SageMakerModel model, SdkBytes response) throws Exception; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiCompletionPayload.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiCompletionPayload.java new file mode 100644 index 0000000000000..64b42f00d2d5b --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiCompletionPayload.java @@ -0,0 +1,168 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.sagemaker.schema.openai; + +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointResponse; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.xcontent.XContent; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.json.JsonXContent; +import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults; +import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent; +import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser; +import org.elasticsearch.xpack.inference.external.unified.UnifiedChatCompletionRequestEntity; +import org.elasticsearch.xpack.inference.services.openai.OpenAiStreamingProcessor; +import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler; +import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedStreamingProcessor; +import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity; +import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest; +import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel; +import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredTaskSchema; +import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStreamSchemaPayload; + +import java.util.ArrayDeque; +import java.util.Map; +import java.util.stream.Stream; + +public class OpenAiCompletionPayload implements SageMakerStreamSchemaPayload { + + private static final XContent jsonXContent = JsonXContent.jsonXContent; + private static final String APPLICATION_JSON = jsonXContent.type().mediaTypeWithoutParameters(); + private static final XContentParserConfiguration parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler( + LoggingDeprecationHandler.INSTANCE + ); + private static final String USER_FIELD = "user"; + private static final String USER_ROLE = "user"; + private static final String MAX_COMPLETION_TOKENS_FIELD = "max_completion_tokens"; + private static final ResponseHandler ERROR_HANDLER = new OpenAiUnifiedChatCompletionResponseHandler( + "sagemaker openai chat completion", + ((request, result) -> { + assert false : "do not call this"; + throw new UnsupportedOperationException("SageMaker should not call this object's response parser."); + }) + ); + + @Override + public SdkBytes chatCompletionRequestBytes(SageMakerModel model, UnifiedCompletionRequest request) throws Exception { + return completion(model, new UnifiedChatCompletionRequestEntity(request, true), request.maxCompletionTokens()); + } + + private SdkBytes completion(SageMakerModel model, UnifiedChatCompletionRequestEntity requestEntity, @Nullable Long maxCompletionTokens) + throws Exception { + if (model.apiTaskSettings() instanceof SageMakerOpenAiTaskSettings apiTaskSettings) { + return SdkBytes.fromUtf8String(Strings.toString((builder, params) -> { + requestEntity.toXContent(builder, params); + + if (Strings.isNullOrEmpty(apiTaskSettings.user()) == false) { + builder.field(USER_FIELD, apiTaskSettings.user()); + } + + if (maxCompletionTokens != null) { + builder.field(MAX_COMPLETION_TOKENS_FIELD, maxCompletionTokens); + } + return builder; + })); + } else { + throw createUnsupportedSchemaException(model); + } + } + + @Override + public StreamingUnifiedChatCompletionResults.Results chatCompletionResponseBody(SageMakerModel model, SdkBytes response) { + var serverSentEvents = serverSentEvents(response); + var results = serverSentEvents.flatMap(event -> { + if ("error".equals(event.type())) { + throw OpenAiUnifiedChatCompletionResponseHandler.buildMidStreamError(model.getInferenceEntityId(), event.data(), null); + } else { + try { + return OpenAiUnifiedStreamingProcessor.parse(parserConfig, event); + } catch (Exception e) { + throw OpenAiUnifiedChatCompletionResponseHandler.buildMidStreamError(model.getInferenceEntityId(), event.data(), e); + } + } + }) + .collect( + () -> new ArrayDeque(), + ArrayDeque::offer, + ArrayDeque::addAll + ); + return new StreamingUnifiedChatCompletionResults.Results(results); + } + + /* + * We should be safe to use ServerSentEventParser. It was built knowing Apache HTTP will have leftover bytes for us to manage, + * but SageMaker uses Netty and (likely, hopefully) doesn't have that problem. + */ + private Stream serverSentEvents(SdkBytes response) { + return new ServerSentEventParser().parse(response.asByteArray()).stream().filter(ServerSentEvent::hasData); + } + + @Override + public String api() { + return "openai"; + } + + @Override + public SageMakerStoredTaskSchema apiTaskSettings(Map taskSettings, ValidationException validationException) { + return SageMakerOpenAiTaskSettings.fromMap(taskSettings, validationException); + } + + @Override + public Stream namedWriteables() { + return Stream.of( + new NamedWriteableRegistry.Entry( + SageMakerStoredTaskSchema.class, + SageMakerOpenAiTaskSettings.NAME, + SageMakerOpenAiTaskSettings::new + ) + ); + } + + @Override + public String accept(SageMakerModel model) { + return APPLICATION_JSON; + } + + @Override + public String contentType(SageMakerModel model) { + return APPLICATION_JSON; + } + + @Override + public SdkBytes requestBytes(SageMakerModel model, SageMakerInferenceRequest request) throws Exception { + return completion( + model, + new UnifiedChatCompletionRequestEntity(new UnifiedChatInput(request.input(), USER_ROLE, request.stream())), + null + ); + } + + @Override + public InferenceServiceResults responseBody(SageMakerModel model, InvokeEndpointResponse response) throws Exception { + return OpenAiChatCompletionResponseEntity.fromResponse(response.body().asByteArray()); + } + + @Override + public StreamingChatCompletionResults.Results streamResponseBody(SageMakerModel model, SdkBytes response) { + var serverSentEvents = serverSentEvents(response); + var results = serverSentEvents.flatMap(event -> OpenAiStreamingProcessor.parse(parserConfig, event)) + .collect(() -> new ArrayDeque(), ArrayDeque::offer, ArrayDeque::addAll); + return new StreamingChatCompletionResults.Results(results); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiTextEmbeddingPayload.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiTextEmbeddingPayload.java index 7bd122a5922e4..276c407d694d6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiTextEmbeddingPayload.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiTextEmbeddingPayload.java @@ -40,7 +40,6 @@ import java.util.stream.Stream; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; public class OpenAiTextEmbeddingPayload implements SageMakerSchemaPayload { @@ -64,14 +63,18 @@ public SageMakerStoredServiceSchema apiServiceSettings(Map servi @Override public SageMakerStoredTaskSchema apiTaskSettings(Map taskSettings, ValidationException validationException) { - return ApiTaskSettings.fromMap(taskSettings, validationException); + return SageMakerOpenAiTaskSettings.fromMap(taskSettings, validationException); } @Override public Stream namedWriteables() { return Stream.of( new NamedWriteableRegistry.Entry(SageMakerStoredServiceSchema.class, ApiServiceSettings.NAME, ApiServiceSettings::new), - new NamedWriteableRegistry.Entry(SageMakerStoredTaskSchema.class, ApiTaskSettings.NAME, ApiTaskSettings::new) + new NamedWriteableRegistry.Entry( + SageMakerStoredTaskSchema.class, + SageMakerOpenAiTaskSettings.NAME, + SageMakerOpenAiTaskSettings::new + ) ); } @@ -88,7 +91,7 @@ public String contentType(SageMakerModel model) { @Override public SdkBytes requestBytes(SageMakerModel model, SageMakerInferenceRequest request) throws Exception { if (model.apiServiceSettings() instanceof ApiServiceSettings apiServiceSettings - && model.apiTaskSettings() instanceof ApiTaskSettings apiTaskSettings) { + && model.apiTaskSettings() instanceof SageMakerOpenAiTaskSettings apiTaskSettings) { try (var builder = JsonXContent.contentBuilder()) { builder.startObject(); if (request.query() != null) { @@ -178,52 +181,4 @@ public SageMakerStoredServiceSchema updateModelWithEmbeddingDetails(Integer dime return new ApiServiceSettings(dimensions, false); } } - - record ApiTaskSettings(@Nullable String user) implements SageMakerStoredTaskSchema { - private static final String NAME = "sagemaker_openai_text_embeddings_task_settings"; - private static final String USER_FIELD = "user"; - - ApiTaskSettings(StreamInput in) throws IOException { - this(in.readOptionalString()); - } - - @Override - public String getWriteableName() { - return NAME; - } - - @Override - public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.ML_INFERENCE_SAGEMAKER; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeOptionalString(user); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - return user != null ? builder.field(USER_FIELD, user) : builder; - } - - @Override - public boolean isEmpty() { - return user == null; - } - - @Override - public ApiTaskSettings updatedTaskSettings(Map newSettings) { - var validationException = new ValidationException(); - var newTaskSettings = fromMap(newSettings, validationException); - validationException.throwIfValidationErrorsExist(); - - return new ApiTaskSettings(newTaskSettings.user() != null ? newTaskSettings.user() : user); - } - - static ApiTaskSettings fromMap(Map map, ValidationException exception) { - var user = extractOptionalString(map, USER_FIELD, ModelConfigurations.TASK_SETTINGS, exception); - return new ApiTaskSettings(user); - } - } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/SageMakerOpenAiTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/SageMakerOpenAiTaskSettings.java new file mode 100644 index 0000000000000..4eeba9f69022d --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/SageMakerOpenAiTaskSettings.java @@ -0,0 +1,71 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.sagemaker.schema.openai; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredTaskSchema; + +import java.io.IOException; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; + +record SageMakerOpenAiTaskSettings(@Nullable String user) implements SageMakerStoredTaskSchema { + static final String NAME = "sagemaker_openai_task_settings"; + private static final String USER_FIELD = "user"; + + SageMakerOpenAiTaskSettings(StreamInput in) throws IOException { + this(in.readOptionalString()); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_SAGEMAKER_CHAT_COMPLETION; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalString(user); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return user != null ? builder.field(USER_FIELD, user) : builder; + } + + @Override + public boolean isEmpty() { + return user == null; + } + + @Override + public SageMakerOpenAiTaskSettings updatedTaskSettings(Map newSettings) { + var validationException = new ValidationException(); + var newTaskSettings = fromMap(newSettings, validationException); + validationException.throwIfValidationErrorsExist(); + + return new SageMakerOpenAiTaskSettings(newTaskSettings.user() != null ? newTaskSettings.user() : user); + } + + static SageMakerOpenAiTaskSettings fromMap(Map map, ValidationException exception) { + var user = extractOptionalString(map, USER_FIELD, ModelConfigurations.TASK_SETTINGS, exception); + return new SageMakerOpenAiTaskSettings(user); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchemaPayloadTestCase.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchemaPayloadTestCase.java index 4e480ed4c17b8..ebef5b6eefd9e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchemaPayloadTestCase.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchemaPayloadTestCase.java @@ -11,6 +11,7 @@ import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.VersionedNamedWriteable; +import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; @@ -88,27 +89,33 @@ public void testNamedWriteables() { } public final void testWithUnknownApiServiceSettings() { - SageMakerModel model = mock(); - when(model.apiServiceSettings()).thenReturn(mock()); - when(model.apiTaskSettings()).thenReturn(randomApiTaskSettings()); - when(model.api()).thenReturn("serviceApi"); - when(model.getTaskType()).thenReturn(TaskType.ANY); + // skip the test if we don't have SageMakerStoredServiceSchema for this payload + if (randomApiServiceSettings() != SageMakerStoredServiceSchema.NO_OP) { + SageMakerModel model = mock(); + when(model.apiServiceSettings()).thenReturn(mock()); + when(model.apiTaskSettings()).thenReturn(randomApiTaskSettings()); + when(model.api()).thenReturn("serviceApi"); + when(model.getTaskType()).thenReturn(TaskType.ANY); - var e = assertThrows(IllegalArgumentException.class, () -> payload.requestBytes(model, randomRequest())); + var e = assertThrows(IllegalArgumentException.class, () -> payload.requestBytes(model, randomRequest())); - assertThat(e.getMessage(), startsWith("Unsupported SageMaker settings for api [serviceApi] and task type [any]:")); + assertThat(e.getMessage(), startsWith("Unsupported SageMaker settings for api [serviceApi] and task type [any]:")); + } } public final void testWithUnknownApiTaskSettings() { - SageMakerModel model = mock(); - when(model.apiServiceSettings()).thenReturn(randomApiServiceSettings()); - when(model.apiTaskSettings()).thenReturn(mock()); - when(model.api()).thenReturn("taskApi"); - when(model.getTaskType()).thenReturn(TaskType.ANY); + // skip the test if we don't have SageMakerStoredTaskSchema for this payload + if (randomApiTaskSettings() != SageMakerStoredTaskSchema.NO_OP) { + SageMakerModel model = mock(); + when(model.apiServiceSettings()).thenReturn(randomApiServiceSettings()); + when(model.apiTaskSettings()).thenReturn(mock()); + when(model.api()).thenReturn("taskApi"); + when(model.getTaskType()).thenReturn(TaskType.ANY); - var e = assertThrows(IllegalArgumentException.class, () -> payload.requestBytes(model, randomRequest())); + var e = assertThrows(IllegalArgumentException.class, () -> payload.requestBytes(model, randomRequest())); - assertThat(e.getMessage(), startsWith("Unsupported SageMaker settings for api [taskApi] and task type [any]:")); + assertThat(e.getMessage(), startsWith("Unsupported SageMaker settings for api [taskApi] and task type [any]:")); + } } public final void testUpdate() throws IOException { @@ -131,12 +138,6 @@ public final void testUpdate() throws IOException { }); assertTrue("Map should be empty now that we verified all updated keys and all initial keys", updatedSettings.isEmpty()); } - if (payload instanceof SageMakerStoredTaskSchema taskSchema) { - var otherTaskSettings = randomValueOtherThan(randomApiTaskSettings(), this::randomApiTaskSettings); - var otherTaskSettingsAsMap = toMap(otherTaskSettings); - - taskSchema.updatedTaskSettings(otherTaskSettingsAsMap); - } } protected static SageMakerInferenceRequest randomRequest() { @@ -153,4 +154,8 @@ protected static SageMakerInferenceRequest randomRequest() { protected static void assertSdkBytes(SdkBytes sdkBytes, String expectedValue) { assertThat(sdkBytes.asUtf8String(), equalTo(expectedValue)); } + + protected static void assertJsonSdkBytes(SdkBytes sdkBytes, String expectedValue) throws IOException { + assertThat(sdkBytes.asUtf8String(), equalTo(XContentHelper.stripWhitespace(expectedValue))); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchemasTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchemasTests.java index 8e3c30a95e36b..d306fc2713077 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchemasTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchemasTests.java @@ -11,12 +11,12 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel; +import org.elasticsearch.xpack.inference.services.sagemaker.schema.openai.OpenAiCompletionPayload; import org.elasticsearch.xpack.inference.services.sagemaker.schema.openai.OpenAiTextEmbeddingPayload; import java.util.stream.Stream; import static org.hamcrest.Matchers.containsInAnyOrder; -import static org.hamcrest.Matchers.empty; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyMap; import static org.mockito.ArgumentMatchers.anyString; @@ -41,15 +41,18 @@ public static SageMakerSchema mockSchema() { private static final SageMakerSchemas schemas = new SageMakerSchemas(); public void testSupportedTaskTypes() { - assertThat(schemas.supportedTaskTypes(), containsInAnyOrder(TaskType.TEXT_EMBEDDING)); + assertThat( + schemas.supportedTaskTypes(), + containsInAnyOrder(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION, TaskType.CHAT_COMPLETION) + ); } public void testSupportedStreamingTasks() { - assertThat(schemas.supportedStreamingTasks(), empty()); + assertThat(schemas.supportedStreamingTasks(), containsInAnyOrder(TaskType.COMPLETION, TaskType.CHAT_COMPLETION)); } public void testSchemaFor() { - var payloads = Stream.of(new OpenAiTextEmbeddingPayload()); + var payloads = Stream.of(new OpenAiTextEmbeddingPayload(), new OpenAiCompletionPayload()); payloads.forEach(payload -> { payload.supportedTasks().forEach(taskType -> { var model = mockModel(taskType, payload.api()); @@ -59,7 +62,7 @@ public void testSchemaFor() { } public void testStreamSchemaFor() { - var payloads = Stream.of(/* For when we add support for streaming payloads */); + var payloads = Stream.of(new OpenAiCompletionPayload()); payloads.forEach(payload -> { payload.supportedTasks().forEach(taskType -> { var model = mockModel(taskType, payload.api()); @@ -77,10 +80,11 @@ private SageMakerModel mockModel(TaskType taskType, String api) { public void testMissingTaskTypeThrowsException() { var knownPayload = new OpenAiTextEmbeddingPayload(); - var unknownTaskType = TaskType.COMPLETION; + var unknownTaskType = TaskType.RERANK; var knownModel = mockModel(unknownTaskType, knownPayload.api()); assertThrows( - "Task [completion] is not compatible for service [sagemaker] and api [openai]. Supported tasks: [text_embedding]", + "Task [rerank] is not compatible for service [sagemaker] and api [openai]. " + + "Supported tasks: [text_embedding, completion, chat_completion]", ElasticsearchStatusException.class, () -> schemas.schemaFor(knownModel) ); @@ -105,7 +109,10 @@ public void testMissingStreamSchemaThrowsException() { } public void testNamedWriteables() { - var namedWriteables = Stream.of(new OpenAiTextEmbeddingPayload().namedWriteables()); + var namedWriteables = Stream.of( + new OpenAiTextEmbeddingPayload().namedWriteables(), + new OpenAiCompletionPayload().namedWriteables() + ); var expectedNamedWriteables = Stream.concat( namedWriteables.flatMap(names -> names.map(entry -> entry.name)), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiCompletionPayloadTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiCompletionPayloadTests.java new file mode 100644 index 0000000000000..24be845e5ab66 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiCompletionPayloadTests.java @@ -0,0 +1,283 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.sagemaker.schema.openai; + +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointResponse; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.json.JsonXContent; +import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; +import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest; +import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel; +import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemaPayloadTestCase; +import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredServiceSchema; +import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredTaskSchema; + +import java.io.IOException; +import java.util.List; +import java.util.Set; + +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class OpenAiCompletionPayloadTests extends SageMakerSchemaPayloadTestCase { + @Override + protected OpenAiCompletionPayload payload() { + return new OpenAiCompletionPayload(); + } + + @Override + protected String expectedApi() { + return "openai"; + } + + @Override + protected Set expectedSupportedTaskTypes() { + return Set.of(TaskType.CHAT_COMPLETION, TaskType.COMPLETION); + } + + @Override + protected SageMakerStoredServiceSchema randomApiServiceSettings() { + return SageMakerStoredServiceSchema.NO_OP; + } + + @Override + protected SageMakerStoredTaskSchema randomApiTaskSettings() { + return SageMakerOpenAiTaskSettingsTests.randomApiTaskSettings(); + } + + public void testRequest() throws Exception { + var sdkByes = payload.requestBytes(mockModel("coolPerson"), request(false)); + assertJsonSdkBytes(sdkByes, """ + { + "messages": [ + { + "content": "hello", + "role": "user" + } + ], + "n": 1, + "stream": false, + "user": "coolPerson" + }"""); + } + + public void testRequestWithoutUser() throws Exception { + var sdkByes = payload.requestBytes(mockModel(null), request(false)); + assertJsonSdkBytes(sdkByes, """ + { + "messages": [ + { + "content": "hello", + "role": "user" + } + ], + "n": 1, + "stream": false + }"""); + } + + public void testStreamRequest() throws Exception { + var sdkByes = payload.requestBytes(mockModel("user"), request(true)); + assertJsonSdkBytes(sdkByes, """ + { + "messages":[ + { + "content": "hello", + "role": "user" + } + ], + "n": 1, + "stream": true, + "stream_options": { + "include_usage": true + }, + "user": "user" + }"""); + } + + public void testStreamRequestWithoutUser() throws Exception { + var sdkByes = payload.requestBytes(mockModel(null), request(true)); + assertJsonSdkBytes(sdkByes, """ + { + "messages":[ + { + "content": "hello", + "role": "user" + } + ], + "n": 1, + "stream": true, + "stream_options": { + "include_usage": true + } + }"""); + } + + private SageMakerInferenceRequest request(boolean stream) { + return new SageMakerInferenceRequest(null, null, null, List.of("hello"), stream, InputType.UNSPECIFIED); + } + + private SageMakerModel mockModel(String user) { + SageMakerModel model = mock(); + when(model.apiTaskSettings()).thenReturn(new SageMakerOpenAiTaskSettings(user)); + return model; + } + + public void testResponse() throws Exception { + var responseJson = """ + { + "id": "some-id", + "object": "chat.completion", + "created": 1705397787, + "model": "gpt-3.5-turbo-0613", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "result" + }, + "logprobs": null, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 46, + "completion_tokens": 39, + "total_tokens": 85 + }, + "system_fingerprint": null + } + """; + + var chatCompletionResults = (ChatCompletionResults) payload.responseBody( + mockModel(), + InvokeEndpointResponse.builder().body(SdkBytes.fromUtf8String(responseJson)).build() + ); + + assertThat(chatCompletionResults.getResults().size(), is(1)); + assertThat(chatCompletionResults.getResults().get(0).content(), is("result")); + } + + public void testStreamResponse() throws Exception { + var responseJson = dataPayload(""" + { + "id":"12345", + "object":"chat.completion.chunk", + "created":123456789, + "model":"gpt-4o-mini", + "system_fingerprint": "123456789", + "choices":[ + { + "index":0, + "delta":{ + "content":"test" + }, + "logprobs":null, + "finish_reason":null + } + ] + } + """); + + var streamingResults = payload.streamResponseBody(mockModel(), responseJson); + + assertThat(streamingResults.results().size(), is(1)); + assertThat(streamingResults.results().iterator().next().delta(), is("test")); + } + + private SdkBytes dataPayload(String json) throws IOException { + return SdkBytes.fromUtf8String("data: " + XContentHelper.stripWhitespace(json) + "\n\n"); + } + + private SageMakerModel mockModel() { + SageMakerModel model = mock(); + when(model.apiTaskSettings()).thenReturn(randomApiTaskSettings()); + return model; + } + + public void testChatCompletionRequest() throws Exception { + var message = new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("Hello, world!"), "user", null, null); + var unifiedRequest = new UnifiedCompletionRequest(List.of(message), null, null, null, null, null, null, null); + var sdkBytes = payload.chatCompletionRequestBytes(mockModel("coolUser"), unifiedRequest); + assertJsonSdkBytes(sdkBytes, """ + { + "messages": [ + { + "content": "Hello, world!", + "role": "user" + } + ], + "n": 1, + "stream": true, + "stream_options": { + "include_usage": true + }, + "user": "coolUser" + } + """); + } + + public void testChatCompletionResponse() throws Exception { + var responseJson = """ + { + "id": "chunk1", + "choices": [ + { + "delta": { + "content": "example_content", + "refusal": "example_refusal", + "role": "assistant", + "tool_calls": [ + { + "index": 1, + "id": "tool1", + "function": { + "arguments": "example_arguments", + "name": "example_function" + }, + "type": "function" + } + ] + }, + "finish_reason": "example_reason", + "index": 0 + } + ], + "model": "example_model", + "object": "example_object", + "usage": { + "completion_tokens": 10, + "prompt_tokens": 5, + "total_tokens": 15 + } + } + """; + + var chatCompletionResponse = payload.chatCompletionResponseBody(mockModel(), dataPayload(responseJson)); + + XContentBuilder builder = JsonXContent.contentBuilder(); + chatCompletionResponse.toXContentChunked(null).forEachRemaining(xContent -> { + try { + xContent.toXContent(builder, null); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + + assertEquals(XContentHelper.stripWhitespace(responseJson), Strings.toString(builder).trim()); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiTextEmbeddingPayloadTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiTextEmbeddingPayloadTests.java index 7a85a5e05fab1..35b78b004618c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiTextEmbeddingPayloadTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiTextEmbeddingPayloadTests.java @@ -65,7 +65,7 @@ public void testContentType() { public void testRequestWithSingleInput() throws Exception { SageMakerModel model = mock(); when(model.apiServiceSettings()).thenReturn(new OpenAiTextEmbeddingPayload.ApiServiceSettings(null, false)); - when(model.apiTaskSettings()).thenReturn(new OpenAiTextEmbeddingPayload.ApiTaskSettings((String) null)); + when(model.apiTaskSettings()).thenReturn(new SageMakerOpenAiTaskSettings((String) null)); var request = new SageMakerInferenceRequest(null, null, null, List.of("hello"), randomBoolean(), randomFrom(InputType.values())); var sdkByes = payload.requestBytes(model, request); @@ -76,7 +76,7 @@ public void testRequestWithSingleInput() throws Exception { public void testRequestWithArrayInput() throws Exception { SageMakerModel model = mock(); when(model.apiServiceSettings()).thenReturn(new OpenAiTextEmbeddingPayload.ApiServiceSettings(null, false)); - when(model.apiTaskSettings()).thenReturn(new OpenAiTextEmbeddingPayload.ApiTaskSettings((String) null)); + when(model.apiTaskSettings()).thenReturn(new SageMakerOpenAiTaskSettings((String) null)); var request = new SageMakerInferenceRequest( null, null, @@ -94,7 +94,7 @@ public void testRequestWithArrayInput() throws Exception { public void testRequestWithDimensionsNotSetByUserIgnoreDimensions() throws Exception { SageMakerModel model = mock(); when(model.apiServiceSettings()).thenReturn(new OpenAiTextEmbeddingPayload.ApiServiceSettings(123, false)); - when(model.apiTaskSettings()).thenReturn(new OpenAiTextEmbeddingPayload.ApiTaskSettings((String) null)); + when(model.apiTaskSettings()).thenReturn(new SageMakerOpenAiTaskSettings((String) null)); var request = new SageMakerInferenceRequest( null, null, @@ -112,7 +112,7 @@ public void testRequestWithDimensionsNotSetByUserIgnoreDimensions() throws Excep public void testRequestWithOptionals() throws Exception { SageMakerModel model = mock(); when(model.apiServiceSettings()).thenReturn(new OpenAiTextEmbeddingPayload.ApiServiceSettings(1234, true)); - when(model.apiTaskSettings()).thenReturn(new OpenAiTextEmbeddingPayload.ApiTaskSettings("user")); + when(model.apiTaskSettings()).thenReturn(new SageMakerOpenAiTaskSettings("user")); var request = new SageMakerInferenceRequest("query", null, null, List.of("hello"), randomBoolean(), randomFrom(InputType.values())); var sdkByes = payload.requestBytes(model, request); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/SageMakerOpenAiTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/SageMakerOpenAiTaskSettingsTests.java index 1eaaf4100f5f6..70c25fa283c76 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/SageMakerOpenAiTaskSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/SageMakerOpenAiTaskSettingsTests.java @@ -13,26 +13,26 @@ import java.util.Map; -public class SageMakerOpenAiTaskSettingsTests extends InferenceSettingsTestCase { +public class SageMakerOpenAiTaskSettingsTests extends InferenceSettingsTestCase { @Override - protected OpenAiTextEmbeddingPayload.ApiTaskSettings fromMutableMap(Map mutableMap) { + protected SageMakerOpenAiTaskSettings fromMutableMap(Map mutableMap) { var validationException = new ValidationException(); - var settings = OpenAiTextEmbeddingPayload.ApiTaskSettings.fromMap(mutableMap, validationException); + var settings = SageMakerOpenAiTaskSettings.fromMap(mutableMap, validationException); validationException.throwIfValidationErrorsExist(); return settings; } @Override - protected Writeable.Reader instanceReader() { - return OpenAiTextEmbeddingPayload.ApiTaskSettings::new; + protected Writeable.Reader instanceReader() { + return SageMakerOpenAiTaskSettings::new; } @Override - protected OpenAiTextEmbeddingPayload.ApiTaskSettings createTestInstance() { + protected SageMakerOpenAiTaskSettings createTestInstance() { return randomApiTaskSettings(); } - static OpenAiTextEmbeddingPayload.ApiTaskSettings randomApiTaskSettings() { - return new OpenAiTextEmbeddingPayload.ApiTaskSettings(randomOptionalString()); + static SageMakerOpenAiTaskSettings randomApiTaskSettings() { + return new SageMakerOpenAiTaskSettings(randomOptionalString()); } }