@@ -331,15 +331,16 @@ def _convert_delta_to_message_chunk(
331
331
def _convert_chunk_to_generation_chunk (
332
332
chunk : dict ,
333
333
default_chunk_class : Type ,
334
- base_generation_info : Optional [Dict ],
335
- is_first_chunk : bool ,
336
334
is_first_tool_chunk : bool ,
335
+ _prompt_tokens_included : bool ,
337
336
) -> Optional [ChatGenerationChunk ]:
338
337
token_usage = chunk .get ("usage" )
339
338
choices = chunk .get ("choices" , [])
340
339
341
340
usage_metadata : Optional [UsageMetadata ] = (
342
- _create_usage_metadata (token_usage , is_first_chunk ) if token_usage else None
341
+ _create_usage_metadata (token_usage , _prompt_tokens_included )
342
+ if token_usage
343
+ else None
343
344
)
344
345
345
346
if len (choices ) == 0 :
@@ -356,7 +357,7 @@ def _convert_chunk_to_generation_chunk(
356
357
message_chunk = _convert_delta_to_message_chunk (
357
358
choice ["delta" ], default_chunk_class , chunk ["id" ], is_first_tool_chunk
358
359
)
359
- generation_info = {** base_generation_info } if base_generation_info else { }
360
+ generation_info = {}
360
361
361
362
if finish_reason := choice .get ("finish_reason" ):
362
363
generation_info ["finish_reason" ] = finish_reason
@@ -727,25 +728,26 @@ def _stream(
727
728
updated_params = self ._merge_params (params , kwargs )
728
729
729
730
default_chunk_class : Type [BaseMessageChunk ] = AIMessageChunk
730
- base_generation_info : dict = {}
731
731
732
- is_first_chunk = True
733
732
is_first_tool_chunk = True
733
+ _prompt_tokens_included = False
734
734
735
735
for chunk in self .watsonx_model .chat_stream (
736
736
messages = message_dicts , ** (kwargs | {"params" : updated_params })
737
737
):
738
738
if not isinstance (chunk , dict ):
739
739
chunk = chunk .model_dump ()
740
740
generation_chunk = _convert_chunk_to_generation_chunk (
741
- chunk ,
742
- default_chunk_class ,
743
- base_generation_info if is_first_chunk else {},
744
- is_first_chunk ,
745
- is_first_tool_chunk ,
741
+ chunk , default_chunk_class , is_first_tool_chunk , _prompt_tokens_included
746
742
)
747
743
if generation_chunk is None :
748
744
continue
745
+
746
+ if (
747
+ hasattr (generation_chunk .message , "usage_metadata" )
748
+ and generation_chunk .message .usage_metadata
749
+ ):
750
+ _prompt_tokens_included = True
749
751
default_chunk_class = generation_chunk .message .__class__
750
752
logprobs = (generation_chunk .generation_info or {}).get ("logprobs" )
751
753
if run_manager :
@@ -763,8 +765,6 @@ def _stream(
763
765
if isinstance (first_tool_call , dict ) and first_tool_call .get ("name" ):
764
766
is_first_tool_chunk = False
765
767
766
- is_first_chunk = False
767
-
768
768
yield generation_chunk
769
769
770
770
@staticmethod
@@ -809,7 +809,7 @@ def _create_chat_result(
809
809
message = _convert_dict_to_message (res ["message" ], response ["id" ])
810
810
811
811
if token_usage and isinstance (message , AIMessage ):
812
- message .usage_metadata = _create_usage_metadata (token_usage , True )
812
+ message .usage_metadata = _create_usage_metadata (token_usage , False )
813
813
generation_info = generation_info or {}
814
814
generation_info ["finish_reason" ] = (
815
815
res .get ("finish_reason" )
@@ -1200,9 +1200,12 @@ def _lc_invalid_tool_call_to_watsonx_tool_call(
1200
1200
1201
1201
1202
1202
def _create_usage_metadata (
1203
- oai_token_usage : dict , is_first_chunk : bool
1203
+ oai_token_usage : dict ,
1204
+ _prompt_tokens_included : bool ,
1204
1205
) -> UsageMetadata :
1205
- input_tokens = oai_token_usage .get ("prompt_tokens" , 0 ) if is_first_chunk else 0
1206
+ input_tokens = (
1207
+ oai_token_usage .get ("prompt_tokens" , 0 ) if not _prompt_tokens_included else 0
1208
+ )
1206
1209
output_tokens = oai_token_usage .get ("completion_tokens" , 0 )
1207
1210
total_tokens = oai_token_usage .get ("total_tokens" , input_tokens + output_tokens )
1208
1211
return UsageMetadata (
0 commit comments