|
15 | 15 | import software.amazon.awssdk.services.bedrockruntime.model.Message; |
16 | 16 |
|
17 | 17 | import org.elasticsearch.ElasticsearchException; |
| 18 | +import org.elasticsearch.ElasticsearchStatusException; |
18 | 19 | import org.elasticsearch.action.support.PlainActionFuture; |
| 20 | +import org.elasticsearch.common.settings.SecureString; |
19 | 21 | import org.elasticsearch.core.TimeValue; |
20 | 22 | import org.elasticsearch.inference.InferenceServiceResults; |
| 23 | +import org.elasticsearch.inference.TaskType; |
21 | 24 | import org.elasticsearch.inference.UnifiedCompletionRequest; |
22 | 25 | import org.elasticsearch.test.ESTestCase; |
| 26 | +import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings; |
23 | 27 | import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProvider; |
| 28 | +import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionModel; |
24 | 29 | import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionModelTests; |
| 30 | +import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionServiceSettings; |
| 31 | +import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockCompletionTaskSettings; |
25 | 32 | import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsModelTests; |
26 | 33 | import org.elasticsearch.xpack.inference.services.amazonbedrock.request.completion.AmazonBedrockChatCompletionRequest; |
27 | 34 | import org.elasticsearch.xpack.inference.services.amazonbedrock.request.completion.AmazonBedrockChatCompletionRequestEntity; |
|
38 | 45 |
|
39 | 46 | import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; |
40 | 47 | import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat; |
| 48 | +import static org.elasticsearch.xpack.inference.action.TransportInferenceActionProxy.CHAT_COMPLETION_STREAMING_ONLY_EXCEPTION; |
41 | 49 | import static org.elasticsearch.xpack.inference.common.TruncatorTests.createTruncator; |
42 | | -import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProvider.ANTHROPIC; |
43 | 50 | import static org.hamcrest.Matchers.containsString; |
44 | 51 | import static org.hamcrest.Matchers.is; |
45 | 52 |
|
@@ -143,14 +150,14 @@ public void testExecute_CompletionFailsProperly_WithElasticsearchException() { |
143 | 150 | assertThat(exceptionThrown.getCause().getMessage(), containsString("test exception")); |
144 | 151 | } |
145 | 152 |
|
146 | | - public void testExecute_ChatCompletionRequest() { |
147 | | - var model = AmazonBedrockChatCompletionModelTests.createModel( |
| 153 | + public void testExecute_ChatCompletionRequest_NonStreaming_Fails() { |
| 154 | + var model = new AmazonBedrockChatCompletionModel( |
148 | 155 | "id", |
149 | | - "region", |
150 | | - "model", |
151 | | - AmazonBedrockProvider.AMAZONTITAN, |
152 | | - "accesskey", |
153 | | - "secretkey" |
| 156 | + TaskType.CHAT_COMPLETION, |
| 157 | + "amazonbedrock", |
| 158 | + new AmazonBedrockChatCompletionServiceSettings("region", "model", AmazonBedrockProvider.AMAZONTITAN, null), |
| 159 | + new AmazonBedrockCompletionTaskSettings(null, null, null, null), |
| 160 | + new AwsSecretSettings(new SecureString("accessKey"), new SecureString("secretKey")) |
154 | 161 | ); |
155 | 162 | var content = new UnifiedCompletionRequest.ContentString("content"); |
156 | 163 | var toolCall = new UnifiedCompletionRequest.ToolCall( |
@@ -178,43 +185,8 @@ public void testExecute_ChatCompletionRequest() { |
178 | 185 |
|
179 | 186 | var executor = new AmazonBedrockChatCompletionExecutor(request, responseHandler, logger, () -> false, listener, clientCache); |
180 | 187 | executor.run(); |
181 | | - var result = listener.actionGet(new TimeValue(30000)); |
182 | | - assertNotNull(result); |
183 | | - assertThat(result.asMap(), is(buildExpectationCompletion(List.of("converse result")))); |
184 | | - } |
185 | | - |
186 | | - public void testExecute_ChatCompletionFailsProperly_WithElasticsearchException() { |
187 | | - var model = AmazonBedrockChatCompletionModelTests.createModel("id", "region", "model", ANTHROPIC, "accesskey", "secretkey"); |
188 | | - var content = new UnifiedCompletionRequest.ContentString("content"); |
189 | | - var toolCall = new UnifiedCompletionRequest.ToolCall( |
190 | | - "id", |
191 | | - new UnifiedCompletionRequest.ToolCall.FunctionField("function", model.model()), |
192 | | - "" |
193 | | - ); |
194 | | - var message = new UnifiedCompletionRequest.Message(content, "user", "tooluse_Z7IP83_eTt2y_TECni1ULw", List.of(toolCall)); |
195 | | - |
196 | | - var requestEntity = new AmazonBedrockChatCompletionRequestEntity( |
197 | | - List.of(message), |
198 | | - model.model(), |
199 | | - 512L, |
200 | | - null, |
201 | | - null, |
202 | | - null, |
203 | | - null, |
204 | | - null |
205 | | - ); |
206 | | - var request = new AmazonBedrockChatCompletionRequest(model, requestEntity, null, false); |
207 | | - var responseHandler = new AmazonBedrockChatCompletionResponseHandler(); |
208 | | - |
209 | | - var clientCache = new AmazonBedrockMockClientCache(null, null, new ElasticsearchException("test exception")); |
210 | | - var listener = new PlainActionFuture<InferenceServiceResults>(); |
211 | | - |
212 | | - var executor = new AmazonBedrockChatCompletionExecutor(request, responseHandler, logger, () -> false, listener, clientCache); |
213 | | - executor.run(); |
214 | | - |
215 | | - var exceptionThrown = assertThrows(ElasticsearchException.class, () -> listener.actionGet(new TimeValue(30000))); |
216 | | - assertThat(exceptionThrown.getMessage(), containsString("Failed to send request from inference entity id [id]")); |
217 | | - assertThat(exceptionThrown.getCause().getMessage(), containsString("test exception")); |
| 188 | + var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(new TimeValue(30000))); |
| 189 | + assertThat(CHAT_COMPLETION_STREAMING_ONLY_EXCEPTION, is(exception)); |
218 | 190 | } |
219 | 191 |
|
220 | 192 | public static ConverseResponse getTestConverseResult(String resultText) { |
|
0 commit comments