Skip to content

Commit f6045d0

Browse files
Adding validation call and tests
1 parent f1539fd commit f6045d0

File tree

13 files changed

+368
-24
lines changed

13 files changed

+368
-24
lines changed

x-pack/plugin/inference/qa/inference-service-tests/build.gradle

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ dependencies {
77
clusterPlugins project(':x-pack:plugin:inference:qa:test-service-plugin')
88
// Added this to have access to MockWebServer within the tests
99
javaRestTestImplementation(testArtifact(project(xpackModule('core'))))
10+
11+
// Access public test classes/utilities from x-pack:inference tests
12+
javaRestTestImplementation(testArtifact(project(xpackModule('inference'))))
1013
}
1114

1215
tasks.named("javaRestTest").configure {

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CCMCrudIT.java

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,30 +19,48 @@
1919
import org.elasticsearch.rest.RestStatus;
2020
import org.elasticsearch.test.cluster.ElasticsearchCluster;
2121
import org.elasticsearch.test.cluster.local.distribution.DistributionType;
22+
import org.elasticsearch.test.http.MockResponse;
2223
import org.elasticsearch.xpack.core.inference.action.PutCCMConfigurationAction;
24+
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings;
2325
import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMFeatureFlag;
2426
import org.junit.After;
2527
import org.junit.BeforeClass;
2628
import org.junit.ClassRule;
29+
import org.junit.rules.RuleChain;
30+
import org.junit.rules.TestRule;
2731

2832
import java.io.IOException;
2933

34+
import static org.elasticsearch.xpack.inference.action.TransportPutCCMConfigurationActionTests.ERROR_MESSAGE;
35+
import static org.elasticsearch.xpack.inference.action.TransportPutCCMConfigurationActionTests.ERROR_RESPONSE;
36+
import static org.elasticsearch.xpack.inference.external.http.retry.RetryingHttpSender.MAX_RETRIES;
3037
import static org.elasticsearch.xpack.inference.rest.Paths.INFERENCE_CCM_PATH;
3138
import static org.elasticsearch.xpack.inference.services.elastic.ccm.CCMSettings.CCM_SUPPORTED_ENVIRONMENT;
3239
import static org.hamcrest.Matchers.containsString;
3340
import static org.hamcrest.Matchers.is;
3441

3542
public class CCMCrudIT extends CCMRestBaseIT {
3643

37-
@ClassRule
38-
public static ElasticsearchCluster cluster = ElasticsearchCluster.local()
44+
private static final MockElasticInferenceServiceAuthorizationServer mockEISServer =
45+
new MockElasticInferenceServiceAuthorizationServer();
46+
47+
private static ElasticsearchCluster cluster = ElasticsearchCluster.local()
3948
.distribution(DistributionType.DEFAULT)
4049
.setting("xpack.license.self_generated.type", "basic")
4150
.setting("xpack.security.enabled", "true")
4251
.setting(CCM_SUPPORTED_ENVIRONMENT.getKey(), "true")
52+
.setting(ElasticInferenceServiceSettings.ELASTIC_INFERENCE_SERVICE_URL.getKey(), mockEISServer::getUrl)
53+
.setting(ElasticInferenceServiceSettings.PERIODIC_AUTHORIZATION_ENABLED.getKey(), "false")
54+
// TODO disable the authorization task so it doesn't try to call the mock server
4355
.user("x_pack_rest_user", "x-pack-test-password")
4456
.build();
4557

58+
// The reason we're doing this is to make sure the mock server is initialized first so we can get the address before communicating
59+
// it to the cluster as a setting.
60+
// Note: @ClassRule is executed once for the entire test class
61+
@ClassRule
62+
public static TestRule ruleChain = RuleChain.outerRule(mockEISServer).around(cluster);
63+
4664
@Override
4765
protected String getTestRestCluster() {
4866
return cluster.getHttpAddresses();
@@ -70,16 +88,61 @@ public void cleanup() {
7088
}
7189

7290
public void testEnablesCCM_Succeeds() throws IOException {
91+
mockEISServer.enqueueAuthorizeAllModelsResponse();
7392
var response = putCCMConfiguration(ENABLE_CCM_REQUEST);
7493

7594
assertTrue(response.isEnabled());
7695
}
7796

97+
public void testEnablesCCM_Fails_IfAuthValidationFails401() throws IOException {
98+
var webServer = mockEISServer.getWebServer();
99+
webServer.enqueue(new MockResponse().setResponseCode(RestStatus.UNAUTHORIZED.getStatus()).setBody(ERROR_RESPONSE));
100+
101+
var exception = expectThrows(ResponseException.class, () -> putRawRequest(INFERENCE_CCM_PATH, ENABLE_CCM_REQUEST));
102+
assertThat(exception.getMessage(), containsString(Strings.format(ERROR_MESSAGE)));
103+
assertThat(exception.getResponse().getStatusLine().getStatusCode(), is(RestStatus.UNAUTHORIZED.getStatus()));
104+
105+
assertFalse(getCCMConfiguration().isEnabled());
106+
}
107+
108+
public void testEnablesCCM_Fails_IfAuthValidationFails400() throws IOException {
109+
var webServer = mockEISServer.getWebServer();
110+
webServer.enqueue(new MockResponse().setResponseCode(RestStatus.BAD_REQUEST.getStatus()).setBody(ERROR_RESPONSE));
111+
112+
var exception = expectThrows(ResponseException.class, () -> putRawRequest(INFERENCE_CCM_PATH, ENABLE_CCM_REQUEST));
113+
assertThat(exception.getMessage(), containsString(Strings.format(ERROR_MESSAGE)));
114+
assertThat(exception.getResponse().getStatusLine().getStatusCode(), is(RestStatus.BAD_REQUEST.getStatus()));
115+
116+
assertFalse(getCCMConfiguration().isEnabled());
117+
}
118+
119+
public void testEnablesCCM_Fails_IfAuthValidationFails500() throws IOException {
120+
// 500 errors are retried so we need to queue up multiple responses
121+
queueResponsesToHandleRetries(
122+
new MockResponse().setResponseCode(RestStatus.INTERNAL_SERVER_ERROR.getStatus()).setBody(ERROR_RESPONSE)
123+
);
124+
125+
var exception = expectThrows(ResponseException.class, () -> putRawRequest(INFERENCE_CCM_PATH, ENABLE_CCM_REQUEST));
126+
assertThat(exception.getMessage(), containsString(Strings.format(ERROR_MESSAGE)));
127+
// 5xx errors are transformed by the response handler to a 400 error to avoid triggering alerts
128+
assertThat(exception.getResponse().getStatusLine().getStatusCode(), is(RestStatus.BAD_REQUEST.getStatus()));
129+
130+
assertFalse(getCCMConfiguration().isEnabled());
131+
}
132+
133+
private void queueResponsesToHandleRetries(MockResponse response) {
134+
for (int i = 0; i < MAX_RETRIES; i++) {
135+
mockEISServer.getWebServer().enqueue(response);
136+
}
137+
}
138+
78139
public void testEnablesCCMTwice_Succeeds() throws IOException {
140+
mockEISServer.enqueueAuthorizeAllModelsResponse();
79141
var response = putCCMConfiguration(ENABLE_CCM_REQUEST);
80142

81143
assertTrue(response.isEnabled());
82144

145+
mockEISServer.enqueueAuthorizeAllModelsResponse();
83146
response = putCCMConfiguration(
84147
PutCCMConfigurationAction.Request.createEnabled("other_key", TimeValue.THIRTY_SECONDS, TimeValue.THIRTY_SECONDS)
85148
);
@@ -125,6 +188,7 @@ public void testGetCCMConfiguration_WhenCCMDisabled_ReturnsDisabled() throws IOE
125188
}
126189

127190
public void testEnablesCCM_ThenDisable() throws IOException {
191+
mockEISServer.enqueueAuthorizeAllModelsResponse();
128192
var response = putCCMConfiguration(ENABLE_CCM_REQUEST);
129193

130194
assertTrue(response.isEnabled());

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,10 @@ public void enqueueAuthorizeAllModelsResponse() {
6161
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
6262
}
6363

64+
public MockWebServer getWebServer() {
65+
return webServer;
66+
}
67+
6468
public String getUrl() {
6569
return format("http://%s:%s", webServer.getHostName(), webServer.getPort());
6670
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,7 @@ public Collection<?> createComponents(PluginServices services) {
445445
services.featureService()
446446
)
447447
);
448+
components.add(new PluginComponentBinding<>(ElasticInferenceServiceSettings.class, inferenceServiceSettings));
448449

449450
return components;
450451
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutCCMConfigurationAction.java

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,13 @@
77

88
package org.elasticsearch.xpack.inference.action;
99

10+
import org.apache.logging.log4j.LogManager;
11+
import org.apache.logging.log4j.Logger;
12+
import org.elasticsearch.ElasticsearchStatusException;
13+
import org.elasticsearch.ExceptionsHelper;
1014
import org.elasticsearch.action.ActionListener;
1115
import org.elasticsearch.action.support.ActionFilters;
16+
import org.elasticsearch.action.support.SubscribableListener;
1217
import org.elasticsearch.action.support.master.TransportMasterNodeAction;
1318
import org.elasticsearch.cluster.ClusterState;
1419
import org.elasticsearch.cluster.block.ClusterBlockException;
@@ -17,14 +22,20 @@
1722
import org.elasticsearch.cluster.service.ClusterService;
1823
import org.elasticsearch.common.util.concurrent.EsExecutors;
1924
import org.elasticsearch.injection.guice.Inject;
25+
import org.elasticsearch.rest.RestStatus;
2026
import org.elasticsearch.tasks.Task;
2127
import org.elasticsearch.threadpool.ThreadPool;
2228
import org.elasticsearch.transport.TransportService;
2329
import org.elasticsearch.xpack.core.inference.action.CCMEnabledActionResponse;
2430
import org.elasticsearch.xpack.core.inference.action.PutCCMConfigurationAction;
31+
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
32+
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings;
33+
import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationModel;
34+
import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler;
2535
import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMFeature;
2636
import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMModel;
2737
import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMService;
38+
import org.elasticsearch.xpack.inference.services.elastic.ccm.ValidationAuthenticationFactory;
2839

2940
import java.util.Objects;
3041

@@ -34,9 +45,14 @@ public class TransportPutCCMConfigurationAction extends TransportMasterNodeActio
3445
PutCCMConfigurationAction.Request,
3546
CCMEnabledActionResponse> {
3647

48+
private static final Logger logger = LogManager.getLogger(TransportPutCCMConfigurationAction.class);
49+
static final String FAILED_VALIDATION_MESSAGE = "Failed to validate the Cloud Connected Mode API key";
50+
3751
private final CCMFeature ccmFeature;
3852
private final CCMService ccmService;
3953
private final ProjectResolver projectResolver;
54+
private final Sender eisSender;
55+
private final ElasticInferenceServiceSettings eisSettings;
4056

4157
@Inject
4258
public TransportPutCCMConfigurationAction(
@@ -46,7 +62,9 @@ public TransportPutCCMConfigurationAction(
4662
ActionFilters actionFilters,
4763
CCMService ccmService,
4864
ProjectResolver projectResolver,
49-
CCMFeature ccmFeature
65+
CCMFeature ccmFeature,
66+
Sender eisSender,
67+
ElasticInferenceServiceSettings eisSettings
5068
) {
5169
super(
5270
PutCCMConfigurationAction.NAME,
@@ -61,6 +79,8 @@ public TransportPutCCMConfigurationAction(
6179
this.ccmService = Objects.requireNonNull(ccmService);
6280
this.projectResolver = Objects.requireNonNull(projectResolver);
6381
this.ccmFeature = Objects.requireNonNull(ccmFeature);
82+
this.eisSender = Objects.requireNonNull(eisSender);
83+
this.eisSettings = Objects.requireNonNull(eisSettings);
6484
}
6585

6686
@Override
@@ -75,11 +95,37 @@ protected void masterOperation(
7595
return;
7696
}
7797

78-
var enabledListener = listener.<Void>delegateFailureIgnoreResponseAndWrap(
79-
delegate -> delegate.onResponse(new CCMEnabledActionResponse(true))
80-
);
98+
SubscribableListener.<ElasticInferenceServiceAuthorizationModel>newForked(authValidationListener -> {
99+
var authRequestHandler = new ElasticInferenceServiceAuthorizationRequestHandler(
100+
eisSettings.getElasticInferenceServiceUrl(),
101+
threadPool,
102+
new ValidationAuthenticationFactory(request.getApiKey())
103+
);
104+
105+
var errorListener = authValidationListener.delegateResponse((delegate, exception) -> {
106+
// The exception will likely be a RetryException, so unwrap it to get to the real cause
107+
var unwrappedException = ExceptionsHelper.unwrapCause(exception);
108+
109+
logger.atWarn().withThrowable(unwrappedException).log(FAILED_VALIDATION_MESSAGE);
110+
111+
if (unwrappedException instanceof ElasticsearchStatusException statusException) {
112+
delegate.onFailure(
113+
new ElasticsearchStatusException(FAILED_VALIDATION_MESSAGE, statusException.status(), statusException)
114+
);
115+
return;
116+
}
117+
118+
delegate.onFailure(new ElasticsearchStatusException(FAILED_VALIDATION_MESSAGE, RestStatus.BAD_REQUEST, unwrappedException));
119+
});
120+
121+
authRequestHandler.getAuthorization(errorListener, eisSender);
122+
}).<CCMEnabledActionResponse>andThen((storeConfigurationListener) -> {
123+
var enabledListener = storeConfigurationListener.<Void>delegateFailureIgnoreResponseAndWrap(
124+
delegate -> delegate.onResponse(new CCMEnabledActionResponse(true))
125+
);
81126

82-
ccmService.storeConfiguration(new CCMModel(request.getApiKey()), enabledListener);
127+
ccmService.storeConfiguration(new CCMModel(request.getApiKey()), enabledListener);
128+
}).addListener(listener);
83129
}
84130

85131
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/RetryingHttpSender.java

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535

3636
public class RetryingHttpSender implements RequestSender {
3737

38-
public static final int MAX_RETIES = 3;
38+
public static final int MAX_RETRIES = 3;
3939

4040
private final HttpClient httpClient;
4141
private final ThrottlerManager throttlerManager;
@@ -108,14 +108,19 @@ public void tryAction(ActionListener<InferenceServiceResults> listener) {
108108
return;
109109
}
110110

111-
var retryableListener = listener.delegateResponse((l, e) -> {
111+
/*
112+
* This listener handles failures from the http client level such as an IOException or UnknownHostException. We will try
113+
* to determine if the exception is retryable and if so wrap it in a RetryException so that when we pass the failure to the
114+
* tryAction original listener it will get passed to shouldRetry() and be retried.
115+
*/
116+
var httpClientFailureListener = listener.delegateResponse((l, e) -> {
112117
logException(logger, request, responseHandler.getRequestType(), e);
113118
l.onFailure(transformIfRetryable(e));
114119
});
115120

116121
try {
117122
if (request.isStreaming() && responseHandler.canHandleStreamingResponses()) {
118-
httpClient.stream(request.createHttpRequest(), context, retryableListener.delegateFailure((l, r) -> {
123+
httpClient.stream(request.createHttpRequest(), context, httpClientFailureListener.delegateFailure((l, r) -> {
119124
if (r.isSuccessfulResponse()) {
120125
l.onResponse(responseHandler.parseResult(request, r.toHttpResult()));
121126
} else {
@@ -125,22 +130,28 @@ public void tryAction(ActionListener<InferenceServiceResults> listener) {
125130
InferenceServiceResults inferenceResults = responseHandler.parseResult(request, httpResult);
126131
ll.onResponse(inferenceResults);
127132
} catch (Exception e) {
133+
// A failure here typically happens when validateResponse() throws an exception for when we get a
134+
// failure
135+
// status code. We pass it back to the original listener so shouldRetry() can determine if we need to
136+
// retry.
128137
logException(logger, request, httpResult, responseHandler.getRequestType(), e);
129-
listener.onFailure(e); // skip retrying
138+
listener.onFailure(e);
130139
}
131140
}));
132141
}
133142
}));
134143
} else {
135-
httpClient.send(request.createHttpRequest(), context, retryableListener.delegateFailure((l, r) -> {
144+
httpClient.send(request.createHttpRequest(), context, httpClientFailureListener.delegateFailure((l, r) -> {
136145
try {
137146
responseHandler.validateResponse(throttlerManager, logger, request, r);
138147
InferenceServiceResults inferenceResults = responseHandler.parseResult(request, r);
139148

140149
l.onResponse(inferenceResults);
141150
} catch (Exception e) {
151+
// A failure here typically happens when validateResponse() throws an exception for when we get a failure
152+
// status code. We pass it back to the original listener so shouldRetry() can determine if we need to retry.
142153
logException(logger, request, r, responseHandler.getRequestType(), e);
143-
listener.onFailure(e); // skip retrying
154+
listener.onFailure(e);
144155
}
145156
}));
146157
}
@@ -191,7 +202,7 @@ private Exception wrapWithElasticsearchException(Exception e, String inferenceEn
191202

192203
@Override
193204
public boolean shouldRetry(Exception e) {
194-
if (retryCount.get() >= MAX_RETIES) {
205+
if (retryCount.get() >= MAX_RETRIES) {
195206
return false;
196207
}
197208

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
2323
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
2424
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceResponseHandler;
25-
import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMAuthenticationApplierFactory;
25+
import org.elasticsearch.xpack.inference.services.elastic.ccm.AuthenticationFactory;
2626
import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceAuthorizationRequest;
2727
import org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntity;
2828
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
@@ -55,12 +55,12 @@ private static ResponseHandler createAuthResponseHandler() {
5555
private final ThreadPool threadPool;
5656
private final Logger logger;
5757
private final CountDownLatch requestCompleteLatch = new CountDownLatch(1);
58-
private CCMAuthenticationApplierFactory authFactory;
58+
private final AuthenticationFactory authFactory;
5959

6060
public ElasticInferenceServiceAuthorizationRequestHandler(
6161
@Nullable String baseUrl,
6262
ThreadPool threadPool,
63-
CCMAuthenticationApplierFactory authFactory
63+
AuthenticationFactory authFactory
6464
) {
6565
this(
6666
baseUrl,
@@ -75,7 +75,7 @@ public ElasticInferenceServiceAuthorizationRequestHandler(
7575
@Nullable String baseUrl,
7676
ThreadPool threadPool,
7777
Logger logger,
78-
CCMAuthenticationApplierFactory authFactory
78+
AuthenticationFactory authFactory
7979
) {
8080
this.baseUrl = baseUrl;
8181
this.threadPool = Objects.requireNonNull(threadPool);
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.elastic.ccm;
9+
10+
import org.elasticsearch.action.ActionListener;
11+
12+
public interface AuthenticationFactory {
13+
void getAuthenticationApplier(ActionListener<CCMAuthenticationApplierFactory.AuthApplier> listener);
14+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ccm/CCMAuthenticationApplierFactory.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
/**
2424
* Returns a class to handle modifying the HTTP requests with the appropriate CCM authentication information if CCM is configured.
2525
*/
26-
public class CCMAuthenticationApplierFactory {
26+
public class CCMAuthenticationApplierFactory implements AuthenticationFactory {
2727

2828
public static final NoopApplier NOOP_APPLIER = new NoopApplier();
2929

0 commit comments

Comments
 (0)