@@ -1068,74 +1068,88 @@ public void getInferenceStats(String[] modelIds, @Nullable TaskId parentTaskId,
10681068 delegate ,
10691069 client .admin ().cluster ()::health
10701070 );
1071- }).<List <InferenceStats >>andThen ((delegate , clusterHealthResponse ) -> {
1072- if (clusterHealthResponse .isTimedOut ()) {
1073- logger .error (
1074- "getInferenceStats Timed out waiting for index [{}] to be available, this will probably cause the request to fail" ,
1075- MlStatsIndex .indexPattern ()
1076- );
1077- }
1078-
1079- MultiSearchRequest multiSearchRequest = new MultiSearchRequest ();
1080- Arrays .stream (modelIds ).map (TrainedModelProvider ::buildStatsSearchRequest ).forEach (multiSearchRequest ::add );
1081- if (multiSearchRequest .requests ().isEmpty ()) {
1082- listener .onResponse (Collections .emptyList ());
1083- return ;
1084- }
1085- if (parentTaskId != null ) {
1086- multiSearchRequest .setParentTask (parentTaskId );
1087- }
1088- executeAsyncWithOrigin (
1071+ })
1072+ .<List <InferenceStats >>andThen (
1073+ client .threadPool ().executor (MachineLearning .UTILITY_THREAD_POOL_NAME ),
10891074 client .threadPool ().getThreadContext (),
1090- ML_ORIGIN ,
1091- multiSearchRequest ,
1092- ActionListener .<MultiSearchResponse >wrap (responses -> {
1093- List <InferenceStats > allStats = new ArrayList <>(modelIds .length );
1094- int modelIndex = 0 ;
1095- assert responses .getResponses ().length == modelIds .length
1096- : "mismatch between search response size and models requested" ;
1097- for (MultiSearchResponse .Item response : responses .getResponses ()) {
1098- if (response .isFailure ()) {
1099- if (ExceptionsHelper .unwrapCause (response .getFailure ()) instanceof ResourceNotFoundException ) {
1100- modelIndex ++;
1101- continue ;
1102- }
1103- logger .error (
1104- () -> "[" + Strings .arrayToCommaDelimitedString (modelIds ) + "] search failed for models" ,
1105- response .getFailure ()
1106- );
1107- listener .onFailure (
1108- ExceptionsHelper .serverError (
1109- "Searching for stats for models [{}] failed" ,
1110- response .getFailure (),
1111- Strings .arrayToCommaDelimitedString (modelIds )
1112- )
1113- );
1114- return ;
1115- }
1116- try {
1117- InferenceStats inferenceStats = handleMultiNodeStatsResponse (response .getResponse (), modelIds [modelIndex ++]);
1118- if (inferenceStats != null ) {
1119- allStats .add (inferenceStats );
1120- }
1121- } catch (Exception e ) {
1122- listener .onFailure (e );
1123- return ;
1124- }
1075+ (delegate , clusterHealthResponse ) -> {
1076+ if (clusterHealthResponse .isTimedOut ()) {
1077+ logger .error (
1078+ "getInferenceStats Timed out waiting for index [{}] to be available, "
1079+ + "this will probably cause the request to fail" ,
1080+ MlStatsIndex .indexPattern ()
1081+ );
11251082 }
1126- listener . onResponse ( allStats );
1127- }, e -> {
1128- Throwable unwrapped = ExceptionsHelper . unwrapCause ( e );
1129- if (unwrapped instanceof ResourceNotFoundException ) {
1130- listener .onResponse (Collections .emptyList ());
1083+
1084+ MultiSearchRequest multiSearchRequest = new MultiSearchRequest ();
1085+ Arrays . stream ( modelIds ). map ( TrainedModelProvider :: buildStatsSearchRequest ). forEach ( multiSearchRequest :: add );
1086+ if (multiSearchRequest . requests (). isEmpty () ) {
1087+ delegate .onResponse (Collections .emptyList ());
11311088 return ;
11321089 }
1133- listener .onFailure ((Exception ) unwrapped );
1134- }),
1135- client ::multiSearch
1136- );
1090+ if (parentTaskId != null ) {
1091+ multiSearchRequest .setParentTask (parentTaskId );
1092+ }
1093+ executeAsyncWithOrigin (
1094+ client .threadPool ().getThreadContext (),
1095+ ML_ORIGIN ,
1096+ multiSearchRequest ,
1097+ ActionListener .<MultiSearchResponse >wrap (responses -> {
1098+ List <InferenceStats > allStats = new ArrayList <>(modelIds .length );
1099+ int modelIndex = 0 ;
1100+ assert responses .getResponses ().length == modelIds .length
1101+ : "mismatch between search response size and models requested" ;
1102+ for (MultiSearchResponse .Item response : responses .getResponses ()) {
1103+ if (response .isFailure ()) {
1104+ if (ExceptionsHelper .unwrapCause (response .getFailure ()) instanceof ResourceNotFoundException ) {
1105+ modelIndex ++;
1106+ continue ;
1107+ }
1108+ logger .error (
1109+ () -> "[" + Strings .arrayToCommaDelimitedString (modelIds ) + "] search failed for models" ,
1110+ response .getFailure ()
1111+ );
1112+ delegate .onFailure (
1113+ ExceptionsHelper .serverError (
1114+ "Searching for stats for models [{}] failed" ,
1115+ response .getFailure (),
1116+ Strings .arrayToCommaDelimitedString (modelIds )
1117+ )
1118+ );
1119+ return ;
1120+ }
1121+ try {
1122+ InferenceStats inferenceStats = handleMultiNodeStatsResponse (
1123+ response .getResponse (),
1124+ modelIds [modelIndex ++]
1125+ );
1126+ if (inferenceStats != null ) {
1127+ allStats .add (inferenceStats );
1128+ }
1129+ } catch (Exception e ) {
1130+ delegate .onFailure (e );
1131+ return ;
1132+ }
1133+ }
1134+ delegate .onResponse (allStats );
1135+ }, e -> {
1136+ Throwable unwrapped = ExceptionsHelper .unwrapCause (e );
1137+ if (unwrapped instanceof ResourceNotFoundException ) {
1138+ delegate .onResponse (Collections .emptyList ());
1139+ return ;
1140+ }
1141+ delegate .onFailure ((Exception ) unwrapped );
1142+ }),
1143+ client ::multiSearch
1144+ );
11371145
1138- }).addListener (listener );
1146+ }
1147+ )
1148+ .addListener (
1149+ listener ,
1150+ client .threadPool ().executor (MachineLearning .UTILITY_THREAD_POOL_NAME ),
1151+ client .threadPool ().getThreadContext ()
1152+ );
11391153 }
11401154
11411155 private static SearchRequest buildStatsSearchRequest (String modelId ) {
0 commit comments