Skip to content

Commit 9908229

Browse files
authored
Add more UT for remote inference classes (#1077) (#1090)
Signed-off-by: Sicheng Song <[email protected]>
1 parent 20e0978 commit 9908229

13 files changed

+741
-83
lines changed

common/build.gradle

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,11 @@ jacocoTestCoverageVerification {
3636
rule {
3737
limit {
3838
counter = 'LINE'
39-
minimum = 0.6 //TODO: add more test to meet the coverage bar 0.9
39+
minimum = 0.8 //TODO: add more test to meet the coverage bar 0.9
4040
}
4141
limit {
4242
counter = 'BRANCH'
43-
minimum = 0.5 //TODO: add more test to meet the coverage bar 0.9
43+
minimum = 0.7 //TODO: add more test to meet the coverage bar 0.9
4444
}
4545
}
4646
}

common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ public void writeTo(StreamOutput output) throws IOException {
227227
}
228228
if (!CollectionUtils.isEmpty(backendRoles)) {
229229
output.writeBoolean(true);
230-
output.writeOptionalStringCollection(backendRoles);
230+
output.writeStringCollection(backendRoles);
231231
} else {
232232
output.writeBoolean(false);
233233
}
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.transport.connector;
7+
8+
import org.junit.Before;
9+
import org.junit.Test;
10+
import org.opensearch.action.ActionRequest;
11+
import org.opensearch.action.ActionRequestValidationException;
12+
import org.opensearch.common.io.stream.BytesStreamOutput;
13+
import org.opensearch.common.io.stream.StreamOutput;
14+
15+
import java.io.IOException;
16+
import java.io.UncheckedIOException;
17+
18+
import static org.junit.Assert.assertEquals;
19+
import static org.junit.Assert.assertNotSame;
20+
import static org.junit.Assert.assertNull;
21+
import static org.junit.Assert.assertSame;
22+
23+
public class MLConnectorDeleteRequestTests {
24+
private String connectorId;
25+
26+
@Before
27+
public void setUp() {
28+
connectorId = "test-connector-id";
29+
}
30+
31+
@Test
32+
public void writeTo_Success() throws IOException {
33+
MLConnectorDeleteRequest mlConnectorDeleteRequest = MLConnectorDeleteRequest.builder()
34+
.connectorId(connectorId).build();
35+
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
36+
mlConnectorDeleteRequest.writeTo(bytesStreamOutput);
37+
MLConnectorDeleteRequest parsedConnector = new MLConnectorDeleteRequest(bytesStreamOutput.bytes().streamInput());
38+
assertEquals(parsedConnector.getConnectorId(), connectorId);
39+
}
40+
41+
@Test
42+
public void valid_Exception_NullConnectorId() {
43+
MLConnectorDeleteRequest mlConnectorDeleteRequest = MLConnectorDeleteRequest.builder().build();
44+
ActionRequestValidationException exception = mlConnectorDeleteRequest.validate();
45+
assertEquals("Validation Failed: 1: ML connector id can't be null;", exception.getMessage());
46+
}
47+
48+
@Test
49+
public void validate_Success() {
50+
MLConnectorDeleteRequest mlConnectorDeleteRequest = MLConnectorDeleteRequest.builder()
51+
.connectorId(connectorId).build();
52+
ActionRequestValidationException actionRequestValidationException = mlConnectorDeleteRequest.validate();
53+
assertNull(actionRequestValidationException);
54+
}
55+
56+
@Test
57+
public void fromActionRequest_Success() {
58+
MLConnectorDeleteRequest mlConnectorDeleteRequest = MLConnectorDeleteRequest.builder()
59+
.connectorId(connectorId).build();
60+
ActionRequest actionRequest = new ActionRequest() {
61+
@Override
62+
public ActionRequestValidationException validate() {
63+
return null;
64+
}
65+
66+
@Override
67+
public void writeTo(StreamOutput out) throws IOException {
68+
mlConnectorDeleteRequest.writeTo(out);
69+
}
70+
};
71+
MLConnectorDeleteRequest parsedConnector = MLConnectorDeleteRequest.fromActionRequest(actionRequest);
72+
assertNotSame(parsedConnector, mlConnectorDeleteRequest);
73+
assertEquals(parsedConnector.getConnectorId(), connectorId);
74+
}
75+
76+
@Test(expected = UncheckedIOException.class)
77+
public void fromActionRequest_IOException() {
78+
ActionRequest actionRequest = new ActionRequest() {
79+
@Override
80+
public ActionRequestValidationException validate() {
81+
return null;
82+
}
83+
84+
@Override
85+
public void writeTo(StreamOutput out) throws IOException {
86+
throw new IOException();
87+
}
88+
};
89+
MLConnectorDeleteRequest.fromActionRequest(actionRequest);
90+
}
91+
92+
@Test
93+
public void fromActionRequestWithConnectorDeleteRequest_Success() {
94+
MLConnectorDeleteRequest mlConnectorDeleteRequest = MLConnectorDeleteRequest.builder()
95+
.connectorId(connectorId).build();
96+
MLConnectorDeleteRequest mlConnectorDeleteRequestFromActionRequest = MLConnectorDeleteRequest.fromActionRequest(mlConnectorDeleteRequest);
97+
assertSame(mlConnectorDeleteRequest, mlConnectorDeleteRequestFromActionRequest);
98+
assertEquals(mlConnectorDeleteRequest.getConnectorId(), mlConnectorDeleteRequestFromActionRequest.getConnectorId());
99+
}
100+
}
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
7+
package org.opensearch.ml.common.transport.connector;
8+
9+
import java.io.IOException;
10+
import java.io.UncheckedIOException;
11+
12+
import org.junit.Before;
13+
import org.junit.Test;
14+
import org.opensearch.action.ActionRequest;
15+
import org.opensearch.action.ActionRequestValidationException;
16+
import org.opensearch.common.io.stream.BytesStreamOutput;
17+
import org.opensearch.common.io.stream.StreamOutput;
18+
19+
import static org.junit.Assert.assertEquals;
20+
import static org.junit.Assert.assertNotSame;
21+
import static org.junit.Assert.assertNull;
22+
import static org.junit.Assert.assertSame;
23+
24+
public class MLConnectorGetRequestTests {
25+
private String connectorId;
26+
27+
@Before
28+
public void setUp() {
29+
connectorId = "test-connector-id";
30+
}
31+
32+
@Test
33+
public void writeTo_Success() throws IOException {
34+
MLConnectorGetRequest mlConnectorGetRequest = MLConnectorGetRequest.builder().connectorId(connectorId).build();
35+
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
36+
mlConnectorGetRequest.writeTo(bytesStreamOutput);
37+
MLConnectorGetRequest parsedConnector = new MLConnectorGetRequest(bytesStreamOutput.bytes().streamInput());
38+
assertEquals(connectorId, parsedConnector.getConnectorId());
39+
}
40+
41+
@Test
42+
public void fromActionRequest_Success() {
43+
MLConnectorGetRequest mlConnectorGetRequest = MLConnectorGetRequest.builder().connectorId(connectorId).build();
44+
ActionRequest actionRequest = new ActionRequest() {
45+
@Override
46+
public ActionRequestValidationException validate() {
47+
return null;
48+
}
49+
50+
@Override
51+
public void writeTo(StreamOutput out) throws IOException {
52+
mlConnectorGetRequest.writeTo(out);
53+
}
54+
};
55+
MLConnectorGetRequest mlConnectorGetRequestFromActionRequest = MLConnectorGetRequest.fromActionRequest(actionRequest);
56+
assertNotSame(mlConnectorGetRequest, mlConnectorGetRequestFromActionRequest);
57+
assertEquals(mlConnectorGetRequest.getConnectorId(), mlConnectorGetRequestFromActionRequest.getConnectorId());
58+
}
59+
60+
@Test(expected = UncheckedIOException.class)
61+
public void fromActionRequest_IOException() {
62+
ActionRequest actionRequest = new ActionRequest() {
63+
@Override
64+
public ActionRequestValidationException validate() {
65+
return null;
66+
}
67+
68+
@Override
69+
public void writeTo(StreamOutput out) throws IOException {
70+
throw new IOException();
71+
}
72+
};
73+
MLConnectorGetRequest.fromActionRequest(actionRequest);
74+
}
75+
76+
@Test
77+
public void fromActionRequestWithMLConnectorGetRequest_Success() {
78+
MLConnectorGetRequest mlConnectorGetRequest = MLConnectorGetRequest.builder().connectorId(connectorId).build();
79+
MLConnectorGetRequest mlConnectorGetRequestFromActionRequest = MLConnectorGetRequest.fromActionRequest(mlConnectorGetRequest);
80+
assertSame(mlConnectorGetRequest, mlConnectorGetRequestFromActionRequest);
81+
assertEquals(mlConnectorGetRequest.getConnectorId(), mlConnectorGetRequestFromActionRequest.getConnectorId());
82+
}
83+
84+
@Test
85+
public void validate_Exception_NullConnctorId() {
86+
MLConnectorGetRequest mlConnectorGetRequest = MLConnectorGetRequest.builder().build();
87+
ActionRequestValidationException actionRequestValidationException = mlConnectorGetRequest.validate();
88+
assertEquals("Validation Failed: 1: ML connector id can't be null;", actionRequestValidationException.getMessage());
89+
}
90+
91+
@Test
92+
public void validate_Success() {
93+
MLConnectorGetRequest mlConnectorGetRequest = MLConnectorGetRequest.builder().connectorId(connectorId).build();
94+
ActionRequestValidationException actionRequestValidationException = mlConnectorGetRequest.validate();
95+
assertNull(actionRequestValidationException);
96+
}
97+
}
98+
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.transport.connector;
7+
8+
import org.junit.Before;
9+
import org.junit.Test;
10+
import org.opensearch.action.ActionResponse;
11+
import org.opensearch.common.Strings;
12+
import org.opensearch.common.io.stream.BytesStreamOutput;
13+
import org.opensearch.common.io.stream.StreamOutput;
14+
import org.opensearch.common.xcontent.XContentFactory;
15+
import org.opensearch.common.xcontent.XContentType;
16+
import org.opensearch.core.xcontent.ToXContent;
17+
import org.opensearch.core.xcontent.XContentBuilder;
18+
import org.opensearch.ml.common.connector.Connector;
19+
import org.opensearch.ml.common.connector.HttpConnectorTest;
20+
21+
import java.io.IOException;
22+
import java.io.UncheckedIOException;
23+
24+
import static org.junit.Assert.assertEquals;
25+
import static org.junit.Assert.assertNotEquals;
26+
import static org.junit.Assert.assertNotNull;
27+
import static org.junit.Assert.assertNotSame;
28+
import static org.junit.Assert.assertSame;
29+
30+
public class MLConnectorGetResponseTests {
31+
Connector connector;
32+
33+
@Before
34+
public void setUp() {
35+
connector = HttpConnectorTest.createHttpConnector();
36+
}
37+
38+
@Test
39+
public void writeTo_Success() throws IOException {
40+
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
41+
MLConnectorGetResponse response = MLConnectorGetResponse.builder().mlConnector(connector).build();
42+
response.writeTo(bytesStreamOutput);
43+
MLConnectorGetResponse parsedResponse = new MLConnectorGetResponse(bytesStreamOutput.bytes().streamInput());
44+
assertNotEquals(response, parsedResponse);
45+
assertNotSame(response.mlConnector, parsedResponse.mlConnector);
46+
assertEquals(response.mlConnector, parsedResponse.mlConnector);
47+
assertEquals(response.mlConnector.getName(), parsedResponse.mlConnector.getName());
48+
assertEquals(response.mlConnector.getAccess(), parsedResponse.mlConnector.getAccess());
49+
assertEquals(response.mlConnector.getProtocol(), parsedResponse.mlConnector.getProtocol());
50+
assertEquals(response.mlConnector.getDecryptedHeaders(), parsedResponse.mlConnector.getDecryptedHeaders());
51+
assertEquals(response.mlConnector.getBackendRoles(), parsedResponse.mlConnector.getBackendRoles());
52+
assertEquals(response.mlConnector.getActions(), parsedResponse.mlConnector.getActions());
53+
assertEquals(response.mlConnector.getParameters(), parsedResponse.mlConnector.getParameters());
54+
}
55+
56+
@Test
57+
public void toXContentTest() throws IOException {
58+
MLConnectorGetResponse mlConnectorGetResponse = MLConnectorGetResponse.builder().mlConnector(connector).build();
59+
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
60+
mlConnectorGetResponse.toXContent(builder, ToXContent.EMPTY_PARAMS);
61+
assertNotNull(builder);
62+
String jsonStr = Strings.toString(builder);
63+
assertEquals("{\"name\":\"test_connector_name\"," +
64+
"\"version\":\"1\",\"description\":\"this is a test connector\",\"protocol\":\"http\"," +
65+
"\"parameters\":{\"input\":\"test input value\"},\"credential\":{\"key\":\"test_key_value\"}," +
66+
"\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\"," +
67+
"\"headers\":{\"api_key\":\"${credential.key}\"}," +
68+
"\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," +
69+
"\"pre_process_function\":\"connector.pre_process.openai.embedding\"," +
70+
"\"post_process_function\":\"connector.post_process.openai.embedding\"}]," +
71+
"\"backend_roles\":[\"role1\",\"role2\"]," +
72+
"\"access\":\"public\"}", jsonStr);
73+
}
74+
75+
@Test
76+
public void fromActionResponseWithMLConnectorGetResponse_Success() {
77+
MLConnectorGetResponse mlConnectorGetResponse = MLConnectorGetResponse.builder().mlConnector(connector).build();
78+
MLConnectorGetResponse mlConnectorGetResponseFromActionResponse = MLConnectorGetResponse.fromActionResponse(mlConnectorGetResponse);
79+
assertSame(mlConnectorGetResponse, mlConnectorGetResponseFromActionResponse);
80+
assertEquals(mlConnectorGetResponse.mlConnector, mlConnectorGetResponseFromActionResponse.mlConnector);
81+
}
82+
83+
@Test
84+
public void fromActionResponse_Success() {
85+
MLConnectorGetResponse mlConnectorGetResponse = MLConnectorGetResponse.builder().mlConnector(connector).build();
86+
ActionResponse actionResponse = new ActionResponse() {
87+
@Override
88+
public void writeTo(StreamOutput out) throws IOException {
89+
mlConnectorGetResponse.writeTo(out);
90+
}
91+
};
92+
MLConnectorGetResponse mlConnectorGetResponseFromActionResponse = MLConnectorGetResponse.fromActionResponse(actionResponse);
93+
assertNotSame(mlConnectorGetResponse, mlConnectorGetResponseFromActionResponse);
94+
assertEquals(mlConnectorGetResponse.mlConnector, mlConnectorGetResponseFromActionResponse.mlConnector);
95+
}
96+
97+
@Test(expected = UncheckedIOException.class)
98+
public void fromActionResponse_IOException() {
99+
ActionResponse actionResponse = new ActionResponse() {
100+
@Override
101+
public void writeTo(StreamOutput out) throws IOException {
102+
throw new IOException();
103+
}
104+
};
105+
MLConnectorGetResponse.fromActionResponse(actionResponse);
106+
}
107+
}

0 commit comments

Comments
 (0)