Skip to content

Commit 35146e5

Browse files
Starting tests
1 parent a9f706f commit 35146e5

File tree

5 files changed

+247
-37
lines changed

5 files changed

+247
-37
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -253,11 +253,6 @@ private static Map<String, DefaultModelConfig> initDefaultEndpoints(
253253
);
254254
}
255255

256-
@Override
257-
public void onNodeStarted() {
258-
// authorizationHandler.init();
259-
}
260-
261256
@Override
262257
protected void validateRerankParameters(Boolean returnDocuments, Integer topN, ValidationException validationException) {
263258
if (returnDocuments != null) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java

Lines changed: 10 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,6 @@
3232
import java.util.Map;
3333
import java.util.Objects;
3434
import java.util.Set;
35-
import java.util.concurrent.CountDownLatch;
36-
import java.util.concurrent.TimeUnit;
3735
import java.util.concurrent.atomic.AtomicBoolean;
3836
import java.util.concurrent.atomic.AtomicReference;
3937
import java.util.stream.Collectors;
@@ -50,7 +48,6 @@ public class AuthorizationPoller extends AllocatedPersistentTask {
5048
private final ServiceComponents serviceComponents;
5149
private final ModelRegistry modelRegistry;
5250
private final ElasticInferenceServiceAuthorizationRequestHandler authorizationHandler;
53-
private final CountDownLatch firstAuthorizationCompletedLatch = new CountDownLatch(1);
5451
private final Sender sender;
5552
private final Runnable callback;
5653
private final AtomicReference<Scheduler.ScheduledCancellable> lastAuthTask = new AtomicReference<>(null);
@@ -117,23 +114,13 @@ public void start() {
117114
}
118115
}
119116

120-
/**
121-
* Waits the specified amount of time for the first authorization call to complete. This is mainly to make testing easier.
122-
* @param waitTime the max time to wait
123-
* @throws IllegalStateException if the wait time is exceeded or the call receives an {@link InterruptedException}
124-
*/
125-
public void waitForAuthorizationToComplete(TimeValue waitTime) {
126-
try {
127-
if (firstAuthorizationCompletedLatch.await(waitTime.getSeconds(), TimeUnit.SECONDS) == false) {
128-
throw new IllegalStateException("The wait time has expired for authorization to complete.");
129-
}
130-
} catch (InterruptedException e) {
131-
throw new IllegalStateException("Waiting for authorization to complete was interrupted");
132-
}
133-
}
134-
135117
@Override
136118
protected void onCancelled() {
119+
shutdown();
120+
}
121+
122+
// default for testing
123+
void shutdown() {
137124
shutdown.set(true);
138125
if (lastAuthTask.get() != null) {
139126
lastAuthTask.get().cancel();
@@ -182,7 +169,8 @@ private void scheduleAndSendAuthorizationRequest() {
182169
sendAuthorizationRequest();
183170
}
184171

185-
private void sendAuthorizationRequest() {
172+
// default for testing
173+
void sendAuthorizationRequest() {
186174
if (modelRegistry.isReady() == false) {
187175
return;
188176
}
@@ -191,7 +179,6 @@ private void sendAuthorizationRequest() {
191179
if (callback != null) {
192180
callback.run();
193181
}
194-
firstAuthorizationCompletedLatch.countDown();
195182
}).delegateResponse((delegate, e) -> {
196183
logger.atWarn().withThrowable(e).log("Failed processing EIS preconfigured endpoints");
197184
delegate.onResponse(null);
@@ -227,11 +214,11 @@ private void storePreconfiguredModels(Set<String> newInferenceIds, ActionListene
227214
return;
228215
}
229216

230-
logger.debug("Storing new EIS preconfigured inference endpoints with inference IDs {}", newInferenceIds);
217+
logger.info("Storing new EIS preconfigured inference endpoints with inference IDs {}", newInferenceIds);
231218
var modelsToAdd = PreconfiguredEndpointModelAdapter.getModels(newInferenceIds, elasticInferenceServiceComponents);
232219
var storeRequest = new StoreInferenceEndpointsAction.Request(modelsToAdd, TimeValue.THIRTY_SECONDS);
233220

234-
ActionListener<StoreInferenceEndpointsAction.Response> storeListener = ActionListener.wrap(responses -> {
221+
ActionListener<StoreInferenceEndpointsAction.Response> logResultsListener = ActionListener.wrap(responses -> {
235222
for (var response : responses.getResults()) {
236223
if (response.failed()) {
237224
logger.atWarn()
@@ -247,7 +234,7 @@ private void storePreconfiguredModels(Set<String> newInferenceIds, ActionListene
247234
client.execute(
248235
StoreInferenceEndpointsAction.INSTANCE,
249236
storeRequest,
250-
ActionListener.runAfter(storeListener, () -> listener.onResponse(null))
237+
ActionListener.runAfter(logResultsListener, () -> listener.onResponse(null))
251238
);
252239
}
253240
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointModelAdapter.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ public static List<Model> getModels(Set<String> inferenceIds, ElasticInferenceSe
2727
.toList();
2828
}
2929

30-
private static Model createModel(
30+
public static Model createModel(
3131
InternalPreconfiguredEndpoints.MinimalModel minimalModel,
3232
ElasticInferenceServiceComponents elasticInferenceServiceComponents
3333
) {

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -113,12 +113,12 @@ public void testStoreModels_StoresSingleInferenceEndpoint() {
113113
new TestModel.TestSecretSettings(secrets)
114114
);
115115

116-
PlainActionFuture<List<ModelRegistry.ModelStoreResponse>> storeListener = new PlainActionFuture<>();
116+
PlainActionFuture<List<ModelStoreResponse>> storeListener = new PlainActionFuture<>();
117117
registry.storeModels(List.of(model), storeListener, TimeValue.THIRTY_SECONDS);
118118

119119
var response = storeListener.actionGet(TimeValue.THIRTY_SECONDS);
120120
assertThat(response.size(), is(1));
121-
assertThat(response.get(0), is(new ModelRegistry.ModelStoreResponse("1", RestStatus.CREATED, null)));
121+
assertThat(response.get(0), is(new ModelStoreResponse("1", RestStatus.CREATED, null)));
122122

123123
assertMinimalServiceSettings(registry, model);
124124

@@ -158,13 +158,13 @@ public void testStoreModels_StoresMultipleInferenceEndpoints() {
158158
new TestModel.TestSecretSettings(secrets)
159159
);
160160

161-
PlainActionFuture<List<ModelRegistry.ModelStoreResponse>> storeListener = new PlainActionFuture<>();
161+
PlainActionFuture<List<ModelStoreResponse>> storeListener = new PlainActionFuture<>();
162162
registry.storeModels(List.of(model1, model2), storeListener, TimeValue.THIRTY_SECONDS);
163163

164164
var response = storeListener.actionGet(TimeValue.THIRTY_SECONDS);
165165
assertThat(response.size(), is(2));
166-
assertThat(response.get(0), is(new ModelRegistry.ModelStoreResponse("1", RestStatus.CREATED, null)));
167-
assertThat(response.get(1), is(new ModelRegistry.ModelStoreResponse("2", RestStatus.CREATED, null)));
166+
assertThat(response.get(0), is(new ModelStoreResponse("1", RestStatus.CREATED, null)));
167+
assertThat(response.get(1), is(new ModelStoreResponse("2", RestStatus.CREATED, null)));
168168

169169
assertModelAndMinimalSettingsWithSecrets(registry, model1, secrets);
170170
assertModelAndMinimalSettingsWithSecrets(registry, model2, secrets);
@@ -214,12 +214,12 @@ public void testStoreModels_StoresOneModel_FailsToStoreSecond_WhenVersionConflic
214214
new TestModel.TestSecretSettings(secrets)
215215
);
216216

217-
PlainActionFuture<List<ModelRegistry.ModelStoreResponse>> storeListener = new PlainActionFuture<>();
217+
PlainActionFuture<List<ModelStoreResponse>> storeListener = new PlainActionFuture<>();
218218
registry.storeModels(List.of(model1, model2), storeListener, TimeValue.THIRTY_SECONDS);
219219

220220
var response = storeListener.actionGet(TimeValue.THIRTY_SECONDS);
221221
assertThat(response.size(), is(2));
222-
assertThat(response.get(0), is(new ModelRegistry.ModelStoreResponse("1", RestStatus.CREATED, null)));
222+
assertThat(response.get(0), is(new ModelStoreResponse("1", RestStatus.CREATED, null)));
223223
assertThat(response.get(1).inferenceId(), is(model2.getInferenceEntityId()));
224224
assertThat(response.get(1).status(), is(RestStatus.CONFLICT));
225225
assertTrue(response.get(1).failed());
@@ -246,7 +246,7 @@ public void testStoreModels_FailsToStoreModel_WhenInferenceIndexDocumentAlreadyE
246246

247247
storeCorruptedModel(model1);
248248

249-
PlainActionFuture<List<ModelRegistry.ModelStoreResponse>> storeListener = new PlainActionFuture<>();
249+
PlainActionFuture<List<ModelStoreResponse>> storeListener = new PlainActionFuture<>();
250250
registry.storeModels(List.of(model1), storeListener, TimeValue.THIRTY_SECONDS);
251251

252252
var response = storeListener.actionGet(TimeValue.THIRTY_SECONDS);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.elastic.authorization;
9+
10+
import org.elasticsearch.action.ActionListener;
11+
import org.elasticsearch.client.internal.Client;
12+
import org.elasticsearch.common.util.concurrent.DeterministicTaskQueue;
13+
import org.elasticsearch.core.TimeValue;
14+
import org.elasticsearch.inference.TaskType;
15+
import org.elasticsearch.tasks.TaskId;
16+
import org.elasticsearch.test.ESTestCase;
17+
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
18+
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
19+
import org.elasticsearch.xpack.inference.registry.StoreInferenceEndpointsAction;
20+
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents;
21+
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettingsTests;
22+
import org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints;
23+
import org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntity;
24+
import org.junit.Before;
25+
import org.mockito.ArgumentCaptor;
26+
27+
import java.util.EnumSet;
28+
import java.util.List;
29+
import java.util.Map;
30+
import java.util.Set;
31+
32+
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
33+
import static org.hamcrest.Matchers.is;
34+
import static org.mockito.ArgumentMatchers.any;
35+
import static org.mockito.ArgumentMatchers.eq;
36+
import static org.mockito.Mockito.doAnswer;
37+
import static org.mockito.Mockito.mock;
38+
import static org.mockito.Mockito.never;
39+
import static org.mockito.Mockito.verify;
40+
import static org.mockito.Mockito.when;
41+
42+
public class AuthorizationPollerTests extends ESTestCase {
43+
private DeterministicTaskQueue taskQueue;
44+
45+
@Before
46+
public void init() throws Exception {
47+
taskQueue = new DeterministicTaskQueue();
48+
}
49+
50+
public void testDoesNotSendAuthorizationRequest_WhenModelRegistryIsNotReady() {
51+
var mockRegistry = mock(ModelRegistry.class);
52+
when(mockRegistry.isReady()).thenReturn(false);
53+
54+
var authorizationRequestHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class);
55+
56+
var poller = new AuthorizationPoller(
57+
new AuthorizationPoller.TaskFields(0, "abc", "abc", "abc", new TaskId("abc", 0), Map.of()),
58+
createWithEmptySettings(taskQueue.getThreadPool()),
59+
authorizationRequestHandler,
60+
mock(Sender.class),
61+
ElasticInferenceServiceSettingsTests.create(null, TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true),
62+
new ElasticInferenceServiceComponents(""),
63+
mockRegistry,
64+
mock(Client.class),
65+
null
66+
);
67+
68+
poller.sendAuthorizationRequest();
69+
70+
verify(authorizationRequestHandler, never()).getAuthorization(any(), any());
71+
}
72+
73+
public void testSendsAuthorizationRequest_WhenModelRegistryIsReady() {
74+
var mockRegistry = mock(ModelRegistry.class);
75+
when(mockRegistry.isReady()).thenReturn(true);
76+
when(mockRegistry.getInferenceIds()).thenReturn(Set.of("id1", "id2"));
77+
78+
var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class);
79+
doAnswer(invocation -> {
80+
ActionListener<ElasticInferenceServiceAuthorizationModel> listener = invocation.getArgument(0);
81+
listener.onResponse(
82+
ElasticInferenceServiceAuthorizationModel.of(
83+
new ElasticInferenceServiceAuthorizationResponseEntity(
84+
List.of(
85+
new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel(
86+
InternalPreconfiguredEndpoints.DEFAULT_ELSER_2_MODEL_ID,
87+
EnumSet.of(TaskType.SPARSE_EMBEDDING)
88+
)
89+
)
90+
)
91+
)
92+
);
93+
return Void.TYPE;
94+
}).when(mockAuthHandler).getAuthorization(any(), any());
95+
96+
var mockClient = mock(Client.class);
97+
when(mockClient.threadPool()).thenReturn(taskQueue.getThreadPool());
98+
99+
var eisComponents = new ElasticInferenceServiceComponents("");
100+
101+
var poller = new AuthorizationPoller(
102+
new AuthorizationPoller.TaskFields(0, "abc", "abc", "abc", new TaskId("abc", 0), Map.of()),
103+
createWithEmptySettings(taskQueue.getThreadPool()),
104+
mockAuthHandler,
105+
mock(Sender.class),
106+
ElasticInferenceServiceSettingsTests.create(null, TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true),
107+
eisComponents,
108+
mockRegistry,
109+
mockClient,
110+
null
111+
);
112+
113+
var requestArgCaptor = ArgumentCaptor.forClass(StoreInferenceEndpointsAction.Request.class);
114+
115+
poller.sendAuthorizationRequest();
116+
verify(mockClient).execute(eq(StoreInferenceEndpointsAction.INSTANCE), requestArgCaptor.capture(), any());
117+
var capturedRequest = requestArgCaptor.getValue();
118+
assertThat(
119+
capturedRequest.getModels(),
120+
is(
121+
List.of(
122+
PreconfiguredEndpointModelAdapter.createModel(
123+
InternalPreconfiguredEndpoints.getWithInferenceId(InternalPreconfiguredEndpoints.DEFAULT_ELSER_ENDPOINT_ID_V2),
124+
eisComponents
125+
)
126+
)
127+
)
128+
);
129+
}
130+
131+
public void testDoesNotAttemptToStoreModelIds_ThatDoNotExistInThePreconfiguredMapping() {
132+
var mockRegistry = mock(ModelRegistry.class);
133+
when(mockRegistry.isReady()).thenReturn(true);
134+
when(mockRegistry.getInferenceIds()).thenReturn(Set.of("id1", "id2"));
135+
136+
var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class);
137+
doAnswer(invocation -> {
138+
ActionListener<ElasticInferenceServiceAuthorizationModel> listener = invocation.getArgument(0);
139+
listener.onResponse(
140+
ElasticInferenceServiceAuthorizationModel.of(
141+
new ElasticInferenceServiceAuthorizationResponseEntity(
142+
List.of(
143+
new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel(
144+
// This is a model id that does not exist in the preconfigured endpoints map so it will not be stored
145+
"abc",
146+
EnumSet.of(TaskType.SPARSE_EMBEDDING)
147+
)
148+
)
149+
)
150+
)
151+
);
152+
return Void.TYPE;
153+
}).when(mockAuthHandler).getAuthorization(any(), any());
154+
155+
var mockClient = mock(Client.class);
156+
when(mockClient.threadPool()).thenReturn(taskQueue.getThreadPool());
157+
158+
var eisComponents = new ElasticInferenceServiceComponents("");
159+
160+
var poller = new AuthorizationPoller(
161+
new AuthorizationPoller.TaskFields(0, "abc", "abc", "abc", new TaskId("abc", 0), Map.of()),
162+
createWithEmptySettings(taskQueue.getThreadPool()),
163+
mockAuthHandler,
164+
mock(Sender.class),
165+
ElasticInferenceServiceSettingsTests.create(null, TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true),
166+
eisComponents,
167+
mockRegistry,
168+
mockClient,
169+
null
170+
);
171+
172+
var requestArgCaptor = ArgumentCaptor.forClass(StoreInferenceEndpointsAction.Request.class);
173+
174+
poller.sendAuthorizationRequest();
175+
verify(mockClient, never()).execute(eq(StoreInferenceEndpointsAction.INSTANCE), requestArgCaptor.capture(), any());
176+
}
177+
178+
public void testDoesNotAttemptToStoreModelIds_ThatHaveATaskTypeThatTheEISIntegration_DoesNotSupport() {
179+
var mockRegistry = mock(ModelRegistry.class);
180+
when(mockRegistry.isReady()).thenReturn(true);
181+
when(mockRegistry.getInferenceIds()).thenReturn(Set.of("id1", "id2"));
182+
183+
var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class);
184+
doAnswer(invocation -> {
185+
ActionListener<ElasticInferenceServiceAuthorizationModel> listener = invocation.getArgument(0);
186+
listener.onResponse(
187+
ElasticInferenceServiceAuthorizationModel.of(
188+
new ElasticInferenceServiceAuthorizationResponseEntity(
189+
List.of(
190+
new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel(
191+
InternalPreconfiguredEndpoints.DEFAULT_ELSER_2_MODEL_ID,
192+
// EIS does not yet support completions so this model will be ignored
193+
EnumSet.of(TaskType.COMPLETION)
194+
)
195+
)
196+
)
197+
)
198+
);
199+
return Void.TYPE;
200+
}).when(mockAuthHandler).getAuthorization(any(), any());
201+
202+
var mockClient = mock(Client.class);
203+
when(mockClient.threadPool()).thenReturn(taskQueue.getThreadPool());
204+
205+
var eisComponents = new ElasticInferenceServiceComponents("");
206+
207+
var poller = new AuthorizationPoller(
208+
new AuthorizationPoller.TaskFields(0, "abc", "abc", "abc", new TaskId("abc", 0), Map.of()),
209+
createWithEmptySettings(taskQueue.getThreadPool()),
210+
mockAuthHandler,
211+
mock(Sender.class),
212+
ElasticInferenceServiceSettingsTests.create(null, TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true),
213+
eisComponents,
214+
mockRegistry,
215+
mockClient,
216+
null
217+
);
218+
219+
var requestArgCaptor = ArgumentCaptor.forClass(StoreInferenceEndpointsAction.Request.class);
220+
221+
poller.sendAuthorizationRequest();
222+
verify(mockClient, never()).execute(eq(StoreInferenceEndpointsAction.INSTANCE), requestArgCaptor.capture(), any());
223+
}
224+
225+
public void testSendsTwoAuthorizationRequests() {
226+
fail("TODO");
227+
}
228+
}

0 commit comments

Comments
 (0)