Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/137434.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 137434
summary: Require basic licence for the Elastic Inference Service
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ public final class XPackField {
public static final String UPGRADE = "upgrade";
// inside of YAML settings we still use xpack do not having handle issues with dashes
public static final String SETTINGS_NAME = "xpack";
/** Name constant for the EIS feature. */
public static final String ELASTIC_INFERENCE_SERVICE = "Elastic Inference Service";
/** Name constant for the eql feature. */
public static final String EQL = "eql";
/** Name constant for the esql feature. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public class BaseMockEISAuthServerTest extends ESRestTestCase {

private static ElasticsearchCluster cluster = ElasticsearchCluster.local()
.distribution(DistributionType.DEFAULT)
.setting("xpack.license.self_generated.type", "trial")
.setting("xpack.license.self_generated.type", "basic")
.setting("xpack.security.enabled", "true")
// Adding both settings unless one feature flag is disabled in a particular environment
.setting("xpack.inference.elastic.url", mockEISServer::getUrl)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,6 @@ public void testUpdateModel_RestrictedWithBasicLicense() throws Exception {
sendRestrictedRequest("PUT", endpoint, requestBody);
}

public void testPerformInference_RestrictedWithBasicLicense() throws Exception {
var endpoint = Strings.format("_inference/%s/%s?error_trace", TaskType.SPARSE_EMBEDDING, "endpoint-id");
var requestBody = """
{
"input": ["washing", "machine"]
}
""";
sendRestrictedRequest("POST", endpoint, requestBody);
}

public void testGetServices_NonRestrictedWithBasicLicense() throws Exception {
var endpoint = "_inference/_services";
sendNonRestrictedRequest("GET", endpoint, null, 200, false);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference;

import org.elasticsearch.ElasticsearchSecurityException;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.xpack.core.XPackField;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;

import static org.elasticsearch.xpack.inference.InferencePlugin.EIS_INFERENCE_FEATURE;
import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE;

public class InferenceLicenceCheck {

private InferenceLicenceCheck() {}

public static boolean isServiceLicenced(String serviceName, XPackLicenseState licenseState) {
if (ElasticInferenceService.NAME.equals(serviceName)) {
return EIS_INFERENCE_FEATURE.check(licenseState);
} else {
return INFERENCE_API_FEATURE.check(licenseState);
}
}

public static ElasticsearchSecurityException complianceException(String serviceName) {
if (ElasticInferenceService.NAME.equals(serviceName)) {
return LicenseUtils.newComplianceException(XPackField.ELASTIC_INFERENCE_SERVICE);
} else {
return LicenseUtils.newComplianceException(XPackField.INFERENCE);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,12 @@ public class InferencePlugin extends Plugin
License.OperationMode.ENTERPRISE
);

public static final LicensedFeature.Momentary EIS_INFERENCE_FEATURE = LicensedFeature.momentary(
"inference",
"eis",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its a bit confusing to me, the ElasticInferenceService.NAME is elastic, but our license service name is eis. So the user configures elastic to get the license with name eis.

Could we name this elastic to match the name of the ElasticInferenceService.NAME?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've changed this to the proper name "Elastic Inference Service"

The error the user sees if the licence is not compatible is either:

"current license is non-compliant for [inference]"

Or

"current license is non-compliant for [Elastic Inference Service]"

License.OperationMode.BASIC
);

public static final String X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER = "X-elastic-product-use-case";

public static final String NAME = "inference";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,14 @@
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.telemetry.InferenceStats;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.XPackField;
import org.elasticsearch.xpack.core.inference.action.BaseInferenceActionRequest;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.inference.InferenceLicenceCheck;
import org.elasticsearch.xpack.inference.InferencePlugin;
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
import org.elasticsearch.xpack.inference.registry.InferenceEndpointRegistry;
Expand All @@ -51,7 +50,6 @@
import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.inference.telemetry.InferenceStats.responseAttributes;
import static org.elasticsearch.inference.telemetry.InferenceStats.serviceAndResponseAttributes;
import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE;

/**
* Base class for transport actions that handle inference requests.
Expand Down Expand Up @@ -112,16 +110,17 @@ protected abstract void doInference(

@Override
protected void doExecute(Task task, Request request, ActionListener<InferenceAction.Response> listener) {
if (INFERENCE_API_FEATURE.check(licenseState) == false) {
listener.onFailure(LicenseUtils.newComplianceException(XPackField.INFERENCE));
return;
}

var timer = InferenceTimer.start();

var getModelListener = ActionListener.wrap((Model model) -> {
var serviceName = model.getConfigurations().getService();

if (InferenceLicenceCheck.isServiceLicenced(serviceName, licenseState) == false) {
listener.onFailure(InferenceLicenceCheck.complianceException(serviceName));
return;
}

try {
validateRequest(request, model);
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,18 @@
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.injection.guice.Inject;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xpack.core.XPackField;
import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentUtils;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.inference.InferenceLicenceCheck;
import org.elasticsearch.xpack.inference.InferencePlugin;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
Expand All @@ -53,7 +52,6 @@
import java.util.Map;

import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE;
import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.OLD_ELSER_SERVICE_NAME;

public class TransportPutInferenceModelAction extends TransportMasterNodeAction<
Expand Down Expand Up @@ -106,11 +104,6 @@ protected void masterOperation(
ClusterState state,
ActionListener<PutInferenceModelAction.Response> listener
) throws Exception {
if (INFERENCE_API_FEATURE.check(licenseState) == false) {
listener.onFailure(LicenseUtils.newComplianceException(XPackField.INFERENCE));
return;
}

if (modelRegistry.containsDefaultConfigId(request.getInferenceEntityId())) {
listener.onFailure(
new ElasticsearchStatusException(
Expand All @@ -136,6 +129,11 @@ protected void masterOperation(
return;
}

if (InferenceLicenceCheck.isServiceLicenced(serviceName, licenseState) == false) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I wonder if we should move the reserved ID check to below this check?

listener.onFailure(InferenceLicenceCheck.complianceException(serviceName));
return;
}

if (List.of(OLD_ELSER_SERVICE_NAME, ElasticsearchInternalService.NAME).contains(serviceName)) {
// required for BWC of elser service in elasticsearch service TODO remove when elser service deprecated
requestAsMap.put(ModelConfigurations.SERVICE, serviceName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,21 +35,20 @@
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnparsedModel;
import org.elasticsearch.injection.guice.Inject;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xpack.core.XPackField;
import org.elasticsearch.xpack.core.inference.action.UpdateInferenceModelAction;
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentUtils;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.inference.InferenceLicenceCheck;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalModel;
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService;
Expand All @@ -61,7 +60,6 @@
import java.util.Optional;
import java.util.concurrent.atomic.AtomicReference;

import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.resolveTaskType;
import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings.NUM_ALLOCATIONS;

Expand Down Expand Up @@ -113,11 +111,6 @@ protected void masterOperation(
ClusterState state,
ActionListener<UpdateInferenceModelAction.Response> masterListener
) {
if (INFERENCE_API_FEATURE.check(licenseState) == false) {
masterListener.onFailure(LicenseUtils.newComplianceException(XPackField.INFERENCE));
return;
}

var bodyTaskType = request.getContentAsSettings().taskType();
var resolvedTaskType = resolveTaskType(request.getTaskType(), bodyTaskType != null ? bodyTaskType.toString() : null);

Expand All @@ -137,10 +130,16 @@ protected void masterOperation(
unparsedModel.service()
)
);
} else {
service.set(optionalService.get());
listener.onResponse(unparsedModel);
return;
}

if (InferenceLicenceCheck.isServiceLicenced(optionalService.get().name(), licenseState) == false) {
listener.onFailure(InferenceLicenceCheck.complianceException(optionalService.get().name()));
return;
}

service.set(optionalService.get());
listener.onResponse(unparsedModel);
})
.<Boolean>andThen((listener, existingUnparsedModel) -> {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.UnparsedModel;
import org.elasticsearch.inference.telemetry.InferenceStats;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.Task;
Expand All @@ -57,10 +56,10 @@
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.XPackField;
import org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
import org.elasticsearch.xpack.inference.InferenceException;
import org.elasticsearch.xpack.inference.InferenceLicenceCheck;
import org.elasticsearch.xpack.inference.mapper.SemanticTextField;
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
import org.elasticsearch.xpack.inference.mapper.SemanticTextUtils;
Expand All @@ -78,7 +77,6 @@
import java.util.stream.Collectors;

import static org.elasticsearch.inference.telemetry.InferenceStats.serviceAndResponseAttributes;
import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunks;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunksLegacy;

Expand Down Expand Up @@ -383,6 +381,19 @@ public void onFailure(Exception exc) {
modelRegistry.getModelWithSecrets(inferenceId, modelLoadingListener);
return;
}

if (InferenceLicenceCheck.isServiceLicenced(inferenceProvider.service.name(), licenseState) == false) {
try (onFinish) {
for (FieldInferenceRequest request : requests) {
addInferenceResponseFailure(
request.bulkItemIndex,
InferenceLicenceCheck.complianceException(inferenceProvider.service.name())
);
}
return;
}
}
Comment on lines 385 to 393
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@carlosdelest what do you think of this? It seems OK to me. Both the old and new license check failures are per bulk item request. Now we just delay it until we have the inference provider loaded.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense - we're checking the inference provider service when we have it. As requests are grouped by inference provider, we can fail all of them at once in case it's not compliant. 👍


final List<ChunkInferenceInput> inputs = requests.stream()
.map(r -> new ChunkInferenceInput(r.input, r.chunkingSettings))
.collect(Collectors.toList());
Expand Down Expand Up @@ -571,11 +582,6 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
break;
}

if (INFERENCE_API_FEATURE.check(licenseState) == false) {
addInferenceResponseFailure(itemIndex, LicenseUtils.newComplianceException(XPackField.INFERENCE));
break;
}

List<FieldInferenceRequest> requests = requestsMap.computeIfAbsent(inferenceId, k -> new ArrayList<>());
int offsetAdjustment = 0;
for (String v : values) {
Expand Down
Loading