24
24
import org .elasticsearch .action .bulk .BulkResponse ;
25
25
import org .elasticsearch .action .index .IndexAction ;
26
26
import org .elasticsearch .action .index .IndexRequest ;
27
- import org .elasticsearch .action .search .MultiSearchAction ;
28
27
import org .elasticsearch .action .search .MultiSearchRequest ;
29
- import org .elasticsearch .action .search .MultiSearchRequestBuilder ;
30
28
import org .elasticsearch .action .search .MultiSearchResponse ;
31
29
import org .elasticsearch .action .search .SearchAction ;
32
30
import org .elasticsearch .action .search .SearchRequest ;
@@ -602,26 +600,17 @@ public void getTrainedModel(
602
600
}, finalListener ::onFailure );
603
601
604
602
QueryBuilder queryBuilder = QueryBuilders .constantScoreQuery (QueryBuilders .idsQuery ().addIds (modelId ));
605
- MultiSearchRequestBuilder multiSearchRequestBuilder = client .prepareMultiSearch ()
606
- .add (
607
- client .prepareSearch (InferenceIndexConstants .INDEX_PATTERN )
608
- .setQuery (queryBuilder )
609
- // use sort to get the last
610
- .addSort ("_index" , SortOrder .DESC )
611
- .setSize (1 )
612
- .request ()
613
- );
614
-
615
- if (includes .isIncludeModelDefinition ()) {
616
- multiSearchRequestBuilder .add (
617
- ChunkedTrainedModelRestorer .buildSearch (client , modelId , InferenceIndexConstants .INDEX_PATTERN , MAX_NUM_DEFINITION_DOCS )
618
- );
619
- }
603
+ SearchRequest trainedModelConfigSearch = client .prepareSearch (InferenceIndexConstants .INDEX_PATTERN )
604
+ .setQuery (queryBuilder )
605
+ // use sort to get the last
606
+ .addSort ("_index" , SortOrder .DESC )
607
+ .setSize (1 )
608
+ .request ();
620
609
621
- ActionListener <MultiSearchResponse > multiSearchResponseActionListener = ActionListener .wrap (multiSearchResponse -> {
610
+ ActionListener <SearchResponse > trainedModelSearchHandler = ActionListener .wrap (modelSearchResponse -> {
622
611
TrainedModelConfig .Builder builder ;
623
612
try {
624
- builder = handleSearchItem ( multiSearchResponse . getResponses ()[ 0 ] , modelId , this ::parseModelConfigLenientlyFromSource );
613
+ builder = handleHits ( modelSearchResponse . getHits (). getHits () , modelId , this ::parseModelConfigLenientlyFromSource ). get ( 0 );
625
614
} catch (ResourceNotFoundException ex ) {
626
615
getTrainedModelListener .onFailure (
627
616
new ResourceNotFoundException (Messages .getMessage (Messages .INFERENCE_NOT_FOUND , modelId ))
@@ -631,46 +620,58 @@ public void getTrainedModel(
631
620
getTrainedModelListener .onFailure (ex );
632
621
return ;
633
622
}
634
-
635
- if (includes .isIncludeModelDefinition ()) {
636
- try {
637
- List <TrainedModelDefinitionDoc > docs = handleSearchItems (
638
- multiSearchResponse .getResponses ()[1 ],
623
+ if (includes .isIncludeModelDefinition () == false ) {
624
+ getTrainedModelListener .onResponse (builder );
625
+ return ;
626
+ }
627
+ if (builder .getModelType () == TrainedModelType .PYTORCH && includes .isIncludeModelDefinition ()) {
628
+ finalListener .onFailure (
629
+ ExceptionsHelper .badRequestException (
630
+ "[{}] is type [{}] and does not support retrieving the definition" ,
639
631
modelId ,
640
- (bytes , resourceId ) -> ChunkedTrainedModelRestorer .parseModelDefinitionDocLenientlyFromSource (
641
- bytes ,
642
- resourceId ,
643
- xContentRegistry
644
- )
645
- );
632
+ builder .getModelType ()
633
+ )
634
+ );
635
+ return ;
636
+ }
637
+ executeAsyncWithOrigin (
638
+ client ,
639
+ ML_ORIGIN ,
640
+ SearchAction .INSTANCE ,
641
+ ChunkedTrainedModelRestorer .buildSearch (client , modelId , InferenceIndexConstants .INDEX_PATTERN , MAX_NUM_DEFINITION_DOCS ),
642
+ ActionListener .wrap (definitionSearchResponse -> {
646
643
try {
647
- BytesReference compressedData = getDefinitionFromDocs (docs , modelId );
648
- builder .setDefinitionFromBytes (compressedData );
649
- } catch (ElasticsearchException elasticsearchException ) {
650
- getTrainedModelListener .onFailure (elasticsearchException );
644
+ List <TrainedModelDefinitionDoc > docs = handleHits (
645
+ definitionSearchResponse .getHits ().getHits (),
646
+ modelId ,
647
+ (bytes , resourceId ) -> ChunkedTrainedModelRestorer .parseModelDefinitionDocLenientlyFromSource (
648
+ bytes ,
649
+ resourceId ,
650
+ xContentRegistry
651
+ )
652
+ );
653
+ try {
654
+ BytesReference compressedData = getDefinitionFromDocs (docs , modelId );
655
+ builder .setDefinitionFromBytes (compressedData );
656
+ } catch (ElasticsearchException elasticsearchException ) {
657
+ getTrainedModelListener .onFailure (elasticsearchException );
658
+ return ;
659
+ }
660
+
661
+ } catch (ResourceNotFoundException ex ) {
662
+ getTrainedModelListener .onFailure (
663
+ new ResourceNotFoundException (Messages .getMessage (Messages .MODEL_DEFINITION_NOT_FOUND , modelId ))
664
+ );
665
+ return ;
666
+ } catch (Exception ex ) {
667
+ getTrainedModelListener .onFailure (ex );
651
668
return ;
652
669
}
653
-
654
- } catch (ResourceNotFoundException ex ) {
655
- getTrainedModelListener .onFailure (
656
- new ResourceNotFoundException (Messages .getMessage (Messages .MODEL_DEFINITION_NOT_FOUND , modelId ))
657
- );
658
- return ;
659
- } catch (Exception ex ) {
660
- getTrainedModelListener .onFailure (ex );
661
- return ;
662
- }
663
- }
664
- getTrainedModelListener .onResponse (builder );
670
+ getTrainedModelListener .onResponse (builder );
671
+ }, getTrainedModelListener ::onFailure )
672
+ );
665
673
}, getTrainedModelListener ::onFailure );
666
-
667
- executeAsyncWithOrigin (
668
- client ,
669
- ML_ORIGIN ,
670
- MultiSearchAction .INSTANCE ,
671
- multiSearchRequestBuilder .request (),
672
- multiSearchResponseActionListener
673
- );
674
+ executeAsyncWithOrigin (client , ML_ORIGIN , SearchAction .INSTANCE , trainedModelConfigSearch , trainedModelSearchHandler );
674
675
}
675
676
676
677
public void getTrainedModels (
@@ -1204,6 +1205,9 @@ private static <T> List<T> handleHits(
1204
1205
String resourceId ,
1205
1206
CheckedBiFunction <BytesReference , String , T , Exception > parseLeniently
1206
1207
) throws Exception {
1208
+ if (hits .length == 0 ) {
1209
+ throw new ResourceNotFoundException (resourceId );
1210
+ }
1207
1211
List <T > results = new ArrayList <>(hits .length );
1208
1212
String initialIndex = hits [0 ].getIndex ();
1209
1213
for (SearchHit hit : hits ) {
0 commit comments