Skip to content

Commit 9600127

Browse files
jonathan-buttnerelasticsearchmachine
andauthored
[ML] Adding custom headers support openai text embeddings (elastic#134960)
* Adding custom headers support openai text embeddings * Update docs/changelog/134960.yaml * Adding headers to the service api result * [CI] Auto commit changes from spotless * Addressing feedback * Adding transport version change * [CI] Auto commit changes from spotless * Cleaning up helpers * [CI] Auto commit changes from spotless --------- Co-authored-by: elasticsearchmachine <[email protected]>
1 parent b34d068 commit 9600127

20 files changed

+554
-584
lines changed

docs/changelog/134960.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 134960
2+
summary: Adding custom headers support openai text embeddings
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
9169000
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
security_stats_endpoint,9168000
1+
inference_api_openai_embeddings_headers,9169000

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -485,9 +485,8 @@ public static InferenceServiceConfiguration get() {
485485

486486
configurationMap.put(
487487
HEADERS,
488-
new SettingsConfiguration.Builder(EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION)).setDescription(
489-
"Custom headers to include in the requests to OpenAI."
490-
)
488+
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION, TaskType.CHAT_COMPLETION))
489+
.setDescription("Custom headers to include in the requests to OpenAI.")
491490
.setLabel("Custom Headers")
492491
.setRequired(false)
493492
.setSensitive(false)
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.openai;
9+
10+
import org.elasticsearch.common.ValidationException;
11+
import org.elasticsearch.core.Nullable;
12+
import org.elasticsearch.inference.ModelConfigurations;
13+
import org.elasticsearch.inference.TaskSettings;
14+
import org.elasticsearch.xcontent.XContentBuilder;
15+
16+
import java.io.IOException;
17+
import java.util.HashMap;
18+
import java.util.Map;
19+
import java.util.Objects;
20+
21+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalMapRemoveNulls;
22+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
23+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.validateMapStringValues;
24+
import static org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields.HEADERS;
25+
import static org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields.USER;
26+
27+
public abstract class OpenAiTaskSettings<T extends OpenAiTaskSettings<T>> implements TaskSettings {
28+
private static final Settings EMPTY_SETTINGS = new Settings(null, null);
29+
30+
private final Settings taskSettings;
31+
32+
public OpenAiTaskSettings(Map<String, Object> map) {
33+
this(fromMap(map));
34+
}
35+
36+
public record Settings(@Nullable String user, @Nullable Map<String, String> headers) {}
37+
38+
public static Settings createSettings(String user, Map<String, String> stringHeaders) {
39+
if (user == null && stringHeaders == null) {
40+
return EMPTY_SETTINGS;
41+
} else {
42+
return new Settings(user, stringHeaders);
43+
}
44+
}
45+
46+
private static Settings fromMap(Map<String, Object> map) {
47+
if (map.isEmpty()) {
48+
return EMPTY_SETTINGS;
49+
}
50+
51+
ValidationException validationException = new ValidationException();
52+
53+
String user = extractOptionalString(map, USER, ModelConfigurations.TASK_SETTINGS, validationException);
54+
Map<String, Object> headers = extractOptionalMapRemoveNulls(map, HEADERS, validationException);
55+
var stringHeaders = validateMapStringValues(headers, HEADERS, validationException, false, null);
56+
57+
if (validationException.validationErrors().isEmpty() == false) {
58+
throw validationException;
59+
}
60+
61+
return createSettings(user, stringHeaders);
62+
}
63+
64+
public OpenAiTaskSettings(@Nullable String user, @Nullable Map<String, String> headers) {
65+
this(new Settings(user, headers));
66+
}
67+
68+
protected OpenAiTaskSettings(Settings taskSettings) {
69+
this.taskSettings = Objects.requireNonNull(taskSettings);
70+
}
71+
72+
public String user() {
73+
return taskSettings.user();
74+
}
75+
76+
public Map<String, String> headers() {
77+
return taskSettings.headers();
78+
}
79+
80+
@Override
81+
public boolean isEmpty() {
82+
return taskSettings.user() == null && (taskSettings.headers() == null || taskSettings.headers().isEmpty());
83+
}
84+
85+
@Override
86+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
87+
builder.startObject();
88+
89+
if (taskSettings.user() != null) {
90+
builder.field(USER, taskSettings.user());
91+
}
92+
93+
if (taskSettings.headers() != null && taskSettings.headers().isEmpty() == false) {
94+
builder.field(HEADERS, taskSettings.headers());
95+
}
96+
97+
builder.endObject();
98+
99+
return builder;
100+
}
101+
102+
@Override
103+
public boolean equals(Object o) {
104+
if (this == o) return true;
105+
if (o == null || getClass() != o.getClass()) return false;
106+
OpenAiTaskSettings<?> that = (OpenAiTaskSettings<?>) o;
107+
return Objects.equals(taskSettings, that.taskSettings);
108+
}
109+
110+
@Override
111+
public int hashCode() {
112+
return Objects.hash(taskSettings);
113+
}
114+
115+
@Override
116+
public T updatedTaskSettings(Map<String, Object> newSettings) {
117+
Settings updatedSettings = fromMap(new HashMap<>(newSettings));
118+
119+
var userToUse = updatedSettings.user() == null ? taskSettings.user() : updatedSettings.user();
120+
var headersToUse = updatedSettings.headers() == null ? taskSettings.headers() : updatedSettings.headers();
121+
return create(userToUse, headersToUse);
122+
}
123+
124+
protected abstract T create(@Nullable String user, @Nullable Map<String, String> headers);
125+
126+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModel.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,7 @@ public static OpenAiChatCompletionModel of(OpenAiChatCompletionModel model, Map<
3535
return model;
3636
}
3737

