Skip to content

Commit 4a88b7f

Browse files
authored
Merge pull request #20 from NAPTlME/bedrock-token-count-callbacks
Bedrock token count callbacks
2 parents 622756c + bfa0871 commit 4a88b7f

File tree

5 files changed

+268
-66
lines changed

5 files changed

+268
-66
lines changed

libs/aws/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ test tests integration_test integration_tests:
1818
PYTHON_FILES=.
1919
MYPY_CACHE=.mypy_cache
2020
lint format: PYTHON_FILES=.
21-
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/partners/aws --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
21+
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/aws --name-only --diff-filter=d main | grep -E '\.py$$|\.ipynb$$')
2222
lint_package: PYTHON_FILES=langchain_aws
2323
lint_tests: PYTHON_FILES=tests
2424
lint_tests: MYPY_CACHE=.mypy_cache_test

libs/aws/langchain_aws/chat_models/bedrock.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,10 @@
3535
from langchain_core.tools import BaseTool
3636

3737
from langchain_aws.function_calling import convert_to_anthropic_tool, get_system_message
38-
from langchain_aws.llms.bedrock import BedrockBase
38+
from langchain_aws.llms.bedrock import (
39+
BedrockBase,
40+
_combine_generation_info_for_llm_result,
41+
)
3942
from langchain_aws.utils import (
4043
get_num_tokens_anthropic,
4144
get_token_ids_anthropic,
@@ -383,7 +386,13 @@ def _stream(
383386
**kwargs,
384387
):
385388
delta = chunk.text
386-
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
389+
yield ChatGenerationChunk(
390+
message=AIMessageChunk(
391+
content=delta, response_metadata=chunk.generation_info
392+
)
393+
if chunk.generation_info is not None
394+
else AIMessageChunk(content=delta)
395+
)
387396

388397
def _generate(
389398
self,
@@ -393,11 +402,18 @@ def _generate(
393402
**kwargs: Any,
394403
) -> ChatResult:
395404
completion = ""
396-
llm_output: Dict[str, Any] = {"model_id": self.model_id}
397-
usage_info: Dict[str, Any] = {}
405+
llm_output: Dict[str, Any] = {}
406+
provider_stop_reason_code = self.provider_stop_reason_key_map.get(
407+
self._get_provider(), "stop_reason"
408+
)
398409
if self.streaming:
410+
response_metadata: List[Dict[str, Any]] = []
399411
for chunk in self._stream(messages, stop, run_manager, **kwargs):
400412
completion += chunk.text
413+
response_metadata.append(chunk.message.response_metadata)
414+
llm_output = _combine_generation_info_for_llm_result(
415+
response_metadata, provider_stop_reason_code
416+
)
401417
else:
402418
provider = self._get_provider()
403419
prompt, system, formatted_messages = None, None, None
@@ -420,7 +436,7 @@ def _generate(
420436
if stop:
421437
params["stop_sequences"] = stop
422438

423-
completion, usage_info = self._prepare_input_and_invoke(
439+
completion, llm_output = self._prepare_input_and_invoke(
424440
prompt=prompt,
425441
stop=stop,
426442
run_manager=run_manager,
@@ -429,14 +445,11 @@ def _generate(
429445
**params,
430446
)
431447

432-
llm_output["usage"] = usage_info
433-
448+
llm_output["model_id"] = self.model_id
434449
return ChatResult(
435450
generations=[
436451
ChatGeneration(
437-
message=AIMessage(
438-
content=completion, additional_kwargs={"usage": usage_info}
439-
)
452+
message=AIMessage(content=completion, additional_kwargs=llm_output)
440453
)
441454
],
442455
llm_output=llm_output,
@@ -447,7 +460,7 @@ def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
447460
final_output = {}
448461
for output in llm_outputs:
449462
output = output or {}
450-
usage = output.pop("usage", {})
463+
usage = output.get("usage", {})
451464
for token_type, token_count in usage.items():
452465
final_usage[token_type] += token_count
453466
final_output.update(output)

0 commit comments

Comments
 (0)