Skip to content

Commit fcf3550

Browse files
Add Azure OpenAI chat completion support (#138726)
Extending of an existing Azure OpenAI inference provider integration allowing chat_completion task to be executed as part of inference API with azureopenai provider. Other changes: * Fix parameter naming in UnifiedCompletionRequest for max completion tokens
1 parent 51a7b97 commit fcf3550

File tree

15 files changed

+712
-24
lines changed

15 files changed

+712
-24
lines changed

docs/changelog/138726.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 136624
2+
summary: Added Azure OpenAI chat_completion support to the Inference Plugin
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

rest-api-spec/src/main/resources/rest-api-spec/api/inference.put_azureopenai.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
"description": "The task type",
2828
"options": [
2929
"completion",
30+
"chat_completion",
3031
"text_embedding"
3132
]
3233
},

server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ public static Params withMaxCompletionTokens(String modelId, Params params) {
135135
* Creates a {@link org.elasticsearch.xcontent.ToXContent.Params} that causes ToXContent to include the key values:
136136
* - Key: {@link #MAX_COMPLETION_TOKENS_FIELD}, Value: {@link #maxCompletionTokens()}
137137
*/
138-
public static Params withMaxCompletionTokensTokens(Params params) {
138+
public static Params withMaxCompletionTokens(Params params) {
139139
return new DelegatingMapParams(Map.of(MAX_TOKENS_PARAM, MAX_COMPLETION_TOKENS_FIELD), params);
140140
}
141141

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ public void testGetServicesWithChatCompletionTaskType() throws IOException {
193193
containsInAnyOrder(
194194
List.of(
195195
"ai21",
196+
"azureopenai",
196197
"llama",
197198
"deepseek",
198199
"elastic",

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiModel.java

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,15 @@
88
package org.elasticsearch.xpack.inference.services.azureopenai;
99

1010
import org.apache.http.client.utils.URIBuilder;
11-
import org.elasticsearch.inference.Model;
1211
import org.elasticsearch.inference.ModelConfigurations;
1312
import org.elasticsearch.inference.ModelSecrets;
1413
import org.elasticsearch.inference.ServiceSettings;
1514
import org.elasticsearch.inference.TaskSettings;
1615
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
16+
import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel;
1717
import org.elasticsearch.xpack.inference.services.azureopenai.action.AzureOpenAiActionVisitor;
1818
import org.elasticsearch.xpack.inference.services.azureopenai.request.AzureOpenAiUtils;
19+
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
1920

2021
import java.net.URI;
2122
import java.net.URISyntaxException;
@@ -27,7 +28,7 @@
2728

2829
import static org.elasticsearch.core.Strings.format;
2930

30-
public abstract class AzureOpenAiModel extends Model {
31+
public abstract class AzureOpenAiModel extends RateLimitGroupingModel {
3132

3233
protected URI uri;
3334
private final AzureOpenAiRateLimitServiceSettings rateLimitServiceSettings;
@@ -95,6 +96,16 @@ public AzureOpenAiRateLimitServiceSettings rateLimitServiceSettings() {
9596
return rateLimitServiceSettings;
9697
}
9798

99+
@Override
100+
public RateLimitSettings rateLimitSettings() {
101+
return rateLimitServiceSettings.rateLimitSettings();
102+
}
103+
104+
@Override
105+
public int rateLimitGroupingHash() {
106+
return Objects.hash(resourceName(), deploymentId());
107+
}
108+
98109
// TODO: can be inferred directly from modelConfigurations.getServiceSettings(); will be addressed with separate refactoring
99110
public abstract String resourceName();
100111

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiSecretSettings.java

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,8 @@ public static Map<String, SettingsConfiguration> get() {
148148
var configurationMap = new HashMap<String, SettingsConfiguration>();
149149
configurationMap.put(
150150
API_KEY,
151-
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION)).setDescription(
152-
"You must provide either an API key or an Entra ID."
153-
)
151+
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION, TaskType.CHAT_COMPLETION))
152+
.setDescription("You must provide either an API key or an Entra ID.")
154153
.setLabel("API Key")
155154
.setRequired(false)
156155
.setSensitive(true)
@@ -160,9 +159,8 @@ public static Map<String, SettingsConfiguration> get() {
160159
);
161160
configurationMap.put(
162161
ENTRA_ID,
163-
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION)).setDescription(
164-
"You must provide either an API key or an Entra ID."
165-
)
162+
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION, TaskType.CHAT_COMPLETION))
163+
.setDescription("You must provide either an API key or an Entra ID.")
166164
.setLabel("Entra ID")
167165
.setRequired(false)
168166
.setSensitive(true)

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,10 @@
3131
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
3232
import org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsBuilder;
3333
import org.elasticsearch.xpack.core.inference.chunking.EmbeddingRequestChunker;
34+
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
35+
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
3436
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
37+
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
3538
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
3639
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
3740
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
@@ -40,9 +43,12 @@
4043
import org.elasticsearch.xpack.inference.services.ServiceComponents;
4144
import org.elasticsearch.xpack.inference.services.ServiceUtils;
4245
import org.elasticsearch.xpack.inference.services.azureopenai.action.AzureOpenAiActionCreator;
46+
import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiChatCompletionResponseHandler;
4347
import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionModel;
4448
import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModel;
4549
import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsServiceSettings;
50+
import org.elasticsearch.xpack.inference.services.azureopenai.request.AzureOpenAiChatCompletionRequest;
51+
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity;
4652
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
4753

4854
import java.util.EnumSet;
@@ -51,14 +57,14 @@
5157
import java.util.Map;
5258
import java.util.Set;
5359

60+
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
5461
import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS;
5562
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
5663
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException;
5764
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
5865
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
5966
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
6067
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
61-
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation;
6268
import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields.API_VERSION;
6369
import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields.DEPLOYMENT_ID;
6470
import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields.RESOURCE_NAME;
@@ -68,7 +74,16 @@ public class AzureOpenAiService extends SenderService {
6874
public static final String NAME = "azureopenai";
6975

7076
private static final String SERVICE_NAME = "Azure OpenAI";
71-
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION);
77+
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(
78+
TaskType.TEXT_EMBEDDING,
79+
TaskType.COMPLETION,
80+
TaskType.CHAT_COMPLETION
81+
);
82+
public static final String CHAT_COMPLETION_REQUEST_TYPE = "Azure OpenAI chat completions";
83+
private static final ResponseHandler CHAT_COMPLETION_HANDLER = new AzureOpenAiChatCompletionResponseHandler(
84+
CHAT_COMPLETION_REQUEST_TYPE,
85+
OpenAiChatCompletionResponseEntity::fromResponse
86+
);
7287

7388
public AzureOpenAiService(
7489
HttpRequestSender.Factory factory,
@@ -166,7 +181,7 @@ private static AzureOpenAiModel createModel(
166181
context
167182
);
168183
}
169-
case COMPLETION -> {
184+
case COMPLETION, CHAT_COMPLETION -> {
170185
return new AzureOpenAiCompletionModel(
171186
inferenceEntityId,
172187
taskType,
@@ -237,7 +252,25 @@ protected void doUnifiedCompletionInfer(
237252
TimeValue timeout,
238253
ActionListener<InferenceServiceResults> listener
239254
) {
240-
throwUnsupportedUnifiedCompletionOperation(NAME);
255+
if (model instanceof AzureOpenAiCompletionModel == false) {
256+
listener.onFailure(createInvalidModelException(model));
257+
return;
258+
}
259+
260+
AzureOpenAiCompletionModel openAiModel = (AzureOpenAiCompletionModel) model;
261+
262+
var manager = new GenericRequestManager<>(
263+
getServiceComponents().threadPool(),
264+
openAiModel,
265+
CHAT_COMPLETION_HANDLER,
266+
chatInput -> new AzureOpenAiChatCompletionRequest(chatInput, openAiModel),
267+
UnifiedChatInput.class
268+
);
269+
270+
var errorMessage = constructFailedToSendRequestMessage(CHAT_COMPLETION_REQUEST_TYPE);
271+
var action = new SenderExecutableAction(getSender(), manager, errorMessage);
272+
273+
action.execute(inputs, timeout, listener);
241274
}
242275

243276
@Override
@@ -324,7 +357,7 @@ public TransportVersion getMinimalSupportedVersion() {
324357

325358
@Override
326359
public Set<TaskType> supportedStreamingTasks() {
327-
return COMPLETION_ONLY;
360+
return EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION);
328361
}
329362

330363
public static class Configuration {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.azureopenai.completion;
9+
10+
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
11+
import org.elasticsearch.xpack.inference.external.http.retry.UnifiedChatCompletionErrorParserContract;
12+
import org.elasticsearch.xpack.inference.external.http.retry.UnifiedChatCompletionErrorResponseUtils;
13+
import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler;
14+
15+
/**
16+
* Handles streaming chat completion responses and error parsing for Azure OpenAI inference endpoints.
17+
* Adapts the OpenAI handler to support Azure OpenAI's error schema.
18+
*/
19+
public class AzureOpenAiChatCompletionResponseHandler extends OpenAiUnifiedChatCompletionResponseHandler {
20+
21+
private static final String AZURE_OPENAI_ERROR = "azure_openai_error";
22+
private static final UnifiedChatCompletionErrorParserContract AZURE_OPENAI_ERROR_PARSER = UnifiedChatCompletionErrorResponseUtils
23+
.createErrorParserWithStringify(AZURE_OPENAI_ERROR);
24+
25+
public AzureOpenAiChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
26+
super(requestType, parseFunction, AZURE_OPENAI_ERROR_PARSER::parse, AZURE_OPENAI_ERROR_PARSER);
27+
}
28+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.azureopenai.request;
9+
10+
import org.apache.http.client.methods.HttpPost;
11+
import org.apache.http.entity.ByteArrayEntity;
12+
import org.elasticsearch.common.Strings;
13+
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
14+
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
15+
import org.elasticsearch.xpack.inference.external.request.Request;
16+
import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionModel;
17+
18+
import java.net.URI;
19+
import java.nio.charset.StandardCharsets;
20+
import java.util.Objects;
21+
22+
public class AzureOpenAiChatCompletionRequest implements AzureOpenAiRequest {
23+
24+
private final UnifiedChatInput chatInput;
25+
26+
private final AzureOpenAiCompletionModel model;
27+
28+
public AzureOpenAiChatCompletionRequest(UnifiedChatInput chatInput, AzureOpenAiCompletionModel model) {
29+
this.chatInput = chatInput;
30+
this.model = Objects.requireNonNull(model);
31+
}
32+
33+
@Override
34+
public HttpRequest createHttpRequest() {
35+
var httpPost = new HttpPost(getURI());
36+
var requestEntity = Strings.toString(new AzureOpenAiChatCompletionRequestEntity(chatInput, model.getTaskSettings().user()));
37+
38+
ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8));
39+
httpPost.setEntity(byteEntity);
40+
41+
AzureOpenAiRequest.decorateWithAuthHeader(httpPost, model.getSecretSettings());
42+
43+
return new HttpRequest(httpPost, getInferenceEntityId());
44+
}
45+
46+
@Override
47+
public URI getURI() {
48+
return model.getUri();
49+
}
50+
51+
@Override
52+
public String getInferenceEntityId() {
53+
return model.getInferenceEntityId();
54+
}
55+
56+
@Override
57+
public boolean isStreaming() {
58+
return chatInput.stream();
59+
}
60+
61+
@Override
62+
public Request truncate() {
63+
// No truncation for Azure OpenAI completion
64+
return this;
65+
}
66+
67+
@Override
68+
public boolean[] getTruncationInfo() {
69+
// No truncation for Azure OpenAI completion
70+
return null;
71+
}
72+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.azureopenai.request;
9+
10+
import org.elasticsearch.common.Strings;
11+
import org.elasticsearch.core.Nullable;
12+
import org.elasticsearch.inference.UnifiedCompletionRequest;
13+
import org.elasticsearch.xcontent.ToXContentObject;
14+
import org.elasticsearch.xcontent.XContentBuilder;
15+
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
16+
import org.elasticsearch.xpack.inference.external.unified.UnifiedChatCompletionRequestEntity;
17+
18+
import java.io.IOException;
19+
20+
public class AzureOpenAiChatCompletionRequestEntity implements ToXContentObject {
21+
22+
public static final String USER_FIELD = "user";
23+
private final UnifiedChatCompletionRequestEntity requestEntity;
24+
private final String user;
25+
26+
public AzureOpenAiChatCompletionRequestEntity(UnifiedChatInput chatInput, @Nullable String user) {
27+
this.requestEntity = new UnifiedChatCompletionRequestEntity(chatInput);
28+
this.user = user;
29+
}
30+
31+
@Override
32+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
33+
builder.startObject();
34+
requestEntity.toXContent(builder, UnifiedCompletionRequest.withMaxCompletionTokens(params));
35+
36+
if (Strings.isNullOrEmpty(user) == false) {
37+
builder.field(USER_FIELD, user);
38+
}
39+
builder.endObject();
40+
return builder;
41+
}
42+
}

0 commit comments

Comments
 (0)