Skip to content

Commit b395b4e

Browse files
Add Amazon Bedrock Unified Chat Completions support
1 parent 91f43ab commit b395b4e

File tree

1 file changed

+139
-13
lines changed

1 file changed

+139
-13
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockUnifiedStreamingChatProcessor.java

Lines changed: 139 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,13 @@
1111
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockDeltaEvent;
1212
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockStart;
1313
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockStartEvent;
14+
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockStopEvent;
1415
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamMetadataEvent;
1516
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamOutput;
1617
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamResponseHandler;
1718
import software.amazon.awssdk.services.bedrockruntime.model.MessageStartEvent;
19+
import software.amazon.awssdk.services.bedrockruntime.model.MessageStopEvent;
20+
import software.amazon.awssdk.services.bedrockruntime.model.StopReason;
1821

1922
import org.elasticsearch.ElasticsearchException;
2023
import org.elasticsearch.ExceptionsHelper;
@@ -90,6 +93,13 @@ public void onNext(ConverseStreamOutput item) {
9093
);
9194
return;
9295
}
96+
case ConverseStreamOutput.EventType.MESSAGE_STOP -> {
97+
demand.set(0); // reset demand before we fork to another thread
98+
item.accept(
99+
ConverseStreamResponseHandler.Visitor.builder().onMessageStop(event -> handleMessageStop(event, chunks)).build()
100+
);
101+
return;
102+
}
93103
case ConverseStreamOutput.EventType.CONTENT_BLOCK_START -> {
94104
demand.set(0); // reset demand before we fork to another thread
95105
item.accept(
@@ -108,14 +118,18 @@ public void onNext(ConverseStreamOutput item) {
108118
);
109119
return;
110120
}
111-
case ConverseStreamOutput.EventType.METADATA -> {
121+
case ConverseStreamOutput.EventType.CONTENT_BLOCK_STOP -> {
112122
demand.set(0); // reset demand before we fork to another thread
113-
item.accept(ConverseStreamResponseHandler.Visitor.builder().onMetadata(event -> handleMetadata(event, chunks)).build());
123+
item.accept(
124+
ConverseStreamResponseHandler.Visitor.builder()
125+
.onContentBlockStop(event -> handleContentBlockStop(event, chunks))
126+
.build()
127+
);
114128
return;
115129
}
116-
case ConverseStreamOutput.EventType.MESSAGE_STOP -> {
130+
case ConverseStreamOutput.EventType.METADATA -> {
117131
demand.set(0); // reset demand before we fork to another thread
118-
item.accept(ConverseStreamResponseHandler.Visitor.builder().onMessageStop(event -> Stream.empty()).build());
132+
item.accept(ConverseStreamResponseHandler.Visitor.builder().onMetadata(event -> handleMetadata(event, chunks)).build());
119133
return;
120134
}
121135
default -> {
@@ -146,6 +160,22 @@ private void handleMessageStart(MessageStartEvent event, ArrayDeque<StreamingUni
146160
});
147161
}
148162

163+
private void handleMessageStop(MessageStopEvent event, ArrayDeque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> chunks) {
164+
runOnUtilityThreadPool(() -> {
165+
try {
166+
var messageStop = handleMessageStop(event);
167+
messageStop.forEach(chunks::offer);
168+
} catch (Exception e) {
169+
logger.warn("Failed to parse message stop event from Amazon Bedrock provider: {}", event);
170+
}
171+
if (chunks.isEmpty()) {
172+
upstream.request(1);
173+
} else {
174+
downstream.onNext(new StreamingUnifiedChatCompletionResults.Results(chunks));
175+
}
176+
});
177+
}
178+
149179
private void handleContentBlockStart(
150180
ContentBlockStartEvent event,
151181
ArrayDeque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> chunks
@@ -176,6 +206,22 @@ private void handleContentBlockDelta(
176206
});
177207
}
178208

209+
private void handleContentBlockStop(
210+
ContentBlockStopEvent event,
211+
ArrayDeque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> chunks
212+
) {
213+
runOnUtilityThreadPool(() -> {
214+
try {
215+
var contentBlockStop = handleContentBlockStop(event);
216+
contentBlockStop.forEach(chunks::offer);
217+
} catch (Exception e) {
218+
logger.warn("Failed to parse content block stop event from Amazon Bedrock provider: {}", event);
219+
}
220+
var results = new StreamingUnifiedChatCompletionResults.Results(chunks);
221+
downstream.onNext(results);
222+
});
223+
}
224+
179225
private void handleMetadata(
180226
ConverseStreamMetadataEvent event,
181227
ArrayDeque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> chunks
@@ -283,6 +329,56 @@ public static Stream<StreamingUnifiedChatCompletionResults.ChatCompletionChunk>
283329
return Stream.of(chunk);
284330
}
285331

332+
/**
333+
* Parse a MessageStopEvent into a ChatCompletionChunk stream
334+
* @param event the MessageStopEvent data
335+
* @return a stream of ChatCompletionChunk
336+
*/
337+
public static Stream<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> handleMessageStop(MessageStopEvent event) {
338+
var finishReason = handleFinishReason(event.stopReason());
339+
var choice = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice(null, finishReason, 0);
340+
var chunk = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk(null, List.of(choice), null, null, null);
341+
return Stream.of(chunk);
342+
}
343+
344+
public static Stream<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> processEvent(MessageStopEvent event) {
345+
var finishReason = handleFinishReason(event.stopReason());
346+
var choice = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice(null, finishReason, 0);
347+
var chunk = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk(null, List.of(choice), null, null, null);
348+
return Stream.of(chunk);
349+
}
350+
351+
/**
352+
* This ensures consistent handling of completion termination across different providers.
353+
* For example, both "stop_sequence" and "end_turn" from Bedrock map to the unified "stop" reason.
354+
* @param stopReason the stop reason
355+
* @return a stop reason
356+
*/
357+
public static String handleFinishReason(StopReason stopReason) {
358+
switch (stopReason) {
359+
case StopReason.TOOL_USE -> {
360+
return "FinishReasonToolCalls";
361+
}
362+
case StopReason.MAX_TOKENS -> {
363+
return "FinishReasonLength";
364+
}
365+
case StopReason.CONTENT_FILTERED, StopReason.GUARDRAIL_INTERVENED -> {
366+
return "FinishReasonContentFilter";
367+
}
368+
case StopReason.END_TURN, StopReason.STOP_SEQUENCE -> {
369+
return "FinishReasonStop";
370+
}
371+
default -> {
372+
logger.debug("unhandled stop reason [{}].", stopReason);
373+
return "FinishReasonStop";
374+
}
375+
}
376+
}
377+
378+
public StreamingUnifiedChatCompletionResults.ChatCompletionChunk createBaseChunk() {
379+
return new StreamingUnifiedChatCompletionResults.ChatCompletionChunk(null, null, null, "chat.completion.chunk", null);
380+
}
381+
286382
/**
287383
* processes a tool initialization event from Bedrock
288384
* This occurs when the model first decides to use a tool, providing its name and ID.
@@ -326,11 +422,20 @@ private static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.
326422
* @return a stream of ChatCompletionChunk
327423
*/
328424
public static Stream<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> handleContentBlockStart(ContentBlockStartEvent event) {
329-
var toolCall = handleToolUseStart(event.start());
330-
var role = "assistant";
331-
332-
var delta = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta(null, null, role, List.of(toolCall));
333-
var choice = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice(delta, null, 0);
425+
StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta delta = null;
426+
var index = event.contentBlockIndex();
427+
var type = event.start().type();
428+
429+
switch (type) {
430+
case ContentBlockStart.Type.TOOL_USE -> {
431+
var toolCall = handleToolUseStart(event.start());
432+
var role = "assistant";
433+
delta = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta(null, null, role, List.of(toolCall));
434+
}
435+
default -> logger.debug("unhandled content block start type [{}].", type);
436+
}
437+
delta = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta(null, null, null, null);
438+
var choice = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice(delta, null, index);
334439
var chunk = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk(null, List.of(choice), null, null, null);
335440
return Stream.of(chunk);
336441
}
@@ -342,14 +447,35 @@ public static Stream<StreamingUnifiedChatCompletionResults.ChatCompletionChunk>
342447
* @return a stream of ChatCompletionChunk
343448
*/
344449
public static Stream<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> handleContentBlockDelta(ContentBlockDeltaEvent event) {
345-
var text = event.delta().text();
346-
var toolCall = handleToolUseDelta(event.delta());
347-
var delta = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta(text, null, null, List.of(toolCall));
348-
var choice = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice(delta, null, 0);
450+
StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta delta = null;
451+
var type = event.delta().type();
452+
453+
switch (type) {
454+
case ContentBlockDelta.Type.TEXT -> {
455+
var content = event.delta().text();
456+
delta = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta(content, null, null, null);
457+
}
458+
case ContentBlockDelta.Type.TOOL_USE -> {
459+
var toolCall = handleToolUseDelta(event.delta());
460+
delta = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta(null, null, null, List.of(toolCall));
461+
}
462+
default -> logger.debug("unknown content block delta type [{}].", type);
463+
}
464+
var choice = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice(delta, null, event.contentBlockIndex());
349465
var chunk = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk(null, List.of(choice), null, null, null);
350466
return Stream.of(chunk);
351467
}
352468

469+
/**
470+
* processes incremental content updates
471+
* Parse a ContentBlockStopEvent into a ChatCompletionChunk stream
472+
* @param event the event data
473+
* @return a stream of ChatCompletionChunk
474+
*/
475+
public static Stream<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> handleContentBlockStop(ContentBlockStopEvent event) {
476+
return Stream.empty();
477+
}
478+
353479
/**
354480
* processes usage statistics
355481
* Parse a ConverseStreamMetadataEvent into a ChatCompletionChunk stream

0 commit comments

Comments
 (0)