Skip to content

Commit 5953e9b

Browse files
Addressing feedback
1 parent 91763be commit 5953e9b

File tree

14 files changed

+337
-133
lines changed

14 files changed

+337
-133
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockChatCompletionExecutor.java

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616

1717
import java.util.function.Supplier;
1818

19-
import static org.elasticsearch.xpack.inference.action.TransportInferenceActionProxy.CHAT_COMPLETION_STREAMING_ONLY_EXCEPTION;
20-
2119
public class AmazonBedrockChatCompletionExecutor extends AmazonBedrockExecutor {
2220
private final AmazonBedrockChatCompletionRequest chatCompletionRequest;
2321

@@ -35,11 +33,7 @@ protected AmazonBedrockChatCompletionExecutor(
3533

3634
@Override
3735
protected void executeClientRequest(AmazonBedrockBaseClient awsBedrockClient) {
38-
// Chat completions only supports streaming
39-
if (chatCompletionRequest.isStreaming() == false) {
40-
inferenceResultsListener.onFailure(CHAT_COMPLETION_STREAMING_ONLY_EXCEPTION);
41-
return;
42-
}
36+
assert chatCompletionRequest.isStreaming() : "The chat_completion task type only supports streaming";
4337

4438
var publisher = chatCompletionRequest.executeStreamChatCompletionRequest(awsBedrockClient);
4539
inferenceResultsListener.onResponse(new StreamingUnifiedChatCompletionResults(publisher));

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockStreamingProcessor.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@
2727
abstract class AmazonBedrockStreamingProcessor<T> implements Flow.Processor<ConverseStreamOutput, T> {
2828
private static final Logger logger = LogManager.getLogger(AmazonBedrockStreamingProcessor.class);
2929

30-
final AtomicReference<Throwable> error = new AtomicReference<>(null);
30+
private final AtomicReference<Throwable> error = new AtomicReference<>(null);
31+
private final AtomicBoolean onErrorCalled = new AtomicBoolean(false);
32+
private final ThreadPool threadPool;
3133
/**
3234
* The purpose of demand is solely to guard against the situation where the bedrock sdk can complete the future before the publisher
3335
* and subscriber aren't connected together via {@link #subscribe(Flow.Subscriber)} and {@link AmazonBedrockInferenceClient}
@@ -38,8 +40,7 @@ abstract class AmazonBedrockStreamingProcessor<T> implements Flow.Processor<Conv
3840
final AtomicLong demand = new AtomicLong(0);
3941
final AtomicBoolean isDone = new AtomicBoolean(false);
4042
final AtomicBoolean onCompleteCalled = new AtomicBoolean(false);
41-
final AtomicBoolean onErrorCalled = new AtomicBoolean(false);
42-
final ThreadPool threadPool;
43+
4344
volatile Flow.Subscription upstream;
4445

4546
volatile Flow.Subscriber<? super T> downstream;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionModel.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,12 @@ public static AmazonBedrockChatCompletionModel of(AmazonBedrockChatCompletionMod
3131

3232
var requestTaskSettings = AmazonBedrockCompletionRequestTaskSettings.fromMap(taskSettings);
3333
var taskSettingsToUse = AmazonBedrockCompletionTaskSettings.of(completionModel.getTaskSettings(), requestTaskSettings);
34+
35+
// If the task settings didn't change, then return the same model
36+
if (taskSettingsToUse.equals(completionModel.getTaskSettings())) {
37+
return completionModel;
38+
}
39+
3440
return new AmazonBedrockChatCompletionModel(completionModel, taskSettingsToUse);
3541
}
3642

@@ -42,16 +48,18 @@ public static AmazonBedrockChatCompletionModel of(AmazonBedrockChatCompletionMod
4248
* @return A new AmazonBedrockChatCompletionModel with the overridden model ID.
4349
*/
4450
public static AmazonBedrockChatCompletionModel of(AmazonBedrockChatCompletionModel model, UnifiedCompletionRequest request) {
45-
if (request.model() == null) {
51+
if (request.model() == null || request.model().equals(model.getServiceSettings().modelId())) {
4652
return model;
4753
}
54+
4855
var originalModelServiceSettings = model.getServiceSettings();
4956
var overriddenServiceSettings = new AmazonBedrockChatCompletionServiceSettings(
5057
originalModelServiceSettings.region(),
5158
Objects.requireNonNull(request.model(), originalModelServiceSettings.modelId()),
5259
originalModelServiceSettings.provider(),
5360
originalModelServiceSettings.rateLimitSettings()
5461
);
62+
5563
return new AmazonBedrockChatCompletionModel(
5664
model.getInferenceEntityId(),
5765
model.getTaskType(),

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockCompletionTaskSettings.java

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,18 @@
3232

3333
public class AmazonBedrockCompletionTaskSettings implements TaskSettings {
3434
public static final String NAME = "amazon_bedrock_chat_completion_task_settings";
35-
36-
public static final AmazonBedrockCompletionRequestTaskSettings EMPTY_SETTINGS = new AmazonBedrockCompletionRequestTaskSettings(
35+
private static final AmazonBedrockCompletionTaskSettings EMPTY_SETTINGS = new AmazonBedrockCompletionTaskSettings(
3736
null,
3837
null,
3938
null,
4039
null
4140
);
4241

4342
public static AmazonBedrockCompletionTaskSettings fromMap(Map<String, Object> settings) {
43+
if (settings.isEmpty()) {
44+
return EMPTY_SETTINGS;
45+
}
46+
4447
ValidationException validationException = new ValidationException();
4548

4649
Double temperature = extractOptionalDoubleInRange(
@@ -90,6 +93,13 @@ public static AmazonBedrockCompletionTaskSettings of(
9093
var topK = requestSettings.topK() == null ? originalSettings.topK() : requestSettings.topK();
9194
var maxNewTokens = requestSettings.maxNewTokens() == null ? originalSettings.maxNewTokens() : requestSettings.maxNewTokens();
9295

96+
if (Objects.equals(temperature, originalSettings.temperature())
97+
&& Objects.equals(topP, originalSettings.topP())
98+
&& Objects.equals(topK, originalSettings.topK())
99+
&& Objects.equals(maxNewTokens, originalSettings.maxNewTokens)) {
100+
return originalSettings;
101+
}
102+
93103
return new AmazonBedrockCompletionTaskSettings(temperature, topP, topK, maxNewTokens);
94104
}
95105

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/request/AmazonBedrockRequest.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import org.elasticsearch.xpack.inference.services.amazonbedrock.client.AmazonBedrockBaseClient;
1717

1818
import java.net.URI;
19+
import java.util.Objects;
1920

2021
public abstract class AmazonBedrockRequest implements Request {
2122

@@ -24,7 +25,7 @@ public abstract class AmazonBedrockRequest implements Request {
2425
protected final TimeValue timeout;
2526

2627
protected AmazonBedrockRequest(AmazonBedrockModel model, @Nullable TimeValue timeout) {
27-
this.amazonBedrockModel = model;
28+
this.amazonBedrockModel = Objects.requireNonNull(model);
2829
this.inferenceId = model.getInferenceEntityId();
2930
this.timeout = timeout;
3031
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/request/completion/AmazonBedrockChatCompletionRequest.java

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ public TaskType taskType() {
7474
public Flow.Publisher<StreamingUnifiedChatCompletionResults.Results> executeStreamChatCompletionRequest(
7575
AmazonBedrockBaseClient awsBedrockClient
7676
) {
77-
var toolChoice = convertToolChoice(requestEntity.tools(), requestEntity.toolChoice());
77+
var toolChoice = buildToolChoice(requestEntity.tools(), requestEntity.toolChoice());
7878

7979
var toolsEnabled = toolChoice != null;
8080

@@ -86,23 +86,29 @@ public Flow.Publisher<StreamingUnifiedChatCompletionResults.Results> executeStre
8686

8787
if (toolsEnabled) {
8888
converseStreamRequest.toolConfig(
89-
ToolConfiguration.builder().tools(convertTools(requestEntity.tools())).toolChoice(toolChoice.build()).build()
89+
ToolConfiguration.builder().tools(convertTools(requestEntity.tools())).toolChoice(toolChoice).build()
9090
);
9191
}
9292

9393
inferenceConfig(requestEntity).ifPresent(converseStreamRequest::inferenceConfig);
9494
return awsBedrockClient.converseUnifiedStream(converseStreamRequest.build(), amazonBedrockModel);
9595
}
9696

97-
private static ToolChoice.Builder convertToolChoice(
97+
private static ToolChoice buildToolChoice(
9898
@Nullable List<UnifiedCompletionRequest.Tool> tools,
9999
@Nullable UnifiedCompletionRequest.ToolChoice toolChoice
100100
) {
101101
if (tools == null || tools.isEmpty()) {
102102
return null;
103103
}
104104

105-
return determineToolChoice(toolChoice);
105+
var toolChoiceBuilder = determineToolChoice(toolChoice);
106+
107+
if (toolChoiceBuilder == null) {
108+
return null;
109+
}
110+
111+
return toolChoiceBuilder.build();
106112
}
107113

108114
// default for testing

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/request/completion/AmazonBedrockConverseUtils.java

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -122,11 +122,8 @@ private static List<Message> mergeOrAddMessage(Message message, List<Message> me
122122
var previousMessage = messages.getLast();
123123

124124
// Detect message types to determine merging strategy
125-
var previousToolResult = findToolResult(previousMessage.content());
126-
var currentToolResult = findToolResult(message.content());
127-
128-
var previousToolResultExists = previousToolResult != null;
129-
var currentToolResultExists = currentToolResult != null;
125+
var previousToolResultExists = hasToolResult(previousMessage.content());
126+
var currentToolResultExists = hasToolResult(message.content());
130127

131128
// If exactly one message is a tool result (XOR), insert assistant transition
132129
if (previousToolResultExists != currentToolResultExists) {
@@ -152,18 +149,18 @@ private static List<Message> mergeOrAddMessage(Message message, List<Message> me
152149
return messages;
153150
}
154151

155-
private static ToolResultBlock findToolResult(List<ContentBlock> blocks) {
152+
private static boolean hasToolResult(List<ContentBlock> blocks) {
156153
if (blocks == null) {
157-
return null;
154+
return false;
158155
}
159156

160157
for (ContentBlock block : blocks) {
161158
if (block != null && block.toolResult() != null) {
162-
return block.toolResult();
159+
return true;
163160
}
164161
}
165162

166-
return null;
163+
return false;
167164
}
168165

169166
private static List<SystemContentBlock> getSystemContentBlock(UnifiedCompletionRequest.Content content) {
@@ -174,14 +171,14 @@ private static List<SystemContentBlock> getSystemContentBlock(UnifiedCompletionR
174171
);
175172
case UnifiedCompletionRequest.ContentObjects objectsContent -> objectsContent.contentObjects()
176173
.stream()
177-
.filter(obj -> obj.type().equals(TEXT_CONTENT_TYPE) && obj.text().isEmpty() == false)
174+
.filter(obj -> obj.text().isEmpty() == false && obj.type().equals(TEXT_CONTENT_TYPE))
178175
.map(obj -> SystemContentBlock.builder().text(obj.text()).build())
179176
.toList();
180177
};
181178
}
182179

183180
private static Message convertToolResultMessage(UnifiedCompletionRequest.Message requestMessage) {
184-
// Bedrock allows empty tool result string content
181+
// Bedrock allows empty tool result string content but not empty tool result object content
185182
var convertedToolResultContentBlock = switch (requestMessage.content()) {
186183
case UnifiedCompletionRequest.ContentString stringContent -> List.of(
187184
ToolResultContentBlock.builder().text(stringContent.content()).build()
@@ -274,13 +271,14 @@ public static Document toDocument(Object value) {
274271
return switch (value) {
275272
case null -> Document.fromNull();
276273
case String stringValue -> Document.fromString(stringValue);
274+
case Boolean booleanValue -> Document.fromBoolean(booleanValue);
277275
case Integer numberValue -> Document.fromNumber(numberValue);
278-
case List<?> values -> Document.fromList(values.stream().map(v -> {
279-
if (v instanceof String) {
280-
return Document.fromString((String) v);
281-
}
282-
return Document.fromNull();
283-
}).collect(Collectors.toList()));
276+
case Long numberValue -> Document.fromNumber(numberValue);
277+
case Double numberValue -> Document.fromNumber(numberValue);
278+
case Float numberValue -> Document.fromNumber(numberValue);
279+
case List<?> values -> Document.fromList(
280+
values.stream().map(AmazonBedrockConverseUtils::toDocument).collect(Collectors.toList())
281+
);
284282
case Map<?, ?> mapValue -> {
285283
final Map<String, Document> converted = new HashMap<>();
286284
for (Map.Entry<?, ?> entry : mapValue.entrySet()) {
@@ -316,14 +314,14 @@ public static Optional<InferenceConfiguration> inferenceConfig(AmazonBedrockComp
316314
}
317315

318316
public static Optional<InferenceConfiguration> inferenceConfig(AmazonBedrockChatCompletionRequestEntity request) {
319-
if (request.temperature() != null || request.topP() != null || request.maxCompletionTokens() != null) {
317+
if (request.temperature() != null || request.topP() != null || request.maxCompletionTokens() != null || request.stop() != null) {
320318
var builder = InferenceConfiguration.builder();
321319
if (request.temperature() != null) {
322-
builder.temperature(request.temperature().floatValue());
320+
builder.temperature(request.temperature());
323321
}
324322

325323
if (request.topP() != null) {
326-
builder.topP(request.topP().floatValue());
324+
builder.topP(request.topP());
327325
}
328326

329327
if (request.maxCompletionTokens() != null) {

0 commit comments

Comments
 (0)