Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/138726.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 136624
summary: Added Azure OpenAI chat_completion support to the Inference Plugin
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ public void testGetServicesWithChatCompletionTaskType() throws IOException {
containsInAnyOrder(
List.of(
"ai21",
"azureopenai",
"llama",
"deepseek",
"elastic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@
package org.elasticsearch.xpack.inference.services.azureopenai;

import org.apache.http.client.utils.URIBuilder;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel;
import org.elasticsearch.xpack.inference.services.azureopenai.action.AzureOpenAiActionVisitor;
import org.elasticsearch.xpack.inference.services.azureopenai.request.AzureOpenAiUtils;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;

import java.net.URI;
import java.net.URISyntaxException;
Expand All @@ -27,7 +28,7 @@

import static org.elasticsearch.core.Strings.format;

public abstract class AzureOpenAiModel extends Model {
public abstract class AzureOpenAiModel extends RateLimitGroupingModel {

protected URI uri;
private final AzureOpenAiRateLimitServiceSettings rateLimitServiceSettings;
Expand Down Expand Up @@ -95,6 +96,16 @@ public AzureOpenAiRateLimitServiceSettings rateLimitServiceSettings() {
return rateLimitServiceSettings;
}

@Override
public RateLimitSettings rateLimitSettings() {
return rateLimitServiceSettings.rateLimitSettings();
}

@Override
public int rateLimitGroupingHash() {
return Objects.hash(resourceName(), deploymentId());
}

// TODO: can be inferred directly from modelConfigurations.getServiceSettings(); will be addressed with separate refactoring
public abstract String resourceName();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,8 @@ public static Map<String, SettingsConfiguration> get() {
var configurationMap = new HashMap<String, SettingsConfiguration>();
configurationMap.put(
API_KEY,
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION)).setDescription(
"You must provide either an API key or an Entra ID."
)
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION, TaskType.CHAT_COMPLETION))
.setDescription("You must provide either an API key or an Entra ID.")
.setLabel("API Key")
.setRequired(false)
.setSensitive(true)
Expand All @@ -160,9 +159,8 @@ public static Map<String, SettingsConfiguration> get() {
);
configurationMap.put(
ENTRA_ID,
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION)).setDescription(
"You must provide either an API key or an Entra ID."
)
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION, TaskType.CHAT_COMPLETION))
.setDescription("You must provide either an API key or an Entra ID.")
.setLabel("Entra ID")
.setRequired(false)
.setSensitive(true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
import org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.core.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
Expand All @@ -40,9 +43,12 @@
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.azureopenai.action.AzureOpenAiActionCreator;
import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiChatCompletionResponseHandler;
import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionModel;
import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.azureopenai.request.AzureOpenAiChatCompletionRequest;
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;

import java.util.EnumSet;
Expand All @@ -51,14 +57,14 @@
import java.util.Map;
import java.util.Set;

import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation;
import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields.API_VERSION;
import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields.DEPLOYMENT_ID;
import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields.RESOURCE_NAME;
Expand All @@ -68,7 +74,16 @@ public class AzureOpenAiService extends SenderService {
public static final String NAME = "azureopenai";

private static final String SERVICE_NAME = "Azure OpenAI";
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION);
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(
TaskType.TEXT_EMBEDDING,
TaskType.COMPLETION,
TaskType.CHAT_COMPLETION
);
public static final String CHAT_COMPLETION_REQUEST_TYPE = "Azure OpenAI chat completions";
private static final ResponseHandler CHAT_COMPLETION_HANDLER = new AzureOpenAiChatCompletionResponseHandler(
CHAT_COMPLETION_REQUEST_TYPE,
OpenAiChatCompletionResponseEntity::fromResponse
);

public AzureOpenAiService(
HttpRequestSender.Factory factory,
Expand Down Expand Up @@ -166,7 +181,7 @@ private static AzureOpenAiModel createModel(
context
);
}
case COMPLETION -> {
case COMPLETION, CHAT_COMPLETION -> {
return new AzureOpenAiCompletionModel(
inferenceEntityId,
taskType,
Expand Down Expand Up @@ -237,7 +252,25 @@ protected void doUnifiedCompletionInfer(
TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
throwUnsupportedUnifiedCompletionOperation(NAME);
if (model instanceof AzureOpenAiCompletionModel == false) {
listener.onFailure(createInvalidModelException(model));
return;
}

AzureOpenAiCompletionModel openAiModel = (AzureOpenAiCompletionModel) model;

var manager = new GenericRequestManager<>(
getServiceComponents().threadPool(),
openAiModel,
CHAT_COMPLETION_HANDLER,
chatInput -> new AzureOpenAiChatCompletionRequest(chatInput, openAiModel),
UnifiedChatInput.class
);

var errorMessage = constructFailedToSendRequestMessage(CHAT_COMPLETION_REQUEST_TYPE);
var action = new SenderExecutableAction(getSender(), manager, errorMessage);

action.execute(inputs, timeout, listener);
}

@Override
Expand Down Expand Up @@ -324,7 +357,7 @@ public TransportVersion getMinimalSupportedVersion() {

@Override
public Set<TaskType> supportedStreamingTasks() {
return COMPLETION_ONLY;
return EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION);
}

public static class Configuration {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.azureopenai.completion;

import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
import org.elasticsearch.xpack.inference.external.http.retry.UnifiedChatCompletionErrorParserContract;
import org.elasticsearch.xpack.inference.external.http.retry.UnifiedChatCompletionErrorResponseUtils;
import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler;

/**
* Handles streaming chat completion responses and error parsing for OpenShift AI inference endpoints.
* Adapts the OpenAI handler to support OpenShift AI's error schema.
*/
public class AzureOpenAiChatCompletionResponseHandler extends OpenAiUnifiedChatCompletionResponseHandler {

private static final String AZURE_OPENAI_ERROR = "azure_openai_error";
private static final UnifiedChatCompletionErrorParserContract AZURE_OPENAI_ERROR_PARSER = UnifiedChatCompletionErrorResponseUtils
.createErrorParserWithStringify(AZURE_OPENAI_ERROR);

public AzureOpenAiChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
super(requestType, parseFunction, AZURE_OPENAI_ERROR_PARSER::parse, AZURE_OPENAI_ERROR_PARSER);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.azureopenai.request;

import org.apache.http.client.methods.HttpPost;
import org.apache.http.entity.ByteArrayEntity;
import org.elasticsearch.common.Strings;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionModel;

import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.Objects;

public class AzureOpenAiChatCompletionRequest implements AzureOpenAiRequest {

private final UnifiedChatInput chatInput;

private final URI uri;

private final AzureOpenAiCompletionModel model;

public AzureOpenAiChatCompletionRequest(UnifiedChatInput chatInput, AzureOpenAiCompletionModel model) {
this.chatInput = chatInput;
this.model = Objects.requireNonNull(model);
this.uri = model.getUri();
}

@Override
public HttpRequest createHttpRequest() {
var httpPost = new HttpPost(uri);
var requestEntity = Strings.toString(new AzureOpenAiChatCompletionRequestEntity(chatInput, model.getTaskSettings().user()));

ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8));
httpPost.setEntity(byteEntity);

AzureOpenAiRequest.decorateWithAuthHeader(httpPost, model.getSecretSettings());

return new HttpRequest(httpPost, getInferenceEntityId());
}

@Override
public URI getURI() {
return this.uri;
}

@Override
public String getInferenceEntityId() {
return model.getInferenceEntityId();
}

@Override
public boolean isStreaming() {
return chatInput.stream();
}

@Override
public Request truncate() {
// No truncation for Azure OpenAI completion
return this;
}

@Override
public boolean[] getTruncationInfo() {
// No truncation for Azure OpenAI completion
return null;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.azureopenai.request;

import org.elasticsearch.common.Strings;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.UnifiedCompletionRequest;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.external.unified.UnifiedChatCompletionRequestEntity;

import java.io.IOException;

public class AzureOpenAiChatCompletionRequestEntity implements ToXContentObject {

public static final String USER_FIELD = "user";
private final UnifiedChatCompletionRequestEntity requestEntity;
private final String user;

public AzureOpenAiChatCompletionRequestEntity(UnifiedChatInput chatInput, @Nullable String user) {
this.requestEntity = new UnifiedChatCompletionRequestEntity(chatInput);
this.user = user;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
requestEntity.toXContent(builder, UnifiedCompletionRequest.withMaxCompletionTokensTokens(params));

if (Strings.isNullOrEmpty(user) == false) {
builder.field(USER_FIELD, user);
}
builder.endObject();
return builder;
}
}
Loading