Skip to content

Commit b986151

Browse files
zane-neobrianf-aws
authored andcommitted
Move HttpClientFactory to common to expose to other components (opensearch-project#4175)
* Move HttpClientFactory to common to expose to other componenets Signed-off-by: zane-neo <[email protected]> * optimize code for better maintainability Signed-off-by: zane-neo <[email protected]> * Optimize code and increase UT coverage Signed-off-by: zane-neo <[email protected]> * Address comments Signed-off-by: zane-neo <[email protected]> * Use amazon aws version from opensearch core Signed-off-by: zane-neo <[email protected]> * address comments Signed-off-by: zane-neo <[email protected]> --------- Signed-off-by: zane-neo <[email protected]>
1 parent 7a07243 commit b986151

File tree

8 files changed

+230
-165
lines changed

8 files changed

+230
-165
lines changed

common/build.gradle

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ dependencies {
4444
compileOnly group: 'com.networknt' , name: 'json-schema-validator', version: '1.4.0'
4545
// Multi-tenant SDK Client
4646
compileOnly "org.opensearch:opensearch-remote-metadata-sdk:${opensearch_build}"
47+
compileOnly (group: 'software.amazon.awssdk', name: 'netty-nio-client', version: "${versions.aws}") {
48+
exclude(group: 'org.reactivestreams', module: 'reactive-streams')
49+
exclude(group: 'org.slf4j', module: 'slf4j-api')
50+
}
4751
}
4852

4953
lombok {
Lines changed: 25 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,18 @@
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

6-
package org.opensearch.ml.engine.httpclient;
6+
package org.opensearch.ml.common.httpclient;
77

88
import java.net.Inet4Address;
99
import java.net.InetAddress;
1010
import java.net.UnknownHostException;
11-
import java.security.AccessController;
12-
import java.security.PrivilegedActionException;
13-
import java.security.PrivilegedExceptionAction;
1411
import java.time.Duration;
1512
import java.util.Arrays;
1613
import java.util.Locale;
1714
import java.util.concurrent.atomic.AtomicBoolean;
1815

16+
import org.opensearch.common.util.concurrent.ThreadContextAccess;
17+
1918
import lombok.extern.log4j.Log4j2;
2019
import software.amazon.awssdk.http.async.SdkAsyncHttpClient;
2120
import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient;
@@ -24,19 +23,15 @@
2423
public class MLHttpClientFactory {
2524

2625
public static SdkAsyncHttpClient getAsyncHttpClient(Duration connectionTimeout, Duration readTimeout, int maxConnections) {
27-
try {
28-
return AccessController
29-
.doPrivileged(
30-
(PrivilegedExceptionAction<SdkAsyncHttpClient>) () -> NettyNioAsyncHttpClient
31-
.builder()
32-
.connectionTimeout(connectionTimeout)
33-
.readTimeout(readTimeout)
34-
.maxConcurrency(maxConnections)
35-
.build()
36-
);
37-
} catch (PrivilegedActionException e) {
38-
return null;
39-
}
26+
return ThreadContextAccess
27+
.doPrivileged(
28+
() -> NettyNioAsyncHttpClient
29+
.builder()
30+
.connectionTimeout(connectionTimeout)
31+
.readTimeout(readTimeout)
32+
.maxConcurrency(maxConnections)
33+
.build()
34+
);
4035
}
4136

4237
/**
@@ -50,7 +45,7 @@ public static SdkAsyncHttpClient getAsyncHttpClient(Duration connectionTimeout,
5045
public static void validate(String protocol, String host, int port, AtomicBoolean connectorPrivateIpEnabled)
5146
throws UnknownHostException {
5247
if (protocol != null && !"http".equalsIgnoreCase(protocol) && !"https".equalsIgnoreCase(protocol)) {
53-
log.error("Remote inference protocol is not http or https: " + protocol);
48+
log.error("Remote inference protocol is not http or https: {}", protocol);
5449
throw new IllegalArgumentException("Protocol is not http or https: " + protocol);
5550
}
5651
// When port is not specified, the default port is -1, and we need to set it to 80 or 443 based on protocol.
@@ -62,7 +57,7 @@ public static void validate(String protocol, String host, int port, AtomicBoolea
6257
}
6358
}
6459
if (port < 0 || port > 65536) {
65-
log.error("Remote inference port out of range: " + port);
60+
log.error("Remote inference port out of range: {}", port);
6661
throw new IllegalArgumentException("Port out of range: " + port);
6762
}
6863
validateIp(host, connectorPrivateIpEnabled);
@@ -71,7 +66,7 @@ public static void validate(String protocol, String host, int port, AtomicBoolea
7166
private static void validateIp(String hostName, AtomicBoolean connectorPrivateIpEnabled) throws UnknownHostException {
7267
InetAddress[] addresses = InetAddress.getAllByName(hostName);
7368
if ((connectorPrivateIpEnabled == null || !connectorPrivateIpEnabled.get()) && hasPrivateIpAddress(addresses)) {
74-
log.error("Remote inference host name has private ip address: " + hostName);
69+
log.error("Remote inference host name has private ip address: {}", hostName);
7570
throw new IllegalArgumentException("Remote inference host name has private ip address: " + hostName);
7671
}
7772
}
@@ -83,35 +78,23 @@ private static boolean hasPrivateIpAddress(InetAddress[] ipAddress) {
8378
if (bytes.length != 4) {
8479
return true;
8580
} else {
86-
int firstOctets = bytes[0] & 0xff;
87-
int firstInOctal = parseWithOctal(String.valueOf(firstOctets));
88-
int firstInHex = Integer.parseInt(String.valueOf(firstOctets), 16);
89-
if (firstInOctal == 127 || firstInHex == 127) {
90-
return bytes[1] == 0 && bytes[2] == 0 && bytes[3] == 1;
91-
} else if (firstInOctal == 10 || firstInHex == 10) {
81+
if (isPrivateIPv4(bytes)) {
9282
return true;
93-
} else if (firstInOctal == 172 || firstInHex == 172) {
94-
int secondOctets = bytes[1] & 0xff;
95-
int secondInOctal = parseWithOctal(String.valueOf(secondOctets));
96-
int secondInHex = Integer.parseInt(String.valueOf(secondOctets), 16);
97-
return (secondInOctal >= 16 && secondInOctal <= 32) || (secondInHex >= 16 && secondInHex <= 32);
98-
} else if (firstInOctal == 192 || firstInHex == 192) {
99-
int secondOctets = bytes[1] & 0xff;
100-
int secondInOctal = parseWithOctal(String.valueOf(secondOctets));
101-
int secondInHex = Integer.parseInt(String.valueOf(secondOctets), 16);
102-
return secondInOctal == 168 || secondInHex == 168;
10383
}
10484
}
10585
}
10686
}
10787
return Arrays.stream(ipAddress).anyMatch(x -> x.isSiteLocalAddress() || x.isLoopbackAddress() || x.isAnyLocalAddress());
10888
}
10989

110-
private static int parseWithOctal(String input) {
111-
try {
112-
return Integer.parseInt(input, 8);
113-
} catch (NumberFormatException e) {
114-
return Integer.parseInt(input);
115-
}
90+
private static boolean isPrivateIPv4(byte[] bytes) {
91+
int first = bytes[0] & 0xff;
92+
int second = bytes[1] & 0xff;
93+
94+
// 127.0.0.1, 10.x.x.x, 172.16-31.x.x, 192.168.x.x, 169.254.x.x
95+
return (first == 10)
96+
|| (first == 172 && second >= 16 && second <= 31)
97+
|| (first == 192 && second == 168)
98+
|| (first == 169 && second == 254);
11699
}
117100
}
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.httpclient;
7+
8+
import static org.junit.Assert.assertEquals;
9+
import static org.junit.Assert.assertNotNull;
10+
import static org.junit.Assert.assertThrows;
11+
12+
import java.time.Duration;
13+
import java.util.concurrent.atomic.AtomicBoolean;
14+
15+
import org.junit.Rule;
16+
import org.junit.Test;
17+
import org.junit.rules.ExpectedException;
18+
19+
import software.amazon.awssdk.http.async.SdkAsyncHttpClient;
20+
21+
public class MLHttpClientFactoryTests {
22+
23+
private static final String TEST_HOST = "api.openai.com";
24+
private static final String HTTP = "http";
25+
private static final String HTTPS = "https";
26+
private static final AtomicBoolean PRIVATE_IP_DISABLED = new AtomicBoolean(false);
27+
private static final AtomicBoolean PRIVATE_IP_ENABLED = new AtomicBoolean(true);
28+
29+
@Rule
30+
public ExpectedException expectedException = ExpectedException.none();
31+
32+
@Test
33+
public void test_getSdkAsyncHttpClient_success() {
34+
SdkAsyncHttpClient client = MLHttpClientFactory.getAsyncHttpClient(Duration.ofSeconds(100), Duration.ofSeconds(100), 100);
35+
assertNotNull(client);
36+
}
37+
38+
@Test
39+
public void test_invalidIP_localHost_privateIPDisabled() {
40+
IllegalArgumentException e1 = assertThrows(
41+
IllegalArgumentException.class,
42+
() -> MLHttpClientFactory.validate(HTTP, "127.0.0.1", 80, PRIVATE_IP_DISABLED)
43+
);
44+
assertEquals("Remote inference host name has private ip address: 127.0.0.1", e1.getMessage());
45+
46+
IllegalArgumentException e2 = assertThrows(
47+
IllegalArgumentException.class,
48+
() -> MLHttpClientFactory.validate(HTTP, "192.168.0.1", 80, PRIVATE_IP_DISABLED)
49+
);
50+
assertEquals("Remote inference host name has private ip address: 192.168.0.1", e2.getMessage());
51+
52+
IllegalArgumentException e3 = assertThrows(
53+
IllegalArgumentException.class,
54+
() -> MLHttpClientFactory.validate(HTTP, "169.254.0.1", 80, PRIVATE_IP_DISABLED)
55+
);
56+
assertEquals("Remote inference host name has private ip address: 169.254.0.1", e3.getMessage());
57+
58+
IllegalArgumentException e4 = assertThrows(
59+
IllegalArgumentException.class,
60+
() -> MLHttpClientFactory.validate(HTTP, "172.16.0.1", 80, PRIVATE_IP_DISABLED)
61+
);
62+
assertEquals("Remote inference host name has private ip address: 172.16.0.1", e4.getMessage());
63+
64+
IllegalArgumentException e5 = assertThrows(
65+
IllegalArgumentException.class,
66+
() -> MLHttpClientFactory.validate(HTTP, "172.31.0.1", 80, PRIVATE_IP_DISABLED)
67+
);
68+
assertEquals("Remote inference host name has private ip address: 172.31.0.1", e5.getMessage());
69+
}
70+
71+
@Test
72+
public void test_validateIp_validIp_noException() throws Exception {
73+
MLHttpClientFactory.validate(HTTP, TEST_HOST, 80, PRIVATE_IP_DISABLED);
74+
MLHttpClientFactory.validate(HTTPS, TEST_HOST, 443, PRIVATE_IP_DISABLED);
75+
MLHttpClientFactory.validate(HTTP, "127.0.0.1", 80, PRIVATE_IP_ENABLED);
76+
MLHttpClientFactory.validate(HTTPS, "127.0.0.1", 443, PRIVATE_IP_ENABLED);
77+
MLHttpClientFactory.validate(HTTP, "177.16.0.1", 80, PRIVATE_IP_DISABLED);
78+
MLHttpClientFactory.validate(HTTP, "177.0.1.1", 80, PRIVATE_IP_DISABLED);
79+
MLHttpClientFactory.validate(HTTP, "177.0.0.2", 80, PRIVATE_IP_DISABLED);
80+
MLHttpClientFactory.validate(HTTP, "::ffff", 80, PRIVATE_IP_DISABLED);
81+
MLHttpClientFactory.validate(HTTP, "172.32.0.1", 80, PRIVATE_IP_ENABLED);
82+
MLHttpClientFactory.validate(HTTP, "172.2097152", 80, PRIVATE_IP_ENABLED);
83+
}
84+
85+
@Test
86+
public void test_validateIp_rarePrivateIp_throwException() throws Exception {
87+
try {
88+
MLHttpClientFactory.validate(HTTP, "0254.020.00.01", 80, PRIVATE_IP_DISABLED);
89+
} catch (IllegalArgumentException e) {
90+
assertNotNull(e);
91+
}
92+
93+
try {
94+
MLHttpClientFactory.validate(HTTP, "172.1048577", 80, PRIVATE_IP_DISABLED);
95+
} catch (Exception e) {
96+
assertNotNull(e);
97+
}
98+
99+
try {
100+
MLHttpClientFactory.validate(HTTP, "2886729729", 80, PRIVATE_IP_DISABLED);
101+
} catch (IllegalArgumentException e) {
102+
assertNotNull(e);
103+
}
104+
105+
try {
106+
MLHttpClientFactory.validate(HTTP, "192.11010049", 80, PRIVATE_IP_DISABLED);
107+
} catch (IllegalArgumentException e) {
108+
assertNotNull(e);
109+
}
110+
111+
try {
112+
MLHttpClientFactory.validate(HTTP, "3232300545", 80, PRIVATE_IP_DISABLED);
113+
} catch (IllegalArgumentException e) {
114+
assertNotNull(e);
115+
}
116+
117+
try {
118+
MLHttpClientFactory.validate(HTTP, "0:0:0:0:0:ffff:127.0.0.1", 80, PRIVATE_IP_DISABLED);
119+
} catch (IllegalArgumentException e) {
120+
assertNotNull(e);
121+
}
122+
123+
try {
124+
MLHttpClientFactory.validate(HTTP, "153.24.76.232", 80, PRIVATE_IP_DISABLED);
125+
} catch (IllegalArgumentException e) {
126+
assertNotNull(e);
127+
}
128+
129+
try {
130+
MLHttpClientFactory.validate(HTTP, "177.0.0.1", 80, PRIVATE_IP_DISABLED);
131+
} catch (IllegalArgumentException e) {
132+
assertNotNull(e);
133+
}
134+
135+
try {
136+
MLHttpClientFactory.validate(HTTP, "12.16.2.3", 80, PRIVATE_IP_DISABLED);
137+
} catch (IllegalArgumentException e) {
138+
assertNotNull(e);
139+
}
140+
}
141+
142+
@Test
143+
public void test_validateIp_rarePrivateIp_NotThrowException() throws Exception {
144+
MLHttpClientFactory.validate(HTTP, "0254.020.00.01", 80, PRIVATE_IP_ENABLED);
145+
MLHttpClientFactory.validate(HTTPS, "0254.020.00.01", 443, PRIVATE_IP_ENABLED);
146+
MLHttpClientFactory.validate(HTTP, "172.1048577", 80, PRIVATE_IP_ENABLED);
147+
MLHttpClientFactory.validate(HTTP, "2886729729", 80, PRIVATE_IP_ENABLED);
148+
MLHttpClientFactory.validate(HTTP, "192.11010049", 80, PRIVATE_IP_ENABLED);
149+
MLHttpClientFactory.validate(HTTP, "3232300545", 80, PRIVATE_IP_ENABLED);
150+
MLHttpClientFactory.validate(HTTP, "0:0:0:0:0:ffff:127.0.0.1", 80, PRIVATE_IP_ENABLED);
151+
MLHttpClientFactory.validate(HTTPS, "0:0:0:0:0:ffff:127.0.0.1", 443, PRIVATE_IP_ENABLED);
152+
MLHttpClientFactory.validate(HTTP, "153.24.76.232", 80, PRIVATE_IP_ENABLED);
153+
MLHttpClientFactory.validate(HTTP, "10.24.76.186", 80, PRIVATE_IP_ENABLED);
154+
MLHttpClientFactory.validate(HTTPS, "10.24.76.186", 443, PRIVATE_IP_ENABLED);
155+
}
156+
157+
@Test
158+
public void test_validateSchemaAndPort_success() throws Exception {
159+
MLHttpClientFactory.validate(HTTP, TEST_HOST, 80, PRIVATE_IP_DISABLED);
160+
}
161+
162+
@Test
163+
public void test_validateSchemaAndPort_notAllowedSchema_throwException() throws Exception {
164+
expectedException.expect(IllegalArgumentException.class);
165+
MLHttpClientFactory.validate("ftp", TEST_HOST, 80, PRIVATE_IP_DISABLED);
166+
}
167+
168+
@Test
169+
public void test_validateSchemaAndPort_portNotInRange1_throwException() throws Exception {
170+
expectedException.expect(IllegalArgumentException.class);
171+
expectedException.expectMessage("Port out of range: 65537");
172+
MLHttpClientFactory.validate(HTTPS, TEST_HOST, 65537, PRIVATE_IP_DISABLED);
173+
}
174+
175+
@Test
176+
public void test_validateSchemaAndPort_portNotInRange2_throwException() throws Exception {
177+
expectedException.expect(IllegalArgumentException.class);
178+
expectedException.expectMessage("Port out of range: -10");
179+
MLHttpClientFactory.validate(HTTP, TEST_HOST, -10, PRIVATE_IP_DISABLED);
180+
}
181+
182+
@Test
183+
public void test_validatePort_boundaries_success() throws Exception {
184+
MLHttpClientFactory.validate(HTTP, TEST_HOST, 65536, PRIVATE_IP_DISABLED);
185+
MLHttpClientFactory.validate(HTTP, TEST_HOST, 0, PRIVATE_IP_DISABLED);
186+
MLHttpClientFactory.validate(HTTP, TEST_HOST, -1, PRIVATE_IP_DISABLED);
187+
MLHttpClientFactory.validate(HTTPS, TEST_HOST, -1, PRIVATE_IP_DISABLED);
188+
MLHttpClientFactory.validate(null, TEST_HOST, -1, PRIVATE_IP_DISABLED);
189+
}
190+
191+
}

ml-algorithms/build.gradle

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ dependencies {
5858
// Multi-tenant SDK Client
5959
implementation "org.opensearch:opensearch-remote-metadata-sdk:${opensearch_build}"
6060
implementation 'commons-beanutils:commons-beanutils:1.11.0'
61+
implementation "org.opensearch:opensearch-remote-metadata-sdk-ddb-client:${opensearch_build}"
6162

6263
def os = DefaultNativePlatform.currentOperatingSystem
6364
//arm/macos doesn't support GPU
@@ -88,9 +89,7 @@ dependencies {
8889
}
8990
implementation('net.minidev:json-smart:2.5.2')
9091
implementation group: 'org.json', name: 'json', version: '20231013'
91-
implementation(enforcedPlatform("io.netty:netty-bom:4.2.5.Final"))
92-
implementation("software.amazon.awssdk:netty-nio-client")
93-
92+
implementation group: 'software.amazon.awssdk', name: 'netty-nio-client', version: "2.30.18"
9493
testImplementation("com.fasterxml.jackson.core:jackson-annotations:${versions.jackson}")
9594
testImplementation("com.fasterxml.jackson.core:jackson-databind:${versions.jackson_databind}")
9695
testImplementation group: 'com.networknt' , name: 'json-schema-validator', version: '1.4.0'

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@
2424
import org.opensearch.ml.common.connector.AwsConnector;
2525
import org.opensearch.ml.common.connector.Connector;
2626
import org.opensearch.ml.common.exception.MLException;
27+
import org.opensearch.ml.common.httpclient.MLHttpClientFactory;
2728
import org.opensearch.ml.common.input.MLInput;
2829
import org.opensearch.ml.common.model.MLGuard;
2930
import org.opensearch.ml.common.output.model.ModelTensors;
3031
import org.opensearch.ml.engine.annotation.ConnectorExecutor;
31-
import org.opensearch.ml.engine.httpclient.MLHttpClientFactory;
3232
import org.opensearch.script.ScriptService;
3333

3434
import lombok.Getter;

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@
2626
import org.opensearch.ml.common.connector.Connector;
2727
import org.opensearch.ml.common.connector.HttpConnector;
2828
import org.opensearch.ml.common.exception.MLException;
29+
import org.opensearch.ml.common.httpclient.MLHttpClientFactory;
2930
import org.opensearch.ml.common.input.MLInput;
3031
import org.opensearch.ml.common.model.MLGuard;
3132
import org.opensearch.ml.common.output.model.ModelTensors;
3233
import org.opensearch.ml.engine.annotation.ConnectorExecutor;
33-
import org.opensearch.ml.engine.httpclient.MLHttpClientFactory;
3434
import org.opensearch.script.ScriptService;
3535

3636
import lombok.Getter;

0 commit comments

Comments
 (0)