Skip to content

Commit bfaab8a

Browse files
Merge pull request #14557 from timelfrink/fix/issue-14478-bedrock-count-tokens-endpoint
Implement AWS Bedrock CountTokens API support
2 parents ff36dfd + 7538fc0 commit bfaab8a

File tree

5 files changed

+668
-154
lines changed

5 files changed

+668
-154
lines changed
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
"""
2+
AWS Bedrock CountTokens API handler.
3+
4+
Simplified handler leveraging existing LiteLLM Bedrock infrastructure.
5+
"""
6+
7+
from typing import Any, Dict
8+
9+
from fastapi import HTTPException
10+
11+
from litellm._logging import verbose_logger
12+
from litellm.llms.bedrock.count_tokens.transformation import BedrockCountTokensConfig
13+
14+
15+
class BedrockCountTokensHandler(BedrockCountTokensConfig):
16+
"""
17+
Simplified handler for AWS Bedrock CountTokens API requests.
18+
19+
Uses existing LiteLLM infrastructure for authentication and request handling.
20+
"""
21+
22+
async def handle_count_tokens_request(
23+
self,
24+
request_data: Dict[str, Any],
25+
litellm_params: Dict[str, Any],
26+
resolved_model: str,
27+
) -> Dict[str, Any]:
28+
"""
29+
Handle a CountTokens request using existing LiteLLM patterns.
30+
31+
Args:
32+
request_data: The incoming request payload
33+
litellm_params: LiteLLM configuration parameters
34+
resolved_model: The actual model ID resolved from router
35+
36+
Returns:
37+
Dictionary containing token count response
38+
"""
39+
try:
40+
# Validate the request
41+
self.validate_count_tokens_request(request_data)
42+
43+
verbose_logger.debug(
44+
f"Processing CountTokens request for resolved model: {resolved_model}"
45+
)
46+
47+
# Get AWS region using existing LiteLLM function
48+
aws_region_name = self._get_aws_region_name(
49+
optional_params=litellm_params,
50+
model=resolved_model,
51+
model_id=None,
52+
)
53+
54+
verbose_logger.debug(f"Retrieved AWS region: {aws_region_name}")
55+
56+
# Transform request to Bedrock format (supports both Converse and InvokeModel)
57+
bedrock_request = self.transform_anthropic_to_bedrock_count_tokens(
58+
request_data=request_data
59+
)
60+
61+
verbose_logger.debug(f"Transformed request: {bedrock_request}")
62+
63+
# Get endpoint URL using simplified function
64+
endpoint_url = self.get_bedrock_count_tokens_endpoint(
65+
resolved_model, aws_region_name
66+
)
67+
68+
verbose_logger.debug(f"Making request to: {endpoint_url}")
69+
70+
# Use existing _sign_request method from BaseAWSLLM
71+
headers = {"Content-Type": "application/json"}
72+
signed_headers, signed_body = self._sign_request(
73+
service_name="bedrock",
74+
headers=headers,
75+
optional_params=litellm_params,
76+
request_data=bedrock_request,
77+
api_base=endpoint_url,
78+
model=resolved_model,
79+
)
80+
81+
# Make HTTP request
82+
import httpx
83+
84+
async with httpx.AsyncClient() as client:
85+
response = await client.post(
86+
endpoint_url,
87+
headers=signed_headers,
88+
content=signed_body,
89+
timeout=30.0,
90+
)
91+
92+
verbose_logger.debug(f"Response status: {response.status_code}")
93+
94+
if response.status_code != 200:
95+
error_text = response.text
96+
verbose_logger.error(f"AWS Bedrock error: {error_text}")
97+
raise HTTPException(
98+
status_code=400,
99+
detail={"error": f"AWS Bedrock error: {error_text}"},
100+
)
101+
102+
bedrock_response = response.json()
103+
104+
verbose_logger.debug(f"Bedrock response: {bedrock_response}")
105+
106+
# Transform response back to expected format
107+
final_response = self.transform_bedrock_response_to_anthropic(
108+
bedrock_response
109+
)
110+
111+
verbose_logger.debug(f"Final response: {final_response}")
112+
113+
return final_response
114+
115+
except HTTPException:
116+
# Re-raise HTTP exceptions as-is
117+
raise
118+
except Exception as e:
119+
verbose_logger.error(f"Error in CountTokens handler: {str(e)}")
120+
raise HTTPException(
121+
status_code=500,
122+
detail={"error": f"CountTokens processing error: {str(e)}"},
123+
)
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
"""
2+
AWS Bedrock CountTokens API transformation logic.
3+
4+
This module handles the transformation of requests from Anthropic Messages API format
5+
to AWS Bedrock's CountTokens API format and vice versa.
6+
"""
7+
8+
from typing import Any, Dict, List
9+
10+
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
11+
from litellm.llms.bedrock.common_utils import BedrockModelInfo
12+
13+
14+
class BedrockCountTokensConfig(BaseAWSLLM):
15+
"""
16+
Configuration and transformation logic for AWS Bedrock CountTokens API.
17+
18+
AWS Bedrock CountTokens API Specification:
19+
- Endpoint: POST /model/{modelId}/count-tokens
20+
- Input formats: 'invokeModel' or 'converse'
21+
- Response: {"inputTokens": <number>}
22+
"""
23+
24+
def _detect_input_type(self, request_data: Dict[str, Any]) -> str:
25+
"""
26+
Detect whether to use 'converse' or 'invokeModel' input format.
27+
28+
Args:
29+
request_data: The original request data
30+
31+
Returns:
32+
'converse' or 'invokeModel'
33+
"""
34+
# If the request has messages in the expected Anthropic format, use converse
35+
if "messages" in request_data and isinstance(request_data["messages"], list):
36+
return "converse"
37+
38+
# For raw text or other formats, use invokeModel
39+
# This handles cases where the input is prompt-based or already in raw Bedrock format
40+
return "invokeModel"
41+
42+
def transform_anthropic_to_bedrock_count_tokens(
43+
self,
44+
request_data: Dict[str, Any],
45+
) -> Dict[str, Any]:
46+
"""
47+
Transform request to Bedrock CountTokens format.
48+
Supports both Converse and InvokeModel input types.
49+
50+
Input (Anthropic format):
51+
{
52+
"model": "claude-3-5-sonnet",
53+
"messages": [{"role": "user", "content": "Hello!"}]
54+
}
55+
56+
Output (Bedrock CountTokens format for Converse):
57+
{
58+
"input": {
59+
"converse": {
60+
"messages": [...],
61+
"system": [...] (if present)
62+
}
63+
}
64+
}
65+
66+
Output (Bedrock CountTokens format for InvokeModel):
67+
{
68+
"input": {
69+
"invokeModel": {
70+
"body": "{...raw model input...}"
71+
}
72+
}
73+
}
74+
"""
75+
input_type = self._detect_input_type(request_data)
76+
77+
if input_type == "converse":
78+
return self._transform_to_converse_format(request_data.get("messages", []))
79+
else:
80+
return self._transform_to_invoke_model_format(request_data)
81+
82+
def _transform_to_converse_format(
83+
self, messages: List[Dict[str, Any]]
84+
) -> Dict[str, Any]:
85+
"""Transform to Converse input format."""
86+
# Extract system messages if present
87+
system_messages = []
88+
user_messages = []
89+
90+
for message in messages:
91+
if message.get("role") == "system":
92+
system_messages.append({"text": message.get("content", "")})
93+
else:
94+
# Transform message content to Bedrock format
95+
transformed_message: Dict[str, Any] = {"role": message.get("role"), "content": []}
96+
97+
# Handle content - ensure it's in the correct array format
98+
content = message.get("content", "")
99+
if isinstance(content, str):
100+
# String content -> convert to text block
101+
transformed_message["content"].append({"text": content})
102+
elif isinstance(content, list):
103+
# Already in blocks format - use as is
104+
transformed_message["content"] = content
105+
106+
user_messages.append(transformed_message)
107+
108+
# Build the converse input format
109+
converse_input = {"messages": user_messages}
110+
111+
# Add system messages if present
112+
if system_messages:
113+
converse_input["system"] = system_messages
114+
115+
# Build the complete request
116+
return {"input": {"converse": converse_input}}
117+
118+
def _transform_to_invoke_model_format(
119+
self, request_data: Dict[str, Any]
120+
) -> Dict[str, Any]:
121+
"""Transform to InvokeModel input format."""
122+
import json
123+
124+
# For InvokeModel, we need to provide the raw body that would be sent to the model
125+
# Remove the 'model' field from the body as it's not part of the model input
126+
body_data = {k: v for k, v in request_data.items() if k != "model"}
127+
128+
return {"input": {"invokeModel": {"body": json.dumps(body_data)}}}
129+
130+
def get_bedrock_count_tokens_endpoint(
131+
self, model: str, aws_region_name: str
132+
) -> str:
133+
"""
134+
Construct the AWS Bedrock CountTokens API endpoint using existing LiteLLM functions.
135+
136+
Args:
137+
model: The resolved model ID from router lookup
138+
aws_region_name: AWS region (e.g., "eu-west-1")
139+
140+
Returns:
141+
Complete endpoint URL for CountTokens API
142+
"""
143+
# Use existing LiteLLM function to get the base model ID (removes region prefix)
144+
model_id = BedrockModelInfo.get_base_model(model)
145+
146+
# Remove bedrock/ prefix if present
147+
if model_id.startswith("bedrock/"):
148+
model_id = model_id[8:] # Remove "bedrock/" prefix
149+
150+
base_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
151+
endpoint = f"{base_url}/model/{model_id}/count-tokens"
152+
153+
return endpoint
154+
155+
def transform_bedrock_response_to_anthropic(
156+
self, bedrock_response: Dict[str, Any]
157+
) -> Dict[str, Any]:
158+
"""
159+
Transform Bedrock CountTokens response to Anthropic format.
160+
161+
Input (Bedrock response):
162+
{
163+
"inputTokens": 123
164+
}
165+
166+
Output (Anthropic format):
167+
{
168+
"input_tokens": 123
169+
}
170+
"""
171+
input_tokens = bedrock_response.get("inputTokens", 0)
172+
173+
return {"input_tokens": input_tokens}
174+
175+
def validate_count_tokens_request(self, request_data: Dict[str, Any]) -> None:
176+
"""
177+
Validate the incoming count tokens request.
178+
Supports both Converse and InvokeModel input formats.
179+
180+
Args:
181+
request_data: The request payload
182+
183+
Raises:
184+
ValueError: If the request is invalid
185+
"""
186+
if not request_data.get("model"):
187+
raise ValueError("model parameter is required")
188+
189+
input_type = self._detect_input_type(request_data)
190+
191+
if input_type == "converse":
192+
# Validate Converse format (messages-based)
193+
messages = request_data.get("messages", [])
194+
if not messages:
195+
raise ValueError("messages parameter is required for Converse input")
196+
197+
if not isinstance(messages, list):
198+
raise ValueError("messages must be a list")
199+
200+
for i, message in enumerate(messages):
201+
if not isinstance(message, dict):
202+
raise ValueError(f"Message {i} must be a dictionary")
203+
204+
if "role" not in message:
205+
raise ValueError(f"Message {i} must have a 'role' field")
206+
207+
if "content" not in message:
208+
raise ValueError(f"Message {i} must have a 'content' field")
209+
else:
210+
# For InvokeModel format, we need at least some content to count tokens
211+
# The content structure varies by model, so we do minimal validation
212+
if len(request_data) <= 1: # Only has 'model' field
213+
raise ValueError("Request must contain content to count tokens")

0 commit comments

Comments
 (0)