Skip to content

Commit 3961e74

Browse files
More tests
1 parent 35146e5 commit 3961e74

File tree

10 files changed

+559
-29
lines changed

10 files changed

+559
-29
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -337,10 +337,8 @@ public Collection<?> createComponents(PluginServices services) {
337337

338338
var eisComponents = new ElasticInferenceServiceComponents(inferenceServiceSettings.getElasticInferenceServiceUrl());
339339

340-
var authTaskExecutor = new AuthorizationTaskExecutor(
341-
services.client(),
340+
var authTaskExecutor = AuthorizationTaskExecutor.create(
342341
services.clusterService(),
343-
services.threadPool(),
344342
new AuthorizationPoller.Parameters(
345343
serviceComponents.get(),
346344
authorizationHandler,
@@ -351,7 +349,6 @@ public Collection<?> createComponents(PluginServices services) {
351349
services.client()
352350
)
353351
);
354-
authTaskExecutor.init();
355352
authorizationTaskExecutorRef.set(authTaskExecutor);
356353

357354
var sageMakerSchemas = new SageMakerSchemas();

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceModel.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,4 +53,18 @@ public RateLimitSettings rateLimitSettings() {
5353
public ElasticInferenceServiceComponents elasticInferenceServiceComponents() {
5454
return elasticInferenceServiceComponents;
5555
}
56+
57+
@Override
58+
public boolean equals(Object o) {
59+
if (o == null || getClass() != o.getClass()) return false;
60+
if (super.equals(o) == false) return false;
61+
ElasticInferenceServiceModel that = (ElasticInferenceServiceModel) o;
62+
return Objects.equals(rateLimitServiceSettings, that.rateLimitServiceSettings)
63+
&& Objects.equals(elasticInferenceServiceComponents, that.elasticInferenceServiceComponents);
64+
}
65+
66+
@Override
67+
public int hashCode() {
68+
return Objects.hash(super.hashCode(), rateLimitServiceSettings, elasticInferenceServiceComponents);
69+
}
5670
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,19 +64,23 @@ public record Parameters(
6464
ElasticInferenceServiceAuthorizationRequestHandler authorizationRequestHandler,
6565
Sender sender,
6666
ElasticInferenceServiceSettings elasticInferenceServiceSettings,
67-
ElasticInferenceServiceComponents components,
67+
ElasticInferenceServiceComponents eisComponents,
6868
ModelRegistry modelRegistry,
6969
Client client
7070
) {}
7171

72-
public AuthorizationPoller(TaskFields taskFields, Parameters parameters) {
72+
public static AuthorizationPoller create(TaskFields taskFields, Parameters parameters) {
73+
return new AuthorizationPoller(Objects.requireNonNull(taskFields), Objects.requireNonNull(parameters));
74+
}
75+
76+
private AuthorizationPoller(TaskFields taskFields, Parameters parameters) {
7377
this(
7478
taskFields,
7579
parameters.serviceComponents,
7680
parameters.authorizationRequestHandler,
7781
parameters.sender,
7882
parameters.elasticInferenceServiceSettings,
79-
parameters.components,
83+
parameters.eisComponents,
8084
parameters.modelRegistry,
8185
parameters.client,
8286
null
@@ -109,7 +113,7 @@ public AuthorizationPoller(TaskFields taskFields, Parameters parameters) {
109113

110114
public void start() {
111115
if (initialized.compareAndSet(false, true)) {
112-
logger.debug("Initializing authorization logic");
116+
logger.debug("Initializing EIS authorization logic");
113117
serviceComponents.threadPool().executor(UTILITY_THREAD_POOL_NAME).execute(this::scheduleAndSendAuthorizationRequest);
114118
}
115119
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutor.java

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111
import org.apache.logging.log4j.Logger;
1212
import org.elasticsearch.ResourceAlreadyExistsException;
1313
import org.elasticsearch.action.ActionListener;
14-
import org.elasticsearch.client.internal.Client;
1514
import org.elasticsearch.cluster.ClusterChangedEvent;
1615
import org.elasticsearch.cluster.ClusterStateListener;
1716
import org.elasticsearch.cluster.service.ClusterService;
17+
import org.elasticsearch.common.Strings;
1818
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
1919
import org.elasticsearch.core.TimeValue;
2020
import org.elasticsearch.persistent.AllocatedPersistentTask;
@@ -25,14 +25,14 @@
2525
import org.elasticsearch.persistent.PersistentTasksExecutor;
2626
import org.elasticsearch.persistent.PersistentTasksService;
2727
import org.elasticsearch.tasks.TaskId;
28-
import org.elasticsearch.threadpool.ThreadPool;
2928
import org.elasticsearch.transport.RemoteTransportException;
3029
import org.elasticsearch.xcontent.NamedXContentRegistry;
3130
import org.elasticsearch.xcontent.ParseField;
3231

3332
import java.util.List;
3433
import java.util.Map;
3534
import java.util.Objects;
35+
import java.util.concurrent.atomic.AtomicReference;
3636

3737
import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME;
3838
import static org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationPoller.TASK_NAME;
@@ -44,27 +44,43 @@ public class AuthorizationTaskExecutor extends PersistentTasksExecutor<Authoriza
4444
private final ClusterService clusterService;
4545
private final PersistentTasksService persistentTasksService;
4646
private final AuthorizationPoller.Parameters pollerParameters;
47+
private final AtomicReference<AuthorizationPoller> currentTask = new AtomicReference<>();
4748

48-
public AuthorizationTaskExecutor(
49-
Client client,
49+
public static AuthorizationTaskExecutor create(ClusterService clusterService, AuthorizationPoller.Parameters parameters) {
50+
Objects.requireNonNull(clusterService);
51+
Objects.requireNonNull(parameters);
52+
53+
var executor = new AuthorizationTaskExecutor(
54+
clusterService,
55+
new PersistentTasksService(clusterService, parameters.serviceComponents().threadPool(), parameters.client()),
56+
parameters
57+
);
58+
executor.init();
59+
return executor;
60+
}
61+
62+
// default for testing
63+
AuthorizationTaskExecutor(
5064
ClusterService clusterService,
51-
ThreadPool threadPool,
65+
PersistentTasksService persistentTasksService,
5266
AuthorizationPoller.Parameters pollerParameters
5367
) {
54-
super(TASK_NAME, threadPool.executor(UTILITY_THREAD_POOL_NAME));
68+
super(TASK_NAME, pollerParameters.serviceComponents().threadPool().executor(UTILITY_THREAD_POOL_NAME));
5569
this.clusterService = Objects.requireNonNull(clusterService);
56-
this.persistentTasksService = new PersistentTasksService(clusterService, threadPool, client);
70+
this.persistentTasksService = Objects.requireNonNull(persistentTasksService);
5771
this.pollerParameters = Objects.requireNonNull(pollerParameters);
5872
}
5973

60-
public void init() {
61-
clusterService.addListener(this);
74+
// default for testing
75+
void init() {
76+
// If the EIS url is not configured, then we won't be able to interact with the service, so don't start the task.
77+
if (Strings.isNullOrEmpty(pollerParameters.eisComponents().elasticInferenceServiceUrl()) == false) {
78+
clusterService.addListener(this);
79+
}
6280
}
6381

6482
@Override
6583
protected void nodeOperation(AllocatedPersistentTask task, AuthorizationTaskParams params, PersistentTaskState state) {
66-
// TODO remove
67-
logger.warn("Starting authorization poller task");
6884
var authPoller = (AuthorizationPoller) task;
6985
authPoller.start();
7086
}
@@ -83,7 +99,7 @@ protected AuthorizationPoller createTask(
8399
PersistentTasksCustomMetadata.PersistentTask<AuthorizationTaskParams> taskInProgress,
84100
Map<String, String> headers
85101
) {
86-
return new AuthorizationPoller(
102+
return AuthorizationPoller.create(
87103
new AuthorizationPoller.TaskFields(id, type, action, getDescription(taskInProgress), parentTaskId, headers),
88104
pollerParameters
89105
);
@@ -100,9 +116,7 @@ public void clusterChanged(ClusterChangedEvent event) {
100116
TASK_NAME,
101117
new AuthorizationTaskParams(),
102118
TimeValue.THIRTY_SECONDS,
103-
ActionListener.wrap(persistentTask -> {
104-
logger.warn("Created authorization poller task");
105-
}, e -> {
119+
ActionListener.wrap(persistentTask -> logger.debug("Created authorization poller task"), e -> {
106120
var t = e instanceof RemoteTransportException ? e.getCause() : e;
107121
if (t instanceof ResourceAlreadyExistsException == false) {
108122
logger.error("Failed to create authorization poller task", e);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskParams.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import org.elasticsearch.TransportVersions;
1212
import org.elasticsearch.common.io.stream.StreamInput;
1313
import org.elasticsearch.common.io.stream.StreamOutput;
14-
import org.elasticsearch.health.node.selection.HealthNodeTaskParams;
1514
import org.elasticsearch.persistent.PersistentTaskParams;
1615
import org.elasticsearch.xcontent.ObjectParser;
1716
import org.elasticsearch.xcontent.ToXContent;
@@ -23,7 +22,7 @@
2322
import static org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationPoller.TASK_NAME;
2423

2524
public class AuthorizationTaskParams implements PersistentTaskParams {
26-
private static final AuthorizationTaskParams INSTANCE = new AuthorizationTaskParams();
25+
public static final AuthorizationTaskParams INSTANCE = new AuthorizationTaskParams();
2726

2827
private static final ObjectParser<AuthorizationTaskParams, Void> PARSER = new ObjectParser<>(TASK_NAME, true, () -> INSTANCE);
2928

@@ -61,7 +60,7 @@ public int hashCode() {
6160
}
6261

6362
@Override
64-
public boolean equals(Object obj) {
65-
return obj instanceof HealthNodeTaskParams;
63+
public boolean equals(Object o) {
64+
return this == o || (o != null && getClass() == o.getClass());
6665
}
6766
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointModelAdapter.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
public class PreconfiguredEndpointModelAdapter {
2323
public static List<Model> getModels(Set<String> inferenceIds, ElasticInferenceServiceComponents elasticInferenceServiceComponents) {
2424
return inferenceIds.stream()
25+
.sorted()
2526
.filter(EIS_PRECONFIGURED_ENDPOINT_IDS::contains)
2627
.map(id -> createModel(InternalPreconfiguredEndpoints.getWithInferenceId(id), elasticInferenceServiceComponents))
2728
.toList();

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@
2828
import java.util.List;
2929
import java.util.Map;
3030
import java.util.Set;
31+
import java.util.concurrent.CountDownLatch;
32+
import java.util.concurrent.TimeUnit;
33+
import java.util.concurrent.atomic.AtomicInteger;
34+
import java.util.concurrent.atomic.AtomicReference;
3135

3236
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
3337
import static org.hamcrest.Matchers.is;
@@ -222,7 +226,68 @@ public void testDoesNotAttemptToStoreModelIds_ThatHaveATaskTypeThatTheEISIntegra
222226
verify(mockClient, never()).execute(eq(StoreInferenceEndpointsAction.INSTANCE), requestArgCaptor.capture(), any());
223227
}
224228

225-
public void testSendsTwoAuthorizationRequests() {
226-
fail("TODO");
229+
public void testSendsTwoAuthorizationRequests() throws InterruptedException {
230+
var mockRegistry = mock(ModelRegistry.class);
231+
when(mockRegistry.isReady()).thenReturn(true);
232+
when(mockRegistry.getInferenceIds()).thenReturn(Set.of("id1", "id2"));
233+
234+
var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class);
235+
doAnswer(invocation -> {
236+
ActionListener<ElasticInferenceServiceAuthorizationModel> listener = invocation.getArgument(0);
237+
listener.onResponse(
238+
ElasticInferenceServiceAuthorizationModel.of(
239+
new ElasticInferenceServiceAuthorizationResponseEntity(
240+
List.of(
241+
new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel(
242+
// this is an unknown model id so it won't trigger storing an inference endpoint because
243+
// it doesn't map to a known one
244+
"abc",
245+
EnumSet.of(TaskType.SPARSE_EMBEDDING)
246+
)
247+
)
248+
)
249+
)
250+
);
251+
return Void.TYPE;
252+
}).when(mockAuthHandler).getAuthorization(any(), any());
253+
254+
var mockClient = mock(Client.class);
255+
var eisComponents = new ElasticInferenceServiceComponents("");
256+
257+
var callbackCount = new AtomicInteger(0);
258+
var latch = new CountDownLatch(2);
259+
final var pollerRef = new AtomicReference<AuthorizationPoller>();
260+
261+
Runnable callback = () -> {
262+
var count = callbackCount.incrementAndGet();
263+
latch.countDown();
264+
265+
// we only want to run the tasks twice, so advance the time on the queue
266+
// which flags the scheduled authorization request to be ready to run
267+
if (count == 1) {
268+
taskQueue.advanceTime();
269+
} else {
270+
pollerRef.get().shutdown();
271+
}
272+
};
273+
274+
var poller = new AuthorizationPoller(
275+
new AuthorizationPoller.TaskFields(0, "abc", "abc", "abc", new TaskId("abc", 0), Map.of()),
276+
createWithEmptySettings(taskQueue.getThreadPool()),
277+
mockAuthHandler,
278+
mock(Sender.class),
279+
ElasticInferenceServiceSettingsTests.create(null, TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true),
280+
eisComponents,
281+
mockRegistry,
282+
mockClient,
283+
callback
284+
);
285+
pollerRef.set(poller);
286+
poller.start();
287+
taskQueue.runAllRunnableTasks();
288+
latch.await(TimeValue.THIRTY_SECONDS.getSeconds(), TimeUnit.SECONDS);
289+
290+
assertThat(callbackCount.get(), is(2));
291+
verify(mockClient, never()).execute(eq(StoreInferenceEndpointsAction.INSTANCE), any(), any());
227292
}
228293
}

0 commit comments

Comments
 (0)