38-
var requestTaskSettings = OpenAiChatCompletionRequestTaskSettings.fromMap(taskSettings);
39-
return new OpenAiChatCompletionModel(model, OpenAiChatCompletionTaskSettings.of(model.getTaskSettings(), requestTaskSettings));
38+
return new OpenAiChatCompletionModel(model, model.getTaskSettings().updatedTaskSettings(taskSettings));
4039
}
4140

4241
public static OpenAiChatCompletionModel of(OpenAiChatCompletionModel model, UnifiedCompletionRequest request) {
@@ -73,7 +72,7 @@ public OpenAiChatCompletionModel(
7372
taskType,
7473
service,
7574
OpenAiChatCompletionServiceSettings.fromMap(serviceSettings, context),
76-
OpenAiChatCompletionTaskSettings.fromMap(taskSettings),
75+
new OpenAiChatCompletionTaskSettings(taskSettings),
7776
DefaultSecretSettings.fromMap(secrets)
7877
);
7978
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionRequestTaskSettings.java

Lines changed: 0 additions & 57 deletions
This file was deleted.

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionTaskSettings.java

Lines changed: 17 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -9,100 +9,44 @@
99

1010
import org.elasticsearch.TransportVersion;
1111
import org.elasticsearch.TransportVersions;
12-
import org.elasticsearch.common.ValidationException;
1312
import org.elasticsearch.common.io.stream.StreamInput;
1413
import org.elasticsearch.common.io.stream.StreamOutput;
1514
import org.elasticsearch.core.Nullable;
16-
import org.elasticsearch.inference.ModelConfigurations;
17-
import org.elasticsearch.inference.TaskSettings;
18-
import org.elasticsearch.xcontent.XContentBuilder;
15+
import org.elasticsearch.xpack.inference.services.openai.OpenAiTaskSettings;
1916

2017
import java.io.IOException;
21-
import java.util.HashMap;
2218
import java.util.Map;
23-
import java.util.Objects;
2419

2520
import static org.elasticsearch.TransportVersions.INFERENCE_API_OPENAI_HEADERS;
26-
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalMapRemoveNulls;
27-
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
28-
import static org.elasticsearch.xpack.inference.services.ServiceUtils.validateMapStringValues;
29-
import static org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields.HEADERS;
30-
import static org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields.USER;
3121

32-
public class OpenAiChatCompletionTaskSettings implements TaskSettings {
22+
public class OpenAiChatCompletionTaskSettings extends OpenAiTaskSettings<OpenAiChatCompletionTaskSettings> {
3323

3424
public static final String NAME = "openai_completion_task_settings";
3525

36-
public static OpenAiChatCompletionTaskSettings fromMap(Map<String, Object> map) {
37-
ValidationException validationException = new ValidationException();
38-
39-
String user = extractOptionalString(map, USER, ModelConfigurations.TASK_SETTINGS, validationException);
40-
var headers = extractOptionalMapRemoveNulls(map, HEADERS, validationException);
41-
var stringHeaders = validateMapStringValues(headers, HEADERS, validationException, false, null);
42-
43-
if (validationException.validationErrors().isEmpty() == false) {
44-
throw validationException;
45-
}
46-
47-
return new OpenAiChatCompletionTaskSettings(user, stringHeaders);
26+
public OpenAiChatCompletionTaskSettings(Map<String, Object> map) {
27+
super(map);
4828
}
4929

50-
private final String user;
51-
@Nullable
52-
private final Map<String, String> headers;
53-
5430
public OpenAiChatCompletionTaskSettings(@Nullable String user, @Nullable Map<String, String> headers) {
55-
this.user = user;
56-
this.headers = headers;
31+
super(user, headers);
5732
}
5833

5934
public OpenAiChatCompletionTaskSettings(StreamInput in) throws IOException {
60-
this.user = in.readOptionalString();
35+
super(readTaskSettingsFromStream(in));
36+
}
37+
38+
private static Settings readTaskSettingsFromStream(StreamInput in) throws IOException {
39+
var user = in.readOptionalString();
40+
41+
Map<String, String> headers;
6142

6243
if (in.getTransportVersion().onOrAfter(INFERENCE_API_OPENAI_HEADERS)) {
6344
headers = in.readOptionalImmutableMap(StreamInput::readString, StreamInput::readString);
6445
} else {
6546
headers = null;
6647
}
67-
}
68-
69-
@Override
70-
public boolean isEmpty() {
71-
return user == null && (headers == null || headers.isEmpty());
72-
}
73-
74-
public static OpenAiChatCompletionTaskSettings of(
75-
OpenAiChatCompletionTaskSettings originalSettings,
76-
OpenAiChatCompletionRequestTaskSettings requestSettings
77-
) {
78-
var userToUse = requestSettings.user() == null ? originalSettings.user : requestSettings.user();
79-
var headersToUse = requestSettings.headers() == null ? originalSettings.headers : requestSettings.headers();
80-
return new OpenAiChatCompletionTaskSettings(userToUse, headersToUse);
81-
}
82-
83-
public String user() {
84-
return user;
85-
}
8648

87-
public Map<String, String> headers() {
88-
return headers;
89-
}
90-
91-
@Override
92-
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
93-
builder.startObject();
94-
95-
if (user != null) {
96-
builder.field(USER, user);
97-
}
98-
99-
if (headers != null && headers.isEmpty() == false) {
100-
builder.field(HEADERS, headers);
101-
}
102-
103-
builder.endObject();
104-
105-
return builder;
49+
return createSettings(user, headers);
10650
}
10751

10852
@Override
@@ -117,30 +61,14 @@ public TransportVersion getMinimalSupportedVersion() {
11761

11862
@Override
11963
public void writeTo(StreamOutput out) throws IOException {
120-
out.writeOptionalString(user);
64+
out.writeOptionalString(user());
12165
if (out.getTransportVersion().onOrAfter(INFERENCE_API_OPENAI_HEADERS)) {
122-
out.writeOptionalMap(headers, StreamOutput::writeString, StreamOutput::writeString);
66+
out.writeOptionalMap(headers(), StreamOutput::writeString, StreamOutput::writeString);
12367
}
12468
}
12569

12670
@Override
127-
public boolean equals(Object object) {
128-
if (this == object) return true;
129-
if (object == null || getClass() != object.getClass()) return false;
130-
OpenAiChatCompletionTaskSettings that = (OpenAiChatCompletionTaskSettings) object;
131-
return Objects.equals(user, that.user) && Objects.equals(headers, that.headers);
132-
}
133-
134-
@Override
135-
public int hashCode() {
136-
return Objects.hash(user, headers);
137-
}
138-
139-
@Override
140-
public TaskSettings updatedTaskSettings(Map<String, Object> newSettings) {
141-
OpenAiChatCompletionRequestTaskSettings updatedSettings = OpenAiChatCompletionRequestTaskSettings.fromMap(
142-
new HashMap<>(newSettings)
143-
);
144-
return of(this, updatedSettings);
71+
protected OpenAiChatCompletionTaskSettings create(@Nullable String user, @Nullable Map<String, String> headers) {
72+
return new OpenAiChatCompletionTaskSettings(user, headers);
14573
}
14674
}

0 commit comments

Comments
 (0)