Skip to content

Commit 275a7a0

Browse files
fix: usage chat stream (#54)
* fix getting usage field * return usage only for last chunk * assure backward compatibility * fix linting * update package version
1 parent e215a3b commit 275a7a0

File tree

2 files changed

+20
-17
lines changed

2 files changed

+20
-17
lines changed

libs/ibm/langchain_ibm/chat_models.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -331,15 +331,16 @@ def _convert_delta_to_message_chunk(
331331
def _convert_chunk_to_generation_chunk(
332332
chunk: dict,
333333
default_chunk_class: Type,
334-
base_generation_info: Optional[Dict],
335-
is_first_chunk: bool,
336334
is_first_tool_chunk: bool,
335+
_prompt_tokens_included: bool,
337336
) -> Optional[ChatGenerationChunk]:
338337
token_usage = chunk.get("usage")
339338
choices = chunk.get("choices", [])
340339

341340
usage_metadata: Optional[UsageMetadata] = (
342-
_create_usage_metadata(token_usage, is_first_chunk) if token_usage else None
341+
_create_usage_metadata(token_usage, _prompt_tokens_included)
342+
if token_usage
343+
else None
343344
)
344345

345346
if len(choices) == 0:
@@ -356,7 +357,7 @@ def _convert_chunk_to_generation_chunk(
356357
message_chunk = _convert_delta_to_message_chunk(
357358
choice["delta"], default_chunk_class, chunk["id"], is_first_tool_chunk
358359
)
359-
generation_info = {**base_generation_info} if base_generation_info else {}
360+
generation_info = {}
360361

361362
if finish_reason := choice.get("finish_reason"):
362363
generation_info["finish_reason"] = finish_reason
@@ -727,25 +728,26 @@ def _stream(
727728
updated_params = self._merge_params(params, kwargs)
728729

729730
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
730-
base_generation_info: dict = {}
731731

732-
is_first_chunk = True
733732
is_first_tool_chunk = True
733+
_prompt_tokens_included = False
734734

735735
for chunk in self.watsonx_model.chat_stream(
736736
messages=message_dicts, **(kwargs | {"params": updated_params})
737737
):
738738
if not isinstance(chunk, dict):
739739
chunk = chunk.model_dump()
740740
generation_chunk = _convert_chunk_to_generation_chunk(
741-
chunk,
742-
default_chunk_class,
743-
base_generation_info if is_first_chunk else {},
744-
is_first_chunk,
745-
is_first_tool_chunk,
741+
chunk, default_chunk_class, is_first_tool_chunk, _prompt_tokens_included
746742
)
747743
if generation_chunk is None:
748744
continue
745+
746+
if (
747+
hasattr(generation_chunk.message, "usage_metadata")
748+
and generation_chunk.message.usage_metadata
749+
):
750+
_prompt_tokens_included = True
749751
default_chunk_class = generation_chunk.message.__class__
750752
logprobs = (generation_chunk.generation_info or {}).get("logprobs")
751753
if run_manager:
@@ -763,8 +765,6 @@ def _stream(
763765
if isinstance(first_tool_call, dict) and first_tool_call.get("name"):
764766
is_first_tool_chunk = False
765767

766-
is_first_chunk = False
767-
768768
yield generation_chunk
769769

770770
@staticmethod
@@ -809,7 +809,7 @@ def _create_chat_result(
809809
message = _convert_dict_to_message(res["message"], response["id"])
810810

811811
if token_usage and isinstance(message, AIMessage):
812-
message.usage_metadata = _create_usage_metadata(token_usage, True)
812+
message.usage_metadata = _create_usage_metadata(token_usage, False)
813813
generation_info = generation_info or {}
814814
generation_info["finish_reason"] = (
815815
res.get("finish_reason")
@@ -1200,9 +1200,12 @@ def _lc_invalid_tool_call_to_watsonx_tool_call(
12001200

12011201

12021202
def _create_usage_metadata(
1203-
oai_token_usage: dict, is_first_chunk: bool
1203+
oai_token_usage: dict,
1204+
_prompt_tokens_included: bool,
12041205
) -> UsageMetadata:
1205-
input_tokens = oai_token_usage.get("prompt_tokens", 0) if is_first_chunk else 0
1206+
input_tokens = (
1207+
oai_token_usage.get("prompt_tokens", 0) if not _prompt_tokens_included else 0
1208+
)
12061209
output_tokens = oai_token_usage.get("completion_tokens", 0)
12071210
total_tokens = oai_token_usage.get("total_tokens", input_tokens + output_tokens)
12081211
return UsageMetadata(

libs/ibm/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "langchain-ibm"
3-
version = "0.3.5"
3+
version = "0.3.6"
44
description = "An integration package connecting IBM watsonx.ai and LangChain"
55
authors = ["IBM"]
66
readme = "README.md"

0 commit comments

Comments
 (0)