Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@
import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsTaskSettings;
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings;
import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionTaskSettings;
import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiChatCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiChatCompletionTaskSettings;
import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsTaskSettings;
import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings;
Expand Down Expand Up @@ -247,15 +247,15 @@ private static void addAzureOpenAiNamedWriteables(List<NamedWriteableRegistry.En
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
AzureOpenAiCompletionServiceSettings.NAME,
AzureOpenAiCompletionServiceSettings::new
AzureOpenAiChatCompletionServiceSettings.NAME,
AzureOpenAiChatCompletionServiceSettings::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(
TaskSettings.class,
AzureOpenAiCompletionTaskSettings.NAME,
AzureOpenAiCompletionTaskSettings::new
AzureOpenAiChatCompletionTaskSettings.NAME,
AzureOpenAiChatCompletionTaskSettings::new
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.sender.AzureOpenAiCompletionRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.AzureOpenAiChatCompletionRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.AzureOpenAiEmbeddingsRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionModel;
import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiChatCompletionModel;
import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModel;

import java.util.Map;
Expand Down Expand Up @@ -48,9 +48,9 @@ public ExecutableAction create(AzureOpenAiEmbeddingsModel model, Map<String, Obj
}

@Override
public ExecutableAction create(AzureOpenAiCompletionModel model, Map<String, Object> taskSettings) {
var overriddenModel = AzureOpenAiCompletionModel.of(model, taskSettings);
var requestCreator = new AzureOpenAiCompletionRequestManager(overriddenModel, serviceComponents.threadPool());
public ExecutableAction create(AzureOpenAiChatCompletionModel model, Map<String, Object> taskSettings) {
var overriddenModel = AzureOpenAiChatCompletionModel.of(model, taskSettings);
var requestCreator = new AzureOpenAiChatCompletionRequestManager(overriddenModel, serviceComponents.threadPool());
var errorMessage = constructFailedToSendRequestMessage(overriddenModel.getUri(), COMPLETION_ERROR_PREFIX);
return new SingleInputSenderExecutableAction(sender, requestCreator, errorMessage, COMPLETION_ERROR_PREFIX);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
package org.elasticsearch.xpack.inference.external.action.azureopenai;

import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionModel;
import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiChatCompletionModel;
import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModel;

import java.util.Map;

public interface AzureOpenAiActionVisitor {
ExecutableAction create(AzureOpenAiEmbeddingsModel model, Map<String, Object> taskSettings);

ExecutableAction create(AzureOpenAiCompletionModel model, Map<String, Object> taskSettings);
ExecutableAction create(AzureOpenAiChatCompletionModel model, Map<String, Object> taskSettings);
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,26 @@
import org.elasticsearch.xpack.inference.external.azureopenai.AzureOpenAiResponseHandler;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiCompletionRequest;
import org.elasticsearch.xpack.inference.external.response.azureopenai.AzureOpenAiCompletionResponseEntity;
import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionModel;
import org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiChatCompletionRequest;
import org.elasticsearch.xpack.inference.external.response.azureopenai.AzureOpenAiChatCompletionResponseEntity;
import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiChatCompletionModel;

import java.util.Objects;
import java.util.function.Supplier;

public class AzureOpenAiCompletionRequestManager extends AzureOpenAiRequestManager {
public class AzureOpenAiChatCompletionRequestManager extends AzureOpenAiRequestManager {

private static final Logger logger = LogManager.getLogger(AzureOpenAiCompletionRequestManager.class);
private static final Logger logger = LogManager.getLogger(AzureOpenAiChatCompletionRequestManager.class);

private static final ResponseHandler HANDLER = createCompletionHandler();

private final AzureOpenAiCompletionModel model;
private final AzureOpenAiChatCompletionModel model;

private static ResponseHandler createCompletionHandler() {
return new AzureOpenAiResponseHandler("azure openai completion", AzureOpenAiCompletionResponseEntity::fromResponse, true);
return new AzureOpenAiResponseHandler("azure openai completion", AzureOpenAiChatCompletionResponseEntity::fromResponse, true);
}

public AzureOpenAiCompletionRequestManager(AzureOpenAiCompletionModel model, ThreadPool threadPool) {
public AzureOpenAiChatCompletionRequestManager(AzureOpenAiChatCompletionModel model, ThreadPool threadPool) {
super(threadPool, model);
this.model = Objects.requireNonNull(model);
}
Expand All @@ -49,7 +49,7 @@ public void execute(
var docsOnly = DocumentsOnlyInput.of(inferenceInputs);
var docsInput = docsOnly.getInputs();
var stream = docsOnly.stream();
AzureOpenAiCompletionRequest request = new AzureOpenAiCompletionRequest(docsInput, model, stream);
AzureOpenAiChatCompletionRequest request = new AzureOpenAiChatCompletionRequest(docsInput, model, stream);
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,24 @@
import org.elasticsearch.common.Strings;
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionModel;
import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiChatCompletionModel;

import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Objects;

public class AzureOpenAiCompletionRequest implements AzureOpenAiRequest {
public class AzureOpenAiChatCompletionRequest implements AzureOpenAiRequest {

private final List<String> input;

private final URI uri;

private final AzureOpenAiCompletionModel model;
private final AzureOpenAiChatCompletionModel model;

private final boolean stream;

public AzureOpenAiCompletionRequest(List<String> input, AzureOpenAiCompletionModel model, boolean stream) {
public AzureOpenAiChatCompletionRequest(List<String> input, AzureOpenAiChatCompletionModel model, boolean stream) {
this.input = input;
this.model = Objects.requireNonNull(model);
this.uri = model.getUri();
Expand All @@ -39,7 +39,9 @@ public AzureOpenAiCompletionRequest(List<String> input, AzureOpenAiCompletionMod
@Override
public HttpRequest createHttpRequest() {
var httpPost = new HttpPost(uri);
var requestEntity = Strings.toString(new AzureOpenAiCompletionRequestEntity(input, model.getTaskSettings().user(), isStreaming()));
var requestEntity = Strings.toString(
new AzureOpenAiChatCompletionRequestEntity(input, model.getTaskSettings().user(), isStreaming())
);

ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8));
httpPost.setEntity(byteEntity);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
import java.util.List;
import java.util.Objects;

public record AzureOpenAiCompletionRequestEntity(List<String> messages, @Nullable String user, boolean stream) implements ToXContentObject {
public record AzureOpenAiChatCompletionRequestEntity(List<String> messages, @Nullable String user, boolean stream)
implements
ToXContentObject {

private static final String NUMBER_OF_RETURNED_CHOICES_FIELD = "n";

Expand All @@ -30,7 +32,7 @@ public record AzureOpenAiCompletionRequestEntity(List<String> messages, @Nullabl

private static final String STREAM_FIELD = "stream";

public AzureOpenAiCompletionRequestEntity {
public AzureOpenAiChatCompletionRequestEntity {
Objects.requireNonNull(messages);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken;
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField;

public class AzureOpenAiCompletionResponseEntity {
public class AzureOpenAiChatCompletionResponseEntity {

private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in Azure OpenAI completions response";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
import org.elasticsearch.xpack.inference.services.SenderService;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionModel;
import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiChatCompletionModel;
import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsServiceSettings;

Expand Down Expand Up @@ -147,7 +147,7 @@ private static AzureOpenAiModel createModel(
);
}
case COMPLETION -> {
return new AzureOpenAiCompletionModel(
return new AzureOpenAiChatCompletionModel(
inferenceEntityId,
taskType,
NAME,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,21 @@
import java.net.URISyntaxException;
import java.util.Map;

public class AzureOpenAiCompletionModel extends AzureOpenAiModel {
public class AzureOpenAiChatCompletionModel extends AzureOpenAiModel {

public static AzureOpenAiCompletionModel of(AzureOpenAiCompletionModel model, Map<String, Object> taskSettings) {
public static AzureOpenAiChatCompletionModel of(AzureOpenAiChatCompletionModel model, Map<String, Object> taskSettings) {
if (taskSettings == null || taskSettings.isEmpty()) {
return model;
}

var requestTaskSettings = AzureOpenAiCompletionRequestTaskSettings.fromMap(taskSettings);
return new AzureOpenAiCompletionModel(model, AzureOpenAiCompletionTaskSettings.of(model.getTaskSettings(), requestTaskSettings));
var requestTaskSettings = AzureOpenAiChatCompletionRequestTaskSettings.fromMap(taskSettings);
return new AzureOpenAiChatCompletionModel(
model,
AzureOpenAiChatCompletionTaskSettings.of(model.getTaskSettings(), requestTaskSettings)
);
}

public AzureOpenAiCompletionModel(
public AzureOpenAiChatCompletionModel(
String inferenceEntityId,
TaskType taskType,
String service,
Expand All @@ -45,19 +48,19 @@ public AzureOpenAiCompletionModel(
inferenceEntityId,
taskType,
service,
AzureOpenAiCompletionServiceSettings.fromMap(serviceSettings, context),
AzureOpenAiCompletionTaskSettings.fromMap(taskSettings),
AzureOpenAiChatCompletionServiceSettings.fromMap(serviceSettings, context),
AzureOpenAiChatCompletionTaskSettings.fromMap(taskSettings),
AzureOpenAiSecretSettings.fromMap(secrets)
);
}

// Should only be used directly for testing
AzureOpenAiCompletionModel(
AzureOpenAiChatCompletionModel(
String inferenceEntityId,
TaskType taskType,
String service,
AzureOpenAiCompletionServiceSettings serviceSettings,
AzureOpenAiCompletionTaskSettings taskSettings,
AzureOpenAiChatCompletionServiceSettings serviceSettings,
AzureOpenAiChatCompletionTaskSettings taskSettings,
@Nullable AzureOpenAiSecretSettings secrets
) {
super(
Expand All @@ -72,22 +75,28 @@ public AzureOpenAiCompletionModel(
}
}

public AzureOpenAiCompletionModel(AzureOpenAiCompletionModel originalModel, AzureOpenAiCompletionServiceSettings serviceSettings) {
public AzureOpenAiChatCompletionModel(
AzureOpenAiChatCompletionModel originalModel,
AzureOpenAiChatCompletionServiceSettings serviceSettings
) {
super(originalModel, serviceSettings);
}

private AzureOpenAiCompletionModel(AzureOpenAiCompletionModel originalModel, AzureOpenAiCompletionTaskSettings taskSettings) {
private AzureOpenAiChatCompletionModel(
AzureOpenAiChatCompletionModel originalModel,
AzureOpenAiChatCompletionTaskSettings taskSettings
) {
super(originalModel, taskSettings);
}

@Override
public AzureOpenAiCompletionServiceSettings getServiceSettings() {
return (AzureOpenAiCompletionServiceSettings) super.getServiceSettings();
public AzureOpenAiChatCompletionServiceSettings getServiceSettings() {
return (AzureOpenAiChatCompletionServiceSettings) super.getServiceSettings();
}

@Override
public AzureOpenAiCompletionTaskSettings getTaskSettings() {
return (AzureOpenAiCompletionTaskSettings) super.getTaskSettings();
public AzureOpenAiChatCompletionTaskSettings getTaskSettings() {
return (AzureOpenAiChatCompletionTaskSettings) super.getTaskSettings();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields.USER;

public record AzureOpenAiCompletionRequestTaskSettings(@Nullable String user) {
public record AzureOpenAiChatCompletionRequestTaskSettings(@Nullable String user) {

public static final AzureOpenAiCompletionRequestTaskSettings EMPTY_SETTINGS = new AzureOpenAiCompletionRequestTaskSettings(null);
public static final AzureOpenAiChatCompletionRequestTaskSettings EMPTY_SETTINGS = new AzureOpenAiChatCompletionRequestTaskSettings(
null
);

public static AzureOpenAiCompletionRequestTaskSettings fromMap(Map<String, Object> map) {
public static AzureOpenAiChatCompletionRequestTaskSettings fromMap(Map<String, Object> map) {
if (map.isEmpty()) {
return AzureOpenAiCompletionRequestTaskSettings.EMPTY_SETTINGS;
return AzureOpenAiChatCompletionRequestTaskSettings.EMPTY_SETTINGS;
}

ValidationException validationException = new ValidationException();
Expand All @@ -33,6 +35,6 @@ public static AzureOpenAiCompletionRequestTaskSettings fromMap(Map<String, Objec
throw validationException;
}

return new AzureOpenAiCompletionRequestTaskSettings(user);
return new AzureOpenAiChatCompletionRequestTaskSettings(user);
}
}
Loading