1
+ """
2
+ AWS Bedrock CountTokens API transformation logic.
3
+
4
+ This module handles the transformation of requests from Anthropic Messages API format
5
+ to AWS Bedrock's CountTokens API format and vice versa.
6
+ """
7
+
8
+ from typing import Any , Dict , List
9
+
10
+ from litellm .llms .bedrock .base_aws_llm import BaseAWSLLM
11
+ from litellm .llms .bedrock .common_utils import BedrockModelInfo
12
+
13
+
14
+ class BedrockCountTokensConfig (BaseAWSLLM ):
15
+ """
16
+ Configuration and transformation logic for AWS Bedrock CountTokens API.
17
+
18
+ AWS Bedrock CountTokens API Specification:
19
+ - Endpoint: POST /model/{modelId}/count-tokens
20
+ - Input formats: 'invokeModel' or 'converse'
21
+ - Response: {"inputTokens": <number>}
22
+ """
23
+
24
+ def _detect_input_type (self , request_data : Dict [str , Any ]) -> str :
25
+ """
26
+ Detect whether to use 'converse' or 'invokeModel' input format.
27
+
28
+ Args:
29
+ request_data: The original request data
30
+
31
+ Returns:
32
+ 'converse' or 'invokeModel'
33
+ """
34
+ # If the request has messages in the expected Anthropic format, use converse
35
+ if "messages" in request_data and isinstance (request_data ["messages" ], list ):
36
+ return "converse"
37
+
38
+ # For raw text or other formats, use invokeModel
39
+ # This handles cases where the input is prompt-based or already in raw Bedrock format
40
+ return "invokeModel"
41
+
42
+ def transform_anthropic_to_bedrock_count_tokens (
43
+ self ,
44
+ request_data : Dict [str , Any ],
45
+ ) -> Dict [str , Any ]:
46
+ """
47
+ Transform request to Bedrock CountTokens format.
48
+ Supports both Converse and InvokeModel input types.
49
+
50
+ Input (Anthropic format):
51
+ {
52
+ "model": "claude-3-5-sonnet",
53
+ "messages": [{"role": "user", "content": "Hello!"}]
54
+ }
55
+
56
+ Output (Bedrock CountTokens format for Converse):
57
+ {
58
+ "input": {
59
+ "converse": {
60
+ "messages": [...],
61
+ "system": [...] (if present)
62
+ }
63
+ }
64
+ }
65
+
66
+ Output (Bedrock CountTokens format for InvokeModel):
67
+ {
68
+ "input": {
69
+ "invokeModel": {
70
+ "body": "{...raw model input...}"
71
+ }
72
+ }
73
+ }
74
+ """
75
+ input_type = self ._detect_input_type (request_data )
76
+
77
+ if input_type == "converse" :
78
+ return self ._transform_to_converse_format (request_data .get ("messages" , []))
79
+ else :
80
+ return self ._transform_to_invoke_model_format (request_data )
81
+
82
+ def _transform_to_converse_format (self , messages : List [Dict [str , Any ]]) -> Dict [str , Any ]:
83
+ """Transform to Converse input format."""
84
+ # Extract system messages if present
85
+ system_messages = []
86
+ user_messages = []
87
+
88
+ for message in messages :
89
+ if message .get ("role" ) == "system" :
90
+ system_messages .append ({"text" : message .get ("content" , "" )})
91
+ else :
92
+ # Transform message content to Bedrock format
93
+ transformed_message = {
94
+ "role" : message .get ("role" ),
95
+ "content" : []
96
+ }
97
+
98
+ # Handle content - ensure it's in the correct array format
99
+ content = message .get ("content" , "" )
100
+ if isinstance (content , str ):
101
+ # String content -> convert to text block
102
+ transformed_message ["content" ].append ({"text" : content })
103
+ elif isinstance (content , list ):
104
+ # Already in blocks format - use as is
105
+ transformed_message ["content" ] = content
106
+
107
+ user_messages .append (transformed_message )
108
+
109
+ # Build the converse input format
110
+ converse_input = {
111
+ "messages" : user_messages
112
+ }
113
+
114
+ # Add system messages if present
115
+ if system_messages :
116
+ converse_input ["system" ] = system_messages
117
+
118
+ # Build the complete request
119
+ return {
120
+ "input" : {
121
+ "converse" : converse_input
122
+ }
123
+ }
124
+
125
+ def _transform_to_invoke_model_format (self , request_data : Dict [str , Any ]) -> Dict [str , Any ]:
126
+ """Transform to InvokeModel input format."""
127
+ import json
128
+
129
+ # For InvokeModel, we need to provide the raw body that would be sent to the model
130
+ # Remove the 'model' field from the body as it's not part of the model input
131
+ body_data = {k : v for k , v in request_data .items () if k != "model" }
132
+
133
+ return {
134
+ "input" : {
135
+ "invokeModel" : {
136
+ "body" : json .dumps (body_data )
137
+ }
138
+ }
139
+ }
140
+
141
+ def get_bedrock_count_tokens_endpoint (self , model : str , aws_region_name : str ) -> str :
142
+ """
143
+ Construct the AWS Bedrock CountTokens API endpoint using existing LiteLLM functions.
144
+
145
+ Args:
146
+ model: The resolved model ID from router lookup
147
+ aws_region_name: AWS region (e.g., "eu-west-1")
148
+
149
+ Returns:
150
+ Complete endpoint URL for CountTokens API
151
+ """
152
+ # Use existing LiteLLM function to get the base model ID (removes region prefix)
153
+ model_id = BedrockModelInfo .get_base_model (model )
154
+
155
+ # Remove bedrock/ prefix if present
156
+ if model_id .startswith ("bedrock/" ):
157
+ model_id = model_id [8 :] # Remove "bedrock/" prefix
158
+
159
+ base_url = f"https://bedrock-runtime.{ aws_region_name } .amazonaws.com"
160
+ endpoint = f"{ base_url } /model/{ model_id } /count-tokens"
161
+
162
+ return endpoint
163
+
164
+
165
+ def transform_bedrock_response_to_anthropic (self , bedrock_response : Dict [str , Any ]) -> Dict [str , Any ]:
166
+ """
167
+ Transform Bedrock CountTokens response to Anthropic format.
168
+
169
+ Input (Bedrock response):
170
+ {
171
+ "inputTokens": 123
172
+ }
173
+
174
+ Output (Anthropic format):
175
+ {
176
+ "input_tokens": 123
177
+ }
178
+ """
179
+ input_tokens = bedrock_response .get ("inputTokens" , 0 )
180
+
181
+ return {
182
+ "input_tokens" : input_tokens
183
+ }
184
+
185
+ def validate_count_tokens_request (self , request_data : Dict [str , Any ]) -> None :
186
+ """
187
+ Validate the incoming count tokens request.
188
+ Supports both Converse and InvokeModel input formats.
189
+
190
+ Args:
191
+ request_data: The request payload
192
+
193
+ Raises:
194
+ ValueError: If the request is invalid
195
+ """
196
+ if not request_data .get ("model" ):
197
+ raise ValueError ("model parameter is required" )
198
+
199
+ input_type = self ._detect_input_type (request_data )
200
+
201
+ if input_type == "converse" :
202
+ # Validate Converse format (messages-based)
203
+ messages = request_data .get ("messages" , [])
204
+ if not messages :
205
+ raise ValueError ("messages parameter is required for Converse input" )
206
+
207
+ if not isinstance (messages , list ):
208
+ raise ValueError ("messages must be a list" )
209
+
210
+ for i , message in enumerate (messages ):
211
+ if not isinstance (message , dict ):
212
+ raise ValueError (f"Message { i } must be a dictionary" )
213
+
214
+ if "role" not in message :
215
+ raise ValueError (f"Message { i } must have a 'role' field" )
216
+
217
+ if "content" not in message :
218
+ raise ValueError (f"Message { i } must have a 'content' field" )
219
+ else :
220
+ # For InvokeModel format, we need at least some content to count tokens
221
+ # The content structure varies by model, so we do minimal validation
222
+ if len (request_data ) <= 1 : # Only has 'model' field
223
+ raise ValueError ("Request must contain content to count tokens" )
0 commit comments