Skip to content

Commit d97e769

Browse files
author
Max Hniebergall
committed
improvements from review
1 parent a2e763b commit d97e769

File tree

1 file changed

+77
-63
lines changed

1 file changed

+77
-63
lines changed

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java

Lines changed: 77 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1068,74 +1068,88 @@ public void getInferenceStats(String[] modelIds, @Nullable TaskId parentTaskId,
10681068
delegate,
10691069
client.admin().cluster()::health
10701070
);
1071-
}).<List<InferenceStats>>andThen((delegate, clusterHealthResponse) -> {
1072-
if (clusterHealthResponse.isTimedOut()) {
1073-
logger.error(
1074-
"getInferenceStats Timed out waiting for index [{}] to be available, this will probably cause the request to fail",
1075-
MlStatsIndex.indexPattern()
1076-
);
1077-
}
1078-
1079-
MultiSearchRequest multiSearchRequest = new MultiSearchRequest();
1080-
Arrays.stream(modelIds).map(TrainedModelProvider::buildStatsSearchRequest).forEach(multiSearchRequest::add);
1081-
if (multiSearchRequest.requests().isEmpty()) {
1082-
listener.onResponse(Collections.emptyList());
1083-
return;
1084-
}
1085-
if (parentTaskId != null) {
1086-
multiSearchRequest.setParentTask(parentTaskId);
1087-
}
1088-
executeAsyncWithOrigin(
1071+
})
1072+
.<List<InferenceStats>>andThen(
1073+
client.threadPool().executor(MachineLearning.UTILITY_THREAD_POOL_NAME),
10891074
client.threadPool().getThreadContext(),
1090-
ML_ORIGIN,
1091-
multiSearchRequest,
1092-
ActionListener.<MultiSearchResponse>wrap(responses -> {
1093-
List<InferenceStats> allStats = new ArrayList<>(modelIds.length);
1094-
int modelIndex = 0;
1095-
assert responses.getResponses().length == modelIds.length
1096-
: "mismatch between search response size and models requested";
1097-
for (MultiSearchResponse.Item response : responses.getResponses()) {
1098-
if (response.isFailure()) {
1099-
if (ExceptionsHelper.unwrapCause(response.getFailure()) instanceof ResourceNotFoundException) {
1100-
modelIndex++;
1101-
continue;
1102-
}
1103-
logger.error(
1104-
() -> "[" + Strings.arrayToCommaDelimitedString(modelIds) + "] search failed for models",
1105-
response.getFailure()
1106-
);
1107-
listener.onFailure(
1108-
ExceptionsHelper.serverError(
1109-
"Searching for stats for models [{}] failed",
1110-
response.getFailure(),
1111-
Strings.arrayToCommaDelimitedString(modelIds)
1112-
)
1113-
);
1114-
return;
1115-
}
1116-
try {
1117-
InferenceStats inferenceStats = handleMultiNodeStatsResponse(response.getResponse(), modelIds[modelIndex++]);
1118-
if (inferenceStats != null) {
1119-
allStats.add(inferenceStats);
1120-
}
1121-
} catch (Exception e) {
1122-
listener.onFailure(e);
1123-
return;
1124-
}
1075+
(delegate, clusterHealthResponse) -> {
1076+
if (clusterHealthResponse.isTimedOut()) {
1077+
logger.error(
1078+
"getInferenceStats Timed out waiting for index [{}] to be available, "
1079+
+ "this will probably cause the request to fail",
1080+
MlStatsIndex.indexPattern()
1081+
);
11251082
}
1126-
listener.onResponse(allStats);
1127-
}, e -> {
1128-
Throwable unwrapped = ExceptionsHelper.unwrapCause(e);
1129-
if (unwrapped instanceof ResourceNotFoundException) {
1130-
listener.onResponse(Collections.emptyList());
1083+
1084+
MultiSearchRequest multiSearchRequest = new MultiSearchRequest();
1085+
Arrays.stream(modelIds).map(TrainedModelProvider::buildStatsSearchRequest).forEach(multiSearchRequest::add);
1086+
if (multiSearchRequest.requests().isEmpty()) {
1087+
delegate.onResponse(Collections.emptyList());
11311088
return;
11321089
}
1133-
listener.onFailure((Exception) unwrapped);
1134-
}),
1135-
client::multiSearch
1136-
);
1090+
if (parentTaskId != null) {
1091+
multiSearchRequest.setParentTask(parentTaskId);
1092+
}
1093+
executeAsyncWithOrigin(
1094+
client.threadPool().getThreadContext(),
1095+
ML_ORIGIN,
1096+
multiSearchRequest,
1097+
ActionListener.<MultiSearchResponse>wrap(responses -> {
1098+
List<InferenceStats> allStats = new ArrayList<>(modelIds.length);
1099+
int modelIndex = 0;
1100+
assert responses.getResponses().length == modelIds.length
1101+
: "mismatch between search response size and models requested";
1102+
for (MultiSearchResponse.Item response : responses.getResponses()) {
1103+
if (response.isFailure()) {
1104+
if (ExceptionsHelper.unwrapCause(response.getFailure()) instanceof ResourceNotFoundException) {
1105+
modelIndex++;
1106+
continue;
1107+
}
1108+
logger.error(
1109+
() -> "[" + Strings.arrayToCommaDelimitedString(modelIds) + "] search failed for models",
1110+
response.getFailure()
1111+
);
1112+
delegate.onFailure(
1113+
ExceptionsHelper.serverError(
1114+
"Searching for stats for models [{}] failed",
1115+
response.getFailure(),
1116+
Strings.arrayToCommaDelimitedString(modelIds)
1117+
)
1118+
);
1119+
return;
1120+
}
1121+
try {
1122+
InferenceStats inferenceStats = handleMultiNodeStatsResponse(
1123+
response.getResponse(),
1124+
modelIds[modelIndex++]
1125+
);
1126+
if (inferenceStats != null) {
1127+
allStats.add(inferenceStats);
1128+
}
1129+
} catch (Exception e) {
1130+
delegate.onFailure(e);
1131+
return;
1132+
}
1133+
}
1134+
delegate.onResponse(allStats);
1135+
}, e -> {
1136+
Throwable unwrapped = ExceptionsHelper.unwrapCause(e);
1137+
if (unwrapped instanceof ResourceNotFoundException) {
1138+
delegate.onResponse(Collections.emptyList());
1139+
return;
1140+
}
1141+
delegate.onFailure((Exception) unwrapped);
1142+
}),
1143+
client::multiSearch
1144+
);
11371145

1138-
}).addListener(listener);
1146+
}
1147+
)
1148+
.addListener(
1149+
listener,
1150+
client.threadPool().executor(MachineLearning.UTILITY_THREAD_POOL_NAME),
1151+
client.threadPool().getThreadContext()
1152+
);
11391153
}
11401154

11411155
private static SearchRequest buildStatsSearchRequest(String modelId) {

0 commit comments

Comments
 (0)