Skip to content

Commit f994811

Browse files
Marking task as failed
1 parent b83d2dd commit f994811

File tree

2 files changed

+23
-3
lines changed

2 files changed

+23
-3
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,11 @@ protected void onCancelled() {
152152
markAsCompleted();
153153
}
154154

155+
private void shutdownAndMarkTaskAsFailed(Exception e) {
156+
shutdown();
157+
markAsFailed(e);
158+
}
159+
155160
// default for testing
156161
void shutdown() {
157162
shutdown.set(true);
@@ -198,7 +203,7 @@ private void scheduleAuthorizationRequest() {
198203
} catch (Exception e) {
199204
logger.warn("Failed scheduling authorization request", e);
200205
// Shutdown and complete the task so it will be restarted
201-
onCancelled();
206+
shutdownAndMarkTaskAsFailed(e);
202207
}
203208
}
204209

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
import static org.mockito.Mockito.doAnswer;
4444
import static org.mockito.Mockito.mock;
4545
import static org.mockito.Mockito.never;
46+
import static org.mockito.Mockito.times;
4647
import static org.mockito.Mockito.verify;
4748
import static org.mockito.Mockito.when;
4849

@@ -354,10 +355,11 @@ public void testCallsShutdownAndMarksTaskAsCompleted_WhenSchedulingFails() throw
354355
latch.countDown();
355356
};
356357

358+
var exception = new IllegalStateException("failing");
357359
// Simulate scheduling failure by having the settings throw an exception when queried
358360
// Throwing an exception should cause the poller to shutdown and mark itself as completed
359361
var settingsMock = mock(ElasticInferenceServiceSettings.class);
360-
when(settingsMock.isPeriodicAuthorizationEnabled()).thenThrow(new IllegalStateException("failing"));
362+
when(settingsMock.isPeriodicAuthorizationEnabled()).thenThrow(exception);
361363

362364
var poller = new AuthorizationPoller(
363365
new AuthorizationPoller.TaskFields(0, "abc", "abc", "abc", new TaskId("abc", 0), Map.of()),
@@ -369,13 +371,26 @@ public void testCallsShutdownAndMarksTaskAsCompleted_WhenSchedulingFails() throw
369371
mockClient,
370372
callback
371373
);
372-
poller.init(mock(PersistentTasksService.class), mock(TaskManager.class), "id", 0);
374+
375+
var persistentTaskId = "id";
376+
var allocationId = 0L;
377+
378+
var mockPersistentTasksService = mock(PersistentTasksService.class);
379+
poller.init(mockPersistentTasksService, mock(TaskManager.class), persistentTaskId, allocationId);
373380
poller.start();
374381
taskQueue.runAllRunnableTasks();
375382
latch.await(TimeValue.THIRTY_SECONDS.getSeconds(), TimeUnit.SECONDS);
376383

377384
assertThat(callbackCount.get(), is(1));
378385
assertTrue(poller.isShutdown());
386+
verify(mockPersistentTasksService, times(1)).sendCompletionRequest(
387+
eq(persistentTaskId),
388+
eq(allocationId),
389+
eq(exception),
390+
eq(null),
391+
any(),
392+
any()
393+
);
379394
verify(mockClient, never()).execute(eq(StoreInferenceEndpointsAction.INSTANCE), any(), any());
380395
}
381396
}

0 commit comments

Comments
 (0)