Skip to content

Commit b086620

Browse files
Refactor error handling and add unit tests for Ai21 service
1 parent 8a185b9 commit b086620

File tree

6 files changed

+623
-7
lines changed

6 files changed

+623
-7
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/ErrorResponse.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,4 +68,20 @@ public static ErrorResponse fromResponse(HttpResult response) {
6868

6969
return ErrorResponse.UNDEFINED_ERROR;
7070
}
71+
72+
/**
73+
* Parses a string response into an ErrorResponse.
74+
* If the string is not blank, creates a new ErrorResponse with the string as the error message.
75+
* If the string is blank, returns UNDEFINED_ERROR.
76+
*
77+
* @param response the error response as a string
78+
* @return an ErrorResponse instance
79+
*/
80+
public static ErrorResponse fromString(String response) {
81+
if (Objects.nonNull(response) && response.isBlank() == false) {
82+
return new ErrorResponse(response);
83+
} else {
84+
return ErrorResponse.UNDEFINED_ERROR;
85+
}
86+
}
7187
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/completion/Ai21ChatCompletionResponseHandler.java

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
package org.elasticsearch.xpack.inference.services.ai21.completion;
99

10+
import org.elasticsearch.rest.RestStatus;
1011
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
1112
import org.elasticsearch.xpack.inference.external.http.HttpResult;
1213
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
@@ -16,6 +17,8 @@
1617

1718
import java.util.Locale;
1819

20+
import static org.elasticsearch.core.Strings.format;
21+
1922
/**
2023
* Handles streaming chat completion responses and error parsing for Mistral inference endpoints.
2124
* Adapts the OpenAI handler to support Mistral's error schema.
@@ -36,4 +39,30 @@ protected Exception buildError(String message, Request request, HttpResult resul
3639
var restStatus = toRestStatus(responseStatusCode);
3740
return new UnifiedChatCompletionException(restStatus, errorMessage, AI_21_ERROR, restStatus.name().toLowerCase(Locale.ROOT));
3841
}
42+
43+
protected Exception buildMidStreamError(Request request, String message, Exception e) {
44+
var errorResponse = ErrorResponse.fromString(message);
45+
if (errorResponse.errorStructureFound()) {
46+
return new UnifiedChatCompletionException(
47+
RestStatus.INTERNAL_SERVER_ERROR,
48+
format(
49+
"%s for request from inference entity id [%s]. Error message: [%s]",
50+
SERVER_ERROR_OBJECT,
51+
request.getInferenceEntityId(),
52+
errorResponse.getErrorMessage()
53+
),
54+
AI_21_ERROR,
55+
"stream_error"
56+
);
57+
} else if (e != null) {
58+
return UnifiedChatCompletionException.fromThrowable(e);
59+
} else {
60+
return new UnifiedChatCompletionException(
61+
RestStatus.INTERNAL_SERVER_ERROR,
62+
format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, request.getInferenceEntityId()),
63+
createErrorType(errorResponse),
64+
"stream_error"
65+
);
66+
}
67+
}
3968
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ public static ErrorResponse fromResponse(HttpResult response) {
149149
}
150150
}
151151

152-
static ErrorResponse fromString(String response) {
152+
public static ErrorResponse fromString(String response) {
153153
try (
154154
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
155155
.createParser(XContentParserConfiguration.EMPTY, response)

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionResponseHandler.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ private static class StreamingHuggingFaceErrorResponseEntity extends ErrorRespon
141141
* @param response the raw JSON string representing an error
142142
* @return a parsed {@link ErrorResponse} or {@link ErrorResponse#UNDEFINED_ERROR} if parsing fails
143143
*/
144-
private static ErrorResponse fromString(String response) {
144+
public static ErrorResponse fromString(String response) {
145145
try (
146146
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
147147
.createParser(XContentParserConfiguration.EMPTY, response)

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -351,11 +351,17 @@ public void testParsePersistedConfigWithSecrets_ThrowsUnsupportedModelType() thr
351351

352352
assertThat(
353353
exception.getMessage(),
354-
containsString(Strings.format("service does not support task type [%s]", parseConfigTestConfig.unsupportedTaskType))
354+
containsString(
355+
Strings.format(fetchPersistedConfigTaskTypeParsingErrorMessageFormat(), parseConfigTestConfig.unsupportedTaskType)
356+
)
355357
);
356358
}
357359
}
358360

361+
protected String fetchPersistedConfigTaskTypeParsingErrorMessageFormat() {
362+
return "service does not support task type [%s]";
363+
}
364+
359365
public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException {
360366
var parseConfigTestConfig = testConfiguration.commonConfig;
361367

@@ -374,7 +380,7 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists
374380
persistedConfigMap.secrets()
375381
);
376382

377-
parseConfigTestConfig.assertModel(model, TaskType.TEXT_EMBEDDING);
383+
parseConfigTestConfig.assertModel(model, parseConfigTestConfig.taskType);
378384
}
379385
}
380386

@@ -396,7 +402,7 @@ public void testParsePersistedConfigWithSecrets_ThrowsWhenAnExtraKeyExistsInServ
396402
persistedConfigMap.secrets()
397403
);
398404

399-
parseConfigTestConfig.assertModel(model, TaskType.TEXT_EMBEDDING);
405+
parseConfigTestConfig.assertModel(model, parseConfigTestConfig.taskType);
400406
}
401407
}
402408

@@ -413,7 +419,7 @@ public void testParsePersistedConfigWithSecrets_ThrowsWhenAnExtraKeyExistsInTask
413419

414420
var model = service.parsePersistedConfigWithSecrets("id", parseConfigTestConfig.taskType, config.config(), config.secrets());
415421

416-
parseConfigTestConfig.assertModel(model, TaskType.TEXT_EMBEDDING);
422+
parseConfigTestConfig.assertModel(model, parseConfigTestConfig.taskType);
417423
}
418424
}
419425

@@ -430,7 +436,7 @@ public void testParsePersistedConfigWithSecrets_ThrowsWhenAnExtraKeyExistsInSecr
430436

431437
var model = service.parsePersistedConfigWithSecrets("id", parseConfigTestConfig.taskType, config.config(), config.secrets());
432438

433-
parseConfigTestConfig.assertModel(model, TaskType.TEXT_EMBEDDING);
439+
parseConfigTestConfig.assertModel(model, parseConfigTestConfig.taskType);
434440
}
435441
}
436442

0 commit comments

Comments
 (0)