2424import java .util .HashMap ;
2525import java .util .List ;
2626import java .util .Map ;
27+ import java .util .stream .Collectors ;
2728import javax .annotation .Nullable ;
2829import software .amazon .awssdk .core .SdkRequest ;
2930import software .amazon .awssdk .core .SdkResponse ;
3031import software .amazon .awssdk .core .async .SdkPublisher ;
3132import software .amazon .awssdk .core .document .Document ;
3233import software .amazon .awssdk .protocols .json .SdkJsonGenerator ;
34+ import software .amazon .awssdk .protocols .jsoncore .JsonNode ;
35+ import software .amazon .awssdk .protocols .jsoncore .JsonNodeParser ;
3336import software .amazon .awssdk .services .bedrockruntime .BedrockRuntimeAsyncClient ;
3437import software .amazon .awssdk .services .bedrockruntime .model .ContentBlock ;
38+ import software .amazon .awssdk .services .bedrockruntime .model .ContentBlockDelta ;
39+ import software .amazon .awssdk .services .bedrockruntime .model .ContentBlockDeltaEvent ;
40+ import software .amazon .awssdk .services .bedrockruntime .model .ContentBlockStartEvent ;
41+ import software .amazon .awssdk .services .bedrockruntime .model .ContentBlockStopEvent ;
3542import software .amazon .awssdk .services .bedrockruntime .model .ConverseRequest ;
3643import software .amazon .awssdk .services .bedrockruntime .model .ConverseResponse ;
3744import software .amazon .awssdk .services .bedrockruntime .model .ConverseStreamMetadataEvent ;
4148import software .amazon .awssdk .services .bedrockruntime .model .ConverseStreamResponseHandler ;
4249import software .amazon .awssdk .services .bedrockruntime .model .InferenceConfiguration ;
4350import software .amazon .awssdk .services .bedrockruntime .model .Message ;
51+ import software .amazon .awssdk .services .bedrockruntime .model .MessageStartEvent ;
4452import software .amazon .awssdk .services .bedrockruntime .model .MessageStopEvent ;
4553import software .amazon .awssdk .services .bedrockruntime .model .StopReason ;
4654import software .amazon .awssdk .services .bedrockruntime .model .TokenUsage ;
4755import software .amazon .awssdk .services .bedrockruntime .model .ToolResultContentBlock ;
4856import software .amazon .awssdk .services .bedrockruntime .model .ToolUseBlock ;
57+ import software .amazon .awssdk .services .bedrockruntime .model .ToolUseBlockStart ;
4958import software .amazon .awssdk .thirdparty .jackson .core .JsonFactory ;
5059
5160/**
@@ -59,6 +68,8 @@ private BedrockRuntimeImpl() {}
5968 private static final AttributeKey <String > GEN_AI_SYSTEM = stringKey ("gen_ai.system" );
6069
6170 private static final JsonFactory JSON_FACTORY = new JsonFactory ();
71+ private static final JsonNodeParser JSON_PARSER = JsonNode .parser ();
72+ private static final DocumentUnmarshaller DOCUMENT_UNMARSHALLER = new DocumentUnmarshaller ();
6273
6374 static boolean isBedrockRuntimeRequest (SdkRequest request ) {
6475 if (request instanceof ConverseRequest ) {
@@ -202,35 +213,54 @@ static Long getUsageOutputTokens(Response response) {
202213 static void recordRequestEvents (
203214 Context otelContext , Logger eventLogger , SdkRequest request , boolean captureMessageContent ) {
204215 if (request instanceof ConverseRequest ) {
205- for (Message message : ((ConverseRequest ) request ).messages ()) {
206- long numToolResults =
207- message .content ().stream ().filter (block -> block .toolResult () != null ).count ();
208- if (numToolResults > 0 ) {
209- // Tool results are different from others, emitting multiple events for a single message,
210- // so treat them separately.
211- emitToolResultEvents (otelContext , eventLogger , message , captureMessageContent );
212- if (numToolResults == message .content ().size ()) {
213- continue ;
214- }
215- // There are content blocks besides tool results in the same message. While models
216- // generally don't expect such usage, the SDK allows it so go ahead and generate a normal
217- // message too.
218- }
219- LogRecordBuilder event = newEvent (otelContext , eventLogger );
220- switch (message .role ()) {
221- case ASSISTANT :
222- event .setAttribute (EVENT_NAME , "gen_ai.assistant.message" );
223- break ;
224- case USER :
225- event .setAttribute (EVENT_NAME , "gen_ai.user.message" );
226- break ;
227- default :
228- // unknown role, shouldn't happen in practice
229- continue ;
216+ recordRequestMessageEvents (
217+ otelContext , eventLogger , ((ConverseRequest ) request ).messages (), captureMessageContent );
218+ }
219+ if (request instanceof ConverseStreamRequest ) {
220+ recordRequestMessageEvents (
221+ otelContext ,
222+ eventLogger ,
223+ ((ConverseStreamRequest ) request ).messages (),
224+ captureMessageContent );
225+
226+ // Good a time as any to store the context for a streaming request.
227+ TracingConverseStreamResponseHandler .fromContext (otelContext ).setOtelContext (otelContext );
228+ }
229+ }
230+
231+ private static void recordRequestMessageEvents (
232+ Context otelContext ,
233+ Logger eventLogger ,
234+ List <Message > messages ,
235+ boolean captureMessageContent ) {
236+ for (Message message : messages ) {
237+ long numToolResults =
238+ message .content ().stream ().filter (block -> block .toolResult () != null ).count ();
239+ if (numToolResults > 0 ) {
240+ // Tool results are different from others, emitting multiple events for a single message,
241+ // so treat them separately.
242+ emitToolResultEvents (otelContext , eventLogger , message , captureMessageContent );
243+ if (numToolResults == message .content ().size ()) {
244+ continue ;
230245 }
231- // Requests don't have index or stop reason.
232- event .setBody (convertMessage (message , -1 , null , captureMessageContent )).emit ();
246+ // There are content blocks besides tool results in the same message. While models
247+ // generally don't expect such usage, the SDK allows it so go ahead and generate a normal
248+ // message too.
233249 }
250+ LogRecordBuilder event = newEvent (otelContext , eventLogger );
251+ switch (message .role ()) {
252+ case ASSISTANT :
253+ event .setAttribute (EVENT_NAME , "gen_ai.assistant.message" );
254+ break ;
255+ case USER :
256+ event .setAttribute (EVENT_NAME , "gen_ai.user.message" );
257+ break ;
258+ default :
259+ // unknown role, shouldn't happen in practice
260+ continue ;
261+ }
262+ // Requests don't have index or stop reason.
263+ event .setBody (convertMessage (message , -1 , null , captureMessageContent )).emit ();
234264 }
235265 }
236266
@@ -248,7 +278,7 @@ static void recordResponseEvents(
248278 convertMessage (
249279 converseResponse .output ().message (),
250280 0 ,
251- converseResponse .stopReason (),
281+ converseResponse .stopReasonAsString (),
252282 captureMessageContent ))
253283 .emit ();
254284 }
@@ -270,7 +300,8 @@ private static Double floatToDouble(Float value) {
270300 return Double .valueOf (value );
271301 }
272302
273- public static BedrockRuntimeAsyncClient wrap (BedrockRuntimeAsyncClient asyncClient ) {
303+ public static BedrockRuntimeAsyncClient wrap (
304+ BedrockRuntimeAsyncClient asyncClient , Logger eventLogger , boolean captureMessageContent ) {
274305 // proxy BedrockRuntimeAsyncClient so we can wrap the subscriber to converseStream to capture
275306 // events.
276307 return (BedrockRuntimeAsyncClient )
@@ -283,7 +314,9 @@ public static BedrockRuntimeAsyncClient wrap(BedrockRuntimeAsyncClient asyncClie
283314 && args [1 ] instanceof ConverseStreamResponseHandler ) {
284315 TracingConverseStreamResponseHandler wrapped =
285316 new TracingConverseStreamResponseHandler (
286- (ConverseStreamResponseHandler ) args [1 ]);
317+ (ConverseStreamResponseHandler ) args [1 ],
318+ eventLogger ,
319+ captureMessageContent );
287320 args [1 ] = wrapped ;
288321 try (Scope ignored = wrapped .makeCurrent ()) {
289322 return invokeProxyMethod (method , asyncClient , args );
@@ -318,12 +351,29 @@ public static TracingConverseStreamResponseHandler fromContext(Context context)
318351 ContextKey .named ("bedrock-runtime-converse-stream-response-handler" );
319352
320353 private final ConverseStreamResponseHandler delegate ;
354+ private final Logger eventLogger ;
355+ private final boolean captureMessageContent ;
356+
357+ private StringBuilder currentText ;
358+
359+ // The response handler is created and stored into context before the span, so we need to
360+ // also pass the later context in for recording events. While subscribers are called from a
361+ // single thread, it is not clear if that is guaranteed to be the same as the execution
362+ // interceptor so we use volatile.
363+ private volatile Context otelContext ;
364+
365+ private List <ToolUseBlock > tools ;
366+ private ToolUseBlock .Builder currentTool ;
367+ private StringBuilder currentToolArgs ;
321368
322369 List <String > stopReasons ;
323370 TokenUsage usage ;
324371
325- TracingConverseStreamResponseHandler (ConverseStreamResponseHandler delegate ) {
372+ TracingConverseStreamResponseHandler (
373+ ConverseStreamResponseHandler delegate , Logger eventLogger , boolean captureMessageContent ) {
326374 this .delegate = delegate ;
375+ this .eventLogger = eventLogger ;
376+ this .captureMessageContent = captureMessageContent ;
327377 }
328378
329379 @ Override
@@ -336,19 +386,66 @@ public void onEventStream(SdkPublisher<ConverseStreamOutput> sdkPublisher) {
336386 delegate .onEventStream (
337387 sdkPublisher .map (
338388 event -> {
339- if (event instanceof MessageStopEvent ) {
340- if (stopReasons == null ) {
341- stopReasons = new ArrayList <>();
342- }
343- stopReasons .add (((MessageStopEvent ) event ).stopReasonAsString ());
344- }
345- if (event instanceof ConverseStreamMetadataEvent ) {
346- usage = ((ConverseStreamMetadataEvent ) event ).usage ();
347- }
389+ handleEvent (event );
348390 return event ;
349391 }));
350392 }
351393
394+ private void handleEvent (ConverseStreamOutput event ) {
395+ if (captureMessageContent && event instanceof MessageStartEvent ) {
396+ if (currentText == null ) {
397+ currentText = new StringBuilder ();
398+ }
399+ currentText .setLength (0 );
400+ }
401+ if (event instanceof ContentBlockStartEvent ) {
402+ ToolUseBlockStart toolUse = ((ContentBlockStartEvent ) event ).start ().toolUse ();
403+ if (toolUse != null ) {
404+ if (currentToolArgs == null ) {
405+ currentToolArgs = new StringBuilder ();
406+ }
407+ currentToolArgs .setLength (0 );
408+ currentTool = ToolUseBlock .builder ().name (toolUse .name ()).toolUseId (toolUse .toolUseId ());
409+ }
410+ }
411+ if (event instanceof ContentBlockDeltaEvent ) {
412+ ContentBlockDelta delta = ((ContentBlockDeltaEvent ) event ).delta ();
413+ if (captureMessageContent && delta .text () != null ) {
414+ currentText .append (delta .text ());
415+ }
416+ if (delta .toolUse () != null ) {
417+ currentToolArgs .append (delta .toolUse ().input ());
418+ }
419+ }
420+ if (event instanceof ContentBlockStopEvent ) {
421+ if (currentTool != null ) {
422+ if (tools == null ) {
423+ tools = new ArrayList <>();
424+ }
425+ if (currentToolArgs != null ) {
426+ Document args = deserializeDocument (currentToolArgs .toString ());
427+ currentTool .input (args );
428+ }
429+ tools .add (currentTool .build ());
430+ currentTool = null ;
431+ }
432+ }
433+ if (event instanceof MessageStopEvent ) {
434+ if (stopReasons == null ) {
435+ stopReasons = new ArrayList <>();
436+ }
437+ String stopReason = ((MessageStopEvent ) event ).stopReasonAsString ();
438+ stopReasons .add (stopReason );
439+ newEvent (otelContext , eventLogger )
440+ .setAttribute (EVENT_NAME , "gen_ai.choice" )
441+ .setBody (convertMessageData (currentText , tools , 0 , stopReason , captureMessageContent ))
442+ .emit ();
443+ }
444+ if (event instanceof ConverseStreamMetadataEvent ) {
445+ usage = ((ConverseStreamMetadataEvent ) event ).usage ();
446+ }
447+ }
448+
352449 @ Override
353450 public void exceptionOccurred (Throwable throwable ) {
354451 delegate .exceptionOccurred (throwable );
@@ -363,6 +460,10 @@ public void complete() {
363460 public Context storeInContext (Context context ) {
364461 return context .with (KEY , this );
365462 }
463+
464+ void setOtelContext (Context otelContext ) {
465+ this .otelContext = otelContext ;
466+ }
366467 }
367468
368469 private static LogRecordBuilder newEvent (Context otelContext , Logger eventLogger ) {
@@ -401,9 +502,9 @@ private static void emitToolResultEvents(
401502 }
402503
403504 private static Value <?> convertMessage (
404- Message message , int index , @ Nullable StopReason stopReason , boolean captureMessageContent ) {
505+ Message message , int index , @ Nullable String stopReason , boolean captureMessageContent ) {
405506 StringBuilder text = null ;
406- List <Value <?> > toolCalls = null ;
507+ List <ToolUseBlock > toolCalls = null ;
407508 for (ContentBlock content : message .content ()) {
408509 if (captureMessageContent && content .text () != null ) {
409510 if (text == null ) {
@@ -415,15 +516,29 @@ private static Value<?> convertMessage(
415516 if (toolCalls == null ) {
416517 toolCalls = new ArrayList <>();
417518 }
418- toolCalls .add (convertToolCall ( content .toolUse (), captureMessageContent ));
519+ toolCalls .add (content .toolUse ());
419520 }
420521 }
522+
523+ return convertMessageData (text , toolCalls , index , stopReason , captureMessageContent );
524+ }
525+
526+ private static Value <?> convertMessageData (
527+ @ Nullable StringBuilder text ,
528+ List <ToolUseBlock > toolCalls ,
529+ int index ,
530+ @ Nullable String stopReason ,
531+ boolean captureMessageContent ) {
421532 Map <String , Value <?>> body = new HashMap <>();
422533 if (text != null ) {
423534 body .put ("content" , Value .of (text .toString ()));
424535 }
425536 if (toolCalls != null ) {
426- body .put ("toolCalls" , Value .of (toolCalls ));
537+ List <Value <?>> toolCallValues =
538+ toolCalls .stream ()
539+ .map (tool -> convertToolCall (tool , captureMessageContent ))
540+ .collect (Collectors .toList ());
541+ body .put ("toolCalls" , Value .of (toolCallValues ));
427542 }
428543 if (stopReason != null ) {
429544 body .put ("finish_reason" , Value .of (stopReason .toString ()));
@@ -451,4 +566,9 @@ private static String serializeDocument(Document document) {
451566 document .accept (marshaller );
452567 return new String (generator .getBytes (), StandardCharsets .UTF_8 );
453568 }
569+
570+ private static Document deserializeDocument (String json ) {
571+ JsonNode node = JSON_PARSER .parse (json );
572+ return node .visit (DOCUMENT_UNMARSHALLER );
573+ }
454574}
0 commit comments