From 65f0c7e8cb453a2d2f4818b812de0ea4ca42a004 Mon Sep 17 00:00:00 2001 From: Tim Grein Date: Tue, 15 Oct 2024 10:55:06 +0200 Subject: [PATCH 1/3] Rename AzureOpenAiCompletion to AzureOpenAiChatCompletion --- .../InferenceNamedWriteablesProvider.java | 12 +++--- .../azureopenai/AzureOpenAiActionCreator.java | 10 ++--- .../azureopenai/AzureOpenAiActionVisitor.java | 4 +- ...reOpenAiChatCompletionRequestManager.java} | 18 ++++---- ... => AzureOpenAiChatCompletionRequest.java} | 12 +++--- ...ureOpenAiChatCompletionRequestEntity.java} | 6 ++- ...reOpenAiChatCompletionResponseEntity.java} | 2 +- .../azureopenai/AzureOpenAiService.java | 4 +- ...va => AzureOpenAiChatCompletionModel.java} | 41 +++++++++++------- ...nAiChatCompletionRequestTaskSettings.java} | 12 +++--- ...eOpenAiChatCompletionServiceSettings.java} | 18 ++++---- ...zureOpenAiChatCompletionTaskSettings.java} | 22 +++++----- .../AzureOpenAiActionCreatorTests.java | 8 ++-- ...AzureOpenAiChatCompletionActionTests.java} | 8 ++-- ...enAiChatCompletionRequestEntityTests.java} | 8 ++-- ...zureOpenAiChatCompletionRequestTests.java} | 12 +++--- ...nAiChatCompletionResponseEntityTests.java} | 12 +++--- .../azureopenai/AzureOpenAiServiceTests.java | 4 +- ... AzureOpenAiChatCompletionModelTests.java} | 22 +++++----- ...atCompletionRequestTaskSettingsTests.java} | 14 +++---- ...AiChatCompletionServiceSettingsTests.java} | 24 ++++++----- ...penAiChatCompletionTaskSettingsTests.java} | 42 +++++++++---------- 22 files changed, 166 insertions(+), 149 deletions(-) rename x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/{AzureOpenAiCompletionRequestManager.java => AzureOpenAiChatCompletionRequestManager.java} (75%) rename x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/{AzureOpenAiCompletionRequest.java => AzureOpenAiChatCompletionRequest.java} (81%) rename x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/{AzureOpenAiCompletionRequestEntity.java => AzureOpenAiChatCompletionRequestEntity.java} (89%) rename x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/azureopenai/{AzureOpenAiCompletionResponseEntity.java => AzureOpenAiChatCompletionResponseEntity.java} (98%) rename x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/{AzureOpenAiCompletionModel.java => AzureOpenAiChatCompletionModel.java} (67%) rename x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/{AzureOpenAiCompletionRequestTaskSettings.java => AzureOpenAiChatCompletionRequestTaskSettings.java} (68%) rename x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/{AzureOpenAiCompletionServiceSettings.java => AzureOpenAiChatCompletionServiceSettings.java} (88%) rename x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/{AzureOpenAiCompletionTaskSettings.java => AzureOpenAiChatCompletionTaskSettings.java} (77%) rename x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/{AzureOpenAiCompletionActionTests.java => AzureOpenAiChatCompletionActionTests.java} (97%) rename x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureopenai/completion/{AzureOpenAiCompletionRequestEntityTests.java => AzureOpenAiChatCompletionRequestEntityTests.java} (83%) rename x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureopenai/completion/{AzureOpenAiCompletionRequestTests.java => AzureOpenAiChatCompletionRequestTests.java} (90%) rename x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/azureopenai/{AzureOpenAiCompletionResponseEntityTests.java => AzureOpenAiChatCompletionResponseEntityTests.java} (95%) rename x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/{AzureOpenAiCompletionModelTests.java => AzureOpenAiChatCompletionModelTests.java} (83%) rename x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/{AzureOpenAiCompletionRequestTaskSettingsTests.java => AzureOpenAiChatCompletionRequestTaskSettingsTests.java} (61%) rename x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/{AzureOpenAiCompletionServiceSettingsTests.java => AzureOpenAiChatCompletionServiceSettingsTests.java} (66%) rename x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/{AzureOpenAiCompletionTaskSettingsTests.java => AzureOpenAiChatCompletionTaskSettingsTests.java} (57%) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index 02bddb6076d69..6db72d5d38424 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -49,8 +49,8 @@ import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsTaskSettings; import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings; -import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionServiceSettings; -import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionTaskSettings; +import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiChatCompletionServiceSettings; +import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiChatCompletionTaskSettings; import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsTaskSettings; import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings; @@ -247,15 +247,15 @@ private static void addAzureOpenAiNamedWriteables(List taskSettings) { - var overriddenModel = AzureOpenAiCompletionModel.of(model, taskSettings); - var requestCreator = new AzureOpenAiCompletionRequestManager(overriddenModel, serviceComponents.threadPool()); + public ExecutableAction create(AzureOpenAiChatCompletionModel model, Map taskSettings) { + var overriddenModel = AzureOpenAiChatCompletionModel.of(model, taskSettings); + var requestCreator = new AzureOpenAiChatCompletionRequestManager(overriddenModel, serviceComponents.threadPool()); var errorMessage = constructFailedToSendRequestMessage(overriddenModel.getUri(), COMPLETION_ERROR_PREFIX); return new SingleInputSenderExecutableAction(sender, requestCreator, errorMessage, COMPLETION_ERROR_PREFIX); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionVisitor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionVisitor.java index f45c1d797085e..ce362bad95a59 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionVisitor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionVisitor.java @@ -8,7 +8,7 @@ package org.elasticsearch.xpack.inference.external.action.azureopenai; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; -import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionModel; +import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiChatCompletionModel; import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModel; import java.util.Map; @@ -16,5 +16,5 @@ public interface AzureOpenAiActionVisitor { ExecutableAction create(AzureOpenAiEmbeddingsModel model, Map taskSettings); - ExecutableAction create(AzureOpenAiCompletionModel model, Map taskSettings); + ExecutableAction create(AzureOpenAiChatCompletionModel model, Map taskSettings); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiChatCompletionRequestManager.java similarity index 75% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiCompletionRequestManager.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiChatCompletionRequestManager.java index d036559ec3dcb..6b65c4f9919a1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiChatCompletionRequestManager.java @@ -15,26 +15,26 @@ import org.elasticsearch.xpack.inference.external.azureopenai.AzureOpenAiResponseHandler; import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; -import org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiCompletionRequest; -import org.elasticsearch.xpack.inference.external.response.azureopenai.AzureOpenAiCompletionResponseEntity; -import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionModel; +import org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiChatCompletionRequest; +import org.elasticsearch.xpack.inference.external.response.azureopenai.AzureOpenAiChatCompletionResponseEntity; +import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiChatCompletionModel; import java.util.Objects; import java.util.function.Supplier; -public class AzureOpenAiCompletionRequestManager extends AzureOpenAiRequestManager { +public class AzureOpenAiChatCompletionRequestManager extends AzureOpenAiRequestManager { - private static final Logger logger = LogManager.getLogger(AzureOpenAiCompletionRequestManager.class); + private static final Logger logger = LogManager.getLogger(AzureOpenAiChatCompletionRequestManager.class); private static final ResponseHandler HANDLER = createCompletionHandler(); - private final AzureOpenAiCompletionModel model; + private final AzureOpenAiChatCompletionModel model; private static ResponseHandler createCompletionHandler() { - return new AzureOpenAiResponseHandler("azure openai completion", AzureOpenAiCompletionResponseEntity::fromResponse, true); + return new AzureOpenAiResponseHandler("azure openai completion", AzureOpenAiChatCompletionResponseEntity::fromResponse, true); } - public AzureOpenAiCompletionRequestManager(AzureOpenAiCompletionModel model, ThreadPool threadPool) { + public AzureOpenAiChatCompletionRequestManager(AzureOpenAiChatCompletionModel model, ThreadPool threadPool) { super(threadPool, model); this.model = Objects.requireNonNull(model); } @@ -49,7 +49,7 @@ public void execute( var docsOnly = DocumentsOnlyInput.of(inferenceInputs); var docsInput = docsOnly.getInputs(); var stream = docsOnly.stream(); - AzureOpenAiCompletionRequest request = new AzureOpenAiCompletionRequest(docsInput, model, stream); + AzureOpenAiChatCompletionRequest request = new AzureOpenAiChatCompletionRequest(docsInput, model, stream); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiChatCompletionRequest.java similarity index 81% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiCompletionRequest.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiChatCompletionRequest.java index 41f05b500efa8..77f5244a9f51b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiCompletionRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiChatCompletionRequest.java @@ -12,24 +12,24 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.xpack.inference.external.request.HttpRequest; import org.elasticsearch.xpack.inference.external.request.Request; -import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionModel; +import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiChatCompletionModel; import java.net.URI; import java.nio.charset.StandardCharsets; import java.util.List; import java.util.Objects; -public class AzureOpenAiCompletionRequest implements AzureOpenAiRequest { +public class AzureOpenAiChatCompletionRequest implements AzureOpenAiRequest { private final List input; private final URI uri; - private final AzureOpenAiCompletionModel model; + private final AzureOpenAiChatCompletionModel model; private final boolean stream; - public AzureOpenAiCompletionRequest(List input, AzureOpenAiCompletionModel model, boolean stream) { + public AzureOpenAiChatCompletionRequest(List input, AzureOpenAiChatCompletionModel model, boolean stream) { this.input = input; this.model = Objects.requireNonNull(model); this.uri = model.getUri(); @@ -39,7 +39,9 @@ public AzureOpenAiCompletionRequest(List input, AzureOpenAiCompletionMod @Override public HttpRequest createHttpRequest() { var httpPost = new HttpPost(uri); - var requestEntity = Strings.toString(new AzureOpenAiCompletionRequestEntity(input, model.getTaskSettings().user(), isStreaming())); + var requestEntity = Strings.toString( + new AzureOpenAiChatCompletionRequestEntity(input, model.getTaskSettings().user(), isStreaming()) + ); ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8)); httpPost.setEntity(byteEntity); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiChatCompletionRequestEntity.java similarity index 89% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiCompletionRequestEntity.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiChatCompletionRequestEntity.java index 725e51c06c494..0daea672f4132 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiChatCompletionRequestEntity.java @@ -16,7 +16,9 @@ import java.util.List; import java.util.Objects; -public record AzureOpenAiCompletionRequestEntity(List messages, @Nullable String user, boolean stream) implements ToXContentObject { +public record AzureOpenAiChatCompletionRequestEntity(List messages, @Nullable String user, boolean stream) + implements + ToXContentObject { private static final String NUMBER_OF_RETURNED_CHOICES_FIELD = "n"; @@ -30,7 +32,7 @@ public record AzureOpenAiCompletionRequestEntity(List messages, @Nullabl private static final String STREAM_FIELD = "stream"; - public AzureOpenAiCompletionRequestEntity { + public AzureOpenAiChatCompletionRequestEntity { Objects.requireNonNull(messages); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/azureopenai/AzureOpenAiCompletionResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/azureopenai/AzureOpenAiChatCompletionResponseEntity.java similarity index 98% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/azureopenai/AzureOpenAiCompletionResponseEntity.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/azureopenai/AzureOpenAiChatCompletionResponseEntity.java index ca1df7027cb40..7743db6a37cfc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/azureopenai/AzureOpenAiCompletionResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/azureopenai/AzureOpenAiChatCompletionResponseEntity.java @@ -23,7 +23,7 @@ import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken; import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField; -public class AzureOpenAiCompletionResponseEntity { +public class AzureOpenAiChatCompletionResponseEntity { private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in Azure OpenAI completions response"; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java index e1c230b98a2f7..2f2fc0e4a0df3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java @@ -36,7 +36,7 @@ import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.ServiceUtils; -import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionModel; +import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiChatCompletionModel; import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModel; import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsServiceSettings; @@ -147,7 +147,7 @@ private static AzureOpenAiModel createModel( ); } case COMPLETION -> { - return new AzureOpenAiCompletionModel( + return new AzureOpenAiChatCompletionModel( inferenceEntityId, taskType, NAME, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiChatCompletionModel.java similarity index 67% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionModel.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiChatCompletionModel.java index c4146b2ba2d30..00bca77c6de61 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiChatCompletionModel.java @@ -21,18 +21,21 @@ import java.net.URISyntaxException; import java.util.Map; -public class AzureOpenAiCompletionModel extends AzureOpenAiModel { +public class AzureOpenAiChatCompletionModel extends AzureOpenAiModel { - public static AzureOpenAiCompletionModel of(AzureOpenAiCompletionModel model, Map taskSettings) { + public static AzureOpenAiChatCompletionModel of(AzureOpenAiChatCompletionModel model, Map taskSettings) { if (taskSettings == null || taskSettings.isEmpty()) { return model; } - var requestTaskSettings = AzureOpenAiCompletionRequestTaskSettings.fromMap(taskSettings); - return new AzureOpenAiCompletionModel(model, AzureOpenAiCompletionTaskSettings.of(model.getTaskSettings(), requestTaskSettings)); + var requestTaskSettings = AzureOpenAiChatCompletionRequestTaskSettings.fromMap(taskSettings); + return new AzureOpenAiChatCompletionModel( + model, + AzureOpenAiChatCompletionTaskSettings.of(model.getTaskSettings(), requestTaskSettings) + ); } - public AzureOpenAiCompletionModel( + public AzureOpenAiChatCompletionModel( String inferenceEntityId, TaskType taskType, String service, @@ -45,19 +48,19 @@ public AzureOpenAiCompletionModel( inferenceEntityId, taskType, service, - AzureOpenAiCompletionServiceSettings.fromMap(serviceSettings, context), - AzureOpenAiCompletionTaskSettings.fromMap(taskSettings), + AzureOpenAiChatCompletionServiceSettings.fromMap(serviceSettings, context), + AzureOpenAiChatCompletionTaskSettings.fromMap(taskSettings), AzureOpenAiSecretSettings.fromMap(secrets) ); } // Should only be used directly for testing - AzureOpenAiCompletionModel( + AzureOpenAiChatCompletionModel( String inferenceEntityId, TaskType taskType, String service, - AzureOpenAiCompletionServiceSettings serviceSettings, - AzureOpenAiCompletionTaskSettings taskSettings, + AzureOpenAiChatCompletionServiceSettings serviceSettings, + AzureOpenAiChatCompletionTaskSettings taskSettings, @Nullable AzureOpenAiSecretSettings secrets ) { super( @@ -72,22 +75,28 @@ public AzureOpenAiCompletionModel( } } - public AzureOpenAiCompletionModel(AzureOpenAiCompletionModel originalModel, AzureOpenAiCompletionServiceSettings serviceSettings) { + public AzureOpenAiChatCompletionModel( + AzureOpenAiChatCompletionModel originalModel, + AzureOpenAiChatCompletionServiceSettings serviceSettings + ) { super(originalModel, serviceSettings); } - private AzureOpenAiCompletionModel(AzureOpenAiCompletionModel originalModel, AzureOpenAiCompletionTaskSettings taskSettings) { + private AzureOpenAiChatCompletionModel( + AzureOpenAiChatCompletionModel originalModel, + AzureOpenAiChatCompletionTaskSettings taskSettings + ) { super(originalModel, taskSettings); } @Override - public AzureOpenAiCompletionServiceSettings getServiceSettings() { - return (AzureOpenAiCompletionServiceSettings) super.getServiceSettings(); + public AzureOpenAiChatCompletionServiceSettings getServiceSettings() { + return (AzureOpenAiChatCompletionServiceSettings) super.getServiceSettings(); } @Override - public AzureOpenAiCompletionTaskSettings getTaskSettings() { - return (AzureOpenAiCompletionTaskSettings) super.getTaskSettings(); + public AzureOpenAiChatCompletionTaskSettings getTaskSettings() { + return (AzureOpenAiChatCompletionTaskSettings) super.getTaskSettings(); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionRequestTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiChatCompletionRequestTaskSettings.java similarity index 68% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionRequestTaskSettings.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiChatCompletionRequestTaskSettings.java index 5dd42bb1b911f..953f0ec46046e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionRequestTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiChatCompletionRequestTaskSettings.java @@ -16,13 +16,15 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields.USER; -public record AzureOpenAiCompletionRequestTaskSettings(@Nullable String user) { +public record AzureOpenAiChatCompletionRequestTaskSettings(@Nullable String user) { - public static final AzureOpenAiCompletionRequestTaskSettings EMPTY_SETTINGS = new AzureOpenAiCompletionRequestTaskSettings(null); + public static final AzureOpenAiChatCompletionRequestTaskSettings EMPTY_SETTINGS = new AzureOpenAiChatCompletionRequestTaskSettings( + null + ); - public static AzureOpenAiCompletionRequestTaskSettings fromMap(Map map) { + public static AzureOpenAiChatCompletionRequestTaskSettings fromMap(Map map) { if (map.isEmpty()) { - return AzureOpenAiCompletionRequestTaskSettings.EMPTY_SETTINGS; + return AzureOpenAiChatCompletionRequestTaskSettings.EMPTY_SETTINGS; } ValidationException validationException = new ValidationException(); @@ -33,6 +35,6 @@ public static AzureOpenAiCompletionRequestTaskSettings fromMap(Map map, ConfigurationParseContext context) { + public static AzureOpenAiChatCompletionServiceSettings fromMap(Map map, ConfigurationParseContext context) { ValidationException validationException = new ValidationException(); var settings = fromMap(map, validationException, context); @@ -66,10 +66,10 @@ public static AzureOpenAiCompletionServiceSettings fromMap(Map m throw validationException; } - return new AzureOpenAiCompletionServiceSettings(settings); + return new AzureOpenAiChatCompletionServiceSettings(settings); } - private static AzureOpenAiCompletionServiceSettings.CommonFields fromMap( + private static AzureOpenAiChatCompletionServiceSettings.CommonFields fromMap( Map map, ValidationException validationException, ConfigurationParseContext context @@ -85,7 +85,7 @@ private static AzureOpenAiCompletionServiceSettings.CommonFields fromMap( context ); - return new AzureOpenAiCompletionServiceSettings.CommonFields(resourceName, deploymentId, apiVersion, rateLimitSettings); + return new AzureOpenAiChatCompletionServiceSettings.CommonFields(resourceName, deploymentId, apiVersion, rateLimitSettings); } private record CommonFields(String resourceName, String deploymentId, String apiVersion, RateLimitSettings rateLimitSettings) {} @@ -96,7 +96,7 @@ private record CommonFields(String resourceName, String deploymentId, String api private final RateLimitSettings rateLimitSettings; - public AzureOpenAiCompletionServiceSettings( + public AzureOpenAiChatCompletionServiceSettings( String resourceName, String deploymentId, String apiVersion, @@ -108,14 +108,14 @@ public AzureOpenAiCompletionServiceSettings( this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); } - public AzureOpenAiCompletionServiceSettings(StreamInput in) throws IOException { + public AzureOpenAiChatCompletionServiceSettings(StreamInput in) throws IOException { resourceName = in.readString(); deploymentId = in.readString(); apiVersion = in.readString(); rateLimitSettings = new RateLimitSettings(in); } - private AzureOpenAiCompletionServiceSettings(AzureOpenAiCompletionServiceSettings.CommonFields fields) { + private AzureOpenAiChatCompletionServiceSettings(AzureOpenAiChatCompletionServiceSettings.CommonFields fields) { this(fields.resourceName, fields.deploymentId, fields.apiVersion, fields.rateLimitSettings); } @@ -183,7 +183,7 @@ public void writeTo(StreamOutput out) throws IOException { public boolean equals(Object object) { if (this == object) return true; if (object == null || getClass() != object.getClass()) return false; - AzureOpenAiCompletionServiceSettings that = (AzureOpenAiCompletionServiceSettings) object; + AzureOpenAiChatCompletionServiceSettings that = (AzureOpenAiChatCompletionServiceSettings) object; return Objects.equals(resourceName, that.resourceName) && Objects.equals(deploymentId, that.deploymentId) && Objects.equals(apiVersion, that.apiVersion) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiChatCompletionTaskSettings.java similarity index 77% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionTaskSettings.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiChatCompletionTaskSettings.java index 3008a543b8fea..d5f4b1824f15e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiChatCompletionTaskSettings.java @@ -24,13 +24,13 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; -public class AzureOpenAiCompletionTaskSettings implements TaskSettings { +public class AzureOpenAiChatCompletionTaskSettings implements TaskSettings { public static final String NAME = "azure_openai_completion_task_settings"; public static final String USER = "user"; - public static AzureOpenAiCompletionTaskSettings fromMap(Map map) { + public static AzureOpenAiChatCompletionTaskSettings fromMap(Map map) { ValidationException validationException = new ValidationException(); String user = extractOptionalString(map, USER, ModelConfigurations.TASK_SETTINGS, validationException); @@ -39,24 +39,24 @@ public static AzureOpenAiCompletionTaskSettings fromMap(Map map) throw validationException; } - return new AzureOpenAiCompletionTaskSettings(user); + return new AzureOpenAiChatCompletionTaskSettings(user); } private final String user; - public static AzureOpenAiCompletionTaskSettings of( - AzureOpenAiCompletionTaskSettings originalSettings, - AzureOpenAiCompletionRequestTaskSettings requestSettings + public static AzureOpenAiChatCompletionTaskSettings of( + AzureOpenAiChatCompletionTaskSettings originalSettings, + AzureOpenAiChatCompletionRequestTaskSettings requestSettings ) { var userToUse = requestSettings.user() == null ? originalSettings.user : requestSettings.user(); - return new AzureOpenAiCompletionTaskSettings(userToUse); + return new AzureOpenAiChatCompletionTaskSettings(userToUse); } - public AzureOpenAiCompletionTaskSettings(@Nullable String user) { + public AzureOpenAiChatCompletionTaskSettings(@Nullable String user) { this.user = user; } - public AzureOpenAiCompletionTaskSettings(StreamInput in) throws IOException { + public AzureOpenAiChatCompletionTaskSettings(StreamInput in) throws IOException { this.user = in.readOptionalString(); } @@ -100,7 +100,7 @@ public void writeTo(StreamOutput out) throws IOException { public boolean equals(Object object) { if (this == object) return true; if (object == null || getClass() != object.getClass()) return false; - AzureOpenAiCompletionTaskSettings that = (AzureOpenAiCompletionTaskSettings) object; + AzureOpenAiChatCompletionTaskSettings that = (AzureOpenAiChatCompletionTaskSettings) object; return Objects.equals(user, that.user); } @@ -111,7 +111,7 @@ public int hashCode() { @Override public TaskSettings updatedTaskSettings(Map newSettings) { - AzureOpenAiCompletionRequestTaskSettings updatedSettings = AzureOpenAiCompletionRequestTaskSettings.fromMap( + AzureOpenAiChatCompletionRequestTaskSettings updatedSettings = AzureOpenAiChatCompletionRequestTaskSettings.fromMap( new HashMap<>(newSettings) ); return of(this, updatedSettings); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreatorTests.java index 45a2fb0954c79..f9a5f04f47142 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreatorTests.java @@ -47,7 +47,7 @@ import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; -import static org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionModelTests.createCompletionModel; +import static org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiChatCompletionModelTests.createCompletionModel; import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModelTests.createModel; import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsRequestTaskSettingsTests.createRequestTaskSettingsMap; import static org.hamcrest.Matchers.equalTo; @@ -438,7 +438,7 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException { } } - public void testInfer_AzureOpenAiCompletion_WithOverriddenUser() throws IOException { + public void testInfer_AzureOpenAiChatCompletion_WithOverriddenUser() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var sender = createSender(senderFactory)) { @@ -496,7 +496,7 @@ public void testInfer_AzureOpenAiCompletion_WithOverriddenUser() throws IOExcept } } - public void testInfer_AzureOpenAiCompletionModel_WithoutUser() throws IOException { + public void testInfer_AzureOpenAiChatCompletionModel_WithoutUser() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var sender = createSender(senderFactory)) { @@ -551,7 +551,7 @@ public void testInfer_AzureOpenAiCompletionModel_WithoutUser() throws IOExceptio } } - public void testInfer_AzureOpenAiCompletionModel_FailsFromInvalidResponseFormat() throws IOException { + public void testInfer_AzureOpenAiChatCompletionModel_FailsFromInvalidResponseFormat() throws IOException { // timeout as zero for no retries var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, ZERO_TIMEOUT_SETTINGS); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiChatCompletionActionTests.java similarity index 97% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiCompletionActionTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiChatCompletionActionTests.java index 4c7683c882816..0f05fdde9d597 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiChatCompletionActionTests.java @@ -25,7 +25,7 @@ import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; -import org.elasticsearch.xpack.inference.external.http.sender.AzureOpenAiCompletionRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.AzureOpenAiChatCompletionRequestManager; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.external.http.sender.Sender; @@ -49,7 +49,7 @@ import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; -import static org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionModelTests.createCompletionModel; +import static org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiChatCompletionModelTests.createCompletionModel; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; import static org.mockito.ArgumentMatchers.any; @@ -57,7 +57,7 @@ import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; -public class AzureOpenAiCompletionActionTests extends ESTestCase { +public class AzureOpenAiChatCompletionActionTests extends ESTestCase { private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); private final MockWebServer webServer = new MockWebServer(); @@ -196,7 +196,7 @@ private ExecutableAction createAction( try { var model = createCompletionModel(resourceName, deploymentId, apiVersion, user, apiKey, null, inferenceEntityId); model.setUri(new URI(getUrl(webServer))); - var requestCreator = new AzureOpenAiCompletionRequestManager(model, threadPool); + var requestCreator = new AzureOpenAiChatCompletionRequestManager(model, threadPool); var errorMessage = constructFailedToSendRequestMessage(model.getUri(), "Azure OpenAI completion"); return new SingleInputSenderExecutableAction(sender, requestCreator, errorMessage, "Azure OpenAI completion"); } catch (URISyntaxException e) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureopenai/completion/AzureOpenAiCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureopenai/completion/AzureOpenAiChatCompletionRequestEntityTests.java similarity index 83% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureopenai/completion/AzureOpenAiCompletionRequestEntityTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureopenai/completion/AzureOpenAiChatCompletionRequestEntityTests.java index 6942f62756c50..19baef4192189 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureopenai/completion/AzureOpenAiCompletionRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureopenai/completion/AzureOpenAiChatCompletionRequestEntityTests.java @@ -12,17 +12,17 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiCompletionRequestEntity; +import org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiChatCompletionRequestEntity; import java.io.IOException; import java.util.List; import static org.hamcrest.CoreMatchers.is; -public class AzureOpenAiCompletionRequestEntityTests extends ESTestCase { +public class AzureOpenAiChatCompletionRequestEntityTests extends ESTestCase { public void testXContent_WritesSingleMessage_DoesNotWriteUserWhenItIsNull() throws IOException { - var entity = new AzureOpenAiCompletionRequestEntity(List.of("input"), null, false); + var entity = new AzureOpenAiChatCompletionRequestEntity(List.of("input"), null, false); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, null); @@ -33,7 +33,7 @@ public void testXContent_WritesSingleMessage_DoesNotWriteUserWhenItIsNull() thro } public void testXContent_WritesSingleMessage_WriteUserWhenItIsNull() throws IOException { - var entity = new AzureOpenAiCompletionRequestEntity(List.of("input"), "user", false); + var entity = new AzureOpenAiChatCompletionRequestEntity(List.of("input"), "user", false); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, null); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureopenai/completion/AzureOpenAiCompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureopenai/completion/AzureOpenAiChatCompletionRequestTests.java similarity index 90% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureopenai/completion/AzureOpenAiCompletionRequestTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureopenai/completion/AzureOpenAiChatCompletionRequestTests.java index d2761bf007927..e0e5612acc5ed 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureopenai/completion/AzureOpenAiCompletionRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureopenai/completion/AzureOpenAiChatCompletionRequestTests.java @@ -11,8 +11,8 @@ import org.apache.http.client.methods.HttpPost; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiCompletionRequest; -import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionModelTests; +import org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiChatCompletionRequest; +import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiChatCompletionModelTests; import java.io.IOException; import java.util.List; @@ -23,7 +23,7 @@ import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; -public class AzureOpenAiCompletionRequestTests extends ESTestCase { +public class AzureOpenAiChatCompletionRequestTests extends ESTestCase { public void testCreateRequest_WithApiKeyDefined() throws IOException { var input = "input"; @@ -75,7 +75,7 @@ public void testCreateRequest_WithEntraIdDefined() throws IOException { assertThat(requestMap.get("n"), is(1)); } - protected AzureOpenAiCompletionRequest createRequest( + protected AzureOpenAiChatCompletionRequest createRequest( String resource, String deployment, String apiVersion, @@ -84,7 +84,7 @@ protected AzureOpenAiCompletionRequest createRequest( String input, String user ) { - var completionModel = AzureOpenAiCompletionModelTests.createCompletionModel( + var completionModel = AzureOpenAiChatCompletionModelTests.createCompletionModel( resource, deployment, apiVersion, @@ -94,7 +94,7 @@ protected AzureOpenAiCompletionRequest createRequest( "id" ); - return new AzureOpenAiCompletionRequest(List.of(input), completionModel, false); + return new AzureOpenAiChatCompletionRequest(List.of(input), completionModel, false); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/azureopenai/AzureOpenAiCompletionResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/azureopenai/AzureOpenAiChatCompletionResponseEntityTests.java similarity index 95% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/azureopenai/AzureOpenAiCompletionResponseEntityTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/azureopenai/AzureOpenAiChatCompletionResponseEntityTests.java index ec76f43a6d52f..2b7a15c0d0fd0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/azureopenai/AzureOpenAiCompletionResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/azureopenai/AzureOpenAiChatCompletionResponseEntityTests.java @@ -20,7 +20,7 @@ import static org.hamcrest.Matchers.is; import static org.mockito.Mockito.mock; -public class AzureOpenAiCompletionResponseEntityTests extends ESTestCase { +public class AzureOpenAiChatCompletionResponseEntityTests extends ESTestCase { public void testFromResponse_CreatesResultsForASingleItem() throws IOException { String responseJson = """ @@ -86,7 +86,7 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException { } }"""; - ChatCompletionResults chatCompletionResults = AzureOpenAiCompletionResponseEntity.fromResponse( + ChatCompletionResults chatCompletionResults = AzureOpenAiChatCompletionResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); @@ -115,7 +115,7 @@ public void testFromResponse_FailsWhenChoicesFieldIsNotPresent() { var thrownException = expectThrows( IllegalStateException.class, - () -> AzureOpenAiCompletionResponseEntity.fromResponse( + () -> AzureOpenAiChatCompletionResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ) @@ -143,7 +143,7 @@ public void testFromResponse_FailsWhenChoicesFieldIsNotAnArray() { var thrownException = expectThrows( ParsingException.class, - () -> AzureOpenAiCompletionResponseEntity.fromResponse( + () -> AzureOpenAiChatCompletionResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ) @@ -175,7 +175,7 @@ public void testFromResponse_FailsWhenMessageDoesNotExist() { var thrownException = expectThrows( IllegalStateException.class, - () -> AzureOpenAiCompletionResponseEntity.fromResponse( + () -> AzureOpenAiChatCompletionResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ) @@ -202,7 +202,7 @@ public void testFromResponse_FailsWhenMessageValueIsAString() { var thrownException = expectThrows( ParsingException.class, - () -> AzureOpenAiCompletionResponseEntity.fromResponse( + () -> AzureOpenAiChatCompletionResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java index 900b666c0b8fb..d156be1396e9f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java @@ -39,7 +39,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion; -import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionModelTests; +import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiChatCompletionModelTests; import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModel; import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModelTests; import org.hamcrest.CoreMatchers; @@ -1458,7 +1458,7 @@ public void testInfer_StreamRequest() throws Exception { private InferenceServiceResults streamChatCompletion() throws IOException, URISyntaxException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) { - var model = AzureOpenAiCompletionModelTests.createCompletionModel( + var model = AzureOpenAiChatCompletionModelTests.createCompletionModel( "resource", "deployment", "apiversion", diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiChatCompletionModelTests.java similarity index 83% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionModelTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiChatCompletionModelTests.java index 93d948a5bdcf3..dc4e3895bf297 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiChatCompletionModelTests.java @@ -22,7 +22,7 @@ import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.sameInstance; -public class AzureOpenAiCompletionModelTests extends ESTestCase { +public class AzureOpenAiChatCompletionModelTests extends ESTestCase { public void testOverrideWith_UpdatedTaskSettings_OverridesUser() { var resource = "resource"; @@ -37,7 +37,7 @@ public void testOverrideWith_UpdatedTaskSettings_OverridesUser() { var model = createCompletionModel(resource, deploymentId, apiVersion, user, apiKey, entraId, inferenceEntityId); var requestTaskSettingsMap = taskSettingsMap(userOverride); - var overriddenModel = AzureOpenAiCompletionModel.of(model, requestTaskSettingsMap); + var overriddenModel = AzureOpenAiChatCompletionModel.of(model, requestTaskSettingsMap); assertThat( overriddenModel, @@ -48,14 +48,14 @@ public void testOverrideWith_UpdatedTaskSettings_OverridesUser() { public void testOverrideWith_EmptyMap_OverridesNothing() { var model = createCompletionModel("resource", "deployment", "api version", "user", "api key", "entra id", "inference entity id"); var requestTaskSettingsMap = Map.of(); - var overriddenModel = AzureOpenAiCompletionModel.of(model, requestTaskSettingsMap); + var overriddenModel = AzureOpenAiChatCompletionModel.of(model, requestTaskSettingsMap); assertThat(overriddenModel, sameInstance(model)); } public void testOverrideWith_NullMap_OverridesNothing() { var model = createCompletionModel("resource", "deployment", "api version", "user", "api key", "entra id", "inference entity id"); - var overriddenModel = AzureOpenAiCompletionModel.of(model, null); + var overriddenModel = AzureOpenAiChatCompletionModel.of(model, null); assertThat(overriddenModel, sameInstance(model)); } @@ -71,10 +71,10 @@ public void testOverrideWith_UpdatedServiceSettings_OverridesApiVersion() { var apiVersion = "api version"; var updatedApiVersion = "updated api version"; - var updatedServiceSettings = new AzureOpenAiCompletionServiceSettings(resource, deploymentId, updatedApiVersion, null); + var updatedServiceSettings = new AzureOpenAiChatCompletionServiceSettings(resource, deploymentId, updatedApiVersion, null); var model = createCompletionModel(resource, deploymentId, apiVersion, user, apiKey, entraId, inferenceEntityId); - var overriddenModel = new AzureOpenAiCompletionModel(model, updatedServiceSettings); + var overriddenModel = new AzureOpenAiChatCompletionModel(model, updatedServiceSettings); assertThat( overriddenModel, @@ -99,7 +99,7 @@ public void testBuildUriString() throws URISyntaxException { ); } - public static AzureOpenAiCompletionModel createModelWithRandomValues() { + public static AzureOpenAiChatCompletionModel createModelWithRandomValues() { return createCompletionModel( randomAlphaOfLength(10), randomAlphaOfLength(10), @@ -111,7 +111,7 @@ public static AzureOpenAiCompletionModel createModelWithRandomValues() { ); } - public static AzureOpenAiCompletionModel createCompletionModel( + public static AzureOpenAiChatCompletionModel createCompletionModel( String resourceName, String deploymentId, String apiVersion, @@ -123,12 +123,12 @@ public static AzureOpenAiCompletionModel createCompletionModel( var secureApiKey = apiKey != null ? new SecureString(apiKey.toCharArray()) : null; var secureEntraId = entraId != null ? new SecureString(entraId.toCharArray()) : null; - return new AzureOpenAiCompletionModel( + return new AzureOpenAiChatCompletionModel( inferenceEntityId, TaskType.COMPLETION, "service", - new AzureOpenAiCompletionServiceSettings(resourceName, deploymentId, apiVersion, null), - new AzureOpenAiCompletionTaskSettings(user), + new AzureOpenAiChatCompletionServiceSettings(resourceName, deploymentId, apiVersion, null), + new AzureOpenAiChatCompletionTaskSettings(user), new AzureOpenAiSecretSettings(secureApiKey, secureEntraId) ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionRequestTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiChatCompletionRequestTaskSettingsTests.java similarity index 61% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionRequestTaskSettingsTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiChatCompletionRequestTaskSettingsTests.java index 51963c275a08a..83007b7e39b16 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionRequestTaskSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiChatCompletionRequestTaskSettingsTests.java @@ -17,27 +17,27 @@ import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.is; -public class AzureOpenAiCompletionRequestTaskSettingsTests extends ESTestCase { +public class AzureOpenAiChatCompletionRequestTaskSettingsTests extends ESTestCase { public void testFromMap_ReturnsEmptySettings_WhenMapIsEmpty() { - var settings = AzureOpenAiCompletionRequestTaskSettings.fromMap(new HashMap<>(Map.of())); - assertThat(settings, is(AzureOpenAiCompletionRequestTaskSettings.EMPTY_SETTINGS)); + var settings = AzureOpenAiChatCompletionRequestTaskSettings.fromMap(new HashMap<>(Map.of())); + assertThat(settings, is(AzureOpenAiChatCompletionRequestTaskSettings.EMPTY_SETTINGS)); } public void testFromMap_ReturnsEmptySettings_WhenMapDoesNotContainKnownFields() { - var settings = AzureOpenAiCompletionRequestTaskSettings.fromMap(new HashMap<>(Map.of("key", "model"))); - assertThat(settings, is(AzureOpenAiCompletionRequestTaskSettings.EMPTY_SETTINGS)); + var settings = AzureOpenAiChatCompletionRequestTaskSettings.fromMap(new HashMap<>(Map.of("key", "model"))); + assertThat(settings, is(AzureOpenAiChatCompletionRequestTaskSettings.EMPTY_SETTINGS)); } public void testFromMap_ReturnsUser() { - var settings = AzureOpenAiCompletionRequestTaskSettings.fromMap(new HashMap<>(Map.of(AzureOpenAiServiceFields.USER, "user"))); + var settings = AzureOpenAiChatCompletionRequestTaskSettings.fromMap(new HashMap<>(Map.of(AzureOpenAiServiceFields.USER, "user"))); assertThat(settings.user(), is("user")); } public void testFromMap_WhenUserIsEmpty_ThrowsValidationException() { var exception = expectThrows( ValidationException.class, - () -> AzureOpenAiCompletionRequestTaskSettings.fromMap(new HashMap<>(Map.of(AzureOpenAiServiceFields.USER, ""))) + () -> AzureOpenAiChatCompletionRequestTaskSettings.fromMap(new HashMap<>(Map.of(AzureOpenAiServiceFields.USER, ""))) ); assertThat(exception.getMessage(), containsString("[user] must be a non-empty string")); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiChatCompletionServiceSettingsTests.java similarity index 66% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionServiceSettingsTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiChatCompletionServiceSettingsTests.java index 797cad8f300ae..c71d7270f4dec 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiChatCompletionServiceSettingsTests.java @@ -22,14 +22,15 @@ import static org.hamcrest.Matchers.is; -public class AzureOpenAiCompletionServiceSettingsTests extends AbstractWireSerializingTestCase { +public class AzureOpenAiChatCompletionServiceSettingsTests extends AbstractWireSerializingTestCase< + AzureOpenAiChatCompletionServiceSettings> { - private static AzureOpenAiCompletionServiceSettings createRandom() { + private static AzureOpenAiChatCompletionServiceSettings createRandom() { var resourceName = randomAlphaOfLength(8); var deploymentId = randomAlphaOfLength(8); var apiVersion = randomAlphaOfLength(8); - return new AzureOpenAiCompletionServiceSettings(resourceName, deploymentId, apiVersion, null); + return new AzureOpenAiChatCompletionServiceSettings(resourceName, deploymentId, apiVersion, null); } public void testFromMap_Request_CreatesSettingsCorrectly() { @@ -37,7 +38,7 @@ public void testFromMap_Request_CreatesSettingsCorrectly() { var deploymentId = "this-deployment"; var apiVersion = "2024-01-01"; - var serviceSettings = AzureOpenAiCompletionServiceSettings.fromMap( + var serviceSettings = AzureOpenAiChatCompletionServiceSettings.fromMap( new HashMap<>( Map.of( AzureOpenAiServiceFields.RESOURCE_NAME, @@ -51,11 +52,11 @@ public void testFromMap_Request_CreatesSettingsCorrectly() { ConfigurationParseContext.PERSISTENT ); - assertThat(serviceSettings, is(new AzureOpenAiCompletionServiceSettings(resourceName, deploymentId, apiVersion, null))); + assertThat(serviceSettings, is(new AzureOpenAiChatCompletionServiceSettings(resourceName, deploymentId, apiVersion, null))); } public void testToXContent_WritesAllValues() throws IOException { - var entity = new AzureOpenAiCompletionServiceSettings("resource", "deployment", "2024", null); + var entity = new AzureOpenAiChatCompletionServiceSettings("resource", "deployment", "2024", null); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, null); @@ -66,17 +67,18 @@ public void testToXContent_WritesAllValues() throws IOException { } @Override - protected Writeable.Reader instanceReader() { - return AzureOpenAiCompletionServiceSettings::new; + protected Writeable.Reader instanceReader() { + return AzureOpenAiChatCompletionServiceSettings::new; } @Override - protected AzureOpenAiCompletionServiceSettings createTestInstance() { + protected AzureOpenAiChatCompletionServiceSettings createTestInstance() { return createRandom(); } @Override - protected AzureOpenAiCompletionServiceSettings mutateInstance(AzureOpenAiCompletionServiceSettings instance) throws IOException { - return randomValueOtherThan(instance, AzureOpenAiCompletionServiceSettingsTests::createRandom); + protected AzureOpenAiChatCompletionServiceSettings mutateInstance(AzureOpenAiChatCompletionServiceSettings instance) + throws IOException { + return randomValueOtherThan(instance, AzureOpenAiChatCompletionServiceSettingsTests::createRandom); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiChatCompletionTaskSettingsTests.java similarity index 57% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionTaskSettingsTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiChatCompletionTaskSettingsTests.java index 9d77abfe6d512..2c069c065a483 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionTaskSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiChatCompletionTaskSettingsTests.java @@ -21,15 +21,15 @@ import static org.hamcrest.Matchers.is; -public class AzureOpenAiCompletionTaskSettingsTests extends AbstractWireSerializingTestCase { +public class AzureOpenAiChatCompletionTaskSettingsTests extends AbstractWireSerializingTestCase { - public static AzureOpenAiCompletionTaskSettings createRandomWithUser() { - return new AzureOpenAiCompletionTaskSettings(randomAlphaOfLength(15)); + public static AzureOpenAiChatCompletionTaskSettings createRandomWithUser() { + return new AzureOpenAiChatCompletionTaskSettings(randomAlphaOfLength(15)); } - public static AzureOpenAiCompletionTaskSettings createRandom() { + public static AzureOpenAiChatCompletionTaskSettings createRandom() { var user = randomBoolean() ? randomAlphaOfLength(15) : null; - return new AzureOpenAiCompletionTaskSettings(user); + return new AzureOpenAiChatCompletionTaskSettings(user); } public void testIsEmpty() { @@ -41,7 +41,7 @@ public void testIsEmpty() { public void testUpdatedTaskSettings() { var initialSettings = createRandom(); var newSettings = createRandom(); - AzureOpenAiCompletionTaskSettings updatedSettings = (AzureOpenAiCompletionTaskSettings) initialSettings.updatedTaskSettings( + AzureOpenAiChatCompletionTaskSettings updatedSettings = (AzureOpenAiChatCompletionTaskSettings) initialSettings.updatedTaskSettings( newSettings.user() == null ? Map.of() : Map.of(AzureOpenAiServiceFields.USER, newSettings.user()) ); @@ -52,8 +52,8 @@ public void testFromMap_WithUser() { var user = "user"; assertThat( - new AzureOpenAiCompletionTaskSettings(user), - is(AzureOpenAiCompletionTaskSettings.fromMap(new HashMap<>(Map.of(AzureOpenAiServiceFields.USER, user)))) + new AzureOpenAiChatCompletionTaskSettings(user), + is(AzureOpenAiChatCompletionTaskSettings.fromMap(new HashMap<>(Map.of(AzureOpenAiServiceFields.USER, user)))) ); } @@ -70,16 +70,16 @@ public void testFromMap_UserIsEmptyString() { } public void testFromMap_MissingUser_DoesNotThrowException() { - var taskSettings = AzureOpenAiCompletionTaskSettings.fromMap(new HashMap<>(Map.of())); + var taskSettings = AzureOpenAiChatCompletionTaskSettings.fromMap(new HashMap<>(Map.of())); assertNull(taskSettings.user()); } public void testOverrideWith_KeepsOriginalValuesWithOverridesAreNull() { - var taskSettings = AzureOpenAiCompletionTaskSettings.fromMap(new HashMap<>(Map.of(AzureOpenAiServiceFields.USER, "user"))); + var taskSettings = AzureOpenAiChatCompletionTaskSettings.fromMap(new HashMap<>(Map.of(AzureOpenAiServiceFields.USER, "user"))); - var overriddenTaskSettings = AzureOpenAiCompletionTaskSettings.of( + var overriddenTaskSettings = AzureOpenAiChatCompletionTaskSettings.of( taskSettings, - AzureOpenAiCompletionRequestTaskSettings.EMPTY_SETTINGS + AzureOpenAiChatCompletionRequestTaskSettings.EMPTY_SETTINGS ); assertThat(overriddenTaskSettings, is(taskSettings)); } @@ -88,28 +88,28 @@ public void testOverrideWith_UsesOverriddenSettings() { var user = "user"; var userOverride = "user override"; - var taskSettings = AzureOpenAiCompletionTaskSettings.fromMap(new HashMap<>(Map.of(AzureOpenAiServiceFields.USER, user))); + var taskSettings = AzureOpenAiChatCompletionTaskSettings.fromMap(new HashMap<>(Map.of(AzureOpenAiServiceFields.USER, user))); - var requestTaskSettings = AzureOpenAiCompletionRequestTaskSettings.fromMap( + var requestTaskSettings = AzureOpenAiChatCompletionRequestTaskSettings.fromMap( new HashMap<>(Map.of(AzureOpenAiServiceFields.USER, userOverride)) ); - var overriddenTaskSettings = AzureOpenAiCompletionTaskSettings.of(taskSettings, requestTaskSettings); - assertThat(overriddenTaskSettings, is(new AzureOpenAiCompletionTaskSettings(userOverride))); + var overriddenTaskSettings = AzureOpenAiChatCompletionTaskSettings.of(taskSettings, requestTaskSettings); + assertThat(overriddenTaskSettings, is(new AzureOpenAiChatCompletionTaskSettings(userOverride))); } @Override - protected Writeable.Reader instanceReader() { - return AzureOpenAiCompletionTaskSettings::new; + protected Writeable.Reader instanceReader() { + return AzureOpenAiChatCompletionTaskSettings::new; } @Override - protected AzureOpenAiCompletionTaskSettings createTestInstance() { + protected AzureOpenAiChatCompletionTaskSettings createTestInstance() { return createRandomWithUser(); } @Override - protected AzureOpenAiCompletionTaskSettings mutateInstance(AzureOpenAiCompletionTaskSettings instance) throws IOException { - return randomValueOtherThan(instance, AzureOpenAiCompletionTaskSettingsTests::createRandomWithUser); + protected AzureOpenAiChatCompletionTaskSettings mutateInstance(AzureOpenAiChatCompletionTaskSettings instance) throws IOException { + return randomValueOtherThan(instance, AzureOpenAiChatCompletionTaskSettingsTests::createRandomWithUser); } } From fdd4ed89ef633bc4af3fe025b021b55caab8fe89 Mon Sep 17 00:00:00 2001 From: Tim Grein Date: Mon, 21 Oct 2024 09:49:22 +0200 Subject: [PATCH 2/3] WIP --- .../InferenceNamedWriteablesProvider.java | 6 ++--- .../AzureOpenAiChatCompletionModel.java | 10 ++++---- ...AzureOpenAiCompletionServiceSettings.java} | 18 +++++++------- .../AzureOpenAiChatCompletionModelTests.java | 4 ++-- ...OpenAiCompletionServiceSettingsTests.java} | 24 +++++++++---------- 5 files changed, 31 insertions(+), 31 deletions(-) rename x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/{AzureOpenAiChatCompletionServiceSettings.java => AzureOpenAiCompletionServiceSettings.java} (88%) rename x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/{AzureOpenAiChatCompletionServiceSettingsTests.java => AzureOpenAiCompletionServiceSettingsTests.java} (67%) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index 6db72d5d38424..8d86fcbe732cc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -49,7 +49,7 @@ import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsTaskSettings; import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings; -import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiChatCompletionServiceSettings; +import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiChatCompletionTaskSettings; import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsTaskSettings; @@ -247,8 +247,8 @@ private static void addAzureOpenAiNamedWriteables(List map, ConfigurationParseContext context) { + public static AzureOpenAiCompletionServiceSettings fromMap(Map map, ConfigurationParseContext context) { ValidationException validationException = new ValidationException(); var settings = fromMap(map, validationException, context); @@ -66,10 +66,10 @@ public static AzureOpenAiChatCompletionServiceSettings fromMap(Map map, ValidationException validationException, ConfigurationParseContext context @@ -85,7 +85,7 @@ private static AzureOpenAiChatCompletionServiceSettings.CommonFields fromMap( context ); - return new AzureOpenAiChatCompletionServiceSettings.CommonFields(resourceName, deploymentId, apiVersion, rateLimitSettings); + return new AzureOpenAiCompletionServiceSettings.CommonFields(resourceName, deploymentId, apiVersion, rateLimitSettings); } private record CommonFields(String resourceName, String deploymentId, String apiVersion, RateLimitSettings rateLimitSettings) {} @@ -96,7 +96,7 @@ private record CommonFields(String resourceName, String deploymentId, String api private final RateLimitSettings rateLimitSettings; - public AzureOpenAiChatCompletionServiceSettings( + public AzureOpenAiCompletionServiceSettings( String resourceName, String deploymentId, String apiVersion, @@ -108,14 +108,14 @@ public AzureOpenAiChatCompletionServiceSettings( this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); } - public AzureOpenAiChatCompletionServiceSettings(StreamInput in) throws IOException { + public AzureOpenAiCompletionServiceSettings(StreamInput in) throws IOException { resourceName = in.readString(); deploymentId = in.readString(); apiVersion = in.readString(); rateLimitSettings = new RateLimitSettings(in); } - private AzureOpenAiChatCompletionServiceSettings(AzureOpenAiChatCompletionServiceSettings.CommonFields fields) { + private AzureOpenAiCompletionServiceSettings(AzureOpenAiCompletionServiceSettings.CommonFields fields) { this(fields.resourceName, fields.deploymentId, fields.apiVersion, fields.rateLimitSettings); } @@ -183,7 +183,7 @@ public void writeTo(StreamOutput out) throws IOException { public boolean equals(Object object) { if (this == object) return true; if (object == null || getClass() != object.getClass()) return false; - AzureOpenAiChatCompletionServiceSettings that = (AzureOpenAiChatCompletionServiceSettings) object; + AzureOpenAiCompletionServiceSettings that = (AzureOpenAiCompletionServiceSettings) object; return Objects.equals(resourceName, that.resourceName) && Objects.equals(deploymentId, that.deploymentId) && Objects.equals(apiVersion, that.apiVersion) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiChatCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiChatCompletionModelTests.java index dc4e3895bf297..19667c6079b58 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiChatCompletionModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiChatCompletionModelTests.java @@ -71,7 +71,7 @@ public void testOverrideWith_UpdatedServiceSettings_OverridesApiVersion() { var apiVersion = "api version"; var updatedApiVersion = "updated api version"; - var updatedServiceSettings = new AzureOpenAiChatCompletionServiceSettings(resource, deploymentId, updatedApiVersion, null); + var updatedServiceSettings = new AzureOpenAiCompletionServiceSettings(resource, deploymentId, updatedApiVersion, null); var model = createCompletionModel(resource, deploymentId, apiVersion, user, apiKey, entraId, inferenceEntityId); var overriddenModel = new AzureOpenAiChatCompletionModel(model, updatedServiceSettings); @@ -127,7 +127,7 @@ public static AzureOpenAiChatCompletionModel createCompletionModel( inferenceEntityId, TaskType.COMPLETION, "service", - new AzureOpenAiChatCompletionServiceSettings(resourceName, deploymentId, apiVersion, null), + new AzureOpenAiCompletionServiceSettings(resourceName, deploymentId, apiVersion, null), new AzureOpenAiChatCompletionTaskSettings(user), new AzureOpenAiSecretSettings(secureApiKey, secureEntraId) ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiChatCompletionServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionServiceSettingsTests.java similarity index 67% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiChatCompletionServiceSettingsTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionServiceSettingsTests.java index c71d7270f4dec..5e2c92d8ed1d1 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiChatCompletionServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionServiceSettingsTests.java @@ -22,15 +22,15 @@ import static org.hamcrest.Matchers.is; -public class AzureOpenAiChatCompletionServiceSettingsTests extends AbstractWireSerializingTestCase< - AzureOpenAiChatCompletionServiceSettings> { +public class AzureOpenAiCompletionServiceSettingsTests extends AbstractWireSerializingTestCase< + AzureOpenAiCompletionServiceSettings> { - private static AzureOpenAiChatCompletionServiceSettings createRandom() { + private static AzureOpenAiCompletionServiceSettings createRandom() { var resourceName = randomAlphaOfLength(8); var deploymentId = randomAlphaOfLength(8); var apiVersion = randomAlphaOfLength(8); - return new AzureOpenAiChatCompletionServiceSettings(resourceName, deploymentId, apiVersion, null); + return new AzureOpenAiCompletionServiceSettings(resourceName, deploymentId, apiVersion, null); } public void testFromMap_Request_CreatesSettingsCorrectly() { @@ -38,7 +38,7 @@ public void testFromMap_Request_CreatesSettingsCorrectly() { var deploymentId = "this-deployment"; var apiVersion = "2024-01-01"; - var serviceSettings = AzureOpenAiChatCompletionServiceSettings.fromMap( + var serviceSettings = AzureOpenAiCompletionServiceSettings.fromMap( new HashMap<>( Map.of( AzureOpenAiServiceFields.RESOURCE_NAME, @@ -52,11 +52,11 @@ public void testFromMap_Request_CreatesSettingsCorrectly() { ConfigurationParseContext.PERSISTENT ); - assertThat(serviceSettings, is(new AzureOpenAiChatCompletionServiceSettings(resourceName, deploymentId, apiVersion, null))); + assertThat(serviceSettings, is(new AzureOpenAiCompletionServiceSettings(resourceName, deploymentId, apiVersion, null))); } public void testToXContent_WritesAllValues() throws IOException { - var entity = new AzureOpenAiChatCompletionServiceSettings("resource", "deployment", "2024", null); + var entity = new AzureOpenAiCompletionServiceSettings("resource", "deployment", "2024", null); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, null); @@ -67,18 +67,18 @@ public void testToXContent_WritesAllValues() throws IOException { } @Override - protected Writeable.Reader instanceReader() { - return AzureOpenAiChatCompletionServiceSettings::new; + protected Writeable.Reader instanceReader() { + return AzureOpenAiCompletionServiceSettings::new; } @Override - protected AzureOpenAiChatCompletionServiceSettings createTestInstance() { + protected AzureOpenAiCompletionServiceSettings createTestInstance() { return createRandom(); } @Override - protected AzureOpenAiChatCompletionServiceSettings mutateInstance(AzureOpenAiChatCompletionServiceSettings instance) + protected AzureOpenAiCompletionServiceSettings mutateInstance(AzureOpenAiCompletionServiceSettings instance) throws IOException { - return randomValueOtherThan(instance, AzureOpenAiChatCompletionServiceSettingsTests::createRandom); + return randomValueOtherThan(instance, AzureOpenAiCompletionServiceSettingsTests::createRandom); } } From 4ed3bf8ec6396ecaa1ad11b1943852b4ab0bf0da Mon Sep 17 00:00:00 2001 From: Tim Grein Date: Mon, 21 Oct 2024 09:49:50 +0200 Subject: [PATCH 3/3] Redo --- .../InferenceNamedWriteablesProvider.java | 6 ++--- .../AzureOpenAiChatCompletionModel.java | 10 ++++---- ...eOpenAiChatCompletionServiceSettings.java} | 18 +++++++------- .../AzureOpenAiChatCompletionModelTests.java | 4 ++-- ...AiChatCompletionServiceSettingsTests.java} | 24 +++++++++---------- 5 files changed, 31 insertions(+), 31 deletions(-) rename x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/{AzureOpenAiCompletionServiceSettings.java => AzureOpenAiChatCompletionServiceSettings.java} (88%) rename x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/{AzureOpenAiCompletionServiceSettingsTests.java => AzureOpenAiChatCompletionServiceSettingsTests.java} (67%) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index 8d86fcbe732cc..6db72d5d38424 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -49,7 +49,7 @@ import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsTaskSettings; import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings; -import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionServiceSettings; +import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiChatCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiChatCompletionTaskSettings; import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsTaskSettings; @@ -247,8 +247,8 @@ private static void addAzureOpenAiNamedWriteables(List map, ConfigurationParseContext context) { + public static AzureOpenAiChatCompletionServiceSettings fromMap(Map map, ConfigurationParseContext context) { ValidationException validationException = new ValidationException(); var settings = fromMap(map, validationException, context); @@ -66,10 +66,10 @@ public static AzureOpenAiCompletionServiceSettings fromMap(Map m throw validationException; } - return new AzureOpenAiCompletionServiceSettings(settings); + return new AzureOpenAiChatCompletionServiceSettings(settings); } - private static AzureOpenAiCompletionServiceSettings.CommonFields fromMap( + private static AzureOpenAiChatCompletionServiceSettings.CommonFields fromMap( Map map, ValidationException validationException, ConfigurationParseContext context @@ -85,7 +85,7 @@ private static AzureOpenAiCompletionServiceSettings.CommonFields fromMap( context ); - return new AzureOpenAiCompletionServiceSettings.CommonFields(resourceName, deploymentId, apiVersion, rateLimitSettings); + return new AzureOpenAiChatCompletionServiceSettings.CommonFields(resourceName, deploymentId, apiVersion, rateLimitSettings); } private record CommonFields(String resourceName, String deploymentId, String apiVersion, RateLimitSettings rateLimitSettings) {} @@ -96,7 +96,7 @@ private record CommonFields(String resourceName, String deploymentId, String api private final RateLimitSettings rateLimitSettings; - public AzureOpenAiCompletionServiceSettings( + public AzureOpenAiChatCompletionServiceSettings( String resourceName, String deploymentId, String apiVersion, @@ -108,14 +108,14 @@ public AzureOpenAiCompletionServiceSettings( this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); } - public AzureOpenAiCompletionServiceSettings(StreamInput in) throws IOException { + public AzureOpenAiChatCompletionServiceSettings(StreamInput in) throws IOException { resourceName = in.readString(); deploymentId = in.readString(); apiVersion = in.readString(); rateLimitSettings = new RateLimitSettings(in); } - private AzureOpenAiCompletionServiceSettings(AzureOpenAiCompletionServiceSettings.CommonFields fields) { + private AzureOpenAiChatCompletionServiceSettings(AzureOpenAiChatCompletionServiceSettings.CommonFields fields) { this(fields.resourceName, fields.deploymentId, fields.apiVersion, fields.rateLimitSettings); } @@ -183,7 +183,7 @@ public void writeTo(StreamOutput out) throws IOException { public boolean equals(Object object) { if (this == object) return true; if (object == null || getClass() != object.getClass()) return false; - AzureOpenAiCompletionServiceSettings that = (AzureOpenAiCompletionServiceSettings) object; + AzureOpenAiChatCompletionServiceSettings that = (AzureOpenAiChatCompletionServiceSettings) object; return Objects.equals(resourceName, that.resourceName) && Objects.equals(deploymentId, that.deploymentId) && Objects.equals(apiVersion, that.apiVersion) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiChatCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiChatCompletionModelTests.java index 19667c6079b58..dc4e3895bf297 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiChatCompletionModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiChatCompletionModelTests.java @@ -71,7 +71,7 @@ public void testOverrideWith_UpdatedServiceSettings_OverridesApiVersion() { var apiVersion = "api version"; var updatedApiVersion = "updated api version"; - var updatedServiceSettings = new AzureOpenAiCompletionServiceSettings(resource, deploymentId, updatedApiVersion, null); + var updatedServiceSettings = new AzureOpenAiChatCompletionServiceSettings(resource, deploymentId, updatedApiVersion, null); var model = createCompletionModel(resource, deploymentId, apiVersion, user, apiKey, entraId, inferenceEntityId); var overriddenModel = new AzureOpenAiChatCompletionModel(model, updatedServiceSettings); @@ -127,7 +127,7 @@ public static AzureOpenAiChatCompletionModel createCompletionModel( inferenceEntityId, TaskType.COMPLETION, "service", - new AzureOpenAiCompletionServiceSettings(resourceName, deploymentId, apiVersion, null), + new AzureOpenAiChatCompletionServiceSettings(resourceName, deploymentId, apiVersion, null), new AzureOpenAiChatCompletionTaskSettings(user), new AzureOpenAiSecretSettings(secureApiKey, secureEntraId) ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiChatCompletionServiceSettingsTests.java similarity index 67% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionServiceSettingsTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiChatCompletionServiceSettingsTests.java index 5e2c92d8ed1d1..c71d7270f4dec 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiChatCompletionServiceSettingsTests.java @@ -22,15 +22,15 @@ import static org.hamcrest.Matchers.is; -public class AzureOpenAiCompletionServiceSettingsTests extends AbstractWireSerializingTestCase< - AzureOpenAiCompletionServiceSettings> { +public class AzureOpenAiChatCompletionServiceSettingsTests extends AbstractWireSerializingTestCase< + AzureOpenAiChatCompletionServiceSettings> { - private static AzureOpenAiCompletionServiceSettings createRandom() { + private static AzureOpenAiChatCompletionServiceSettings createRandom() { var resourceName = randomAlphaOfLength(8); var deploymentId = randomAlphaOfLength(8); var apiVersion = randomAlphaOfLength(8); - return new AzureOpenAiCompletionServiceSettings(resourceName, deploymentId, apiVersion, null); + return new AzureOpenAiChatCompletionServiceSettings(resourceName, deploymentId, apiVersion, null); } public void testFromMap_Request_CreatesSettingsCorrectly() { @@ -38,7 +38,7 @@ public void testFromMap_Request_CreatesSettingsCorrectly() { var deploymentId = "this-deployment"; var apiVersion = "2024-01-01"; - var serviceSettings = AzureOpenAiCompletionServiceSettings.fromMap( + var serviceSettings = AzureOpenAiChatCompletionServiceSettings.fromMap( new HashMap<>( Map.of( AzureOpenAiServiceFields.RESOURCE_NAME, @@ -52,11 +52,11 @@ public void testFromMap_Request_CreatesSettingsCorrectly() { ConfigurationParseContext.PERSISTENT ); - assertThat(serviceSettings, is(new AzureOpenAiCompletionServiceSettings(resourceName, deploymentId, apiVersion, null))); + assertThat(serviceSettings, is(new AzureOpenAiChatCompletionServiceSettings(resourceName, deploymentId, apiVersion, null))); } public void testToXContent_WritesAllValues() throws IOException { - var entity = new AzureOpenAiCompletionServiceSettings("resource", "deployment", "2024", null); + var entity = new AzureOpenAiChatCompletionServiceSettings("resource", "deployment", "2024", null); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, null); @@ -67,18 +67,18 @@ public void testToXContent_WritesAllValues() throws IOException { } @Override - protected Writeable.Reader instanceReader() { - return AzureOpenAiCompletionServiceSettings::new; + protected Writeable.Reader instanceReader() { + return AzureOpenAiChatCompletionServiceSettings::new; } @Override - protected AzureOpenAiCompletionServiceSettings createTestInstance() { + protected AzureOpenAiChatCompletionServiceSettings createTestInstance() { return createRandom(); } @Override - protected AzureOpenAiCompletionServiceSettings mutateInstance(AzureOpenAiCompletionServiceSettings instance) + protected AzureOpenAiChatCompletionServiceSettings mutateInstance(AzureOpenAiChatCompletionServiceSettings instance) throws IOException { - return randomValueOtherThan(instance, AzureOpenAiCompletionServiceSettingsTests::createRandom); + return randomValueOtherThan(instance, AzureOpenAiChatCompletionServiceSettingsTests::createRandom); } }