Skip to content

Commit b66dfc1

Browse files
committed
Deleting inference endpoint if start model action returns license error
1 parent 7d0d50d commit b66dfc1

File tree

3 files changed

+110
-6
lines changed

3 files changed

+110
-6
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelAction.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ public class PutTrainedModelAction extends ActionType<PutTrainedModelAction.Resp
3232
public static final String NAME = "cluster:admin/xpack/ml/inference/put";
3333
public static final String MODEL_ALREADY_EXISTS_ERROR_MESSAGE_FRAGMENT =
3434
"the model id is the same as the deployment id of a current model deployment";
35+
public static final String LICENSE_NON_COMPLIANT_ERROR_MESSAGE_FRAGMENT = "current license is non-compliant";
3536

3637
private PutTrainedModelAction() {
3738
super(NAME);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference;
9+
10+
import org.apache.http.HttpHost;
11+
import org.elasticsearch.client.Request;
12+
import org.elasticsearch.client.ResponseException;
13+
import org.elasticsearch.client.RestClient;
14+
import org.elasticsearch.common.settings.SecureString;
15+
import org.elasticsearch.common.settings.Settings;
16+
import org.elasticsearch.common.util.concurrent.ThreadContext;
17+
import org.elasticsearch.rest.RestStatus;
18+
import org.elasticsearch.test.cluster.ElasticsearchCluster;
19+
import org.elasticsearch.test.cluster.local.distribution.DistributionType;
20+
import org.elasticsearch.test.rest.ESRestTestCase;
21+
import org.junit.ClassRule;
22+
23+
import java.io.IOException;
24+
25+
import static org.hamcrest.Matchers.containsString;
26+
import static org.hamcrest.Matchers.equalTo;
27+
28+
public class InsufficientLicenseIT extends ESRestTestCase {
29+
30+
private static final String PASSWORD = "secret-test-password";
31+
32+
@ClassRule
33+
public static ElasticsearchCluster cluster = ElasticsearchCluster.local()
34+
.distribution(DistributionType.DEFAULT)
35+
.setting("xpack.license.self_generated.type", "basic")
36+
.setting("xpack.security.enabled", "true")
37+
.plugin("inference-service-test")
38+
.user("x_pack_rest_user", "x-pack-test-password")
39+
.build();
40+
41+
@Override
42+
protected String getTestRestCluster() {
43+
return cluster.getHttpAddresses();
44+
}
45+
46+
@Override
47+
protected Settings restClientSettings() {
48+
// use the privileged users here but not in the tests
49+
String token = basicAuthHeaderValue("x_pack_rest_user", new SecureString("x-pack-test-password".toCharArray()));
50+
return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", token).build();
51+
}
52+
53+
public void testInsufficientLicense() throws IOException {
54+
var putRequest = new Request("PUT", "_inference/sparse_embedding/license_test");
55+
putRequest.setJsonEntity("""
56+
{
57+
"service": "elser",
58+
"service_settings": {
59+
"num_allocations": 1,
60+
"num_threads": 1
61+
}
62+
}
63+
""");
64+
var getRequest = new Request("GET", "_inference/sparse_embedding/license_test");
65+
66+
try (RestClient client = buildClient(restClientSettings(), getClusterHosts().toArray(new HttpHost[0]))) {
67+
// Creating inference endpoint will return a license error
68+
ResponseException putException = expectThrows(ResponseException.class, () -> client.performRequest(putRequest));
69+
assertThat(putException.getResponse().getStatusLine().getStatusCode(), equalTo(RestStatus.FORBIDDEN.getStatus()));
70+
assertThat(putException.getMessage(), containsString("current license is non-compliant for [ml]"));
71+
72+
// Assert no inference endpoint created
73+
ResponseException getException = expectThrows(ResponseException.class, () -> client.performRequest(getRequest));
74+
assertThat(getException.getResponse().getStatusLine().getStatusCode(), equalTo(RestStatus.NOT_FOUND.getStatus()));
75+
assertThat(getException.getMessage(), containsString("Inference endpoint not found [license_test]"));
76+
}
77+
}
78+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import org.apache.logging.log4j.LogManager;
1111
import org.apache.logging.log4j.Logger;
12+
import org.elasticsearch.ElasticsearchSecurityException;
1213
import org.elasticsearch.ElasticsearchStatusException;
1314
import org.elasticsearch.ExceptionsHelper;
1415
import org.elasticsearch.ResourceNotFoundException;
@@ -23,7 +24,9 @@
2324
import org.elasticsearch.inference.Model;
2425
import org.elasticsearch.inference.TaskType;
2526
import org.elasticsearch.inference.UnparsedModel;
27+
import org.elasticsearch.rest.RestStatus;
2628
import org.elasticsearch.xpack.core.ClientHelper;
29+
import org.elasticsearch.xpack.core.inference.action.DeleteInferenceEndpointAction;
2730
import org.elasticsearch.xpack.core.ml.MachineLearningField;
2831
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
2932
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
@@ -155,9 +158,14 @@ public void putModel(Model model, ActionListener<Boolean> listener) {
155158
listener.onFailure(notElasticsearchModelException(model));
156159
return;
157160
} else if (model instanceof MultilingualE5SmallModel e5Model) {
158-
putBuiltInModel(e5Model.getServiceSettings().modelId(), listener);
161+
putBuiltInModel(e5Model.getServiceSettings().modelId(), e5Model.getInferenceEntityId(), e5Model.getTaskType(), listener);
159162
} else if (model instanceof ElserInternalModel elserModel) {
160-
putBuiltInModel(elserModel.getServiceSettings().modelId(), listener);
163+
putBuiltInModel(
164+
elserModel.getServiceSettings().modelId(),
165+
elserModel.getInferenceEntityId(),
166+
elserModel.getTaskType(),
167+
listener
168+
);
161169
} else if (model instanceof CustomElandModel) {
162170
logger.info("Custom eland model detected, model must have been already loaded into the cluster with eland.");
163171
listener.onResponse(Boolean.TRUE);
@@ -173,7 +181,7 @@ public void putModel(Model model, ActionListener<Boolean> listener) {
173181
}
174182
}
175183

176-
protected void putBuiltInModel(String modelId, ActionListener<Boolean> listener) {
184+
protected void putBuiltInModel(String modelId, String inferenceId, TaskType taskType, ActionListener<Boolean> listener) {
177185
var input = new TrainedModelInput(List.<String>of("text_field")); // by convention text_field is used
178186
var config = TrainedModelConfig.builder().setInput(input).setModelId(modelId).validate(true).build();
179187
PutTrainedModelAction.Request putRequest = new PutTrainedModelAction.Request(config, false, true);
@@ -186,9 +194,26 @@ protected void putBuiltInModel(String modelId, ActionListener<Boolean> listener)
186194
if (e instanceof ElasticsearchStatusException esException
187195
&& esException.getMessage().contains(PutTrainedModelAction.MODEL_ALREADY_EXISTS_ERROR_MESSAGE_FRAGMENT)) {
188196
listener.onResponse(Boolean.TRUE);
189-
} else {
190-
listener.onFailure(e);
191-
}
197+
} else if (e instanceof ElasticsearchSecurityException esException
198+
&& esException.getMessage().contains(PutTrainedModelAction.LICENSE_NON_COMPLIANT_ERROR_MESSAGE_FRAGMENT)) {
199+
var deleteRequest = new DeleteInferenceEndpointAction.Request(inferenceId, taskType, true, false);
200+
client.execute(
201+
DeleteInferenceEndpointAction.INSTANCE,
202+
deleteRequest,
203+
ActionListener.wrap(
204+
r -> listener.onFailure(e),
205+
e1 -> listener.onFailure(
206+
new ElasticsearchStatusException(
207+
"Failed to delete the inference endpoint after failing to start the trained model due to "
208+
+ "non-compliant license",
209+
RestStatus.FORBIDDEN
210+
)
211+
)
212+
)
213+
);
214+
} else {
215+
listener.onFailure(e);
216+
}
192217
})
193218
);
194219
}

0 commit comments

Comments
 (0)