Skip to content

Commit 1321fd2

Browse files
jonathan-buttnerelasticsearchmachine
andauthored
[9.0] [ML] Improve EIS auth call logs and fix revocation bug (#132546) (#132693)
* [ML] Improve EIS auth call logs and fix revocation bug (#132546) * Fixing revoking and adding logs * Fixing tests * Update docs/changelog/132546.yaml * [CI] Auto commit changes from spotless * Addressing feedback --------- Co-authored-by: elasticsearchmachine <[email protected]> * Fixing mock registry --------- Co-authored-by: elasticsearchmachine <[email protected]>
1 parent efccf94 commit 1321fd2

File tree

7 files changed

+152
-46
lines changed

7 files changed

+152
-46
lines changed

docs/changelog/132546.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 132546
2+
summary: Improve EIS auth call logs and fix revocation bug
3+
area: Machine Learning
4+
type: bug
5+
issues: []

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceAuthorizationResponseEntity.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,16 @@
77

88
package org.elasticsearch.xpack.inference.external.response.elastic;
99

10+
import org.elasticsearch.common.Strings;
1011
import org.elasticsearch.common.io.stream.StreamInput;
1112
import org.elasticsearch.common.io.stream.StreamOutput;
1213
import org.elasticsearch.common.io.stream.Writeable;
1314
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
1415
import org.elasticsearch.inference.InferenceResults;
1516
import org.elasticsearch.inference.InferenceServiceResults;
1617
import org.elasticsearch.inference.TaskType;
18+
import org.elasticsearch.logging.LogManager;
19+
import org.elasticsearch.logging.Logger;
1720
import org.elasticsearch.xcontent.ConstructingObjectParser;
1821
import org.elasticsearch.xcontent.ParseField;
1922
import org.elasticsearch.xcontent.ToXContent;
@@ -39,6 +42,9 @@
3942
public class ElasticInferenceServiceAuthorizationResponseEntity implements InferenceServiceResults {
4043

4144
public static final String NAME = "elastic_inference_service_auth_results";
45+
46+
private static final Logger logger = LogManager.getLogger(ElasticInferenceServiceAuthorizationResponseEntity.class);
47+
private static final String AUTH_FIELD_NAME = "authorized_models";
4248
private static final Map<String, TaskType> ELASTIC_INFERENCE_SERVICE_TASK_TYPE_MAPPING = Map.of(
4349
"embed/text/sparse",
4450
TaskType.SPARSE_EMBEDDING,
@@ -103,6 +109,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
103109

104110
return builder;
105111
}
112+
113+
@Override
114+
public String toString() {
115+
return Strings.format("{modelName='%s', taskTypes='%s'}", modelName, taskTypes);
116+
}
106117
}
107118

108119
private final List<AuthorizedModel> authorizedModels;
@@ -134,6 +145,11 @@ public List<AuthorizedModel> getAuthorizedModels() {
134145
return authorizedModels;
135146
}
136147

148+
@Override
149+
public String toString() {
150+
return authorizedModels.stream().map(AuthorizedModel::toString).collect(Collectors.joining(", "));
151+
}
152+
137153
@Override
138154
public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
139155
throw new UnsupportedOperationException();

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -247,14 +247,15 @@ private void sendAuthorizationRequest() {
247247
}
248248

249249
private synchronized void setAuthorizedContent(ElasticInferenceServiceAuthorizationModel auth) {
250-
logger.debug("Received authorization response");
251-
var authorizedTaskTypesAndModels = authorizedContent.get().taskTypesAndModels.merge(auth)
252-
.newLimitedToTaskTypes(EnumSet.copyOf(implementedTaskTypes));
250+
logger.debug(() -> Strings.format("Received authorization response, %s", auth));
251+
252+
var authorizedTaskTypesAndModels = auth.newLimitedToTaskTypes(EnumSet.copyOf(implementedTaskTypes));
253+
logger.debug(() -> Strings.format("Authorization entity limited to service task types, %s", authorizedTaskTypesAndModels));
253254

254255
// recalculate which default config ids and models are authorized now
255-
var authorizedDefaultModelIds = getAuthorizedDefaultModelIds(auth);
256+
var authorizedDefaultModelIds = getAuthorizedDefaultModelIds(authorizedTaskTypesAndModels);
256257

257-
var authorizedDefaultConfigIds = getAuthorizedDefaultConfigIds(authorizedDefaultModelIds, auth);
258+
var authorizedDefaultConfigIds = getAuthorizedDefaultConfigIds(authorizedDefaultModelIds, authorizedTaskTypesAndModels);
258259
var authorizedDefaultModelObjects = getAuthorizedDefaultModelsObjects(authorizedDefaultModelIds);
259260
authorizedContent.set(
260261
new AuthorizedContent(authorizedTaskTypesAndModels, authorizedDefaultConfigIds, authorizedDefaultModelObjects)
@@ -341,7 +342,12 @@ private void handleRevokedDefaultConfigs(Set<String> authorizedDefaultModelIds)
341342
firstAuthorizationCompletedLatch.countDown();
342343
});
343344

344-
logger.debug("Synchronizing default inference endpoints");
345+
logger.debug(
346+
() -> Strings.format(
347+
"Synchronizing default inference endpoints, attempting to remove ids: %s",
348+
unauthorizedDefaultInferenceEndpointIds
349+
)
350+
);
345351
modelRegistry.removeDefaultConfigs(unauthorizedDefaultInferenceEndpointIds, deleteInferenceEndpointsListener);
346352
}
347353
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationModel.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,4 +161,16 @@ public boolean equals(Object o) {
161161
public int hashCode() {
162162
return Objects.hash(taskTypeToModels, authorizedTaskTypes, authorizedModelIds);
163163
}
164+
165+
@Override
166+
public String toString() {
167+
return "{"
168+
+ "taskTypeToModels="
169+
+ taskTypeToModels
170+
+ ", authorizedTaskTypes="
171+
+ authorizedTaskTypes
172+
+ ", authorizedModelIds="
173+
+ authorizedModelIds
174+
+ '}';
175+
}
164176
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99

1010
import org.apache.logging.log4j.LogManager;
1111
import org.apache.logging.log4j.Logger;
12-
import org.elasticsearch.ElasticsearchWrapperException;
12+
import org.elasticsearch.ElasticsearchException;
13+
import org.elasticsearch.ExceptionsHelper;
1314
import org.elasticsearch.action.ActionListener;
1415
import org.elasticsearch.common.Strings;
1516
import org.elasticsearch.core.Nullable;
@@ -86,25 +87,25 @@ public void getAuthorization(ActionListener<ElasticInferenceServiceAuthorization
8687

8788
ActionListener<InferenceServiceResults> newListener = ActionListener.wrap(results -> {
8889
if (results instanceof ElasticInferenceServiceAuthorizationResponseEntity authResponseEntity) {
90+
logger.debug(() -> Strings.format("Received authorization information from gateway %s", authResponseEntity));
8991
listener.onResponse(ElasticInferenceServiceAuthorizationModel.of(authResponseEntity));
9092
} else {
91-
logger.warn(
92-
Strings.format(
93-
FAILED_TO_RETRIEVE_MESSAGE + " Received an invalid response type: %s",
94-
results.getClass().getSimpleName()
95-
)
93+
var errorMessage = Strings.format(
94+
"%s Received an invalid response type from the Elastic Inference Service: %s",
95+
FAILED_TO_RETRIEVE_MESSAGE,
96+
results.getClass().getSimpleName()
9697
);
97-
listener.onResponse(ElasticInferenceServiceAuthorizationModel.newDisabledService());
98+
99+
logger.warn(errorMessage);
100+
listener.onFailure(new ElasticsearchException(errorMessage));
98101
}
99102
requestCompleteLatch.countDown();
100103
}, e -> {
101-
Throwable exception = e;
102-
if (e instanceof ElasticsearchWrapperException wrapperException) {
103-
exception = wrapperException.getCause();
104-
}
104+
// unwrap because it's likely a retry exception
105+
var exception = ExceptionsHelper.unwrapCause(e);
105106

106-
logger.warn(Strings.format(FAILED_TO_RETRIEVE_MESSAGE + " Encountered an exception: %s", exception));
107-
listener.onResponse(ElasticInferenceServiceAuthorizationModel.newDisabledService());
107+
logger.warn(Strings.format(FAILED_TO_RETRIEVE_MESSAGE + " Encountered an exception: %s", exception), exception);
108+
listener.onFailure(e);
108109
requestCompleteLatch.countDown();
109110
});
110111

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,78 @@ public void init() throws Exception {
5353
taskQueue = new DeterministicTaskQueue();
5454
}
5555

56+
public void testSecondAuthResultRevokesAuthorization() throws Exception {
57+
var callbackCount = new AtomicInteger(0);
58+
// we're only interested in two authorization calls which is why I'm using a value of 2 here
59+
var latch = new CountDownLatch(2);
60+
final AtomicReference<ElasticInferenceServiceAuthorizationHandler> handlerRef = new AtomicReference<>();
61+
62+
Runnable callback = () -> {
63+
// the first authorization response contains a streaming task so we're expecting to support streaming here
64+
if (callbackCount.incrementAndGet() == 1) {
65+
assertThat(handlerRef.get().supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));
66+
}
67+
latch.countDown();
68+
69+
// we only want to run the tasks twice, so advance the time on the queue
70+
// which flags the scheduled authorization request to be ready to run
71+
if (callbackCount.get() == 1) {
72+
taskQueue.advanceTime();
73+
} else {
74+
try {
75+
handlerRef.get().close();
76+
} catch (IOException e) {
77+
// ignore
78+
}
79+
}
80+
};
81+
82+
var requestHandler = mockAuthorizationRequestHandler(
83+
ElasticInferenceServiceAuthorizationModel.of(
84+
new ElasticInferenceServiceAuthorizationResponseEntity(
85+
List.of(
86+
new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel(
87+
"rainbow-sprinkles",
88+
EnumSet.of(TaskType.CHAT_COMPLETION)
89+
)
90+
)
91+
)
92+
),
93+
ElasticInferenceServiceAuthorizationModel.of(new ElasticInferenceServiceAuthorizationResponseEntity(List.of()))
94+
);
95+
96+
handlerRef.set(
97+
new ElasticInferenceServiceAuthorizationHandler(
98+
createWithEmptySettings(taskQueue.getThreadPool()),
99+
mockModelRegistry(taskQueue.getThreadPool()),
100+
requestHandler,
101+
initDefaultEndpoints(),
102+
EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION),
103+
null,
104+
mock(Sender.class),
105+
ElasticInferenceServiceSettingsTests.create(null, TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true),
106+
callback
107+
)
108+
);
109+
110+
var handler = handlerRef.get();
111+
handler.init();
112+
taskQueue.runAllRunnableTasks();
113+
latch.await(Utils.TIMEOUT.getSeconds(), TimeUnit.SECONDS);
114+
115+
// this should be after we've received both authorization responses, the second response will revoke authorization
116+
117+
assertThat(handler.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class)));
118+
assertThat(handler.defaultConfigIds(), is(List.of()));
119+
assertThat(handler.supportedTaskTypes(), is(EnumSet.noneOf(TaskType.class)));
120+
121+
PlainActionFuture<List<Model>> listener = new PlainActionFuture<>();
122+
handler.defaultConfigs(listener);
123+
124+
var configs = listener.actionGet();
125+
assertThat(configs.size(), is(0));
126+
}
127+
56128
public void testSendsAnAuthorizationRequestTwice() throws Exception {
57129
var callbackCount = new AtomicInteger(0);
58130
// we're only interested in two authorization calls which is why I'm using a value of 2 here
@@ -90,6 +162,10 @@ public void testSendsAnAuthorizationRequestTwice() throws Exception {
90162
ElasticInferenceServiceAuthorizationModel.of(
91163
new ElasticInferenceServiceAuthorizationResponseEntity(
92164
List.of(
165+
new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel(
166+
"abc",
167+
EnumSet.of(TaskType.SPARSE_EMBEDDING)
168+
),
93169
new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel(
94170
"rainbow-sprinkles",
95171
EnumSet.of(TaskType.CHAT_COMPLETION)

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java

Lines changed: 17 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
package org.elasticsearch.xpack.inference.services.elastic.authorization;
99

1010
import org.apache.logging.log4j.Logger;
11+
import org.elasticsearch.ElasticsearchException;
1112
import org.elasticsearch.action.ActionListener;
1213
import org.elasticsearch.action.support.PlainActionFuture;
1314
import org.elasticsearch.common.settings.Settings;
@@ -18,6 +19,7 @@
1819
import org.elasticsearch.test.http.MockResponse;
1920
import org.elasticsearch.test.http.MockWebServer;
2021
import org.elasticsearch.threadpool.ThreadPool;
22+
import org.elasticsearch.xcontent.XContentParseException;
2123
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
2224
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
2325
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
@@ -38,13 +40,14 @@
3840
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
3941
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
4042
import static org.elasticsearch.xpack.inference.external.http.retry.RetryingHttpSender.MAX_RETIES;
43+
import static org.hamcrest.Matchers.containsString;
44+
import static org.hamcrest.Matchers.instanceOf;
4145
import static org.hamcrest.Matchers.is;
4246
import static org.mockito.ArgumentMatchers.any;
4347
import static org.mockito.Mockito.doAnswer;
4448
import static org.mockito.Mockito.mock;
4549
import static org.mockito.Mockito.times;
4650
import static org.mockito.Mockito.verify;
47-
import static org.mockito.Mockito.verifyNoMoreInteractions;
4851
import static org.mockito.Mockito.when;
4952

5053
public class ElasticInferenceServiceAuthorizationRequestHandlerTests extends ESTestCase {
@@ -135,22 +138,17 @@ public void testGetAuthorization_FailsWhenAnInvalidFieldIsFound() throws IOExcep
135138
PlainActionFuture<ElasticInferenceServiceAuthorizationModel> listener = new PlainActionFuture<>();
136139
authHandler.getAuthorization(listener, sender);
137140

138-
var authResponse = listener.actionGet(TIMEOUT);
139-
assertTrue(authResponse.getAuthorizedTaskTypes().isEmpty());
140-
assertTrue(authResponse.getAuthorizedModelIds().isEmpty());
141-
assertFalse(authResponse.isAuthorized());
141+
var exception = expectThrows(XContentParseException.class, () -> listener.actionGet(TIMEOUT));
142+
assertThat(exception.getMessage(), containsString("failed to parse field [models]"));
142143

143-
var loggerArgsCaptor = ArgumentCaptor.forClass(String.class);
144-
verify(logger).warn(loggerArgsCaptor.capture());
145-
var message = loggerArgsCaptor.getValue();
146-
assertThat(
147-
message,
148-
is(
149-
"Failed to retrieve the authorization information from the Elastic Inference Service."
150-
+ " Encountered an exception: org.elasticsearch.xcontent.XContentParseException: [4:28] "
151-
+ "[ElasticInferenceServiceAuthorizationResponseEntity] failed to parse field [models]"
152-
)
153-
);
144+
var stringCaptor = ArgumentCaptor.forClass(String.class);
145+
var exceptionCaptor = ArgumentCaptor.forClass(Exception.class);
146+
verify(logger).warn(stringCaptor.capture(), exceptionCaptor.capture());
147+
var message = stringCaptor.getValue();
148+
assertThat(message, containsString("failed to parse field [models]"));
149+
150+
var capturedException = exceptionCaptor.getValue();
151+
assertThat(capturedException, instanceOf(XContentParseException.class));
154152
}
155153
}
156154

@@ -196,7 +194,6 @@ public void testGetAuthorization_ReturnsAValidResponse() throws IOException {
196194

197195
var message = loggerArgsCaptor.getValue();
198196
assertThat(message, is("Retrieving authorization information from the Elastic Inference Service."));
199-
verifyNoMoreInteractions(logger);
200197
}
201198
}
202199

@@ -230,7 +227,6 @@ public void testGetAuthorization_OnResponseCalledOnce() throws IOException {
230227

231228
var message = loggerArgsCaptor.getValue();
232229
assertThat(message, is("Retrieving authorization information from the Elastic Inference Service."));
233-
verifyNoMoreInteractions(logger);
234230
}
235231
}
236232

@@ -252,20 +248,14 @@ public void testGetAuthorization_InvalidResponse() throws IOException {
252248
PlainActionFuture<ElasticInferenceServiceAuthorizationModel> listener = new PlainActionFuture<>();
253249

254250
authHandler.getAuthorization(listener, sender);
255-
var result = listener.actionGet(TIMEOUT);
251+
var exception = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
256252

257-
assertThat(result, is(ElasticInferenceServiceAuthorizationModel.newDisabledService()));
253+
assertThat(exception.getMessage(), containsString("Received an invalid response type from the Elastic Inference Service"));
258254

259255
var loggerArgsCaptor = ArgumentCaptor.forClass(String.class);
260256
verify(logger).warn(loggerArgsCaptor.capture());
261257
var message = loggerArgsCaptor.getValue();
262-
assertThat(
263-
message,
264-
is(
265-
"Failed to retrieve the authorization information from the Elastic Inference Service."
266-
+ " Received an invalid response type: ChatCompletionResults"
267-
)
268-
);
258+
assertThat(message, containsString("Failed to retrieve the authorization information from the Elastic Inference Service."));
269259
}
270260

271261
}

0 commit comments

Comments
 (0)