Skip to content

Commit deb1f22

Browse files
authored
[ML] Return a Conflict status code if the model deployment is stopped by a user (#125204) (#125486)
1 parent ab8c388 commit deb1f22

File tree

7 files changed

+62
-11
lines changed

7 files changed

+62
-11
lines changed

docs/changelog/125204.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 125204
2+
summary: Return a Conflict status code if the model deployment is stopped by a user
3+
area: Machine Learning
4+
type: bug
5+
issues:
6+
- 123745

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/AbstractControlMessagePyTorchAction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ final BytesReference buildControlMessage(String requestId) throws IOException {
8787

8888
private void processResponse(PyTorchResult result) {
8989
if (result.isError()) {
90-
onFailure(result.errorResult().error());
90+
onFailure(result.errorResult());
9191
return;
9292
}
9393
onSuccess(getResult(result));

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/AbstractPyTorchAction.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import org.elasticsearch.threadpool.ThreadPool;
1717
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
1818
import org.elasticsearch.xpack.ml.MachineLearning;
19+
import org.elasticsearch.xpack.ml.inference.pytorch.results.ErrorResult;
1920
import org.elasticsearch.xpack.ml.job.process.AbstractInitializableRunnable;
2021

2122
import java.util.concurrent.atomic.AtomicBoolean;
@@ -116,8 +117,9 @@ public void onFailure(Exception e) {
116117
getLogger().debug(() -> format("[%s] request [%s] received failure but listener already notified", deploymentId, requestId), e);
117118
}
118119

119-
protected void onFailure(String errorMessage) {
120-
onFailure(new ElasticsearchStatusException("Error in inference process: [" + errorMessage + "]", RestStatus.INTERNAL_SERVER_ERROR));
120+
protected void onFailure(ErrorResult errorResult) {
121+
var restStatus = errorResult.isStopping() ? RestStatus.CONFLICT : RestStatus.INTERNAL_SERVER_ERROR;
122+
onFailure(new ElasticsearchStatusException("Error in inference process: [" + errorResult.error() + "]", restStatus));
121123
}
122124

123125
boolean isNotified() {

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/InferencePyTorchAction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ private void processResult(
178178
NlpTask.ResultProcessor inferenceResultsProcessor
179179
) {
180180
if (pyTorchResult.isError()) {
181-
onFailure(pyTorchResult.errorResult().error());
181+
onFailure(pyTorchResult.errorResult());
182182
return;
183183
}
184184

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,11 +130,12 @@ public void process(PyTorchProcess process) {
130130
var errorResult = new ErrorResult(
131131
isStopping
132132
? "inference canceled as process is stopping"
133-
: "inference native process died unexpectedly with failure [" + e.getMessage() + "]"
133+
: "inference native process died unexpectedly with failure [" + e.getMessage() + "]",
134+
isStopping
134135
);
135136
notifyAndClearPendingResults(errorResult);
136137
} finally {
137-
notifyAndClearPendingResults(new ErrorResult("inference canceled as process is stopping"));
138+
notifyAndClearPendingResults(new ErrorResult("inference canceled as process is stopping", true));
138139
processorCompletionLatch.countDown();
139140
}
140141
logger.debug(() -> "[" + modelId + "] Results processing finished");

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/ErrorResult.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import java.io.IOException;
1616

17-
public record ErrorResult(String error) implements ToXContentObject {
17+
public record ErrorResult(String error, boolean isStopping) implements ToXContentObject {
1818

1919
public static final ParseField ERROR = new ParseField("error");
2020

@@ -23,6 +23,10 @@ public record ErrorResult(String error) implements ToXContentObject {
2323
a -> new ErrorResult((String) a[0])
2424
);
2525

26+
public ErrorResult(String error) {
27+
this(error, false);
28+
}
29+
2630
static {
2731
PARSER.declareString(ConstructingObjectParser.constructorArg(), ERROR);
2832
}

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessorTests.java

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,11 +134,16 @@ public void testCancelPendingRequest() {
134134
public void testPendingRequestAreCalledAtShutdown() {
135135
var processor = new PyTorchResultProcessor("foo", s -> {});
136136

137+
Consumer<PyTorchResult> resultChecker = r -> {
138+
assertTrue(r.errorResult().isStopping());
139+
assertEquals(r.errorResult().error(), "inference canceled as process is stopping");
140+
};
141+
137142
var listeners = List.of(
138-
new AssertingResultListener(r -> assertEquals(r.errorResult().error(), "inference canceled as process is stopping")),
139-
new AssertingResultListener(r -> assertEquals(r.errorResult().error(), "inference canceled as process is stopping")),
140-
new AssertingResultListener(r -> assertEquals(r.errorResult().error(), "inference canceled as process is stopping")),
141-
new AssertingResultListener(r -> assertEquals(r.errorResult().error(), "inference canceled as process is stopping"))
143+
new AssertingResultListener(resultChecker),
144+
new AssertingResultListener(resultChecker),
145+
new AssertingResultListener(resultChecker),
146+
new AssertingResultListener(resultChecker)
142147
);
143148

144149
int i = 0;
@@ -153,6 +158,33 @@ public void testPendingRequestAreCalledAtShutdown() {
153158
}
154159
}
155160

161+
public void testPendingRequestAreCalledOnException() {
162+
var processor = new PyTorchResultProcessor("foo", s -> {});
163+
164+
Consumer<PyTorchResult> resultChecker = r -> {
165+
assertFalse(r.errorResult().isStopping());
166+
assertEquals(r.errorResult().error(), "inference native process died unexpectedly with failure [mocked exception]");
167+
};
168+
169+
var listeners = List.of(
170+
new AssertingResultListener(resultChecker),
171+
new AssertingResultListener(resultChecker),
172+
new AssertingResultListener(resultChecker),
173+
new AssertingResultListener(resultChecker)
174+
);
175+
176+
int i = 0;
177+
for (var l : listeners) {
178+
processor.registerRequest(Integer.toString(i++), l);
179+
}
180+
181+
processor.process(throwingNativeProcess());
182+
183+
for (var l : listeners) {
184+
assertTrue(l.hasResponse);
185+
}
186+
}
187+
156188
public void testsHandleUnknownResult() {
157189
var processor = new PyTorchResultProcessor("deployment-foo", settings -> {});
158190
var listener = new AssertingResultListener(
@@ -379,4 +411,10 @@ private NativePyTorchProcess mockNativeProcess(Iterator<PyTorchResult> results)
379411
when(process.readResults()).thenReturn(results);
380412
return process;
381413
}
414+
415+
private NativePyTorchProcess throwingNativeProcess() {
416+
var process = mock(NativePyTorchProcess.class);
417+
when(process.readResults()).thenThrow(new RuntimeException("mocked exception"));
418+
return process;
419+
}
382420
}

0 commit comments

Comments
 (0)