-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[ML] Improve EIS auth call logs and fix revocation bug #132546
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
ff6cfec
7895cca
abff092
5d2a2d3
f56cee4
48b3b68
8887a1f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -243,14 +243,15 @@ private void sendAuthorizationRequest() { | |
} | ||
|
||
private synchronized void setAuthorizedContent(ElasticInferenceServiceAuthorizationModel auth) { | ||
logger.debug("Received authorization response"); | ||
var authorizedTaskTypesAndModels = authorizedContent.get().taskTypesAndModels.merge(auth) | ||
.newLimitedToTaskTypes(EnumSet.copyOf(implementedTaskTypes)); | ||
logger.debug(() -> Strings.format("Received authorization response, %s", auth)); | ||
|
||
var authorizedTaskTypesAndModels = auth.newLimitedToTaskTypes(EnumSet.copyOf(implementedTaskTypes)); | ||
logger.debug(() -> Strings.format("Authorization entity limited to service task types, %s", authorizedTaskTypesAndModels)); | ||
|
||
// recalculate which default config ids and models are authorized now | ||
var authorizedDefaultModelIds = getAuthorizedDefaultModelIds(auth); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The bug is here and the line below where we reference |
||
var authorizedDefaultModelIds = getAuthorizedDefaultModelIds(authorizedTaskTypesAndModels); | ||
|
||
var authorizedDefaultConfigIds = getAuthorizedDefaultConfigIds(authorizedDefaultModelIds, auth); | ||
var authorizedDefaultConfigIds = getAuthorizedDefaultConfigIds(authorizedDefaultModelIds, authorizedTaskTypesAndModels); | ||
var authorizedDefaultModelObjects = getAuthorizedDefaultModelsObjects(authorizedDefaultModelIds); | ||
authorizedContent.set( | ||
new AuthorizedContent(authorizedTaskTypesAndModels, authorizedDefaultConfigIds, authorizedDefaultModelObjects) | ||
|
@@ -337,7 +338,12 @@ private void handleRevokedDefaultConfigs(Set<String> authorizedDefaultModelIds) | |
firstAuthorizationCompletedLatch.countDown(); | ||
}); | ||
|
||
logger.debug("Synchronizing default inference endpoints"); | ||
logger.debug( | ||
() -> Strings.format( | ||
"Synchronizing default inference endpoints, attempting to remove ids: %s", | ||
unauthorizedDefaultInferenceEndpointIds | ||
) | ||
); | ||
modelRegistry.removeDefaultConfigs(unauthorizedDefaultInferenceEndpointIds, deleteInferenceEndpointsListener); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,7 +9,8 @@ | |
|
||
import org.apache.logging.log4j.LogManager; | ||
import org.apache.logging.log4j.Logger; | ||
import org.elasticsearch.ElasticsearchWrapperException; | ||
import org.elasticsearch.ElasticsearchException; | ||
import org.elasticsearch.ExceptionsHelper; | ||
import org.elasticsearch.action.ActionListener; | ||
import org.elasticsearch.common.Strings; | ||
import org.elasticsearch.core.Nullable; | ||
|
@@ -86,25 +87,25 @@ public void getAuthorization(ActionListener<ElasticInferenceServiceAuthorization | |
|
||
ActionListener<InferenceServiceResults> newListener = ActionListener.wrap(results -> { | ||
if (results instanceof ElasticInferenceServiceAuthorizationResponseEntity authResponseEntity) { | ||
logger.debug(() -> Strings.format("Received authorization information from gateway %s", authResponseEntity)); | ||
listener.onResponse(ElasticInferenceServiceAuthorizationModel.of(authResponseEntity)); | ||
} else { | ||
logger.warn( | ||
Strings.format( | ||
FAILED_TO_RETRIEVE_MESSAGE + " Received an invalid response type: %s", | ||
results.getClass().getSimpleName() | ||
) | ||
var errorMessage = Strings.format( | ||
"%s Received an invalid response type from the Elastic Inference Service: %s", | ||
FAILED_TO_RETRIEVE_MESSAGE, | ||
results.getClass().getSimpleName() | ||
); | ||
listener.onResponse(ElasticInferenceServiceAuthorizationModel.newDisabledService()); | ||
|
||
logger.warn(errorMessage); | ||
listener.onFailure(new ElasticsearchException(errorMessage)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We're now returning an error when a failure occurs. The caller will ignore the error and it will not affect the authorization. This way we can differentiate between a failure authorization and a successful one. Whenever |
||
} | ||
requestCompleteLatch.countDown(); | ||
}, e -> { | ||
Throwable exception = e; | ||
if (e instanceof ElasticsearchWrapperException wrapperException) { | ||
exception = wrapperException.getCause(); | ||
} | ||
// unwrap because it's likely a retry exception | ||
var exception = ExceptionsHelper.unwrapCause(e); | ||
|
||
logger.warn(Strings.format(FAILED_TO_RETRIEVE_MESSAGE + " Encountered an exception: %s", exception)); | ||
listener.onResponse(ElasticInferenceServiceAuthorizationModel.newDisabledService()); | ||
logger.warn(Strings.format(FAILED_TO_RETRIEVE_MESSAGE + " Encountered an exception: %s", exception), exception); | ||
listener.onFailure(e); | ||
requestCompleteLatch.countDown(); | ||
}); | ||
|
||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -7,13 +7,18 @@ | |||||
|
||||||
package org.elasticsearch.xpack.inference.services.elastic.response; | ||||||
|
||||||
import org.elasticsearch.common.Strings; | ||||||
import org.elasticsearch.common.bytes.BytesReference; | ||||||
import org.elasticsearch.common.io.stream.StreamInput; | ||||||
import org.elasticsearch.common.io.stream.StreamOutput; | ||||||
import org.elasticsearch.common.io.stream.Writeable; | ||||||
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper; | ||||||
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; | ||||||
import org.elasticsearch.inference.InferenceResults; | ||||||
import org.elasticsearch.inference.InferenceServiceResults; | ||||||
import org.elasticsearch.inference.TaskType; | ||||||
import org.elasticsearch.logging.LogManager; | ||||||
import org.elasticsearch.logging.Logger; | ||||||
import org.elasticsearch.xcontent.ConstructingObjectParser; | ||||||
import org.elasticsearch.xcontent.ParseField; | ||||||
import org.elasticsearch.xcontent.ToXContent; | ||||||
|
@@ -23,10 +28,12 @@ | |||||
import org.elasticsearch.xcontent.XContentParser; | ||||||
import org.elasticsearch.xcontent.XContentParserConfiguration; | ||||||
import org.elasticsearch.xcontent.XContentType; | ||||||
import org.elasticsearch.xcontent.json.JsonXContent; | ||||||
import org.elasticsearch.xpack.inference.external.http.HttpResult; | ||||||
import org.elasticsearch.xpack.inference.external.request.Request; | ||||||
|
||||||
import java.io.IOException; | ||||||
import java.util.Arrays; | ||||||
import java.util.EnumSet; | ||||||
import java.util.Iterator; | ||||||
import java.util.List; | ||||||
|
@@ -39,6 +46,9 @@ | |||||
public class ElasticInferenceServiceAuthorizationResponseEntity implements InferenceServiceResults { | ||||||
|
||||||
public static final String NAME = "elastic_inference_service_auth_results"; | ||||||
|
||||||
private static final Logger logger = LogManager.getLogger(ElasticInferenceServiceAuthorizationResponseEntity.class); | ||||||
private static final String AUTH_FIELD_NAME = "authorized_models"; | ||||||
private static final Map<String, TaskType> ELASTIC_INFERENCE_SERVICE_TASK_TYPE_MAPPING = Map.of( | ||||||
"embed/text/sparse", | ||||||
TaskType.SPARSE_EMBEDDING, | ||||||
|
@@ -107,6 +117,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws | |||||
|
||||||
return builder; | ||||||
} | ||||||
|
||||||
@Override | ||||||
public String toString() { | ||||||
return Strings.format("{modelName='%s', taskTypes='%s'}", modelName, taskTypes); | ||||||
} | ||||||
} | ||||||
|
||||||
private final List<AuthorizedModel> authorizedModels; | ||||||
|
@@ -138,6 +153,11 @@ public List<AuthorizedModel> getAuthorizedModels() { | |||||
return authorizedModels; | ||||||
} | ||||||
|
||||||
@Override | ||||||
public String toString() { | ||||||
return String.join(", ", authorizedModels.stream().map(AuthorizedModel::toString).toList()); | ||||||
|
return String.join(", ", authorizedModels.stream().map(AuthorizedModel::toString).toList()); | |
return authorizedModels.stream().map(AuthorizedModel::toString).collect(Collectors.joining(", ")); |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -67,6 +67,78 @@ public void init() throws Exception { | |
modelRegistry = getInstanceFromNode(ModelRegistry.class); | ||
} | ||
|
||
public void testSecondAuthResultRevokesAuthorization() throws Exception { | ||
var callbackCount = new AtomicInteger(0); | ||
// we're only interested in two authorization calls which is why I'm using a value of 2 here | ||
var latch = new CountDownLatch(2); | ||
final AtomicReference<ElasticInferenceServiceAuthorizationHandler> handlerRef = new AtomicReference<>(); | ||
|
||
Runnable callback = () -> { | ||
// the first authorization response contains a streaming task so we're expecting to support streaming here | ||
if (callbackCount.incrementAndGet() == 1) { | ||
assertThat(handlerRef.get().supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); | ||
} | ||
latch.countDown(); | ||
|
||
// we only want to run the tasks twice, so advance the time on the queue | ||
// which flags the scheduled authorization request to be ready to run | ||
if (callbackCount.get() == 1) { | ||
taskQueue.advanceTime(); | ||
} else { | ||
try { | ||
handlerRef.get().close(); | ||
} catch (IOException e) { | ||
// ignore | ||
} | ||
} | ||
}; | ||
|
||
var requestHandler = mockAuthorizationRequestHandler( | ||
ElasticInferenceServiceAuthorizationModel.of( | ||
new ElasticInferenceServiceAuthorizationResponseEntity( | ||
List.of( | ||
new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( | ||
"rainbow-sprinkles", | ||
EnumSet.of(TaskType.CHAT_COMPLETION) | ||
) | ||
) | ||
) | ||
), | ||
ElasticInferenceServiceAuthorizationModel.of(new ElasticInferenceServiceAuthorizationResponseEntity(List.of())) | ||
); | ||
|
||
handlerRef.set( | ||
new ElasticInferenceServiceAuthorizationHandler( | ||
createWithEmptySettings(taskQueue.getThreadPool()), | ||
modelRegistry, | ||
requestHandler, | ||
initDefaultEndpoints(), | ||
EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION), | ||
null, | ||
mock(Sender.class), | ||
ElasticInferenceServiceSettingsTests.create(null, TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), | ||
callback | ||
) | ||
); | ||
|
||
var handler = handlerRef.get(); | ||
handler.init(); | ||
taskQueue.runAllRunnableTasks(); | ||
latch.await(Utils.TIMEOUT.getSeconds(), TimeUnit.SECONDS); | ||
|
||
// this should be after we've received both authorization responses, the second response will revoke authorization | ||
|
||
assertThat(handler.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class))); | ||
assertThat(handler.defaultConfigIds(), is(List.of())); | ||
assertThat(handler.supportedTaskTypes(), is(EnumSet.noneOf(TaskType.class))); | ||
|
||
PlainActionFuture<List<Model>> listener = new PlainActionFuture<>(); | ||
handler.defaultConfigs(listener); | ||
|
||
var configs = listener.actionGet(); | ||
assertThat(configs.size(), is(0)); | ||
} | ||
|
||
public void testSendsAnAuthorizationRequestTwice() throws Exception { | ||
var callbackCount = new AtomicInteger(0); | ||
// we're only interested in two authorization calls which is why I'm using a value of 2 here | ||
|
@@ -104,6 +176,10 @@ public void testSendsAnAuthorizationRequestTwice() throws Exception { | |
ElasticInferenceServiceAuthorizationModel.of( | ||
new ElasticInferenceServiceAuthorizationResponseEntity( | ||
List.of( | ||
new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Previously the first response was merged with the second auth response. Now the latest successful auth response dictates the auth so we need to repeat the model again in the second response for this test. |
||
"abc", | ||
EnumSet.of(TaskType.SPARSE_EMBEDDING) | ||
), | ||
new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( | ||
"rainbow-sprinkles", | ||
EnumSet.of(TaskType.CHAT_COMPLETION) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For future reference, loggers have a formatter, I believe, something like: