Skip to content

Commit e2dce7c

Browse files
Add unit tests for LlamaChatCompletionResponseHandler to validate error response handling
1 parent c6fc56f commit e2dce7c

File tree

1 file changed

+162
-0
lines changed

1 file changed

+162
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.llama.completion;
9+
10+
import org.apache.http.HttpResponse;
11+
import org.apache.http.StatusLine;
12+
import org.elasticsearch.common.bytes.BytesReference;
13+
import org.elasticsearch.common.xcontent.XContentHelper;
14+
import org.elasticsearch.test.ESTestCase;
15+
import org.elasticsearch.xcontent.XContentFactory;
16+
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
17+
import org.elasticsearch.xpack.inference.external.http.HttpResult;
18+
import org.elasticsearch.xpack.inference.external.http.retry.RetryException;
19+
import org.elasticsearch.xpack.inference.external.request.Request;
20+
21+
import java.io.IOException;
22+
import java.net.URI;
23+
import java.net.URISyntaxException;
24+
import java.nio.charset.StandardCharsets;
25+
26+
import static org.elasticsearch.ExceptionsHelper.unwrapCause;
27+
import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS;
28+
import static org.hamcrest.Matchers.is;
29+
import static org.hamcrest.Matchers.isA;
30+
import static org.mockito.Mockito.mock;
31+
import static org.mockito.Mockito.when;
32+
33+
public class LlamaChatCompletionResponseHandlerTests extends ESTestCase {
34+
private final LlamaChatCompletionResponseHandler responseHandler = new LlamaChatCompletionResponseHandler(
35+
"chat completions",
36+
(a, b) -> mock()
37+
);
38+
39+
public void testFailNotFound() throws IOException {
40+
var responseJson = XContentHelper.stripWhitespace("""
41+
{
42+
"detail": "Not Found"
43+
}
44+
""");
45+
46+
var errorJson = invalidResponseJson(responseJson, 404);
47+
48+
assertThat(errorJson, is(XContentHelper.stripWhitespace("""
49+
{
50+
"error" : {
51+
"code" : "not_found",
52+
"message" : "Resource not found at [https://api.llama.ai/v1/chat/completions] for request from inference entity id [id] \
53+
status [404]. Error message: [{\\"detail\\":\\"Not Found\\"}]",
54+
"type" : "llama_error"
55+
}
56+
}""")));
57+
}
58+
59+
public void testFailBadRequest() throws IOException {
60+
var responseJson = XContentHelper.stripWhitespace("""
61+
{
62+
"error": {
63+
"detail": {
64+
"errors": [{
65+
"loc": [
66+
"body",
67+
"messages"
68+
],
69+
"msg": "Field required",
70+
"type": "missing"
71+
}
72+
]
73+
}
74+
}
75+
}
76+
""");
77+
78+
var errorJson = invalidResponseJson(responseJson, 400);
79+
80+
assertThat(errorJson, is(XContentHelper.stripWhitespace("""
81+
{
82+
"error": {
83+
"code": "bad_request",
84+
"message": "Received a bad request status code for request from inference entity id [id] status [400].\
85+
Error message: [{\\"error\\":{\\"detail\\":{\\"errors\\":[{\\"loc\\":[\\"body\\",\\"messages\\"],\\"msg\\":\\"Field\
86+
required\\",\\"type\\":\\"missing\\"}]}}}]",
87+
"type": "llama_error"
88+
}
89+
}
90+
""")));
91+
}
92+
93+
public void testFailValidationWithInvalidJson() throws IOException {
94+
var responseJson = """
95+
what? this isn't a json
96+
""";
97+
98+
var errorJson = invalidResponseJson(responseJson, 500);
99+
100+
assertThat(errorJson, is(XContentHelper.stripWhitespace("""
101+
{
102+
"error": {
103+
"code": "bad_request",
104+
"message": "Received a server error status code for request from inference entity id [id] status [500]. Error message: \
105+
[what? this isn't a json\\n]",
106+
"type": "llama_error"
107+
}
108+
}
109+
""")));
110+
}
111+
112+
private String invalidResponseJson(String responseJson, int statusCode) throws IOException {
113+
var exception = invalidResponse(responseJson, statusCode);
114+
assertThat(exception, isA(RetryException.class));
115+
assertThat(unwrapCause(exception), isA(UnifiedChatCompletionException.class));
116+
return toJson((UnifiedChatCompletionException) unwrapCause(exception));
117+
}
118+
119+
private Exception invalidResponse(String responseJson, int statusCode) {
120+
return expectThrows(
121+
RetryException.class,
122+
() -> responseHandler.validateResponse(
123+
mock(),
124+
mock(),
125+
mockRequest(),
126+
new HttpResult(mockErrorResponse(statusCode), responseJson.getBytes(StandardCharsets.UTF_8)),
127+
true
128+
)
129+
);
130+
}
131+
132+
private static Request mockRequest() throws URISyntaxException {
133+
var request = mock(Request.class);
134+
when(request.getInferenceEntityId()).thenReturn("id");
135+
when(request.isStreaming()).thenReturn(true);
136+
when(request.getURI()).thenReturn(new URI("https://api.llama.ai/v1/chat/completions"));
137+
return request;
138+
}
139+
140+
private static HttpResponse mockErrorResponse(int statusCode) {
141+
var statusLine = mock(StatusLine.class);
142+
when(statusLine.getStatusCode()).thenReturn(statusCode);
143+
144+
var response = mock(HttpResponse.class);
145+
when(response.getStatusLine()).thenReturn(statusLine);
146+
147+
return response;
148+
}
149+
150+
private String toJson(UnifiedChatCompletionException e) throws IOException {
151+
try (var builder = XContentFactory.jsonBuilder()) {
152+
e.toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> {
153+
try {
154+
xContent.toXContent(builder, EMPTY_PARAMS);
155+
} catch (IOException ex) {
156+
throw new RuntimeException(ex);
157+
}
158+
});
159+
return XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType());
160+
}
161+
}
162+
}

0 commit comments

Comments
 (0)