Skip to content

Commit 9d3e1f3

Browse files
jan-elasticdavidkyle
authored andcommitted
Fix inference stats for cancellations (elastic#112233)
* Fix inference stats for cancellations * Fix PyTorchResultProcessorTests * Refactor onCancel * Update x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/AbstractPyTorchAction.java Co-authored-by: David Kyle <[email protected]> --------- Co-authored-by: David Kyle <[email protected]>
1 parent 71f5845 commit 9d3e1f3

File tree

4 files changed

+38
-15
lines changed

4 files changed

+38
-15
lines changed

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/AbstractPyTorchAction.java

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,18 @@ public final void init() {
6565
}
6666

6767
void onTimeout() {
68+
onTimeout(new ElasticsearchStatusException("timeout [{}] waiting for inference result", RestStatus.REQUEST_TIMEOUT, timeout));
69+
}
70+
71+
void onCancel() {
72+
onTimeout(new ElasticsearchStatusException("inference task cancelled", RestStatus.BAD_REQUEST));
73+
}
74+
75+
void onTimeout(Exception e) {
6876
if (notified.compareAndSet(false, true)) {
6977
processContext.getTimeoutCount().incrementAndGet();
7078
processContext.getResultProcessor().ignoreResponseWithoutNotifying(String.valueOf(requestId));
71-
listener.onFailure(
72-
new ElasticsearchStatusException("timeout [{}] waiting for inference result", RestStatus.REQUEST_TIMEOUT, timeout)
73-
);
79+
listener.onFailure(e);
7480
return;
7581
}
7682
getLogger().debug("[{}] request [{}] received timeout after [{}] but listener already alerted", deploymentId, requestId, timeout);

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/InferencePyTorchAction.java

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,12 +84,11 @@ protected void doRun() throws Exception {
8484
logger.debug(() -> format("[%s] skipping inference on request [%s] as it has timed out", getDeploymentId(), getRequestId()));
8585
return;
8686
}
87+
final String requestIdStr = String.valueOf(getRequestId());
8788
if (isCancelled()) {
88-
onFailure("inference task cancelled");
89+
onCancel();
8990
return;
9091
}
91-
92-
final String requestIdStr = String.valueOf(getRequestId());
9392
try {
9493
String inputText = input.extractInput(getProcessContext().getModelInput().get());
9594
if (prefixType != TrainedModelPrefixStrings.PrefixType.NONE) {
@@ -141,7 +140,7 @@ protected void doRun() throws Exception {
141140

142141
// Tokenization is non-trivial, so check for cancellation one last time before sending request to the native process
143142
if (isCancelled()) {
144-
onFailure("inference task cancelled");
143+
onCancel();
145144
return;
146145
}
147146
getProcessContext().getResultProcessor()
@@ -196,9 +195,11 @@ private void processResult(
196195
return;
197196
}
198197
if (isCancelled()) {
199-
onFailure("inference task cancelled");
198+
onCancel();
200199
return;
201200
}
201+
202+
getProcessContext().getResultProcessor().updateStats(pyTorchResult);
202203
InferenceResults results = inferenceResultsProcessor.processResult(
203204
tokenization,
204205
pyTorchResult.inferenceResult(),

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -153,18 +153,12 @@ private void notifyAndClearPendingResults(ErrorResult errorResult) {
153153
void processInferenceResult(PyTorchResult result) {
154154
PyTorchInferenceResult inferenceResult = result.inferenceResult();
155155
assert inferenceResult != null;
156-
Long timeMs = result.timeMs();
157-
if (timeMs == null) {
158-
assert false : "time_ms should be set for an inference result";
159-
timeMs = 0L;
160-
}
161156

162157
logger.debug(() -> format("[%s] Parsed inference result with id [%s]", modelId, result.requestId()));
163158
PendingResult pendingResult = pendingResults.remove(result.requestId());
164159
if (pendingResult == null) {
165160
logger.debug(() -> format("[%s] no pending result for inference [%s]", modelId, result.requestId()));
166161
} else {
167-
updateStats(timeMs, Boolean.TRUE.equals(result.isCacheHit()));
168162
pendingResult.listener.onResponse(result);
169163
}
170164
}
@@ -273,7 +267,13 @@ private static LongSummaryStatistics cloneSummaryStats(LongSummaryStatistics sta
273267
return new LongSummaryStatistics(stats.getCount(), stats.getMin(), stats.getMax(), stats.getSum());
274268
}
275269

276-
private synchronized void updateStats(long timeMs, boolean isCacheHit) {
270+
public synchronized void updateStats(PyTorchResult result) {
271+
Long timeMs = result.timeMs();
272+
if (timeMs == null) {
273+
assert false : "time_ms should be set for an inference result";
274+
timeMs = 0L;
275+
}
276+
boolean isCacheHit = Boolean.TRUE.equals(result.isCacheHit());
277277
timingStats.accept(timeMs);
278278

279279
lastResultTimeMs = currentTimeMsSupplier.getAsLong();

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessorTests.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,8 @@ public void testsStats() {
210210
var c = wrapInferenceResult("c", true, 200L); // cache hit
211211

212212
processor.processInferenceResult(a);
213+
processor.updateStats(a);
214+
213215
var stats = processor.getResultStats();
214216
assertThat(stats.errorCount(), equalTo(0));
215217
assertThat(stats.cacheHitCount(), equalTo(0L));
@@ -220,6 +222,8 @@ public void testsStats() {
220222
assertThat(stats.timingStatsExcludingCacheHits().getSum(), equalTo(1000L));
221223

222224
processor.processInferenceResult(b);
225+
processor.updateStats(b);
226+
223227
stats = processor.getResultStats();
224228
assertThat(stats.errorCount(), equalTo(0));
225229
assertThat(stats.cacheHitCount(), equalTo(0L));
@@ -230,6 +234,8 @@ public void testsStats() {
230234
assertThat(stats.timingStatsExcludingCacheHits().getSum(), equalTo(1900L));
231235

232236
processor.processInferenceResult(c);
237+
processor.updateStats(c);
238+
233239
stats = processor.getResultStats();
234240
assertThat(stats.errorCount(), equalTo(0));
235241
assertThat(stats.cacheHitCount(), equalTo(1L));
@@ -284,6 +290,9 @@ public void testsTimeDependentStats() {
284290
processor.processInferenceResult(wrapInferenceResult("foo0", false, 200L));
285291
processor.processInferenceResult(wrapInferenceResult("foo1", false, 200L));
286292
processor.processInferenceResult(wrapInferenceResult("foo2", false, 200L));
293+
processor.updateStats(wrapInferenceResult("foo0", false, 200L));
294+
processor.updateStats(wrapInferenceResult("foo1", false, 200L));
295+
processor.updateStats(wrapInferenceResult("foo2", false, 200L));
287296

288297
// first call has no results as is in the same period
289298
var stats = processor.getResultStats();
@@ -299,6 +308,7 @@ public void testsTimeDependentStats() {
299308

300309
// 2nd period
301310
processor.processInferenceResult(wrapInferenceResult("foo3", false, 100L));
311+
processor.updateStats(wrapInferenceResult("foo3", false, 100L));
302312
stats = processor.getResultStats();
303313
assertNotNull(stats.recentStats());
304314
assertThat(stats.recentStats().requestsProcessed(), equalTo(1L));
@@ -311,6 +321,7 @@ public void testsTimeDependentStats() {
311321

312322
// 4th period
313323
processor.processInferenceResult(wrapInferenceResult("foo4", false, 300L));
324+
processor.updateStats(wrapInferenceResult("foo4", false, 300L));
314325
stats = processor.getResultStats();
315326
assertNotNull(stats.recentStats());
316327
assertThat(stats.recentStats().requestsProcessed(), equalTo(1L));
@@ -320,6 +331,8 @@ public void testsTimeDependentStats() {
320331
// 7th period
321332
processor.processInferenceResult(wrapInferenceResult("foo5", false, 410L));
322333
processor.processInferenceResult(wrapInferenceResult("foo6", false, 390L));
334+
processor.updateStats(wrapInferenceResult("foo5", false, 410L));
335+
processor.updateStats(wrapInferenceResult("foo6", false, 390L));
323336
stats = processor.getResultStats();
324337
assertThat(stats.recentStats().requestsProcessed(), equalTo(0L));
325338
assertThat(stats.recentStats().avgInferenceTime(), nullValue());
@@ -333,6 +346,9 @@ public void testsTimeDependentStats() {
333346
processor.processInferenceResult(wrapInferenceResult("foo7", false, 510L));
334347
processor.processInferenceResult(wrapInferenceResult("foo8", false, 500L));
335348
processor.processInferenceResult(wrapInferenceResult("foo9", false, 490L));
349+
processor.updateStats(wrapInferenceResult("foo7", false, 510L));
350+
processor.updateStats(wrapInferenceResult("foo8", false, 500L));
351+
processor.updateStats(wrapInferenceResult("foo9", false, 490L));
336352
stats = processor.getResultStats();
337353
assertNotNull(stats.recentStats());
338354
assertThat(stats.recentStats().requestsProcessed(), equalTo(3L));

0 commit comments

Comments
 (0)