Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/125204.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 125204
summary: Return a Conflict status code if the model deployment is stopped by a user
area: Machine Learning
type: bug
issues:
- 123745
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ final BytesReference buildControlMessage(String requestId) throws IOException {

private void processResponse(PyTorchResult result) {
if (result.isError()) {
onFailure(result.errorResult().error());
onFailure(result.errorResult());
return;
}
onSuccess(getResult(result));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.inference.pytorch.results.ErrorResult;
import org.elasticsearch.xpack.ml.job.process.AbstractInitializableRunnable;

import java.util.concurrent.atomic.AtomicBoolean;
Expand Down Expand Up @@ -116,8 +117,9 @@ public void onFailure(Exception e) {
getLogger().debug(() -> format("[%s] request [%s] received failure but listener already notified", deploymentId, requestId), e);
}

protected void onFailure(String errorMessage) {
onFailure(new ElasticsearchStatusException("Error in inference process: [" + errorMessage + "]", RestStatus.INTERNAL_SERVER_ERROR));
protected void onFailure(ErrorResult errorResult) {
var restStatus = errorResult.isStopping() ? RestStatus.CONFLICT : RestStatus.INTERNAL_SERVER_ERROR;
onFailure(new ElasticsearchStatusException("Error in inference process: [" + errorResult.error() + "]", restStatus));
}

boolean isNotified() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ private void processResult(
NlpTask.ResultProcessor inferenceResultsProcessor
) {
if (pyTorchResult.isError()) {
onFailure(pyTorchResult.errorResult().error());
onFailure(pyTorchResult.errorResult());
return;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,12 @@ public void process(PyTorchProcess process) {
var errorResult = new ErrorResult(
isStopping
? "inference canceled as process is stopping"
: "inference native process died unexpectedly with failure [" + e.getMessage() + "]"
: "inference native process died unexpectedly with failure [" + e.getMessage() + "]",
isStopping
);
notifyAndClearPendingResults(errorResult);
} finally {
notifyAndClearPendingResults(new ErrorResult("inference canceled as process is stopping"));
notifyAndClearPendingResults(new ErrorResult("inference canceled as process is stopping", true));
processorCompletionLatch.countDown();
}
logger.debug(() -> "[" + modelId + "] Results processing finished");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import java.io.IOException;

public record ErrorResult(String error) implements ToXContentObject {
public record ErrorResult(String error, boolean isStopping) implements ToXContentObject {

public static final ParseField ERROR = new ParseField("error");

Expand All @@ -23,6 +23,10 @@ public record ErrorResult(String error) implements ToXContentObject {
a -> new ErrorResult((String) a[0])
);

public ErrorResult(String error) {
this(error, false);
}

static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), ERROR);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,16 @@ public void testCancelPendingRequest() {
public void testPendingRequestAreCalledAtShutdown() {
var processor = new PyTorchResultProcessor("foo", s -> {});

Consumer<PyTorchResult> resultChecker = r -> {
assertTrue(r.errorResult().isStopping());
assertEquals(r.errorResult().error(), "inference canceled as process is stopping");
};

var listeners = List.of(
new AssertingResultListener(r -> assertEquals(r.errorResult().error(), "inference canceled as process is stopping")),
new AssertingResultListener(r -> assertEquals(r.errorResult().error(), "inference canceled as process is stopping")),
new AssertingResultListener(r -> assertEquals(r.errorResult().error(), "inference canceled as process is stopping")),
new AssertingResultListener(r -> assertEquals(r.errorResult().error(), "inference canceled as process is stopping"))
new AssertingResultListener(resultChecker),
new AssertingResultListener(resultChecker),
new AssertingResultListener(resultChecker),
new AssertingResultListener(resultChecker)
);

int i = 0;
Expand All @@ -153,6 +158,33 @@ public void testPendingRequestAreCalledAtShutdown() {
}
}

public void testPendingRequestAreCalledOnException() {
var processor = new PyTorchResultProcessor("foo", s -> {});

Consumer<PyTorchResult> resultChecker = r -> {
assertFalse(r.errorResult().isStopping());
assertEquals(r.errorResult().error(), "inference native process died unexpectedly with failure [mocked exception]");
};

var listeners = List.of(
new AssertingResultListener(resultChecker),
new AssertingResultListener(resultChecker),
new AssertingResultListener(resultChecker),
new AssertingResultListener(resultChecker)
);

int i = 0;
for (var l : listeners) {
processor.registerRequest(Integer.toString(i++), l);
}

processor.process(throwingNativeProcess());

for (var l : listeners) {
assertTrue(l.hasResponse);
}
}

public void testsHandleUnknownResult() {
var processor = new PyTorchResultProcessor("deployment-foo", settings -> {});
var listener = new AssertingResultListener(
Expand Down Expand Up @@ -379,4 +411,10 @@ private NativePyTorchProcess mockNativeProcess(Iterator<PyTorchResult> results)
when(process.readResults()).thenReturn(results);
return process;
}

private NativePyTorchProcess throwingNativeProcess() {
var process = mock(NativePyTorchProcess.class);
when(process.readResults()).thenThrow(new RuntimeException("mocked exception"));
return process;
}
}
Loading