Skip to content

Commit 54c35e0

Browse files
authored
[ML] Require basic licence for the Elastic Inference Service (#137434)
1 parent c6bdd28 commit 54c35e0

File tree

14 files changed

+384
-106
lines changed

14 files changed

+384
-106
lines changed

docs/changelog/137434.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 137434
2+
summary: Require basic licence for the Elastic Inference Service
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackField.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ public final class XPackField {
3232
public static final String UPGRADE = "upgrade";
3333
// inside of YAML settings we still use xpack do not having handle issues with dashes
3434
public static final String SETTINGS_NAME = "xpack";
35+
/** Name constant for the EIS feature. */
36+
public static final String ELASTIC_INFERENCE_SERVICE = "Elastic Inference Service";
3537
/** Name constant for the eql feature. */
3638
public static final String EQL = "eql";
3739
/** Name constant for the esql feature. */

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/BaseMockEISAuthServerTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ public class BaseMockEISAuthServerTest extends ESRestTestCase {
3434

3535
private static ElasticsearchCluster cluster = ElasticsearchCluster.local()
3636
.distribution(DistributionType.DEFAULT)
37-
.setting("xpack.license.self_generated.type", "trial")
37+
.setting("xpack.license.self_generated.type", "basic")
3838
.setting("xpack.security.enabled", "true")
3939
// Adding both settings unless one feature flag is disabled in a particular environment
4040
.setting("xpack.inference.elastic.url", mockEISServer::getUrl)

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBasicLicenseIT.java

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -45,28 +45,6 @@ public void testPutModel_RestrictedWithBasicLicense() throws Exception {
4545
sendRestrictedRequest("PUT", endpoint, modelConfig);
4646
}
4747

48-
public void testUpdateModel_RestrictedWithBasicLicense() throws Exception {
49-
var endpoint = Strings.format("_inference/%s/%s/_update?error_trace", TaskType.SPARSE_EMBEDDING, "endpoint-id");
50-
var requestBody = """
51-
{
52-
"task_settings": {
53-
"num_threads": 2
54-
}
55-
}
56-
""";
57-
sendRestrictedRequest("PUT", endpoint, requestBody);
58-
}
59-
60-
public void testPerformInference_RestrictedWithBasicLicense() throws Exception {
61-
var endpoint = Strings.format("_inference/%s/%s?error_trace", TaskType.SPARSE_EMBEDDING, "endpoint-id");
62-
var requestBody = """
63-
{
64-
"input": ["washing", "machine"]
65-
}
66-
""";
67-
sendRestrictedRequest("POST", endpoint, requestBody);
68-
}
69-
7048
public void testGetServices_NonRestrictedWithBasicLicense() throws Exception {
7149
var endpoint = "_inference/_services";
7250
sendNonRestrictedRequest("GET", endpoint, null, 200, false);

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterBasicLicenseIT.java

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@
4646
@ESTestCase.WithoutEntitlements // due to dependency issue ES-12435
4747
public class ShardBulkInferenceActionFilterBasicLicenseIT extends ESIntegTestCase {
4848
public static final String INDEX_NAME = "test-index";
49+
private static final String SPARSE_INFERENCE_ID = "sparse-endpoint";
50+
private static final String DENSE_INFERENCE_ID = "dense-endpoint";
4951

5052
private final boolean useLegacyFormat;
5153

@@ -61,9 +63,9 @@ public static Iterable<Object[]> parameters() {
6163
@Before
6264
public void setup() throws Exception {
6365
ModelRegistry modelRegistry = internalCluster().getCurrentMasterNodeInstance(ModelRegistry.class);
64-
Utils.storeSparseModel("sparse-endpoint", modelRegistry);
66+
Utils.storeSparseModel(SPARSE_INFERENCE_ID, modelRegistry);
6567
Utils.storeDenseModel(
66-
"dense-endpoint",
68+
DENSE_INFERENCE_ID,
6769
modelRegistry,
6870
randomIntBetween(1, 100),
6971
// dot product means that we need normalized vectors; it's not worth doing that in this test
@@ -92,27 +94,20 @@ public Settings indexSettings() {
9294
}
9395

9496
public void testLicenseInvalidForInference() {
95-
prepareCreate(INDEX_NAME).setMapping(
96-
String.format(
97-
Locale.ROOT,
98-
"""
99-
{
100-
"properties": {
101-
"sparse_field": {
102-
"type": "semantic_text",
103-
"inference_id": "%s"
104-
},
105-
"dense_field": {
106-
"type": "semantic_text",
107-
"inference_id": "%s"
108-
}
109-
}
97+
prepareCreate(INDEX_NAME).setMapping(String.format(Locale.ROOT, """
98+
{
99+
"properties": {
100+
"sparse_field": {
101+
"type": "semantic_text",
102+
"inference_id": "%s"
103+
},
104+
"dense_field": {
105+
"type": "semantic_text",
106+
"inference_id": "%s"
110107
}
111-
""",
112-
TestSparseInferenceServiceExtension.TestInferenceService.NAME,
113-
TestDenseInferenceServiceExtension.TestInferenceService.NAME
114-
)
115-
).get();
108+
}
109+
}
110+
""", SPARSE_INFERENCE_ID, DENSE_INFERENCE_ID)).get();
116111

117112
BulkRequestBuilder bulkRequest = client().prepareBulk();
118113
int totalBulkReqs = randomIntBetween(2, 100);
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference;
9+
10+
import org.elasticsearch.ElasticsearchSecurityException;
11+
import org.elasticsearch.license.LicenseUtils;
12+
import org.elasticsearch.license.XPackLicenseState;
13+
import org.elasticsearch.xpack.core.XPackField;
14+
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;
15+
16+
import static org.elasticsearch.xpack.inference.InferencePlugin.EIS_INFERENCE_FEATURE;
17+
import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE;
18+
19+
public class InferenceLicenceCheck {
20+
21+
private InferenceLicenceCheck() {}
22+
23+
public static boolean isServiceLicenced(String serviceName, XPackLicenseState licenseState) {
24+
if (ElasticInferenceService.NAME.equals(serviceName)) {
25+
return EIS_INFERENCE_FEATURE.check(licenseState);
26+
} else {
27+
return INFERENCE_API_FEATURE.check(licenseState);
28+
}
29+
}
30+
31+
public static ElasticsearchSecurityException complianceException(String serviceName) {
32+
if (ElasticInferenceService.NAME.equals(serviceName)) {
33+
return LicenseUtils.newComplianceException(XPackField.ELASTIC_INFERENCE_SERVICE);
34+
} else {
35+
return LicenseUtils.newComplianceException(XPackField.INFERENCE);
36+
}
37+
}
38+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,12 @@ public class InferencePlugin extends Plugin
203203
License.OperationMode.ENTERPRISE
204204
);
205205

206+
public static final LicensedFeature.Momentary EIS_INFERENCE_FEATURE = LicensedFeature.momentary(
207+
"inference",
208+
"Elastic Inference Service",
209+
License.OperationMode.BASIC
210+
);
211+
206212
public static final String X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER = "X-elastic-product-use-case";
207213

208214
public static final String NAME = "inference";

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,14 @@
2525
import org.elasticsearch.inference.Model;
2626
import org.elasticsearch.inference.TaskType;
2727
import org.elasticsearch.inference.telemetry.InferenceStats;
28-
import org.elasticsearch.license.LicenseUtils;
2928
import org.elasticsearch.license.XPackLicenseState;
3029
import org.elasticsearch.rest.RestStatus;
3130
import org.elasticsearch.tasks.Task;
3231
import org.elasticsearch.threadpool.ThreadPool;
3332
import org.elasticsearch.transport.TransportService;
34-
import org.elasticsearch.xpack.core.XPackField;
3533
import org.elasticsearch.xpack.core.inference.action.BaseInferenceActionRequest;
3634
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
35+
import org.elasticsearch.xpack.inference.InferenceLicenceCheck;
3736
import org.elasticsearch.xpack.inference.InferencePlugin;
3837
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
3938
import org.elasticsearch.xpack.inference.registry.InferenceEndpointRegistry;
@@ -51,7 +50,6 @@
5150
import static org.elasticsearch.core.Strings.format;
5251
import static org.elasticsearch.inference.telemetry.InferenceStats.responseAttributes;
5352
import static org.elasticsearch.inference.telemetry.InferenceStats.serviceAndResponseAttributes;
54-
import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE;
5553

5654
/**
5755
* Base class for transport actions that handle inference requests.
@@ -112,16 +110,17 @@ protected abstract void doInference(
112110

113111
@Override
114112
protected void doExecute(Task task, Request request, ActionListener<InferenceAction.Response> listener) {
115-
if (INFERENCE_API_FEATURE.check(licenseState) == false) {
116-
listener.onFailure(LicenseUtils.newComplianceException(XPackField.INFERENCE));
117-
return;
118-
}
119113

120114
var timer = InferenceTimer.start();
121115

122116
var getModelListener = ActionListener.wrap((Model model) -> {
123117
var serviceName = model.getConfigurations().getService();
124118

119+
if (InferenceLicenceCheck.isServiceLicenced(serviceName, licenseState) == false) {
120+
listener.onFailure(InferenceLicenceCheck.complianceException(serviceName));
121+
return;
122+
}
123+
125124
try {
126125
validateRequest(request, model);
127126
} catch (Exception e) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import org.elasticsearch.action.support.ActionFilters;
1515
import org.elasticsearch.action.support.master.TransportMasterNodeAction;
1616
import org.elasticsearch.client.internal.Client;
17-
import org.elasticsearch.client.internal.OriginSettingClient;
1817
import org.elasticsearch.cluster.ClusterState;
1918
import org.elasticsearch.cluster.block.ClusterBlockException;
2019
import org.elasticsearch.cluster.block.ClusterBlockLevel;
@@ -34,19 +33,18 @@
3433
import org.elasticsearch.inference.ModelConfigurations;
3534
import org.elasticsearch.inference.TaskType;
3635
import org.elasticsearch.injection.guice.Inject;
37-
import org.elasticsearch.license.LicenseUtils;
3836
import org.elasticsearch.license.XPackLicenseState;
3937
import org.elasticsearch.rest.RestStatus;
4038
import org.elasticsearch.tasks.Task;
4139
import org.elasticsearch.threadpool.ThreadPool;
4240
import org.elasticsearch.transport.TransportService;
4341
import org.elasticsearch.xcontent.XContentParser;
4442
import org.elasticsearch.xcontent.XContentParserConfiguration;
45-
import org.elasticsearch.xpack.core.XPackField;
4643
import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction;
4744
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentUtils;
4845
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
4946
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
47+
import org.elasticsearch.xpack.inference.InferenceLicenceCheck;
5048
import org.elasticsearch.xpack.inference.InferencePlugin;
5149
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
5250
import org.elasticsearch.xpack.inference.services.ServiceUtils;
@@ -60,8 +58,6 @@
6058
import java.util.Set;
6159

6260
import static org.elasticsearch.core.Strings.format;
63-
import static org.elasticsearch.xpack.core.ClientHelper.INFERENCE_ORIGIN;
64-
import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE;
6561
import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME;
6662
import static org.elasticsearch.xpack.inference.common.SemanticTextInfoExtractor.getModelSettingsForIndicesReferencingInferenceEndpoints;
6763
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper.canMergeModelSettings;
@@ -76,7 +72,6 @@ public class TransportPutInferenceModelAction extends TransportMasterNodeAction<
7672
private final XPackLicenseState licenseState;
7773
private final ModelRegistry modelRegistry;
7874
private final InferenceServiceRegistry serviceRegistry;
79-
private final OriginSettingClient client;
8075
private volatile boolean skipValidationAndStart;
8176
private final ProjectResolver projectResolver;
8277

@@ -110,7 +105,6 @@ public TransportPutInferenceModelAction(
110105
clusterService.getClusterSettings()
111106
.addSettingsUpdateConsumer(InferencePlugin.SKIP_VALIDATE_AND_START, this::setSkipValidationAndStart);
112107
this.projectResolver = projectResolver;
113-
this.client = new OriginSettingClient(client, INFERENCE_ORIGIN);
114108
}
115109

116110
@Override
@@ -120,11 +114,6 @@ protected void masterOperation(
120114
ClusterState state,
121115
ActionListener<PutInferenceModelAction.Response> listener
122116
) throws Exception {
123-
if (INFERENCE_API_FEATURE.check(licenseState) == false) {
124-
listener.onFailure(LicenseUtils.newComplianceException(XPackField.INFERENCE));
125-
return;
126-
}
127-
128117
if (modelRegistry.containsDefaultConfigId(request.getInferenceEntityId())) {
129118
listener.onFailure(
130119
new ElasticsearchStatusException(
@@ -150,6 +139,11 @@ protected void masterOperation(
150139
return;
151140
}
152141

142+
if (InferenceLicenceCheck.isServiceLicenced(serviceName, licenseState) == false) {
143+
listener.onFailure(InferenceLicenceCheck.complianceException(serviceName));
144+
return;
145+
}
146+
153147
if (List.of(OLD_ELSER_SERVICE_NAME, ElasticsearchInternalService.NAME).contains(serviceName)) {
154148
// required for BWC of elser service in elasticsearch service TODO remove when elser service deprecated
155149
requestAsMap.put(ModelConfigurations.SERVICE, serviceName);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUpdateInferenceModelAction.java

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,21 +35,20 @@
3535
import org.elasticsearch.inference.TaskType;
3636
import org.elasticsearch.inference.UnparsedModel;
3737
import org.elasticsearch.injection.guice.Inject;
38-
import org.elasticsearch.license.LicenseUtils;
3938
import org.elasticsearch.license.XPackLicenseState;
4039
import org.elasticsearch.rest.RestStatus;
4140
import org.elasticsearch.tasks.Task;
4241
import org.elasticsearch.threadpool.ThreadPool;
4342
import org.elasticsearch.transport.TransportService;
4443
import org.elasticsearch.xcontent.XContentParser;
4544
import org.elasticsearch.xcontent.XContentParserConfiguration;
46-
import org.elasticsearch.xpack.core.XPackField;
4745
import org.elasticsearch.xpack.core.inference.action.UpdateInferenceModelAction;
4846
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
4947
import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelDeploymentAction;
5048
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentUtils;
5149
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
5250
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
51+
import org.elasticsearch.xpack.inference.InferenceLicenceCheck;
5352
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
5453
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalModel;
5554
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService;
@@ -61,7 +60,6 @@
6160
import java.util.Optional;
6261
import java.util.concurrent.atomic.AtomicReference;
6362

64-
import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE;
6563
import static org.elasticsearch.xpack.inference.services.ServiceUtils.resolveTaskType;
6664
import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings.NUM_ALLOCATIONS;
6765

@@ -113,11 +111,6 @@ protected void masterOperation(
113111
ClusterState state,
114112
ActionListener<UpdateInferenceModelAction.Response> masterListener
115113
) {
116-
if (INFERENCE_API_FEATURE.check(licenseState) == false) {
117-
masterListener.onFailure(LicenseUtils.newComplianceException(XPackField.INFERENCE));
118-
return;
119-
}
120-
121114
var bodyTaskType = request.getContentAsSettings().taskType();
122115
var resolvedTaskType = resolveTaskType(request.getTaskType(), bodyTaskType != null ? bodyTaskType.toString() : null);
123116

@@ -137,10 +130,16 @@ protected void masterOperation(
137130
unparsedModel.service()
138131
)
139132
);
140-
} else {
141-
service.set(optionalService.get());
142-
listener.onResponse(unparsedModel);
133+
return;
143134
}
135+
136+
if (InferenceLicenceCheck.isServiceLicenced(optionalService.get().name(), licenseState) == false) {
137+
listener.onFailure(InferenceLicenceCheck.complianceException(optionalService.get().name()));
138+
return;
139+
}
140+
141+
service.set(optionalService.get());
142+
listener.onResponse(unparsedModel);
144143
})
145144
.<Boolean>andThen((listener, existingUnparsedModel) -> {
146145

0 commit comments

Comments
 (0)