Skip to content

Commit 0e7f310

Browse files
Add unit tests for Ai21ChatCompletionResponseHandler
1 parent 027d73c commit 0e7f310

File tree

1 file changed

+152
-0
lines changed

1 file changed

+152
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.ai21.completion;
9+
10+
import org.apache.http.HttpResponse;
11+
import org.apache.http.StatusLine;
12+
import org.elasticsearch.common.bytes.BytesReference;
13+
import org.elasticsearch.common.xcontent.XContentHelper;
14+
import org.elasticsearch.test.ESTestCase;
15+
import org.elasticsearch.xcontent.XContentFactory;
16+
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
17+
import org.elasticsearch.xpack.inference.external.http.HttpResult;
18+
import org.elasticsearch.xpack.inference.external.http.retry.RetryException;
19+
import org.elasticsearch.xpack.inference.external.request.Request;
20+
21+
import java.io.IOException;
22+
import java.net.URI;
23+
import java.net.URISyntaxException;
24+
import java.nio.charset.StandardCharsets;
25+
26+
import static org.elasticsearch.ExceptionsHelper.unwrapCause;
27+
import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS;
28+
import static org.hamcrest.Matchers.is;
29+
import static org.hamcrest.Matchers.isA;
30+
import static org.mockito.Mockito.mock;
31+
import static org.mockito.Mockito.when;
32+
33+
public class Ai21ChatCompletionResponseHandlerTests extends ESTestCase {
34+
private final Ai21ChatCompletionResponseHandler responseHandler = new Ai21ChatCompletionResponseHandler(
35+
"chat completions",
36+
(a, b) -> mock()
37+
);
38+
39+
public void testFailNotFound() throws IOException {
40+
var responseJson = XContentHelper.stripWhitespace("""
41+
{
42+
"detail": "Not Found"
43+
}
44+
""");
45+
46+
var errorJson = invalidResponseJson(responseJson, 404);
47+
48+
assertThat(errorJson, is(XContentHelper.stripWhitespace("""
49+
{
50+
"error" : {
51+
"code" : "not_found",
52+
"message" : "Resource not found at [https://api.ai21.com/studio/v1/chat/completions] for request from inference entity id \
53+
[id] status [404]. Error message: [{\\"detail\\":\\"Not Found\\"}]",
54+
"type" : "ai21_error"
55+
}
56+
}""")));
57+
}
58+
59+
public void testFailUnauthorized() throws IOException {
60+
var responseJson = XContentHelper.stripWhitespace("""
61+
{
62+
"detail": "Forbidden: Bad or missing Apikey/JWT."
63+
}
64+
""");
65+
66+
var errorJson = invalidResponseJson(responseJson, 401);
67+
68+
assertThat(errorJson, is(XContentHelper.stripWhitespace("""
69+
{
70+
"error" : {
71+
"code" : "unauthorized",
72+
"message" : "Received an authentication error status code for request from inference entity id [id] status [401]. Error \
73+
message: [{\\"detail\\":\\"Forbidden: Bad or missing Apikey/JWT.\\"}]",
74+
"type" : "ai21_error"
75+
}
76+
}""")));
77+
}
78+
79+
public void testFailUnprocessableEntity() throws IOException {
80+
var responseJson = XContentHelper.stripWhitespace("""
81+
{
82+
"detail": "The provided model is not supported. See https://docs.ai21.com/docs/jamba-foundation-models#api-versioning \
83+
for a list of supported models"
84+
}
85+
""");
86+
87+
var errorJson = invalidResponseJson(responseJson, 422);
88+
89+
assertThat(errorJson, is(XContentHelper.stripWhitespace("""
90+
{
91+
"error" : {
92+
"code" : "unprocessable_entity",
93+
"message" : "Received an input validation error response for request from inference entity id [id] status [422]. \
94+
Error message: [{\\"detail\\":\\"The provided model is not supported. \
95+
See https://docs.ai21.com/docs/jamba-foundation-models#api-versioning for a list of supported models\\"}]",
96+
"type" : "ai21_error"
97+
}
98+
}""")));
99+
}
100+
101+
private String invalidResponseJson(String responseJson, int statusCode) throws IOException {
102+
var exception = invalidResponse(responseJson, statusCode);
103+
assertThat(exception, isA(RetryException.class));
104+
assertThat(unwrapCause(exception), isA(UnifiedChatCompletionException.class));
105+
return toJson((UnifiedChatCompletionException) unwrapCause(exception));
106+
}
107+
108+
private Exception invalidResponse(String responseJson, int statusCode) {
109+
return expectThrows(
110+
RetryException.class,
111+
() -> responseHandler.validateResponse(
112+
mock(),
113+
mock(),
114+
mockRequest(),
115+
new HttpResult(mockErrorResponse(statusCode), responseJson.getBytes(StandardCharsets.UTF_8)),
116+
true
117+
)
118+
);
119+
}
120+
121+
private static Request mockRequest() throws URISyntaxException {
122+
var request = mock(Request.class);
123+
when(request.getInferenceEntityId()).thenReturn("id");
124+
when(request.isStreaming()).thenReturn(true);
125+
when(request.getURI()).thenReturn(new URI("https://api.ai21.com/studio/v1/chat/completions"));
126+
return request;
127+
}
128+
129+
private static HttpResponse mockErrorResponse(int statusCode) {
130+
var statusLine = mock(StatusLine.class);
131+
when(statusLine.getStatusCode()).thenReturn(statusCode);
132+
133+
var response = mock(HttpResponse.class);
134+
when(response.getStatusLine()).thenReturn(statusLine);
135+
136+
return response;
137+
}
138+
139+
private String toJson(UnifiedChatCompletionException e) throws IOException {
140+
try (var builder = XContentFactory.jsonBuilder()) {
141+
e.toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> {
142+
try {
143+
xContent.toXContent(builder, EMPTY_PARAMS);
144+
} catch (IOException ex) {
145+
throw new RuntimeException(ex);
146+
}
147+
});
148+
return XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType());
149+
}
150+
}
151+
152+
}

0 commit comments

Comments
 (0)