diff --git a/common/src/main/java/org/opensearch/ml/common/httpclient/MLHttpClientFactory.java b/common/src/main/java/org/opensearch/ml/common/httpclient/MLHttpClientFactory.java index 109a5de5f8..59831f4ae5 100644 --- a/common/src/main/java/org/opensearch/ml/common/httpclient/MLHttpClientFactory.java +++ b/common/src/main/java/org/opensearch/ml/common/httpclient/MLHttpClientFactory.java @@ -5,15 +5,9 @@ package org.opensearch.ml.common.httpclient; -import java.net.Inet4Address; -import java.net.InetAddress; -import java.net.UnknownHostException; -import java.time.Duration; -import java.util.Arrays; -import java.util.Locale; -import java.util.concurrent.atomic.AtomicBoolean; +import static org.opensearch.secure_sm.AccessController.doPrivileged; -import org.opensearch.common.util.concurrent.ThreadContextAccess; +import java.time.Duration; import lombok.extern.log4j.Log4j2; import software.amazon.awssdk.http.async.SdkAsyncHttpClient; @@ -22,79 +16,27 @@ @Log4j2 public class MLHttpClientFactory { - public static SdkAsyncHttpClient getAsyncHttpClient(Duration connectionTimeout, Duration readTimeout, int maxConnections) { - return ThreadContextAccess - .doPrivileged( - () -> NettyNioAsyncHttpClient - .builder() - .connectionTimeout(connectionTimeout) - .readTimeout(readTimeout) - .maxConcurrency(maxConnections) - .build() - ); - } - - /** - * Validate the input parameters, such as protocol, host and port. - * @param protocol The protocol supported in remote inference, currently only http and https are supported. - * @param host The host name of the remote inference server, host must be a valid ip address or domain name and must not be localhost. - * @param port The port number of the remote inference server, port number must be in range [0, 65536]. - * @param connectorPrivateIpEnabled The port number of the remote inference server, port number must be in range [0, 65536]. - * @throws UnknownHostException Allow to use private IP or not. - */ - public static void validate(String protocol, String host, int port, AtomicBoolean connectorPrivateIpEnabled) - throws UnknownHostException { - if (protocol != null && !"http".equalsIgnoreCase(protocol) && !"https".equalsIgnoreCase(protocol)) { - log.error("Remote inference protocol is not http or https: {}", protocol); - throw new IllegalArgumentException("Protocol is not http or https: " + protocol); - } - // When port is not specified, the default port is -1, and we need to set it to 80 or 443 based on protocol. - if (port == -1) { - if (protocol == null || "http".equals(protocol.toLowerCase(Locale.getDefault()))) { - port = 80; - } else { - port = 443; - } - } - if (port < 0 || port > 65536) { - log.error("Remote inference port out of range: {}", port); - throw new IllegalArgumentException("Port out of range: " + port); - } - validateIp(host, connectorPrivateIpEnabled); - } - - private static void validateIp(String hostName, AtomicBoolean connectorPrivateIpEnabled) throws UnknownHostException { - InetAddress[] addresses = InetAddress.getAllByName(hostName); - if ((connectorPrivateIpEnabled == null || !connectorPrivateIpEnabled.get()) && hasPrivateIpAddress(addresses)) { - log.error("Remote inference host name has private ip address: {}", hostName); - throw new IllegalArgumentException("Remote inference host name has private ip address: " + hostName); - } - } - - private static boolean hasPrivateIpAddress(InetAddress[] ipAddress) { - for (InetAddress ip : ipAddress) { - if (ip instanceof Inet4Address) { - byte[] bytes = ip.getAddress(); - if (bytes.length != 4) { - return true; - } else { - if (isPrivateIPv4(bytes)) { - return true; - } - } - } - } - return Arrays.stream(ipAddress).anyMatch(x -> x.isSiteLocalAddress() || x.isLoopbackAddress() || x.isAnyLocalAddress()); - } - - private static boolean isPrivateIPv4(byte[] bytes) { - int first = bytes[0] & 0xff; - int second = bytes[1] & 0xff; - - // 127.0.0.1, 10.x.x.x, 172.16-31.x.x, 192.168.x.x, 169.254.x.x - return (first == 10) - || (first == 172 && second >= 16 && second <= 31) - || (first == 192 && second == 168) - || (first == 169 && second == 254); + public static SdkAsyncHttpClient getAsyncHttpClient( + Duration connectionTimeout, + Duration readTimeout, + int maxConnections, + boolean connectorPrivateIpEnabled + ) { + return doPrivileged(() -> { + log + .debug( + "Creating MLHttpClient with connectionTimeout: {}, readTimeout: {}, maxConnections: {}", + connectionTimeout, + readTimeout, + maxConnections + ); + SdkAsyncHttpClient delegate = NettyNioAsyncHttpClient + .builder() + .connectionTimeout(connectionTimeout) + .readTimeout(readTimeout) + .maxConcurrency(maxConnections) + .build(); + return new MLValidatableAsyncHttpClient(delegate, connectorPrivateIpEnabled); + }); } } diff --git a/common/src/main/java/org/opensearch/ml/common/httpclient/MLValidatableAsyncHttpClient.java b/common/src/main/java/org/opensearch/ml/common/httpclient/MLValidatableAsyncHttpClient.java new file mode 100644 index 0000000000..35595765b4 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/httpclient/MLValidatableAsyncHttpClient.java @@ -0,0 +1,110 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.httpclient; + +import java.net.Inet4Address; +import java.net.InetAddress; +import java.net.UnknownHostException; +import java.util.Arrays; +import java.util.Locale; +import java.util.concurrent.CompletableFuture; + +import lombok.extern.log4j.Log4j2; +import software.amazon.awssdk.http.async.AsyncExecuteRequest; +import software.amazon.awssdk.http.async.SdkAsyncHttpClient; + +@Log4j2 +public class MLValidatableAsyncHttpClient implements SdkAsyncHttpClient { + private final SdkAsyncHttpClient delegate; + private final boolean connectorPrivateIpEnabled; + + protected MLValidatableAsyncHttpClient(SdkAsyncHttpClient client, boolean connectorPrivateIpEnabled) { + this.delegate = client; + this.connectorPrivateIpEnabled = connectorPrivateIpEnabled; + } + + @Override + public CompletableFuture execute(AsyncExecuteRequest request) { + String protocol = request.request().protocol(); + String host = request.request().host(); + int port = request.request().port(); + try { + validate(protocol, host, port, connectorPrivateIpEnabled); + return delegate.execute(request); + } catch (Exception e) { + log.error("Failed to validate request!", e); + throw new IllegalArgumentException(e.getMessage(), e); + } + } + + @Override + public void close() { + delegate.close(); + } + + /** + * Validate the input parameters, such as protocol, host and port. + * @param protocol The protocol supported in remote inference, currently only http and https are supported. + * @param host The host name of the remote inference server, host must be a valid ip address or domain name and must not be localhost. + * @param port The port number of the remote inference server, port number must be in range [0, 65536]. + * @param connectorPrivateIpEnabled The port number of the remote inference server, port number must be in range [0, 65536]. + * @throws UnknownHostException Allow to use private IP or not. + */ + public void validate(String protocol, String host, int port, boolean connectorPrivateIpEnabled) throws UnknownHostException { + if (protocol != null && !"http".equalsIgnoreCase(protocol) && !"https".equalsIgnoreCase(protocol)) { + log.error("Remote inference protocol is not http or https: {}", protocol); + throw new IllegalArgumentException("Protocol is not http or https: " + protocol); + } + // When port is not specified, the default port is -1, and we need to set it to 80 or 443 based on protocol. + if (port == -1) { + if (protocol == null || "http".equals(protocol.toLowerCase(Locale.getDefault()))) { + port = 80; + } else { + port = 443; + } + } + if (port < 0 || port > 65536) { + log.error("Remote inference port out of range: {}", port); + throw new IllegalArgumentException("Port out of range: " + port); + } + validateIp(host, connectorPrivateIpEnabled); + } + + private void validateIp(String hostName, boolean connectorPrivateIpEnabled) throws UnknownHostException { + InetAddress[] addresses = InetAddress.getAllByName(hostName); + if (!connectorPrivateIpEnabled && hasPrivateIpAddress(addresses)) { + log.error("Remote inference host name has private ip address: {}", hostName); + throw new IllegalArgumentException("Remote inference host name has private ip address: " + hostName); + } + } + + private boolean hasPrivateIpAddress(InetAddress[] ipAddress) { + for (InetAddress ip : ipAddress) { + if (ip instanceof Inet4Address) { + byte[] bytes = ip.getAddress(); + if (bytes.length != 4) { + return true; + } else { + if (isPrivateIPv4(bytes)) { + return true; + } + } + } + } + return Arrays.stream(ipAddress).anyMatch(x -> x.isSiteLocalAddress() || x.isLoopbackAddress() || x.isAnyLocalAddress()); + } + + private boolean isPrivateIPv4(byte[] bytes) { + int first = bytes[0] & 0xff; + int second = bytes[1] & 0xff; + + // 127.0.0.1, 10.x.x.x, 172.16-31.x.x, 192.168.x.x, 169.254.x.x + return (first == 10) + || (first == 172 && second >= 16 && second <= 31) + || (first == 192 && second == 168) + || (first == 169 && second == 254); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/settings/MLFeatureEnabledSetting.java b/common/src/main/java/org/opensearch/ml/common/settings/MLFeatureEnabledSetting.java index dacfef73ba..2642a12db6 100644 --- a/common/src/main/java/org/opensearch/ml/common/settings/MLFeatureEnabledSetting.java +++ b/common/src/main/java/org/opensearch/ml/common/settings/MLFeatureEnabledSetting.java @@ -25,7 +25,6 @@ import java.util.ArrayList; import java.util.List; -import java.util.concurrent.atomic.AtomicBoolean; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; @@ -38,7 +37,7 @@ public class MLFeatureEnabledSetting { private volatile Boolean isAgentFrameworkEnabled; private volatile Boolean isLocalModelEnabled; - private volatile AtomicBoolean isConnectorPrivateIpEnabled; + private volatile Boolean isConnectorPrivateIpEnabled; private volatile Boolean isControllerEnabled; private volatile Boolean isBatchIngestionEnabled; @@ -70,7 +69,7 @@ public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings) isRemoteInferenceEnabled = ML_COMMONS_REMOTE_INFERENCE_ENABLED.get(settings); isAgentFrameworkEnabled = ML_COMMONS_AGENT_FRAMEWORK_ENABLED.get(settings); isLocalModelEnabled = ML_COMMONS_LOCAL_MODEL_ENABLED.get(settings); - isConnectorPrivateIpEnabled = new AtomicBoolean(ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED.get(settings)); + isConnectorPrivateIpEnabled = ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED.get(settings); isControllerEnabled = ML_COMMONS_CONTROLLER_ENABLED.get(settings); isBatchIngestionEnabled = ML_COMMONS_OFFLINE_BATCH_INGESTION_ENABLED.get(settings); isBatchInferenceEnabled = ML_COMMONS_OFFLINE_BATCH_INFERENCE_ENABLED.get(settings); @@ -94,7 +93,7 @@ public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings) clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_LOCAL_MODEL_ENABLED, it -> isLocalModelEnabled = it); clusterService .getClusterSettings() - .addSettingsUpdateConsumer(ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED, it -> isConnectorPrivateIpEnabled.set(it)); + .addSettingsUpdateConsumer(ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED, it -> isConnectorPrivateIpEnabled = it); clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_CONTROLLER_ENABLED, it -> isControllerEnabled = it); clusterService .getClusterSettings() @@ -145,7 +144,7 @@ public boolean isLocalModelEnabled() { return isLocalModelEnabled; } - public AtomicBoolean isConnectorPrivateIpEnabled() { + public boolean isConnectorPrivateIpEnabled() { return isConnectorPrivateIpEnabled; } diff --git a/common/src/test/java/org/opensearch/ml/common/httpclient/MLHttpClientFactoryTests.java b/common/src/test/java/org/opensearch/ml/common/httpclient/MLHttpClientFactoryTests.java index 1c01172344..d0664cca3a 100644 --- a/common/src/test/java/org/opensearch/ml/common/httpclient/MLHttpClientFactoryTests.java +++ b/common/src/test/java/org/opensearch/ml/common/httpclient/MLHttpClientFactoryTests.java @@ -5,187 +5,20 @@ package org.opensearch.ml.common.httpclient; -import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertThrows; import java.time.Duration; -import java.util.concurrent.atomic.AtomicBoolean; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import software.amazon.awssdk.http.async.SdkAsyncHttpClient; public class MLHttpClientFactoryTests { - private static final String TEST_HOST = "api.openai.com"; - private static final String HTTP = "http"; - private static final String HTTPS = "https"; - private static final AtomicBoolean PRIVATE_IP_DISABLED = new AtomicBoolean(false); - private static final AtomicBoolean PRIVATE_IP_ENABLED = new AtomicBoolean(true); - - @Rule - public ExpectedException expectedException = ExpectedException.none(); - @Test public void test_getSdkAsyncHttpClient_success() { - SdkAsyncHttpClient client = MLHttpClientFactory.getAsyncHttpClient(Duration.ofSeconds(100), Duration.ofSeconds(100), 100); + SdkAsyncHttpClient client = MLHttpClientFactory.getAsyncHttpClient(Duration.ofSeconds(100), Duration.ofSeconds(100), 100, false); assertNotNull(client); } - @Test - public void test_invalidIP_localHost_privateIPDisabled() { - IllegalArgumentException e1 = assertThrows( - IllegalArgumentException.class, - () -> MLHttpClientFactory.validate(HTTP, "127.0.0.1", 80, PRIVATE_IP_DISABLED) - ); - assertEquals("Remote inference host name has private ip address: 127.0.0.1", e1.getMessage()); - - IllegalArgumentException e2 = assertThrows( - IllegalArgumentException.class, - () -> MLHttpClientFactory.validate(HTTP, "192.168.0.1", 80, PRIVATE_IP_DISABLED) - ); - assertEquals("Remote inference host name has private ip address: 192.168.0.1", e2.getMessage()); - - IllegalArgumentException e3 = assertThrows( - IllegalArgumentException.class, - () -> MLHttpClientFactory.validate(HTTP, "169.254.0.1", 80, PRIVATE_IP_DISABLED) - ); - assertEquals("Remote inference host name has private ip address: 169.254.0.1", e3.getMessage()); - - IllegalArgumentException e4 = assertThrows( - IllegalArgumentException.class, - () -> MLHttpClientFactory.validate(HTTP, "172.16.0.1", 80, PRIVATE_IP_DISABLED) - ); - assertEquals("Remote inference host name has private ip address: 172.16.0.1", e4.getMessage()); - - IllegalArgumentException e5 = assertThrows( - IllegalArgumentException.class, - () -> MLHttpClientFactory.validate(HTTP, "172.31.0.1", 80, PRIVATE_IP_DISABLED) - ); - assertEquals("Remote inference host name has private ip address: 172.31.0.1", e5.getMessage()); - } - - @Test - public void test_validateIp_validIp_noException() throws Exception { - MLHttpClientFactory.validate(HTTP, TEST_HOST, 80, PRIVATE_IP_DISABLED); - MLHttpClientFactory.validate(HTTPS, TEST_HOST, 443, PRIVATE_IP_DISABLED); - MLHttpClientFactory.validate(HTTP, "127.0.0.1", 80, PRIVATE_IP_ENABLED); - MLHttpClientFactory.validate(HTTPS, "127.0.0.1", 443, PRIVATE_IP_ENABLED); - MLHttpClientFactory.validate(HTTP, "177.16.0.1", 80, PRIVATE_IP_DISABLED); - MLHttpClientFactory.validate(HTTP, "177.0.1.1", 80, PRIVATE_IP_DISABLED); - MLHttpClientFactory.validate(HTTP, "177.0.0.2", 80, PRIVATE_IP_DISABLED); - MLHttpClientFactory.validate(HTTP, "::ffff", 80, PRIVATE_IP_DISABLED); - MLHttpClientFactory.validate(HTTP, "172.32.0.1", 80, PRIVATE_IP_ENABLED); - MLHttpClientFactory.validate(HTTP, "172.2097152", 80, PRIVATE_IP_ENABLED); - } - - @Test - public void test_validateIp_rarePrivateIp_throwException() throws Exception { - try { - MLHttpClientFactory.validate(HTTP, "0254.020.00.01", 80, PRIVATE_IP_DISABLED); - } catch (IllegalArgumentException e) { - assertNotNull(e); - } - - try { - MLHttpClientFactory.validate(HTTP, "172.1048577", 80, PRIVATE_IP_DISABLED); - } catch (Exception e) { - assertNotNull(e); - } - - try { - MLHttpClientFactory.validate(HTTP, "2886729729", 80, PRIVATE_IP_DISABLED); - } catch (IllegalArgumentException e) { - assertNotNull(e); - } - - try { - MLHttpClientFactory.validate(HTTP, "192.11010049", 80, PRIVATE_IP_DISABLED); - } catch (IllegalArgumentException e) { - assertNotNull(e); - } - - try { - MLHttpClientFactory.validate(HTTP, "3232300545", 80, PRIVATE_IP_DISABLED); - } catch (IllegalArgumentException e) { - assertNotNull(e); - } - - try { - MLHttpClientFactory.validate(HTTP, "0:0:0:0:0:ffff:127.0.0.1", 80, PRIVATE_IP_DISABLED); - } catch (IllegalArgumentException e) { - assertNotNull(e); - } - - try { - MLHttpClientFactory.validate(HTTP, "153.24.76.232", 80, PRIVATE_IP_DISABLED); - } catch (IllegalArgumentException e) { - assertNotNull(e); - } - - try { - MLHttpClientFactory.validate(HTTP, "177.0.0.1", 80, PRIVATE_IP_DISABLED); - } catch (IllegalArgumentException e) { - assertNotNull(e); - } - - try { - MLHttpClientFactory.validate(HTTP, "12.16.2.3", 80, PRIVATE_IP_DISABLED); - } catch (IllegalArgumentException e) { - assertNotNull(e); - } - } - - @Test - public void test_validateIp_rarePrivateIp_NotThrowException() throws Exception { - MLHttpClientFactory.validate(HTTP, "0254.020.00.01", 80, PRIVATE_IP_ENABLED); - MLHttpClientFactory.validate(HTTPS, "0254.020.00.01", 443, PRIVATE_IP_ENABLED); - MLHttpClientFactory.validate(HTTP, "172.1048577", 80, PRIVATE_IP_ENABLED); - MLHttpClientFactory.validate(HTTP, "2886729729", 80, PRIVATE_IP_ENABLED); - MLHttpClientFactory.validate(HTTP, "192.11010049", 80, PRIVATE_IP_ENABLED); - MLHttpClientFactory.validate(HTTP, "3232300545", 80, PRIVATE_IP_ENABLED); - MLHttpClientFactory.validate(HTTP, "0:0:0:0:0:ffff:127.0.0.1", 80, PRIVATE_IP_ENABLED); - MLHttpClientFactory.validate(HTTPS, "0:0:0:0:0:ffff:127.0.0.1", 443, PRIVATE_IP_ENABLED); - MLHttpClientFactory.validate(HTTP, "153.24.76.232", 80, PRIVATE_IP_ENABLED); - MLHttpClientFactory.validate(HTTP, "10.24.76.186", 80, PRIVATE_IP_ENABLED); - MLHttpClientFactory.validate(HTTPS, "10.24.76.186", 443, PRIVATE_IP_ENABLED); - } - - @Test - public void test_validateSchemaAndPort_success() throws Exception { - MLHttpClientFactory.validate(HTTP, TEST_HOST, 80, PRIVATE_IP_DISABLED); - } - - @Test - public void test_validateSchemaAndPort_notAllowedSchema_throwException() throws Exception { - expectedException.expect(IllegalArgumentException.class); - MLHttpClientFactory.validate("ftp", TEST_HOST, 80, PRIVATE_IP_DISABLED); - } - - @Test - public void test_validateSchemaAndPort_portNotInRange1_throwException() throws Exception { - expectedException.expect(IllegalArgumentException.class); - expectedException.expectMessage("Port out of range: 65537"); - MLHttpClientFactory.validate(HTTPS, TEST_HOST, 65537, PRIVATE_IP_DISABLED); - } - - @Test - public void test_validateSchemaAndPort_portNotInRange2_throwException() throws Exception { - expectedException.expect(IllegalArgumentException.class); - expectedException.expectMessage("Port out of range: -10"); - MLHttpClientFactory.validate(HTTP, TEST_HOST, -10, PRIVATE_IP_DISABLED); - } - - @Test - public void test_validatePort_boundaries_success() throws Exception { - MLHttpClientFactory.validate(HTTP, TEST_HOST, 65536, PRIVATE_IP_DISABLED); - MLHttpClientFactory.validate(HTTP, TEST_HOST, 0, PRIVATE_IP_DISABLED); - MLHttpClientFactory.validate(HTTP, TEST_HOST, -1, PRIVATE_IP_DISABLED); - MLHttpClientFactory.validate(HTTPS, TEST_HOST, -1, PRIVATE_IP_DISABLED); - MLHttpClientFactory.validate(null, TEST_HOST, -1, PRIVATE_IP_DISABLED); - } - } diff --git a/common/src/test/java/org/opensearch/ml/common/httpclient/MLValidatableAsyncHttpClientTests.java b/common/src/test/java/org/opensearch/ml/common/httpclient/MLValidatableAsyncHttpClientTests.java new file mode 100644 index 0000000000..60b902120e --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/httpclient/MLValidatableAsyncHttpClientTests.java @@ -0,0 +1,187 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.httpclient; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThrows; +import static org.mockito.Mockito.mock; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +import software.amazon.awssdk.http.async.SdkAsyncHttpClient; + +public class MLValidatableAsyncHttpClientTests { + private static final String TEST_HOST = "api.openai.com"; + private static final String HTTP = "http"; + private static final String HTTPS = "https"; + private static final boolean PRIVATE_IP_DISABLED = false; + private static final boolean PRIVATE_IP_ENABLED = true; + + private final MLValidatableAsyncHttpClient validatingHttpClient = new MLValidatableAsyncHttpClient( + mock(SdkAsyncHttpClient.class), + PRIVATE_IP_DISABLED + ); + + @Rule + public ExpectedException expectedException = ExpectedException.none(); + + @Test + public void test_invalidIP_localHost_privateIPDisabled() { + IllegalArgumentException e1 = assertThrows( + IllegalArgumentException.class, + () -> validatingHttpClient.validate(HTTP, "127.0.0.1", 80, PRIVATE_IP_DISABLED) + ); + assertEquals("Remote inference host name has private ip address: 127.0.0.1", e1.getMessage()); + + IllegalArgumentException e2 = assertThrows( + IllegalArgumentException.class, + () -> validatingHttpClient.validate(HTTP, "192.168.0.1", 80, PRIVATE_IP_DISABLED) + ); + assertEquals("Remote inference host name has private ip address: 192.168.0.1", e2.getMessage()); + + IllegalArgumentException e3 = assertThrows( + IllegalArgumentException.class, + () -> validatingHttpClient.validate(HTTP, "169.254.0.1", 80, PRIVATE_IP_DISABLED) + ); + assertEquals("Remote inference host name has private ip address: 169.254.0.1", e3.getMessage()); + + IllegalArgumentException e4 = assertThrows( + IllegalArgumentException.class, + () -> validatingHttpClient.validate(HTTP, "172.16.0.1", 80, PRIVATE_IP_DISABLED) + ); + assertEquals("Remote inference host name has private ip address: 172.16.0.1", e4.getMessage()); + + IllegalArgumentException e5 = assertThrows( + IllegalArgumentException.class, + () -> validatingHttpClient.validate(HTTP, "172.31.0.1", 80, PRIVATE_IP_DISABLED) + ); + assertEquals("Remote inference host name has private ip address: 172.31.0.1", e5.getMessage()); + } + + @Test + public void test_validateIp_validIp_noException() throws Exception { + validatingHttpClient.validate(HTTP, TEST_HOST, 80, PRIVATE_IP_DISABLED); + validatingHttpClient.validate(HTTPS, TEST_HOST, 443, PRIVATE_IP_DISABLED); + validatingHttpClient.validate(HTTP, "127.0.0.1", 80, PRIVATE_IP_ENABLED); + validatingHttpClient.validate(HTTPS, "127.0.0.1", 443, PRIVATE_IP_ENABLED); + validatingHttpClient.validate(HTTP, "177.16.0.1", 80, PRIVATE_IP_DISABLED); + validatingHttpClient.validate(HTTP, "177.0.1.1", 80, PRIVATE_IP_DISABLED); + validatingHttpClient.validate(HTTP, "177.0.0.2", 80, PRIVATE_IP_DISABLED); + validatingHttpClient.validate(HTTP, "::ffff", 80, PRIVATE_IP_DISABLED); + validatingHttpClient.validate(HTTP, "172.32.0.1", 80, PRIVATE_IP_ENABLED); + validatingHttpClient.validate(HTTP, "172.2097152", 80, PRIVATE_IP_ENABLED); + } + + @Test + public void test_validateIp_rarePrivateIp_throwException() throws Exception { + try { + validatingHttpClient.validate(HTTP, "0254.020.00.01", 80, PRIVATE_IP_DISABLED); + } catch (IllegalArgumentException e) { + assertNotNull(e); + } + + try { + validatingHttpClient.validate(HTTP, "172.1048577", 80, PRIVATE_IP_DISABLED); + } catch (Exception e) { + assertNotNull(e); + } + + try { + validatingHttpClient.validate(HTTP, "2886729729", 80, PRIVATE_IP_DISABLED); + } catch (IllegalArgumentException e) { + assertNotNull(e); + } + + try { + validatingHttpClient.validate(HTTP, "192.11010049", 80, PRIVATE_IP_DISABLED); + } catch (IllegalArgumentException e) { + assertNotNull(e); + } + + try { + validatingHttpClient.validate(HTTP, "3232300545", 80, PRIVATE_IP_DISABLED); + } catch (IllegalArgumentException e) { + assertNotNull(e); + } + + try { + validatingHttpClient.validate(HTTP, "0:0:0:0:0:ffff:127.0.0.1", 80, PRIVATE_IP_DISABLED); + } catch (IllegalArgumentException e) { + assertNotNull(e); + } + + try { + validatingHttpClient.validate(HTTP, "153.24.76.232", 80, PRIVATE_IP_DISABLED); + } catch (IllegalArgumentException e) { + assertNotNull(e); + } + + try { + validatingHttpClient.validate(HTTP, "177.0.0.1", 80, PRIVATE_IP_DISABLED); + } catch (IllegalArgumentException e) { + assertNotNull(e); + } + + try { + validatingHttpClient.validate(HTTP, "12.16.2.3", 80, PRIVATE_IP_DISABLED); + } catch (IllegalArgumentException e) { + assertNotNull(e); + } + } + + @Test + public void test_validateIp_rarePrivateIp_NotThrowException() throws Exception { + validatingHttpClient.validate(HTTP, "0254.020.00.01", 80, PRIVATE_IP_ENABLED); + validatingHttpClient.validate(HTTPS, "0254.020.00.01", 443, PRIVATE_IP_ENABLED); + validatingHttpClient.validate(HTTP, "172.1048577", 80, PRIVATE_IP_ENABLED); + validatingHttpClient.validate(HTTP, "2886729729", 80, PRIVATE_IP_ENABLED); + validatingHttpClient.validate(HTTP, "192.11010049", 80, PRIVATE_IP_ENABLED); + validatingHttpClient.validate(HTTP, "3232300545", 80, PRIVATE_IP_ENABLED); + validatingHttpClient.validate(HTTP, "0:0:0:0:0:ffff:127.0.0.1", 80, PRIVATE_IP_ENABLED); + validatingHttpClient.validate(HTTPS, "0:0:0:0:0:ffff:127.0.0.1", 443, PRIVATE_IP_ENABLED); + validatingHttpClient.validate(HTTP, "153.24.76.232", 80, PRIVATE_IP_ENABLED); + validatingHttpClient.validate(HTTP, "10.24.76.186", 80, PRIVATE_IP_ENABLED); + validatingHttpClient.validate(HTTPS, "10.24.76.186", 443, PRIVATE_IP_ENABLED); + } + + @Test + public void test_validateSchemaAndPort_success() throws Exception { + validatingHttpClient.validate(HTTP, TEST_HOST, 80, PRIVATE_IP_DISABLED); + } + + @Test + public void test_validateSchemaAndPort_notAllowedSchema_throwException() throws Exception { + expectedException.expect(IllegalArgumentException.class); + validatingHttpClient.validate("ftp", TEST_HOST, 80, PRIVATE_IP_DISABLED); + } + + @Test + public void test_validateSchemaAndPort_portNotInRange1_throwException() throws Exception { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("Port out of range: 65537"); + validatingHttpClient.validate(HTTPS, TEST_HOST, 65537, PRIVATE_IP_DISABLED); + } + + @Test + public void test_validateSchemaAndPort_portNotInRange2_throwException() throws Exception { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("Port out of range: -10"); + validatingHttpClient.validate(HTTP, TEST_HOST, -10, PRIVATE_IP_DISABLED); + } + + @Test + public void test_validatePort_boundaries_success() throws Exception { + validatingHttpClient.validate(HTTP, TEST_HOST, 65536, PRIVATE_IP_DISABLED); + validatingHttpClient.validate(HTTP, TEST_HOST, 0, PRIVATE_IP_DISABLED); + validatingHttpClient.validate(HTTP, TEST_HOST, -1, PRIVATE_IP_DISABLED); + validatingHttpClient.validate(HTTPS, TEST_HOST, -1, PRIVATE_IP_DISABLED); + validatingHttpClient.validate(null, TEST_HOST, -1, PRIVATE_IP_DISABLED); + } + +} diff --git a/common/src/test/java/org/opensearch/ml/common/settings/MLFeatureEnabledSettingTests.java b/common/src/test/java/org/opensearch/ml/common/settings/MLFeatureEnabledSettingTests.java index 9e25f693c5..f6a7daad1d 100644 --- a/common/src/test/java/org/opensearch/ml/common/settings/MLFeatureEnabledSettingTests.java +++ b/common/src/test/java/org/opensearch/ml/common/settings/MLFeatureEnabledSettingTests.java @@ -81,7 +81,7 @@ public void testDefaults_allFeaturesEnabled() { assertTrue(setting.isRemoteInferenceEnabled()); assertTrue(setting.isAgentFrameworkEnabled()); assertTrue(setting.isLocalModelEnabled()); - assertTrue(setting.isConnectorPrivateIpEnabled().get()); + assertTrue(setting.isConnectorPrivateIpEnabled()); assertTrue(setting.isControllerEnabled()); assertTrue(setting.isOfflineBatchIngestionEnabled()); assertTrue(setting.isOfflineBatchInferenceEnabled()); @@ -122,7 +122,7 @@ public void testDefaults_someFeaturesDisabled() { assertFalse(setting.isRemoteInferenceEnabled()); assertFalse(setting.isAgentFrameworkEnabled()); assertFalse(setting.isLocalModelEnabled()); - assertFalse(setting.isConnectorPrivateIpEnabled().get()); + assertFalse(setting.isConnectorPrivateIpEnabled()); assertFalse(setting.isControllerEnabled()); assertFalse(setting.isOfflineBatchIngestionEnabled()); assertFalse(setting.isOfflineBatchInferenceEnabled()); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java index 70673870ce..3b53935aaf 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java @@ -18,6 +18,7 @@ import java.util.Locale; import java.util.Map; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicReference; import org.apache.commons.text.StringEscapeUtils; import org.apache.logging.log4j.Logger; @@ -40,6 +41,8 @@ import org.opensearch.transport.StreamTransportService; import org.opensearch.transport.client.Client; +import com.google.common.annotations.VisibleForTesting; + import lombok.Getter; import lombok.Setter; import lombok.extern.log4j.Log4j2; @@ -70,19 +73,18 @@ public class AwsConnectorExecutor extends AbstractConnectorExecutor { @Getter private MLGuard mlGuard; - private SdkAsyncHttpClient httpClient; + private final AtomicReference httpClientRef = new AtomicReference<>(); @Setter @Getter private StreamTransportService streamTransportService; + @Setter + private boolean connectorPrivateIpEnabled; + public AwsConnectorExecutor(Connector connector) { super.initialize(connector); this.connector = (AwsConnector) connector; - Duration connectionTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getConnectionTimeout()); - Duration readTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getReadTimeout()); - Integer maxConnection = super.getConnectorClientConfig().getMaxConnections(); - this.httpClient = MLHttpClientFactory.getAsyncHttpClient(connectionTimeout, readTimeout, maxConnection); } @Override @@ -129,7 +131,8 @@ public void invokeRemoteService( ) ) .build(); - AccessController.doPrivileged((PrivilegedExceptionAction>) () -> httpClient.execute(executeRequest)); + AccessController + .doPrivileged((PrivilegedExceptionAction>) () -> getHttpClient().execute(executeRequest)); } catch (RuntimeException exception) { log.error("Failed to execute {} in aws connector: {}", action, exception.getMessage(), exception); actionListener.onFailure(exception); @@ -154,7 +157,7 @@ public void invokeRemoteServiceStream( llmInterface = StringEscapeUtils.unescapeJava(llmInterface); validateLLMInterface(llmInterface); - StreamingHandler handler = StreamingHandlerFactory.createHandler(llmInterface, connector, httpClient, null); + StreamingHandler handler = StreamingHandlerFactory.createHandler(llmInterface, connector, getHttpClient(), null); handler.startStream(action, parameters, payload, actionListener); } catch (Exception e) { log.error("Failed to execute streaming", e); @@ -180,4 +183,19 @@ private void validateLLMInterface(String llmInterface) { throw new IllegalArgumentException(String.format("Unsupported llm interface: %s", llmInterface)); } } + + @VisibleForTesting + protected SdkAsyncHttpClient getHttpClient() { + if (httpClientRef.get() == null) { + Duration connectionTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getConnectionTimeout()); + Duration readTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getReadTimeout()); + Integer maxConnection = super.getConnectorClientConfig().getMaxConnections(); + this.httpClientRef + .compareAndSet( + null, + MLHttpClientFactory.getAsyncHttpClient(connectionTimeout, readTimeout, maxConnection, connectorPrivateIpEnabled) + ); + } + return httpClientRef.get(); + } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java index b8984a1246..7804770258 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java @@ -11,14 +11,13 @@ import static software.amazon.awssdk.http.SdkHttpMethod.GET; import static software.amazon.awssdk.http.SdkHttpMethod.POST; -import java.net.URL; import java.security.AccessController; import java.security.PrivilegedExceptionAction; import java.time.Duration; import java.util.Locale; import java.util.Map; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; import org.apache.commons.text.StringEscapeUtils; import org.apache.logging.log4j.Logger; @@ -41,6 +40,8 @@ import org.opensearch.transport.StreamTransportService; import org.opensearch.transport.client.Client; +import com.google.common.annotations.VisibleForTesting; + import lombok.Getter; import lombok.Setter; import lombok.extern.log4j.Log4j2; @@ -72,9 +73,9 @@ public class HttpJsonConnectorExecutor extends AbstractConnectorExecutor { @Getter private MLGuard mlGuard; @Setter - private volatile AtomicBoolean connectorPrivateIpEnabled; + private volatile boolean connectorPrivateIpEnabled; - private SdkAsyncHttpClient httpClient; + private final AtomicReference httpClientRef = new AtomicReference<>(); @Setter @Getter @@ -83,10 +84,6 @@ public class HttpJsonConnectorExecutor extends AbstractConnectorExecutor { public HttpJsonConnectorExecutor(Connector connector) { super.initialize(connector); this.connector = (HttpConnector) connector; - Duration connectionTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getConnectionTimeout()); - Duration readTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getReadTimeout()); - Integer maxConnection = super.getConnectorClientConfig().getMaxConnections(); - this.httpClient = MLHttpClientFactory.getAsyncHttpClient(connectionTimeout, readTimeout, maxConnection); } @Override @@ -109,11 +106,9 @@ public void invokeRemoteService( switch (connector.getActionHttpMethod(action).toUpperCase(Locale.ROOT)) { case "POST": log.debug("original payload to remote model: " + payload); - validateHttpClientParameters(action, parameters); request = ConnectorUtils.buildSdkRequest(action, connector, parameters, payload, POST); break; case "GET": - validateHttpClientParameters(action, parameters); request = ConnectorUtils.buildSdkRequest(action, connector, parameters, null, GET); break; default: @@ -135,7 +130,8 @@ public void invokeRemoteService( ) ) .build(); - AccessController.doPrivileged((PrivilegedExceptionAction>) () -> httpClient.execute(executeRequest)); + AccessController + .doPrivileged((PrivilegedExceptionAction>) () -> getHttpClient().execute(executeRequest)); } catch (RuntimeException e) { log.error("Fail to execute http connector", e); actionListener.onFailure(e); @@ -169,15 +165,6 @@ public void invokeRemoteServiceStream( } } - private void validateHttpClientParameters(String action, Map parameters) throws Exception { - String endpoint = connector.getActionEndpoint(action, parameters); - URL url = new URL(endpoint); - String protocol = url.getProtocol(); - String host = url.getHost(); - int port = url.getPort(); - MLHttpClientFactory.validate(protocol, host, port, connectorPrivateIpEnabled); - } - private void validateLLMInterface(String llmInterface) { switch (llmInterface) { case LLM_INTERFACE_OPENAI_V1_CHAT_COMPLETIONS: @@ -186,4 +173,19 @@ private void validateLLMInterface(String llmInterface) { throw new IllegalArgumentException(String.format("Unsupported llm interface: %s", llmInterface)); } } + + @VisibleForTesting + protected SdkAsyncHttpClient getHttpClient() { + if (httpClientRef.get() == null) { + Duration connectionTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getConnectionTimeout()); + Duration readTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getReadTimeout()); + Integer maxConnection = super.getConnectorClientConfig().getMaxConnections(); + this.httpClientRef + .compareAndSet( + null, + MLHttpClientFactory.getAsyncHttpClient(connectionTimeout, readTimeout, maxConnection, connectorPrivateIpEnabled) + ); + } + return httpClientRef.get(); + } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java index 6867ebc0b1..5181e8d087 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java @@ -19,7 +19,6 @@ import java.util.Locale; import java.util.Map; import java.util.Optional; -import java.util.concurrent.atomic.AtomicBoolean; import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; @@ -183,7 +182,7 @@ default void setScriptService(ScriptService scriptService) {} default void setClient(Client client) {} - default void setConnectorPrivateIpEnabled(AtomicBoolean connectorPrivateIpEnabled) {} + default void setConnectorPrivateIpEnabled(boolean connectorPrivateIpEnabled) {} default void setXContentRegistry(NamedXContentRegistry xContentRegistry) {} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java index 6cc380b395..a000989ee6 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java @@ -11,7 +11,6 @@ import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; -import java.util.concurrent.atomic.AtomicBoolean; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; @@ -127,7 +126,7 @@ public CompletionStage initModelAsync(MLModel model, Map) params.get(USER_RATE_LIMITER_MAP)); this.connectorExecutor.setMlGuard((MLGuard) params.get(GUARDRAILS)); - this.connectorExecutor.setConnectorPrivateIpEnabled((AtomicBoolean) params.get(CONNECTOR_PRIVATE_IP_ENABLED)); + this.connectorExecutor.setConnectorPrivateIpEnabled((boolean) params.getOrDefault(CONNECTOR_PRIVATE_IP_ENABLED, false)); return CompletableFuture.completedStage(true); }).exceptionally(e -> { log.error("Failed to init remote model.", e); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java index b1234dae93..59998b714e 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java @@ -22,7 +22,6 @@ import static org.opensearch.ml.common.connector.HttpConnector.REGION_FIELD; import static org.opensearch.ml.common.connector.HttpConnector.SERVICE_NAME_FIELD; -import java.lang.reflect.Field; import java.util.Arrays; import java.util.List; import java.util.Map; @@ -448,7 +447,7 @@ public void executePredict_TextDocsInferenceInput_withStepSize_failWithMultipleF } @Test - public void executePredict_RemoteInferenceInput_nullHttpClient_throwNPException() throws NoSuchFieldException, IllegalAccessException { + public void executePredict_RemoteInferenceInput_nullHttpClient_throwNPException() { ConnectorAction predictAction = ConnectorAction .builder() .actionType(PREDICT) @@ -469,14 +468,11 @@ public void executePredict_RemoteInferenceInput_nullHttpClient_throwNPException( .actions(Arrays.asList(predictAction)) .build(); connector.decrypt(PREDICT.name(), (c, tenantId) -> encryptor.decrypt(c, null), null); - AwsConnectorExecutor executor0 = new AwsConnectorExecutor(connector); - Field httpClientField = AwsConnectorExecutor.class.getDeclaredField("httpClient"); - httpClientField.setAccessible(true); - httpClientField.set(executor0, null); - AwsConnectorExecutor executor = spy(executor0); + AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector)); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); when(executor.getClient()).thenReturn(client); + when(executor.getHttpClient()).thenReturn(null); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java index feb8dc173c..ff3298f7e9 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java @@ -10,15 +10,15 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT; -import java.lang.reflect.Field; import java.util.Arrays; import java.util.HashMap; import java.util.Map; -import java.util.concurrent.atomic.AtomicBoolean; import org.junit.Before; import org.junit.Rule; @@ -96,8 +96,7 @@ public void invokeRemoteService_invalidIpAddress() { .actions(Arrays.asList(predictAction)) .build(); HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector); - AtomicBoolean privateIpEnabled = new AtomicBoolean(false); - executor.setConnectorPrivateIpEnabled(privateIpEnabled); + executor.setConnectorPrivateIpEnabled(false); executor .invokeRemoteService( PREDICT.name(), @@ -130,8 +129,7 @@ public void invokeRemoteService_EnabledPrivateIpAddress() { .actions(Arrays.asList(predictAction)) .build(); HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector); - AtomicBoolean privateIpEnabled = new AtomicBoolean(true); - executor.setConnectorPrivateIpEnabled(privateIpEnabled); + executor.setConnectorPrivateIpEnabled(true); executor .invokeRemoteService( PREDICT.name(), @@ -161,8 +159,7 @@ public void invokeRemoteService_DisabledPrivateIpAddress() { .actions(Arrays.asList(predictAction)) .build(); HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector); - AtomicBoolean privateIpEnabled = new AtomicBoolean(false); - executor.setConnectorPrivateIpEnabled(privateIpEnabled); + executor.setConnectorPrivateIpEnabled(false); executor .invokeRemoteService( PREDICT.name(), @@ -195,8 +192,6 @@ public void invokeRemoteService_Empty_payload() { .actions(Arrays.asList(predictAction)) .build(); HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector); - AtomicBoolean privateIpEnabled = new AtomicBoolean(false); - executor.setConnectorPrivateIpEnabled(privateIpEnabled); executor.invokeRemoteService(PREDICT.name(), createMLInput(), new HashMap<>(), null, new ExecutionContext(0), actionListener); ArgumentCaptor captor = ArgumentCaptor.forClass(IllegalArgumentException.class); Mockito.verify(actionListener, times(1)).onFailure(captor.capture()); @@ -246,7 +241,7 @@ public void invokeRemoteService_post_request() { } @Test - public void invokeRemoteService_nullHttpClient_throwMLException() throws NoSuchFieldException, IllegalAccessException { + public void invokeRemoteService_nullHttpClient_throwMLException() { ConnectorAction predictAction = ConnectorAction .builder() .actionType(PREDICT) @@ -261,10 +256,8 @@ public void invokeRemoteService_nullHttpClient_throwMLException() throws NoSuchF .protocol("http") .actions(Arrays.asList(predictAction)) .build(); - HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector); - Field httpClientField = HttpJsonConnectorExecutor.class.getDeclaredField("httpClient"); - httpClientField.setAccessible(true); - httpClientField.set(executor, null); + HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); + when(executor.getHttpClient()).thenReturn(null); executor .invokeRemoteService(PREDICT.name(), createMLInput(), new HashMap<>(), "hello world", new ExecutionContext(0), actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); diff --git a/plugin/build.gradle b/plugin/build.gradle index e2ef299232..01a01e1358 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -672,6 +672,15 @@ configurations.all { resolutionStrategy.force 'com.fasterxml.jackson.dataformat:jackson-dataformat-yaml:2.18.3' resolutionStrategy.force 'com.fasterxml.jackson.dataformat:jackson-dataformat-cbor:2.18.3' resolutionStrategy.force 'com.fasterxml.jackson.datatype:jackson-datatype-jsr310:2.18.3' + resolutionStrategy.force "io.netty:netty-codec-http:${versions.netty}" + resolutionStrategy.force "io.netty:netty-codec-http2:${versions.netty}" + resolutionStrategy.force "io.netty:netty-codec:${versions.netty}" + resolutionStrategy.force "io.netty:netty-transport:${versions.netty}" + resolutionStrategy.force "io.netty:netty-common:${versions.netty}" + resolutionStrategy.force "io.netty:netty-buffer:${versions.netty}" + resolutionStrategy.force "io.netty:netty-handler:${versions.netty}" + resolutionStrategy.force "io.netty:netty-resolver:${versions.netty}" + resolutionStrategy.force "io.netty:netty-transport-native-unix-common:${versions.netty}" } apply plugin: 'com.netflix.nebula.ospackage' diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java index 52ec2c6e34..93d525f305 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java @@ -65,7 +65,6 @@ import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutorService; -import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Supplier; import java.util.stream.Collectors; @@ -285,7 +284,7 @@ public void setup() throws URISyntaxException { when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); when(mlFeatureEnabledSetting.isControllerEnabled()).thenReturn(true); - when(mlFeatureEnabledSetting.isConnectorPrivateIpEnabled()).thenReturn(new AtomicBoolean(false)); + when(mlFeatureEnabledSetting.isConnectorPrivateIpEnabled()).thenReturn(false); modelManager = spy( new MLModelManager( diff --git a/plugin/src/test/java/org/opensearch/ml/tools/ToolIntegrationWithLLMTest.java b/plugin/src/test/java/org/opensearch/ml/tools/ToolIntegrationWithLLMTest.java index f73478d092..dd618c7b2f 100644 --- a/plugin/src/test/java/org/opensearch/ml/tools/ToolIntegrationWithLLMTest.java +++ b/plugin/src/test/java/org/opensearch/ml/tools/ToolIntegrationWithLLMTest.java @@ -185,6 +185,7 @@ private void setUpClusterSettings() throws IOException { updateClusterSettings("plugins.ml_commons.only_run_on_ml_node", false); updateClusterSettings("plugins.ml_commons.memory_feature_enabled", true); updateClusterSettings("plugins.ml_commons.trusted_connector_endpoints_regex", List.of("^.*$")); + updateClusterSettings("plugins.ml_commons.connector.private_ip_enabled", true); } private void restoreClusterSettings() throws IOException {