Skip to content

Commit 72c3f9e

Browse files
Adding translation tests
1 parent bbfc154 commit 72c3f9e

File tree

8 files changed

+722
-85
lines changed

8 files changed

+722
-85
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockChatCompletionStreamingProcessor.java

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,9 @@
2626

2727
import java.util.ArrayDeque;
2828
import java.util.List;
29-
import java.util.concurrent.Flow;
3029
import java.util.stream.Stream;
3130

32-
@SuppressWarnings("checkstyle:LineLength")
33-
class AmazonBedrockChatCompletionStreamingProcessor extends AmazonBedrockStreamingProcessor<StreamingUnifiedChatCompletionResults.Results>
34-
implements
35-
Flow.Processor<ConverseStreamOutput, StreamingUnifiedChatCompletionResults.Results> {
31+
class AmazonBedrockChatCompletionStreamingProcessor extends AmazonBedrockStreamingProcessor<StreamingUnifiedChatCompletionResults.Results> {
3632
private static final Logger logger = LogManager.getLogger(AmazonBedrockChatCompletionStreamingProcessor.class);
3733

3834
protected AmazonBedrockChatCompletionStreamingProcessor(ThreadPool threadPool) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockCompletionStreamingProcessor.java

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,8 @@
1515
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
1616

1717
import java.util.ArrayDeque;
18-
import java.util.concurrent.Flow;
1918

20-
class AmazonBedrockCompletionStreamingProcessor extends AmazonBedrockStreamingProcessor<StreamingChatCompletionResults.Results>
21-
implements
22-
Flow.Processor<ConverseStreamOutput, StreamingChatCompletionResults.Results> {
19+
class AmazonBedrockCompletionStreamingProcessor extends AmazonBedrockStreamingProcessor<StreamingChatCompletionResults.Results> {
2320
protected AmazonBedrockCompletionStreamingProcessor(ThreadPool threadPool) {
2421
super(threadPool);
2522
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockStreamingProcessor.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
package org.elasticsearch.xpack.inference.services.amazonbedrock.client;
99

10+
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamOutput;
11+
1012
import org.elasticsearch.ElasticsearchException;
1113
import org.elasticsearch.ExceptionsHelper;
1214
import org.elasticsearch.common.Strings;
@@ -22,7 +24,7 @@
2224

2325
import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME;
2426

25-
class AmazonBedrockStreamingProcessor<T> {
27+
abstract class AmazonBedrockStreamingProcessor<T> implements Flow.Processor<ConverseStreamOutput, T> {
2628
private static final Logger logger = LogManager.getLogger(AmazonBedrockStreamingProcessor.class);
2729

2830
final AtomicReference<Throwable> error = new AtomicReference<>(null);
@@ -35,6 +37,7 @@ class AmazonBedrockStreamingProcessor<T> {
3537

3638
volatile Flow.Subscriber<? super T> downstream;
3739

40+
@Override
3841
public void onSubscribe(Flow.Subscription subscription) {
3942
if (upstream == null) {
4043
upstream = subscription;
@@ -47,6 +50,7 @@ public void onSubscribe(Flow.Subscription subscription) {
4750
}
4851
}
4952

53+
@Override
5054
public void subscribe(Flow.Subscriber<? super T> subscriber) {
5155
if (downstream == null) {
5256
downstream = subscriber;
@@ -56,6 +60,7 @@ public void subscribe(Flow.Subscriber<? super T> subscriber) {
5660
}
5761
}
5862

63+
@Override
5964
public void onError(Throwable amazonBedrockRuntimeException) {
6065
ExceptionsHelper.maybeDieOnAnotherThread(amazonBedrockRuntimeException);
6166
error.set(
@@ -73,6 +78,7 @@ private boolean checkAndResetDemand() {
7378
return demand.getAndUpdate(i -> 0L) > 0L;
7479
}
7580

81+
@Override
7682
public void onComplete() {
7783
if (isDone.compareAndSet(false, true) && checkAndResetDemand() && onCompleteCalled.compareAndSet(false, true)) {
7884
downstream.onComplete();

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/request/completion/AmazonBedrockChatCompletionRequest.java

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,13 @@
4141
import static org.elasticsearch.xpack.inference.services.amazonbedrock.request.completion.AmazonBedrockConverseUtils.toDocument;
4242

4343
public class AmazonBedrockChatCompletionRequest extends AmazonBedrockRequest {
44-
private static final String AUTO_TOOL_CHOICE = "auto";
45-
private static final String REQUIRED_TOOL_CHOICE = "required";
46-
private static final String NONE_TOOL_CHOICE = "none";
44+
static final String AUTO_TOOL_CHOICE = "auto";
45+
static final String REQUIRED_TOOL_CHOICE = "required";
46+
static final String NONE_TOOL_CHOICE = "none";
47+
static final String FUNCTION_TYPE = "function";
48+
4749
private static final Set<String> VALID_TOOL_CHOICES = Set.of(AUTO_TOOL_CHOICE, REQUIRED_TOOL_CHOICE, NONE_TOOL_CHOICE);
48-
private static final String FUNCTION_TYPE = "function";
4950

50-
public static final String USER_ROLE = "user";
5151
private final AmazonBedrockChatCompletionRequestEntity requestEntity;
5252
private final boolean stream;
5353

@@ -106,7 +106,8 @@ private static ToolChoice.Builder convertToolChoice(
106106
return determineToolChoice(toolChoice);
107107
}
108108

109-
private static ToolChoice.Builder determineToolChoice(@Nullable UnifiedCompletionRequest.ToolChoice toolChoice) {
109+
// default for testing
110+
static ToolChoice.Builder determineToolChoice(@Nullable UnifiedCompletionRequest.ToolChoice toolChoice) {
110111
// If a specific tool choice isn't provided, the chat completion schema (openai) defaults to "auto"
111112
if (toolChoice == null) {
112113
return ToolChoice.builder().auto(AutoToolChoice.builder().build());
@@ -126,7 +127,8 @@ private static ToolChoice.Builder determineToolChoice(@Nullable UnifiedCompletio
126127
};
127128
}
128129

129-
private static List<Tool> convertTools(@Nullable List<UnifiedCompletionRequest.Tool> tools) {
130+
// default for testing
131+
static List<Tool> convertTools(@Nullable List<UnifiedCompletionRequest.Tool> tools) {
130132
if (tools == null || tools.isEmpty()) {
131133
return List.of();
132134
}
@@ -143,6 +145,7 @@ private static List<Tool> convertTools(@Nullable List<UnifiedCompletionRequest.T
143145
builtTools.add(
144146
Tool.builder()
145147
.toolSpec(
148+
// Bedrock does not use the strict field
146149
ToolSpecification.builder()
147150
.name(requestTool.function().name())
148151
.description(requestTool.function().description())
@@ -156,7 +159,8 @@ private static List<Tool> convertTools(@Nullable List<UnifiedCompletionRequest.T
156159
return builtTools;
157160
}
158161

159-
private static Map<String, Document> paramToDocumentMap(UnifiedCompletionRequest.Tool tool) {
162+
// default for testing
163+
static Map<String, Document> paramToDocumentMap(UnifiedCompletionRequest.Tool tool) {
160164
if (tool.function().parameters() == null) {
161165
return Map.of();
162166
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/request/completion/AmazonBedrockConverseUtils.java

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -35,20 +35,22 @@
3535

3636
public final class AmazonBedrockConverseUtils {
3737

38-
private static final String SYSTEM_ROLE = "system";
39-
private static final String TOOL_ROLE = "tool";
40-
private static final String ASSISTANT_ROLE = "assistant";
41-
private static final String USER_ROLE = "user";
42-
private static final String TEXT_CONTENT_TYPE = "text";
43-
44-
private static final Message DEFAULT_ASSISTANT_MESSAGE = Message.builder()
38+
static final String SYSTEM_ROLE = "system";
39+
static final String TOOL_ROLE = "tool";
40+
static final String ASSISTANT_ROLE = "assistant";
41+
static final String USER_ROLE = "user";
42+
static final String TEXT_CONTENT_TYPE = "text";
43+
static final String HI_TEXT = "Hi";
44+
static final String PLEASE_CONTINUE_TEXT = "Please continue.";
45+
46+
static final Message DEFAULT_USER_MESSAGE = Message.builder()
4547
.role(ConversationRole.USER)
46-
.content(ContentBlock.builder().text("Hi").build())
48+
.content(ContentBlock.builder().text(HI_TEXT).build())
4749
.build();
4850

49-
private static final Message CONTINUE_ASSISTANT_MESSAGE = Message.builder()
51+
static final Message CONTINUE_ASSISTANT_MESSAGE = Message.builder()
5052
.role(ConversationRole.ASSISTANT)
51-
.content(ContentBlock.builder().text("Please continue.").build())
53+
.content(ContentBlock.builder().text(PLEASE_CONTINUE_TEXT).build())
5254
.build();
5355

5456
public static List<Message> getConverseMessageList(List<String> texts) {
@@ -58,17 +60,6 @@ public static List<Message> getConverseMessageList(List<String> texts) {
5860
.toList();
5961
}
6062

61-
public static List<Message> getUnifiedConverseMessageList(List<UnifiedCompletionRequest.Message> messages) {
62-
return messages.stream()
63-
.map(
64-
message -> Message.builder()
65-
.role(message.role())
66-
.content(ContentBlock.builder().text(message.content().toString()).build())
67-
.build()
68-
)
69-
.toList();
70-
}
71-
7263
public record TranslatedMessages(List<Message> messages, List<SystemContentBlock> systemContent) {}
7364

7465
public static TranslatedMessages convertChatCompletionMessagesToConverse(
@@ -123,7 +114,7 @@ private static List<Message> mergeOrAddMessage(Message message, List<Message> me
123114
// If this is the first message and it's an assistant, prepend a default user message
124115
// System messages are not of concern here and tool message are also just user messages
125116
if (messages.isEmpty() && message.role().equals(ConversationRole.ASSISTANT)) {
126-
messages.add(DEFAULT_ASSISTANT_MESSAGE);
117+
messages.add(DEFAULT_USER_MESSAGE);
127118
}
128119

129120
// Check if we should consider merging (not first message and same role)

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockExecutorTests.java

Lines changed: 17 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,20 @@
1515
import software.amazon.awssdk.services.bedrockruntime.model.Message;
1616

1717
import org.elasticsearch.ElasticsearchException;
18+
import org.elasticsearch.ElasticsearchStatusException;
1819
import org.elasticsearch.action.support.PlainActionFuture;
20+
import org.elasticsearch.common.settings.SecureString;
1921
import org.elasticsearch.core.TimeValue;
2022
import org.elasticsearch.inference.InferenceServiceResults;
23+
import org.elasticsearch.inference.TaskType;
2124
import org.elasticsearch.inference.UnifiedCompletionRequest;
2225
import org.elasticsearch.test.ESTestCase;
26+
import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings;
2327
import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProvider;
28+
import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionModel;
2429
import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionModelTests;
30+
import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionServiceSettings;
31+
import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockCompletionTaskSettings;
2532
import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsModelTests;
2633
import org.elasticsearch.xpack.inference.services.amazonbedrock.request.completion.AmazonBedrockChatCompletionRequest;
2734
import org.elasticsearch.xpack.inference.services.amazonbedrock.request.completion.AmazonBedrockChatCompletionRequestEntity;
@@ -38,8 +45,8 @@
3845

3946
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
4047
import static org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResultsTests.buildExpectationFloat;
48+
import static org.elasticsearch.xpack.inference.action.TransportInferenceActionProxy.CHAT_COMPLETION_STREAMING_ONLY_EXCEPTION;
4149
import static org.elasticsearch.xpack.inference.common.TruncatorTests.createTruncator;
42-
import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProvider.ANTHROPIC;
4350
import static org.hamcrest.Matchers.containsString;
4451
import static org.hamcrest.Matchers.is;
4552

@@ -143,14 +150,14 @@ public void testExecute_CompletionFailsProperly_WithElasticsearchException() {
143150
assertThat(exceptionThrown.getCause().getMessage(), containsString("test exception"));
144151
}
145152

146-
public void testExecute_ChatCompletionRequest() {
147-
var model = AmazonBedrockChatCompletionModelTests.createModel(
153+
public void testExecute_ChatCompletionRequest_NonStreaming_Fails() {
154+
var model = new AmazonBedrockChatCompletionModel(
148155
"id",
149-
"region",
150-
"model",
151-
AmazonBedrockProvider.AMAZONTITAN,
152-
"accesskey",
153-
"secretkey"
156+
TaskType.CHAT_COMPLETION,
157+
"amazonbedrock",
158+
new AmazonBedrockChatCompletionServiceSettings("region", "model", AmazonBedrockProvider.AMAZONTITAN, null),
159+
new AmazonBedrockCompletionTaskSettings(null, null, null, null),
160+
new AwsSecretSettings(new SecureString("accessKey"), new SecureString("secretKey"))
154161
);
155162
var content = new UnifiedCompletionRequest.ContentString("content");
156163
var toolCall = new UnifiedCompletionRequest.ToolCall(
@@ -178,43 +185,8 @@ public void testExecute_ChatCompletionRequest() {
178185

179186
var executor = new AmazonBedrockChatCompletionExecutor(request, responseHandler, logger, () -> false, listener, clientCache);
180187
executor.run();
181-
var result = listener.actionGet(new TimeValue(30000));
182-
assertNotNull(result);
183-
assertThat(result.asMap(), is(buildExpectationCompletion(List.of("converse result"))));
184-
}
185-
186-
public void testExecute_ChatCompletionFailsProperly_WithElasticsearchException() {
187-
var model = AmazonBedrockChatCompletionModelTests.createModel("id", "region", "model", ANTHROPIC, "accesskey", "secretkey");
188-
var content = new UnifiedCompletionRequest.ContentString("content");
189-
var toolCall = new UnifiedCompletionRequest.ToolCall(
190-
"id",
191-
new UnifiedCompletionRequest.ToolCall.FunctionField("function", model.model()),
192-
""
193-
);
194-
var message = new UnifiedCompletionRequest.Message(content, "user", "tooluse_Z7IP83_eTt2y_TECni1ULw", List.of(toolCall));
195-
196-
var requestEntity = new AmazonBedrockChatCompletionRequestEntity(
197-
List.of(message),
198-
model.model(),
199-
512L,
200-
null,
201-
null,
202-
null,
203-
null,
204-
null
205-
);
206-
var request = new AmazonBedrockChatCompletionRequest(model, requestEntity, null, false);
207-
var responseHandler = new AmazonBedrockChatCompletionResponseHandler();
208-
209-
var clientCache = new AmazonBedrockMockClientCache(null, null, new ElasticsearchException("test exception"));
210-
var listener = new PlainActionFuture<InferenceServiceResults>();
211-
212-
var executor = new AmazonBedrockChatCompletionExecutor(request, responseHandler, logger, () -> false, listener, clientCache);
213-
executor.run();
214-
215-
var exceptionThrown = assertThrows(ElasticsearchException.class, () -> listener.actionGet(new TimeValue(30000)));
216-
assertThat(exceptionThrown.getMessage(), containsString("Failed to send request from inference entity id [id]"));
217-
assertThat(exceptionThrown.getCause().getMessage(), containsString("test exception"));
188+
var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(new TimeValue(30000)));
189+
assertThat(CHAT_COMPLETION_STREAMING_ONLY_EXCEPTION, is(exception));
218190
}
219191

220192
public static ConverseResponse getTestConverseResult(String resultText) {

0 commit comments

Comments
 (0)