Skip to content

Commit 97482ce

Browse files
author
Hendrik Muhs
authored
set location for packaged model correctly (#95399)
set the location for packaged models correctly, previously it used the default .ml-inference-... due to the order the fields get set. This change sets the config according to the model type correctly, with the option to use a dedicated index for packaged models in future.
1 parent 7fef10d commit 97482ce

File tree

3 files changed

+94
-23
lines changed

3 files changed

+94
-23
lines changed

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ModelPackageConfigTests.java

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,23 @@
88
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
99

1010
import org.elasticsearch.TransportVersion;
11+
import org.elasticsearch.common.bytes.BytesReference;
1112
import org.elasticsearch.common.io.stream.Writeable;
13+
import org.elasticsearch.common.xcontent.XContentHelper;
14+
import org.elasticsearch.xcontent.ToXContent;
15+
import org.elasticsearch.xcontent.XContentBuilder;
16+
import org.elasticsearch.xcontent.XContentFactory;
1217
import org.elasticsearch.xcontent.XContentParser;
18+
import org.elasticsearch.xcontent.XContentType;
1319
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
20+
import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
1421

1522
import java.io.IOException;
1623
import java.time.Duration;
1724
import java.time.Instant;
1825
import java.util.Arrays;
1926
import java.util.Collections;
27+
import java.util.Map;
2028

2129
public class ModelPackageConfigTests extends AbstractBWCSerializationTestCase<ModelPackageConfig> {
2230

@@ -31,9 +39,9 @@ public static ModelPackageConfig randomModulePackageConfig() {
3139
: null,
3240
randomLongBetween(0, Long.MAX_VALUE - 100),
3341
randomBoolean() ? randomAlphaOfLength(10) : null,
42+
randomInferenceConfigAsMap(),
3443
randomBoolean() ? Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)) : null,
35-
randomBoolean() ? Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)) : null,
36-
randomBoolean() ? randomAlphaOfLength(10) : null,
44+
randomFrom(TrainedModelType.values()).toString(),
3745
randomBoolean() ? Arrays.asList(generateRandomStringArray(randomIntBetween(0, 5), 15, false)) : null,
3846
randomBoolean() ? randomAlphaOfLength(10) : null
3947
);
@@ -104,4 +112,30 @@ protected ModelPackageConfig mutateInstance(ModelPackageConfig instance) {
104112
protected ModelPackageConfig mutateInstanceForVersion(ModelPackageConfig instance, TransportVersion version) {
105113
return instance;
106114
}
115+
116+
private static Map<String, Object> randomInferenceConfigAsMap() {
117+
InferenceConfig inferenceConfig = randomFrom(
118+
new InferenceConfig[] {
119+
ClassificationConfigTests.randomClassificationConfig(),
120+
RegressionConfigTests.randomRegressionConfig(),
121+
NerConfigTests.createRandom(),
122+
PassThroughConfigTests.createRandom(),
123+
TextClassificationConfigTests.createRandom(),
124+
FillMaskConfigTests.createRandom(),
125+
TextEmbeddingConfigTests.createRandom(),
126+
QuestionAnsweringConfigTests.createRandom(),
127+
TextSimilarityConfigTests.createRandom(),
128+
TextExpansionConfigTests.createRandom() }
129+
);
130+
131+
try (XContentBuilder xContentBuilder = XContentFactory.jsonBuilder()) {
132+
XContentBuilder content = inferenceConfig.toXContent(xContentBuilder, ToXContent.EMPTY_PARAMS);
133+
return Collections.singletonMap(
134+
inferenceConfig.getWriteableName(),
135+
XContentHelper.convertToMap(BytesReference.bytes(content), true, XContentType.JSON).v2()
136+
);
137+
} catch (IOException e) {
138+
throw new RuntimeException(e);
139+
}
140+
}
107141
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ protected void masterOperation(
313313
return;
314314
}
315315
modelPackageConfigHolder.set(resolvedModelPackageConfig);
316-
setTrainedModelConfigFieldsFromPackagedModel(trainedModelConfig, resolvedModelPackageConfig);
316+
setTrainedModelConfigFieldsFromPackagedModel(trainedModelConfig, resolvedModelPackageConfig, xContentRegistry);
317317

318318
checkModelIdAgainstTags(trainedModelConfig.getModelId(), modelIdTagCheckListener);
319319
}, listener::onFailure));
@@ -322,26 +322,6 @@ protected void masterOperation(
322322
}
323323
}
324324

