Skip to content

Commit 7eecba6

Browse files
committed
Implement AWS Bedrock CountTokens API support
- Add support for both Converse and InvokeModel input formats - Implement endpoint handling in pass_through_endpoints - Add transformation logic for AWS Bedrock CountTokens API - Simplify model resolution using existing router patterns - Support token counting for messages and raw text inputs
1 parent f34bbd1 commit 7eecba6

File tree

3 files changed

+422
-0
lines changed

3 files changed

+422
-0
lines changed
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
"""
2+
AWS Bedrock CountTokens API handler.
3+
4+
Simplified handler leveraging existing LiteLLM Bedrock infrastructure.
5+
"""
6+
7+
from typing import Any, Dict
8+
9+
from fastapi import HTTPException
10+
11+
from litellm._logging import verbose_logger
12+
from litellm.llms.bedrock.count_tokens.transformation import BedrockCountTokensConfig
13+
14+
15+
class BedrockCountTokensHandler(BedrockCountTokensConfig):
16+
"""
17+
Simplified handler for AWS Bedrock CountTokens API requests.
18+
19+
Uses existing LiteLLM infrastructure for authentication and request handling.
20+
"""
21+
22+
async def handle_count_tokens_request(
23+
self,
24+
request_data: Dict[str, Any],
25+
litellm_params: Dict[str, Any],
26+
resolved_model: str,
27+
) -> Dict[str, Any]:
28+
"""
29+
Handle a CountTokens request using existing LiteLLM patterns.
30+
31+
Args:
32+
request_data: The incoming request payload
33+
litellm_params: LiteLLM configuration parameters
34+
resolved_model: The actual model ID resolved from router
35+
36+
Returns:
37+
Dictionary containing token count response
38+
"""
39+
try:
40+
# Validate the request
41+
self.validate_count_tokens_request(request_data)
42+
43+
verbose_logger.debug(f"Processing CountTokens request for resolved model: {resolved_model}")
44+
45+
# Get AWS region using existing LiteLLM function
46+
aws_region_name = self._get_aws_region_name(
47+
optional_params=litellm_params,
48+
model=resolved_model,
49+
model_id=None,
50+
)
51+
52+
verbose_logger.debug(f"Retrieved AWS region: {aws_region_name}")
53+
54+
# Transform request to Bedrock format (supports both Converse and InvokeModel)
55+
bedrock_request = self.transform_anthropic_to_bedrock_count_tokens(
56+
request_data=request_data
57+
)
58+
59+
verbose_logger.debug(f"Transformed request: {bedrock_request}")
60+
61+
# Get endpoint URL using simplified function
62+
endpoint_url = self.get_bedrock_count_tokens_endpoint(resolved_model, aws_region_name)
63+
64+
verbose_logger.debug(f"Making request to: {endpoint_url}")
65+
66+
# Use existing _sign_request method from BaseAWSLLM
67+
headers = {"Content-Type": "application/json"}
68+
signed_headers, signed_body = self._sign_request(
69+
service_name="bedrock",
70+
headers=headers,
71+
optional_params=litellm_params,
72+
request_data=bedrock_request,
73+
api_base=endpoint_url,
74+
model=resolved_model,
75+
)
76+
77+
# Make HTTP request
78+
import httpx
79+
async with httpx.AsyncClient() as client:
80+
response = await client.post(
81+
endpoint_url,
82+
headers=signed_headers,
83+
content=signed_body,
84+
timeout=30.0,
85+
)
86+
87+
verbose_logger.debug(f"Response status: {response.status_code}")
88+
89+
if response.status_code != 200:
90+
error_text = response.text
91+
verbose_logger.error(f"AWS Bedrock error: {error_text}")
92+
raise HTTPException(
93+
status_code=400,
94+
detail={"error": f"AWS Bedrock error: {error_text}"}
95+
)
96+
97+
bedrock_response = response.json()
98+
99+
verbose_logger.debug(f"Bedrock response: {bedrock_response}")
100+
101+
# Transform response back to expected format
102+
final_response = self.transform_bedrock_response_to_anthropic(bedrock_response)
103+
104+
verbose_logger.debug(f"Final response: {final_response}")
105+
106+
return final_response
107+
108+
except HTTPException:
109+
# Re-raise HTTP exceptions as-is
110+
raise
111+
except Exception as e:
112+
verbose_logger.error(f"Error in CountTokens handler: {str(e)}")
113+
raise HTTPException(
114+
status_code=500,
115+
detail={"error": f"CountTokens processing error: {str(e)}"}
116+
)
Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
"""
2+
AWS Bedrock CountTokens API transformation logic.
3+
4+
This module handles the transformation of requests from Anthropic Messages API format
5+
to AWS Bedrock's CountTokens API format and vice versa.
6+
"""
7+
8+
from typing import Any, Dict, List
9+
10+
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
11+
from litellm.llms.bedrock.common_utils import BedrockModelInfo
12+
13+
14+
class BedrockCountTokensConfig(BaseAWSLLM):
15+
"""
16+
Configuration and transformation logic for AWS Bedrock CountTokens API.
17+
18+
AWS Bedrock CountTokens API Specification:
19+
- Endpoint: POST /model/{modelId}/count-tokens
20+
- Input formats: 'invokeModel' or 'converse'
21+
- Response: {"inputTokens": <number>}
22+
"""
23+
24+
def _detect_input_type(self, request_data: Dict[str, Any]) -> str:
25+
"""
26+
Detect whether to use 'converse' or 'invokeModel' input format.
27+
28+
Args:
29+
request_data: The original request data
30+
31+
Returns:
32+
'converse' or 'invokeModel'
33+
"""
34+
# If the request has messages in the expected Anthropic format, use converse
35+
if "messages" in request_data and isinstance(request_data["messages"], list):
36+
return "converse"
37+
38+
# For raw text or other formats, use invokeModel
39+
# This handles cases where the input is prompt-based or already in raw Bedrock format
40+
return "invokeModel"
41+
42+
def transform_anthropic_to_bedrock_count_tokens(
43+
self,
44+
request_data: Dict[str, Any],
45+
) -> Dict[str, Any]:
46+
"""
47+
Transform request to Bedrock CountTokens format.
48+
Supports both Converse and InvokeModel input types.
49+
50+
Input (Anthropic format):
51+
{
52+
"model": "claude-3-5-sonnet",
53+
"messages": [{"role": "user", "content": "Hello!"}]
54+
}
55+
56+
Output (Bedrock CountTokens format for Converse):
57+
{
58+
"input": {
59+
"converse": {
60+
"messages": [...],
61+
"system": [...] (if present)
62+
}
63+
}
64+
}
65+
66+
Output (Bedrock CountTokens format for InvokeModel):
67+
{
68+
"input": {
69+
"invokeModel": {
70+
"body": "{...raw model input...}"
71+
}
72+
}
73+
}
74+
"""
75+
input_type = self._detect_input_type(request_data)
76+
77+
if input_type == "converse":
78+
return self._transform_to_converse_format(request_data.get("messages", []))
79+
else:
80+
return self._transform_to_invoke_model_format(request_data)
81+
82+
def _transform_to_converse_format(self, messages: List[Dict[str, Any]]) -> Dict[str, Any]:
83+
"""Transform to Converse input format."""
84+
# Extract system messages if present
85+
system_messages = []
86+
user_messages = []
87+
88+
for message in messages:
89+
if message.get("role") == "system":
90+
system_messages.append({"text": message.get("content", "")})
91+
else:
92+
# Transform message content to Bedrock format
93+
transformed_message = {
94+
"role": message.get("role"),
95+
"content": []
96+
}
97+
98+
# Handle content - ensure it's in the correct array format
99+
content = message.get("content", "")
100+
if isinstance(content, str):
101+
# String content -> convert to text block
102+
transformed_message["content"].append({"text": content})
103+
elif isinstance(content, list):
104+
# Already in blocks format - use as is
105+
transformed_message["content"] = content
106+
107+
user_messages.append(transformed_message)
108+
109+
# Build the converse input format
110+
converse_input = {
111+
"messages": user_messages
112+
}
113+
114+
# Add system messages if present
115+
if system_messages:
116+
converse_input["system"] = system_messages
117+
118+
# Build the complete request
119+
return {
120+
"input": {
121+
"converse": converse_input
122+
}
123+
}
124+
125+
def _transform_to_invoke_model_format(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
126+
"""Transform to InvokeModel input format."""
127+
import json
128+
129+
# For InvokeModel, we need to provide the raw body that would be sent to the model
130+
# Remove the 'model' field from the body as it's not part of the model input
131+
body_data = {k: v for k, v in request_data.items() if k != "model"}
132+
133+
return {
134+
"input": {
135+
"invokeModel": {
136+
"body": json.dumps(body_data)
137+
}
138+
}
139+
}
140+
141+
def get_bedrock_count_tokens_endpoint(self, model: str, aws_region_name: str) -> str:
142+
"""
143+
Construct the AWS Bedrock CountTokens API endpoint using existing LiteLLM functions.
144+
145+
Args:
146+
model: The resolved model ID from router lookup
147+
aws_region_name: AWS region (e.g., "eu-west-1")
148+
149+
Returns:
150+
Complete endpoint URL for CountTokens API
151+
"""
152+
# Use existing LiteLLM function to get the base model ID (removes region prefix)
153+
model_id = BedrockModelInfo.get_base_model(model)
154+
155+
# Remove bedrock/ prefix if present
156+
if model_id.startswith("bedrock/"):
157+
model_id = model_id[8:] # Remove "bedrock/" prefix
158+
159+
base_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
160+
endpoint = f"{base_url}/model/{model_id}/count-tokens"
161+
162+
return endpoint
163+
164+
165+
def transform_bedrock_response_to_anthropic(self, bedrock_response: Dict[str, Any]) -> Dict[str, Any]:
166+
"""
167+
Transform Bedrock CountTokens response to Anthropic format.
168+
169+
Input (Bedrock response):
170+
{
171+
"inputTokens": 123
172+
}
173+
174+
Output (Anthropic format):
175+
{
176+
"input_tokens": 123
177+
}
178+
"""
179+
input_tokens = bedrock_response.get("inputTokens", 0)
180+
181+
return {
182+
"input_tokens": input_tokens
183+
}
184+
185+
def validate_count_tokens_request(self, request_data: Dict[str, Any]) -> None:
186+
"""
187+
Validate the incoming count tokens request.
188+
Supports both Converse and InvokeModel input formats.
189+
190+
Args:
191+
request_data: The request payload
192+
193+
Raises:
194+
ValueError: If the request is invalid
195+
"""
196+
if not request_data.get("model"):
197+
raise ValueError("model parameter is required")
198+
199+
input_type = self._detect_input_type(request_data)
200+
201+
if input_type == "converse":
202+
# Validate Converse format (messages-based)
203+
messages = request_data.get("messages", [])
204+
if not messages:
205+
raise ValueError("messages parameter is required for Converse input")
206+
207+
if not isinstance(messages, list):
208+
raise ValueError("messages must be a list")
209+
210+
for i, message in enumerate(messages):
211+
if not isinstance(message, dict):
212+
raise ValueError(f"Message {i} must be a dictionary")
213+
214+
if "role" not in message:
215+
raise ValueError(f"Message {i} must have a 'role' field")
216+
217+
if "content" not in message:
218+
raise ValueError(f"Message {i} must have a 'content' field")
219+
else:
220+
# For InvokeModel format, we need at least some content to count tokens
221+
# The content structure varies by model, so we do minimal validation
222+
if len(request_data) <= 1: # Only has 'model' field
223+
raise ValueError("Request must contain content to count tokens")

0 commit comments

Comments
 (0)