Skip to content

Commit 04e1a64

Browse files
Adding dns lookup
1 parent 340f0c6 commit 04e1a64

File tree

6 files changed

+144
-29
lines changed

6 files changed

+144
-29
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -277,10 +277,7 @@ public Collection<?> createComponents(PluginServices services) {
277277
var inferenceServiceSettings = new ElasticInferenceServiceSettings(settings);
278278
inferenceServiceSettings.init(services.clusterService());
279279

280-
var authorizationHandler = new ElasticInferenceServiceAuthorizationRequestHandler(
281-
inferenceServiceSettings.getElasticInferenceServiceUrl(),
282-
services.threadPool()
283-
);
280+
var authorizationHandler = new ElasticInferenceServiceAuthorizationRequestHandler(inferenceServiceSettings, services.threadPool());
284281

285282
inferenceServices.add(
286283
() -> List.of(

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,16 @@ public class ElasticInferenceServiceSettings {
3333
Setting.Property.NodeScope
3434
);
3535

36+
/**
37+
* This controls whether the authorization logic will look up the ip address of the gateway.
38+
* This should only be enabled for testing and debugging purposes.
39+
*/
40+
static final Setting<Boolean> AUTH_DEBUGGING_DNS_LOOKUP_ENABLED = Setting.boolSetting(
41+
"xpack.inference.elastic.authorization_debugging_dns_lookup_enabled",
42+
false,
43+
Setting.Property.NodeScope
44+
);
45+
3646
/**
3747
* This setting is for testing only. It controls whether authorization is only performed once at bootup. If set to true, an
3848
* authorization request will be made repeatedly on an interval.
@@ -75,6 +85,7 @@ public class ElasticInferenceServiceSettings {
7585

7686
private final String elasticInferenceServiceUrl;
7787
private final boolean periodicAuthorizationEnabled;
88+
private final boolean authDebuggingDnsLookupEnabled;
7889
private volatile TimeValue authRequestInterval;
7990
private volatile TimeValue maxAuthorizationRequestJitter;
8091

@@ -84,6 +95,7 @@ public ElasticInferenceServiceSettings(Settings settings) {
8495
periodicAuthorizationEnabled = PERIODIC_AUTHORIZATION_ENABLED.get(settings);
8596
authRequestInterval = AUTHORIZATION_REQUEST_INTERVAL.get(settings);
8697
maxAuthorizationRequestJitter = MAX_AUTHORIZATION_REQUEST_JITTER.get(settings);
98+
authDebuggingDnsLookupEnabled = AUTH_DEBUGGING_DNS_LOOKUP_ENABLED.get(settings);
8799
}
88100

89101
/**
@@ -115,6 +127,14 @@ public TimeValue getMaxAuthorizationRequestJitter() {
115127
return maxAuthorizationRequestJitter;
116128
}
117129

130+
public boolean isAuthDebuggingDnsLookupEnabled() {
131+
return authDebuggingDnsLookupEnabled;
132+
}
133+
134+
public boolean isPeriodicAuthorizationEnabled() {
135+
return periodicAuthorizationEnabled;
136+
}
137+
118138
public static List<Setting<?>> getSettingsDefinitions() {
119139
ArrayList<Setting<?>> settings = new ArrayList<>();
120140
settings.add(EIS_GATEWAY_URL);
@@ -124,14 +144,11 @@ public static List<Setting<?>> getSettingsDefinitions() {
124144
settings.add(PERIODIC_AUTHORIZATION_ENABLED);
125145
settings.add(AUTHORIZATION_REQUEST_INTERVAL);
126146
settings.add(MAX_AUTHORIZATION_REQUEST_JITTER);
147+
settings.add(AUTH_DEBUGGING_DNS_LOOKUP_ENABLED);
127148
return settings;
128149
}
129150

130151
public String getElasticInferenceServiceUrl() {
131152
return Strings.isEmpty(elasticInferenceServiceUrl) ? eisGatewayUrl : elasticInferenceServiceUrl;
132153
}
133-
134-
public boolean isPeriodicAuthorizationEnabled() {
135-
return periodicAuthorizationEnabled;
136-
}
137154
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import org.elasticsearch.ElasticsearchWrapperException;
1313
import org.elasticsearch.action.ActionListener;
1414
import org.elasticsearch.common.Strings;
15-
import org.elasticsearch.core.Nullable;
1615
import org.elasticsearch.core.TimeValue;
1716
import org.elasticsearch.inference.InferenceServiceResults;
1817
import org.elasticsearch.tasks.Task;
@@ -22,8 +21,11 @@
2221
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
2322
import org.elasticsearch.xpack.inference.external.request.elastic.ElasticInferenceServiceAuthorizationRequest;
2423
import org.elasticsearch.xpack.inference.external.response.elastic.ElasticInferenceServiceAuthorizationResponseEntity;
24+
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings;
2525
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
2626

27+
import java.net.InetAddress;
28+
import java.net.URI;
2729
import java.util.Locale;
2830
import java.util.Objects;
2931
import java.util.concurrent.CountDownLatch;
@@ -48,20 +50,25 @@ private static ResponseHandler createAuthResponseHandler() {
4850
);
4951
}
5052

51-
private final String baseUrl;
53+
private final ElasticInferenceServiceSettings elasticInferenceServiceSettings;
5254
private final ThreadPool threadPool;
5355
private final Logger logger;
5456
private final CountDownLatch requestCompleteLatch = new CountDownLatch(1);
5557

56-
public ElasticInferenceServiceAuthorizationRequestHandler(@Nullable String baseUrl, ThreadPool threadPool) {
57-
this.baseUrl = baseUrl;
58-
this.threadPool = Objects.requireNonNull(threadPool);
59-
logger = LogManager.getLogger(ElasticInferenceServiceAuthorizationRequestHandler.class);
58+
public ElasticInferenceServiceAuthorizationRequestHandler(
59+
ElasticInferenceServiceSettings elasticInferenceServiceSettings,
60+
ThreadPool threadPool
61+
) {
62+
this(elasticInferenceServiceSettings, threadPool, LogManager.getLogger(ElasticInferenceServiceAuthorizationRequestHandler.class));
6063
}
6164

6265
// only use for testing
63-
ElasticInferenceServiceAuthorizationRequestHandler(@Nullable String baseUrl, ThreadPool threadPool, Logger logger) {
64-
this.baseUrl = baseUrl;
66+
ElasticInferenceServiceAuthorizationRequestHandler(
67+
ElasticInferenceServiceSettings elasticInferenceServiceSettings,
68+
ThreadPool threadPool,
69+
Logger logger
70+
) {
71+
this.elasticInferenceServiceSettings = Objects.requireNonNull(elasticInferenceServiceSettings);
6572
this.threadPool = Objects.requireNonNull(threadPool);
6673
this.logger = Objects.requireNonNull(logger);
6774
}
@@ -75,7 +82,7 @@ public void getAuthorization(ActionListener<ElasticInferenceServiceAuthorization
7582
try {
7683
logger.debug("Retrieving authorization information from the Elastic Inference Service.");
7784

78-
if (Strings.isNullOrEmpty(baseUrl)) {
85+
if (Strings.isNullOrEmpty(elasticInferenceServiceSettings.getElasticInferenceServiceUrl())) {
7986
logger.debug("The base URL for the authorization service is not valid, rejecting authorization.");
8087
listener.onResponse(ElasticInferenceServiceAuthorizationModel.newDisabledService());
8188
return;
@@ -108,7 +115,28 @@ public void getAuthorization(ActionListener<ElasticInferenceServiceAuthorization
108115
requestCompleteLatch.countDown();
109116
});
110117

111-
var request = new ElasticInferenceServiceAuthorizationRequest(baseUrl, getCurrentTraceInfo());
118+
if (logger.isDebugEnabled() && elasticInferenceServiceSettings.isAuthDebuggingDnsLookupEnabled()) {
119+
logger.debug("Attempting to look up authorization gateway host info");
120+
try {
121+
var baseUri = new URI(elasticInferenceServiceSettings.getElasticInferenceServiceUrl());
122+
logger.debug(
123+
Strings.format(
124+
"Looking up authorization gateway ip for url: %s, host: %s",
125+
elasticInferenceServiceSettings.getElasticInferenceServiceUrl(),
126+
baseUri.getHost()
127+
)
128+
);
129+
var gatewayAddress = InetAddress.getByName(baseUri.getHost());
130+
logger.debug(Strings.format("Gateway address: %s", gatewayAddress.getHostAddress()));
131+
} catch (Exception e) {
132+
logger.debug("Failed to resolve gateway address", e);
133+
}
134+
}
135+
136+
var request = new ElasticInferenceServiceAuthorizationRequest(
137+
elasticInferenceServiceSettings.getElasticInferenceServiceUrl(),
138+
getCurrentTraceInfo()
139+
);
112140

113141
sender.sendWithoutQueuing(logger, request, AUTH_RESPONSE_HANDLER, DEFAULT_AUTH_TIMEOUT, newListener);
114142
} catch (Exception e) {

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettingsTests.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,15 @@ public static ElasticInferenceServiceSettings create(
4242
return new ElasticInferenceServiceSettings(settings);
4343
}
4444

45+
public static ElasticInferenceServiceSettings createWithUrlAndDnsLookup(String elasticInferenceServiceUrl) {
46+
var settings = Settings.builder()
47+
.put(ElasticInferenceServiceSettings.ELASTIC_INFERENCE_SERVICE_URL.getKey(), elasticInferenceServiceUrl)
48+
.put(ElasticInferenceServiceSettings.AUTH_DEBUGGING_DNS_LOOKUP_ENABLED.getKey(), true)
49+
.build();
50+
51+
return new ElasticInferenceServiceSettings(settings);
52+
}
53+
4554
public void testGetElasticInferenceServiceUrl_WithUrlSetting() {
4655
var settings = Settings.builder()
4756
.put(ElasticInferenceServiceSettings.ELASTIC_INFERENCE_SERVICE_URL.getKey(), ELASTIC_INFERENCE_SERVICE_URL)

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1163,7 +1163,7 @@ private ElasticInferenceService createServiceWithAuthHandler(HttpRequestSender.F
11631163
createWithEmptySettings(threadPool),
11641164
ElasticInferenceServiceSettingsTests.create(eisGatewayUrl),
11651165
mockModelRegistry(),
1166-
new ElasticInferenceServiceAuthorizationRequestHandler(eisGatewayUrl, threadPool)
1166+
new ElasticInferenceServiceAuthorizationRequestHandler(ElasticInferenceServiceSettingsTests.create(eisGatewayUrl), threadPool)
11671167
);
11681168
}
11691169

@@ -1177,7 +1177,7 @@ public static ElasticInferenceService createServiceWithAuthHandler(
11771177
createWithEmptySettings(threadPool),
11781178
ElasticInferenceServiceSettingsTests.create(eisGatewayUrl),
11791179
mockModelRegistry(threadPool),
1180-
new ElasticInferenceServiceAuthorizationRequestHandler(eisGatewayUrl, threadPool)
1180+
new ElasticInferenceServiceAuthorizationRequestHandler(ElasticInferenceServiceSettingsTests.create(eisGatewayUrl), threadPool)
11811181
);
11821182
}
11831183
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java

Lines changed: 73 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
2525
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
2626
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
27+
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettingsTests;
2728
import org.junit.After;
2829
import org.junit.Before;
2930
import org.mockito.ArgumentCaptor;
@@ -44,7 +45,6 @@
4445
import static org.mockito.Mockito.mock;
4546
import static org.mockito.Mockito.times;
4647
import static org.mockito.Mockito.verify;
47-
import static org.mockito.Mockito.verifyNoMoreInteractions;
4848
import static org.mockito.Mockito.when;
4949

5050
public class ElasticInferenceServiceAuthorizationRequestHandlerTests extends ESTestCase {
@@ -71,7 +71,11 @@ public void shutdown() throws IOException {
7171
public void testDoesNotAttempt_ToRetrieveAuthorization_IfBaseUrlIsNull() throws Exception {
7272
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
7373
var logger = mock(Logger.class);
74-
var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler(null, threadPool, logger);
74+
var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler(
75+
ElasticInferenceServiceSettingsTests.create(null),
76+
threadPool,
77+
logger
78+
);
7579

7680
try (var sender = senderFactory.createSender()) {
7781
PlainActionFuture<ElasticInferenceServiceAuthorizationModel> listener = new PlainActionFuture<>();
@@ -93,7 +97,11 @@ public void testDoesNotAttempt_ToRetrieveAuthorization_IfBaseUrlIsNull() throws
9397
public void testDoesNotAttempt_ToRetrieveAuthorization_IfBaseUrlIsEmpty() throws Exception {
9498
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
9599
var logger = mock(Logger.class);
96-
var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler("", threadPool, logger);
100+
var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler(
101+
ElasticInferenceServiceSettingsTests.create(""),
102+
threadPool,
103+
logger
104+
);
97105

98106
try (var sender = senderFactory.createSender()) {
99107
PlainActionFuture<ElasticInferenceServiceAuthorizationModel> listener = new PlainActionFuture<>();
@@ -116,7 +124,11 @@ public void testGetAuthorization_FailsWhenAnInvalidFieldIsFound() throws IOExcep
116124
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
117125
var eisGatewayUrl = getUrl(webServer);
118126
var logger = mock(Logger.class);
119-
var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler(eisGatewayUrl, threadPool, logger);
127+
var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler(
128+
ElasticInferenceServiceSettingsTests.create(eisGatewayUrl),
129+
threadPool,
130+
logger
131+
);
120132

121133
try (var sender = senderFactory.createSender()) {
122134
String responseJson = """
@@ -167,7 +179,11 @@ public void testGetAuthorization_ReturnsAValidResponse() throws IOException {
167179
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
168180
var eisGatewayUrl = getUrl(webServer);
169181
var logger = mock(Logger.class);
170-
var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler(eisGatewayUrl, threadPool, logger);
182+
var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler(
183+
ElasticInferenceServiceSettingsTests.create(eisGatewayUrl),
184+
threadPool,
185+
logger
186+
);
171187

172188
try (var sender = senderFactory.createSender()) {
173189
String responseJson = """
@@ -196,7 +212,48 @@ public void testGetAuthorization_ReturnsAValidResponse() throws IOException {
196212

197213
var message = loggerArgsCaptor.getValue();
198214
assertThat(message, is("Retrieving authorization information from the Elastic Inference Service."));
199-
verifyNoMoreInteractions(logger);
215+
}
216+
}
217+
218+
public void testGetAuthorization_PerformsADnsLookup() throws IOException {
219+
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
220+
var eisGatewayUrl = getUrl(webServer);
221+
var logger = mock(Logger.class);
222+
when(logger.isDebugEnabled()).thenReturn(true);
223+
var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler(
224+
ElasticInferenceServiceSettingsTests.createWithUrlAndDnsLookup(eisGatewayUrl),
225+
threadPool,
226+
logger
227+
);
228+
229+
try (var sender = senderFactory.createSender()) {
230+
String responseJson = """
231+
{
232+
"models": [
233+
{
234+
"model_name": "model-a",
235+
"task_types": ["embed/text/sparse", "chat"]
236+
}
237+
]
238+
}
239+
""";
240+
241+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
242+
243+
PlainActionFuture<ElasticInferenceServiceAuthorizationModel> listener = new PlainActionFuture<>();
244+
authHandler.getAuthorization(listener, sender);
245+
246+
var authResponse = listener.actionGet(TIMEOUT);
247+
assertThat(authResponse.getAuthorizedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION)));
248+
assertThat(authResponse.getAuthorizedModelIds(), is(Set.of("model-a")));
249+
assertTrue(authResponse.isAuthorized());
250+
251+
var loggerArgsCaptor = ArgumentCaptor.forClass(String.class);
252+
verify(logger, times(4)).debug(loggerArgsCaptor.capture());
253+
254+
var messages = loggerArgsCaptor.getAllValues();
255+
assertThat(messages.size(), is(4));
256+
assertThat(messages.get(3), is("Gateway address: 127.0.0.1"));
200257
}
201258
}
202259

@@ -205,7 +262,11 @@ public void testGetAuthorization_OnResponseCalledOnce() throws IOException {
205262
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
206263
var eisGatewayUrl = getUrl(webServer);
207264
var logger = mock(Logger.class);
208-
var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler(eisGatewayUrl, threadPool, logger);
265+
var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler(
266+
ElasticInferenceServiceSettingsTests.create(eisGatewayUrl),
267+
threadPool,
268+
logger
269+
);
209270

210271
ActionListener<ElasticInferenceServiceAuthorizationModel> listener = mock(ActionListener.class);
211272
String responseJson = """
@@ -230,7 +291,6 @@ public void testGetAuthorization_OnResponseCalledOnce() throws IOException {
230291

231292
var message = loggerArgsCaptor.getValue();
232293
assertThat(message, is("Retrieving authorization information from the Elastic Inference Service."));
233-
verifyNoMoreInteractions(logger);
234294
}
235295
}
236296

@@ -246,7 +306,11 @@ public void testGetAuthorization_InvalidResponse() throws IOException {
246306
}).when(senderMock).sendWithoutQueuing(any(), any(), any(), any(), any());
247307

248308
var logger = mock(Logger.class);
249-
var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler("abc", threadPool, logger);
309+
var authHandler = new ElasticInferenceServiceAuthorizationRequestHandler(
310+
ElasticInferenceServiceSettingsTests.create("abc"),
311+
threadPool,
312+
logger
313+
);
250314

251315
try (var sender = senderFactory.createSender()) {
252316
PlainActionFuture<ElasticInferenceServiceAuthorizationModel> listener = new PlainActionFuture<>();

0 commit comments

Comments
 (0)