Skip to content

Commit 578546f

Browse files
Fixing tests
1 parent 6b6d21b commit 578546f

File tree

4 files changed

+38
-9
lines changed

4 files changed

+38
-9
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
2828
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
2929
import org.elasticsearch.xpack.inference.services.ServiceComponents;
30-
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents;
3130
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings;
3231
import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMFeature;
3332
import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMService;
@@ -61,7 +60,6 @@ public class AuthorizationPoller extends AllocatedPersistentTask {
6160
private final AtomicBoolean shutdown = new AtomicBoolean(false);
6261
private final ElasticInferenceServiceSettings elasticInferenceServiceSettings;
6362
private final AtomicBoolean initialized = new AtomicBoolean(false);
64-
private final ElasticInferenceServiceComponents elasticInferenceServiceComponents;
6563
private final Client client;
6664
private final CountDownLatch receivedFirstAuthResponseLatch = new CountDownLatch(1);
6765
private final CCMFeature ccmFeature;
@@ -118,9 +116,6 @@ private AuthorizationPoller(TaskFields taskFields, Parameters parameters) {
118116
this.authorizationHandler = Objects.requireNonNull(authorizationRequestHandler);
119117
this.sender = Objects.requireNonNull(sender);
120118
this.elasticInferenceServiceSettings = Objects.requireNonNull(elasticInferenceServiceSettings);
121-
this.elasticInferenceServiceComponents = new ElasticInferenceServiceComponents(
122-
elasticInferenceServiceSettings.getElasticInferenceServiceUrl()
123-
);
124119
this.modelRegistry = Objects.requireNonNull(modelRegistry);
125120
this.client = new OriginSettingClient(Objects.requireNonNull(client), ClientHelper.INFERENCE_ORIGIN);
126121
this.ccmFeature = Objects.requireNonNull(ccmFeature);

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
import static org.elasticsearch.xpack.inference.services.elastic.ccm.CCMFeatureTests.createMockCCMFeature;
4545
import static org.elasticsearch.xpack.inference.services.elastic.ccm.CCMServiceTests.createMockCCMService;
4646
import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.createAuthorizedEndpoint;
47+
import static org.elasticsearch.xpack.inference.services.elastic.response.AuthorizationResponseEntityTests.createInvalidTaskTypeAuthorizedEndpoint;
4748
import static org.hamcrest.Matchers.is;
4849
import static org.mockito.ArgumentMatchers.any;
4950
import static org.mockito.ArgumentMatchers.eq;
@@ -336,7 +337,7 @@ public void testSendsAuthorizationRequest_ButDoesNotStoreAnyModels_WhenTheirInfe
336337

337338
public void testDoesNotAttemptToStoreModelIds_ThatHaveATaskTypeThatTheEISIntegration_DoesNotSupport() {
338339
var url = "eis-url";
339-
var completionModel = createAuthorizedEndpoint(TaskType.COMPLETION);
340+
var invalidTaskTypeEndpoint = createInvalidTaskTypeAuthorizedEndpoint();
340341

341342
var mockRegistry = mock(ModelRegistry.class);
342343
when(mockRegistry.isReady()).thenReturn(true);
@@ -345,7 +346,7 @@ public void testDoesNotAttemptToStoreModelIds_ThatHaveATaskTypeThatTheEISIntegra
345346
var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class);
346347
doAnswer(invocation -> {
347348
ActionListener<AuthorizationModel> listener = invocation.getArgument(0);
348-
listener.onResponse(AuthorizationModel.of(new AuthorizationResponseEntity(List.of(completionModel)), url));
349+
listener.onResponse(AuthorizationModel.of(new AuthorizationResponseEntity(List.of(invalidTaskTypeEndpoint)), url));
349350
return Void.TYPE;
350351
}).when(mockAuthHandler).getAuthorization(any(), any());
351352

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,15 @@ public void testGetAuthorization_ReturnsAValidResponse() throws IOException {
203203
var authResponse = listener.actionGet(TIMEOUT);
204204
assertThat(
205205
authResponse.getTaskTypes(),
206-
is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING, TaskType.TEXT_EMBEDDING, TaskType.RERANK))
206+
is(
207+
EnumSet.of(
208+
TaskType.CHAT_COMPLETION,
209+
TaskType.SPARSE_EMBEDDING,
210+
TaskType.TEXT_EMBEDDING,
211+
TaskType.RERANK,
212+
TaskType.COMPLETION
213+
)
214+
)
207215
);
208216
assertThat(authResponse.getEndpointIds(), containsInAnyOrder(responseData.inferenceIds().toArray(String[]::new)));
209217
assertTrue(authResponse.isAuthorized());

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/AuthorizationResponseEntityTests.java

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,23 @@ public static AuthorizationResponseEntity createResponse() {
493493
);
494494
}
495495

496+
public static AuthorizationResponseEntity.AuthorizedEndpoint createInvalidTaskTypeAuthorizedEndpoint() {
497+
var id = randomAlphaOfLength(10);
498+
var name = randomAlphaOfLength(10);
499+
var status = randomFrom("ga", "beta", "preview");
500+
501+
return new AuthorizationResponseEntity.AuthorizedEndpoint(
502+
id,
503+
name,
504+
createTaskTypeObject("invalid/task/type", TaskType.ANY.toString()),
505+
status,
506+
null,
507+
"",
508+
"",
509+
null
510+
);
511+
}
512+
496513
public static AuthorizationResponseEntity.AuthorizedEndpoint createAuthorizedEndpoint(TaskType taskType) {
497514
var id = randomAlphaOfLength(10);
498515
var name = randomAlphaOfLength(10);
@@ -572,7 +589,15 @@ public void testParseAllFields() throws IOException {
572589

573590
assertThat(
574591
authModel.getTaskTypes(),
575-
is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING, TaskType.TEXT_EMBEDDING, TaskType.RERANK))
592+
is(
593+
EnumSet.of(
594+
TaskType.CHAT_COMPLETION,
595+
TaskType.SPARSE_EMBEDDING,
596+
TaskType.TEXT_EMBEDDING,
597+
TaskType.RERANK,
598+
TaskType.COMPLETION
599+
)
600+
)
576601
);
577602
assertThat(
578603
authModel.getEndpoints(responseData.inferenceIds()),

0 commit comments

Comments
 (0)