7
7
8
8
package org .elasticsearch .xpack .inference .services .amazonbedrock .client ;
9
9
10
- import org .elasticsearch .xpack .core .inference .results .StreamingUnifiedChatCompletionResults ;
11
-
12
10
import software .amazon .awssdk .services .bedrockruntime .model .ContentBlockDelta ;
13
11
import software .amazon .awssdk .services .bedrockruntime .model .ContentBlockDeltaEvent ;
14
12
import software .amazon .awssdk .services .bedrockruntime .model .ContentBlockStart ;
15
13
import software .amazon .awssdk .services .bedrockruntime .model .ContentBlockStartEvent ;
16
14
import software .amazon .awssdk .services .bedrockruntime .model .ConverseStreamMetadataEvent ;
17
15
import software .amazon .awssdk .services .bedrockruntime .model .ConverseStreamOutput ;
18
16
import software .amazon .awssdk .services .bedrockruntime .model .ConverseStreamResponseHandler ;
17
+ import software .amazon .awssdk .services .bedrockruntime .model .MessageStartEvent ;
19
18
20
19
import org .elasticsearch .ElasticsearchException ;
21
20
import org .elasticsearch .ExceptionsHelper ;
24
23
import org .elasticsearch .logging .LogManager ;
25
24
import org .elasticsearch .logging .Logger ;
26
25
import org .elasticsearch .threadpool .ThreadPool ;
27
-
28
- import software .amazon .awssdk .services .bedrockruntime .model .MessageStartEvent ;
26
+ import org .elasticsearch .xpack .core .inference .results .StreamingUnifiedChatCompletionResults ;
29
27
30
28
import java .util .ArrayDeque ;
31
29
import java .util .List ;
38
36
import static org .elasticsearch .xpack .inference .InferencePlugin .UTILITY_THREAD_POOL_NAME ;
39
37
40
38
@ SuppressWarnings ("checkstyle:LineLength" )
41
- class AmazonBedrockUnifiedStreamingChatProcessor implements Flow .Processor <ConverseStreamOutput , StreamingUnifiedChatCompletionResults .Results > {
39
+ class AmazonBedrockUnifiedStreamingChatProcessor
40
+ implements
41
+ Flow .Processor <ConverseStreamOutput , StreamingUnifiedChatCompletionResults .Results > {
42
42
private static final Logger logger = LogManager .getLogger (AmazonBedrockStreamingChatProcessor .class );
43
43
44
44
private final AtomicReference <Throwable > error = new AtomicReference <>(null );
@@ -85,32 +85,37 @@ public void onNext(ConverseStreamOutput item) {
85
85
switch (eventType ) {
86
86
case ConverseStreamOutput .EventType .MESSAGE_START -> {
87
87
demand .set (0 ); // reset demand before we fork to another thread
88
- item .accept (ConverseStreamResponseHandler .Visitor .builder ()
89
- .onMessageStart (event -> handleMessageStart (event , chunks )).build ());
88
+ item .accept (
89
+ ConverseStreamResponseHandler .Visitor .builder ().onMessageStart (event -> handleMessageStart (event , chunks )).build ()
90
+ );
90
91
return ;
91
92
}
92
93
case ConverseStreamOutput .EventType .CONTENT_BLOCK_START -> {
93
94
demand .set (0 ); // reset demand before we fork to another thread
94
- item .accept (ConverseStreamResponseHandler .Visitor .builder ()
95
- .onContentBlockStart (event -> handleContentBlockStart (event , chunks )).build ());
95
+ item .accept (
96
+ ConverseStreamResponseHandler .Visitor .builder ()
97
+ .onContentBlockStart (event -> handleContentBlockStart (event , chunks ))
98
+ .build ()
99
+ );
96
100
return ;
97
101
}
98
102
case ConverseStreamOutput .EventType .CONTENT_BLOCK_DELTA -> {
99
103
demand .set (0 ); // reset demand before we fork to another thread
100
- item .accept (ConverseStreamResponseHandler .Visitor .builder ()
101
- .onContentBlockDelta (event -> handleContentBlockDelta (event , chunks )).build ());
104
+ item .accept (
105
+ ConverseStreamResponseHandler .Visitor .builder ()
106
+ .onContentBlockDelta (event -> handleContentBlockDelta (event , chunks ))
107
+ .build ()
108
+ );
102
109
return ;
103
110
}
104
111
case ConverseStreamOutput .EventType .METADATA -> {
105
112
demand .set (0 ); // reset demand before we fork to another thread
106
- item .accept (ConverseStreamResponseHandler .Visitor .builder ()
107
- .onMetadata (event -> handleMetadata (event , chunks )).build ());
113
+ item .accept (ConverseStreamResponseHandler .Visitor .builder ().onMetadata (event -> handleMetadata (event , chunks )).build ());
108
114
return ;
109
115
}
110
116
case ConverseStreamOutput .EventType .MESSAGE_STOP -> {
111
117
demand .set (0 ); // reset demand before we fork to another thread
112
- item .accept (ConverseStreamResponseHandler .Visitor .builder ()
113
- .onMessageStop (event -> Stream .empty ()).build ());
118
+ item .accept (ConverseStreamResponseHandler .Visitor .builder ().onMessageStop (event -> Stream .empty ()).build ());
114
119
return ;
115
120
}
116
121
default -> {
@@ -125,9 +130,7 @@ public void onNext(ConverseStreamOutput item) {
125
130
}
126
131
}
127
132
128
- private void handleMessageStart (
129
- MessageStartEvent event ,
130
- ArrayDeque <StreamingUnifiedChatCompletionResults .ChatCompletionChunk > chunks ) {
133
+ private void handleMessageStart (MessageStartEvent event , ArrayDeque <StreamingUnifiedChatCompletionResults .ChatCompletionChunk > chunks ) {
131
134
runOnUtilityThreadPool (() -> {
132
135
try {
133
136
var messageStart = handleMessageStart (event );
@@ -143,7 +146,10 @@ private void handleMessageStart(
143
146
});
144
147
}
145
148
146
- private void handleContentBlockStart (ContentBlockStartEvent event , ArrayDeque <StreamingUnifiedChatCompletionResults .ChatCompletionChunk > chunks ) {
149
+ private void handleContentBlockStart (
150
+ ContentBlockStartEvent event ,
151
+ ArrayDeque <StreamingUnifiedChatCompletionResults .ChatCompletionChunk > chunks
152
+ ) {
147
153
try {
148
154
var contentBlockStart = handleContentBlockStart (event );
149
155
contentBlockStart .forEach (chunks ::offer );
@@ -154,7 +160,10 @@ private void handleContentBlockStart(ContentBlockStartEvent event, ArrayDeque<St
154
160
downstream .onNext (results );
155
161
}
156
162
157
- private void handleContentBlockDelta (ContentBlockDeltaEvent event , ArrayDeque <StreamingUnifiedChatCompletionResults .ChatCompletionChunk > chunks ) {
163
+ private void handleContentBlockDelta (
164
+ ContentBlockDeltaEvent event ,
165
+ ArrayDeque <StreamingUnifiedChatCompletionResults .ChatCompletionChunk > chunks
166
+ ) {
158
167
runOnUtilityThreadPool (() -> {
159
168
try {
160
169
var contentBlockDelta = handleContentBlockDelta (event );
@@ -167,7 +176,10 @@ private void handleContentBlockDelta(ContentBlockDeltaEvent event, ArrayDeque<St
167
176
});
168
177
}
169
178
170
- private void handleMetadata (ConverseStreamMetadataEvent event , ArrayDeque <StreamingUnifiedChatCompletionResults .ChatCompletionChunk > chunks ) {
179
+ private void handleMetadata (
180
+ ConverseStreamMetadataEvent event ,
181
+ ArrayDeque <StreamingUnifiedChatCompletionResults .ChatCompletionChunk > chunks
182
+ ) {
171
183
runOnUtilityThreadPool (() -> {
172
184
try {
173
185
var messageDelta = handleMetadata (event );
@@ -265,10 +277,10 @@ public void cancel() {
265
277
* @return a stream of ChatCompletionChunk
266
278
*/
267
279
public static Stream <StreamingUnifiedChatCompletionResults .ChatCompletionChunk > handleMessageStart (MessageStartEvent event ) {
268
- var delta = new StreamingUnifiedChatCompletionResults .ChatCompletionChunk .Choice .Delta (null , null , event .roleAsString (), null );
269
- var choice = new StreamingUnifiedChatCompletionResults .ChatCompletionChunk .Choice (delta , null , 0 );
270
- var chunk = new StreamingUnifiedChatCompletionResults .ChatCompletionChunk (null , List .of (choice ), null , null , null );
271
- return Stream .of (chunk );
280
+ var delta = new StreamingUnifiedChatCompletionResults .ChatCompletionChunk .Choice .Delta (null , null , event .roleAsString (), null );
281
+ var choice = new StreamingUnifiedChatCompletionResults .ChatCompletionChunk .Choice (delta , null , 0 );
282
+ var chunk = new StreamingUnifiedChatCompletionResults .ChatCompletionChunk (null , List .of (choice ), null , null , null );
283
+ return Stream .of (chunk );
272
284
}
273
285
274
286
/**
@@ -278,11 +290,18 @@ public static Stream<StreamingUnifiedChatCompletionResults.ChatCompletionChunk>
278
290
* @param start the ContentBlockStart data
279
291
* @return a ToolCall
280
292
*/
281
- private static StreamingUnifiedChatCompletionResults .ChatCompletionChunk .Choice .Delta .ToolCall handleToolUseStart (ContentBlockStart start ) {
293
+ private static StreamingUnifiedChatCompletionResults .ChatCompletionChunk .Choice .Delta .ToolCall handleToolUseStart (
294
+ ContentBlockStart start
295
+ ) {
282
296
var type = start .type ();
283
297
var toolUse = start .toolUse ();
284
298
var function = new StreamingUnifiedChatCompletionResults .ChatCompletionChunk .Choice .Delta .ToolCall .Function (null , toolUse .name ());
285
- return new StreamingUnifiedChatCompletionResults .ChatCompletionChunk .Choice .Delta .ToolCall (0 , toolUse .toolUseId (), function , type .name ());
299
+ return new StreamingUnifiedChatCompletionResults .ChatCompletionChunk .Choice .Delta .ToolCall (
300
+ 0 ,
301
+ toolUse .toolUseId (),
302
+ function ,
303
+ type .name ()
304
+ );
286
305
}
287
306
288
307
/**
@@ -292,7 +311,9 @@ private static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.
292
311
* @param delta the ContentBlockDelta data
293
312
* @return a ToolCall
294
313
*/
295
- private static StreamingUnifiedChatCompletionResults .ChatCompletionChunk .Choice .Delta .ToolCall handleToolUseDelta (ContentBlockDelta delta ) {
314
+ private static StreamingUnifiedChatCompletionResults .ChatCompletionChunk .Choice .Delta .ToolCall handleToolUseDelta (
315
+ ContentBlockDelta delta
316
+ ) {
296
317
var type = delta .type ();
297
318
var toolUse = delta .toolUse ();
298
319
var function = new StreamingUnifiedChatCompletionResults .ChatCompletionChunk .Choice .Delta .ToolCall .Function (toolUse .input (), null );
0 commit comments