Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -243,14 +243,15 @@ private void sendAuthorizationRequest() {
}

private synchronized void setAuthorizedContent(ElasticInferenceServiceAuthorizationModel auth) {
logger.debug("Received authorization response");
var authorizedTaskTypesAndModels = authorizedContent.get().taskTypesAndModels.merge(auth)
.newLimitedToTaskTypes(EnumSet.copyOf(implementedTaskTypes));
logger.debug(() -> Strings.format("Received authorization response, %s", auth));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For future reference, loggers have a formatter, I believe, something like:

logger.debug("Received authorization response, {}", auth);


var authorizedTaskTypesAndModels = auth.newLimitedToTaskTypes(EnumSet.copyOf(implementedTaskTypes));
logger.debug(() -> Strings.format("Authorization entity limited to service task types, %s", authorizedTaskTypesAndModels));

// recalculate which default config ids and models are authorized now
var authorizedDefaultModelIds = getAuthorizedDefaultModelIds(auth);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The bug is here and the line below where we reference auth instead of authorizedTaskTypesAndModels. In the fixed version we're using the auth.newLimitedToTaskTypes response instead.

var authorizedDefaultModelIds = getAuthorizedDefaultModelIds(authorizedTaskTypesAndModels);

var authorizedDefaultConfigIds = getAuthorizedDefaultConfigIds(authorizedDefaultModelIds, auth);
var authorizedDefaultConfigIds = getAuthorizedDefaultConfigIds(authorizedDefaultModelIds, authorizedTaskTypesAndModels);
var authorizedDefaultModelObjects = getAuthorizedDefaultModelsObjects(authorizedDefaultModelIds);
authorizedContent.set(
new AuthorizedContent(authorizedTaskTypesAndModels, authorizedDefaultConfigIds, authorizedDefaultModelObjects)
Expand Down Expand Up @@ -337,7 +338,12 @@ private void handleRevokedDefaultConfigs(Set<String> authorizedDefaultModelIds)
firstAuthorizationCompletedLatch.countDown();
});

