diff --git a/docs/changelog/125204.yaml b/docs/changelog/125204.yaml new file mode 100644 index 0000000000000..de0ca932aafe0 --- /dev/null +++ b/docs/changelog/125204.yaml @@ -0,0 +1,6 @@ +pr: 125204 +summary: Return a Conflict status code if the model deployment is stopped by a user +area: Machine Learning +type: bug +issues: + - 123745 diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/AbstractControlMessagePyTorchAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/AbstractControlMessagePyTorchAction.java index fd9ebbd6cee70..a8df1a11f4869 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/AbstractControlMessagePyTorchAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/AbstractControlMessagePyTorchAction.java @@ -87,7 +87,7 @@ final BytesReference buildControlMessage(String requestId) throws IOException { private void processResponse(PyTorchResult result) { if (result.isError()) { - onFailure(result.errorResult().error()); + onFailure(result.errorResult()); return; } onSuccess(getResult(result)); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/AbstractPyTorchAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/AbstractPyTorchAction.java index 8aa3b310da21a..7c7feff12aed9 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/AbstractPyTorchAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/AbstractPyTorchAction.java @@ -16,6 +16,7 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.ml.inference.pytorch.results.ErrorResult; import org.elasticsearch.xpack.ml.job.process.AbstractInitializableRunnable; import java.util.concurrent.atomic.AtomicBoolean; @@ -116,8 +117,9 @@ public void onFailure(Exception e) { getLogger().debug(() -> format("[%s] request [%s] received failure but listener already notified", deploymentId, requestId), e); } - protected void onFailure(String errorMessage) { - onFailure(new ElasticsearchStatusException("Error in inference process: [" + errorMessage + "]", RestStatus.INTERNAL_SERVER_ERROR)); + protected void onFailure(ErrorResult errorResult) { + var restStatus = errorResult.isStopping() ? RestStatus.CONFLICT : RestStatus.INTERNAL_SERVER_ERROR; + onFailure(new ElasticsearchStatusException("Error in inference process: [" + errorResult.error() + "]", restStatus)); } boolean isNotified() { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/InferencePyTorchAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/InferencePyTorchAction.java index a0fc00af3f859..87337439ce47d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/InferencePyTorchAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/InferencePyTorchAction.java @@ -178,7 +178,7 @@ private void processResult( NlpTask.ResultProcessor inferenceResultsProcessor ) { if (pyTorchResult.isError()) { - onFailure(pyTorchResult.errorResult().error()); + onFailure(pyTorchResult.errorResult()); return; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java index 7dd0dae1e3ad7..68389e6ca7165 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java @@ -130,11 +130,12 @@ public void process(PyTorchProcess process) { var errorResult = new ErrorResult( isStopping ? "inference canceled as process is stopping" - : "inference native process died unexpectedly with failure [" + e.getMessage() + "]" + : "inference native process died unexpectedly with failure [" + e.getMessage() + "]", + isStopping ); notifyAndClearPendingResults(errorResult); } finally { - notifyAndClearPendingResults(new ErrorResult("inference canceled as process is stopping")); + notifyAndClearPendingResults(new ErrorResult("inference canceled as process is stopping", true)); processorCompletionLatch.countDown(); } logger.debug(() -> "[" + modelId + "] Results processing finished"); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/ErrorResult.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/ErrorResult.java index d5fd64c74c54e..4461c7d1b6562 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/ErrorResult.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/ErrorResult.java @@ -14,7 +14,7 @@ import java.io.IOException; -public record ErrorResult(String error) implements ToXContentObject { +public record ErrorResult(String error, boolean isStopping) implements ToXContentObject { public static final ParseField ERROR = new ParseField("error"); @@ -23,6 +23,10 @@ public record ErrorResult(String error) implements ToXContentObject { a -> new ErrorResult((String) a[0]) ); + public ErrorResult(String error) { + this(error, false); + } + static { PARSER.declareString(ConstructingObjectParser.constructorArg(), ERROR); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessorTests.java index 69a9d4e9430c3..abf926a67f1b3 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessorTests.java @@ -134,11 +134,16 @@ public void testCancelPendingRequest() { public void testPendingRequestAreCalledAtShutdown() { var processor = new PyTorchResultProcessor("foo", s -> {}); + Consumer resultChecker = r -> { + assertTrue(r.errorResult().isStopping()); + assertEquals(r.errorResult().error(), "inference canceled as process is stopping"); + }; + var listeners = List.of( - new AssertingResultListener(r -> assertEquals(r.errorResult().error(), "inference canceled as process is stopping")), - new AssertingResultListener(r -> assertEquals(r.errorResult().error(), "inference canceled as process is stopping")), - new AssertingResultListener(r -> assertEquals(r.errorResult().error(), "inference canceled as process is stopping")), - new AssertingResultListener(r -> assertEquals(r.errorResult().error(), "inference canceled as process is stopping")) + new AssertingResultListener(resultChecker), + new AssertingResultListener(resultChecker), + new AssertingResultListener(resultChecker), + new AssertingResultListener(resultChecker) ); int i = 0; @@ -153,6 +158,33 @@ public void testPendingRequestAreCalledAtShutdown() { } } + public void testPendingRequestAreCalledOnException() { + var processor = new PyTorchResultProcessor("foo", s -> {}); + + Consumer resultChecker = r -> { + assertFalse(r.errorResult().isStopping()); + assertEquals(r.errorResult().error(), "inference native process died unexpectedly with failure [mocked exception]"); + }; + + var listeners = List.of( + new AssertingResultListener(resultChecker), + new AssertingResultListener(resultChecker), + new AssertingResultListener(resultChecker), + new AssertingResultListener(resultChecker) + ); + + int i = 0; + for (var l : listeners) { + processor.registerRequest(Integer.toString(i++), l); + } + + processor.process(throwingNativeProcess()); + + for (var l : listeners) { + assertTrue(l.hasResponse); + } + } + public void testsHandleUnknownResult() { var processor = new PyTorchResultProcessor("deployment-foo", settings -> {}); var listener = new AssertingResultListener( @@ -379,4 +411,10 @@ private NativePyTorchProcess mockNativeProcess(Iterator results) when(process.readResults()).thenReturn(results); return process; } + + private NativePyTorchProcess throwingNativeProcess() { + var process = mock(NativePyTorchProcess.class); + when(process.readResults()).thenThrow(new RuntimeException("mocked exception")); + return process; + } }