Skip to content

Commit fb10f12

Browse files
authored
[ML] PyTorchModelIT: investigate single processor mode failures (#91547)
Start the priority process worker after the model has been loaded. Adds debug logging and some renaming for clarity.
1 parent fe0e5c4 commit fb10f12

File tree

3 files changed

+43
-34
lines changed

3 files changed

+43
-34
lines changed

x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ public void setLogging() throws IOException {
124124
{"persistent" : {
125125
"logger.org.elasticsearch.xpack.ml.inference.assignment" : "DEBUG",
126126
"logger.org.elasticsearch.xpack.ml.inference.deployment" : "DEBUG",
127+
"logger.org.elasticsearch.xpack.ml.inference.pytorch" : "DEBUG",
127128
"logger.org.elasticsearch.xpack.ml.process.logging" : "DEBUG"
128129
}}""");
129130
client().performRequest(loggingSettings);
@@ -139,6 +140,7 @@ public void cleanup() throws Exception {
139140
"logger.org.elasticsearch.xpack.ml.inference.assignment": null,
140141
"logger.org.elasticsearch.xpack.ml.inference.deployment" : null,
141142
"logger.org.elasticsearch.xpack.ml.process.logging" : null,
143+
"logger.org.elasticsearch.xpack.ml.inference.pytorch" : null,
142144
"xpack.ml.max_lazy_ml_nodes": null
143145
}}""");
144146
client().performRequest(loggingSettings);
@@ -293,14 +295,14 @@ public void testDeploymentStats() throws IOException {
293295

294296
@SuppressWarnings("unchecked")
295297
public void testLiveDeploymentStats() throws IOException {
296-
String modelA = "model_a";
298+
String modelId = "live_deployment_stats";
297299

298-
createTrainedModel(modelA);
299-
putVocabulary(List.of("once", "twice"), modelA);
300-
putModelDefinition(modelA);
301-
startDeployment(modelA, AllocationStatus.State.FULLY_ALLOCATED.toString());
300+
createTrainedModel(modelId);
301+
putVocabulary(List.of("once", "twice"), modelId);
302+
putModelDefinition(modelId);
303+
startDeployment(modelId, AllocationStatus.State.FULLY_ALLOCATED.toString());
302304
{
303-
Response noInferenceCallsStatsResponse = getTrainedModelStats(modelA);
305+
Response noInferenceCallsStatsResponse = getTrainedModelStats(modelId);
304306
List<Map<String, Object>> stats = (List<Map<String, Object>>) entityAsMap(noInferenceCallsStatsResponse).get(
305307
"trained_model_stats"
306308
);
@@ -321,17 +323,17 @@ public void testLiveDeploymentStats() throws IOException {
321323
}
322324
}
323325

324-
infer("once", modelA);
325-
infer("twice", modelA);
326+
infer("once", modelId);
327+
infer("twice", modelId);
326328
// By making this request 3 times at least one of the responses must come from the cache because the cluster has 2 ML nodes
327-
infer("three times", modelA);
328-
infer("three times", modelA);
329-
infer("three times", modelA);
329+
infer("three times", modelId);
330+
infer("three times", modelId);
331+
infer("three times", modelId);
330332
{
331-
Response postInferStatsResponse = getTrainedModelStats(modelA);
333+
Response postInferStatsResponse = getTrainedModelStats(modelId);
332334
List<Map<String, Object>> stats = (List<Map<String, Object>>) entityAsMap(postInferStatsResponse).get("trained_model_stats");
333335
assertThat(stats, hasSize(1));
334-
assertThat(XContentMapValues.extractValue("deployment_stats.model_id", stats.get(0)), equalTo(modelA));
336+
assertThat(XContentMapValues.extractValue("deployment_stats.model_id", stats.get(0)), equalTo(modelId));
335337
assertThat(XContentMapValues.extractValue("model_size_stats.model_size_bytes", stats.get(0)), equalTo((int) RAW_MODEL_SIZE));
336338
List<Map<String, Object>> nodes = (List<Map<String, Object>>) XContentMapValues.extractValue(
337339
"deployment_stats.nodes",
@@ -748,12 +750,12 @@ public void testStoppingDeploymentShouldTriggerRebalance() throws Exception {
748750
}}""");
749751
client().performRequest(loggingSettings);
750752

751-
String modelId1 = "model_1";
753+
String modelId1 = "stopping_triggers_rebalance_1";
752754
createTrainedModel(modelId1);
753755
putModelDefinition(modelId1);
754756
putVocabulary(List.of("these", "are", "my", "words"), modelId1);
755757

756-
String modelId2 = "model_2";
758+
String modelId2 = "stopping_triggers_rebalance_2";
757759
createTrainedModel(modelId2);
758760
putModelDefinition(modelId2);
759761
putVocabulary(List.of("these", "are", "my", "words"), modelId2);
@@ -826,12 +828,12 @@ public void testStartDeployment_GivenNoProcessorsLeft_AndLazyStartEnabled() thro
826828
}}""");
827829
client().performRequest(loggingSettings);
828830

829-
String modelId1 = "model_1";
831+
String modelId1 = "start_no_processors_left_lazy_start_1";
830832
createTrainedModel(modelId1);
831833
putModelDefinition(modelId1);
832834
putVocabulary(List.of("these", "are", "my", "words"), modelId1);
833835

834-
String modelId2 = "model_2";
836+
String modelId2 = "start_no_processors_left_lazy_start_2";
835837
createTrainedModel(modelId2);
836838
putModelDefinition(modelId2);
837839
putVocabulary(List.of("these", "are", "my", "words"), modelId2);

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ public Optional<ModelStats> getStats(TrainedModelDeploymentTask task) {
107107
stats.timingStats().getAverage(),
108108
stats.timingStatsExcludingCacheHits().getAverage(),
109109
stats.lastUsed(),
110-
processContext.executorService.queueSize() + stats.numberOfPendingResults(),
110+
processContext.priorityProcessWorker.queueSize() + stats.numberOfPendingResults(),
111111
stats.errorCount(),
112112
stats.cacheHitCount(),
113113
processContext.rejectedExecutionCount.intValue(),
@@ -130,7 +130,7 @@ ProcessContext addProcessContext(Long id, ProcessContext processContext) {
130130
private void doStartDeployment(TrainedModelDeploymentTask task, ActionListener<TrainedModelDeploymentTask> finalListener) {
131131
logger.info("[{}] Starting model deployment", task.getModelId());
132132

133-
ProcessContext processContext = new ProcessContext(task, executorServiceForProcess);
133+
ProcessContext processContext = new ProcessContext(task);
134134
if (addProcessContext(task.getId(), processContext) != null) {
135135
finalListener.onFailure(
136136
ExceptionsHelper.serverError("[{}] Could not create inference process as one already exists", task.getModelId())
@@ -232,7 +232,10 @@ Vocabulary parseVocabularyDocLeniently(SearchHit hit) throws IOException {
232232
private void startAndLoad(ProcessContext processContext, TrainedModelLocation modelLocation, ActionListener<Boolean> loadedListener) {
233233
try {
234234
processContext.startProcess();
235-
processContext.loadModel(modelLocation, loadedListener);
235+
processContext.loadModel(modelLocation, ActionListener.wrap(success -> {
236+
processContext.startPriorityProcessWorker();
237+
loadedListener.onResponse(success);
238+
}, loadedListener::onFailure));
236239
} catch (Exception e) {
237240
loadedListener.onFailure(e);
238241
}
@@ -332,13 +335,13 @@ public void clearCache(TrainedModelDeploymentTask task, TimeValue timeout, Actio
332335
executePyTorchAction(processContext, PriorityProcessWorkerExecutorService.RequestPriority.HIGHEST, controlMessageAction);
333336
}
334337

335-
public void executePyTorchAction(
338+
void executePyTorchAction(
336339
ProcessContext processContext,
337340
PriorityProcessWorkerExecutorService.RequestPriority priority,
338341
AbstractPyTorchAction<?> action
339342
) {
340343
try {
341-
processContext.getExecutorService().executeWithPriority(action, priority, action.getRequestId());
344+
processContext.getPriorityProcessWorker().executeWithPriority(action, priority, action.getRequestId());
342345
} catch (EsRejectedExecutionException e) {
343346
processContext.getRejectedExecutionCount().incrementAndGet();
344347
action.onFailure(e);
@@ -376,21 +379,21 @@ class ProcessContext {
376379
private final SetOnce<TrainedModelInput> modelInput = new SetOnce<>();
377380
private final PyTorchResultProcessor resultProcessor;
378381
private final PyTorchStateStreamer stateStreamer;
379-
private final PriorityProcessWorkerExecutorService executorService;
382+
private final PriorityProcessWorkerExecutorService priorityProcessWorker;
380383
private volatile Instant startTime;
381384
private volatile Integer numThreadsPerAllocation;
382385
private volatile Integer numAllocations;
383386
private final AtomicInteger rejectedExecutionCount = new AtomicInteger();
384387
private final AtomicInteger timeoutCount = new AtomicInteger();
385388

386-
ProcessContext(TrainedModelDeploymentTask task, ExecutorService executorService) {
389+
ProcessContext(TrainedModelDeploymentTask task) {
387390
this.task = Objects.requireNonNull(task);
388391
resultProcessor = new PyTorchResultProcessor(task.getModelId(), threadSettings -> {
389392
this.numThreadsPerAllocation = threadSettings.numThreadsPerAllocation();
390393
this.numAllocations = threadSettings.numAllocations();
391394
});
392-
this.stateStreamer = new PyTorchStateStreamer(client, executorService, xContentRegistry);
393-
this.executorService = new PriorityProcessWorkerExecutorService(
395+
this.stateStreamer = new PyTorchStateStreamer(client, executorServiceForProcess, xContentRegistry);
396+
this.priorityProcessWorker = new PriorityProcessWorkerExecutorService(
394397
threadPool.getThreadContext(),
395398
"inference process",
396399
task.getParams().getQueueCapacity()
@@ -404,12 +407,15 @@ PyTorchResultProcessor getResultProcessor() {
404407
synchronized void startProcess() {
405408
process.set(pyTorchProcessFactory.createProcess(task, executorServiceForProcess, onProcessCrash()));
406409
startTime = Instant.now();
407-
executorServiceForProcess.submit(executorService::start);
410+
}
411+
412+
void startPriorityProcessWorker() {
413+
executorServiceForProcess.submit(priorityProcessWorker::start);
408414
}
409415

410416
synchronized void stopProcess() {
411417
resultProcessor.stop();
412-
executorService.shutdown();
418+
priorityProcessWorker.shutdown();
413419
try {
414420
if (process.get() == null) {
415421
return;
@@ -430,7 +436,7 @@ private Consumer<String> onProcessCrash() {
430436
return reason -> {
431437
logger.error("[{}] inference process crashed due to reason [{}]", task.getModelId(), reason);
432438
resultProcessor.stop();
433-
executorService.shutdownWithError(new IllegalStateException(reason));
439+
priorityProcessWorker.shutdownWithError(new IllegalStateException(reason));
434440
processContextByAllocation.remove(task.getId());
435441
if (nlpTaskProcessor.get() != null) {
436442
nlpTaskProcessor.get().close();
@@ -441,6 +447,7 @@ private Consumer<String> onProcessCrash() {
441447

442448
void loadModel(TrainedModelLocation modelLocation, ActionListener<Boolean> listener) {
443449
if (modelLocation instanceof IndexLocation indexLocation) {
450+
logger.debug("[{}] loading model state", task.getModelId());
444451
process.get().loadModel(task.getModelId(), indexLocation.getIndexName(), stateStreamer, listener);
445452
} else {
446453
listener.onFailure(
@@ -455,8 +462,8 @@ AtomicInteger getTimeoutCount() {
455462
}
456463

457464
// accessor used for mocking in tests
458-
PriorityProcessWorkerExecutorService getExecutorService() {
459-
return executorService;
465+
PriorityProcessWorkerExecutorService getPriorityProcessWorker() {
466+
return priorityProcessWorker;
460467
}
461468

462469
// accessor used for mocking in tests

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManagerTests.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,19 +79,19 @@ public void testRejectedExecution() {
7979
mock(PyTorchProcessFactory.class)
8080
);
8181

82-
PriorityProcessWorkerExecutorService executorService = new PriorityProcessWorkerExecutorService(
82+
PriorityProcessWorkerExecutorService priorityExecutorService = new PriorityProcessWorkerExecutorService(
8383
tp.getThreadContext(),
8484
"test reject",
8585
10
8686
);
87-
executorService.shutdown();
87+
priorityExecutorService.shutdown();
8888

8989
AtomicInteger rejectedCount = new AtomicInteger();
9090

9191
DeploymentManager.ProcessContext context = mock(DeploymentManager.ProcessContext.class);
9292
PyTorchResultProcessor resultProcessor = new PyTorchResultProcessor("1", threadSettings -> {});
9393
when(context.getResultProcessor()).thenReturn(resultProcessor);
94-
when(context.getExecutorService()).thenReturn(executorService);
94+
when(context.getPriorityProcessWorker()).thenReturn(priorityExecutorService);
9595
when(context.getRejectedExecutionCount()).thenReturn(rejectedCount);
9696

9797
deploymentManager.addProcessContext(taskId, context);

0 commit comments

Comments
 (0)