Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/132546.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 132546
summary: Improve EIS auth call logs and fix revocation bug
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@

package org.elasticsearch.xpack.inference.external.response.elastic;

import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.logging.LogManager;
import org.elasticsearch.logging.Logger;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
Expand All @@ -39,6 +42,9 @@
public class ElasticInferenceServiceAuthorizationResponseEntity implements InferenceServiceResults {

public static final String NAME = "elastic_inference_service_auth_results";

private static final Logger logger = LogManager.getLogger(ElasticInferenceServiceAuthorizationResponseEntity.class);
private static final String AUTH_FIELD_NAME = "authorized_models";
private static final Map<String, TaskType> ELASTIC_INFERENCE_SERVICE_TASK_TYPE_MAPPING = Map.of(
"embed/text/sparse",
TaskType.SPARSE_EMBEDDING,
Expand Down Expand Up @@ -103,6 +109,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws

return builder;
}

@Override
public String toString() {
return Strings.format("{modelName='%s', taskTypes='%s'}", modelName, taskTypes);
}
}

private final List<AuthorizedModel> authorizedModels;
Expand Down Expand Up @@ -134,6 +145,11 @@ public List<AuthorizedModel> getAuthorizedModels() {
return authorizedModels;
}

@Override
public String toString() {
return authorizedModels.stream().map(AuthorizedModel::toString).collect(Collectors.joining(", "));
}

@Override
public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
throw new UnsupportedOperationException();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,14 +247,15 @@ private void sendAuthorizationRequest() {
}

private synchronized void setAuthorizedContent(ElasticInferenceServiceAuthorizationModel auth) {
logger.debug("Received authorization response");
var authorizedTaskTypesAndModels = authorizedContent.get().taskTypesAndModels.merge(auth)
.newLimitedToTaskTypes(EnumSet.copyOf(implementedTaskTypes));
logger.debug(() -> Strings.format("Received authorization response, %s", auth));

var authorizedTaskTypesAndModels = auth.newLimitedToTaskTypes(EnumSet.copyOf(implementedTaskTypes));
logger.debug(() -> Strings.format("Authorization entity limited to service task types, %s", authorizedTaskTypesAndModels));

// recalculate which default config ids and models are authorized now
var authorizedDefaultModelIds = getAuthorizedDefaultModelIds(auth);
var authorizedDefaultModelIds = getAuthorizedDefaultModelIds(authorizedTaskTypesAndModels);

var authorizedDefaultConfigIds = getAuthorizedDefaultConfigIds(authorizedDefaultModelIds, auth);
var authorizedDefaultConfigIds = getAuthorizedDefaultConfigIds(authorizedDefaultModelIds, authorizedTaskTypesAndModels);
var authorizedDefaultModelObjects = getAuthorizedDefaultModelsObjects(authorizedDefaultModelIds);
authorizedContent.set(
new AuthorizedContent(authorizedTaskTypesAndModels, authorizedDefaultConfigIds, authorizedDefaultModelObjects)
Expand Down Expand Up @@ -341,7 +342,12 @@ private void handleRevokedDefaultConfigs(Set<String> authorizedDefaultModelIds)
firstAuthorizationCompletedLatch.countDown();
});

