Skip to content

Commit 06e1621

Browse files
[ML] Adds a new field supported_task_types in the configuration response (#120150)
* Adding new field to settings class * adding new available_for_task_types field * Update docs/changelog/120150.yaml * Delete docs/changelog/120150.yaml * Fixing tests and task types * Renaming field to supported_task_types * Pulling in chat_completion addition * Addressing feedback
1 parent 0392de0 commit 06e1621

File tree

46 files changed

+325
-173
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+325
-173
lines changed

server/src/main/java/org/elasticsearch/inference/InferenceServiceConfiguration.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ public Builder setName(String name) {
191191
}
192192

193193
public Builder setTaskTypes(EnumSet<TaskType> taskTypes) {
194-
this.taskTypes = taskTypes;
194+
this.taskTypes = TaskType.copyOf(taskTypes);
195195
return this;
196196
}
197197

server/src/main/java/org/elasticsearch/inference/SettingsConfiguration.java

Lines changed: 34 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,13 @@
2828
import org.elasticsearch.xcontent.XContentType;
2929

3030
import java.io.IOException;
31+
import java.util.EnumSet;
3132
import java.util.HashMap;
33+
import java.util.List;
3234
import java.util.Map;
3335
import java.util.Objects;
3436
import java.util.Optional;
37+
import java.util.Set;
3538

3639
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
3740
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
@@ -50,6 +53,7 @@ public class SettingsConfiguration implements Writeable, ToXContentObject {
5053
private final boolean sensitive;
5154
private final boolean updatable;
5255
private final SettingsConfigurationFieldType type;
56+
private final EnumSet<TaskType> supportedTaskTypes;
5357

5458
/**
5559
* Constructs a new {@link SettingsConfiguration} instance with specified properties.
@@ -61,6 +65,7 @@ public class SettingsConfiguration implements Writeable, ToXContentObject {
6165
* @param sensitive A boolean indicating whether the configuration contains sensitive information.
6266
* @param updatable A boolean indicating whether the configuration can be updated.
6367
* @param type The type of the configuration field, defined by {@link SettingsConfigurationFieldType}.
68+
* @param supportedTaskTypes The task types that support this field.
6469
*/
6570
private SettingsConfiguration(
6671
Object defaultValue,
@@ -69,7 +74,8 @@ private SettingsConfiguration(
6974
boolean required,
7075
boolean sensitive,
7176
boolean updatable,
72-
SettingsConfigurationFieldType type
77+
SettingsConfigurationFieldType type,
78+
EnumSet<TaskType> supportedTaskTypes
7379
) {
7480
this.defaultValue = defaultValue;
7581
this.description = description;
@@ -78,6 +84,7 @@ private SettingsConfiguration(
7884
this.sensitive = sensitive;
7985
this.updatable = updatable;
8086
this.type = type;
87+
this.supportedTaskTypes = supportedTaskTypes;
8188
}
8289

8390
public SettingsConfiguration(StreamInput in) throws IOException {
@@ -88,6 +95,7 @@ public SettingsConfiguration(StreamInput in) throws IOException {
8895
this.sensitive = in.readBoolean();
8996
this.updatable = in.readBoolean();
9097
this.type = in.readEnum(SettingsConfigurationFieldType.class);
98+
this.supportedTaskTypes = in.readEnumSet(TaskType.class);
9199
}
92100

93101
static final ParseField DEFAULT_VALUE_FIELD = new ParseField("default_value");
@@ -97,14 +105,23 @@ public SettingsConfiguration(StreamInput in) throws IOException {
97105
static final ParseField SENSITIVE_FIELD = new ParseField("sensitive");
98106
static final ParseField UPDATABLE_FIELD = new ParseField("updatable");
99107
static final ParseField TYPE_FIELD = new ParseField("type");
108+
static final ParseField SUPPORTED_TASK_TYPES = new ParseField("supported_task_types");
100109

101110
@SuppressWarnings("unchecked")
102111
private static final ConstructingObjectParser<SettingsConfiguration, Void> PARSER = new ConstructingObjectParser<>(
103112
"service_configuration",
104113
true,
105114
args -> {
106115
int i = 0;
107-
return new SettingsConfiguration.Builder().setDefaultValue(args[i++])
116+
117+
EnumSet<TaskType> supportedTaskTypes = EnumSet.noneOf(TaskType.class);
118+
var supportedTaskTypesListOfStrings = (List<String>) args[i++];
119+
120+
for (var supportedTaskTypeString : supportedTaskTypesListOfStrings) {
121+
supportedTaskTypes.add(TaskType.fromString(supportedTaskTypeString));
122+
}
123+
124+
return new SettingsConfiguration.Builder(supportedTaskTypes).setDefaultValue(args[i++])
108125
.setDescription((String) args[i++])
109126
.setLabel((String) args[i++])
110127
.setRequired((Boolean) args[i++])
@@ -116,6 +133,7 @@ public SettingsConfiguration(StreamInput in) throws IOException {
116133
);
117134

118135
static {
136+
PARSER.declareStringArray(constructorArg(), SUPPORTED_TASK_TYPES);
119137
PARSER.declareField(optionalConstructorArg(), (p, c) -> {
120138
if (p.currentToken() == XContentParser.Token.VALUE_STRING) {
121139
return p.text();
@@ -169,28 +187,8 @@ public SettingsConfigurationFieldType getType() {
169187
return type;
170188
}
171189

172-
/**
173-
* Parses a configuration value from a parser context.
174-
* This method can parse strings, numbers, booleans, objects, and null values, matching the types commonly
175-
* supported in {@link SettingsConfiguration}.
176-
*
177-
* @param p the {@link org.elasticsearch.xcontent.XContentParser} instance from which to parse the configuration value.
178-
*/
179-
public static Object parseConfigurationValue(XContentParser p) throws IOException {
180-
181-
if (p.currentToken() == XContentParser.Token.VALUE_STRING) {
182-
return p.text();
183-
} else if (p.currentToken() == XContentParser.Token.VALUE_NUMBER) {
184-
return p.numberValue();
185-
} else if (p.currentToken() == XContentParser.Token.VALUE_BOOLEAN) {
186-
return p.booleanValue();
187-
} else if (p.currentToken() == XContentParser.Token.START_OBJECT) {
188-
// Crawler expects the value to be an object
189-
return p.map();
190-
} else if (p.currentToken() == XContentParser.Token.VALUE_NULL) {
191-
return null;
192-
}
193-
throw new XContentParseException("Unsupported token [" + p.currentToken() + "]");
190+
public Set<TaskType> getSupportedTaskTypes() {
191+
return supportedTaskTypes;
194192
}
195193

196194
@Override
@@ -211,6 +209,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
211209
if (type != null) {
212210
builder.field(TYPE_FIELD.getPreferredName(), type.toString());
213211
}
212+
builder.field(SUPPORTED_TASK_TYPES.getPreferredName(), supportedTaskTypes);
214213
}
215214
builder.endObject();
216215
return builder;
@@ -237,6 +236,7 @@ public void writeTo(StreamOutput out) throws IOException {
237236
out.writeBoolean(sensitive);
238237
out.writeBoolean(updatable);
239238
out.writeEnum(type);
239+
out.writeEnumSet(supportedTaskTypes);
240240
}
241241

242242
public Map<String, Object> toMap() {
@@ -253,6 +253,7 @@ public Map<String, Object> toMap() {
253253

254254
Optional.ofNullable(type).ifPresent(t -> map.put(TYPE_FIELD.getPreferredName(), t.toString()));
255255

256+
map.put(SUPPORTED_TASK_TYPES.getPreferredName(), supportedTaskTypes);
256257
return map;
257258
}
258259

@@ -267,12 +268,13 @@ public boolean equals(Object o) {
267268
&& Objects.equals(defaultValue, that.defaultValue)
268269
&& Objects.equals(description, that.description)
269270
&& Objects.equals(label, that.label)
270-
&& type == that.type;
271+
&& type == that.type
272+
&& Objects.equals(supportedTaskTypes, that.supportedTaskTypes);
271273
}
272274

273275
@Override
274276
public int hashCode() {
275-
return Objects.hash(defaultValue, description, label, required, sensitive, updatable, type);
277+
return Objects.hash(defaultValue, description, label, required, sensitive, updatable, type, supportedTaskTypes);
276278
}
277279

278280
public static class Builder {
@@ -284,6 +286,11 @@ public static class Builder {
284286
private boolean sensitive;
285287
private boolean updatable;
286288
private SettingsConfigurationFieldType type;
289+
private final EnumSet<TaskType> supportedTaskTypes;
290+
291+
public Builder(EnumSet<TaskType> supportedTaskTypes) {
292+
this.supportedTaskTypes = TaskType.copyOf(Objects.requireNonNull(supportedTaskTypes));
293+
}
287294

288295
public Builder setDefaultValue(Object defaultValue) {
289296
this.defaultValue = defaultValue;
@@ -321,7 +328,7 @@ public Builder setType(SettingsConfigurationFieldType type) {
321328
}
322329

323330
public SettingsConfiguration build() {
324-
return new SettingsConfiguration(defaultValue, description, label, required, sensitive, updatable, type);
331+
return new SettingsConfiguration(defaultValue, description, label, required, sensitive, updatable, type, supportedTaskTypes);
325332
}
326333
}
327334
}

server/src/main/java/org/elasticsearch/inference/TaskType.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import org.elasticsearch.rest.RestStatus;
1717

1818
import java.io.IOException;
19+
import java.util.EnumSet;
1920
import java.util.Locale;
2021
import java.util.Objects;
2122

@@ -78,4 +79,14 @@ public void writeTo(StreamOutput out) throws IOException {
7879
public static String unsupportedTaskTypeErrorMsg(TaskType taskType, String serviceName) {
7980
return "The [" + serviceName + "] service does not support task type [" + taskType + "]";
8081
}
82+
83+
/**
84+
* Copies a {@link EnumSet<TaskType>} if non-empty, otherwise returns an empty {@link EnumSet<TaskType>}. This is essentially the same
85+
* as {@link EnumSet#copyOf(EnumSet)}, except it does not throw for an empty set.
86+
* @param taskTypes task types to copy
87+
* @return a copy of the passed in {@link EnumSet<TaskType>}
88+
*/
89+
public static EnumSet<TaskType> copyOf(EnumSet<TaskType> taskTypes) {
90+
return taskTypes.isEmpty() ? EnumSet.noneOf(TaskType.class) : EnumSet.copyOf(taskTypes);
91+
}
8192
}

server/src/test/java/org/elasticsearch/inference/SettingsConfigurationTestUtils.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,18 @@
1111

1212
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
1313

14+
import java.util.EnumSet;
15+
1416
import static org.elasticsearch.test.ESTestCase.randomAlphaOfLength;
1517
import static org.elasticsearch.test.ESTestCase.randomBoolean;
1618
import static org.elasticsearch.test.ESTestCase.randomInt;
1719

1820
public class SettingsConfigurationTestUtils {
1921

2022
public static SettingsConfiguration getRandomSettingsConfigurationField() {
21-
return new SettingsConfiguration.Builder().setDefaultValue(randomAlphaOfLength(10))
23+
return new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING)).setDefaultValue(
24+
randomAlphaOfLength(10)
25+
)
2226
.setDescription(randomAlphaOfLength(10))
2327
.setLabel(randomAlphaOfLength(10))
2428
.setRequired(randomBoolean())

server/src/test/java/org/elasticsearch/inference/SettingsConfigurationTests.java

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ public void testToXContent() throws IOException {
3434
"required": true,
3535
"sensitive": false,
3636
"updatable": true,
37-
"type": "str"
37+
"type": "str",
38+
"supported_task_types": ["text_embedding", "completion", "sparse_embedding", "rerank"]
3839
}
3940
""");
4041

@@ -56,7 +57,8 @@ public void testToXContent_WithNumericSelectOptions() throws IOException {
5657
"required": true,
5758
"sensitive": false,
5859
"updatable": true,
59-
"type": "str"
60+
"type": "str",
61+
"supported_task_types": ["text_embedding"]
6062
}
6163
""");
6264

@@ -74,7 +76,8 @@ public void testToXContentCrawlerConfig_WithNullValue() throws IOException {
7476
String content = XContentHelper.stripWhitespace("""
7577
{
7678
"label": "nextSyncConfig",
77-
"value": null
79+
"value": null,
80+
"supported_task_types": ["text_embedding", "completion", "sparse_embedding", "rerank"]
7881
}
7982
""");
8083

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ public static InferenceServiceConfiguration get() {
257257

258258
configurationMap.put(
259259
"model",
260-
new SettingsConfiguration.Builder().setDescription("")
260+
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING)).setDescription("")
261261
.setLabel("Model")
262262
.setRequired(true)
263263
.setSensitive(true)

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ public static InferenceServiceConfiguration get() {
171171

172172
configurationMap.put(
173173
"model",
174-
new SettingsConfiguration.Builder().setDescription("")
174+
new SettingsConfiguration.Builder(EnumSet.of(TaskType.RERANK)).setDescription("")
175175
.setLabel("Model")
176176
.setRequired(true)
177177
.setSensitive(true)

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ public static InferenceServiceConfiguration get() {
205205

206206
configurationMap.put(
207207
"model",
208-
new SettingsConfiguration.Builder().setDescription("")
208+
new SettingsConfiguration.Builder(EnumSet.of(TaskType.SPARSE_EMBEDDING)).setDescription("")
209209
.setLabel("Model")
210210
.setRequired(true)
211211
.setSensitive(false)
@@ -215,7 +215,7 @@ public static InferenceServiceConfiguration get() {
215215

216216
configurationMap.put(
217217
"hidden_field",
218-
new SettingsConfiguration.Builder().setDescription("")
218+
new SettingsConfiguration.Builder(EnumSet.of(TaskType.SPARSE_EMBEDDING)).setDescription("")
219219
.setLabel("Hidden Field")
220220
.setRequired(true)
221221
.setSensitive(false)

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ public static InferenceServiceConfiguration get() {
257257

258258
configurationMap.put(
259259
"model_id",
260-
new SettingsConfiguration.Builder().setDescription("")
260+
new SettingsConfiguration.Builder(EnumSet.of(TaskType.COMPLETION)).setDescription("")
261261
.setLabel("Model ID")
262262
.setRequired(true)
263263
.setSensitive(true)

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,9 @@ public static InferenceServiceConfiguration get() {
379379

380380
configurationMap.put(
381381
SERVICE_ID,
382-
new SettingsConfiguration.Builder().setDescription("The name of the model service to use for the {infer} task.")
382+
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription(
383+
"The name of the model service to use for the {infer} task."
384+
)
383385
.setLabel("Project ID")
384386
.setRequired(true)
385387
.setSensitive(false)
@@ -390,7 +392,7 @@ public static InferenceServiceConfiguration get() {
390392

391393
configurationMap.put(
392394
HOST,
393-
new SettingsConfiguration.Builder().setDescription(
395+
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription(
394396
"The name of the host address used for the {infer} task. You can find the host address at "
395397
+ "https://opensearch.console.aliyun.com/cn-shanghai/rag/api-key[ the API keys section] "
396398
+ "of the documentation."
@@ -405,7 +407,7 @@ public static InferenceServiceConfiguration get() {
405407

406408
configurationMap.put(
407409
HTTP_SCHEMA_NAME,
408-
new SettingsConfiguration.Builder().setDescription("")
410+
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription("")
409411
.setLabel("HTTP Schema")
410412
.setRequired(true)
411413
.setSensitive(false)
@@ -416,7 +418,9 @@ public static InferenceServiceConfiguration get() {
416418

417419
configurationMap.put(
418420
WORKSPACE_NAME,
419-
new SettingsConfiguration.Builder().setDescription("The name of the workspace used for the {infer} task.")
421+
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription(
422+
"The name of the workspace used for the {infer} task."
423+
)
420424
.setLabel("Workspace")
421425
.setRequired(true)
422426
.setSensitive(false)
@@ -426,9 +430,12 @@ public static InferenceServiceConfiguration get() {
426430
);
427431

428432
configurationMap.putAll(
429-
DefaultSecretSettings.toSettingsConfigurationWithDescription("A valid API key for the AlibabaCloud AI Search API.")
433+
DefaultSecretSettings.toSettingsConfigurationWithDescription(
434+
"A valid API key for the AlibabaCloud AI Search API.",
435+
supportedTaskTypes
436+
)
430437
);
431-
configurationMap.putAll(RateLimitSettings.toSettingsConfiguration());
438+
configurationMap.putAll(RateLimitSettings.toSettingsConfiguration(supportedTaskTypes));
432439

433440
return new InferenceServiceConfiguration.Builder().setService(NAME)
434441
.setName(SERVICE_NAME)

0 commit comments

Comments
 (0)