@@ -274,6 +274,7 @@ def _convert_delta_to_message_chunk(
274
274
_dict : Mapping [str , Any ],
275
275
default_class : Type [BaseMessageChunk ],
276
276
call_id : str ,
277
+ is_first_tool_chunk : bool ,
277
278
) -> BaseMessageChunk :
278
279
id_ = call_id
279
280
role = cast (str , _dict .get ("role" ))
@@ -290,8 +291,12 @@ def _convert_delta_to_message_chunk(
290
291
try :
291
292
tool_call_chunks = [
292
293
tool_call_chunk (
293
- name = rtc ["function" ].get ("name" ),
294
+ name = rtc ["function" ].get ("name" )
295
+ if is_first_tool_chunk or (rtc .get ("id" ) is not None )
296
+ else None ,
294
297
args = rtc ["function" ].get ("arguments" ),
298
+ # `id` is provided only for the first delta with unique tool_calls
299
+ # (multiple tool calls scenario)
295
300
id = rtc .get ("id" ),
296
301
index = rtc ["index" ],
297
302
)
@@ -328,6 +333,7 @@ def _convert_chunk_to_generation_chunk(
328
333
default_chunk_class : Type ,
329
334
base_generation_info : Optional [Dict ],
330
335
is_first_chunk : bool ,
336
+ is_first_tool_chunk : bool ,
331
337
) -> Optional [ChatGenerationChunk ]:
332
338
token_usage = chunk .get ("usage" )
333
339
choices = chunk .get ("choices" , [])
@@ -348,7 +354,7 @@ def _convert_chunk_to_generation_chunk(
348
354
return None
349
355
350
356
message_chunk = _convert_delta_to_message_chunk (
351
- choice ["delta" ], default_chunk_class , chunk ["id" ]
357
+ choice ["delta" ], default_chunk_class , chunk ["id" ], is_first_tool_chunk
352
358
)
353
359
generation_info = {** base_generation_info } if base_generation_info else {}
354
360
@@ -722,6 +728,7 @@ def _stream(
722
728
base_generation_info : dict = {}
723
729
724
730
is_first_chunk = True
731
+ is_first_tool_chunk = True
725
732
726
733
for chunk in self .watsonx_model .chat_stream (
727
734
messages = message_dicts , ** (kwargs | {"params" : updated_params })
@@ -733,6 +740,7 @@ def _stream(
733
740
default_chunk_class ,
734
741
base_generation_info if is_first_chunk else {},
735
742
is_first_chunk ,
743
+ is_first_tool_chunk ,
736
744
)
737
745
if generation_chunk is None :
738
746
continue
@@ -742,6 +750,16 @@ def _stream(
742
750
run_manager .on_llm_new_token (
743
751
generation_chunk .text , chunk = generation_chunk , logprobs = logprobs
744
752
)
753
+ if hasattr (generation_chunk .message , "tool_calls" ) and isinstance (
754
+ generation_chunk .message .tool_calls , list
755
+ ):
756
+ first_tool_call = (
757
+ generation_chunk .message .tool_calls [0 ]
758
+ if generation_chunk .message .tool_calls
759
+ else None
760
+ )
761
+ if isinstance (first_tool_call , dict ) and first_tool_call .get ("name" ):
762
+ is_first_tool_chunk = False
745
763
746
764
is_first_chunk = False
747
765
0 commit comments