Skip to content

Commit 977bfc4

Browse files
Add unit tests for MistralUnifiedChatCompletionResponseHandler to validate error handling
1 parent 5cc7402 commit 977bfc4

File tree

1 file changed

+155
-0
lines changed

1 file changed

+155
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.mistral;
9+
10+
import org.apache.http.HttpResponse;
11+
import org.apache.http.StatusLine;
12+
import org.elasticsearch.common.bytes.BytesReference;
13+
import org.elasticsearch.common.xcontent.XContentHelper;
14+
import org.elasticsearch.test.ESTestCase;
15+
import org.elasticsearch.xcontent.XContentFactory;
16+
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
17+
import org.elasticsearch.xpack.inference.external.http.HttpResult;
18+
import org.elasticsearch.xpack.inference.external.http.retry.RetryException;
19+
import org.elasticsearch.xpack.inference.external.request.Request;
20+
21+
import java.io.IOException;
22+
import java.net.URI;
23+
import java.net.URISyntaxException;
24+
import java.nio.charset.StandardCharsets;
25+
26+
import static org.elasticsearch.ExceptionsHelper.unwrapCause;
27+
import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS;
28+
import static org.hamcrest.Matchers.is;
29+
import static org.hamcrest.Matchers.isA;
30+
import static org.mockito.Mockito.mock;
31+
import static org.mockito.Mockito.when;
32+
33+
public class MistralUnifiedChatCompletionResponseHandlerTests extends ESTestCase {
34+
private final MistralUnifiedChatCompletionResponseHandler responseHandler = new MistralUnifiedChatCompletionResponseHandler(
35+
"chat completions",
36+
(a, b) -> mock()
37+
);
38+
39+
public void testFailNotFound() throws IOException {
40+
var responseJson = XContentHelper.stripWhitespace("""
41+
{
42+
"detail": "Not Found"
43+
}
44+
""");
45+
46+
var errorJson = invalidResponseJson(responseJson, 404);
47+
48+
assertThat(errorJson, is(XContentHelper.stripWhitespace("""
49+
{
50+
"error" : {
51+
"code" : "not_found",
52+
"message" : "Resource not found at [https://api.mistral.ai/v1/chat/completions] for request from inference entity id [id] \
53+
status [404]. Error message: [{\\"detail\\":\\"Not Found\\"}]",
54+
"type" : "mistral_error"
55+
}
56+
}""")));
57+
}
58+
59+
public void testFailUnauthorized() throws IOException {
60+
var responseJson = XContentHelper.stripWhitespace("""
61+
{
62+
"message": "Unauthorized",
63+
"request_id": "a580d263fb1521778782b22104efb415"
64+
}
65+
""");
66+
67+
var errorJson = invalidResponseJson(responseJson, 401);
68+
69+
assertThat(errorJson, is(XContentHelper.stripWhitespace("""
70+
{
71+
"error" : {
72+
"code" : "unauthorized",
73+
"message" : "Received an authentication error status code for request from inference entity id [id] status [401]. Error \
74+
message: [{\\"message\\":\\"Unauthorized\\",\\"request_id\\":\\"a580d263fb1521778782b22104efb415\\"}]",
75+
"type" : "mistral_error"
76+
}
77+
}""")));
78+
}
79+
80+
public void testFailBadRequest() throws IOException {
81+
var responseJson = XContentHelper.stripWhitespace("""
82+
{
83+
"object": "error",
84+
"message": "Invalid model: mistral-small-l2atest",
85+
"type": "invalid_model",
86+
"param": null,
87+
"code": "1500"
88+
}
89+
""");
90+
91+
var errorJson = invalidResponseJson(responseJson, 400);
92+
93+
assertThat(errorJson, is(XContentHelper.stripWhitespace("""
94+
{
95+
"error" : {
96+
"code" : "bad_request",
97+
"message" : "Received a bad request status code for request from inference entity id [id] status [400]. Error message: \
98+
[{\\"object\\":\\"error\\",\\"message\\":\\"Invalid model: mistral-small-l2atest\\",\\"type\\":\\"invalid_model\\",\\"par\
99+
am\\":null,\\"code\\":\\"1500\\"}]",
100+
"type" : "mistral_error"
101+
}
102+
}""")));
103+
}
104+
105+
private String invalidResponseJson(String responseJson, int statusCode) throws IOException {
106+
var exception = invalidResponse(responseJson, statusCode);
107+
assertThat(exception, isA(RetryException.class));
108+
assertThat(unwrapCause(exception), isA(UnifiedChatCompletionException.class));
109+
return toJson((UnifiedChatCompletionException) unwrapCause(exception));
110+
}
111+
112+
private Exception invalidResponse(String responseJson, int statusCode) {
113+
return expectThrows(
114+
RetryException.class,
115+
() -> responseHandler.validateResponse(
116+
mock(),
117+
mock(),
118+
mockRequest(),
119+
new HttpResult(mockErrorResponse(statusCode), responseJson.getBytes(StandardCharsets.UTF_8)),
120+
true
121+
)
122+
);
123+
}
124+
125+
private static Request mockRequest() throws URISyntaxException {
126+
var request = mock(Request.class);
127+
when(request.getInferenceEntityId()).thenReturn("id");
128+
when(request.isStreaming()).thenReturn(true);
129+
when(request.getURI()).thenReturn(new URI("https://api.mistral.ai/v1/chat/completions"));
130+
return request;
131+
}
132+
133+
private static HttpResponse mockErrorResponse(int statusCode) {
134+
var statusLine = mock(StatusLine.class);
135+
when(statusLine.getStatusCode()).thenReturn(statusCode);
136+
137+
var response = mock(HttpResponse.class);
138+
when(response.getStatusLine()).thenReturn(statusLine);
139+
140+
return response;
141+
}
142+
143+
private String toJson(UnifiedChatCompletionException e) throws IOException {
144+
try (var builder = XContentFactory.jsonBuilder()) {
145+
e.toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> {
146+
try {
147+
xContent.toXContent(builder, EMPTY_PARAMS);
148+
} catch (IOException ex) {
149+
throw new RuntimeException(ex);
150+
}
151+
});
152+
return XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType());
153+
}
154+
}
155+
}

0 commit comments

Comments
 (0)