Skip to content

Commit 0146002

Browse files
author
Hendrik Muhs
authored
[ML] implement check for the minimum version for a packaged trained model (#95361)
check the minimum_version field of a packaged model, only allow model install if all nodes fulfill the requirement
1 parent 72dabbb commit 0146002

File tree

3 files changed

+170
-1
lines changed

3 files changed

+170
-1
lines changed
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.ml.action;
9+
10+
import org.elasticsearch.Version;
11+
import org.elasticsearch.action.ActionRequestValidationException;
12+
import org.elasticsearch.cluster.ClusterState;
13+
import org.elasticsearch.common.Strings;
14+
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
15+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ModelPackageConfig;
16+
17+
import static org.elasticsearch.core.Strings.format;
18+
19+
/**
20+
* {@TrainedModelValidator} analyzes a trained model config to find various issues w.r.t. the various combinations of model types,
21+
* packages, etc.
22+
*/
23+
final class TrainedModelValidator {
24+
25+
static void validatePackage(
26+
TrainedModelConfig.Builder trainedModelConfig,
27+
ModelPackageConfig resolvedModelPackageConfig,
28+
ClusterState state
29+
) {
30+
validateMinimumVersion(resolvedModelPackageConfig, state);
31+
// check that user doesn't try to override parts that are
32+
trainedModelConfig.validateNoPackageOverrides();
33+
}
34+
35+
static void validateMinimumVersion(ModelPackageConfig resolvedModelPackageConfig, ClusterState state) {
36+
Version minimumVersion;
37+
38+
// {@link Version#fromString} interprets an empty string as current version, so we have to check ourselves
39+
if (Strings.isNullOrEmpty(resolvedModelPackageConfig.getMinimumVersion())) {
40+
throw new ActionRequestValidationException().addValidationError(
41+
format(
42+
"Invalid model package configuration for [%s], missing minimum_version property",
43+
resolvedModelPackageConfig.getPackagedModelId()
44+
)
45+
);
46+
}
47+
48+
try {
49+
minimumVersion = Version.fromString(resolvedModelPackageConfig.getMinimumVersion());
50+
} catch (IllegalArgumentException e) {
51+
throw new ActionRequestValidationException().addValidationError(
52+
format(
53+
"Invalid model package configuration for [%s], failed to parse the minimum_version property",
54+
resolvedModelPackageConfig.getPackagedModelId()
55+
)
56+
);
57+
}
58+
59+
if (state.nodes().getMinNodeVersion().before(minimumVersion)) {
60+
throw new ActionRequestValidationException().addValidationError(
61+
format(
62+
"The model [%s] requires that all nodes are at least version [%s]",
63+
resolvedModelPackageConfig.getPackagedModelId(),
64+
resolvedModelPackageConfig.getMinimumVersion()
65+
)
66+
);
67+
}
68+
}
69+
70+
private TrainedModelValidator() {}
71+
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ protected void masterOperation(
307307
if (config.isPackagedModel()) {
308308
resolvePackageConfig(config.getModelId(), ActionListener.wrap(resolvedModelPackageConfig -> {
309309
try {
310-
trainedModelConfig.validateNoPackageOverrides();
310+
TrainedModelValidator.validatePackage(trainedModelConfig, resolvedModelPackageConfig, state);
311311
} catch (ValidationException e) {
312312
listener.onFailure(e);
313313
return;
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.ml.action;
9+
10+
import org.elasticsearch.Version;
11+
import org.elasticsearch.action.ActionRequestValidationException;
12+
import org.elasticsearch.cluster.ClusterState;
13+
import org.elasticsearch.cluster.node.DiscoveryNodes;
14+
import org.elasticsearch.test.ESTestCase;
15+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ModelPackageConfig;
16+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ModelPackageConfigTests;
17+
18+
import static org.mockito.Mockito.mock;
19+
import static org.mockito.Mockito.when;
20+
21+
public class TrainedModelValidatorTests extends ESTestCase {
22+
23+
public void testValidateMinimumVersion() {
24+
final ModelPackageConfig packageConfig = new ModelPackageConfig.Builder(ModelPackageConfigTests.randomModulePackageConfig())
25+
.setMinimumVersion("99.9.9")
26+
.build();
27+
28+
ClusterState state = mock(ClusterState.class);
29+
DiscoveryNodes nodes = mock(DiscoveryNodes.class);
30+
when(state.nodes()).thenReturn(nodes);
31+
when(nodes.getMinNodeVersion()).thenReturn(Version.CURRENT);
32+
33+
Exception e = expectThrows(
34+
ActionRequestValidationException.class,
35+
() -> TrainedModelValidator.validateMinimumVersion(packageConfig, state)
36+
);
37+
38+
assertEquals(
39+
"Validation Failed: 1: The model ["
40+
+ packageConfig.getPackagedModelId()
41+
+ "] requires that all nodes are at least version [99.9.9];",
42+
e.getMessage()
43+
);
44+
45+
final ModelPackageConfig packageConfigCurrent = new ModelPackageConfig.Builder(ModelPackageConfigTests.randomModulePackageConfig())
46+
.setMinimumVersion(Version.CURRENT.toString())
47+
.build();
48+
TrainedModelValidator.validateMinimumVersion(packageConfigCurrent, state);
49+
50+
when(nodes.getMinNodeVersion()).thenReturn(Version.V_8_7_0);
51+
52+
e = expectThrows(
53+
ActionRequestValidationException.class,
54+
() -> TrainedModelValidator.validateMinimumVersion(packageConfigCurrent, state)
55+
);
56+
57+
assertEquals(
58+
"Validation Failed: 1: The model ["
59+
+ packageConfigCurrent.getPackagedModelId()
60+
+ "] requires that all nodes are at least version ["
61+
+ Version.CURRENT
62+
+ "];",
63+
e.getMessage()
64+
);
65+
66+
final ModelPackageConfig packageConfigBroken = new ModelPackageConfig.Builder(ModelPackageConfigTests.randomModulePackageConfig())
67+
.setMinimumVersion("_broken_version_")
68+
.build();
69+
70+
e = expectThrows(
71+
ActionRequestValidationException.class,
72+
() -> TrainedModelValidator.validateMinimumVersion(packageConfigBroken, state)
73+
);
74+
75+
assertEquals(
76+
"Validation Failed: 1: Invalid model package configuration for ["
77+
+ packageConfigBroken.getPackagedModelId()
78+
+ "], failed to parse the minimum_version property;",
79+
e.getMessage()
80+
);
81+
82+
final ModelPackageConfig packageConfigVersionMissing = new ModelPackageConfig.Builder(
83+
ModelPackageConfigTests.randomModulePackageConfig()
84+
).setMinimumVersion("").build();
85+
86+
e = expectThrows(
87+
ActionRequestValidationException.class,
88+
() -> TrainedModelValidator.validateMinimumVersion(packageConfigVersionMissing, state)
89+
);
90+
91+
assertEquals(
92+
"Validation Failed: 1: Invalid model package configuration for ["
93+
+ packageConfigVersionMissing.getPackagedModelId()
94+
+ "], missing minimum_version property;",
95+
e.getMessage()
96+
);
97+
}
98+
}

0 commit comments

Comments
 (0)