Skip to content

Commit c234b13

Browse files
committed
Apply code formatting and linting fixes
- Apply Black formatting to all Bedrock CountTokens files - Clean up imports and remove unused variables in tests - Fix indentation and simplify test structure - Fix pyright type error with type ignore annotation - All tests continue to pass after cleanup
1 parent e74ac35 commit c234b13

File tree

5 files changed

+221
-225
lines changed

5 files changed

+221
-225
lines changed

litellm/llms/bedrock/count_tokens/handler.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,9 @@ async def handle_count_tokens_request(
4040
# Validate the request
4141
self.validate_count_tokens_request(request_data)
4242

43-
verbose_logger.debug(f"Processing CountTokens request for resolved model: {resolved_model}")
43+
verbose_logger.debug(
44+
f"Processing CountTokens request for resolved model: {resolved_model}"
45+
)
4446

4547
# Get AWS region using existing LiteLLM function
4648
aws_region_name = self._get_aws_region_name(
@@ -59,7 +61,9 @@ async def handle_count_tokens_request(
5961
verbose_logger.debug(f"Transformed request: {bedrock_request}")
6062

6163
# Get endpoint URL using simplified function
62-
endpoint_url = self.get_bedrock_count_tokens_endpoint(resolved_model, aws_region_name)
64+
endpoint_url = self.get_bedrock_count_tokens_endpoint(
65+
resolved_model, aws_region_name
66+
)
6367

6468
verbose_logger.debug(f"Making request to: {endpoint_url}")
6569

@@ -76,6 +80,7 @@ async def handle_count_tokens_request(
7680

7781
# Make HTTP request
7882
import httpx
83+
7984
async with httpx.AsyncClient() as client:
8085
response = await client.post(
8186
endpoint_url,
@@ -91,15 +96,17 @@ async def handle_count_tokens_request(
9196
verbose_logger.error(f"AWS Bedrock error: {error_text}")
9297
raise HTTPException(
9398
status_code=400,
94-
detail={"error": f"AWS Bedrock error: {error_text}"}
99+
detail={"error": f"AWS Bedrock error: {error_text}"},
95100
)
96101

97102
bedrock_response = response.json()
98103

99104
verbose_logger.debug(f"Bedrock response: {bedrock_response}")
100105

101106
# Transform response back to expected format
102-
final_response = self.transform_bedrock_response_to_anthropic(bedrock_response)
107+
final_response = self.transform_bedrock_response_to_anthropic(
108+
bedrock_response
109+
)
103110

104111
verbose_logger.debug(f"Final response: {final_response}")
105112

@@ -112,5 +119,5 @@ async def handle_count_tokens_request(
112119
verbose_logger.error(f"Error in CountTokens handler: {str(e)}")
113120
raise HTTPException(
114121
status_code=500,
115-
detail={"error": f"CountTokens processing error: {str(e)}"}
116-
)
122+
detail={"error": f"CountTokens processing error: {str(e)}"},
123+
)

litellm/llms/bedrock/count_tokens/transformation.py

Lines changed: 18 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,9 @@ def transform_anthropic_to_bedrock_count_tokens(
7979
else:
8080
return self._transform_to_invoke_model_format(request_data)
8181

82-
def _transform_to_converse_format(self, messages: List[Dict[str, Any]]) -> Dict[str, Any]:
82+
def _transform_to_converse_format(
83+
self, messages: List[Dict[str, Any]]
84+
) -> Dict[str, Any]:
8385
"""Transform to Converse input format."""
8486
# Extract system messages if present
8587
system_messages = []
@@ -90,10 +92,7 @@ def _transform_to_converse_format(self, messages: List[Dict[str, Any]]) -> Dict[
9092
system_messages.append({"text": message.get("content", "")})
9193
else:
9294
# Transform message content to Bedrock format
93-
transformed_message = {
94-
"role": message.get("role"),
95-
"content": []
96-
}
95+
transformed_message = {"role": message.get("role"), "content": []}
9796

9897
# Handle content - ensure it's in the correct array format
9998
content = message.get("content", "")
@@ -107,38 +106,30 @@ def _transform_to_converse_format(self, messages: List[Dict[str, Any]]) -> Dict[
107106
user_messages.append(transformed_message)
108107

109108
# Build the converse input format
110-
converse_input = {
111-
"messages": user_messages
112-
}
109+
converse_input = {"messages": user_messages}
113110

114111
# Add system messages if present
115112
if system_messages:
116113
converse_input["system"] = system_messages
117114

118115
# Build the complete request
119-
return {
120-
"input": {
121-
"converse": converse_input
122-
}
123-
}
116+
return {"input": {"converse": converse_input}}
124117

125-
def _transform_to_invoke_model_format(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
118+
def _transform_to_invoke_model_format(
119+
self, request_data: Dict[str, Any]
120+
) -> Dict[str, Any]:
126121
"""Transform to InvokeModel input format."""
127122
import json
128123

129124
# For InvokeModel, we need to provide the raw body that would be sent to the model
130125
# Remove the 'model' field from the body as it's not part of the model input
131126
body_data = {k: v for k, v in request_data.items() if k != "model"}
132127

133-
return {
134-
"input": {
135-
"invokeModel": {
136-
"body": json.dumps(body_data)
137-
}
138-
}
139-
}
128+
return {"input": {"invokeModel": {"body": json.dumps(body_data)}}}
140129

141-
def get_bedrock_count_tokens_endpoint(self, model: str, aws_region_name: str) -> str:
130+
def get_bedrock_count_tokens_endpoint(
131+
self, model: str, aws_region_name: str
132+
) -> str:
142133
"""
143134
Construct the AWS Bedrock CountTokens API endpoint using existing LiteLLM functions.
144135
@@ -161,8 +152,9 @@ def get_bedrock_count_tokens_endpoint(self, model: str, aws_region_name: str) ->
161152

162153
return endpoint
163154

164-
165-
def transform_bedrock_response_to_anthropic(self, bedrock_response: Dict[str, Any]) -> Dict[str, Any]:
155+
def transform_bedrock_response_to_anthropic(
156+
self, bedrock_response: Dict[str, Any]
157+
) -> Dict[str, Any]:
166158
"""
167159
Transform Bedrock CountTokens response to Anthropic format.
168160
@@ -178,9 +170,7 @@ def transform_bedrock_response_to_anthropic(self, bedrock_response: Dict[str, An
178170
"""
179171
input_tokens = bedrock_response.get("inputTokens", 0)
180172

181-
return {
182-
"input_tokens": input_tokens
183-
}
173+
return {"input_tokens": input_tokens}
184174

185175
def validate_count_tokens_request(self, request_data: Dict[str, Any]) -> None:
186176
"""
@@ -220,4 +210,4 @@ def validate_count_tokens_request(self, request_data: Dict[str, Any]) -> None:
220210
# For InvokeModel format, we need at least some content to count tokens
221211
# The content structure varies by model, so we do minimal validation
222212
if len(request_data) <= 1: # Only has 'model' field
223-
raise ValueError("Request must contain content to count tokens")
213+
raise ValueError("Request must contain content to count tokens")

litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,9 @@ async def gemini_proxy_route(
172172
request=request, api_key=f"Bearer {google_ai_studio_api_key}"
173173
)
174174

175-
base_target_url = os.getenv("GEMINI_API_BASE") or "https://generativelanguage.googleapis.com"
175+
base_target_url = (
176+
os.getenv("GEMINI_API_BASE") or "https://generativelanguage.googleapis.com"
177+
)
176178
encoded_endpoint = httpx.URL(endpoint).path
177179

178180
# Ensure endpoint starts with '/' for proper URL construction
@@ -489,8 +491,7 @@ async def handle_bedrock_count_tokens(
489491
model = request_body.get("model")
490492
if not model:
491493
raise HTTPException(
492-
status_code=400,
493-
detail={"error": "Model is required in request body"}
494+
status_code=400, detail={"error": "Model is required in request body"}
494495
)
495496

496497
# Get model parameters from router
@@ -511,7 +512,7 @@ async def handle_bedrock_count_tokens(
511512
# Copy all litellm_params - BaseAWSLLM will handle AWS credential discovery
512513
for key, value in model_litellm_params.items():
513514
if key != "user_api_key_dict": # Don't overwrite user_api_key_dict
514-
litellm_params[key] = value
515+
litellm_params[key] = value # type: ignore
515516

516517
verbose_proxy_logger.debug(f"Count tokens litellm_params: {litellm_params}")
517518
verbose_proxy_logger.debug(f"Resolved model: {resolved_model}")
@@ -531,8 +532,7 @@ async def handle_bedrock_count_tokens(
531532
except Exception as e:
532533
verbose_proxy_logger.error(f"Error in handle_bedrock_count_tokens: {str(e)}")
533534
raise HTTPException(
534-
status_code=500,
535-
detail={"error": f"CountTokens processing error: {str(e)}"}
535+
status_code=500, detail={"error": f"CountTokens processing error: {str(e)}"}
536536
)
537537

538538

@@ -588,13 +588,13 @@ async def bedrock_llm_proxy_route(
588588
"error": "Model missing from endpoint. Expected format: /model/<Model>/<endpoint>. Got: "
589589
+ endpoint,
590590
},
591-
)
591+
)
592592

593593
data["method"] = request.method
594594
data["endpoint"] = endpoint
595595
data["data"] = request_body
596596
data["custom_llm_provider"] = "bedrock"
597-
597+
598598
try:
599599
result = await base_llm_response_processor.base_passthrough_process_llm_request(
600600
request=request,

0 commit comments

Comments
 (0)