Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ public class PutTrainedModelAction extends ActionType<PutTrainedModelAction.Resp
public static final String NAME = "cluster:admin/xpack/ml/inference/put";
public static final String MODEL_ALREADY_EXISTS_ERROR_MESSAGE_FRAGMENT =
"the model id is the same as the deployment id of a current model deployment";
public static final String LICENSE_NON_COMPLIANT_ERROR_MESSAGE_FRAGMENT = "current license is non-compliant";

private PutTrainedModelAction() {
super(NAME);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference;

import org.apache.http.HttpHost;
import org.elasticsearch.client.Request;
import org.elasticsearch.client.ResponseException;
import org.elasticsearch.client.RestClient;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.test.cluster.ElasticsearchCluster;
import org.elasticsearch.test.cluster.local.distribution.DistributionType;
import org.elasticsearch.test.rest.ESRestTestCase;
import org.junit.ClassRule;

import java.io.IOException;

import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;

public class InsufficientLicenseIT extends ESRestTestCase {

private static final String PASSWORD = "secret-test-password";

@ClassRule
public static ElasticsearchCluster cluster = ElasticsearchCluster.local()
.distribution(DistributionType.DEFAULT)
.setting("xpack.license.self_generated.type", "basic")
.setting("xpack.security.enabled", "true")
.plugin("inference-service-test")
.user("x_pack_rest_user", "x-pack-test-password")
.build();

@Override
protected String getTestRestCluster() {
return cluster.getHttpAddresses();
}

@Override
protected Settings restClientSettings() {
// use the privileged users here but not in the tests
String token = basicAuthHeaderValue("x_pack_rest_user", new SecureString("x-pack-test-password".toCharArray()));
return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", token).build();
}

public void testInsufficientLicense() throws IOException {
var putRequest = new Request("PUT", "_inference/sparse_embedding/license_test");
putRequest.setJsonEntity("""
{
"service": "elser",
"service_settings": {
"num_allocations": 1,
"num_threads": 1
}
}
""");
var getRequest = new Request("GET", "_inference/sparse_embedding/license_test");

try (RestClient client = buildClient(restClientSettings(), getClusterHosts().toArray(new HttpHost[0]))) {
// Creating inference endpoint will return a license error
ResponseException putException = expectThrows(ResponseException.class, () -> client.performRequest(putRequest));
assertThat(putException.getResponse().getStatusLine().getStatusCode(), equalTo(RestStatus.FORBIDDEN.getStatus()));
assertThat(putException.getMessage(), containsString("current license is non-compliant for [ml]"));

// Assert no inference endpoint created
ResponseException getException = expectThrows(ResponseException.class, () -> client.performRequest(getRequest));
assertThat(getException.getResponse().getStatusLine().getStatusCode(), equalTo(RestStatus.NOT_FOUND.getStatus()));
assertThat(getException.getMessage(), containsString("Inference endpoint not found [license_test]"));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchSecurityException;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.ResourceNotFoundException;
Expand All @@ -23,7 +24,9 @@
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnparsedModel;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.inference.action.DeleteInferenceEndpointAction;
import org.elasticsearch.xpack.core.ml.MachineLearningField;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
Expand Down Expand Up @@ -155,9 +158,14 @@ public void putModel(Model model, ActionListener<Boolean> listener) {
listener.onFailure(notElasticsearchModelException(model));
return;
} else if (model instanceof MultilingualE5SmallModel e5Model) {
putBuiltInModel(e5Model.getServiceSettings().modelId(), listener);
putBuiltInModel(e5Model.getServiceSettings().modelId(), e5Model.getInferenceEntityId(), e5Model.getTaskType(), listener);
} else if (model instanceof ElserInternalModel elserModel) {
putBuiltInModel(elserModel.getServiceSettings().modelId(), listener);
putBuiltInModel(
elserModel.getServiceSettings().modelId(),
elserModel.getInferenceEntityId(),
elserModel.getTaskType(),
listener
);
} else if (model instanceof CustomElandModel) {
logger.info("Custom eland model detected, model must have been already loaded into the cluster with eland.");
listener.onResponse(Boolean.TRUE);
Expand All @@ -173,7 +181,7 @@ public void putModel(Model model, ActionListener<Boolean> listener) {
}
}

protected void putBuiltInModel(String modelId, ActionListener<Boolean> listener) {
protected void putBuiltInModel(String modelId, String inferenceId, TaskType taskType, ActionListener<Boolean> listener) {
var input = new TrainedModelInput(List.<String>of("text_field")); // by convention text_field is used
var config = TrainedModelConfig.builder().setInput(input).setModelId(modelId).validate(true).build();
PutTrainedModelAction.Request putRequest = new PutTrainedModelAction.Request(config, false, true);
Expand All @@ -186,9 +194,26 @@ protected void putBuiltInModel(String modelId, ActionListener<Boolean> listener)
if (e instanceof ElasticsearchStatusException esException
&& esException.getMessage().contains(PutTrainedModelAction.MODEL_ALREADY_EXISTS_ERROR_MESSAGE_FRAGMENT)) {
listener.onResponse(Boolean.TRUE);
} else {
listener.onFailure(e);
}
} else if (e instanceof ElasticsearchSecurityException esException
&& esException.getMessage().contains(PutTrainedModelAction.LICENSE_NON_COMPLIANT_ERROR_MESSAGE_FRAGMENT)) {
var deleteRequest = new DeleteInferenceEndpointAction.Request(inferenceId, taskType, true, false);
client.execute(
DeleteInferenceEndpointAction.INSTANCE,
deleteRequest,
ActionListener.wrap(
r -> listener.onFailure(e),
e1 -> listener.onFailure(
new ElasticsearchStatusException(
"Failed to delete the inference endpoint after failing to start the trained model due to "
+ "non-compliant license",
RestStatus.FORBIDDEN
)
)
)
);
} else {
listener.onFailure(e);
}
})
);
}
Expand Down