Skip to content

Commit 1ca8b32

Browse files
Functionality for filtering task types based on acl info for EIS
1 parent 430c9fa commit 1ca8b32

File tree

12 files changed

+353
-36
lines changed

12 files changed

+353
-36
lines changed

server/src/main/java/org/elasticsearch/inference/InferenceService.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,8 @@ default void init(Client client) {}
7878
* Whether this service should be hidden from the API. Should be used for services
7979
* that are not ready to be used.
8080
*/
81-
default Boolean hideFromConfigurationApi() {
82-
return Boolean.FALSE;
81+
default boolean hideFromConfigurationApi() {
82+
return false;
8383
}
8484

8585
/**

server/src/main/java/org/elasticsearch/inference/InferenceServiceConfiguration.java

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,11 @@
2424
import org.elasticsearch.xcontent.XContentType;
2525

2626
import java.io.IOException;
27-
import java.util.ArrayList;
2827
import java.util.EnumSet;
2928
import java.util.HashMap;
3029
import java.util.List;
3130
import java.util.Map;
3231
import java.util.Objects;
33-
import java.util.stream.Collectors;
3432

3533
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
3634

@@ -80,14 +78,11 @@ public InferenceServiceConfiguration(StreamInput in) throws IOException {
8078
private static final ConstructingObjectParser<InferenceServiceConfiguration, Void> PARSER = new ConstructingObjectParser<>(
8179
"inference_service_configuration",
8280
true,
83-
args -> {
84-
List<String> taskTypes = (ArrayList<String>) args[2];
85-
return new InferenceServiceConfiguration.Builder().setService((String) args[0])
86-
.setName((String) args[1])
87-
.setTaskTypes(EnumSet.copyOf(taskTypes.stream().map(TaskType::fromString).collect(Collectors.toList())))
88-
.setConfigurations((Map<String, SettingsConfiguration>) args[3])
89-
.build();
90-
}
81+
args -> new InferenceServiceConfiguration.Builder().setService((String) args[0])
82+
.setName((String) args[1])
83+
.setTaskTypes((List<String>) args[2])
84+
.setConfigurations((Map<String, SettingsConfiguration>) args[3])
85+
.build()
9186
);
9287

9388
static {
@@ -195,6 +190,16 @@ public Builder setTaskTypes(EnumSet<TaskType> taskTypes) {
195190
return this;
196191
}
197192

193+
public Builder setTaskTypes(List<String> taskTypes) {
194+
var enumTaskTypes = EnumSet.noneOf(TaskType.class);
195+
196+
for (var supportedTaskTypeString : taskTypes) {
197+
enumTaskTypes.add(TaskType.fromString(supportedTaskTypeString));
198+
}
199+
this.taskTypes = enumTaskTypes;
200+
return this;
201+
}
202+
198203
public Builder setConfigurations(Map<String, SettingsConfiguration> configurations) {
199204
this.configurations = configurations;
200205
return this;

server/src/test/java/org/elasticsearch/inference/InferenceServiceConfigurationTests.java

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,47 @@ public void testToXContent() throws IOException {
6666
assertToXContentEquivalent(originalBytes, toXContent(parsed, XContentType.JSON, humanReadable), XContentType.JSON);
6767
}
6868

69+
public void testToXContent_EmptyTaskTypes() throws IOException {
70+
String content = XContentHelper.stripWhitespace("""
71+
{
72+
"service": "some_provider",
73+
"name": "Some Provider",
74+
"task_types": [],
75+
"configurations": {
76+
"text_field_configuration": {
77+
"description": "Wow, this tooltip is useful.",
78+
"label": "Very important field",
79+
"required": true,
80+
"sensitive": true,
81+
"updatable": false,
82+
"type": "str"
83+
},
84+
"numeric_field_configuration": {
85+
"default_value": 3,
86+
"description": "Wow, this tooltip is useful.",
87+
"label": "Very important numeric field",
88+
"required": true,
89+
"sensitive": false,
90+
"updatable": true,
91+
"type": "int"
92+
}
93+
}
94+
}
95+
""");
96+
97+
InferenceServiceConfiguration configuration = InferenceServiceConfiguration.fromXContentBytes(
98+
new BytesArray(content),
99+
XContentType.JSON
100+
);
101+
boolean humanReadable = true;
102+
BytesReference originalBytes = toShuffledXContent(configuration, XContentType.JSON, ToXContent.EMPTY_PARAMS, humanReadable);
103+
InferenceServiceConfiguration parsed;
104+
try (XContentParser parser = createParser(XContentType.JSON.xContent(), originalBytes)) {
105+
parsed = InferenceServiceConfiguration.fromXContent(parser);
106+
}
107+
assertToXContentEquivalent(originalBytes, toXContent(parsed, XContentType.JSON, humanReadable), XContentType.JSON);
108+
}
109+
69110
public void testToMap() {
70111
InferenceServiceConfiguration configField = InferenceServiceConfigurationTestUtils.getRandomServiceConfigurationField();
71112
Map<String, Object> configFieldAsMap = configField.toMap();

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import org.elasticsearch.indices.SystemIndexDescriptor;
3030
import org.elasticsearch.inference.InferenceServiceExtension;
3131
import org.elasticsearch.inference.InferenceServiceRegistry;
32+
import org.elasticsearch.inference.TaskType;
3233
import org.elasticsearch.license.License;
3334
import org.elasticsearch.license.LicensedFeature;
3435
import org.elasticsearch.license.XPackLicenseState;
@@ -110,6 +111,7 @@
110111
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiService;
111112
import org.elasticsearch.xpack.inference.services.cohere.CohereService;
112113
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;
114+
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceACL;
113115
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents;
114116
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings;
115117
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService;
@@ -125,6 +127,7 @@
125127

126128
import java.util.ArrayList;
127129
import java.util.Collection;
130+
import java.util.EnumSet;
128131
import java.util.List;
129132
import java.util.Map;
130133
import java.util.function.Predicate;
@@ -274,7 +277,12 @@ public Collection<?> createComponents(PluginServices services) {
274277

275278
ElasticInferenceServiceSettings inferenceServiceSettings = new ElasticInferenceServiceSettings(settings);
276279
String elasticInferenceUrl = this.getElasticInferenceServiceUrl(inferenceServiceSettings);
277-
elasticInferenceServiceComponents.set(new ElasticInferenceServiceComponents(elasticInferenceUrl));
280+
elasticInferenceServiceComponents.set(
281+
new ElasticInferenceServiceComponents(
282+
elasticInferenceUrl,
283+
new ElasticInferenceServiceACL(Map.of("model-abc", EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.COMPLETION)))
284+
)
285+
);
278286

279287
inferenceServices.add(
280288
() -> List.of(

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -69,21 +69,39 @@ public class ElasticInferenceService extends SenderService {
6969

7070
private final ElasticInferenceServiceComponents elasticInferenceServiceComponents;
7171

72-
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.COMPLETION);
72+
private static final EnumSet<TaskType> IMPLEMENTED_TASK_TYPES = EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.COMPLETION);
7373
private static final String SERVICE_NAME = "Elastic";
7474

75+
private final Configuration configuration;
76+
private final EnumSet<TaskType> enabledTaskTypes;
77+
7578
public ElasticInferenceService(
7679
HttpRequestSender.Factory factory,
7780
ServiceComponents serviceComponents,
7881
ElasticInferenceServiceComponents elasticInferenceServiceComponents
7982
) {
8083
super(factory, serviceComponents);
8184
this.elasticInferenceServiceComponents = elasticInferenceServiceComponents;
85+
enabledTaskTypes = enabledTaskTypes(this.elasticInferenceServiceComponents.acl());
86+
configuration = new Configuration(enabledTaskTypes);
87+
}
88+
89+
private static EnumSet<TaskType> enabledTaskTypes(ElasticInferenceServiceACL acl) {
90+
var implementedTaskTypes = EnumSet.copyOf(IMPLEMENTED_TASK_TYPES);
91+
implementedTaskTypes.retainAll(acl.enabledTaskTypes());
92+
return implementedTaskTypes;
8293
}
8394

8495
@Override
8596
public Set<TaskType> supportedStreamingTasks() {
86-
return COMPLETION_ONLY;
97+
var enabledStreamingTaskTypes = EnumSet.of(TaskType.COMPLETION);
98+
enabledStreamingTaskTypes.retainAll(enabledTaskTypes);
99+
100+
if (enabledStreamingTaskTypes.isEmpty() == false) {
101+
enabledStreamingTaskTypes.add(TaskType.ANY);
102+
}
103+
104+
return enabledStreamingTaskTypes;
87105
}
88106

89107
@Override
@@ -202,12 +220,17 @@ public void parseRequestConfig(
202220

203221
@Override
204222
public InferenceServiceConfiguration getConfiguration() {
205-
return Configuration.get();
223+
return configuration.get();
206224
}
207225

208226
@Override
209227
public EnumSet<TaskType> supportedTaskTypes() {
210-
return supportedTaskTypes;
228+
return enabledTaskTypes;
229+
}
230+
231+
@Override
232+
public boolean hideFromConfigurationApi() {
233+
return enabledTaskTypes.isEmpty();
211234
}
212235

213236
private static ElasticInferenceServiceModel createModel(
@@ -349,12 +372,17 @@ private TraceContext getCurrentTraceInfo() {
349372
}
350373

351374
public static class Configuration {
352-
public static InferenceServiceConfiguration get() {
353-
return configuration.getOrCompute();
375+
376+
private final EnumSet<TaskType> enabledTaskTypes;
377+
private final LazyInitializable<InferenceServiceConfiguration, RuntimeException> configuration;
378+
379+
public Configuration(EnumSet<TaskType> enabledTaskTypes) {
380+
this.enabledTaskTypes = enabledTaskTypes;
381+
configuration = initConfiguration();
354382
}
355383

356-
private static final LazyInitializable<InferenceServiceConfiguration, RuntimeException> configuration = new LazyInitializable<>(
357-
() -> {
384+
private LazyInitializable<InferenceServiceConfiguration, RuntimeException> initConfiguration() {
385+
return new LazyInitializable<>(() -> {
358386
var configurationMap = new HashMap<String, SettingsConfiguration>();
359387

360388
configurationMap.put(
@@ -383,10 +411,14 @@ public static InferenceServiceConfiguration get() {
383411

384412
return new InferenceServiceConfiguration.Builder().setService(NAME)
385413
.setName(SERVICE_NAME)
386-
.setTaskTypes(supportedTaskTypes)
414+
.setTaskTypes(enabledTaskTypes)
387415
.setConfigurations(configurationMap)
388416
.build();
389-
}
390-
);
417+
});
418+
}
419+
420+
public InferenceServiceConfiguration get() {
421+
return configuration.getOrCompute();
422+
}
391423
}
392424
}
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.elastic;
9+
10+
import org.elasticsearch.inference.TaskType;
11+
12+
import java.util.EnumSet;
13+
import java.util.Map;
14+
import java.util.Objects;
15+
import java.util.Set;
16+
import java.util.stream.Collectors;
17+
18+
/**
19+
* Provides a structure for governing which models (if any) a cluster has access to according to the upstream Elastic Inference Service.
20+
* @param enabledModels a mapping of model ids to a set of {@link TaskType} to indicate which models are available and for which task types
21+
*/
22+
public record ElasticInferenceServiceACL(Map<String, EnumSet<TaskType>> enabledModels) {
23+
24+
/**
25+
* Returns an object indicating that the cluster has no access to EIS.
26+
*/
27+
public static ElasticInferenceServiceACL newDisabledService() {
28+
return new ElasticInferenceServiceACL();
29+
}
30+
31+
public ElasticInferenceServiceACL {
32+
Objects.requireNonNull(enabledModels);
33+
}
34+
35+
private ElasticInferenceServiceACL() {
36+
this(Map.of());
37+
}
38+
39+
public boolean isEnabled() {
40+
return enabledModels.isEmpty() == false;
41+
}
42+
43+
public EnumSet<TaskType> enabledTaskTypes() {
44+
return enabledModels.values().stream().flatMap(Set::stream).collect(Collectors.toCollection(() -> EnumSet.noneOf(TaskType.class)));
45+
}
46+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceComponents.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@
77

88
package org.elasticsearch.xpack.inference.services.elastic;
99

10-
public record ElasticInferenceServiceComponents(String elasticInferenceServiceUrl) {}
10+
public record ElasticInferenceServiceComponents(String elasticInferenceServiceUrl, ElasticInferenceServiceACL acl) {}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,8 @@ public InferenceServiceConfiguration getConfiguration() {
151151
}
152152

153153
@Override
154-
public Boolean hideFromConfigurationApi() {
155-
return Boolean.TRUE;
154+
public boolean hideFromConfigurationApi() {
155+
return true;
156156
}
157157

158158
@Override
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.elastic;
9+
10+
import org.elasticsearch.inference.TaskType;
11+
import org.elasticsearch.test.ESTestCase;
12+
13+
import java.util.EnumSet;
14+
import java.util.Map;
15+
16+
import static org.hamcrest.Matchers.is;
17+
18+
public class ElasticInferenceServiceACLTests extends ESTestCase {
19+
public static ElasticInferenceServiceACL createEnabledAcl() {
20+
return new ElasticInferenceServiceACL(Map.of("model-1", EnumSet.of(TaskType.TEXT_EMBEDDING)));
21+
}
22+
23+
public void testIsEnabled_ReturnsFalse_WithEmptyMap() {
24+
assertFalse(ElasticInferenceServiceACL.newDisabledService().isEnabled());
25+
}
26+
27+
public void testEnabledTaskTypes_MergesFromSeparateModels() {
28+
assertThat(
29+
new ElasticInferenceServiceACL(
30+
Map.of("model-1", EnumSet.of(TaskType.TEXT_EMBEDDING), "model-2", EnumSet.of(TaskType.SPARSE_EMBEDDING))
31+
).enabledTaskTypes(),
32+
is(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING))
33+
);
34+
}
35+
36+
public void testEnabledTaskTypes_FromSingleEntry() {
37+
assertThat(
38+
new ElasticInferenceServiceACL(Map.of("model-1", EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING)))
39+
.enabledTaskTypes(),
40+
is(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING))
41+
);
42+
}
43+
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModelTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ public static ElasticInferenceServiceSparseEmbeddingsModel createModel(String ur
2727
new ElasticInferenceServiceSparseEmbeddingsServiceSettings(ElserModels.ELSER_V2_MODEL, maxInputTokens, null),
2828
EmptyTaskSettings.INSTANCE,
2929
EmptySecretSettings.INSTANCE,
30-
new ElasticInferenceServiceComponents(url)
30+
new ElasticInferenceServiceComponents(url, ElasticInferenceServiceACLTests.createEnabledAcl())
3131
);
3232
}
3333
}

0 commit comments

Comments
 (0)