Skip to content

Commit 6b85189

Browse files
jonathan-buttnerelasticsearchmachine
andauthored
[ML] Improve EIS auth call logs and fix revocation bug (elastic#132546)
* Fixing revoking and adding logs * Fixing tests * Update docs/changelog/132546.yaml * [CI] Auto commit changes from spotless * Addressing feedback --------- Co-authored-by: elasticsearchmachine <[email protected]>
1 parent 883f631 commit 6b85189

File tree

7 files changed

+152
-46
lines changed

7 files changed

+152
-46
lines changed

docs/changelog/132546.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 132546
2+
summary: Improve EIS auth call logs and fix revocation bug
3+
area: Machine Learning
4+
type: bug
5+
issues: []

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -243,14 +243,15 @@ private void sendAuthorizationRequest() {
243243
}
244244

245245
private synchronized void setAuthorizedContent(ElasticInferenceServiceAuthorizationModel auth) {
246-
logger.debug("Received authorization response");
247-
var authorizedTaskTypesAndModels = authorizedContent.get().taskTypesAndModels.merge(auth)
248-
.newLimitedToTaskTypes(EnumSet.copyOf(implementedTaskTypes));
246+
logger.debug(() -> Strings.format("Received authorization response, %s", auth));
247+
248+
var authorizedTaskTypesAndModels = auth.newLimitedToTaskTypes(EnumSet.copyOf(implementedTaskTypes));
249+
logger.debug(() -> Strings.format("Authorization entity limited to service task types, %s", authorizedTaskTypesAndModels));
249250

250251
// recalculate which default config ids and models are authorized now
251-
var authorizedDefaultModelIds = getAuthorizedDefaultModelIds(auth);
252+
var authorizedDefaultModelIds = getAuthorizedDefaultModelIds(authorizedTaskTypesAndModels);
252253

253-
var authorizedDefaultConfigIds = getAuthorizedDefaultConfigIds(authorizedDefaultModelIds, auth);
254+
var authorizedDefaultConfigIds = getAuthorizedDefaultConfigIds(authorizedDefaultModelIds, authorizedTaskTypesAndModels);
254255
var authorizedDefaultModelObjects = getAuthorizedDefaultModelsObjects(authorizedDefaultModelIds);
255256
authorizedContent.set(
256257
new AuthorizedContent(authorizedTaskTypesAndModels, authorizedDefaultConfigIds, authorizedDefaultModelObjects)
@@ -337,7 +338,12 @@ private void handleRevokedDefaultConfigs(Set<String> authorizedDefaultModelIds)
337338
firstAuthorizationCompletedLatch.countDown();
338339
});
339340

340-
logger.debug("Synchronizing default inference endpoints");
341+
logger.debug(
342+
() -> Strings.format(
343+
"Synchronizing default inference endpoints, attempting to remove ids: %s",
344+
unauthorizedDefaultInferenceEndpointIds
345+
)
346+
);
341347
modelRegistry.removeDefaultConfigs(unauthorizedDefaultInferenceEndpointIds, deleteInferenceEndpointsListener);
342348
}
343349
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationModel.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,4 +161,16 @@ public boolean equals(Object o) {
161161
public int hashCode() {
162162
return Objects.hash(taskTypeToModels, authorizedTaskTypes, authorizedModelIds);
163163
}
164+
165+
@Override
166+
public String toString() {
167+
return "{"
168+
+ "taskTypeToModels="
169+
+ taskTypeToModels
170+
+ ", authorizedTaskTypes="
171+
+ authorizedTaskTypes
172+
+ ", authorizedModelIds="
173+
+ authorizedModelIds
174+
+ '}';
175+
}
164176
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99

1010
import org.apache.logging.log4j.LogManager;
1111
import org.apache.logging.log4j.Logger;
12-
import org.elasticsearch.ElasticsearchWrapperException;
12+
import org.elasticsearch.ElasticsearchException;
13+
import org.elasticsearch.ExceptionsHelper;
1314
import org.elasticsearch.action.ActionListener;
1415
import org.elasticsearch.common.Strings;
1516
import org.elasticsearch.core.Nullable;
@@ -86,25 +87,25 @@ public void getAuthorization(ActionListener<ElasticInferenceServiceAuthorization
8687

8788
ActionListener<InferenceServiceResults> newListener = ActionListener.wrap(results -> {
8889
if (results instanceof ElasticInferenceServiceAuthorizationResponseEntity authResponseEntity) {
90+
logger.debug(() -> Strings.format("Received authorization information from gateway %s", authResponseEntity));
8991
listener.onResponse(ElasticInferenceServiceAuthorizationModel.of(authResponseEntity));
9092
} else {
91-
logger.warn(
92-
Strings.format(
93-
FAILED_TO_RETRIEVE_MESSAGE + " Received an invalid response type: %s",
94-
results.getClass().getSimpleName()
95-
)
93+
var errorMessage = Strings.format(
94+
"%s Received an invalid response type from the Elastic Inference Service: %s",
95+
FAILED_TO_RETRIEVE_MESSAGE,
96+
results.getClass().getSimpleName()
9697
);
97-
listener.onResponse(ElasticInferenceServiceAuthorizationModel.newDisabledService());
98+
99+
logger.warn(errorMessage);
100+
listener.onFailure(new ElasticsearchException(errorMessage));
98101
}
99102
requestCompleteLatch.countDown();
100103
}, e -> {
101-
Throwable exception = e;
102-
if (e instanceof ElasticsearchWrapperException wrapperException) {
103-
exception = wrapperException.getCause();
104-
}
104+
// unwrap because it's likely a retry exception
105+
var exception = ExceptionsHelper.unwrapCause(e);
105106

106-
logger.warn(Strings.format(FAILED_TO_RETRIEVE_MESSAGE + " Encountered an exception: %s", exception));
107-
listener.onResponse(ElasticInferenceServiceAuthorizationModel.newDisabledService());
107+
logger.warn(Strings.format(FAILED_TO_RETRIEVE_MESSAGE + " Encountered an exception: %s", exception), exception);
108+
listener.onFailure(e);
108109
requestCompleteLatch.countDown();
109110
});
110111

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntity.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,16 @@
77

88
package org.elasticsearch.xpack.inference.services.elastic.response;
99

10+
import org.elasticsearch.common.Strings;
1011
import org.elasticsearch.common.io.stream.StreamInput;
1112
import org.elasticsearch.common.io.stream.StreamOutput;
1213
import org.elasticsearch.common.io.stream.Writeable;
1314
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
1415
import org.elasticsearch.inference.InferenceResults;
1516
import org.elasticsearch.inference.InferenceServiceResults;
1617
import org.elasticsearch.inference.TaskType;
18+
import org.elasticsearch.logging.LogManager;
19+
import org.elasticsearch.logging.Logger;
1720
import org.elasticsearch.xcontent.ConstructingObjectParser;
1821
import org.elasticsearch.xcontent.ParseField;
1922
import org.elasticsearch.xcontent.ToXContent;
@@ -39,6 +42,9 @@
3942
public class ElasticInferenceServiceAuthorizationResponseEntity implements InferenceServiceResults {
4043

4144
public static final String NAME = "elastic_inference_service_auth_results";
45+
46+
private static final Logger logger = LogManager.getLogger(ElasticInferenceServiceAuthorizationResponseEntity.class);
47+
private static final String AUTH_FIELD_NAME = "authorized_models";
4248
private static final Map<String, TaskType> ELASTIC_INFERENCE_SERVICE_TASK_TYPE_MAPPING = Map.of(
4349
"embed/text/sparse",
4450
TaskType.SPARSE_EMBEDDING,
@@ -107,6 +113,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
107113

108114
return builder;
109115
}
116+
117+
@Override
118+
public String toString() {
119+
return Strings.format("{modelName='%s', taskTypes='%s'}", modelName, taskTypes);
120+
}
110121
}
111122

112123
private final List<AuthorizedModel> authorizedModels;
@@ -138,6 +149,11 @@ public List<AuthorizedModel> getAuthorizedModels() {
138149
return authorizedModels;
139150
}
140151

152+
@Override
153+
public String toString() {
154+
return authorizedModels.stream().map(AuthorizedModel::toString).collect(Collectors.joining(", "));
155+
}
156+
141157
@Override
142158
public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
143159
throw new UnsupportedOperationException();

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,78 @@ public void init() throws Exception {
6767
modelRegistry = getInstanceFromNode(ModelRegistry.class);
6868
}
6969

70+
public void testSecondAuthResultRevokesAuthorization() throws Exception {
71+
var callbackCount = new AtomicInteger(0);
72+
// we're only interested in two authorization calls which is why I'm using a value of 2 here
73+
var latch = new CountDownLatch(2);
74+
final AtomicReference<ElasticInferenceServiceAuthorizationHandler> handlerRef = new AtomicReference<>();
75+
76+
Runnable callback = () -> {
77+
// the first authorization response contains a streaming task so we're expecting to support streaming here
78+
if (callbackCount.incrementAndGet() == 1) {
79+
assertThat(handlerRef.get().supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));
80+
}
81+
latch.countDown();
82+
83+
// we only want to run the tasks twice, so advance the time on the queue
84+
// which flags the scheduled authorization request to be ready to run
85+
if (callbackCount.get() == 1) {
86+
taskQueue.advanceTime();
87+
} else {
88+
try {
89+
handlerRef.get().close();
90+
} catch (IOException e) {
91+
// ignore
92+
}
93+
}
94+
};
95+
96+
var requestHandler = mockAuthorizationRequestHandler(
97+
ElasticInferenceServiceAuthorizationModel.of(
98+
new ElasticInferenceServiceAuthorizationResponseEntity(
99+
List.of(
100+
new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel(
101+
"rainbow-sprinkles",
102+
EnumSet.of(TaskType.CHAT_COMPLETION)
103+
)
104+
)
105+
)
106+
),
107+
ElasticInferenceServiceAuthorizationModel.of(new ElasticInferenceServiceAuthorizationResponseEntity(List.of()))
108+
);
109+
110+
handlerRef.set(
111+
new ElasticInferenceServiceAuthorizationHandler(
112+
createWithEmptySettings(taskQueue.getThreadPool()),
113+
modelRegistry,
114+
requestHandler,
115+
initDefaultEndpoints(),
116+
EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION),
117+
null,
118+
mock(Sender.class),
119+
ElasticInferenceServiceSettingsTests.create(null, TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true),
120+
callback
121+
)
122+
);
123+
124+
var handler = handlerRef.get();
125+
handler.init();
126+
taskQueue.runAllRunnableTasks();
127+
latch.await(Utils.TIMEOUT.getSeconds(), TimeUnit.SECONDS);
128+
129+
// this should be after we've received both authorization responses, the second response will revoke authorization
130+
131+
assertThat(handler.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class)));
132+
assertThat(handler.defaultConfigIds(), is(List.of()));
133+
assertThat(handler.supportedTaskTypes(), is(EnumSet.noneOf(TaskType.class)));
134+
135+
PlainActionFuture<List<Model>> listener = new PlainActionFuture<>();
136+
handler.defaultConfigs(listener);
137+
138+
var configs = listener.actionGet();
139+
assertThat(configs.size(), is(0));
140+
}
141+
70142
public void testSendsAnAuthorizationRequestTwice() throws Exception {
71143
var callbackCount = new AtomicInteger(0);
72144
// we're only interested in two authorization calls which is why I'm using a value of 2 here
@@ -104,6 +176,10 @@ public void testSendsAnAuthorizationRequestTwice() throws Exception {
104176
ElasticInferenceServiceAuthorizationModel.of(
105177
new ElasticInferenceServiceAuthorizationResponseEntity(
106178
List.of(
179+
new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel(
180+
"abc",
181+
EnumSet.of(TaskType.SPARSE_EMBEDDING)
182+
),
107183
new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel(
108184
"rainbow-sprinkles",
109185
EnumSet.of(TaskType.CHAT_COMPLETION)

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java

Lines changed: 17 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
package org.elasticsearch.xpack.inference.services.elastic.authorization;
99

1010
import org.apache.logging.log4j.Logger;
11+
import org.elasticsearch.ElasticsearchException;
1112
import org.elasticsearch.action.ActionListener;
1213
import org.elasticsearch.action.support.PlainActionFuture;
1314
import org.elasticsearch.common.settings.Settings;
@@ -18,6 +19,7 @@
1819
import org.elasticsearch.test.http.MockResponse;
1920
import org.elasticsearch.test.http.MockWebServer;
2021
import org.elasticsearch.threadpool.ThreadPool;
22+
import org.elasticsearch.xcontent.XContentParseException;
2123
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
2224
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
2325
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
@@ -38,13 +40,14 @@
3840
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
3941
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
4042
import static org.elasticsearch.xpack.inference.external.http.retry.RetryingHttpSender.MAX_RETIES;
43+
import static org.hamcrest.Matchers.containsString;
44+
import static org.hamcrest.Matchers.instanceOf;
4145
import static org.hamcrest.Matchers.is;
4246
import static org.mockito.ArgumentMatchers.any;
4347
import static org.mockito.Mockito.doAnswer;
4448
import static org.mockito.Mockito.mock;
4549
import static org.mockito.Mockito.times;
4650
import static org.mockito.Mockito.verify;
47-
import static org.mockito.Mockito.verifyNoMoreInteractions;
4851
import static org.mockito.Mockito.when;
4952

5053
public class ElasticInferenceServiceAuthorizationRequestHandlerTests extends ESTestCase {
@@ -135,22 +138,17 @@ public void testGetAuthorization_FailsWhenAnInvalidFieldIsFound() throws IOExcep
135138
PlainActionFuture<ElasticInferenceServiceAuthorizationModel> listener = new PlainActionFuture<>();
136139
authHandler.getAuthorization(listener, sender);
137140

138-
var authResponse = listener.actionGet(TIMEOUT);
139-
assertTrue(authResponse.getAuthorizedTaskTypes().isEmpty());
140-
assertTrue(authResponse.getAuthorizedModelIds().isEmpty());
141-
assertFalse(authResponse.isAuthorized());
141+
var exception = expectThrows(XContentParseException.class, () -> listener.actionGet(TIMEOUT));
142+
assertThat(exception.getMessage(), containsString("failed to parse field [models]"));
142143

143-
var loggerArgsCaptor = ArgumentCaptor.forClass(String.class);
144-
verify(logger).warn(loggerArgsCaptor.capture());
145-
var message = loggerArgsCaptor.getValue();
146-
assertThat(
147-
message,
148-
is(
149-
"Failed to retrieve the authorization information from the Elastic Inference Service."
150-
+ " Encountered an exception: org.elasticsearch.xcontent.XContentParseException: [4:28] "
151-
+ "[ElasticInferenceServiceAuthorizationResponseEntity] failed to parse field [models]"
152-
)
153-
);
144+
var stringCaptor = ArgumentCaptor.forClass(String.class);
145+
var exceptionCaptor = ArgumentCaptor.forClass(Exception.class);
146+
verify(logger).warn(stringCaptor.capture(), exceptionCaptor.capture());
147+
var message = stringCaptor.getValue();
148+
assertThat(message, containsString("failed to parse field [models]"));
149+
150+
var capturedException = exceptionCaptor.getValue();
151+
assertThat(capturedException, instanceOf(XContentParseException.class));
154152
}
155153
}
156154

@@ -196,7 +194,6 @@ public void testGetAuthorization_ReturnsAValidResponse() throws IOException {
196194

197195
var message = loggerArgsCaptor.getValue();
198196
assertThat(message, is("Retrieving authorization information from the Elastic Inference Service."));
199-
verifyNoMoreInteractions(logger);
200197
}
201198
}
202199

@@ -230,7 +227,6 @@ public void testGetAuthorization_OnResponseCalledOnce() throws IOException {
230227

231228
var message = loggerArgsCaptor.getValue();
232229
assertThat(message, is("Retrieving authorization information from the Elastic Inference Service."));
233-
verifyNoMoreInteractions(logger);
234230
}
235231
}
236232

@@ -252,20 +248,14 @@ public void testGetAuthorization_InvalidResponse() throws IOException {
252248
PlainActionFuture<ElasticInferenceServiceAuthorizationModel> listener = new PlainActionFuture<>();
253249

254250
authHandler.getAuthorization(listener, sender);
255-
var result = listener.actionGet(TIMEOUT);
251+
var exception = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
256252

257-
assertThat(result, is(ElasticInferenceServiceAuthorizationModel.newDisabledService()));
253+
assertThat(exception.getMessage(), containsString("Received an invalid response type from the Elastic Inference Service"));
258254

259255
var loggerArgsCaptor = ArgumentCaptor.forClass(String.class);
260256
verify(logger).warn(loggerArgsCaptor.capture());
261257
var message = loggerArgsCaptor.getValue();
262-
assertThat(
263-
message,
264-
is(
265-
"Failed to retrieve the authorization information from the Elastic Inference Service."
266-
+ " Received an invalid response type: ChatCompletionResults"
267-
)
268-
);
258+
assertThat(message, containsString("Failed to retrieve the authorization information from the Elastic Inference Service."));
269259
}
270260

271261
}

0 commit comments

Comments
 (0)