Skip to content

Commit 9f91fa1

Browse files
Add enterprise license check for Inference API actions (elastic#119893) (elastic#120066)
* Add enterprise license check for Inference API actions * Update docs/changelog/119893.yaml * Adding missing plugin to ModelRegistryIT and removing license check from get inference services API * Fix tests * Fix basic license test * Removing unused feature flag from InferenceUpgradeTestCase --------- Co-authored-by: Elastic Machine <[email protected]>
1 parent d93011e commit 9f91fa1

File tree

20 files changed

+323
-9
lines changed

20 files changed

+323
-9
lines changed

docs/changelog/119893.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 119893
2+
summary: Add enterprise license check for Inference API actions
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference;
9+
10+
import org.elasticsearch.common.Strings;
11+
import org.elasticsearch.common.settings.SecureString;
12+
import org.elasticsearch.common.settings.Settings;
13+
import org.elasticsearch.common.util.concurrent.ThreadContext;
14+
import org.elasticsearch.inference.TaskType;
15+
import org.elasticsearch.test.cluster.ElasticsearchCluster;
16+
import org.elasticsearch.test.cluster.local.distribution.DistributionType;
17+
import org.junit.ClassRule;
18+
19+
import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.mockSparseServiceModelConfig;
20+
21+
public class InferenceBasicLicenseIT extends InferenceLicenseBaseRestTest {
22+
@ClassRule
23+
public static ElasticsearchCluster cluster = ElasticsearchCluster.local()
24+
.distribution(DistributionType.DEFAULT)
25+
.setting("xpack.license.self_generated.type", "basic")
26+
.setting("xpack.security.enabled", "true")
27+
.user("x_pack_rest_user", "x-pack-test-password")
28+
.plugin("inference-service-test")
29+
.build();
30+
31+
@Override
32+
protected String getTestRestCluster() {
33+
return cluster.getHttpAddresses();
34+
}
35+
36+
@Override
37+
protected Settings restClientSettings() {
38+
String token = basicAuthHeaderValue("x_pack_rest_user", new SecureString("x-pack-test-password".toCharArray()));
39+
return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", token).build();
40+
}
41+
42+
public void testPutModel_RestrictedWithBasicLicense() throws Exception {
43+
var endpoint = Strings.format("_inference/%s/%s?error_trace", TaskType.SPARSE_EMBEDDING, "endpoint-id");
44+
var modelConfig = mockSparseServiceModelConfig(null, true);
45+
sendRestrictedRequest("PUT", endpoint, modelConfig);
46+
}
47+
48+
public void testUpdateModel_RestrictedWithBasicLicense() throws Exception {
49+
var endpoint = Strings.format("_inference/%s/%s/_update?error_trace", TaskType.SPARSE_EMBEDDING, "endpoint-id");
50+
var requestBody = """
51+
{
52+
"task_settings": {
53+
"num_threads": 2
54+
}
55+
}
56+
""";
57+
sendRestrictedRequest("PUT", endpoint, requestBody);
58+
}
59+
60+
public void testPerformInference_RestrictedWithBasicLicense() throws Exception {
61+
var endpoint = Strings.format("_inference/%s/%s?error_trace", TaskType.SPARSE_EMBEDDING, "endpoint-id");
62+
var requestBody = """
63+
{
64+
"input": ["washing", "machine"]
65+
}
66+
""";
67+
sendRestrictedRequest("POST", endpoint, requestBody);
68+
}
69+
70+
public void testGetServices_NonRestrictedWithBasicLicense() throws Exception {
71+
var endpoint = "_inference/_services";
72+
sendNonRestrictedRequest("GET", endpoint, null, 200, false);
73+
}
74+
75+
public void testGetModels_NonRestrictedWithBasicLicense() throws Exception {
76+
var endpoint = "_inference/_all";
77+
sendNonRestrictedRequest("GET", endpoint, null, 200, false);
78+
}
79+
80+
public void testDeleteModel_NonRestrictedWithBasicLicense() throws Exception {
81+
var endpoint = Strings.format("_inference/%s/%s?error_trace", TaskType.SPARSE_EMBEDDING, "endpoint-id");
82+
sendNonRestrictedRequest("DELETE", endpoint, null, 404, true);
83+
}
84+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference;
9+
10+
import org.elasticsearch.client.Request;
11+
import org.elasticsearch.client.ResponseException;
12+
import org.elasticsearch.test.rest.ESRestTestCase;
13+
14+
import java.io.IOException;
15+
16+
import static org.hamcrest.Matchers.containsString;
17+
18+
public class InferenceLicenseBaseRestTest extends ESRestTestCase {
19+
protected void sendRestrictedRequest(String method, String endpoint, String body) throws IOException {
20+
var request = new Request(method, endpoint);
21+
request.setJsonEntity(body);
22+
23+
var exception = assertThrows(ResponseException.class, () -> client().performRequest(request));
24+
assertEquals(403, exception.getResponse().getStatusLine().getStatusCode());
25+
assertThat(exception.getMessage(), containsString("current license is non-compliant for [inference]"));
26+
}
27+
28+
protected void sendNonRestrictedRequest(String method, String endpoint, String body, int expectedStatusCode, boolean exceptionExpected)
29+
throws IOException {
30+
var request = new Request(method, endpoint);
31+
request.setJsonEntity(body);
32+
33+
int actualStatusCode;
34+
if (exceptionExpected) {
35+
var exception = assertThrows(ResponseException.class, () -> client().performRequest(request));
36+
actualStatusCode = exception.getResponse().getStatusLine().getStatusCode();
37+
} else {
38+
var response = client().performRequest(request);
39+
actualStatusCode = response.getStatusLine().getStatusCode();
40+
}
41+
assertEquals(expectedStatusCode, actualStatusCode);
42+
}
43+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference;
9+
10+
import org.elasticsearch.common.Strings;
11+
import org.elasticsearch.common.settings.SecureString;
12+
import org.elasticsearch.common.settings.Settings;
13+
import org.elasticsearch.common.util.concurrent.ThreadContext;
14+
import org.elasticsearch.inference.TaskType;
15+
import org.elasticsearch.test.cluster.ElasticsearchCluster;
16+
import org.elasticsearch.test.cluster.local.distribution.DistributionType;
17+
import org.junit.ClassRule;
18+
19+
import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.mockSparseServiceModelConfig;
20+
21+
public class InferenceTrialLicenseIT extends InferenceLicenseBaseRestTest {
22+
@ClassRule
23+
public static ElasticsearchCluster cluster = ElasticsearchCluster.local()
24+
.distribution(DistributionType.DEFAULT)
25+
.setting("xpack.license.self_generated.type", "trial")
26+
.setting("xpack.security.enabled", "true")
27+
.user("x_pack_rest_user", "x-pack-test-password")
28+
.plugin("inference-service-test")
29+
.build();
30+
31+
@Override
32+
protected String getTestRestCluster() {
33+
return cluster.getHttpAddresses();
34+
}
35+
36+
@Override
37+
protected Settings restClientSettings() {
38+
String token = basicAuthHeaderValue("x_pack_rest_user", new SecureString("x-pack-test-password".toCharArray()));
39+
return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", token).build();
40+
}
41+
42+
public void testPutModel_NonRestrictedWithTrialLicense() throws Exception {
43+
var endpoint = Strings.format("_inference/%s/%s?error_trace", TaskType.SPARSE_EMBEDDING, "endpoint-id");
44+
var modelConfig = mockSparseServiceModelConfig(null, true);
45+
sendNonRestrictedRequest("PUT", endpoint, modelConfig, 200, false);
46+
}
47+
48+
public void testUpdateModel_NonRestrictedWithTrialLicense() throws Exception {
49+
var endpoint = Strings.format("_inference/%s/%s/_update?error_trace", TaskType.SPARSE_EMBEDDING, "endpoint-id");
50+
var requestBody = """
51+
{
52+
"task_settings": {
53+
"num_threads": 2
54+
}
55+
}
56+
""";
57+
sendNonRestrictedRequest("PUT", endpoint, requestBody, 404, true);
58+
}
59+
60+
public void testPerformInference_NonRestrictedWithTrialLicense() throws Exception {
61+
var endpoint = Strings.format("_inference/%s/%s?error_trace", TaskType.SPARSE_EMBEDDING, "endpoint-id");
62+
var requestBody = """
63+
{
64+
"input": ["washing", "machine"]
65+
}
66+
""";
67+
sendNonRestrictedRequest("POST", endpoint, requestBody, 404, true);
68+
}
69+
70+
public void testGetServices_NonRestrictedWithBasicLicense() throws Exception {
71+
var endpoint = "_inference/_services";
72+
sendNonRestrictedRequest("GET", endpoint, null, 200, false);
73+
}
74+
75+
public void testGetModels_NonRestrictedWithBasicLicense() throws Exception {
76+
var endpoint = "_inference/_all";
77+
sendNonRestrictedRequest("GET", endpoint, null, 200, false);
78+
}
79+
80+
public void testDeleteModel_NonRestrictedWithBasicLicense() throws Exception {
81+
var endpoint = Strings.format("_inference/%s/%s?error_trace", TaskType.SPARSE_EMBEDDING, "endpoint-id");
82+
sendNonRestrictedRequest("DELETE", endpoint, null, 404, true);
83+
}
84+
}

x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/InferenceUpgradeTestCase.java

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,11 @@
1212
import org.elasticsearch.client.Request;
1313
import org.elasticsearch.common.Strings;
1414
import org.elasticsearch.inference.TaskType;
15+
import org.elasticsearch.test.cluster.ElasticsearchCluster;
16+
import org.elasticsearch.test.cluster.local.distribution.DistributionType;
1517
import org.elasticsearch.test.http.MockWebServer;
16-
import org.elasticsearch.upgrades.AbstractRollingUpgradeTestCase;
18+
import org.elasticsearch.upgrades.ParameterizedRollingUpgradeTestCase;
19+
import org.junit.ClassRule;
1720

1821
import java.io.IOException;
1922
import java.util.LinkedList;
@@ -22,14 +25,28 @@
2225

2326
import static org.elasticsearch.core.Strings.format;
2427

25-
public class InferenceUpgradeTestCase extends AbstractRollingUpgradeTestCase {
28+
public class InferenceUpgradeTestCase extends ParameterizedRollingUpgradeTestCase {
2629

2730
static final String MODELS_RENAMED_TO_ENDPOINTS = "8.15.0";
2831

2932
public InferenceUpgradeTestCase(@Name("upgradedNodes") int upgradedNodes) {
3033
super(upgradedNodes);
3134
}
3235

36+
@ClassRule
37+
public static ElasticsearchCluster cluster = ElasticsearchCluster.local()
38+
.distribution(DistributionType.DEFAULT)
39+
.version(getOldClusterTestVersion())
40+
.nodes(NODE_NUM)
41+
.setting("xpack.security.enabled", "false")
42+
.setting("xpack.license.self_generated.type", "trial")
43+
.build();
44+
45+
@Override
46+
protected ElasticsearchCluster getUpgradeCluster() {
47+
return cluster;
48+
}
49+
3350
protected static String getUrl(MockWebServer webServer) {
3451
return format("http://%s:%s", webServer.getHostName(), webServer.getPort());
3552
}

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.elasticsearch.plugins.Plugin;
2828
import org.elasticsearch.search.builder.SearchSourceBuilder;
2929
import org.elasticsearch.test.ESIntegTestCase;
30+
import org.elasticsearch.xpack.core.LocalStateCompositeXPackPlugin;
3031
import org.elasticsearch.xpack.inference.Utils;
3132
import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension;
3233
import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension;
@@ -73,7 +74,7 @@ public void setup() throws Exception {
7374

7475
@Override
7576
protected Collection<Class<? extends Plugin>> nodePlugins() {
76-
return Arrays.asList(Utils.TestInferencePlugin.class);
77+
return Arrays.asList(Utils.TestInferencePlugin.class, LocalStateCompositeXPackPlugin.class);
7778
}
7879

7980
@Override

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import org.elasticsearch.threadpool.ThreadPool;
3232
import org.elasticsearch.xcontent.ToXContentObject;
3333
import org.elasticsearch.xcontent.XContentBuilder;
34+
import org.elasticsearch.xpack.core.LocalStateCompositeXPackPlugin;
3435
import org.elasticsearch.xpack.inference.InferencePlugin;
3536
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests;
3637
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
@@ -76,7 +77,7 @@ public void createComponents() {
7677

7778
@Override
7879
protected Collection<Class<? extends Plugin>> getPlugins() {
79-
return pluginList(ReindexPlugin.class, InferencePlugin.class);
80+
return pluginList(ReindexPlugin.class, InferencePlugin.class, LocalStateCompositeXPackPlugin.class);
8081
}
8182

8283
public void testStoreModel() throws Exception {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
import org.elasticsearch.indices.SystemIndexDescriptor;
3030
import org.elasticsearch.inference.InferenceServiceExtension;
3131
import org.elasticsearch.inference.InferenceServiceRegistry;
32+
import org.elasticsearch.license.License;
33+
import org.elasticsearch.license.LicensedFeature;
3234
import org.elasticsearch.node.PluginComponentBinding;
3335
import org.elasticsearch.plugins.ActionPlugin;
3436
import org.elasticsearch.plugins.ExtensiblePlugin;
@@ -150,6 +152,12 @@ public class InferencePlugin extends Plugin implements ActionPlugin, ExtensibleP
150152
Setting.Property.Dynamic
151153
);
152154

155+
public static final LicensedFeature.Momentary INFERENCE_API_FEATURE = LicensedFeature.momentary(
156+
"inference",
157+
"api",
158+
License.OperationMode.ENTERPRISE
159+
);
160+
153161
public static final String NAME = "inference";
154162
public static final String UTILITY_THREAD_POOL_NAME = "inference_utility";
155163

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,12 @@
2323
import org.elasticsearch.inference.Model;
2424
import org.elasticsearch.inference.TaskType;
2525
import org.elasticsearch.inference.UnparsedModel;
26+
import org.elasticsearch.license.LicenseUtils;
27+
import org.elasticsearch.license.XPackLicenseState;
2628
import org.elasticsearch.rest.RestStatus;
2729
import org.elasticsearch.tasks.Task;
2830
import org.elasticsearch.transport.TransportService;
31+
import org.elasticsearch.xpack.core.XPackField;
2932
import org.elasticsearch.xpack.core.inference.action.BaseInferenceActionRequest;
3033
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
3134
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
@@ -38,6 +41,7 @@
3841
import java.util.stream.Collectors;
3942

4043
import static org.elasticsearch.core.Strings.format;
44+
import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE;
4145
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.modelAttributes;
4246
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.responseAttributes;
4347

@@ -48,6 +52,7 @@ public abstract class BaseTransportInferenceAction<Request extends BaseInference
4852
private static final Logger log = LogManager.getLogger(BaseTransportInferenceAction.class);
4953
private static final String STREAMING_INFERENCE_TASK_TYPE = "streaming_inference";
5054
private static final String STREAMING_TASK_ACTION = "xpack/inference/streaming_inference[n]";
55+
private final XPackLicenseState licenseState;
5156
private final ModelRegistry modelRegistry;
5257
private final InferenceServiceRegistry serviceRegistry;
5358
private final InferenceStats inferenceStats;
@@ -57,13 +62,15 @@ public BaseTransportInferenceAction(
5762
String inferenceActionName,
5863
TransportService transportService,
5964
ActionFilters actionFilters,
65+
XPackLicenseState licenseState,
6066
ModelRegistry modelRegistry,
6167
InferenceServiceRegistry serviceRegistry,
6268
InferenceStats inferenceStats,
6369
StreamingTaskManager streamingTaskManager,
6470
Writeable.Reader<Request> requestReader
6571
) {
6672
super(inferenceActionName, transportService, actionFilters, requestReader, EsExecutors.DIRECT_EXECUTOR_SERVICE);
73+
this.licenseState = licenseState;
6774
this.modelRegistry = modelRegistry;
6875
this.serviceRegistry = serviceRegistry;
6976
this.inferenceStats = inferenceStats;
@@ -72,6 +79,11 @@ public BaseTransportInferenceAction(
7279

7380
@Override
7481
protected void doExecute(Task task, Request request, ActionListener<InferenceAction.Response> listener) {
82+
if (INFERENCE_API_FEATURE.check(licenseState) == false) {
83+
listener.onFailure(LicenseUtils.newComplianceException(XPackField.INFERENCE));
84+
return;
85+
}
86+
7587
var timer = InferenceTimer.start();
7688

7789
var getModelListener = ActionListener.wrap((UnparsedModel unparsedModel) -> {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import org.elasticsearch.inference.Model;
1717
import org.elasticsearch.inference.UnparsedModel;
1818
import org.elasticsearch.injection.guice.Inject;
19+
import org.elasticsearch.license.XPackLicenseState;
1920
import org.elasticsearch.transport.TransportService;
2021
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
2122
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
@@ -28,6 +29,7 @@ public class TransportInferenceAction extends BaseTransportInferenceAction<Infer
2829
public TransportInferenceAction(
2930
TransportService transportService,
3031
ActionFilters actionFilters,
32+
XPackLicenseState licenseState,
3133
ModelRegistry modelRegistry,
3234
InferenceServiceRegistry serviceRegistry,
3335
InferenceStats inferenceStats,
@@ -37,6 +39,7 @@ public TransportInferenceAction(
3739
InferenceAction.NAME,
3840
transportService,
3941
actionFilters,
42+
licenseState,
4043
modelRegistry,
4144
serviceRegistry,
4245
inferenceStats,

0 commit comments

Comments
 (0)