logger.debug("Synchronizing default inference endpoints");
logger.debug(
() -> Strings.format(
"Synchronizing default inference endpoints, attempting to remove ids: %s",
unauthorizedDefaultInferenceEndpointIds
)
);
modelRegistry.removeDefaultConfigs(unauthorizedDefaultInferenceEndpointIds, deleteInferenceEndpointsListener);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -161,4 +161,13 @@ public boolean equals(Object o) {
public int hashCode() {
return Objects.hash(taskTypeToModels, authorizedTaskTypes, authorizedModelIds);
}

@Override
public String toString() {
return "{" +
"taskTypeToModels=" + taskTypeToModels +
", authorizedTaskTypes=" + authorizedTaskTypes +
", authorizedModelIds=" + authorizedModelIds +
'}';
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchWrapperException;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.Strings;
import org.elasticsearch.core.Nullable;
Expand Down Expand Up @@ -86,25 +87,25 @@ public void getAuthorization(ActionListener<ElasticInferenceServiceAuthorization

ActionListener<InferenceServiceResults> newListener = ActionListener.wrap(results -> {
if (results instanceof ElasticInferenceServiceAuthorizationResponseEntity authResponseEntity) {
logger.debug(() -> Strings.format("Received authorization information from gateway %s", authResponseEntity));
listener.onResponse(ElasticInferenceServiceAuthorizationModel.of(authResponseEntity));
} else {
logger.warn(
Strings.format(
FAILED_TO_RETRIEVE_MESSAGE + " Received an invalid response type: %s",
results.getClass().getSimpleName()
)
var errorMessage = Strings.format(
"%s Received an invalid response type from the Elastic Inference Service: %s",
FAILED_TO_RETRIEVE_MESSAGE,
results.getClass().getSimpleName()
);
listener.onResponse(ElasticInferenceServiceAuthorizationModel.newDisabledService());

logger.warn(errorMessage);
listener.onFailure(new ElasticsearchException(errorMessage));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're now returning an error when a failure occurs. The caller will ignore the error and it will not affect the authorization. This way we can differentiate between a failure authorization and a successful one. Whenever onResponse is called we'll use that response object as the source of truth for authorization.

}
requestCompleteLatch.countDown();
}, e -> {
Throwable exception = e;
if (e instanceof ElasticsearchWrapperException wrapperException) {
exception = wrapperException.getCause();
}
// unwrap because it's likely a retry exception
var exception = ExceptionsHelper.unwrapCause(e);

logger.warn(Strings.format(FAILED_TO_RETRIEVE_MESSAGE + " Encountered an exception: %s", exception));
listener.onResponse(ElasticInferenceServiceAuthorizationModel.newDisabledService());
logger.warn(Strings.format(FAILED_TO_RETRIEVE_MESSAGE + " Encountered an exception: %s", exception), exception);
listener.onFailure(e);
requestCompleteLatch.countDown();
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,18 @@

package org.elasticsearch.xpack.inference.services.elastic.response;

import org.elasticsearch.common.Strings;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.logging.LogManager;
import org.elasticsearch.logging.Logger;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
Expand All @@ -23,10 +28,12 @@
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xcontent.json.JsonXContent;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.request.Request;

import java.io.IOException;
import java.util.Arrays;
import java.util.EnumSet;
import java.util.Iterator;
import java.util.List;
Expand All @@ -39,6 +46,9 @@
public class ElasticInferenceServiceAuthorizationResponseEntity implements InferenceServiceResults {

public static final String NAME = "elastic_inference_service_auth_results";

private static final Logger logger = LogManager.getLogger(ElasticInferenceServiceAuthorizationResponseEntity.class);
private static final String AUTH_FIELD_NAME = "authorized_models";
private static final Map<String, TaskType> ELASTIC_INFERENCE_SERVICE_TASK_TYPE_MAPPING = Map.of(
"embed/text/sparse",
TaskType.SPARSE_EMBEDDING,
Expand Down Expand Up @@ -107,6 +117,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws

return builder;
}

@Override
public String toString() {
return Strings.format("{modelName='%s', taskTypes='%s'}", modelName, taskTypes);
}
}

private final List<AuthorizedModel> authorizedModels;
Expand Down Expand Up @@ -138,6 +153,11 @@ public List<AuthorizedModel> getAuthorizedModels() {
return authorizedModels;
}

@Override
public String toString() {
return String.join(", ", authorizedModels.stream().map(AuthorizedModel::toString).toList());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return String.join(", ", authorizedModels.stream().map(AuthorizedModel::toString).toList());
return authorizedModels.stream().map(AuthorizedModel::toString).collect(Collectors.joining(", "));

}

@Override
public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
throw new UnsupportedOperationException();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,78 @@ public void init() throws Exception {
modelRegistry = getInstanceFromNode(ModelRegistry.class);
}

public void testSecondAuthResultRevokesAuthorization() throws Exception {
var callbackCount = new AtomicInteger(0);
// we're only interested in two authorization calls which is why I'm using a value of 2 here
var latch = new CountDownLatch(2);
final AtomicReference<ElasticInferenceServiceAuthorizationHandler> handlerRef = new AtomicReference<>();

Runnable callback = () -> {
// the first authorization response contains a streaming task so we're expecting to support streaming here
if (callbackCount.incrementAndGet() == 1) {
assertThat(handlerRef.get().supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));
}
latch.countDown();

// we only want to run the tasks twice, so advance the time on the queue
// which flags the scheduled authorization request to be ready to run
if (callbackCount.get() == 1) {
taskQueue.advanceTime();
} else {
try {
handlerRef.get().close();
} catch (IOException e) {
// ignore
}
}
};

var requestHandler = mockAuthorizationRequestHandler(
ElasticInferenceServiceAuthorizationModel.of(
new ElasticInferenceServiceAuthorizationResponseEntity(
List.of(
new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel(
"rainbow-sprinkles",
EnumSet.of(TaskType.CHAT_COMPLETION)
)
)
)
),
ElasticInferenceServiceAuthorizationModel.of(new ElasticInferenceServiceAuthorizationResponseEntity(List.of()))
);

handlerRef.set(
new ElasticInferenceServiceAuthorizationHandler(
createWithEmptySettings(taskQueue.getThreadPool()),
modelRegistry,
requestHandler,
initDefaultEndpoints(),
EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION),
null,
mock(Sender.class),
ElasticInferenceServiceSettingsTests.create(null, TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true),
callback
)
);

var handler = handlerRef.get();
handler.init();
taskQueue.runAllRunnableTasks();
latch.await(Utils.TIMEOUT.getSeconds(), TimeUnit.SECONDS);

// this should be after we've received both authorization responses, the second response will revoke authorization

assertThat(handler.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class)));
assertThat(handler.defaultConfigIds(), is(List.of()));
assertThat(handler.supportedTaskTypes(), is(EnumSet.noneOf(TaskType.class)));

PlainActionFuture<List<Model>> listener = new PlainActionFuture<>();
handler.defaultConfigs(listener);

var configs = listener.actionGet();
assertThat(configs.size(), is(0));
}

public void testSendsAnAuthorizationRequestTwice() throws Exception {
var callbackCount = new AtomicInteger(0);
// we're only interested in two authorization calls which is why I'm using a value of 2 here
Expand Down Expand Up @@ -104,6 +176,10 @@ public void testSendsAnAuthorizationRequestTwice() throws Exception {
ElasticInferenceServiceAuthorizationModel.of(
new ElasticInferenceServiceAuthorizationResponseEntity(
List.of(
new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously the first response was merged with the second auth response. Now the latest successful auth response dictates the auth so we need to repeat the model again in the second response for this test.

"abc",
EnumSet.of(TaskType.SPARSE_EMBEDDING)
),
new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel(
"rainbow-sprinkles",
EnumSet.of(TaskType.CHAT_COMPLETION)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package org.elasticsearch.xpack.inference.services.elastic.authorization;

import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.common.settings.Settings;
Expand All @@ -18,6 +19,7 @@
import org.elasticsearch.test.http.MockResponse;
import org.elasticsearch.test.http.MockWebServer;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.XContentParseException;
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
Expand All @@ -38,13 +40,14 @@
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.external.http.retry.RetryingHttpSender.MAX_RETIES;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;

public class ElasticInferenceServiceAuthorizationRequestHandlerTests extends ESTestCase {
Expand Down Expand Up @@ -135,22 +138,17 @@ public void testGetAuthorization_FailsWhenAnInvalidFieldIsFound() throws IOExcep
PlainActionFuture<ElasticInferenceServiceAuthorizationModel> listener = new PlainActionFuture<>();
authHandler.getAuthorization(listener, sender);

var authResponse = listener.actionGet(TIMEOUT);
assertTrue(authResponse.getAuthorizedTaskTypes().isEmpty());
assertTrue(authResponse.getAuthorizedModelIds().isEmpty());
assertFalse(authResponse.isAuthorized());
var exception = expectThrows(XContentParseException.class, () -> listener.actionGet(TIMEOUT));
assertThat(exception.getMessage(), containsString("failed to parse field [models]"));

var loggerArgsCaptor = ArgumentCaptor.forClass(String.class);
verify(logger).warn(loggerArgsCaptor.capture());
var message = loggerArgsCaptor.getValue();
assertThat(
message,
is(
"Failed to retrieve the authorization information from the Elastic Inference Service."
+ " Encountered an exception: org.elasticsearch.xcontent.XContentParseException: [4:28] "
+ "[ElasticInferenceServiceAuthorizationResponseEntity] failed to parse field [models]"
)
);
var stringCaptor = ArgumentCaptor.forClass(String.class);
var exceptionCaptor = ArgumentCaptor.forClass(Exception.class);
verify(logger).warn(stringCaptor.capture(), exceptionCaptor.capture());
var message = stringCaptor.getValue();
assertThat(message, containsString("failed to parse field [models]"));

var capturedException = exceptionCaptor.getValue();
assertThat(capturedException, instanceOf(XContentParseException.class));
}
}

Expand Down Expand Up @@ -196,7 +194,6 @@ public void testGetAuthorization_ReturnsAValidResponse() throws IOException {

var message = loggerArgsCaptor.getValue();
assertThat(message, is("Retrieving authorization information from the Elastic Inference Service."));
verifyNoMoreInteractions(logger);
}
}

Expand Down Expand Up @@ -230,7 +227,6 @@ public void testGetAuthorization_OnResponseCalledOnce() throws IOException {

var message = loggerArgsCaptor.getValue();
assertThat(message, is("Retrieving authorization information from the Elastic Inference Service."));
verifyNoMoreInteractions(logger);
}
}

Expand All @@ -252,20 +248,14 @@ public void testGetAuthorization_InvalidResponse() throws IOException {
PlainActionFuture<ElasticInferenceServiceAuthorizationModel> listener = new PlainActionFuture<>();

authHandler.getAuthorization(listener, sender);
var result = listener.actionGet(TIMEOUT);
var exception = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));

assertThat(result, is(ElasticInferenceServiceAuthorizationModel.newDisabledService()));
assertThat(exception.getMessage(), containsString("Received an invalid response type from the Elastic Inference Service"));

var loggerArgsCaptor = ArgumentCaptor.forClass(String.class);
verify(logger).warn(loggerArgsCaptor.capture());
var message = loggerArgsCaptor.getValue();
assertThat(
message,
is(
"Failed to retrieve the authorization information from the Elastic Inference Service."
+ " Received an invalid response type: ChatCompletionResults"
)
);
assertThat(message, containsString("Failed to retrieve the authorization information from the Elastic Inference Service."));
}

}
Expand Down
Loading