Skip to content

Commit ad4eac9

Browse files
Add Azure OpenAI chat completion support and related request handling
1 parent d8bd6b6 commit ad4eac9

File tree

7 files changed

+208
-21
lines changed

7 files changed

+208
-21
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiModel.java

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,15 @@
88
package org.elasticsearch.xpack.inference.services.azureopenai;
99

1010
import org.apache.http.client.utils.URIBuilder;
11-
import org.elasticsearch.inference.Model;
1211
import org.elasticsearch.inference.ModelConfigurations;
1312
import org.elasticsearch.inference.ModelSecrets;
1413
import org.elasticsearch.inference.ServiceSettings;
1514
import org.elasticsearch.inference.TaskSettings;
1615
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
16+
import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel;
1717
import org.elasticsearch.xpack.inference.services.azureopenai.action.AzureOpenAiActionVisitor;
1818
import org.elasticsearch.xpack.inference.services.azureopenai.request.AzureOpenAiUtils;
19+
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
1920

2021
import java.net.URI;
2122
import java.net.URISyntaxException;
@@ -27,7 +28,7 @@
2728

2829
import static org.elasticsearch.core.Strings.format;
2930

30-
public abstract class AzureOpenAiModel extends Model {
31+
public abstract class AzureOpenAiModel extends RateLimitGroupingModel {
3132

3233
protected URI uri;
3334
private final AzureOpenAiRateLimitServiceSettings rateLimitServiceSettings;
@@ -95,6 +96,16 @@ public AzureOpenAiRateLimitServiceSettings rateLimitServiceSettings() {
9596
return rateLimitServiceSettings;
9697
}
9798

99+
@Override
100+
public RateLimitSettings rateLimitSettings() {
101+
return rateLimitServiceSettings.rateLimitSettings();
102+
}
103+
104+
@Override
105+
public int rateLimitGroupingHash() {
106+
return Objects.hash(resourceName(), deploymentId());
107+
}
108+
98109
// TODO: can be inferred directly from modelConfigurations.getServiceSettings(); will be addressed with separate refactoring
99110
public abstract String resourceName();
100111

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiSecretSettings.java

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,8 @@ public static Map<String, SettingsConfiguration> get() {
148148
var configurationMap = new HashMap<String, SettingsConfiguration>();
149149
configurationMap.put(
150150
API_KEY,
151-
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION)).setDescription(
152-
"You must provide either an API key or an Entra ID."
153-
)
151+
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION, TaskType.CHAT_COMPLETION))
152+
.setDescription("You must provide either an API key or an Entra ID.")
154153
.setLabel("API Key")
155154
.setRequired(false)
156155
.setSensitive(true)
@@ -160,9 +159,8 @@ public static Map<String, SettingsConfiguration> get() {
160159
);
161160
configurationMap.put(
162161
ENTRA_ID,
163-
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION)).setDescription(
164-
"You must provide either an API key or an Entra ID."
165-
)
162+
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION, TaskType.CHAT_COMPLETION))
163+
.setDescription("You must provide either an API key or an Entra ID.")
166164
.setLabel("Entra ID")
167165
.setRequired(false)
168166
.setSensitive(true)

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,10 @@
3131
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
3232
import org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsBuilder;
3333
import org.elasticsearch.xpack.core.inference.chunking.EmbeddingRequestChunker;
34+
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
35+
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
3436
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
37+
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
3538
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
3639
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
3740
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
@@ -40,9 +43,12 @@
4043
import org.elasticsearch.xpack.inference.services.ServiceComponents;
4144
import org.elasticsearch.xpack.inference.services.ServiceUtils;
4245
import org.elasticsearch.xpack.inference.services.azureopenai.action.AzureOpenAiActionCreator;
46+
import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiChatCompletionResponseHandler;
4347
import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionModel;
4448
import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModel;
4549
import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsServiceSettings;
50+
import org.elasticsearch.xpack.inference.services.azureopenai.request.AzureOpenAiChatCompletionRequest;
51+
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity;
4652
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
4753

