Skip to content

Commit d0b2047

Browse files
Add Amazon Bedrock Unified Chat Completions support
1 parent 3455a63 commit d0b2047

File tree

1 file changed

+13
-38
lines changed

1 file changed

+13
-38
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockUnifiedStreamingChatProcessor.java

Lines changed: 13 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockDeltaEvent;
1212
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockStart;
1313
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockStartEvent;
14-
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockStopEvent;
1514
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamMetadataEvent;
1615
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamOutput;
1716
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamResponseHandler;
@@ -118,20 +117,20 @@ public void onNext(ConverseStreamOutput item) {
118117
);
119118
return;
120119
}
120+
case ConverseStreamOutput.EventType.METADATA -> {
121+
demand.set(0); // reset demand before we fork to another thread
122+
item.accept(ConverseStreamResponseHandler.Visitor.builder().onMetadata(event -> handleMetadata(event, chunks)).build());
123+
return;
124+
}
121125
case ConverseStreamOutput.EventType.CONTENT_BLOCK_STOP -> {
122126
demand.set(0); // reset demand before we fork to another thread
123127
item.accept(
124128
ConverseStreamResponseHandler.Visitor.builder()
125-
.onContentBlockStop(event -> handleContentBlockStop(event, chunks))
129+
.onContentBlockStop(event -> Stream.empty())
126130
.build()
127131
);
128132
return;
129133
}
130-
case ConverseStreamOutput.EventType.METADATA -> {
131-
demand.set(0); // reset demand before we fork to another thread
132-
item.accept(ConverseStreamResponseHandler.Visitor.builder().onMetadata(event -> handleMetadata(event, chunks)).build());
133-
return;
134-
}
135134
default -> {
136135
logger.debug("Unknown event type [{}] for line [{}].", eventType, item);
137136
}
@@ -206,22 +205,6 @@ private void handleContentBlockDelta(
206205
});
207206
}
208207

209-
private void handleContentBlockStop(
210-
ContentBlockStopEvent event,
211-
ArrayDeque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> chunks
212-
) {
213-
runOnUtilityThreadPool(() -> {
214-
try {
215-
var contentBlockStop = handleContentBlockStop(event);
216-
contentBlockStop.forEach(chunks::offer);
217-
} catch (Exception e) {
218-
logger.warn("Failed to parse content block stop event from Amazon Bedrock provider: {}", event);
219-
}
220-
var results = new StreamingUnifiedChatCompletionResults.Results(chunks);
221-
downstream.onNext(results);
222-
});
223-
}
224-
225208
private void handleMetadata(
226209
ConverseStreamMetadataEvent event,
227210
ArrayDeque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> chunks
@@ -422,19 +405,21 @@ private static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.
422405
* @return a stream of ChatCompletionChunk
423406
*/
424407
public static Stream<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> handleContentBlockStart(ContentBlockStartEvent event) {
425-
StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta delta = null;
426408
var index = event.contentBlockIndex();
427409
var type = event.start().type();
428410

429411
switch (type) {
430412
case ContentBlockStart.Type.TOOL_USE -> {
431413
var toolCall = handleToolUseStart(event.start());
432414
var role = "assistant";
433-
delta = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta(null, null, role, List.of(toolCall));
415+
var delta = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta(null, null, role, List.of(toolCall));
416+
var choice = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice(delta, null, index);
417+
var chunk = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk(null, List.of(choice), null, null, null);
418+
return Stream.of(chunk);
434419
}
435420
default -> logger.debug("unhandled content block start type [{}].", type);
436421
}
437-
delta = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta(null, null, null, null);
422+
var delta = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta(null, null, null, null);
438423
var choice = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice(delta, null, index);
439424
var chunk = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk(null, List.of(choice), null, null, null);
440425
return Stream.of(chunk);
@@ -449,15 +434,15 @@ public static Stream<StreamingUnifiedChatCompletionResults.ChatCompletionChunk>
449434
public static Stream<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> handleContentBlockDelta(ContentBlockDeltaEvent event) {
450435
StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta delta = null;
451436
var type = event.delta().type();
437+
var content = event.delta().text();
452438

453439
switch (type) {
454440
case ContentBlockDelta.Type.TEXT -> {
455-
var content = event.delta().text();
456441
delta = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta(content, null, null, null);
457442
}
458443
case ContentBlockDelta.Type.TOOL_USE -> {
459444
var toolCall = handleToolUseDelta(event.delta());
460-
delta = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta(null, null, null, List.of(toolCall));
445+
delta = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta(content, null, null, List.of(toolCall));
461446
}
462447
default -> logger.debug("unknown content block delta type [{}].", type);
463448
}
@@ -466,16 +451,6 @@ public static Stream<StreamingUnifiedChatCompletionResults.ChatCompletionChunk>
466451
return Stream.of(chunk);
467452
}
468453

469-
/**
470-
* processes incremental content updates
471-
* Parse a ContentBlockStopEvent into a ChatCompletionChunk stream
472-
* @param event the event data
473-
* @return a stream of ChatCompletionChunk
474-
*/
475-
public static Stream<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> handleContentBlockStop(ContentBlockStopEvent event) {
476-
return Stream.empty();
477-
}
478-
479454
/**
480455
* processes usage statistics
481456
* Parse a ConverseStreamMetadataEvent into a ChatCompletionChunk stream

0 commit comments

Comments
 (0)