|
12 | 12 | import org.elasticsearch.common.util.concurrent.DeterministicTaskQueue; |
13 | 13 | import org.elasticsearch.core.TimeValue; |
14 | 14 | import org.elasticsearch.inference.TaskType; |
| 15 | +import org.elasticsearch.persistent.PersistentTasksService; |
15 | 16 | import org.elasticsearch.tasks.TaskId; |
| 17 | +import org.elasticsearch.tasks.TaskManager; |
16 | 18 | import org.elasticsearch.test.ESTestCase; |
17 | 19 | import org.elasticsearch.xpack.core.inference.action.StoreInferenceEndpointsAction; |
18 | 20 | import org.elasticsearch.xpack.inference.external.http.sender.Sender; |
19 | 21 | import org.elasticsearch.xpack.inference.registry.ModelRegistry; |
20 | 22 | import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; |
| 23 | +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings; |
21 | 24 | import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettingsTests; |
22 | 25 | import org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints; |
23 | 26 | import org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntity; |
@@ -315,4 +318,64 @@ public void testSendsTwoAuthorizationRequests() throws InterruptedException { |
315 | 318 | assertThat(callbackCount.get(), is(2)); |
316 | 319 | verify(mockClient, never()).execute(eq(StoreInferenceEndpointsAction.INSTANCE), any(), any()); |
317 | 320 | } |
| 321 | + |
| 322 | + public void testCallsShutdownAndMarksTaskAsCompleted_WhenSchedulingFails() throws InterruptedException { |
| 323 | + var mockRegistry = mock(ModelRegistry.class); |
| 324 | + when(mockRegistry.isReady()).thenReturn(true); |
| 325 | + when(mockRegistry.getInferenceIds()).thenReturn(Set.of("id1", "id2")); |
| 326 | + |
| 327 | + var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class); |
| 328 | + doAnswer(invocation -> { |
| 329 | + ActionListener<ElasticInferenceServiceAuthorizationModel> listener = invocation.getArgument(0); |
| 330 | + listener.onResponse( |
| 331 | + ElasticInferenceServiceAuthorizationModel.of( |
| 332 | + new ElasticInferenceServiceAuthorizationResponseEntity( |
| 333 | + List.of( |
| 334 | + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( |
| 335 | + // this is an unknown model id so it won't trigger storing an inference endpoint because |
| 336 | + // it doesn't map to a known one |
| 337 | + "abc", |
| 338 | + EnumSet.of(TaskType.SPARSE_EMBEDDING) |
| 339 | + ) |
| 340 | + ) |
| 341 | + ) |
| 342 | + ) |
| 343 | + ); |
| 344 | + return Void.TYPE; |
| 345 | + }).when(mockAuthHandler).getAuthorization(any(), any()); |
| 346 | + |
| 347 | + var mockClient = mock(Client.class); |
| 348 | + |
| 349 | + var callbackCount = new AtomicInteger(0); |
| 350 | + var latch = new CountDownLatch(1); |
| 351 | + |
| 352 | + Runnable callback = () -> { |
| 353 | + callbackCount.incrementAndGet(); |
| 354 | + latch.countDown(); |
| 355 | + }; |
| 356 | + |
| 357 | + // Simulate scheduling failure by having the settings throw an exception when queried |
| 358 | + // Throwing an exception should cause the poller to shutdown and mark itself as completed |
| 359 | + var settingsMock = mock(ElasticInferenceServiceSettings.class); |
| 360 | + when(settingsMock.isPeriodicAuthorizationEnabled()).thenThrow(new IllegalStateException("failing")); |
| 361 | + |
| 362 | + var poller = new AuthorizationPoller( |
| 363 | + new AuthorizationPoller.TaskFields(0, "abc", "abc", "abc", new TaskId("abc", 0), Map.of()), |
| 364 | + createWithEmptySettings(taskQueue.getThreadPool()), |
| 365 | + mockAuthHandler, |
| 366 | + mock(Sender.class), |
| 367 | + settingsMock, |
| 368 | + mockRegistry, |
| 369 | + mockClient, |
| 370 | + callback |
| 371 | + ); |
| 372 | + poller.init(mock(PersistentTasksService.class), mock(TaskManager.class), "id", 0); |
| 373 | + poller.start(); |
| 374 | + taskQueue.runAllRunnableTasks(); |
| 375 | + latch.await(TimeValue.THIRTY_SECONDS.getSeconds(), TimeUnit.SECONDS); |
| 376 | + |
| 377 | + assertThat(callbackCount.get(), is(1)); |
| 378 | + assertTrue(poller.isShutdown()); |
| 379 | + verify(mockClient, never()).execute(eq(StoreInferenceEndpointsAction.INSTANCE), any(), any()); |
| 380 | + } |
318 | 381 | } |
0 commit comments