Skip to content

Commit b46b990

Browse files
committed
Deleting inference endpoint if start model action returns license error
1 parent 8cc76e4 commit b46b990

File tree

3 files changed

+102
-6
lines changed

3 files changed

+102
-6
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelAction.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ public class PutTrainedModelAction extends ActionType<PutTrainedModelAction.Resp
3232
public static final String NAME = "cluster:admin/xpack/ml/inference/put";
3333
public static final String MODEL_ALREADY_EXISTS_ERROR_MESSAGE_FRAGMENT =
3434
"the model id is the same as the deployment id of a current model deployment";
35+
public static final String LICENSE_NON_COMPLIANT_ERROR_MESSAGE_FRAGMENT = "current license is non-compliant";
3536

3637
private PutTrainedModelAction() {
3738
super(NAME);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference;
9+
10+
import org.apache.http.HttpHost;
11+
import org.elasticsearch.client.Request;
12+
import org.elasticsearch.client.ResponseException;
13+
import org.elasticsearch.client.RestClient;
14+
import org.elasticsearch.common.settings.SecureString;
15+
import org.elasticsearch.common.settings.Settings;
16+
import org.elasticsearch.common.util.concurrent.ThreadContext;
17+
import org.elasticsearch.rest.RestStatus;
18+
import org.elasticsearch.test.cluster.ElasticsearchCluster;
19+
import org.elasticsearch.test.cluster.local.distribution.DistributionType;
20+
import org.elasticsearch.test.rest.ESRestTestCase;
21+
import org.junit.ClassRule;
22+
23+
import java.io.IOException;
24+
25+
import static org.hamcrest.Matchers.containsString;
26+
import static org.hamcrest.Matchers.equalTo;
27+
28+
public class InsufficientLicenseIT extends ESRestTestCase {
29+
30+
private static final String PASSWORD = "secret-test-password";
31+
32+
@ClassRule
33+
public static ElasticsearchCluster cluster = ElasticsearchCluster.local()
34+
.distribution(DistributionType.DEFAULT)
35+
.setting("xpack.license.self_generated.type", "basic")
36+
.setting("xpack.security.enabled", "true")
37+
.plugin("inference-service-test")
38+
.user("x_pack_rest_user", "x-pack-test-password")
39+
.build();
40+
41+
@Override
42+
protected String getTestRestCluster() {
43+
return cluster.getHttpAddresses();
44+
}
45+
46+
@Override
47+
protected Settings restClientSettings() {
48+
// use the privileged users here but not in the tests
49+
String token = basicAuthHeaderValue("x_pack_rest_user", new SecureString("x-pack-test-password".toCharArray()));
50+
return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", token).build();
51+
}
52+
53+
public void testInsufficientLicense() throws IOException {
54+
var putRequest = new Request("PUT", "_inference/sparse_embedding/license_test");
55+
putRequest.setJsonEntity("""
56+
{
57+
"service": "elser",
58+
"service_settings": {
59+
"num_allocations": 1,
60+
"num_threads": 1
61+
}
62+
}
63+
""");
64+
var getRequest = new Request("GET", "_inference/sparse_embedding/license_test");
65+
66+
try (RestClient client = buildClient(restClientSettings(), getClusterHosts().toArray(new HttpHost[0]))) {
67+
// Creating inference endpoint will return a license error
68+
ResponseException putException = expectThrows(ResponseException.class, () -> client.performRequest(putRequest));
69+
assertThat(putException.getResponse().getStatusLine().getStatusCode(), equalTo(RestStatus.FORBIDDEN.getStatus()));
70+
assertThat(putException.getMessage(), containsString("current license is non-compliant for [ml]"));
71+
72+
// Assert no inference endpoint created
73+
ResponseException getException = expectThrows(ResponseException.class, () -> client.performRequest(getRequest));
74+
assertThat(getException.getResponse().getStatusLine().getStatusCode(), equalTo(RestStatus.NOT_FOUND.getStatus()));
75+
assertThat(getException.getMessage(), containsString("Inference endpoint not found [license_test]"));
76+
}
77+
}
78+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import org.apache.logging.log4j.LogManager;
1111
import org.apache.logging.log4j.Logger;
12+
import org.elasticsearch.ElasticsearchSecurityException;
1213
import org.elasticsearch.ElasticsearchStatusException;
1314
import org.elasticsearch.ExceptionsHelper;
1415
import org.elasticsearch.ResourceNotFoundException;
@@ -24,6 +25,7 @@
2425
import org.elasticsearch.inference.TaskType;
2526
import org.elasticsearch.inference.UnparsedModel;
2627
import org.elasticsearch.xpack.core.ClientHelper;
28+
import org.elasticsearch.xpack.core.inference.action.DeleteInferenceEndpointAction;
2729
import org.elasticsearch.xpack.core.ml.MachineLearningField;
2830
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
2931
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
@@ -155,9 +157,14 @@ public void putModel(Model model, ActionListener<Boolean> listener) {
155157
listener.onFailure(notElasticsearchModelException(model));
156158
return;
157159
} else if (model instanceof MultilingualE5SmallModel e5Model) {
158-
putBuiltInModel(e5Model.getServiceSettings().modelId(), listener);
160+
putBuiltInModel(e5Model.getServiceSettings().modelId(), e5Model.getInferenceEntityId(), e5Model.getTaskType(), listener);
159161
} else if (model instanceof ElserInternalModel elserModel) {
160-
putBuiltInModel(elserModel.getServiceSettings().modelId(), listener);
162+
putBuiltInModel(
163+
elserModel.getServiceSettings().modelId(),
164+
elserModel.getInferenceEntityId(),
165+
elserModel.getTaskType(),
166+
listener
167+
);
161168
} else if (model instanceof CustomElandModel) {
162169
logger.info("Custom eland model detected, model must have been already loaded into the cluster with eland.");
163170
listener.onResponse(Boolean.TRUE);
@@ -173,7 +180,7 @@ public void putModel(Model model, ActionListener<Boolean> listener) {
173180
}
174181
}
175182

176-
protected void putBuiltInModel(String modelId, ActionListener<Boolean> listener) {
183+
protected void putBuiltInModel(String modelId, String inferenceId, TaskType taskType, ActionListener<Boolean> listener) {
177184
var input = new TrainedModelInput(List.<String>of("text_field")); // by convention text_field is used
178185
var config = TrainedModelConfig.builder().setInput(input).setModelId(modelId).validate(true).build();
179186
PutTrainedModelAction.Request putRequest = new PutTrainedModelAction.Request(config, false, true);
@@ -186,9 +193,19 @@ protected void putBuiltInModel(String modelId, ActionListener<Boolean> listener)
186193
if (e instanceof ElasticsearchStatusException esException
187194
&& esException.getMessage().contains(PutTrainedModelAction.MODEL_ALREADY_EXISTS_ERROR_MESSAGE_FRAGMENT)) {
188195
listener.onResponse(Boolean.TRUE);
189-
} else {
190-
listener.onFailure(e);
191-
}
196+
} else if (e instanceof ElasticsearchSecurityException esException
197+
&& esException.getMessage().contains(PutTrainedModelAction.LICENSE_NON_COMPLIANT_ERROR_MESSAGE_FRAGMENT)) {
198+
var deleteRequest = new DeleteInferenceEndpointAction.Request(inferenceId, taskType, true, false);
199+
client.execute(
200+
DeleteInferenceEndpointAction.INSTANCE,
201+
deleteRequest,
202+
listener.delegateFailureAndWrap((l, response) -> {
203+
listener.onFailure(e);
204+
})
205+
);
206+
} else {
207+
listener.onFailure(e);
208+
}
192209
})
193210
);
194211
}

0 commit comments

Comments
 (0)