Skip to content

Commit ba1f616

Browse files
authored
[ML] throw when definition is requested for pytorch models (#80310) (#80338)
Since pytorch models are not built in Elasticsearch, we don't need to provide supplying the definition when retrieving the trained model. In fact, some of these definitions are so large, that returning them is prohibitive. related to #80254
1 parent fdb2ac2 commit ba1f616

File tree

2 files changed

+70
-54
lines changed

2 files changed

+70
-54
lines changed

x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,18 @@ public void testInferWithMissingModel() {
417417
assertThat(ex.getMessage(), containsString("Could not find trained model [missing_model]"));
418418
}
419419

420+
public void testGetPytorchModelWithDefinition() throws IOException {
421+
String model = "should-fail-get";
422+
createTrainedModel(model);
423+
putVocabulary(List.of("once", "twice"), model);
424+
putModelDefinition(model);
425+
Exception ex = expectThrows(
426+
Exception.class,
427+
() -> client().performRequest(new Request("GET", "_ml/trained_models/" + model + "?include=definition"))
428+
);
429+
assertThat(ex.getMessage(), containsString("[should-fail-get] is type [pytorch] and does not support retrieving the definition"));
430+
}
431+
420432
public void testInferencePipelineAgainstUnallocatedModel() throws IOException {
421433
String model = "not-deployed";
422434
createTrainedModel(model);

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java

Lines changed: 58 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,7 @@
2424
import org.elasticsearch.action.bulk.BulkResponse;
2525
import org.elasticsearch.action.index.IndexAction;
2626
import org.elasticsearch.action.index.IndexRequest;
27-
import org.elasticsearch.action.search.MultiSearchAction;
2827
import org.elasticsearch.action.search.MultiSearchRequest;
29-
import org.elasticsearch.action.search.MultiSearchRequestBuilder;
3028
import org.elasticsearch.action.search.MultiSearchResponse;
3129
import org.elasticsearch.action.search.SearchAction;
3230
import org.elasticsearch.action.search.SearchRequest;
@@ -602,26 +600,17 @@ public void getTrainedModel(
602600
}, finalListener::onFailure);
603601

604602
QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds(modelId));
605-
MultiSearchRequestBuilder multiSearchRequestBuilder = client.prepareMultiSearch()
606-
.add(
607-
client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN)
608-
.setQuery(queryBuilder)
609-
// use sort to get the last
610-
.addSort("_index", SortOrder.DESC)
611-
.setSize(1)
612-
.request()
613-
);
614-
615-
if (includes.isIncludeModelDefinition()) {
616-
multiSearchRequestBuilder.add(
617-
ChunkedTrainedModelRestorer.buildSearch(client, modelId, InferenceIndexConstants.INDEX_PATTERN, MAX_NUM_DEFINITION_DOCS)
618-
);
619-
}
603+
SearchRequest trainedModelConfigSearch = client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN)
604+
.setQuery(queryBuilder)
605+
// use sort to get the last
606+
.addSort("_index", SortOrder.DESC)
607+
.setSize(1)
608+
.request();
620609