4854
import java.util.EnumSet;
@@ -51,14 +57,14 @@
5157
import java.util.Map;
5258
import java.util.Set;
5359

60+
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
5461
import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS;
5562
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
5663
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidTaskTypeException;
5764
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
5865
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
5966
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
6067
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
61-
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation;
6268
import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields.API_VERSION;
6369
import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields.DEPLOYMENT_ID;
6470
import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields.RESOURCE_NAME;
@@ -68,7 +74,16 @@ public class AzureOpenAiService extends SenderService {
6874
public static final String NAME = "azureopenai";
6975

7076
private static final String SERVICE_NAME = "Azure OpenAI";
71-
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION);
77+
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(
78+
TaskType.TEXT_EMBEDDING,
79+
TaskType.COMPLETION,
80+
TaskType.CHAT_COMPLETION
81+
);
82+
public static final String CHAT_COMPLETION_REQUEST_TYPE = "Azure OpenAI chat completions";
83+
private static final ResponseHandler CHAT_COMPLETION_HANDLER = new AzureOpenAiChatCompletionResponseHandler(
84+
CHAT_COMPLETION_REQUEST_TYPE,
85+
OpenAiChatCompletionResponseEntity::fromResponse
86+
);
7287

7388
public AzureOpenAiService(
7489
HttpRequestSender.Factory factory,
@@ -166,7 +181,7 @@ private static AzureOpenAiModel createModel(
166181
context
167182
);
168183
}
169-
case COMPLETION -> {
184+
case COMPLETION, CHAT_COMPLETION -> {
170185
return new AzureOpenAiCompletionModel(
171186
inferenceEntityId,
172187
taskType,
@@ -237,7 +252,25 @@ protected void doUnifiedCompletionInfer(
237252
TimeValue timeout,
238253
ActionListener<InferenceServiceResults> listener
239254
) {
240-
throwUnsupportedUnifiedCompletionOperation(NAME);
255+
if (model instanceof AzureOpenAiCompletionModel == false) {
256+
listener.onFailure(createInvalidModelException(model));
257+
return;
258+
}
259+
260+
AzureOpenAiCompletionModel openAiModel = (AzureOpenAiCompletionModel) model;
261+
262+
var manager = new GenericRequestManager<>(
263+
getServiceComponents().threadPool(),
264+
openAiModel,
265+
CHAT_COMPLETION_HANDLER,
266+
chatInput -> new AzureOpenAiChatCompletionRequest(chatInput, openAiModel),
267+
UnifiedChatInput.class
268+
);
269+
270+
var errorMessage = constructFailedToSendRequestMessage(CHAT_COMPLETION_REQUEST_TYPE);
271+
var action = new SenderExecutableAction(getSender(), manager, errorMessage);
272+
273+
action.execute(inputs, timeout, listener);
241274
}
242275

243276
@Override
@@ -324,7 +357,7 @@ public TransportVersion getMinimalSupportedVersion() {
324357

325358
@Override
326359
public Set<TaskType> supportedStreamingTasks() {
327-
return COMPLETION_ONLY;
360+
return EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION);
328361
}
329362

330363
public static class Configuration {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.azureopenai.completion;
9+
10+
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
11+
import org.elasticsearch.xpack.inference.external.http.retry.UnifiedChatCompletionErrorParserContract;
12+
import org.elasticsearch.xpack.inference.external.http.retry.UnifiedChatCompletionErrorResponseUtils;
13+
import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler;
14+
15+
/**
16+
* Handles streaming chat completion responses and error parsing for OpenShift AI inference endpoints.
17+
* Adapts the OpenAI handler to support OpenShift AI's error schema.
18+
*/
19+
public class AzureOpenAiChatCompletionResponseHandler extends OpenAiUnifiedChatCompletionResponseHandler {
20+
21+
private static final String AZURE_OPENAI_ERROR = "azure_openai_error";
22+
private static final UnifiedChatCompletionErrorParserContract AZURE_OPENAI_ERROR_PARSER = UnifiedChatCompletionErrorResponseUtils
23+
.createErrorParserWithStringify(AZURE_OPENAI_ERROR);
24+
25+
public AzureOpenAiChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
26+
super(requestType, parseFunction, AZURE_OPENAI_ERROR_PARSER::parse, AZURE_OPENAI_ERROR_PARSER);
27+
}
28+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.azureopenai.request;
9+
10+
import org.apache.http.client.methods.HttpPost;
11+
import org.apache.http.entity.ByteArrayEntity;
12+
import org.elasticsearch.common.Strings;
13+
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
14+
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
15+
import org.elasticsearch.xpack.inference.external.request.Request;
16+
import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionModel;
17+
18+
import java.net.URI;
19+
import java.nio.charset.StandardCharsets;
20+
import java.util.Objects;
21+
22+
public class AzureOpenAiChatCompletionRequest implements AzureOpenAiRequest {
23+
24+
private final UnifiedChatInput chatInput;
25+
26+
private final URI uri;
27+
28+
private final AzureOpenAiCompletionModel model;
29+
30+
public AzureOpenAiChatCompletionRequest(UnifiedChatInput chatInput, AzureOpenAiCompletionModel model) {
31+
this.chatInput = chatInput;
32+
this.model = Objects.requireNonNull(model);
33+
this.uri = model.getUri();
34+
}
35+
36+
@Override
37+
public HttpRequest createHttpRequest() {
38+
var httpPost = new HttpPost(uri);
39+
var requestEntity = Strings.toString(new AzureOpenAiChatCompletionRequestEntity(chatInput, model.getTaskSettings().user()));
40+
41+
ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8));
42+
httpPost.setEntity(byteEntity);
43+
44+
AzureOpenAiRequest.decorateWithAuthHeader(httpPost, model.getSecretSettings());
45+
46+
return new HttpRequest(httpPost, getInferenceEntityId());
47+
}
48+
49+
@Override
50+
public URI getURI() {
51+
return this.uri;
52+
}
53+
54+
@Override
55+
public String getInferenceEntityId() {
56+
return model.getInferenceEntityId();
57+
}
58+
59+
@Override
60+
public boolean isStreaming() {
61+
return chatInput.stream();
62+
}
63+
64+
@Override
65+
public Request truncate() {
66+
// No truncation for Azure OpenAI completion
67+
return this;
68+
}
69+
70+
@Override
71+
public boolean[] getTruncationInfo() {
72+
// No truncation for Azure OpenAI completion
73+
return null;
74+
}
75+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.azureopenai.request;
9+
10+
import org.elasticsearch.common.Strings;
11+
import org.elasticsearch.core.Nullable;
12+
import org.elasticsearch.inference.UnifiedCompletionRequest;
13+
import org.elasticsearch.xcontent.ToXContentObject;
14+
import org.elasticsearch.xcontent.XContentBuilder;
15+
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
16+
import org.elasticsearch.xpack.inference.external.unified.UnifiedChatCompletionRequestEntity;
17+
18+
import java.io.IOException;
19+
20+
public class AzureOpenAiChatCompletionRequestEntity implements ToXContentObject {
21+
22+
public static final String USER_FIELD = "user";
23+
private final UnifiedChatCompletionRequestEntity requestEntity;
24+
private final String user;
25+
26+
public AzureOpenAiChatCompletionRequestEntity(UnifiedChatInput chatInput, @Nullable String user) {
27+
this.requestEntity = new UnifiedChatCompletionRequestEntity(chatInput);
28+
this.user = user;
29+
}
30+
31+
@Override
32+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
33+
builder.startObject();
34+
requestEntity.toXContent(builder, UnifiedCompletionRequest.withMaxCompletionTokensTokens(params));
35+
36+
if (Strings.isNullOrEmpty(user) == false) {
37+
builder.field(USER_FIELD, user);
38+
}
39+
builder.endObject();
40+
return builder;
41+
}
42+
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1128,7 +1128,7 @@ public void testGetConfiguration() throws Exception {
11281128
{
11291129
"service": "azureopenai",
11301130
"name": "Azure OpenAI",
1131-
"task_types": ["text_embedding", "completion"],
1131+
"task_types": ["text_embedding", "completion", "chat_completion"],
11321132
"configurations": {
11331133
"api_key": {
11341134
"description": "You must provide either an API key or an Entra ID.",
@@ -1137,7 +1137,7 @@ public void testGetConfiguration() throws Exception {
11371137
"sensitive": true,
11381138
"updatable": true,
11391139
"type": "str",
1140-
"supported_task_types": ["text_embedding", "completion"]
1140+
"supported_task_types": ["text_embedding", "completion", "chat_completion"]
11411141
},
11421142
"dimensions": {
11431143
"description": "The number of dimensions the resulting embeddings should have. For more information refer to https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#request-body-1.",
@@ -1155,7 +1155,7 @@ public void testGetConfiguration() throws Exception {
11551155
"sensitive": true,
11561156
"updatable": true,
11571157
"type": "str",
1158-
"supported_task_types": ["text_embedding", "completion"]
1158+
"supported_task_types": ["text_embedding", "completion", "chat_completion"]
11591159
},
11601160
"rate_limit.requests_per_minute": {
11611161
"description": "The azureopenai service sets a default number of requests allowed per minute depending on the task type.",
@@ -1164,7 +1164,7 @@ public void testGetConfiguration() throws Exception {
11641164
"sensitive": false,
11651165
"updatable": false,
11661166
"type": "int",
1167-
"supported_task_types": ["text_embedding", "completion"]
1167+
"supported_task_types": ["text_embedding", "completion", "chat_completion"]
11681168
},
11691169
"deployment_id": {
11701170
"description": "The deployment name of your deployed models.",
@@ -1173,7 +1173,7 @@ public void testGetConfiguration() throws Exception {
11731173
"sensitive": false,
11741174
"updatable": false,
11751175
"type": "str",
1176-
"supported_task_types": ["text_embedding", "completion"]
1176+
"supported_task_types": ["text_embedding", "completion", "chat_completion"]
11771177
},
11781178
"resource_name": {
11791179
"description": "The name of your Azure OpenAI resource.",
@@ -1182,7 +1182,7 @@ public void testGetConfiguration() throws Exception {
11821182
"sensitive": false,
11831183
"updatable": false,
11841184
"type": "str",
1185-
"supported_task_types": ["text_embedding", "completion"]
1185+
"supported_task_types": ["text_embedding", "completion", "chat_completion"]
11861186
},
11871187
"api_version": {
11881188
"description": "The Azure API version ID to use.",
@@ -1191,7 +1191,7 @@ public void testGetConfiguration() throws Exception {
11911191
"sensitive": false,
11921192
"updatable": false,
11931193
"type": "str",
1194-
"supported_task_types": ["text_embedding", "completion"]
1194+
"supported_task_types": ["text_embedding", "completion", "chat_completion"]
11951195
}
11961196
}
11971197
}
@@ -1214,7 +1214,7 @@ public void testGetConfiguration() throws Exception {
12141214

12151215
public void testSupportsStreaming() throws IOException {
12161216
try (var service = new AzureOpenAiService(mock(), createWithEmptySettings(mock()), mockClusterServiceEmpty())) {
1217-
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION)));
1217+
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION)));
12181218
assertFalse(service.canStream(TaskType.ANY));
12191219
}
12201220
}

0 commit comments

Comments
 (0)