Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 35 additions & 57 deletions litellm/proxy/anthropic_endpoints/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
create_streaming_response,
)
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
from litellm.types.utils import TokenCountResponse

router = APIRouter()
Expand Down Expand Up @@ -52,52 +51,35 @@ async def anthropic_response( # noqa: PLR0915
request_data = await _read_request_body(request=request)
data: dict = {**request_data}
try:
data["model"] = (
general_settings.get("completion_model", None) # server default
or user_model # model name passed via cli args
or data.get("model", None) # default passed in http request
)
if user_model:
data["model"] = user_model
# Initialize the base processor
base_llm_response_processor = ProxyBaseLLMRequestProcessing(data=data)

data = await add_litellm_data_to_request(
data=data, # type: ignore
# Use common processing logic to ensure guardrails are applied
(
data,
logging_obj,
) = await base_llm_response_processor.common_processing_pre_call_logic(
request=request,
general_settings=general_settings,
user_api_key_dict=user_api_key_dict,
proxy_logging_obj=proxy_logging_obj,
version=version,
proxy_config=proxy_config,
)

# override with user settings, these are params passed via cli
if user_temperature:
data["temperature"] = user_temperature
if user_request_timeout:
data["request_timeout"] = user_request_timeout
if user_max_tokens:
data["max_tokens"] = user_max_tokens
if user_api_base:
data["api_base"] = user_api_base

### MODEL ALIAS MAPPING ###
# check if model name in model alias map
# get the actual model name
if data["model"] in litellm.model_alias_map:
data["model"] = litellm.model_alias_map[data["model"]]

### CALL HOOKS ### - modify incoming data before calling the model
data = await proxy_logging_obj.pre_call_hook( # type: ignore
user_api_key_dict=user_api_key_dict, data=data, call_type="text_completion"
user_model=user_model,
user_temperature=user_temperature,
user_request_timeout=user_request_timeout,
user_max_tokens=user_max_tokens,
user_api_base=user_api_base,
model=None,
route_type="acompletion", # Use acompletion to ensure guardrails are applied
)

tasks = []
tasks.append(
proxy_logging_obj.during_call_hook(
data=data,
user_api_key_dict=user_api_key_dict,
call_type=ProxyBaseLLMRequestProcessing._get_pre_call_type(
route_type="anthropic_messages" # type: ignore
),
call_type="completion", # Use completion type for anthropic messages
)
)

Expand All @@ -119,8 +101,8 @@ async def anthropic_response( # noqa: PLR0915
llm_router is not None and data["model"] in llm_router.deployment_names
): # model in router deployments, calling a specific deployment on the router
llm_coro = llm_router.aanthropic_messages(**data, specific_deployment=True)
elif (
llm_router is not None and llm_router.has_model_id(data["model"])
elif llm_router is not None and llm_router.has_model_id(
data["model"]
): # model in router model list
llm_coro = llm_router.aanthropic_messages(**data)
elif (
Expand Down Expand Up @@ -202,7 +184,7 @@ async def anthropic_response( # noqa: PLR0915

### CALL HOOKS ### - modify outgoing data
response = await proxy_logging_obj.post_call_success_hook(
data=data, user_api_key_dict=user_api_key_dict, response=response # type: ignore
data=data, user_api_key_dict=user_api_key_dict, response=response # type: ignore
)

verbose_proxy_logger.debug("\nResponse from Litellm:\n{}".format(response))
Expand Down Expand Up @@ -255,35 +237,30 @@ async def count_tokens(
Returns: {"input_tokens": <number>}
"""
from litellm.proxy.proxy_server import token_counter as internal_token_counter

try:
request_data = await _read_request_body(request=request)
data: dict = {**request_data}

# Extract required fields
model_name = data.get("model")
messages = data.get("messages", [])

if not model_name:
raise HTTPException(
status_code=400,
detail={"error": "model parameter is required"}
status_code=400, detail={"error": "model parameter is required"}
)

if not messages:
raise HTTPException(
status_code=400,
detail={"error": "messages parameter is required"}
status_code=400, detail={"error": "messages parameter is required"}
)

# Create TokenCountRequest for the internal endpoint
from litellm.proxy._types import TokenCountRequest

token_request = TokenCountRequest(
model=model_name,
messages=messages
)


token_request = TokenCountRequest(model=model_name, messages=messages)

# Call the internal token counter function with direct request flag set to False
token_response = await internal_token_counter(
request=token_request,
Expand All @@ -294,17 +271,18 @@ async def count_tokens(
_token_response_dict = token_response.model_dump()
elif isinstance(token_response, dict):
_token_response_dict = token_response

# Convert the internal response to Anthropic API format
return {"input_tokens": _token_response_dict.get("total_tokens", 0)}

except HTTPException:
raise
except Exception as e:
verbose_proxy_logger.exception(
"litellm.proxy.anthropic_endpoints.count_tokens(): Exception occurred - {}".format(str(e))
"litellm.proxy.anthropic_endpoints.count_tokens(): Exception occurred - {}".format(
str(e)
)
)
raise HTTPException(
status_code=500,
detail={"error": f"Internal server error: {str(e)}"}
status_code=500, detail={"error": f"Internal server error: {str(e)}"}
)
21 changes: 21 additions & 0 deletions litellm/proxy/common_request_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,27 @@ def _get_pre_call_type(
elif route_type == "aresponses":
return "responses"

@staticmethod
def _map_route_type_to_call_type(route_type: str) -> str:
"""
Maps route_type to call_type for guardrail hooks.
This ensures guardrails receive the correct call_type parameter.
"""
route_to_call_type_map = {
"acompletion": "completion",
"aresponses": "responses",
"atext_completion": "text_completion",
"aimage_edit": "image_generation",
"aimage_generation": "image_generation",
"aembeddings": "embeddings",
"amoderation": "moderation",
"aaudio_transcription": "audio_transcription",
"arerank": "rerank",
"allm_passthrough_route": "pass_through_endpoint",
"amcp_call": "mcp_call",
}
return route_to_call_type_map.get(route_type, route_type)

#########################################################
# Proxy Level Streaming Data Generator
#########################################################
Expand Down
114 changes: 104 additions & 10 deletions litellm/proxy/guardrails/guardrail_hooks/presidio.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ def __init__(
output_parse_pii: Optional[bool] = False,
presidio_ad_hoc_recognizers: Optional[str] = None,
logging_only: Optional[bool] = None,
pii_entities_config: Optional[Dict[Union[PiiEntityType, str], PiiAction]] = None,
pii_entities_config: Optional[
Dict[Union[PiiEntityType, str], PiiAction]
] = None,
presidio_language: Optional[str] = None,
**kwargs,
):
Expand Down Expand Up @@ -245,9 +247,14 @@ async def anonymize_text(
# Make the request to /anonymize
anonymize_url = f"{self.presidio_anonymizer_api_base}anonymize"
verbose_proxy_logger.debug("Making request to: %s", anonymize_url)

# Build anonymize payload
anonymize_payload = {
"text": text,
"analyzer_results": analyze_results,
"anonymizers": {
"DEFAULT": {"type": "replace", "new_value": "{REDACTED}"}
},
}

async with session.post(
Expand Down Expand Up @@ -412,11 +419,17 @@ async def async_pre_call_hook(
):
messages = data["messages"]
tasks = []
content_types = [] # Track whether each content is string or structured

for m in messages:
content = m.get("content", None)
if content is None:
content_types.append(None)
continue

# Handle string content (OpenAI format)
if isinstance(content, str):
content_types.append("string")
tasks.append(
self.check_pii(
text=content,
Expand All @@ -425,15 +438,54 @@ async def async_pre_call_hook(
request_data=data,
)
)
# Handle structured content (Anthropic format with list of content blocks)
elif isinstance(content, list):
content_types.append("list")
# Process each text block in the content array
for content_block in content:
if (
isinstance(content_block, dict)
and content_block.get("type") == "text"
):
text = content_block.get("text", "")
if text:
tasks.append(
self.check_pii(
text=text,
output_parse_pii=self.output_parse_pii,
presidio_config=presidio_config,
request_data=data,
)
)
else:
content_types.append(None)

responses = await asyncio.gather(*tasks)
for index, r in enumerate(responses):
content = messages[index].get("content", None)

# Apply redacted text back to messages
response_index = 0
for msg_index, content_type in enumerate(content_types):
if content_type is None:
continue

content = messages[msg_index].get("content", None)
if content is None:
continue
if isinstance(content, str):
messages[index][
"content"
] = r # replace content with redacted string

# Handle string content
if content_type == "string":
messages[msg_index]["content"] = responses[response_index]
response_index += 1
# Handle structured content
elif content_type == "list":
for content_block in content:
if (
isinstance(content_block, dict)
and content_block.get("type") == "text"
):
if content_block.get("text"):
content_block["text"] = responses[response_index]
response_index += 1
verbose_proxy_logger.debug(
f"Presidio PII Masking: Redacted pii message: {data['messages']}"
)
Expand Down Expand Up @@ -530,10 +582,12 @@ async def async_post_call_success_hook( # type: ignore
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
response: Union[ModelResponse, EmbeddingResponse, ImageResponse],
response: Union[ModelResponse, EmbeddingResponse, ImageResponse, dict],
):
"""
Output parse the response object to replace the masked tokens with user sent values
Output parse the response object to:
1. Replace the masked tokens from input with user sent values (unmask input tokens)
2. Mask any NEW PII found in the LLM's response (mask output PII)
"""
verbose_proxy_logger.debug(
f"PII Masking Args: self.output_parse_pii={self.output_parse_pii}; type of response={type(response)}"
Expand All @@ -542,17 +596,57 @@ async def async_post_call_success_hook( # type: ignore
if self.output_parse_pii is False and litellm.output_parse_pii is False:
return response

presidio_config = self.get_presidio_settings_from_request_data(data)

# Handle OpenAI/ModelResponse format
if isinstance(response, ModelResponse) and not isinstance(
response.choices[0], StreamingChoices
): # /chat/completions requests
if isinstance(response.choices[0].message.content, str):
original_content = response.choices[0].message.content
verbose_proxy_logger.debug(
f"self.pii_tokens: {self.pii_tokens}; initial response: {response.choices[0].message.content}"
f"self.pii_tokens: {self.pii_tokens}; initial response: {original_content}"
)

# Step 1: Unmask input tokens (original behavior)
for key, value in self.pii_tokens.items():
response.choices[0].message.content = response.choices[
0
].message.content.replace(key, value)

# Step 2: Mask NEW PII found in the output
masked_output = await self.check_pii(
text=response.choices[0].message.content,
output_parse_pii=False, # Don't track tokens for unmasking
presidio_config=presidio_config,
request_data=data,
)
response.choices[0].message.content = masked_output

# Handle Anthropic format (dict response from /v1/messages)
elif isinstance(response, dict):
content = response.get("content", [])
if isinstance(content, list):
for content_block in content:
if (
isinstance(content_block, dict)
and content_block.get("type") == "text"
):
text = content_block.get("text", "")
if text:
# Step 1: Unmask input tokens
for key, value in self.pii_tokens.items():
text = text.replace(key, value)

# Step 2: Mask NEW PII found in the output
masked_text = await self.check_pii(
text=text,
output_parse_pii=False,
presidio_config=presidio_config,
request_data=data,
)
content_block["text"] = masked_text

return response

async def async_post_call_streaming_iterator_hook(
Expand Down
Loading
Loading