621-
ActionListener<MultiSearchResponse> multiSearchResponseActionListener = ActionListener.wrap(multiSearchResponse -> {
610+
ActionListener<SearchResponse> trainedModelSearchHandler = ActionListener.wrap(modelSearchResponse -> {
622611
TrainedModelConfig.Builder builder;
623612
try {
624-
builder = handleSearchItem(multiSearchResponse.getResponses()[0], modelId, this::parseModelConfigLenientlyFromSource);
613+
builder = handleHits(modelSearchResponse.getHits().getHits(), modelId, this::parseModelConfigLenientlyFromSource).get(0);
625614
} catch (ResourceNotFoundException ex) {
626615
getTrainedModelListener.onFailure(
627616
new ResourceNotFoundException(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId))
@@ -631,46 +620,58 @@ public void getTrainedModel(
631620
getTrainedModelListener.onFailure(ex);
632621
return;
633622
}
634-
635-
if (includes.isIncludeModelDefinition()) {
636-
try {
637-
List<TrainedModelDefinitionDoc> docs = handleSearchItems(
638-
multiSearchResponse.getResponses()[1],
623+
if (includes.isIncludeModelDefinition() == false) {
624+
getTrainedModelListener.onResponse(builder);
625+
return;
626+
}
627+
if (builder.getModelType() == TrainedModelType.PYTORCH && includes.isIncludeModelDefinition()) {
628+
finalListener.onFailure(
629+
ExceptionsHelper.badRequestException(
630+
"[{}] is type [{}] and does not support retrieving the definition",
639631
modelId,
640-
(bytes, resourceId) -> ChunkedTrainedModelRestorer.parseModelDefinitionDocLenientlyFromSource(
641-
bytes,
642-
resourceId,
643-
xContentRegistry
644-
)
645-
);
632+
builder.getModelType()
633+
)
634+
);
635+
return;
636+
}
637+
executeAsyncWithOrigin(
638+
client,
639+
ML_ORIGIN,
640+
SearchAction.INSTANCE,
641+
ChunkedTrainedModelRestorer.buildSearch(client, modelId, InferenceIndexConstants.INDEX_PATTERN, MAX_NUM_DEFINITION_DOCS),
642+
ActionListener.wrap(definitionSearchResponse -> {
646643
try {
647-
BytesReference compressedData = getDefinitionFromDocs(docs, modelId);
648-
builder.setDefinitionFromBytes(compressedData);
649-
} catch (ElasticsearchException elasticsearchException) {
650-
getTrainedModelListener.onFailure(elasticsearchException);
644+
List<TrainedModelDefinitionDoc> docs = handleHits(
645+
definitionSearchResponse.getHits().getHits(),
646+
modelId,
647+
(bytes, resourceId) -> ChunkedTrainedModelRestorer.parseModelDefinitionDocLenientlyFromSource(
648+
bytes,
649+
resourceId,
650+
xContentRegistry
651+
)
652+
);
653+
try {
654+
BytesReference compressedData = getDefinitionFromDocs(docs, modelId);
655+
builder.setDefinitionFromBytes(compressedData);
656+
} catch (ElasticsearchException elasticsearchException) {
657+
getTrainedModelListener.onFailure(elasticsearchException);
658+
return;
659+
}
660+
661+
} catch (ResourceNotFoundException ex) {
662+
getTrainedModelListener.onFailure(
663+
new ResourceNotFoundException(Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId))
664+
);
665+
return;
666+
} catch (Exception ex) {
667+
getTrainedModelListener.onFailure(ex);
651668
return;
652669
}
653-
654-
} catch (ResourceNotFoundException ex) {
655-
getTrainedModelListener.onFailure(
656-
new ResourceNotFoundException(Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId))
657-
);
658-
return;
659-
} catch (Exception ex) {
660-
getTrainedModelListener.onFailure(ex);
661-
return;
662-
}
663-
}
664-
getTrainedModelListener.onResponse(builder);
670+
getTrainedModelListener.onResponse(builder);
671+
}, getTrainedModelListener::onFailure)
672+
);
665673
}, getTrainedModelListener::onFailure);
666-
667-
executeAsyncWithOrigin(
668-
client,
669-
ML_ORIGIN,
670-
MultiSearchAction.INSTANCE,
671-
multiSearchRequestBuilder.request(),
672-
multiSearchResponseActionListener
673-
);
674+
executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, trainedModelConfigSearch, trainedModelSearchHandler);
674675
}
675676

676677
public void getTrainedModels(
@@ -1204,6 +1205,9 @@ private static <T> List<T> handleHits(
12041205
String resourceId,
12051206
CheckedBiFunction<BytesReference, String, T, Exception> parseLeniently
12061207
) throws Exception {
1208+
if (hits.length == 0) {
1209+
throw new ResourceNotFoundException(resourceId);
1210+
}
12071211
List<T> results = new ArrayList<>(hits.length);
12081212
String initialIndex = hits[0].getIndex();
12091213
for (SearchHit hit : hits) {

0 commit comments

Comments
 (0)