2424import org .elasticsearch .action .bulk .BulkResponse ;
2525import org .elasticsearch .action .index .IndexAction ;
2626import org .elasticsearch .action .index .IndexRequest ;
27- import org .elasticsearch .action .search .MultiSearchAction ;
2827import org .elasticsearch .action .search .MultiSearchRequest ;
29- import org .elasticsearch .action .search .MultiSearchRequestBuilder ;
3028import org .elasticsearch .action .search .MultiSearchResponse ;
3129import org .elasticsearch .action .search .SearchAction ;
3230import org .elasticsearch .action .search .SearchRequest ;
@@ -602,26 +600,17 @@ public void getTrainedModel(
602600 }, finalListener ::onFailure );
603601
604602 QueryBuilder queryBuilder = QueryBuilders .constantScoreQuery (QueryBuilders .idsQuery ().addIds (modelId ));
605- MultiSearchRequestBuilder multiSearchRequestBuilder = client .prepareMultiSearch ()
606- .add (
607- client .prepareSearch (InferenceIndexConstants .INDEX_PATTERN )
608- .setQuery (queryBuilder )
609- // use sort to get the last
610- .addSort ("_index" , SortOrder .DESC )
611- .setSize (1 )
612- .request ()
613- );
614-
615- if (includes .isIncludeModelDefinition ()) {
616- multiSearchRequestBuilder .add (
617- ChunkedTrainedModelRestorer .buildSearch (client , modelId , InferenceIndexConstants .INDEX_PATTERN , MAX_NUM_DEFINITION_DOCS )
618- );
619- }
603+ SearchRequest trainedModelConfigSearch = client .prepareSearch (InferenceIndexConstants .INDEX_PATTERN )
604+ .setQuery (queryBuilder )
605+ // use sort to get the last
606+ .addSort ("_index" , SortOrder .DESC )
607+ .setSize (1 )
608+ .request ();
620609
621- ActionListener <MultiSearchResponse > multiSearchResponseActionListener = ActionListener .wrap (multiSearchResponse -> {
610+ ActionListener <SearchResponse > trainedModelSearchHandler = ActionListener .wrap (modelSearchResponse -> {
622611 TrainedModelConfig .Builder builder ;
623612 try {
624- builder = handleSearchItem ( multiSearchResponse . getResponses ()[ 0 ] , modelId , this ::parseModelConfigLenientlyFromSource );
613+ builder = handleHits ( modelSearchResponse . getHits (). getHits () , modelId , this ::parseModelConfigLenientlyFromSource ). get ( 0 );
625614 } catch (ResourceNotFoundException ex ) {
626615 getTrainedModelListener .onFailure (
627616 new ResourceNotFoundException (Messages .getMessage (Messages .INFERENCE_NOT_FOUND , modelId ))
@@ -631,46 +620,58 @@ public void getTrainedModel(
631620 getTrainedModelListener .onFailure (ex );
632621 return ;
633622 }
634-
635- if (includes .isIncludeModelDefinition ()) {
636- try {
637- List <TrainedModelDefinitionDoc > docs = handleSearchItems (
638- multiSearchResponse .getResponses ()[1 ],
623+ if (includes .isIncludeModelDefinition () == false ) {
624+ getTrainedModelListener .onResponse (builder );
625+ return ;
626+ }
627+ if (builder .getModelType () == TrainedModelType .PYTORCH && includes .isIncludeModelDefinition ()) {
628+ finalListener .onFailure (
629+ ExceptionsHelper .badRequestException (
630+ "[{}] is type [{}] and does not support retrieving the definition" ,
639631 modelId ,
640- (bytes , resourceId ) -> ChunkedTrainedModelRestorer .parseModelDefinitionDocLenientlyFromSource (
641- bytes ,
642- resourceId ,
643- xContentRegistry
644- )
645- );
632+ builder .getModelType ()
633+ )
634+ );
635+ return ;
636+ }
637+ executeAsyncWithOrigin (
638+ client ,
639+ ML_ORIGIN ,
640+ SearchAction .INSTANCE ,
641+ ChunkedTrainedModelRestorer .buildSearch (client , modelId , InferenceIndexConstants .INDEX_PATTERN , MAX_NUM_DEFINITION_DOCS ),
642+ ActionListener .wrap (definitionSearchResponse -> {
646643 try {
647- BytesReference compressedData = getDefinitionFromDocs (docs , modelId );
648- builder .setDefinitionFromBytes (compressedData );
649- } catch (ElasticsearchException elasticsearchException ) {
650- getTrainedModelListener .onFailure (elasticsearchException );
644+ List <TrainedModelDefinitionDoc > docs = handleHits (
645+ definitionSearchResponse .getHits ().getHits (),
646+ modelId ,
647+ (bytes , resourceId ) -> ChunkedTrainedModelRestorer .parseModelDefinitionDocLenientlyFromSource (
648+ bytes ,
649+ resourceId ,
650+ xContentRegistry
651+ )
652+ );
653+ try {
654+ BytesReference compressedData = getDefinitionFromDocs (docs , modelId );
655+ builder .setDefinitionFromBytes (compressedData );
656+ } catch (ElasticsearchException elasticsearchException ) {
657+ getTrainedModelListener .onFailure (elasticsearchException );
658+ return ;
659+ }
660+
661+ } catch (ResourceNotFoundException ex ) {
662+ getTrainedModelListener .onFailure (
663+ new ResourceNotFoundException (Messages .getMessage (Messages .MODEL_DEFINITION_NOT_FOUND , modelId ))
664+ );
665+ return ;
666+ } catch (Exception ex ) {
667+ getTrainedModelListener .onFailure (ex );
651668 return ;
652669 }
653-
654- } catch (ResourceNotFoundException ex ) {
655- getTrainedModelListener .onFailure (
656- new ResourceNotFoundException (Messages .getMessage (Messages .MODEL_DEFINITION_NOT_FOUND , modelId ))
657- );
658- return ;
659- } catch (Exception ex ) {
660- getTrainedModelListener .onFailure (ex );
661- return ;
662- }
663- }
664- getTrainedModelListener .onResponse (builder );
670+ getTrainedModelListener .onResponse (builder );
671+ }, getTrainedModelListener ::onFailure )
672+ );
665673 }, getTrainedModelListener ::onFailure );
666-
667- executeAsyncWithOrigin (
668- client ,
669- ML_ORIGIN ,
670- MultiSearchAction .INSTANCE ,
671- multiSearchRequestBuilder .request (),
672- multiSearchResponseActionListener
673- );
674+ executeAsyncWithOrigin (client , ML_ORIGIN , SearchAction .INSTANCE , trainedModelConfigSearch , trainedModelSearchHandler );
674675 }
675676
676677 public void getTrainedModels (
@@ -1204,6 +1205,9 @@ private static <T> List<T> handleHits(
12041205 String resourceId ,
12051206 CheckedBiFunction <BytesReference , String , T , Exception > parseLeniently
12061207 ) throws Exception {
1208+ if (hits .length == 0 ) {
1209+ throw new ResourceNotFoundException (resourceId );
1210+ }
12071211 List <T > results = new ArrayList <>(hits .length );
12081212 String initialIndex = hits [0 ].getIndex ();
12091213 for (SearchHit hit : hits ) {
0 commit comments