|
24 | 24 | import org.elasticsearch.xpack.core.inference.action.InferenceAction; |
25 | 25 | import org.elasticsearch.xpack.core.inference.action.InferenceActionProxy; |
26 | 26 | import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction; |
| 27 | +import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; |
27 | 28 | import org.elasticsearch.xpack.inference.registry.ModelRegistry; |
28 | 29 |
|
29 | 30 | import java.io.IOException; |
@@ -77,27 +78,33 @@ protected void doExecute(Task task, InferenceActionProxy.Request request, Action |
77 | 78 | } |
78 | 79 | } |
79 | 80 |
|
80 | | - private void sendUnifiedCompletionRequest(InferenceActionProxy.Request request, ActionListener<InferenceAction.Response> listener) |
81 | | - throws IOException { |
| 81 | + private void sendUnifiedCompletionRequest(InferenceActionProxy.Request request, ActionListener<InferenceAction.Response> listener) { |
| 82 | + // format any validation exceptions from the rest -> transport path as UnifiedChatCompletionException |
| 83 | + var unifiedErrorFormatListener = listener.delegateResponse((l, e) -> l.onFailure(UnifiedChatCompletionException.fromThrowable(e))); |
82 | 84 |
|
83 | | - if (request.isStreaming() == false) { |
84 | | - throw new ElasticsearchStatusException( |
85 | | - "The [chat_completion] task type only supports streaming, please try again with the _stream API", |
86 | | - RestStatus.BAD_REQUEST |
87 | | - ); |
88 | | - } |
| 85 | + try { |
| 86 | + if (request.isStreaming() == false) { |
| 87 | + unifiedErrorFormatListener.onFailure(new ElasticsearchStatusException( |
| 88 | + "The [chat_completion] task type only supports streaming, please try again with the _stream API", |
| 89 | + RestStatus.BAD_REQUEST |
| 90 | + )); |
| 91 | + return; |
| 92 | + } |
89 | 93 |
|
90 | | - UnifiedCompletionAction.Request unifiedRequest; |
91 | | - try (var parser = XContentHelper.createParser(XContentParserConfiguration.EMPTY, request.getContent(), request.getContentType())) { |
92 | | - unifiedRequest = UnifiedCompletionAction.Request.parseRequest( |
93 | | - request.getInferenceEntityId(), |
94 | | - request.getTaskType(), |
95 | | - request.getTimeout(), |
96 | | - parser |
97 | | - ); |
98 | | - } |
| 94 | + UnifiedCompletionAction.Request unifiedRequest; |
| 95 | + try (var parser = XContentHelper.createParser(XContentParserConfiguration.EMPTY, request.getContent(), request.getContentType())) { |
| 96 | + unifiedRequest = UnifiedCompletionAction.Request.parseRequest( |
| 97 | + request.getInferenceEntityId(), |
| 98 | + request.getTaskType(), |
| 99 | + request.getTimeout(), |
| 100 | + parser |
| 101 | + ); |
| 102 | + } |
99 | 103 |
|
100 | | - executeAsyncWithOrigin(client, INFERENCE_ORIGIN, UnifiedCompletionAction.INSTANCE, unifiedRequest, listener); |
| 104 | + executeAsyncWithOrigin(client, INFERENCE_ORIGIN, UnifiedCompletionAction.INSTANCE, unifiedRequest, unifiedErrorFormatListener); |
| 105 | + } catch (Exception e) { |
| 106 | + unifiedErrorFormatListener.onFailure(e); |
| 107 | + } |
101 | 108 | } |
102 | 109 |
|
103 | 110 | private void sendInferenceActionRequest(InferenceActionProxy.Request request, ActionListener<InferenceAction.Response> listener) |
|
0 commit comments