325-
private void setTrainedModelConfigFieldsFromPackagedModel(
326-
TrainedModelConfig.Builder trainedModelConfig,
327-
ModelPackageConfig resolvedModelPackageConfig
328-
) throws IOException {
329-
trainedModelConfig.setDescription(resolvedModelPackageConfig.getDescription());
330-
trainedModelConfig.setModelType(TrainedModelType.fromString(resolvedModelPackageConfig.getModelType()));
331-
trainedModelConfig.setMetadata(resolvedModelPackageConfig.getMetadata());
332-
trainedModelConfig.setInferenceConfig(
333-
parseInferenceConfigFromModelPackage(
334-
resolvedModelPackageConfig.getInferenceConfigSource(),
335-
xContentRegistry,
336-
LoggingDeprecationHandler.INSTANCE
337-
)
338-
);
339-
trainedModelConfig.setTags(resolvedModelPackageConfig.getTags());
340-
trainedModelConfig.setModelPackageConfig(
341-
new ModelPackageConfig.Builder(resolvedModelPackageConfig).resetPackageOnlyFields().build()
342-
);
343-
}
344-
345325
private void triggerModelFetchIfNecessary(
346326
String modelId,
347327
ModelPackageConfig modelPackageConfig,
@@ -435,6 +415,29 @@ protected void doExecute(Task task, Request request, ActionListener<Response> li
435415
}
436416
}
437417

418+
static void setTrainedModelConfigFieldsFromPackagedModel(
419+
TrainedModelConfig.Builder trainedModelConfig,
420+
ModelPackageConfig resolvedModelPackageConfig,
421+
NamedXContentRegistry xContentRegistry
422+
) throws IOException {
423+
trainedModelConfig.setDescription(resolvedModelPackageConfig.getDescription());
424+
trainedModelConfig.setModelType(TrainedModelType.fromString(resolvedModelPackageConfig.getModelType()));
425+
trainedModelConfig.setMetadata(resolvedModelPackageConfig.getMetadata());
426+
trainedModelConfig.setInferenceConfig(
427+
parseInferenceConfigFromModelPackage(
428+
resolvedModelPackageConfig.getInferenceConfigSource(),
429+
xContentRegistry,
430+
LoggingDeprecationHandler.INSTANCE
431+
)
432+
);
433+
trainedModelConfig.setTags(resolvedModelPackageConfig.getTags());
434+
trainedModelConfig.setModelPackageConfig(
435+
new ModelPackageConfig.Builder(resolvedModelPackageConfig).resetPackageOnlyFields().build()
436+
);
437+
438+
trainedModelConfig.setLocation(trainedModelConfig.getModelType().getDefaultLocation(trainedModelConfig.getModelId()));
439+
}
440+
438441
static InferenceConfig parseInferenceConfigFromModelPackage(
439442
Map<String, Object> source,
440443
NamedXContentRegistry namedXContentRegistry,

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelActionTests.java

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,14 @@
1717
import org.elasticsearch.xcontent.XContentFactory;
1818
import org.elasticsearch.xcontent.XContentType;
1919
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
20+
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
21+
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInputTests;
22+
import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
2023
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigTests;
2124
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfigTests;
2225
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
26+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ModelPackageConfig;
27+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ModelPackageConfigTests;
2328
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfigTests;
2429
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfigTests;
2530
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.QuestionAnsweringConfigTests;
@@ -67,6 +72,35 @@ public void testParseInferenceConfigFromModelPackage() throws IOException {
6772
assertEquals(inferenceConfig, parsedInferenceConfig);
6873
}
6974

75+
public void testSetTrainedModelConfigFieldsFromPackagedModel() throws IOException {
76+
ModelPackageConfig packageConfig = ModelPackageConfigTests.randomModulePackageConfig();
77+
78+
TrainedModelConfig.Builder trainedModelConfigBuilder = new TrainedModelConfig.Builder().setModelId(
79+
"." + packageConfig.getPackagedModelId()
80+
).setInput(TrainedModelInputTests.createRandomInput());
81+
82+
TransportPutTrainedModelAction.setTrainedModelConfigFieldsFromPackagedModel(
83+
trainedModelConfigBuilder,
84+
packageConfig,
85+
xContentRegistry()
86+
);
87+
88+
TrainedModelConfig trainedModelConfig = trainedModelConfigBuilder.build();
89+
90+
assertEquals(packageConfig.getModelType(), trainedModelConfig.getModelType().toString());
91+
assertEquals(packageConfig.getDescription(), trainedModelConfig.getDescription());
92+
assertEquals(packageConfig.getMetadata(), trainedModelConfig.getMetadata());
93+
assertEquals(packageConfig.getTags(), trainedModelConfig.getTags());
94+
95+
// fully tested in {@link #testParseInferenceConfigFromModelPackage}
96+
assertNotNull(trainedModelConfig.getInferenceConfig());
97+
98+
assertEquals(
99+
TrainedModelType.fromString(packageConfig.getModelType()).getDefaultLocation(trainedModelConfig.getModelId()),
100+
trainedModelConfig.getLocation()
101+
);
102+
}
103+
70104
@Override
71105
protected NamedXContentRegistry xContentRegistry() {
72106
return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());

0 commit comments

Comments
 (0)