Skip to content

Commit e55fffa

Browse files
authored
add integration test for train and predict API (#157)
* add integration test for train and predict API Signed-off-by: Yaliang Wu <[email protected]> * test default parameter for train&predict Kmeans Signed-off-by: Yaliang Wu <[email protected]>
1 parent 638f72e commit e55fffa

File tree

9 files changed

+799
-20
lines changed

9 files changed

+799
-20
lines changed

DEVELOPER_GUIDE.md

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,12 @@ This package uses the [Gradle](https://docs.gradle.org/current/userguide/usergui
3434

3535
1. `./gradlew build` builds and tests
3636
2. `./gradlew :run` launches a single node cluster with ml-commons plugin installed
37-
3. `./gradlew :integTest` launches a single node cluster with ml-commons plugin installed and runs all integration tests except security
38-
4. ` ./gradlew :integTest --tests="**.test execute foo"` runs a single integration test class or method
39-
5. `./gradlew spotlessApply` formats code. And/or import formatting rules in `.eclipseformat.xml` with IDE.
37+
3. `./gradlew :integTest` launches a single node cluster with ml-commons plugin installed and runs all integration tests except security. Use `./gradlew integTest -PnumNodes=<number>` to launch multi-node cluster.
38+
4. ` ./gradlew :integTest --tests="<class path>.<test method>"` runs a single integration test class or method, for example `./gradlew integTest --tests="org.opensearch.ml.rest.RestMLTrainAndPredictIT.testTrainAndPredictKmeansWithEmptyParam"` or `./gradlew integTest --tests="org.opensearch.ml.rest.RestMLTrainAndPredictIT"`
39+
5. `./gradlew integTest -Dtests.class="<class path>"` run specific integ test class, for example `./gradlew integTest -Dtests.class="org.opensearch.ml.rest.RestMLTrainAndPredictIT"`
40+
6. `./gradlew integTest -Dtests.method="<method name>"` run specific integ test method, for example `./gradlew integTest -Dtests.method="testTrainAndPredictKmeans"`
41+
7. `./gradlew integTest -Dtests.rest.cluster=localhost:9200 -Dtests.cluster=localhost:9200 -Dtests.clustername="docker-cluster" -Dhttps=true -Duser=admin -Dpassword=admin` launches integration tests against a local cluster and run tests with security. Detail steps: (1)download OpenSearch tarball to local and install by running `opensearch-tar-install.sh`; (2)build ML plugin zip with your change and install ML plugin zip; (3)restart local test cluster; (4) run this gradle command to test.
42+
8. `./gradlew spotlessApply` formats code. And/or import formatting rules in `.eclipseformat.xml` with IDE.
4043

4144
When launching a cluster using one of the above commands logs are placed in `/build/cluster/run node0/opensearch-<version>/logs`. Though the logs are teed to the console, in practices it's best to check the actual log file.
4245

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/rcf/FixedInTimeRandomCutForestTest.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ public void predict() {
5050
int anomalyCount = 0;
5151
for (int i = 0 ;i<dataSize; i++) {
5252
if (i % 100 == 0) {
53-
System.out.println(predictions.getRow(i).getValue(1).doubleValue());
5453
if (predictions.getRow(i).getValue(1).doubleValue() > 0.01) {
5554
anomalyCount++;
5655
}

plugin/build.gradle

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

6+
import java.util.concurrent.Callable
67
import org.opensearch.gradle.test.RestIntegTestTask
8+
import org.opensearch.gradle.testclusters.StandaloneRestIntegTestTask
79

810
plugins {
911
id 'java'
@@ -43,24 +45,12 @@ dependencies {
4345
compile "org.opensearch:common-utils:${common_utils_version}"
4446
compile("com.fasterxml.jackson.core:jackson-annotations:${versions.jackson}")
4547
compile("com.fasterxml.jackson.core:jackson-databind:${versions.jackson}")
46-
compile group: 'com.google.guava', name: 'guava', version:'29.0-jre'
48+
implementation group: 'com.google.guava', name: 'guava', version: '31.0.1-jre'
49+
implementation group: 'com.google.code.gson', name: 'gson', version: '2.9.0'
4750

4851
checkstyle "com.puppycrawl.tools:checkstyle:${project.checkstyle.toolVersion}"
4952
}
5053

51-
test {
52-
include '**/*Test.class'
53-
systemProperty 'tests.security.manager', 'false'
54-
finalizedBy jacocoTestReport
55-
}
56-
57-
jacocoTestReport {
58-
reports {
59-
xml.enabled true
60-
html.enabled true
61-
csv.enabled true
62-
}
63-
}
6454

6555
compileJava {
6656
options.compilerArgs.addAll(["-processor", 'lombok.launch.AnnotationProcessorHider$AnnotationProcessor'])
@@ -81,14 +71,16 @@ loggerUsageCheck.enabled = false
8171

8272
def _numNodes = findProperty('numNodes') as Integer ?: 1
8373

84-
def opensearch_tmp_dir = rootProject.file('build/private/opensearch_tmp').absoluteFile
85-
opensearch_tmp_dir.mkdirs()
8674

8775
test {
8876
include '**/*Tests.class'
8977
systemProperty 'tests.security.manager', 'false'
9078
}
9179

80+
def opensearch_tmp_dir = rootProject.file('build/private/opensearch_tmp').absoluteFile
81+
opensearch_tmp_dir.mkdirs()
82+
83+
9284
task integTest(type: RestIntegTestTask) {
9385
description = "Run tests against a cluster"
9486
testClassesDirs = sourceSets.test.output.classesDirs
@@ -105,6 +97,13 @@ integTest {
10597
systemProperty "user", System.getProperty("user")
10698
systemProperty "password", System.getProperty("password")
10799

100+
// Only rest case can run with remote cluster
101+
if (System.getProperty("tests.rest.cluster") != null) {
102+
filter {
103+
includeTestsMatching "org.opensearch.ml.rest.*IT"
104+
}
105+
}
106+
108107
// The 'doFirst' delays till execution time.
109108
doFirst {
110109
// Tell the test JVM if the cluster JVM is running under a debugger so that tests can
@@ -150,6 +149,24 @@ testClusters.integTest {
150149
}
151150
}
152151

152+
task integTestRemote(type: RestIntegTestTask) {
153+
testClassesDirs = sourceSets.test.output.classesDirs
154+
classpath = sourceSets.test.runtimeClasspath
155+
systemProperty 'tests.security.manager', 'false'
156+
systemProperty 'java.io.tmpdir', opensearch_tmp_dir.absolutePath
157+
158+
systemProperty "https", System.getProperty("https")
159+
systemProperty "user", System.getProperty("user")
160+
systemProperty "password", System.getProperty("password")
161+
162+
// Only rest case can run with remote cluster
163+
if (System.getProperty("tests.rest.cluster") != null) {
164+
filter {
165+
includeTestsMatching "org.opensearch.ml.rest.*IT"
166+
}
167+
}
168+
}
169+
153170
run {
154171
doFirst {
155172
// There seems to be an issue when running multi node run or integ tasks with unicast_hosts
@@ -172,6 +189,7 @@ task release(type: Copy, group: 'build') {
172189
jacocoTestReport {
173190
reports {
174191
xml.enabled true
192+
html.enabled true
175193
csv.enabled false
176194
}
177195

Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
package org.opensearch.ml.rest;
2+
3+
import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_ENABLED;
4+
import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_FILEPATH;
5+
import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_KEYPASSWORD;
6+
import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_PASSWORD;
7+
import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_PEMCERT_FILEPATH;
8+
9+
import java.io.IOException;
10+
import java.net.URI;
11+
import java.net.URISyntaxException;
12+
import java.nio.file.Path;
13+
import java.util.Collections;
14+
import java.util.List;
15+
import java.util.Map;
16+
import java.util.Objects;
17+
import java.util.Optional;
18+
import java.util.stream.Collectors;
19+
20+
import org.apache.http.Header;
21+
import org.apache.http.HttpHeaders;
22+
import org.apache.http.HttpHost;
23+
import org.apache.http.auth.AuthScope;
24+
import org.apache.http.auth.UsernamePasswordCredentials;
25+
import org.apache.http.client.CredentialsProvider;
26+
import org.apache.http.conn.ssl.NoopHostnameVerifier;
27+
import org.apache.http.impl.client.BasicCredentialsProvider;
28+
import org.apache.http.message.BasicHeader;
29+
import org.apache.http.ssl.SSLContextBuilder;
30+
import org.apache.http.util.EntityUtils;
31+
import org.junit.After;
32+
import org.opensearch.client.Request;
33+
import org.opensearch.client.Response;
34+
import org.opensearch.client.RestClient;
35+
import org.opensearch.client.RestClientBuilder;
36+
import org.opensearch.common.io.PathUtils;
37+
import org.opensearch.common.settings.Settings;
38+
import org.opensearch.common.unit.TimeValue;
39+
import org.opensearch.common.util.concurrent.ThreadContext;
40+
import org.opensearch.common.xcontent.DeprecationHandler;
41+
import org.opensearch.common.xcontent.NamedXContentRegistry;
42+
import org.opensearch.common.xcontent.XContentParser;
43+
import org.opensearch.common.xcontent.XContentType;
44+
import org.opensearch.commons.rest.SecureRestClientBuilder;
45+
import org.opensearch.ml.utils.TestData;
46+
import org.opensearch.ml.utils.TestHelper;
47+
import org.opensearch.rest.RestStatus;
48+
import org.opensearch.test.rest.OpenSearchRestTestCase;
49+
50+
import com.google.common.collect.ImmutableList;
51+
import com.google.common.collect.ImmutableMap;
52+
53+
public abstract class MLCommonsRestTestCase extends OpenSearchRestTestCase {
54+
55+
protected boolean isHttps() {
56+
boolean isHttps = Optional.ofNullable(System.getProperty("https")).map("true"::equalsIgnoreCase).orElse(false);
57+
if (isHttps) {
58+
// currently only external cluster is supported for security enabled testing
59+
if (!Optional.ofNullable(System.getProperty("tests.rest.cluster")).isPresent()) {
60+
throw new RuntimeException("cluster url should be provided for security enabled testing");
61+
}
62+
}
63+
64+
return isHttps;
65+
}
66+
67+
@Override
68+
protected String getProtocol() {
69+
return isHttps() ? "https" : "http";
70+
}
71+
72+
@Override
73+
protected Settings restAdminSettings() {
74+
return Settings
75+
.builder()
76+
// disable the warning exception for admin client since it's only used for cleanup.
77+
.put("strictDeprecationMode", false)
78+
.put("http.port", 9200)
79+
.put(OPENSEARCH_SECURITY_SSL_HTTP_ENABLED, isHttps())
80+
.put(OPENSEARCH_SECURITY_SSL_HTTP_PEMCERT_FILEPATH, "sample.pem")
81+
.put(OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_FILEPATH, "test-kirk.jks")
82+
.put(OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_PASSWORD, "changeit")
83+
.put(OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_KEYPASSWORD, "changeit")
84+
.build();
85+
}
86+
87+
// Utility fn for deleting indices. Should only be used when not allowed in a regular context
88+
// (e.g., deleting system indices)
89+
protected static void deleteIndexWithAdminClient(String name) throws IOException {
90+
Request request = new Request("DELETE", "/" + name);
91+
adminClient().performRequest(request);
92+
}
93+
94+
// Utility fn for checking if an index exists. Should only be used when not allowed in a regular context
95+
// (e.g., checking existence of system indices)
96+
protected static boolean indexExistsWithAdminClient(String indexName) throws IOException {
97+
Request request = new Request("HEAD", "/" + indexName);
98+
Response response = adminClient().performRequest(request);
99+
return RestStatus.OK.getStatus() == response.getStatusLine().getStatusCode();
100+
}
101+
102+
@Override
103+
protected RestClient buildClient(Settings settings, HttpHost[] hosts) throws IOException {
104+
boolean strictDeprecationMode = settings.getAsBoolean("strictDeprecationMode", true);
105+
RestClientBuilder builder = RestClient.builder(hosts);
106+
if (isHttps()) {
107+
String keystore = settings.get(OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_FILEPATH);
108+
if (Objects.nonNull(keystore)) {
109+
URI uri = null;
110+
try {
111+
uri = this.getClass().getClassLoader().getResource("security/sample.pem").toURI();
112+
} catch (URISyntaxException e) {
113+
throw new RuntimeException(e);
114+
}
115+
Path configPath = PathUtils.get(uri).getParent().toAbsolutePath();
116+
return new SecureRestClientBuilder(settings, configPath).build();
117+
} else {
118+
configureHttpsClient(builder, settings);
119+
builder.setStrictDeprecationMode(strictDeprecationMode);
120+
return builder.build();
121+
}
122+
123+
} else {
124+
configureClient(builder, settings);
125+
builder.setStrictDeprecationMode(strictDeprecationMode);
126+
return builder.build();
127+
}
128+
129+
}
130+
131+
@SuppressWarnings("unchecked")
132+
@After
133+
protected void wipeAllODFEIndices() throws IOException {
134+
Response response = adminClient().performRequest(new Request("GET", "/_cat/indices?format=json&expand_wildcards=all"));
135+
XContentType xContentType = XContentType.fromMediaTypeOrFormat(response.getEntity().getContentType().getValue());
136+
try (
137+
XContentParser parser = xContentType
138+
.xContent()
139+
.createParser(
140+
NamedXContentRegistry.EMPTY,
141+
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
142+
response.getEntity().getContent()
143+
)
144+
) {
145+
XContentParser.Token token = parser.nextToken();
146+
List<Map<String, Object>> parserList = null;
147+
if (token == XContentParser.Token.START_ARRAY) {
148+
parserList = parser.listOrderedMap().stream().map(obj -> (Map<String, Object>) obj).collect(Collectors.toList());
149+
} else {
150+
parserList = Collections.singletonList(parser.mapOrdered());
151+
}
152+
153+
for (Map<String, Object> index : parserList) {
154+
String indexName = (String) index.get("index");
155+
if (indexName != null && !".opendistro_security".equals(indexName)) {
156+
adminClient().performRequest(new Request("DELETE", "/" + indexName));
157+
}
158+
}
159+
}
160+
}
161+
162+
protected static void configureHttpsClient(RestClientBuilder builder, Settings settings) throws IOException {
163+
Map<String, String> headers = ThreadContext.buildDefaultHeaders(settings);
164+
Header[] defaultHeaders = new Header[headers.size()];
165+
int i = 0;
166+
for (Map.Entry<String, String> entry : headers.entrySet()) {
167+
defaultHeaders[i++] = new BasicHeader(entry.getKey(), entry.getValue());
168+
}
169+
builder.setDefaultHeaders(defaultHeaders);
170+
builder.setHttpClientConfigCallback(httpClientBuilder -> {
171+
String userName = Optional
172+
.ofNullable(System.getProperty("user"))
173+
.orElseThrow(() -> new RuntimeException("user name is missing"));
174+
String password = Optional
175+
.ofNullable(System.getProperty("password"))
176+
.orElseThrow(() -> new RuntimeException("password is missing"));
177+
CredentialsProvider credentialsProvider = new BasicCredentialsProvider();
178+
credentialsProvider.setCredentials(AuthScope.ANY, new UsernamePasswordCredentials(userName, password));
179+
try {
180+
return httpClientBuilder
181+
.setDefaultCredentialsProvider(credentialsProvider)
182+
// disable the certificate since our testing cluster just uses the default security configuration
183+
.setSSLHostnameVerifier(NoopHostnameVerifier.INSTANCE)
184+
.setSSLContext(SSLContextBuilder.create().loadTrustMaterial(null, (chains, authType) -> true).build());
185+
} catch (Exception e) {
186+
throw new RuntimeException(e);
187+
}
188+
});
189+
190+
final String socketTimeoutString = settings.get(CLIENT_SOCKET_TIMEOUT);
191+
final TimeValue socketTimeout = TimeValue
192+
.parseTimeValue(socketTimeoutString == null ? "60s" : socketTimeoutString, CLIENT_SOCKET_TIMEOUT);
193+
builder.setRequestConfigCallback(conf -> conf.setSocketTimeout(Math.toIntExact(socketTimeout.getMillis())));
194+
if (settings.hasValue(CLIENT_PATH_PREFIX)) {
195+
builder.setPathPrefix(settings.get(CLIENT_PATH_PREFIX));
196+
}
197+
}
198+
199+
/**
200+
* wipeAllIndices won't work since it cannot delete security index. Use wipeAllODFEIndices instead.
201+
*/
202+
@Override
203+
protected boolean preserveIndicesUponCompletion() {
204+
return true;
205+
}
206+
207+
protected Response ingestIrisData(String indexName) throws IOException {
208+
String irisDataIndexMapping = "";
209+
TestHelper
210+
.makeRequest(
211+
client(),
212+
"PUT",
213+
indexName,
214+
null,
215+
TestHelper.toHttpEntity(irisDataIndexMapping),
216+
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana"))
217+
);
218+
219+
Response statsResponse = TestHelper.makeRequest(client(), "GET", indexName, ImmutableMap.of(), "", null);
220+
assertEquals(RestStatus.OK, TestHelper.restStatus(statsResponse));
221+
String result = EntityUtils.toString(statsResponse.getEntity());
222+
assertTrue(result.contains(indexName));
223+
224+
Response bulkResponse = TestHelper
225+
.makeRequest(
226+
client(),
227+
"POST",
228+
"_bulk?refresh=true",
229+
null,
230+
TestHelper.toHttpEntity(TestData.IRIS_DATA),
231+
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, ""))
232+
);
233+
assertEquals(RestStatus.OK, TestHelper.restStatus(statsResponse));
234+
return bulkResponse;
235+
}
236+
}

0 commit comments

Comments
 (0)