Skip to content

Commit ade5b13

Browse files
authored
Add LTR License Check on PUT for Enterprise Licensing (#111248) (#111460)
* add isLicenseAllowedForAction trained model config * fixup tests - trial is allowed * fix license tests * update tests for validate model static method * add validateModel test; update license check
1 parent 5d0eb2a commit ade5b13

File tree

5 files changed

+177
-55
lines changed

5 files changed

+177
-55
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfig.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import org.elasticsearch.common.Strings;
1313
import org.elasticsearch.common.io.stream.VersionedNamedWriteable;
1414
import org.elasticsearch.core.Nullable;
15+
import org.elasticsearch.license.License;
16+
import org.elasticsearch.rest.RestRequest;
1517
import org.elasticsearch.xcontent.ParseField;
1618
import org.elasticsearch.xpack.core.ml.MlConfigVersion;
1719
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
@@ -22,6 +24,7 @@
2224
import java.util.Arrays;
2325

2426
import static org.elasticsearch.action.ValidateActions.addValidationError;
27+
import static org.elasticsearch.xpack.core.ml.MachineLearningField.ML_API_FEATURE;
2528

2629
public interface InferenceConfig extends NamedXContentObject, VersionedNamedWriteable {
2730

@@ -114,4 +117,12 @@ default ElasticsearchStatusException incompatibleUpdateException(String updateNa
114117
updateName
115118
);
116119
}
120+
121+
default License.OperationMode getMinLicenseSupported() {
122+
return ML_API_FEATURE.getMinimumOperationMode();
123+
}
124+
125+
default License.OperationMode getMinLicenseSupportedForAction(RestRequest.Method method) {
126+
return getMinLicenseSupported();
127+
}
117128
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/LearningToRankConfig.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
import org.elasticsearch.common.io.stream.StreamOutput;
1414
import org.elasticsearch.index.query.QueryRewriteContext;
1515
import org.elasticsearch.index.query.Rewriteable;
16+
import org.elasticsearch.license.License;
17+
import org.elasticsearch.rest.RestRequest;
1618
import org.elasticsearch.xcontent.ObjectParser;
1719
import org.elasticsearch.xcontent.ParseField;
1820
import org.elasticsearch.xcontent.XContentBuilder;
@@ -226,6 +228,14 @@ public TransportVersion getMinimalSupportedTransportVersion() {
226228
return MIN_SUPPORTED_TRANSPORT_VERSION;
227229
}
228230

231+
@Override
232+
public License.OperationMode getMinLicenseSupportedForAction(RestRequest.Method method) {
233+
if (method == RestRequest.Method.PUT) {
234+
return License.OperationMode.ENTERPRISE;
235+
}
236+
return super.getMinLicenseSupportedForAction(method);
237+
}
238+
229239
@Override
230240
public LearningToRankConfig rewrite(QueryRewriteContext ctx) throws IOException {
231241
if (this.featureExtractorBuilders.isEmpty()) {

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/LearningToRankConfigTests.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
import org.elasticsearch.common.settings.Settings;
1515
import org.elasticsearch.core.Tuple;
1616
import org.elasticsearch.index.query.QueryRewriteContext;
17+
import org.elasticsearch.license.License;
18+
import org.elasticsearch.rest.RestRequest;
1719
import org.elasticsearch.search.SearchModule;
1820
import org.elasticsearch.xcontent.ConstructingObjectParser;
1921
import org.elasticsearch.xcontent.NamedXContentRegistry;
@@ -36,6 +38,7 @@
3638
import java.util.stream.Stream;
3739

3840
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
41+
import static org.hamcrest.Matchers.is;
3942

4043
public class LearningToRankConfigTests extends InferenceConfigItemTestCase<LearningToRankConfig> {
4144
private boolean lenient;
@@ -140,6 +143,16 @@ public void testDuplicateFeatureNames() {
140143
expectThrows(IllegalArgumentException.class, () -> builder.build());
141144
}
142145

146+
public void testLicenseSupport_ForPutAction_RequiresEnterprise() {
147+
var config = randomLearningToRankConfig();
148+
assertThat(config.getMinLicenseSupportedForAction(RestRequest.Method.PUT), is(License.OperationMode.ENTERPRISE));
149+
}
150+
151+
public void testLicenseSupport_ForGetAction_RequiresPlatinum() {
152+
var config = randomLearningToRankConfig();
153+
assertThat(config.getMinLicenseSupportedForAction(RestRequest.Method.GET), is(License.OperationMode.PLATINUM));
154+
}
155+
143156
@Override
144157
protected NamedXContentRegistry xContentRegistry() {
145158
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java

Lines changed: 82 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
package org.elasticsearch.xpack.ml.action;
88

99
import org.elasticsearch.ElasticsearchException;
10+
import org.elasticsearch.ElasticsearchSecurityException;
1011
import org.elasticsearch.ElasticsearchStatusException;
1112
import org.elasticsearch.ResourceNotFoundException;
1213
import org.elasticsearch.TransportVersion;
@@ -37,6 +38,7 @@
3738
import org.elasticsearch.license.License;
3839
import org.elasticsearch.license.LicenseUtils;
3940
import org.elasticsearch.license.XPackLicenseState;
41+
import org.elasticsearch.rest.RestRequest;
4042
import org.elasticsearch.rest.RestStatus;
4143
import org.elasticsearch.search.builder.SearchSourceBuilder;
4244
import org.elasticsearch.tasks.Task;
@@ -143,61 +145,7 @@ protected void masterOperation(
143145
// NOTE: hasModelDefinition is false if we don't parse it. But, if the fully parsed model was already provided, continue
144146
boolean hasModelDefinition = config.getModelDefinition() != null;
145147
if (hasModelDefinition) {
146-
try {
147-
config.getModelDefinition().getTrainedModel().validate();
148-
} catch (ElasticsearchException ex) {
149-
finalResponseListener.onFailure(
150-
ExceptionsHelper.badRequestException("Definition for [{}] has validation failures.", ex, config.getModelId())
151-
);
152-
return;
153-
}
154-
155-
TrainedModelType trainedModelType = TrainedModelType.typeFromTrainedModel(config.getModelDefinition().getTrainedModel());
156-
if (trainedModelType == null) {
157-
finalResponseListener.onFailure(
158-
ExceptionsHelper.badRequestException(
159-
"Unknown trained model definition class [{}]",
160-
config.getModelDefinition().getTrainedModel().getName()
161-
)
162-
);
163-
return;
164-
}
165-
166-
if (config.getModelType() == null) {
167-
// Set the model type from the definition
168-
config = new TrainedModelConfig.Builder(config).setModelType(trainedModelType).build();
169-
} else if (trainedModelType != config.getModelType()) {
170-
finalResponseListener.onFailure(
171-
ExceptionsHelper.badRequestException(
172-
"{} [{}] does not match the model definition type [{}]",
173-
TrainedModelConfig.MODEL_TYPE.getPreferredName(),
174-
config.getModelType(),
175-
trainedModelType
176-
)
177-
);
178-
return;
179-
}
180-
181-
if (config.getInferenceConfig().isTargetTypeSupported(config.getModelDefinition().getTrainedModel().targetType()) == false) {
182-
finalResponseListener.onFailure(
183-
ExceptionsHelper.badRequestException(
184-
"Model [{}] inference config type [{}] does not support definition target type [{}]",
185-
config.getModelId(),
186-
config.getInferenceConfig().getName(),
187-
config.getModelDefinition().getTrainedModel().targetType()
188-
)
189-
);
190-
return;
191-
}
192-
193-
TransportVersion minCompatibilityVersion = config.getModelDefinition().getTrainedModel().getMinimalCompatibilityVersion();
194-
if (state.getMinTransportVersion().before(minCompatibilityVersion)) {
195-
finalResponseListener.onFailure(
196-
ExceptionsHelper.badRequestException(
197-
"Cannot create model [{}] while cluster upgrade is in progress.",
198-
config.getModelId()
199-
)
200-
);
148+
if (validateModelDefinition(config, state, licenseState, finalResponseListener) == false) {
201149
return;
202150
}
203151
}
@@ -507,6 +455,85 @@ private void checkTagsAgainstModelIds(List<String> tags, ActionListener<Void> li
507455
);
508456
}
509457

458+
public static boolean validateModelDefinition(
459+
TrainedModelConfig config,
460+
ClusterState state,
461+
XPackLicenseState licenseState,
462+
ActionListener<Response> finalResponseListener
463+
) {
464+
try {
465+
config.getModelDefinition().getTrainedModel().validate();
466+
} catch (ElasticsearchException ex) {
467+
finalResponseListener.onFailure(
468+
ExceptionsHelper.badRequestException("Definition for [{}] has validation failures.", ex, config.getModelId())
469+
);
470+
return false;
471+
}
472+
473+
TrainedModelType trainedModelType = TrainedModelType.typeFromTrainedModel(config.getModelDefinition().getTrainedModel());
474+
if (trainedModelType == null) {
475+
finalResponseListener.onFailure(
476+
ExceptionsHelper.badRequestException(
477+
"Unknown trained model definition class [{}]",
478+
config.getModelDefinition().getTrainedModel().getName()
479+
)
480+
);
481+
return false;
482+
}
483+
484+
var configModelType = config.getModelType();
485+
if (configModelType == null) {
486+
// Set the model type from the definition
487+
config = new TrainedModelConfig.Builder(config).setModelType(trainedModelType).build();
488+
} else if (trainedModelType != configModelType) {
489+
finalResponseListener.onFailure(
490+
ExceptionsHelper.badRequestException(
491+
"{} [{}] does not match the model definition type [{}]",
492+
TrainedModelConfig.MODEL_TYPE.getPreferredName(),
493+
configModelType,
494+
trainedModelType
495+
)
496+
);
497+
return false;
498+
}
499+
500+
var inferenceConfig = config.getInferenceConfig();
501+
if (inferenceConfig.isTargetTypeSupported(config.getModelDefinition().getTrainedModel().targetType()) == false) {
502+
finalResponseListener.onFailure(
503+
ExceptionsHelper.badRequestException(
504+
"Model [{}] inference config type [{}] does not support definition target type [{}]",
505+
config.getModelId(),
506+
config.getInferenceConfig().getName(),
507+
config.getModelDefinition().getTrainedModel().targetType()
508+
)
509+
);
510+
return false;
511+
}
512+
513+
var minLicenseSupported = inferenceConfig.getMinLicenseSupportedForAction(RestRequest.Method.PUT);
514+
if (licenseState.isAllowedByLicense(minLicenseSupported) == false) {
515+
finalResponseListener.onFailure(
516+
new ElasticsearchSecurityException(
517+
"Model of type [{}] requires [{}] license level",
518+
RestStatus.FORBIDDEN,
519+
config.getInferenceConfig().getName(),
520+
minLicenseSupported
521+
)
522+
);
523+
return false;
524+
}
525+
526+
TransportVersion minCompatibilityVersion = config.getModelDefinition().getTrainedModel().getMinimalCompatibilityVersion();
527+
if (state.getMinTransportVersion().before(minCompatibilityVersion)) {
528+
finalResponseListener.onFailure(
529+
ExceptionsHelper.badRequestException("Cannot create model [{}] while cluster upgrade is in progress.", config.getModelId())
530+
);
531+
return false;
532+
}
533+
534+
return true;
535+
}
536+
510537
@Override
511538
protected ClusterBlockException checkBlock(Request request, ClusterState state) {
512539
return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_WRITE);

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelActionTests.java

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,21 @@
88
package org.elasticsearch.xpack.ml.action;
99

1010
import org.elasticsearch.ElasticsearchException;
11+
import org.elasticsearch.ElasticsearchSecurityException;
1112
import org.elasticsearch.action.ActionListener;
1213
import org.elasticsearch.action.admin.cluster.node.tasks.list.ListTasksResponse;
1314
import org.elasticsearch.action.admin.cluster.node.tasks.list.TransportListTasksAction;
1415
import org.elasticsearch.action.support.ActionFilters;
1516
import org.elasticsearch.action.support.PlainActionFuture;
1617
import org.elasticsearch.client.internal.Client;
18+
import org.elasticsearch.cluster.ClusterState;
1719
import org.elasticsearch.cluster.service.ClusterService;
1820
import org.elasticsearch.common.bytes.BytesReference;
1921
import org.elasticsearch.common.xcontent.XContentHelper;
2022
import org.elasticsearch.core.TimeValue;
23+
import org.elasticsearch.license.License;
24+
import org.elasticsearch.license.XPackLicenseState;
25+
import org.elasticsearch.license.internal.XPackLicenseStatus;
2126
import org.elasticsearch.rest.RestStatus;
2227
import org.elasticsearch.test.ESTestCase;
2328
import org.elasticsearch.threadpool.TestThreadPool;
@@ -35,11 +40,13 @@
3540
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
3641
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
3742
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfigTests;
43+
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
3844
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInputTests;
3945
import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
4046
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigTests;
4147
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfigTests;
4248
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
49+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearningToRankConfig;
4350
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ModelPackageConfig;
4451
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ModelPackageConfigTests;
4552
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfigTests;
@@ -50,6 +57,7 @@
5057
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfigTests;
5158
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfigTests;
5259
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextSimilarityConfigTests;
60+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.langident.LangIdentNeuralNetwork;
5361
import org.junit.After;
5462
import org.junit.Before;
5563
import org.mockito.ArgumentCaptor;
@@ -60,10 +68,12 @@
6068
import java.util.Map;
6169
import java.util.concurrent.TimeUnit;
6270
import java.util.concurrent.atomic.AtomicBoolean;
71+
import java.util.concurrent.atomic.AtomicInteger;
6372

6473
import static org.elasticsearch.xpack.ml.utils.TaskRetrieverTests.getTaskInfoListOfOne;
6574
import static org.elasticsearch.xpack.ml.utils.TaskRetrieverTests.mockClientWithTasksResponse;
6675
import static org.elasticsearch.xpack.ml.utils.TaskRetrieverTests.mockListTasksClient;
76+
import static org.hamcrest.Matchers.instanceOf;
6777
import static org.hamcrest.Matchers.is;
6878
import static org.mockito.ArgumentMatchers.any;
6979
import static org.mockito.ArgumentMatchers.same;
@@ -73,6 +83,7 @@
7383
import static org.mockito.Mockito.mock;
7484
import static org.mockito.Mockito.spy;
7585
import static org.mockito.Mockito.verify;
86+
import static org.mockito.Mockito.when;
7687

7788
public class TransportPutTrainedModelActionTests extends ESTestCase {
7889
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
@@ -273,6 +284,56 @@ public void testVerifyMlNodesAndModelArchitectures_GivenArchitecturesMatch_ThenT
273284
ensureNoWarnings();
274285
}
275286

287+
public void testValidateModelDefinition_FailsWhenLicenseIsNotSupported() throws IOException {
288+
ModelPackageConfig packageConfig = ModelPackageConfigTests.randomModulePackageConfig();
289+
290+
TrainedModelConfig.Builder trainedModelConfigBuilder = new TrainedModelConfig.Builder().setModelId(
291+
"." + packageConfig.getPackagedModelId()
292+
).setInput(TrainedModelInputTests.createRandomInput());
293+
294+
TransportPutTrainedModelAction.setTrainedModelConfigFieldsFromPackagedModel(
295+
trainedModelConfigBuilder,
296+
packageConfig,
297+
xContentRegistry()
298+
);
299+
300+
var mockTrainedModelDefinition = mock(TrainedModelDefinition.class);
301+
when(mockTrainedModelDefinition.getTrainedModel()).thenReturn(mock(LangIdentNeuralNetwork.class));
302+
var trainedModelConfig = trainedModelConfigBuilder.setLicenseLevel("basic").build();
303+
304+
var mockModelInferenceConfig = spy(new LearningToRankConfig(1, List.of(), Map.of()));
305+
when(mockModelInferenceConfig.isTargetTypeSupported(any())).thenReturn(true);
306+
307+
var mockTrainedModelConfig = spy(trainedModelConfig);
308+
when(mockTrainedModelConfig.getModelType()).thenReturn(TrainedModelType.LANG_IDENT);
309+
when(mockTrainedModelConfig.getModelDefinition()).thenReturn(mockTrainedModelDefinition);
310+
when(mockTrainedModelConfig.getInferenceConfig()).thenReturn(mockModelInferenceConfig);
311+
312+
ActionListener<PutTrainedModelAction.Response> responseListener = ActionListener.wrap(
313+
response -> fail("Expected exception, but got response: " + response),
314+
exception -> {
315+
assertThat(exception, instanceOf(ElasticsearchSecurityException.class));
316+
assertThat(exception.getMessage(), is("Model of type [learning_to_rank] requires [ENTERPRISE] license level"));
317+
}
318+
);
319+
320+
var mockClusterState = mock(ClusterState.class);
321+
322+
AtomicInteger currentTime = new AtomicInteger(100);
323+
var mockXPackLicenseStatus = new XPackLicenseStatus(License.OperationMode.BASIC, true, "");
324+
var mockLicenseState = new XPackLicenseState(currentTime::get, mockXPackLicenseStatus);
325+
326+
assertThat(
327+
TransportPutTrainedModelAction.validateModelDefinition(
328+
mockTrainedModelConfig,
329+
mockClusterState,
330+
mockLicenseState,
331+
responseListener
332+
),
333+
is(false)
334+
);
335+
}
336+
276337
private static void prepareGetTrainedModelResponse(Client client, List<TrainedModelConfig> trainedModels) {
277338
doAnswer(invocationOnMock -> {
278339
@SuppressWarnings("unchecked")

0 commit comments

Comments
 (0)