Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ buildscript {
ext {
opensearch_group = "org.opensearch"
isSnapshot = "true" == System.getProperty("build.snapshot", "true")
opensearch_version = System.getProperty("opensearch.version", "2.19.3-SNAPSHOT")
opensearch_version = System.getProperty("opensearch.version", "2.19.4-SNAPSHOT")
buildVersionQualifier = System.getProperty("build.version_qualifier", "")

// 2.0.0-rc1-SNAPSHOT -> 2.0.0.0-rc1-SNAPSHOT
Expand Down Expand Up @@ -71,6 +71,7 @@ allprojects {

}


subprojects {
configurations {
testImplementation.extendsFrom compileOnly
Expand All @@ -80,6 +81,18 @@ subprojects {
// Force spotless depending on newer version of guava due to CVE-2023-2976. Remove after spotless upgrades.
resolutionStrategy.force "com.google.guava:guava:32.1.3-jre"
resolutionStrategy.force 'org.apache.commons:commons-compress:1.26.0'
resolutionStrategy.force "org.apache.commons:commons-lang3:${versions.commonslang}"
resolutionStrategy.force 'software.amazon.awssdk:bom:2.32.29'

resolutionStrategy.force 'io.netty:netty-buffer:4.1.125.Final'
resolutionStrategy.force 'io.netty:netty-codec:4.1.125.Final'
resolutionStrategy.force 'io.netty:netty-codec-http:4.1.125.Final'
resolutionStrategy.force 'io.netty:netty-codec-http2:4.1.125.Final'
resolutionStrategy.force 'io.netty:netty-common:4.1.125.Final'
resolutionStrategy.force 'io.netty:netty-handler:4.1.125.Final'
resolutionStrategy.force 'io.netty:netty-resolver:4.1.125.Final'
resolutionStrategy.force 'io.netty:netty-transport:4.1.125.Final'
resolutionStrategy.force 'io.netty:netty-transport-native-unix-common:4.1.125.Final'
}
}

Expand Down
4 changes: 4 additions & 0 deletions common/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ dependencies {
compileOnly group: 'com.networknt' , name: 'json-schema-validator', version: '1.4.0'
// Multi-tenant SDK Client
compileOnly "org.opensearch:opensearch-remote-metadata-sdk:${opensearch_build}"
compileOnly (group: 'software.amazon.awssdk', name: 'netty-nio-client', version: "2.32.29") {
exclude(group: 'org.reactivestreams', module: 'reactive-streams')
exclude(group: 'org.slf4j', module: 'slf4j-api')
}
}

lombok {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,18 @@
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.engine.httpclient;
package org.opensearch.ml.common.httpclient;

import java.net.Inet4Address;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.time.Duration;
import java.util.Arrays;
import java.util.Locale;
import java.util.concurrent.atomic.AtomicBoolean;

import org.opensearch.common.util.concurrent.ThreadContextAccess;

import lombok.extern.log4j.Log4j2;
import software.amazon.awssdk.http.async.SdkAsyncHttpClient;
import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient;
Expand All @@ -24,19 +23,15 @@
public class MLHttpClientFactory {

public static SdkAsyncHttpClient getAsyncHttpClient(Duration connectionTimeout, Duration readTimeout, int maxConnections) {
try {
return AccessController
.doPrivileged(
(PrivilegedExceptionAction<SdkAsyncHttpClient>) () -> NettyNioAsyncHttpClient
.builder()
.connectionTimeout(connectionTimeout)
.readTimeout(readTimeout)
.maxConcurrency(maxConnections)
.build()
);
} catch (PrivilegedActionException e) {
return null;
}
return ThreadContextAccess
.doPrivileged(
() -> NettyNioAsyncHttpClient
.builder()
.connectionTimeout(connectionTimeout)
.readTimeout(readTimeout)
.maxConcurrency(maxConnections)
.build()
);
}

/**
Expand All @@ -50,7 +45,7 @@ public static SdkAsyncHttpClient getAsyncHttpClient(Duration connectionTimeout,
public static void validate(String protocol, String host, int port, AtomicBoolean connectorPrivateIpEnabled)
throws UnknownHostException {
if (protocol != null && !"http".equalsIgnoreCase(protocol) && !"https".equalsIgnoreCase(protocol)) {
log.error("Remote inference protocol is not http or https: " + protocol);
log.error("Remote inference protocol is not http or https: {}", protocol);
throw new IllegalArgumentException("Protocol is not http or https: " + protocol);
}
// When port is not specified, the default port is -1, and we need to set it to 80 or 443 based on protocol.
Expand All @@ -62,7 +57,7 @@ public static void validate(String protocol, String host, int port, AtomicBoolea
}
}
if (port < 0 || port > 65536) {
log.error("Remote inference port out of range: " + port);
log.error("Remote inference port out of range: {}", port);
throw new IllegalArgumentException("Port out of range: " + port);
}
validateIp(host, connectorPrivateIpEnabled);
Expand All @@ -71,7 +66,7 @@ public static void validate(String protocol, String host, int port, AtomicBoolea
private static void validateIp(String hostName, AtomicBoolean connectorPrivateIpEnabled) throws UnknownHostException {
InetAddress[] addresses = InetAddress.getAllByName(hostName);
if ((connectorPrivateIpEnabled == null || !connectorPrivateIpEnabled.get()) && hasPrivateIpAddress(addresses)) {
log.error("Remote inference host name has private ip address: " + hostName);
log.error("Remote inference host name has private ip address: {}", hostName);
throw new IllegalArgumentException("Remote inference host name has private ip address: " + hostName);
}
}
Expand All @@ -83,35 +78,23 @@ private static boolean hasPrivateIpAddress(InetAddress[] ipAddress) {
if (bytes.length != 4) {
return true;
} else {
int firstOctets = bytes[0] & 0xff;
int firstInOctal = parseWithOctal(String.valueOf(firstOctets));
int firstInHex = Integer.parseInt(String.valueOf(firstOctets), 16);
if (firstInOctal == 127 || firstInHex == 127) {
return bytes[1] == 0 && bytes[2] == 0 && bytes[3] == 1;
} else if (firstInOctal == 10 || firstInHex == 10) {
if (isPrivateIPv4(bytes)) {
return true;
} else if (firstInOctal == 172 || firstInHex == 172) {
int secondOctets = bytes[1] & 0xff;
int secondInOctal = parseWithOctal(String.valueOf(secondOctets));
int secondInHex = Integer.parseInt(String.valueOf(secondOctets), 16);
return (secondInOctal >= 16 && secondInOctal <= 32) || (secondInHex >= 16 && secondInHex <= 32);
} else if (firstInOctal == 192 || firstInHex == 192) {
int secondOctets = bytes[1] & 0xff;
int secondInOctal = parseWithOctal(String.valueOf(secondOctets));
int secondInHex = Integer.parseInt(String.valueOf(secondOctets), 16);
return secondInOctal == 168 || secondInHex == 168;
}
}
}
}
return Arrays.stream(ipAddress).anyMatch(x -> x.isSiteLocalAddress() || x.isLoopbackAddress() || x.isAnyLocalAddress());
}

private static int parseWithOctal(String input) {
try {
return Integer.parseInt(input, 8);
} catch (NumberFormatException e) {
return Integer.parseInt(input);
}
private static boolean isPrivateIPv4(byte[] bytes) {
int first = bytes[0] & 0xff;
int second = bytes[1] & 0xff;

// 127.0.0.1, 10.x.x.x, 172.16-31.x.x, 192.168.x.x, 169.254.x.x
return (first == 10)
|| (first == 172 && second >= 16 && second <= 31)
|| (first == 192 && second == 168)
|| (first == 169 && second == 254);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.httpclient;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertThrows;

import java.time.Duration;
import java.util.concurrent.atomic.AtomicBoolean;

import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;

import software.amazon.awssdk.http.async.SdkAsyncHttpClient;

public class MLHttpClientFactoryTests {

private static final String TEST_HOST = "api.openai.com";
private static final String HTTP = "http";
private static final String HTTPS = "https";
private static final AtomicBoolean PRIVATE_IP_DISABLED = new AtomicBoolean(false);
private static final AtomicBoolean PRIVATE_IP_ENABLED = new AtomicBoolean(true);

@Rule
public ExpectedException expectedException = ExpectedException.none();

@Test
public void test_getSdkAsyncHttpClient_success() {
SdkAsyncHttpClient client = MLHttpClientFactory.getAsyncHttpClient(Duration.ofSeconds(100), Duration.ofSeconds(100), 100);
assertNotNull(client);
}

@Test
public void test_invalidIP_localHost_privateIPDisabled() {
IllegalArgumentException e1 = assertThrows(
IllegalArgumentException.class,
() -> MLHttpClientFactory.validate(HTTP, "127.0.0.1", 80, PRIVATE_IP_DISABLED)
);
assertEquals("Remote inference host name has private ip address: 127.0.0.1", e1.getMessage());

IllegalArgumentException e2 = assertThrows(
IllegalArgumentException.class,
() -> MLHttpClientFactory.validate(HTTP, "192.168.0.1", 80, PRIVATE_IP_DISABLED)
);
assertEquals("Remote inference host name has private ip address: 192.168.0.1", e2.getMessage());

IllegalArgumentException e3 = assertThrows(
IllegalArgumentException.class,
() -> MLHttpClientFactory.validate(HTTP, "169.254.0.1", 80, PRIVATE_IP_DISABLED)
);
assertEquals("Remote inference host name has private ip address: 169.254.0.1", e3.getMessage());

IllegalArgumentException e4 = assertThrows(
IllegalArgumentException.class,
() -> MLHttpClientFactory.validate(HTTP, "172.16.0.1", 80, PRIVATE_IP_DISABLED)
);
assertEquals("Remote inference host name has private ip address: 172.16.0.1", e4.getMessage());

IllegalArgumentException e5 = assertThrows(
IllegalArgumentException.class,
() -> MLHttpClientFactory.validate(HTTP, "172.31.0.1", 80, PRIVATE_IP_DISABLED)
);
assertEquals("Remote inference host name has private ip address: 172.31.0.1", e5.getMessage());
}

@Test
public void test_validateIp_validIp_noException() throws Exception {
MLHttpClientFactory.validate(HTTP, TEST_HOST, 80, PRIVATE_IP_DISABLED);
MLHttpClientFactory.validate(HTTPS, TEST_HOST, 443, PRIVATE_IP_DISABLED);
MLHttpClientFactory.validate(HTTP, "127.0.0.1", 80, PRIVATE_IP_ENABLED);
MLHttpClientFactory.validate(HTTPS, "127.0.0.1", 443, PRIVATE_IP_ENABLED);
MLHttpClientFactory.validate(HTTP, "177.16.0.1", 80, PRIVATE_IP_DISABLED);
MLHttpClientFactory.validate(HTTP, "177.0.1.1", 80, PRIVATE_IP_DISABLED);
MLHttpClientFactory.validate(HTTP, "177.0.0.2", 80, PRIVATE_IP_DISABLED);
MLHttpClientFactory.validate(HTTP, "::ffff", 80, PRIVATE_IP_DISABLED);
MLHttpClientFactory.validate(HTTP, "172.32.0.1", 80, PRIVATE_IP_ENABLED);
MLHttpClientFactory.validate(HTTP, "172.2097152", 80, PRIVATE_IP_ENABLED);
}

@Test
public void test_validateIp_rarePrivateIp_throwException() throws Exception {
try {
MLHttpClientFactory.validate(HTTP, "0254.020.00.01", 80, PRIVATE_IP_DISABLED);
} catch (IllegalArgumentException e) {
assertNotNull(e);
}

try {
MLHttpClientFactory.validate(HTTP, "172.1048577", 80, PRIVATE_IP_DISABLED);
} catch (Exception e) {
assertNotNull(e);
}

try {
MLHttpClientFactory.validate(HTTP, "2886729729", 80, PRIVATE_IP_DISABLED);
} catch (IllegalArgumentException e) {
assertNotNull(e);
}

try {
MLHttpClientFactory.validate(HTTP, "192.11010049", 80, PRIVATE_IP_DISABLED);
} catch (IllegalArgumentException e) {
assertNotNull(e);
}

try {
MLHttpClientFactory.validate(HTTP, "3232300545", 80, PRIVATE_IP_DISABLED);
} catch (IllegalArgumentException e) {
assertNotNull(e);
}

try {
MLHttpClientFactory.validate(HTTP, "0:0:0:0:0:ffff:127.0.0.1", 80, PRIVATE_IP_DISABLED);
} catch (IllegalArgumentException e) {
assertNotNull(e);
}

try {
MLHttpClientFactory.validate(HTTP, "153.24.76.232", 80, PRIVATE_IP_DISABLED);
} catch (IllegalArgumentException e) {
assertNotNull(e);
}

try {
MLHttpClientFactory.validate(HTTP, "177.0.0.1", 80, PRIVATE_IP_DISABLED);
} catch (IllegalArgumentException e) {
assertNotNull(e);
}

try {
MLHttpClientFactory.validate(HTTP, "12.16.2.3", 80, PRIVATE_IP_DISABLED);
} catch (IllegalArgumentException e) {
assertNotNull(e);
}
}

@Test
public void test_validateIp_rarePrivateIp_NotThrowException() throws Exception {
MLHttpClientFactory.validate(HTTP, "0254.020.00.01", 80, PRIVATE_IP_ENABLED);
MLHttpClientFactory.validate(HTTPS, "0254.020.00.01", 443, PRIVATE_IP_ENABLED);
MLHttpClientFactory.validate(HTTP, "172.1048577", 80, PRIVATE_IP_ENABLED);
MLHttpClientFactory.validate(HTTP, "2886729729", 80, PRIVATE_IP_ENABLED);
MLHttpClientFactory.validate(HTTP, "192.11010049", 80, PRIVATE_IP_ENABLED);
MLHttpClientFactory.validate(HTTP, "3232300545", 80, PRIVATE_IP_ENABLED);
MLHttpClientFactory.validate(HTTP, "0:0:0:0:0:ffff:127.0.0.1", 80, PRIVATE_IP_ENABLED);
MLHttpClientFactory.validate(HTTPS, "0:0:0:0:0:ffff:127.0.0.1", 443, PRIVATE_IP_ENABLED);
MLHttpClientFactory.validate(HTTP, "153.24.76.232", 80, PRIVATE_IP_ENABLED);
MLHttpClientFactory.validate(HTTP, "10.24.76.186", 80, PRIVATE_IP_ENABLED);
MLHttpClientFactory.validate(HTTPS, "10.24.76.186", 443, PRIVATE_IP_ENABLED);
}

@Test
public void test_validateSchemaAndPort_success() throws Exception {
MLHttpClientFactory.validate(HTTP, TEST_HOST, 80, PRIVATE_IP_DISABLED);
}

@Test
public void test_validateSchemaAndPort_notAllowedSchema_throwException() throws Exception {
expectedException.expect(IllegalArgumentException.class);
MLHttpClientFactory.validate("ftp", TEST_HOST, 80, PRIVATE_IP_DISABLED);
}

@Test
public void test_validateSchemaAndPort_portNotInRange1_throwException() throws Exception {
expectedException.expect(IllegalArgumentException.class);
expectedException.expectMessage("Port out of range: 65537");
MLHttpClientFactory.validate(HTTPS, TEST_HOST, 65537, PRIVATE_IP_DISABLED);
}

@Test
public void test_validateSchemaAndPort_portNotInRange2_throwException() throws Exception {
expectedException.expect(IllegalArgumentException.class);
expectedException.expectMessage("Port out of range: -10");
MLHttpClientFactory.validate(HTTP, TEST_HOST, -10, PRIVATE_IP_DISABLED);
}

@Test
public void test_validatePort_boundaries_success() throws Exception {
MLHttpClientFactory.validate(HTTP, TEST_HOST, 65536, PRIVATE_IP_DISABLED);
MLHttpClientFactory.validate(HTTP, TEST_HOST, 0, PRIVATE_IP_DISABLED);
MLHttpClientFactory.validate(HTTP, TEST_HOST, -1, PRIVATE_IP_DISABLED);
MLHttpClientFactory.validate(HTTPS, TEST_HOST, -1, PRIVATE_IP_DISABLED);
MLHttpClientFactory.validate(null, TEST_HOST, -1, PRIVATE_IP_DISABLED);
}

}
Loading
Loading