logger.debug("Synchronizing default inference endpoints");
logger.debug(
() -> Strings.format(
"Synchronizing default inference endpoints, attempting to remove ids: %s",
unauthorizedDefaultInferenceEndpointIds
)
);
modelRegistry.removeDefaultConfigs(unauthorizedDefaultInferenceEndpointIds, deleteInferenceEndpointsListener);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -161,4 +161,16 @@ public boolean equals(Object o) {
public int hashCode() {
return Objects.hash(taskTypeToModels, authorizedTaskTypes, authorizedModelIds);
}

@Override
public String toString() {
return "{"
+ "taskTypeToModels="
+ taskTypeToModels
+ ", authorizedTaskTypes="
+ authorizedTaskTypes
+ ", authorizedModelIds="
+ authorizedModelIds
+ '}';
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchWrapperException;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.Strings;
import org.elasticsearch.core.Nullable;
Expand Down Expand Up @@ -86,25 +87,25 @@ public void getAuthorization(ActionListener<ElasticInferenceServiceAuthorization

ActionListener<InferenceServiceResults> newListener = ActionListener.wrap(results -> {
if (results instanceof ElasticInferenceServiceAuthorizationResponseEntity authResponseEntity) {
logger.debug(() -> Strings.format("Received authorization information from gateway %s", authResponseEntity));
listener.onResponse(ElasticInferenceServiceAuthorizationModel.of(authResponseEntity));
} else {
logger.warn(
Strings.format(
FAILED_TO_RETRIEVE_MESSAGE + " Received an invalid response type: %s",
results.getClass().getSimpleName()
)
var errorMessage = Strings.format(
"%s Received an invalid response type from the Elastic Inference Service: %s",
FAILED_TO_RETRIEVE_MESSAGE,
results.getClass().getSimpleName()
);
listener.onResponse(ElasticInferenceServiceAuthorizationModel.newDisabledService());

logger.warn(errorMessage);
listener.onFailure(new ElasticsearchException(errorMessage));
}
requestCompleteLatch.countDown();
}, e -> {
Throwable exception = e;
if (e instanceof ElasticsearchWrapperException wrapperException) {
exception = wrapperException.getCause();
}
// unwrap because it's likely a retry exception
var exception = ExceptionsHelper.unwrapCause(e);

logger.warn(Strings.format(FAILED_TO_RETRIEVE_MESSAGE + " Encountered an exception: %s", exception));
listener.onResponse(ElasticInferenceServiceAuthorizationModel.newDisabledService());
logger.warn(Strings.format(FAILED_TO_RETRIEVE_MESSAGE + " Encountered an exception: %s", exception), exception);
listener.onFailure(e);
requestCompleteLatch.countDown();
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,78 @@ public void init() throws Exception {
taskQueue = new DeterministicTaskQueue();
}

public void testSecondAuthResultRevokesAuthorization() throws Exception {
var callbackCount = new AtomicInteger(0);
// we're only interested in two authorization calls which is why I'm using a value of 2 here
var latch = new CountDownLatch(2);
final AtomicReference<ElasticInferenceServiceAuthorizationHandler> handlerRef = new AtomicReference<>();

Runnable callback = () -> {
// the first authorization response contains a streaming task so we're expecting to support streaming here
if (callbackCount.incrementAndGet() == 1) {
assertThat(handlerRef.get().supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));
}
latch.countDown();

// we only want to run the tasks twice, so advance the time on the queue
// which flags the scheduled authorization request to be ready to run
if (callbackCount.get() == 1) {
taskQueue.advanceTime();
} else {
try {
handlerRef.get().close();
} catch (IOException e) {
// ignore
}
}
};

var requestHandler = mockAuthorizationRequestHandler(
ElasticInferenceServiceAuthorizationModel.of(
new ElasticInferenceServiceAuthorizationResponseEntity(
List.of(
new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel(
"rainbow-sprinkles",
EnumSet.of(TaskType.CHAT_COMPLETION)
)
)
)
),
ElasticInferenceServiceAuthorizationModel.of(new ElasticInferenceServiceAuthorizationResponseEntity(List.of()))
);

handlerRef.set(
new ElasticInferenceServiceAuthorizationHandler(
createWithEmptySettings(taskQueue.getThreadPool()),
mockModelRegistry(taskQueue.getThreadPool()),
requestHandler,
initDefaultEndpoints(),
EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION),
null,
mock(Sender.class),
ElasticInferenceServiceSettingsTests.create(null, TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true),
callback
)
);

var handler = handlerRef.get();
handler.init();
taskQueue.runAllRunnableTasks();
latch.await(Utils.TIMEOUT.getSeconds(), TimeUnit.SECONDS);

// this should be after we've received both authorization responses, the second response will revoke authorization

assertThat(handler.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class)));
assertThat(handler.defaultConfigIds(), is(List.of()));
assertThat(handler.supportedTaskTypes(), is(EnumSet.noneOf(TaskType.class)));

PlainActionFuture<List<Model>> listener = new PlainActionFuture<>();
handler.defaultConfigs(listener);

var configs = listener.actionGet();
assertThat(configs.size(), is(0));
}

public void testSendsAnAuthorizationRequestTwice() throws Exception {
var callbackCount = new AtomicInteger(0);
// we're only interested in two authorization calls which is why I'm using a value of 2 here
Expand Down Expand Up @@ -90,6 +162,10 @@ public void testSendsAnAuthorizationRequestTwice() throws Exception {
ElasticInferenceServiceAuthorizationModel.of(
new ElasticInferenceServiceAuthorizationResponseEntity(
List.of(
new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel(
"abc",
EnumSet.of(TaskType.SPARSE_EMBEDDING)
),
new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel(
"rainbow-sprinkles",
EnumSet.of(TaskType.CHAT_COMPLETION)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package org.elasticsearch.xpack.inference.services.elastic.authorization;

import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.common.settings.Settings;
Expand All @@ -18,6 +19,7 @@
import org.elasticsearch.test.http.MockResponse;
import org.elasticsearch.test.http.MockWebServer;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.XContentParseException;
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
Expand All @@ -38,13 +40,14 @@
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.external.http.retry.RetryingHttpSender.MAX_RETIES;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;

public class ElasticInferenceServiceAuthorizationRequestHandlerTests extends ESTestCase {
Expand Down Expand Up @@ -135,22 +138,17 @@ public void testGetAuthorization_FailsWhenAnInvalidFieldIsFound() throws IOExcep
PlainActionFuture<ElasticInferenceServiceAuthorizationModel> listener = new PlainActionFuture<>();
authHandler.getAuthorization(listener, sender);

var authResponse = listener.actionGet(TIMEOUT);
assertTrue(authResponse.getAuthorizedTaskTypes().isEmpty());
assertTrue(authResponse.getAuthorizedModelIds().isEmpty());
assertFalse(authResponse.isAuthorized());
var exception = expectThrows(XContentParseException.class, () -> listener.actionGet(TIMEOUT));
assertThat(exception.getMessage(), containsString("failed to parse field [models]"));

var loggerArgsCaptor = ArgumentCaptor.forClass(String.class);
verify(logger).warn(loggerArgsCaptor.capture());
var message = loggerArgsCaptor.getValue();
assertThat(
message,
is(
"Failed to retrieve the authorization information from the Elastic Inference Service."
+ " Encountered an exception: org.elasticsearch.xcontent.XContentParseException: [4:28] "
+ "[ElasticInferenceServiceAuthorizationResponseEntity] failed to parse field [models]"
)
);
var stringCaptor = ArgumentCaptor.forClass(String.class);
var exceptionCaptor = ArgumentCaptor.forClass(Exception.class);
verify(logger).warn(stringCaptor.capture(), exceptionCaptor.capture());
var message = stringCaptor.getValue();
assertThat(message, containsString("failed to parse field [models]"));

var capturedException = exceptionCaptor.getValue();
assertThat(capturedException, instanceOf(XContentParseException.class));
}
}

Expand Down Expand Up @@ -196,7 +194,6 @@ public void testGetAuthorization_ReturnsAValidResponse() throws IOException {

var message = loggerArgsCaptor.getValue();
assertThat(message, is("Retrieving authorization information from the Elastic Inference Service."));
verifyNoMoreInteractions(logger);
}
}

Expand Down Expand Up @@ -230,7 +227,6 @@ public void testGetAuthorization_OnResponseCalledOnce() throws IOException {

var message = loggerArgsCaptor.getValue();
assertThat(message, is("Retrieving authorization information from the Elastic Inference Service."));
verifyNoMoreInteractions(logger);
}
}

Expand All @@ -252,20 +248,14 @@ public void testGetAuthorization_InvalidResponse() throws IOException {
PlainActionFuture<ElasticInferenceServiceAuthorizationModel> listener = new PlainActionFuture<>();

authHandler.getAuthorization(listener, sender);
var result = listener.actionGet(TIMEOUT);
var exception = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));

assertThat(result, is(ElasticInferenceServiceAuthorizationModel.newDisabledService()));
assertThat(exception.getMessage(), containsString("Received an invalid response type from the Elastic Inference Service"));

var loggerArgsCaptor = ArgumentCaptor.forClass(String.class);
verify(logger).warn(loggerArgsCaptor.capture());
var message = loggerArgsCaptor.getValue();
assertThat(
message,
is(
"Failed to retrieve the authorization information from the Elastic Inference Service."
+ " Received an invalid response type: ChatCompletionResults"
)
);
assertThat(message, containsString("Failed to retrieve the authorization information from the Elastic Inference Service."));
}

}
Expand Down