Skip to content

Commit a0a07bc

Browse files
Addressing feedback
1 parent c3badf1 commit a0a07bc

File tree

3 files changed

+92
-2
lines changed

3 files changed

+92
-2
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
import org.elasticsearch.common.Strings;
1818
import org.elasticsearch.core.TimeValue;
1919
import org.elasticsearch.persistent.AllocatedPersistentTask;
20+
import org.elasticsearch.persistent.PersistentTasksService;
2021
import org.elasticsearch.tasks.TaskId;
22+
import org.elasticsearch.tasks.TaskManager;
2123
import org.elasticsearch.threadpool.Scheduler;
2224
import org.elasticsearch.xpack.core.ClientHelper;
2325
import org.elasticsearch.xpack.core.inference.action.StoreInferenceEndpointsAction;
@@ -133,6 +135,17 @@ public void waitForAuthorizationToComplete(TimeValue waitTime) {
133135
}
134136
}
135137

138+
// Overriding so tests in the same package can access
139+
@Override
140+
protected void init(
141+
PersistentTasksService persistentTasksService,
142+
TaskManager taskManager,
143+
String persistentTaskId,
144+
long allocationId
145+
) {
146+
super.init(persistentTasksService, taskManager, persistentTaskId, allocationId);
147+
}
148+
136149
@Override
137150
protected void onCancelled() {
138151
shutdown();
@@ -142,11 +155,18 @@ protected void onCancelled() {
142155
// default for testing
143156
void shutdown() {
144157
shutdown.set(true);
145-
if (lastAuthTask.get() != null) {
146-
lastAuthTask.get().cancel();
158+
159+
var authTask = lastAuthTask.get();
160+
if (authTask != null) {
161+
authTask.cancel();
147162
}
148163
}
149164

165+
// default for testing
166+
boolean isShutdown() {
167+
return shutdown.get();
168+
}
169+
150170
private void scheduleAuthorizationRequest() {
151171
try {
152172
if (elasticInferenceServiceSettings.isPeriodicAuthorizationEnabled() == false) {
@@ -177,6 +197,8 @@ private void scheduleAuthorizationRequest() {
177197
);
178198
} catch (Exception e) {
179199
logger.warn("Failed scheduling authorization request", e);
200+
// Shutdown and complete the task so it will be restarted
201+
onCancelled();
180202
}
181203
}
182204

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutor.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import org.elasticsearch.cluster.service.ClusterService;
1717
import org.elasticsearch.common.Strings;
1818
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
19+
import org.elasticsearch.core.FixForMultiProject;
1920
import org.elasticsearch.core.TimeValue;
2021
import org.elasticsearch.persistent.AllocatedPersistentTask;
2122
import org.elasticsearch.persistent.ClusterPersistentTasksCustomMetadata;
@@ -93,6 +94,10 @@ protected void nodeOperation(AllocatedPersistentTask task, AuthorizationTaskPara
9394
authPoller.start();
9495
}
9596

97+
@FixForMultiProject(
98+
description = "A single cluster can have multiple projects, "
99+
+ "we'll need to either make a call per project/org or use a bulk authorization api that EIS provides"
100+
)
96101
@Override
97102
public Scope scope() {
98103
return Scope.CLUSTER;

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,15 @@
1212
import org.elasticsearch.common.util.concurrent.DeterministicTaskQueue;
1313
import org.elasticsearch.core.TimeValue;
1414
import org.elasticsearch.inference.TaskType;
15+
import org.elasticsearch.persistent.PersistentTasksService;
1516
import org.elasticsearch.tasks.TaskId;
17+
import org.elasticsearch.tasks.TaskManager;
1618
import org.elasticsearch.test.ESTestCase;
1719
import org.elasticsearch.xpack.core.inference.action.StoreInferenceEndpointsAction;
1820
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
1921
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
2022
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents;
23+
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings;
2124
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettingsTests;
2225
import org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints;
2326
import org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntity;
@@ -315,4 +318,64 @@ public void testSendsTwoAuthorizationRequests() throws InterruptedException {
315318
assertThat(callbackCount.get(), is(2));
316319
verify(mockClient, never()).execute(eq(StoreInferenceEndpointsAction.INSTANCE), any(), any());
317320
}
321+
322+
public void testCallsShutdownAndMarksTaskAsCompleted_WhenSchedulingFails() throws InterruptedException {
323+
var mockRegistry = mock(ModelRegistry.class);
324+
when(mockRegistry.isReady()).thenReturn(true);
325+
when(mockRegistry.getInferenceIds()).thenReturn(Set.of("id1", "id2"));
326+
327+
var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class);
328+
doAnswer(invocation -> {
329+
ActionListener<ElasticInferenceServiceAuthorizationModel> listener = invocation.getArgument(0);
330+
listener.onResponse(
331+
ElasticInferenceServiceAuthorizationModel.of(
332+
new ElasticInferenceServiceAuthorizationResponseEntity(
333+
List.of(
334+
new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel(
335+
// this is an unknown model id so it won't trigger storing an inference endpoint because
336+
// it doesn't map to a known one
337+
"abc",
338+
EnumSet.of(TaskType.SPARSE_EMBEDDING)
339+
)
340+
)
341+
)
342+
)
343+
);
344+
return Void.TYPE;
345+
}).when(mockAuthHandler).getAuthorization(any(), any());
346+
347+
var mockClient = mock(Client.class);
348+
349+
var callbackCount = new AtomicInteger(0);
350+
var latch = new CountDownLatch(1);
351+
352+
Runnable callback = () -> {
353+
callbackCount.incrementAndGet();
354+
latch.countDown();
355+
};
356+
357+
// Simulate scheduling failure by having the settings throw an exception when queried
358+
// Throwing an exception should cause the poller to shutdown and mark itself as completed
359+
var settingsMock = mock(ElasticInferenceServiceSettings.class);
360+
when(settingsMock.isPeriodicAuthorizationEnabled()).thenThrow(new IllegalStateException("failing"));
361+
362+
var poller = new AuthorizationPoller(
363+
new AuthorizationPoller.TaskFields(0, "abc", "abc", "abc", new TaskId("abc", 0), Map.of()),
364+
createWithEmptySettings(taskQueue.getThreadPool()),
365+
mockAuthHandler,
366+
mock(Sender.class),
367+
settingsMock,
368+
mockRegistry,
369+
mockClient,
370+
callback
371+
);
372+
poller.init(mock(PersistentTasksService.class), mock(TaskManager.class), "id", 0);
373+
poller.start();
374+
taskQueue.runAllRunnableTasks();
375+
latch.await(TimeValue.THIRTY_SECONDS.getSeconds(), TimeUnit.SECONDS);
376+
377+
assertThat(callbackCount.get(), is(1));
378+
assertTrue(poller.isShutdown());
379+
verify(mockClient, never()).execute(eq(StoreInferenceEndpointsAction.INSTANCE), any(), any());
380+
}
318381
}

0 commit comments

Comments
 (0)