diff --git a/docs/changelog/132546.yaml b/docs/changelog/132546.yaml new file mode 100644 index 0000000000000..60cd60f2a79a4 --- /dev/null +++ b/docs/changelog/132546.yaml @@ -0,0 +1,5 @@ +pr: 132546 +summary: Improve EIS auth call logs and fix revocation bug +area: Machine Learning +type: bug +issues: [] diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceAuthorizationResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceAuthorizationResponseEntity.java index 5912deb006440..b0e3523d2ad4f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceAuthorizationResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceAuthorizationResponseEntity.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.inference.external.response.elastic; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; @@ -14,6 +15,8 @@ import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.logging.LogManager; +import org.elasticsearch.logging.Logger; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.ToXContent; @@ -39,6 +42,9 @@ public class ElasticInferenceServiceAuthorizationResponseEntity implements InferenceServiceResults { public static final String NAME = "elastic_inference_service_auth_results"; + + private static final Logger logger = LogManager.getLogger(ElasticInferenceServiceAuthorizationResponseEntity.class); + private static final String AUTH_FIELD_NAME = "authorized_models"; private static final Map ELASTIC_INFERENCE_SERVICE_TASK_TYPE_MAPPING = Map.of( "embed/text/sparse", TaskType.SPARSE_EMBEDDING, @@ -103,6 +109,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws return builder; } + + @Override + public String toString() { + return Strings.format("{modelName='%s', taskTypes='%s'}", modelName, taskTypes); + } } private final List authorizedModels; @@ -134,6 +145,11 @@ public List getAuthorizedModels() { return authorizedModels; } + @Override + public String toString() { + return authorizedModels.stream().map(AuthorizedModel::toString).collect(Collectors.joining(", ")); + } + @Override public Iterator toXContentChunked(ToXContent.Params params) { throw new UnsupportedOperationException(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java index a3b80cd216067..e003694e6b246 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java @@ -247,14 +247,15 @@ private void sendAuthorizationRequest() { } private synchronized void setAuthorizedContent(ElasticInferenceServiceAuthorizationModel auth) { - logger.debug("Received authorization response"); - var authorizedTaskTypesAndModels = authorizedContent.get().taskTypesAndModels.merge(auth) - .newLimitedToTaskTypes(EnumSet.copyOf(implementedTaskTypes)); + logger.debug(() -> Strings.format("Received authorization response, %s", auth)); + + var authorizedTaskTypesAndModels = auth.newLimitedToTaskTypes(EnumSet.copyOf(implementedTaskTypes)); + logger.debug(() -> Strings.format("Authorization entity limited to service task types, %s", authorizedTaskTypesAndModels)); // recalculate which default config ids and models are authorized now - var authorizedDefaultModelIds = getAuthorizedDefaultModelIds(auth); + var authorizedDefaultModelIds = getAuthorizedDefaultModelIds(authorizedTaskTypesAndModels); - var authorizedDefaultConfigIds = getAuthorizedDefaultConfigIds(authorizedDefaultModelIds, auth); + var authorizedDefaultConfigIds = getAuthorizedDefaultConfigIds(authorizedDefaultModelIds, authorizedTaskTypesAndModels); var authorizedDefaultModelObjects = getAuthorizedDefaultModelsObjects(authorizedDefaultModelIds); authorizedContent.set( new AuthorizedContent(authorizedTaskTypesAndModels, authorizedDefaultConfigIds, authorizedDefaultModelObjects) @@ -341,7 +342,12 @@ private void handleRevokedDefaultConfigs(Set authorizedDefaultModelIds) firstAuthorizationCompletedLatch.countDown(); }); - logger.debug("Synchronizing default inference endpoints"); + logger.debug( + () -> Strings.format( + "Synchronizing default inference endpoints, attempting to remove ids: %s", + unauthorizedDefaultInferenceEndpointIds + ) + ); modelRegistry.removeDefaultConfigs(unauthorizedDefaultInferenceEndpointIds, deleteInferenceEndpointsListener); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationModel.java index 6ff3cb950151e..efe00acbc059b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationModel.java @@ -161,4 +161,16 @@ public boolean equals(Object o) { public int hashCode() { return Objects.hash(taskTypeToModels, authorizedTaskTypes, authorizedModelIds); } + + @Override + public String toString() { + return "{" + + "taskTypeToModels=" + + taskTypeToModels + + ", authorizedTaskTypes=" + + authorizedTaskTypes + + ", authorizedModelIds=" + + authorizedModelIds + + '}'; + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java index 77381fef98128..df4784ff49119 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java @@ -9,7 +9,8 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.elasticsearch.ElasticsearchWrapperException; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.Strings; import org.elasticsearch.core.Nullable; @@ -86,25 +87,25 @@ public void getAuthorization(ActionListener newListener = ActionListener.wrap(results -> { if (results instanceof ElasticInferenceServiceAuthorizationResponseEntity authResponseEntity) { + logger.debug(() -> Strings.format("Received authorization information from gateway %s", authResponseEntity)); listener.onResponse(ElasticInferenceServiceAuthorizationModel.of(authResponseEntity)); } else { - logger.warn( - Strings.format( - FAILED_TO_RETRIEVE_MESSAGE + " Received an invalid response type: %s", - results.getClass().getSimpleName() - ) + var errorMessage = Strings.format( + "%s Received an invalid response type from the Elastic Inference Service: %s", + FAILED_TO_RETRIEVE_MESSAGE, + results.getClass().getSimpleName() ); - listener.onResponse(ElasticInferenceServiceAuthorizationModel.newDisabledService()); + + logger.warn(errorMessage); + listener.onFailure(new ElasticsearchException(errorMessage)); } requestCompleteLatch.countDown(); }, e -> { - Throwable exception = e; - if (e instanceof ElasticsearchWrapperException wrapperException) { - exception = wrapperException.getCause(); - } + // unwrap because it's likely a retry exception + var exception = ExceptionsHelper.unwrapCause(e); - logger.warn(Strings.format(FAILED_TO_RETRIEVE_MESSAGE + " Encountered an exception: %s", exception)); - listener.onResponse(ElasticInferenceServiceAuthorizationModel.newDisabledService()); + logger.warn(Strings.format(FAILED_TO_RETRIEVE_MESSAGE + " Encountered an exception: %s", exception), exception); + listener.onFailure(e); requestCompleteLatch.countDown(); }); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java index 5435d5b9a6dad..771c0731c30ae 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java @@ -53,6 +53,78 @@ public void init() throws Exception { taskQueue = new DeterministicTaskQueue(); } + public void testSecondAuthResultRevokesAuthorization() throws Exception { + var callbackCount = new AtomicInteger(0); + // we're only interested in two authorization calls which is why I'm using a value of 2 here + var latch = new CountDownLatch(2); + final AtomicReference handlerRef = new AtomicReference<>(); + + Runnable callback = () -> { + // the first authorization response contains a streaming task so we're expecting to support streaming here + if (callbackCount.incrementAndGet() == 1) { + assertThat(handlerRef.get().supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); + } + latch.countDown(); + + // we only want to run the tasks twice, so advance the time on the queue + // which flags the scheduled authorization request to be ready to run + if (callbackCount.get() == 1) { + taskQueue.advanceTime(); + } else { + try { + handlerRef.get().close(); + } catch (IOException e) { + // ignore + } + } + }; + + var requestHandler = mockAuthorizationRequestHandler( + ElasticInferenceServiceAuthorizationModel.of( + new ElasticInferenceServiceAuthorizationResponseEntity( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + "rainbow-sprinkles", + EnumSet.of(TaskType.CHAT_COMPLETION) + ) + ) + ) + ), + ElasticInferenceServiceAuthorizationModel.of(new ElasticInferenceServiceAuthorizationResponseEntity(List.of())) + ); + + handlerRef.set( + new ElasticInferenceServiceAuthorizationHandler( + createWithEmptySettings(taskQueue.getThreadPool()), + mockModelRegistry(taskQueue.getThreadPool()), + requestHandler, + initDefaultEndpoints(), + EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION), + null, + mock(Sender.class), + ElasticInferenceServiceSettingsTests.create(null, TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), + callback + ) + ); + + var handler = handlerRef.get(); + handler.init(); + taskQueue.runAllRunnableTasks(); + latch.await(Utils.TIMEOUT.getSeconds(), TimeUnit.SECONDS); + + // this should be after we've received both authorization responses, the second response will revoke authorization + + assertThat(handler.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class))); + assertThat(handler.defaultConfigIds(), is(List.of())); + assertThat(handler.supportedTaskTypes(), is(EnumSet.noneOf(TaskType.class))); + + PlainActionFuture> listener = new PlainActionFuture<>(); + handler.defaultConfigs(listener); + + var configs = listener.actionGet(); + assertThat(configs.size(), is(0)); + } + public void testSendsAnAuthorizationRequestTwice() throws Exception { var callbackCount = new AtomicInteger(0); // we're only interested in two authorization calls which is why I'm using a value of 2 here @@ -90,6 +162,10 @@ public void testSendsAnAuthorizationRequestTwice() throws Exception { ElasticInferenceServiceAuthorizationModel.of( new ElasticInferenceServiceAuthorizationResponseEntity( List.of( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + "abc", + EnumSet.of(TaskType.SPARSE_EMBEDDING) + ), new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( "rainbow-sprinkles", EnumSet.of(TaskType.CHAT_COMPLETION) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java index 380c0e8b3be94..aa562b176d5e8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.inference.services.elastic.authorization; import org.apache.logging.log4j.Logger; +import org.elasticsearch.ElasticsearchException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.common.settings.Settings; @@ -18,6 +19,7 @@ import org.elasticsearch.test.http.MockResponse; import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.XContentParseException; import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; @@ -38,13 +40,14 @@ import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.retry.RetryingHttpSender.MAX_RETIES; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; public class ElasticInferenceServiceAuthorizationRequestHandlerTests extends ESTestCase { @@ -135,22 +138,17 @@ public void testGetAuthorization_FailsWhenAnInvalidFieldIsFound() throws IOExcep PlainActionFuture listener = new PlainActionFuture<>(); authHandler.getAuthorization(listener, sender); - var authResponse = listener.actionGet(TIMEOUT); - assertTrue(authResponse.getAuthorizedTaskTypes().isEmpty()); - assertTrue(authResponse.getAuthorizedModelIds().isEmpty()); - assertFalse(authResponse.isAuthorized()); + var exception = expectThrows(XContentParseException.class, () -> listener.actionGet(TIMEOUT)); + assertThat(exception.getMessage(), containsString("failed to parse field [models]")); - var loggerArgsCaptor = ArgumentCaptor.forClass(String.class); - verify(logger).warn(loggerArgsCaptor.capture()); - var message = loggerArgsCaptor.getValue(); - assertThat( - message, - is( - "Failed to retrieve the authorization information from the Elastic Inference Service." - + " Encountered an exception: org.elasticsearch.xcontent.XContentParseException: [4:28] " - + "[ElasticInferenceServiceAuthorizationResponseEntity] failed to parse field [models]" - ) - ); + var stringCaptor = ArgumentCaptor.forClass(String.class); + var exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(logger).warn(stringCaptor.capture(), exceptionCaptor.capture()); + var message = stringCaptor.getValue(); + assertThat(message, containsString("failed to parse field [models]")); + + var capturedException = exceptionCaptor.getValue(); + assertThat(capturedException, instanceOf(XContentParseException.class)); } } @@ -196,7 +194,6 @@ public void testGetAuthorization_ReturnsAValidResponse() throws IOException { var message = loggerArgsCaptor.getValue(); assertThat(message, is("Retrieving authorization information from the Elastic Inference Service.")); - verifyNoMoreInteractions(logger); } } @@ -230,7 +227,6 @@ public void testGetAuthorization_OnResponseCalledOnce() throws IOException { var message = loggerArgsCaptor.getValue(); assertThat(message, is("Retrieving authorization information from the Elastic Inference Service.")); - verifyNoMoreInteractions(logger); } } @@ -252,20 +248,14 @@ public void testGetAuthorization_InvalidResponse() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); authHandler.getAuthorization(listener, sender); - var result = listener.actionGet(TIMEOUT); + var exception = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); - assertThat(result, is(ElasticInferenceServiceAuthorizationModel.newDisabledService())); + assertThat(exception.getMessage(), containsString("Received an invalid response type from the Elastic Inference Service")); var loggerArgsCaptor = ArgumentCaptor.forClass(String.class); verify(logger).warn(loggerArgsCaptor.capture()); var message = loggerArgsCaptor.getValue(); - assertThat( - message, - is( - "Failed to retrieve the authorization information from the Elastic Inference Service." - + " Received an invalid response type: ChatCompletionResults" - ) - ); + assertThat(message, containsString("Failed to retrieve the authorization information from the Elastic Inference Service.")); } }