diff --git a/examples/bedrock.py b/examples/bedrock.py index 7b1d85f51..dd5ea957f 100644 --- a/examples/bedrock.py +++ b/examples/bedrock.py @@ -10,6 +10,8 @@ # https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html client = AnthropicBedrock() +model = "anthropic.claude-sonnet-4-5-20250929-v1:0" + print("------ standard response ------") message = client.messages.create( max_tokens=1024, @@ -19,7 +21,7 @@ "content": "Hello!", } ], - model="anthropic.claude-sonnet-4-5-20250929-v1:0", + model=model, ) print(message.model_dump_json(indent=2)) @@ -33,7 +35,7 @@ "content": "Say hello there!", } ], - model="anthropic.claude-sonnet-4-5-20250929-v1:0", + model=model, ) as stream: for text in stream.text_stream: print(text, end="", flush=True) @@ -44,3 +46,15 @@ # inside of the context manager accumulated = stream.get_final_message() print("accumulated message: ", accumulated.model_dump_json(indent=2)) + +print("------ count tokens ------") +count = client.messages.count_tokens( + model=model, + messages=[ + { + "role": "user", + "content": "Hello, world!", + } + ], +) +print(count.model_dump_json(indent=2)) diff --git a/src/anthropic/lib/bedrock/_client.py b/src/anthropic/lib/bedrock/_client.py index 013d27023..69957a7d4 100644 --- a/src/anthropic/lib/bedrock/_client.py +++ b/src/anthropic/lib/bedrock/_client.py @@ -1,6 +1,8 @@ from __future__ import annotations import os +import json +import base64 import logging import urllib.parse from typing import Any, Union, Mapping, TypeVar @@ -62,7 +64,23 @@ def _prepare_options(input_options: FinalRequestOptions) -> FinalRequestOptions: raise AnthropicError("The Batch API is not supported in Bedrock yet") if options.url == "/v1/messages/count_tokens": - raise AnthropicError("Token counting is not supported in Bedrock yet") + if not is_dict(options.json_data): + raise RuntimeError("Expected dictionary json_data for /v1/messages/count_tokens endpoint") + + model = options.json_data.pop("model", None) + model = urllib.parse.quote(str(model), safe=":") + + # max_tokens is required for the request to be valid. + # Use 500 which is enough to get a response. + options.json_data["max_tokens"] = 500 + + # body element of the request is base64 encoded. + # See https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_InvokeModelTokensRequest.html + input_to_count = json.dumps(options.json_data) + encoded_bytes = base64.b64encode(input_to_count.encode("utf-8")).decode("utf-8") + options.json_data = {"input": {"invokeModel": {"body": encoded_bytes}}} + + options.url = f"/model/{model}/count-tokens" return options diff --git a/tests/lib/test_bedrock.py b/tests/lib/test_bedrock.py index fe62da434..a25f2c6ca 100644 --- a/tests/lib/test_bedrock.py +++ b/tests/lib/test_bedrock.py @@ -1,4 +1,6 @@ import re +import json +import base64 import typing as t import tempfile from typing import TypedDict, cast @@ -96,6 +98,38 @@ def test_messages_retries(respx_mock: MockRouter) -> None: ) +@pytest.mark.filterwarnings("ignore::DeprecationWarning") +@pytest.mark.respx() +def test_messages_count_tokens(respx_mock: MockRouter) -> None: + respx_mock.post(re.compile(r"https://bedrock-runtime\.us-east-1\.amazonaws\.com/model/.*/count-tokens")).mock( + side_effect=[httpx.Response(200, json={"foo": "bar"})], + ) + + sync_client.messages.count_tokens( + model="anthropic.claude-3-5-sonnet-20241022-v2:0", + messages=[{"role": "user", "content": "Hello, world!"}], + ) + + calls = cast("list[MockRequestCall]", respx_mock.calls) + assert len(calls) == 1 + assert ( + calls[0].request.url + == "https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-5-sonnet-20241022-v2:0/count-tokens" + ) + + # Check that the request content is correct. + requested_content = json.loads(calls[0].request.content) + assert "input" in requested_content + assert "invokeModel" in requested_content["input"] + assert "body" in requested_content["input"]["invokeModel"] + decoded_body = base64.b64decode(requested_content["input"]["invokeModel"]["body"]).decode("utf-8") + assert json.loads(decoded_body) == { + "anthropic_version": "bedrock-2023-05-31", + "max_tokens": 500, + "messages": [{"role": "user", "content": "Hello, world!"}], + } + + @pytest.mark.filterwarnings("ignore::DeprecationWarning") @pytest.mark.respx() @pytest.mark.asyncio()