Skip to content

Commit 4ecdfbb

Browse files
authored
[Inference API] Add API to get configuration of inference services (#114862)
* Adding API to get list of service configurations * Update docs/changelog/114862.yaml * Fixing some configurations * PR feedback -> Stream.of * PR feedback -> singleton * Renaming ServiceConfiguration to SettingsConfiguration. Adding TaskSettingsConfiguration * Adding task type settings configuration to response * PR feedback
1 parent e5c7fce commit 4ecdfbb

File tree

90 files changed

+7156
-27
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

90 files changed

+7156
-27
lines changed

docs/changelog/114862.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 114862
2+
summary: "[Inference API] Add API to get configuration of inference services"
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

server/src/main/java/module-info.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,5 +469,6 @@
469469
org.elasticsearch.serverless.shardhealth,
470470
org.elasticsearch.serverless.apifiltering;
471471
exports org.elasticsearch.lucene.spatial;
472+
exports org.elasticsearch.inference.configuration;
472473

473474
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.inference;
11+
12+
import java.util.Collections;
13+
import java.util.Map;
14+
15+
public class EmptySettingsConfiguration {
16+
public static Map<String, SettingsConfiguration> get() {
17+
return Collections.emptyMap();
18+
}
19+
}

server/src/main/java/org/elasticsearch/inference/InferenceService.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import org.elasticsearch.core.TimeValue;
1717

1818
import java.io.Closeable;
19+
import java.util.EnumSet;
1920
import java.util.List;
2021
import java.util.Map;
2122
import java.util.Set;
@@ -71,6 +72,14 @@ default void init(Client client) {}
7172
*/
7273
Model parsePersistedConfig(String modelId, TaskType taskType, Map<String, Object> config);
7374

75+
InferenceServiceConfiguration getConfiguration();
76+
77+
/**
78+
* The task types supported by the service
79+
* @return Set of supported.
80+
*/
81+
EnumSet<TaskType> supportedTaskTypes();
82+
7483
/**
7584
* Perform inference on the model.
7685
*
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.inference;
11+
12+
import org.elasticsearch.ElasticsearchParseException;
13+
import org.elasticsearch.common.bytes.BytesReference;
14+
import org.elasticsearch.common.io.stream.StreamInput;
15+
import org.elasticsearch.common.io.stream.StreamOutput;
16+
import org.elasticsearch.common.io.stream.Writeable;
17+
import org.elasticsearch.common.xcontent.XContentHelper;
18+
import org.elasticsearch.xcontent.ConstructingObjectParser;
19+
import org.elasticsearch.xcontent.ParseField;
20+
import org.elasticsearch.xcontent.ToXContentObject;
21+
import org.elasticsearch.xcontent.XContentBuilder;
22+
import org.elasticsearch.xcontent.XContentParser;
23+
import org.elasticsearch.xcontent.XContentParserConfiguration;
24+
import org.elasticsearch.xcontent.XContentType;
25+
26+
import java.io.IOException;
27+
import java.util.ArrayList;
28+
import java.util.HashMap;
29+
import java.util.List;
30+
import java.util.Map;
31+
import java.util.Objects;
32+
33+
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
34+
35+
/**
36+
* Represents the configuration field settings for an inference provider.
37+
*/
38+
public class InferenceServiceConfiguration implements Writeable, ToXContentObject {
39+
40+
private final String provider;
41+
private final List<TaskSettingsConfiguration> taskTypes;
42+
private final Map<String, SettingsConfiguration> configuration;
43+
44+
/**
45+
* Constructs a new {@link InferenceServiceConfiguration} instance with specified properties.
46+
*
47+
* @param provider The name of the service provider.
48+
* @param taskTypes A list of {@link TaskSettingsConfiguration} supported by the service provider.
49+
* @param configuration The configuration of the service provider, defined by {@link SettingsConfiguration}.
50+
*/
51+
private InferenceServiceConfiguration(
52+
String provider,
53+
List<TaskSettingsConfiguration> taskTypes,
54+
Map<String, SettingsConfiguration> configuration
55+
) {
56+
this.provider = provider;
57+
this.taskTypes = taskTypes;
58+
this.configuration = configuration;
59+
}
60+
61+
public InferenceServiceConfiguration(StreamInput in) throws IOException {
62+
this.provider = in.readString();
63+
this.taskTypes = in.readCollectionAsList(TaskSettingsConfiguration::new);
64+
this.configuration = in.readMap(SettingsConfiguration::new);
65+
}
66+
67+
static final ParseField PROVIDER_FIELD = new ParseField("provider");
68+
static final ParseField TASK_TYPES_FIELD = new ParseField("task_types");
69+
static final ParseField CONFIGURATION_FIELD = new ParseField("configuration");
70+
71+
@SuppressWarnings("unchecked")
72+
private static final ConstructingObjectParser<InferenceServiceConfiguration, Void> PARSER = new ConstructingObjectParser<>(
73+
"inference_service_configuration",
74+
true,
75+
args -> {
76+
List<String> taskTypes = (ArrayList<String>) args[1];
77+
return new InferenceServiceConfiguration.Builder().setProvider((String) args[0])
78+
.setTaskTypes((List<TaskSettingsConfiguration>) args[1])
79+
.setConfiguration((Map<String, SettingsConfiguration>) args[2])
80+
.build();
81+
}
82+
);
83+
84+
static {
85+
PARSER.declareString(constructorArg(), PROVIDER_FIELD);
86+
PARSER.declareObjectArray(constructorArg(), (p, c) -> TaskSettingsConfiguration.fromXContent(p), TASK_TYPES_FIELD);
87+
PARSER.declareObject(constructorArg(), (p, c) -> p.map(), CONFIGURATION_FIELD);
88+
}
89+
90+
public String getProvider() {
91+
return provider;
92+
}
93+
94+
public List<TaskSettingsConfiguration> getTaskTypes() {
95+
return taskTypes;
96+
}
97+
98+
public Map<String, SettingsConfiguration> getConfiguration() {
99+
return configuration;
100+
}
101+
102+
@Override
103+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
104+
builder.startObject();
105+
{
106+
builder.field(PROVIDER_FIELD.getPreferredName(), provider);
107+
builder.field(TASK_TYPES_FIELD.getPreferredName(), taskTypes);
108+
builder.field(CONFIGURATION_FIELD.getPreferredName(), configuration);
109+
}
110+
builder.endObject();
111+
return builder;
112+
}
113+
114+
public static InferenceServiceConfiguration fromXContent(XContentParser parser) throws IOException {
115+
return PARSER.parse(parser, null);
116+
}
117+
118+
public static InferenceServiceConfiguration fromXContentBytes(BytesReference source, XContentType xContentType) {
119+
try (XContentParser parser = XContentHelper.createParser(XContentParserConfiguration.EMPTY, source, xContentType)) {
120+
return InferenceServiceConfiguration.fromXContent(parser);
121+
} catch (IOException e) {
122+
throw new ElasticsearchParseException("failed to parse inference service configuration", e);
123+
}
124+
}
125+
126+
@Override
127+
public void writeTo(StreamOutput out) throws IOException {
128+
out.writeString(provider);
129+
out.writeCollection(taskTypes);
130+
out.writeMapValues(configuration);
131+
}
132+
133+
public Map<String, Object> toMap() {
134+
Map<String, Object> map = new HashMap<>();
135+
136+
map.put(PROVIDER_FIELD.getPreferredName(), provider);
137+
map.put(TASK_TYPES_FIELD.getPreferredName(), taskTypes);
138+
map.put(CONFIGURATION_FIELD.getPreferredName(), configuration);
139+
140+
return map;
141+
}
142+
143+
@Override
144+
public boolean equals(Object o) {
145+
if (this == o) return true;
146+
if (o == null || getClass() != o.getClass()) return false;
147+
InferenceServiceConfiguration that = (InferenceServiceConfiguration) o;
148+
return provider.equals(that.provider)
149+
&& Objects.equals(taskTypes, that.taskTypes)
150+
&& Objects.equals(configuration, that.configuration);
151+
}
152+
153+
@Override
154+
public int hashCode() {
155+
return Objects.hash(provider, taskTypes, configuration);
156+
}
157+
158+
public static class Builder {
159+
160+
private String provider;
161+
private List<TaskSettingsConfiguration> taskTypes;
162+
private Map<String, SettingsConfiguration> configuration;
163+
164+
public Builder setProvider(String provider) {
165+
this.provider = provider;
166+
return this;
167+
}
168+
169+
public Builder setTaskTypes(List<TaskSettingsConfiguration> taskTypes) {
170+
this.taskTypes = taskTypes;
171+
return this;
172+
}
173+
174+
public Builder setConfiguration(Map<String, SettingsConfiguration> configuration) {
175+
this.configuration = configuration;
176+
return this;
177+
}
178+
179+
public InferenceServiceConfiguration build() {
180+
return new InferenceServiceConfiguration(provider, taskTypes, configuration);
181+
}
182+
}
183+
}

0 commit comments

Comments
 (0)