16
16
17
17
import json
18
18
from os import environ
19
- from typing import Any , Callable , Dict , Union
19
+ from typing import Any , Callable , Dict , Iterator , Sequence , Union
20
20
21
21
from botocore .eventstream import EventStream , EventStreamError
22
22
from wrapt import ObjectProxy
34
34
_StreamErrorCallableT = Callable [[Exception ], None ]
35
35
36
36
37
+ def _decode_tool_use (tool_use ):
38
+ # input get sent encoded in json
39
+ if "input" in tool_use :
40
+ try :
41
+ tool_use ["input" ] = json .loads (tool_use ["input" ])
42
+ except json .JSONDecodeError :
43
+ pass
44
+ return tool_use
45
+
46
+
37
47
# pylint: disable=abstract-method
38
48
class ConverseStreamWrapper (ObjectProxy ):
39
49
"""Wrapper for botocore.eventstream.EventStream"""
@@ -52,7 +62,7 @@ def __init__(
52
62
# {"usage": {"inputTokens": 0, "outputTokens": 0}, "stopReason": "finish", "output": {"message": {"role": "", "content": [{"text": ""}]}
53
63
self ._response = {}
54
64
self ._message = None
55
- self ._content_buf = ""
65
+ self ._content_block = {}
56
66
self ._record_message = False
57
67
58
68
def __iter__ (self ):
@@ -65,23 +75,40 @@ def __iter__(self):
65
75
raise
66
76
67
77
def _process_event (self , event ):
78
+ # pylint: disable=too-many-branches
68
79
if "messageStart" in event :
69
80
# {'messageStart': {'role': 'assistant'}}
70
81
if event ["messageStart" ].get ("role" ) == "assistant" :
71
82
self ._record_message = True
72
83
self ._message = {"role" : "assistant" , "content" : []}
73
84
return
74
85
86
+ if "contentBlockStart" in event :
87
+ # {'contentBlockStart': {'start': {'toolUse': {'toolUseId': 'id', 'name': 'func_name'}}, 'contentBlockIndex': 1}}
88
+ start = event ["contentBlockStart" ].get ("start" , {})
89
+ if "toolUse" in start :
90
+ tool_use = _decode_tool_use (start ["toolUse" ])
91
+ self ._content_block = {"toolUse" : tool_use }
92
+ return
93
+
75
94
if "contentBlockDelta" in event :
76
95
# {'contentBlockDelta': {'delta': {'text': "Hello"}, 'contentBlockIndex': 0}}
96
+ # {'contentBlockDelta': {'delta': {'toolUse': {'input': '{"location":"Seattle"}'}}, 'contentBlockIndex': 1}}
77
97
if self ._record_message :
78
- self ._content_buf += (
79
- event ["contentBlockDelta" ].get ("delta" , {}).get ("text" , "" )
80
- )
98
+ delta = event ["contentBlockDelta" ].get ("delta" , {})
99
+ if "text" in delta :
100
+ self ._content_block .setdefault ("text" , "" )
101
+ self ._content_block ["text" ] += delta ["text" ]
102
+ elif "toolUse" in delta :
103
+ tool_use = _decode_tool_use (delta ["toolUse" ])
104
+ self ._content_block ["toolUse" ].update (tool_use )
81
105
return
82
106
83
107
if "contentBlockStop" in event :
84
108
# {'contentBlockStop': {'contentBlockIndex': 0}}
109
+ if self ._record_message :
110
+ self ._message ["content" ].append (self ._content_block )
111
+ self ._content_block = {}
85
112
return
86
113
87
114
if "messageStop" in event :
@@ -90,8 +117,6 @@ def _process_event(self, event):
90
117
self ._response ["stopReason" ] = stop_reason
91
118
92
119
if self ._record_message :
93
- self ._message ["content" ].append ({"text" : self ._content_buf })
94
- self ._content_buf = ""
95
120
self ._response ["output" ] = {"message" : self ._message }
96
121
self ._record_message = False
97
122
self ._message = None
@@ -134,7 +159,8 @@ def __init__(
134
159
# {"usage": {"inputTokens": 0, "outputTokens": 0}, "stopReason": "finish", "output": {"message": {"role": "", "content": [{"text": ""}]}
135
160
self ._response = {}
136
161
self ._message = None
137
- self ._content_buf = ""
162
+ self ._content_block = {}
163
+ self ._tool_json_input_buf = ""
138
164
self ._record_message = False
139
165
140
166
def __iter__ (self ):
@@ -189,6 +215,8 @@ def _process_amazon_titan_chunk(self, chunk):
189
215
self ._stream_done_callback (self ._response )
190
216
191
217
def _process_amazon_nova_chunk (self , chunk ):
218
+ # pylint: disable=too-many-branches
219
+ # TODO: handle tool calls!
192
220
if "messageStart" in chunk :
193
221
# {'messageStart': {'role': 'assistant'}}
194
222
if chunk ["messageStart" ].get ("role" ) == "assistant" :
@@ -199,9 +227,10 @@ def _process_amazon_nova_chunk(self, chunk):
199
227
if "contentBlockDelta" in chunk :
200
228
# {'contentBlockDelta': {'delta': {'text': "Hello"}, 'contentBlockIndex': 0}}
201
229
if self ._record_message :
202
- self ._content_buf += (
203
- chunk ["contentBlockDelta" ].get ("delta" , {}).get ("text" , "" )
204
- )
230
+ delta = chunk ["contentBlockDelta" ].get ("delta" , {})
231
+ if "text" in delta :
232
+ self ._content_block .setdefault ("text" , "" )
233
+ self ._content_block ["text" ] += delta ["text" ]
205
234
return
206
235
207
236
if "contentBlockStop" in chunk :
@@ -214,8 +243,8 @@ def _process_amazon_nova_chunk(self, chunk):
214
243
self ._response ["stopReason" ] = stop_reason
215
244
216
245
if self ._record_message :
217
- self ._message ["content" ].append ({ "text" : self ._content_buf } )
218
- self ._content_buf = ""
246
+ self ._message ["content" ].append (self ._content_block )
247
+ self ._content_block = {}
219
248
self ._response ["output" ] = {"message" : self ._message }
220
249
self ._record_message = False
221
250
self ._message = None
@@ -235,7 +264,7 @@ def _process_amazon_nova_chunk(self, chunk):
235
264
return
236
265
237
266
def _process_anthropic_claude_chunk (self , chunk ):
238
- # pylint: disable=too-many-return-statements
267
+ # pylint: disable=too-many-return-statements,too-many-branches
239
268
if not (message_type := chunk .get ("type" )):
240
269
return
241
270
@@ -252,18 +281,35 @@ def _process_anthropic_claude_chunk(self, chunk):
252
281
253
282
if message_type == "content_block_start" :
254
283
# {'type': 'content_block_start', 'index': 0, 'content_block': {'type': 'text', 'text': ''}}
284
+ # {'type': 'content_block_start', 'index': 1, 'content_block': {'type': 'tool_use', 'id': 'id', 'name': 'func_name', 'input': {}}}
285
+ if self ._record_message :
286
+ block = chunk .get ("content_block" , {})
287
+ if block .get ("type" ) == "text" :
288
+ self ._content_block = block
289
+ elif block .get ("type" ) == "tool_use" :
290
+ self ._content_block = block
255
291
return
256
292
257
293
if message_type == "content_block_delta" :
258
294
# {'type': 'content_block_delta', 'index': 0, 'delta': {'type': 'text_delta', 'text': 'Here'}}
295
+ # {'type': 'content_block_delta', 'index': 1, 'delta': {'type': 'input_json_delta', 'partial_json': ''}}
259
296
if self ._record_message :
260
- self ._content_buf += chunk .get ("delta" , {}).get ("text" , "" )
297
+ delta = chunk .get ("delta" , {})
298
+ if delta .get ("type" ) == "text_delta" :
299
+ self ._content_block ["text" ] += delta .get ("text" , "" )
300
+ elif delta .get ("type" ) == "input_json_delta" :
301
+ self ._tool_json_input_buf += delta .get ("partial_json" , "" )
261
302
return
262
303
263
304
if message_type == "content_block_stop" :
264
305
# {'type': 'content_block_stop', 'index': 0}
265
- self ._message ["content" ].append ({"text" : self ._content_buf })
266
- self ._content_buf = ""
306
+ if self ._tool_json_input_buf :
307
+ self ._content_block ["input" ] = self ._tool_json_input_buf
308
+ self ._message ["content" ].append (
309
+ _decode_tool_use (self ._content_block )
310
+ )
311
+ self ._content_block = {}
312
+ self ._tool_json_input_buf = ""
267
313
return
268
314
269
315
if message_type == "message_delta" :
@@ -297,16 +343,102 @@ def genai_capture_message_content() -> bool:
297
343
return capture_content .lower () == "true"
298
344
299
345
300
- def message_to_event (message : dict [str , Any ], capture_content : bool ) -> Event :
346
+ def extract_tool_calls (
347
+ message : dict [str , Any ], capture_content : bool
348
+ ) -> Sequence [Dict [str , Any ]] | None :
349
+ content = message .get ("content" )
350
+ if not content :
351
+ return None
352
+
353
+ tool_uses = [item ["toolUse" ] for item in content if "toolUse" in item ]
354
+ if not tool_uses :
355
+ tool_uses = [
356
+ item for item in content if item .get ("type" ) == "tool_use"
357
+ ]
358
+ tool_id_key = "id"
359
+ else :
360
+ tool_id_key = "toolUseId"
361
+
362
+ if not tool_uses :
363
+ return None
364
+
365
+ tool_calls = []
366
+ for tool_use in tool_uses :
367
+ tool_call = {"type" : "function" }
368
+ if call_id := tool_use .get (tool_id_key ):
369
+ tool_call ["id" ] = call_id
370
+
371
+ if function_name := tool_use .get ("name" ):
372
+ tool_call ["function" ] = {"name" : function_name }
373
+
374
+ if (function_input := tool_use .get ("input" )) and capture_content :
375
+ tool_call .setdefault ("function" , {})
376
+ tool_call ["function" ]["arguments" ] = function_input
377
+
378
+ tool_calls .append (tool_call )
379
+ return tool_calls
380
+
381
+
382
+ def extract_tool_results (
383
+ message : dict [str , Any ], capture_content : bool
384
+ ) -> Iterator [Dict [str , Any ]]:
385
+ content = message .get ("content" )
386
+ if not content :
387
+ return
388
+
389
+ # Converse format
390
+ tool_results = [
391
+ item ["toolResult" ] for item in content if "toolResult" in item
392
+ ]
393
+ # InvokeModel anthropic.claude format
394
+ if not tool_results :
395
+ tool_results = [
396
+ item for item in content if item .get ("type" ) == "tool_result"
397
+ ]
398
+ tool_id_key = "tool_use_id"
399
+ else :
400
+ tool_id_key = "toolUseId"
401
+
402
+ if not tool_results :
403
+ return
404
+
405
+ # if we have a user message with toolResult keys we need to send
406
+ # one tool event for each part of the content
407
+ for tool_result in tool_results :
408
+ body = {}
409
+ if tool_id := tool_result .get (tool_id_key ):
410
+ body ["id" ] = tool_id
411
+ tool_content = tool_result .get ("content" )
412
+ if capture_content and tool_content :
413
+ body ["content" ] = tool_content
414
+
415
+ yield body
416
+
417
+
418
+ def message_to_event (
419
+ message : dict [str , Any ], capture_content : bool
420
+ ) -> Iterator [Event ]:
301
421
attributes = {GEN_AI_SYSTEM : GenAiSystemValues .AWS_BEDROCK .value }
302
422
role = message .get ("role" )
303
423
content = message .get ("content" )
304
424
305
425
body = {}
306
426
if capture_content and content :
307
427
body ["content" ] = content
308
-
309
- return Event (
428
+ if role == "assistant" :
429
+ # the assistant message contains both tool calls and model thinking content
430
+ if tool_calls := extract_tool_calls (message , capture_content ):
431
+ body ["tool_calls" ] = tool_calls
432
+ elif role == "user" :
433
+ # in case of tool calls we send one tool event for tool call and one for the user event
434
+ for tool_body in extract_tool_results (message , capture_content ):
435
+ yield Event (
436
+ name = "gen_ai.tool.message" ,
437
+ attributes = attributes ,
438
+ body = tool_body ,
439
+ )
440
+
441
+ yield Event (
310
442
name = f"gen_ai.{ role } .message" ,
311
443
attributes = attributes ,
312
444
body = body if body else None ,
@@ -331,8 +463,12 @@ def from_converse(
331
463
else :
332
464
# amazon.titan does not serialize the role
333
465
message = {}
334
- if capture_content :
466
+
467
+ if tool_calls := extract_tool_calls (orig_message , capture_content ):
468
+ message ["tool_calls" ] = tool_calls
469
+ elif capture_content :
335
470
message ["content" ] = orig_message ["content" ]
471
+
336
472
return cls (message , response ["stopReason" ], index = 0 )
337
473
338
474
@classmethod
@@ -350,14 +486,11 @@ def from_invoke_amazon_titan(
350
486
def from_invoke_anthropic_claude (
351
487
cls , response : dict [str , Any ], capture_content : bool
352
488
) -> _Choice :
353
- if capture_content :
354
- message = {
355
- "content" : response ["content" ],
356
- "role" : response ["role" ],
357
- }
358
- else :
359
- message = {"role" : response ["role" ]}
360
-
489
+ message = {"role" : response ["role" ]}
490
+ if tool_calls := extract_tool_calls (response , capture_content ):
491
+ message ["tool_calls" ] = tool_calls
492
+ elif capture_content :
493
+ message ["content" ] = response ["content" ]
361
494
return cls (message , response ["stop_reason" ], index = 0 )
362
495
363
496
def _to_body_dict (self ) -> dict [str , Any ]:
0 commit comments