From 4673fadab8a47cbb9451c339cfbaf29e1c5a20f0 Mon Sep 17 00:00:00 2001 From: Ashley Kingscote <25075013+akingscote@users.noreply.github.com> Date: Thu, 2 Oct 2025 13:46:04 +0100 Subject: [PATCH] working anthropic presidio redaction --- .../proxy/anthropic_endpoints/endpoints.py | 92 +++---- litellm/proxy/common_request_processing.py | 21 ++ .../guardrails/guardrail_hooks/presidio.py | 114 ++++++++- .../test_presidio_anthropic_support.py | 235 ++++++++++++++++++ 4 files changed, 395 insertions(+), 67 deletions(-) create mode 100644 tests/test_litellm/proxy/guardrails/guardrail_hooks/test_presidio_anthropic_support.py diff --git a/litellm/proxy/anthropic_endpoints/endpoints.py b/litellm/proxy/anthropic_endpoints/endpoints.py index e7ce5e888c07..de5c30fd5953 100644 --- a/litellm/proxy/anthropic_endpoints/endpoints.py +++ b/litellm/proxy/anthropic_endpoints/endpoints.py @@ -15,7 +15,6 @@ create_streaming_response, ) from litellm.proxy.common_utils.http_parsing_utils import _read_request_body -from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request from litellm.types.utils import TokenCountResponse router = APIRouter() @@ -52,42 +51,27 @@ async def anthropic_response( # noqa: PLR0915 request_data = await _read_request_body(request=request) data: dict = {**request_data} try: - data["model"] = ( - general_settings.get("completion_model", None) # server default - or user_model # model name passed via cli args - or data.get("model", None) # default passed in http request - ) - if user_model: - data["model"] = user_model + # Initialize the base processor + base_llm_response_processor = ProxyBaseLLMRequestProcessing(data=data) - data = await add_litellm_data_to_request( - data=data, # type: ignore + # Use common processing logic to ensure guardrails are applied + ( + data, + logging_obj, + ) = await base_llm_response_processor.common_processing_pre_call_logic( request=request, general_settings=general_settings, user_api_key_dict=user_api_key_dict, + proxy_logging_obj=proxy_logging_obj, version=version, proxy_config=proxy_config, - ) - - # override with user settings, these are params passed via cli - if user_temperature: - data["temperature"] = user_temperature - if user_request_timeout: - data["request_timeout"] = user_request_timeout - if user_max_tokens: - data["max_tokens"] = user_max_tokens - if user_api_base: - data["api_base"] = user_api_base - - ### MODEL ALIAS MAPPING ### - # check if model name in model alias map - # get the actual model name - if data["model"] in litellm.model_alias_map: - data["model"] = litellm.model_alias_map[data["model"]] - - ### CALL HOOKS ### - modify incoming data before calling the model - data = await proxy_logging_obj.pre_call_hook( # type: ignore - user_api_key_dict=user_api_key_dict, data=data, call_type="text_completion" + user_model=user_model, + user_temperature=user_temperature, + user_request_timeout=user_request_timeout, + user_max_tokens=user_max_tokens, + user_api_base=user_api_base, + model=None, + route_type="acompletion", # Use acompletion to ensure guardrails are applied ) tasks = [] @@ -95,9 +79,7 @@ async def anthropic_response( # noqa: PLR0915 proxy_logging_obj.during_call_hook( data=data, user_api_key_dict=user_api_key_dict, - call_type=ProxyBaseLLMRequestProcessing._get_pre_call_type( - route_type="anthropic_messages" # type: ignore - ), + call_type="completion", # Use completion type for anthropic messages ) ) @@ -119,8 +101,8 @@ async def anthropic_response( # noqa: PLR0915 llm_router is not None and data["model"] in llm_router.deployment_names ): # model in router deployments, calling a specific deployment on the router llm_coro = llm_router.aanthropic_messages(**data, specific_deployment=True) - elif ( - llm_router is not None and llm_router.has_model_id(data["model"]) + elif llm_router is not None and llm_router.has_model_id( + data["model"] ): # model in router model list llm_coro = llm_router.aanthropic_messages(**data) elif ( @@ -202,7 +184,7 @@ async def anthropic_response( # noqa: PLR0915 ### CALL HOOKS ### - modify outgoing data response = await proxy_logging_obj.post_call_success_hook( - data=data, user_api_key_dict=user_api_key_dict, response=response # type: ignore + data=data, user_api_key_dict=user_api_key_dict, response=response # type: ignore ) verbose_proxy_logger.debug("\nResponse from Litellm:\n{}".format(response)) @@ -255,35 +237,30 @@ async def count_tokens( Returns: {"input_tokens": } """ from litellm.proxy.proxy_server import token_counter as internal_token_counter - + try: request_data = await _read_request_body(request=request) data: dict = {**request_data} - + # Extract required fields model_name = data.get("model") messages = data.get("messages", []) - + if not model_name: raise HTTPException( - status_code=400, - detail={"error": "model parameter is required"} + status_code=400, detail={"error": "model parameter is required"} ) - + if not messages: raise HTTPException( - status_code=400, - detail={"error": "messages parameter is required"} + status_code=400, detail={"error": "messages parameter is required"} ) - + # Create TokenCountRequest for the internal endpoint from litellm.proxy._types import TokenCountRequest - - token_request = TokenCountRequest( - model=model_name, - messages=messages - ) - + + token_request = TokenCountRequest(model=model_name, messages=messages) + # Call the internal token counter function with direct request flag set to False token_response = await internal_token_counter( request=token_request, @@ -294,17 +271,18 @@ async def count_tokens( _token_response_dict = token_response.model_dump() elif isinstance(token_response, dict): _token_response_dict = token_response - + # Convert the internal response to Anthropic API format return {"input_tokens": _token_response_dict.get("total_tokens", 0)} - + except HTTPException: raise except Exception as e: verbose_proxy_logger.exception( - "litellm.proxy.anthropic_endpoints.count_tokens(): Exception occurred - {}".format(str(e)) + "litellm.proxy.anthropic_endpoints.count_tokens(): Exception occurred - {}".format( + str(e) + ) ) raise HTTPException( - status_code=500, - detail={"error": f"Internal server error: {str(e)}"} + status_code=500, detail={"error": f"Internal server error: {str(e)}"} ) diff --git a/litellm/proxy/common_request_processing.py b/litellm/proxy/common_request_processing.py index f07a61c544cd..f0c07259775f 100644 --- a/litellm/proxy/common_request_processing.py +++ b/litellm/proxy/common_request_processing.py @@ -702,6 +702,27 @@ def _get_pre_call_type( elif route_type == "aresponses": return "responses" + @staticmethod + def _map_route_type_to_call_type(route_type: str) -> str: + """ + Maps route_type to call_type for guardrail hooks. + This ensures guardrails receive the correct call_type parameter. + """ + route_to_call_type_map = { + "acompletion": "completion", + "aresponses": "responses", + "atext_completion": "text_completion", + "aimage_edit": "image_generation", + "aimage_generation": "image_generation", + "aembeddings": "embeddings", + "amoderation": "moderation", + "aaudio_transcription": "audio_transcription", + "arerank": "rerank", + "allm_passthrough_route": "pass_through_endpoint", + "amcp_call": "mcp_call", + } + return route_to_call_type_map.get(route_type, route_type) + ######################################################### # Proxy Level Streaming Data Generator ######################################################### diff --git a/litellm/proxy/guardrails/guardrail_hooks/presidio.py b/litellm/proxy/guardrails/guardrail_hooks/presidio.py index b77e802c717f..d29aea8c470d 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/presidio.py +++ b/litellm/proxy/guardrails/guardrail_hooks/presidio.py @@ -68,7 +68,9 @@ def __init__( output_parse_pii: Optional[bool] = False, presidio_ad_hoc_recognizers: Optional[str] = None, logging_only: Optional[bool] = None, - pii_entities_config: Optional[Dict[Union[PiiEntityType, str], PiiAction]] = None, + pii_entities_config: Optional[ + Dict[Union[PiiEntityType, str], PiiAction] + ] = None, presidio_language: Optional[str] = None, **kwargs, ): @@ -245,9 +247,14 @@ async def anonymize_text( # Make the request to /anonymize anonymize_url = f"{self.presidio_anonymizer_api_base}anonymize" verbose_proxy_logger.debug("Making request to: %s", anonymize_url) + + # Build anonymize payload anonymize_payload = { "text": text, "analyzer_results": analyze_results, + "anonymizers": { + "DEFAULT": {"type": "replace", "new_value": "{REDACTED}"} + }, } async with session.post( @@ -412,11 +419,17 @@ async def async_pre_call_hook( ): messages = data["messages"] tasks = [] + content_types = [] # Track whether each content is string or structured + for m in messages: content = m.get("content", None) if content is None: + content_types.append(None) continue + + # Handle string content (OpenAI format) if isinstance(content, str): + content_types.append("string") tasks.append( self.check_pii( text=content, @@ -425,15 +438,54 @@ async def async_pre_call_hook( request_data=data, ) ) + # Handle structured content (Anthropic format with list of content blocks) + elif isinstance(content, list): + content_types.append("list") + # Process each text block in the content array + for content_block in content: + if ( + isinstance(content_block, dict) + and content_block.get("type") == "text" + ): + text = content_block.get("text", "") + if text: + tasks.append( + self.check_pii( + text=text, + output_parse_pii=self.output_parse_pii, + presidio_config=presidio_config, + request_data=data, + ) + ) + else: + content_types.append(None) + responses = await asyncio.gather(*tasks) - for index, r in enumerate(responses): - content = messages[index].get("content", None) + + # Apply redacted text back to messages + response_index = 0 + for msg_index, content_type in enumerate(content_types): + if content_type is None: + continue + + content = messages[msg_index].get("content", None) if content is None: continue - if isinstance(content, str): - messages[index][ - "content" - ] = r # replace content with redacted string + + # Handle string content + if content_type == "string": + messages[msg_index]["content"] = responses[response_index] + response_index += 1 + # Handle structured content + elif content_type == "list": + for content_block in content: + if ( + isinstance(content_block, dict) + and content_block.get("type") == "text" + ): + if content_block.get("text"): + content_block["text"] = responses[response_index] + response_index += 1 verbose_proxy_logger.debug( f"Presidio PII Masking: Redacted pii message: {data['messages']}" ) @@ -530,10 +582,12 @@ async def async_post_call_success_hook( # type: ignore self, data: dict, user_api_key_dict: UserAPIKeyAuth, - response: Union[ModelResponse, EmbeddingResponse, ImageResponse], + response: Union[ModelResponse, EmbeddingResponse, ImageResponse, dict], ): """ - Output parse the response object to replace the masked tokens with user sent values + Output parse the response object to: + 1. Replace the masked tokens from input with user sent values (unmask input tokens) + 2. Mask any NEW PII found in the LLM's response (mask output PII) """ verbose_proxy_logger.debug( f"PII Masking Args: self.output_parse_pii={self.output_parse_pii}; type of response={type(response)}" @@ -542,17 +596,57 @@ async def async_post_call_success_hook( # type: ignore if self.output_parse_pii is False and litellm.output_parse_pii is False: return response + presidio_config = self.get_presidio_settings_from_request_data(data) + + # Handle OpenAI/ModelResponse format if isinstance(response, ModelResponse) and not isinstance( response.choices[0], StreamingChoices ): # /chat/completions requests if isinstance(response.choices[0].message.content, str): + original_content = response.choices[0].message.content verbose_proxy_logger.debug( - f"self.pii_tokens: {self.pii_tokens}; initial response: {response.choices[0].message.content}" + f"self.pii_tokens: {self.pii_tokens}; initial response: {original_content}" ) + + # Step 1: Unmask input tokens (original behavior) for key, value in self.pii_tokens.items(): response.choices[0].message.content = response.choices[ 0 ].message.content.replace(key, value) + + # Step 2: Mask NEW PII found in the output + masked_output = await self.check_pii( + text=response.choices[0].message.content, + output_parse_pii=False, # Don't track tokens for unmasking + presidio_config=presidio_config, + request_data=data, + ) + response.choices[0].message.content = masked_output + + # Handle Anthropic format (dict response from /v1/messages) + elif isinstance(response, dict): + content = response.get("content", []) + if isinstance(content, list): + for content_block in content: + if ( + isinstance(content_block, dict) + and content_block.get("type") == "text" + ): + text = content_block.get("text", "") + if text: + # Step 1: Unmask input tokens + for key, value in self.pii_tokens.items(): + text = text.replace(key, value) + + # Step 2: Mask NEW PII found in the output + masked_text = await self.check_pii( + text=text, + output_parse_pii=False, + presidio_config=presidio_config, + request_data=data, + ) + content_block["text"] = masked_text + return response async def async_post_call_streaming_iterator_hook( diff --git a/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_presidio_anthropic_support.py b/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_presidio_anthropic_support.py new file mode 100644 index 000000000000..c6f24d64bbaf --- /dev/null +++ b/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_presidio_anthropic_support.py @@ -0,0 +1,235 @@ +""" +Unit tests for Presidio guardrail support for Anthropic /v1/messages endpoint. + +Tests the following functionality: +1. Anthropic structured content format (list of content blocks) is properly masked +2. Output PII masking works for Anthropic response format +3. Both input and output masking work together +""" +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from litellm.proxy.guardrails.guardrail_hooks.presidio import ( + _OPTIONAL_PresidioPIIMasking, +) +from litellm.proxy._types import UserAPIKeyAuth + + +@pytest.mark.asyncio +async def test_presidio_anthropic_structured_content_input_masking(): + """Test that Presidio can mask PII in Anthropic's structured content format (list of content blocks)""" + + # Mock Presidio responses + mock_analyze_response = [ + {"start": 11, "end": 17, "score": 0.85, "entity_type": "PERSON"} + ] + + mock_anonymize_response = { + "text": "My name is {REDACTED}, what is my name?", + "items": [ + { + "start": 11, + "end": 21, + "entity_type": "PERSON", + "text": "{REDACTED}", + "operator": "replace", + } + ], + } + + # Create Presidio instance + presidio = _OPTIONAL_PresidioPIIMasking( + mock_testing=True, + presidio_analyzer_api_base="http://fake:3000", + presidio_anonymizer_api_base="http://fake:3000", + default_on=True, + event_hook="pre_call", + ) + + # Mock the analyze and anonymize methods + presidio.analyze_text = AsyncMock(return_value=mock_analyze_response) + presidio.anonymize_text = AsyncMock(return_value=mock_anonymize_response["text"]) + + # Test data with Anthropic structured content format + data = { + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "My name is Ashley, what is my name?"} + ], + } + ] + } + + user_api_key_dict = UserAPIKeyAuth(api_key="test_key") + + # Call the pre-call hook + result = await presidio.async_pre_call_hook( + user_api_key_dict=user_api_key_dict, + cache=MagicMock(), + data=data, + call_type="acompletion", + ) + + # Verify the PII was masked in the structured content + assert ( + result["messages"][0]["content"][0]["text"] + == "My name is {REDACTED}, what is my name?" + ) + assert presidio.anonymize_text.called + + +@pytest.mark.asyncio +async def test_presidio_anthropic_output_pii_masking(): + """Test that Presidio can mask PII in Anthropic response format (dict with content array)""" + + # Mock Presidio responses for output + mock_analyze_response = [ + {"start": 36, "end": 62, "score": 0.95, "entity_type": "EMAIL_ADDRESS"} + ] + + mock_anonymize_response = { + "text": "Here's a fake email address:\n\n{REDACTED}", + "items": [ + { + "start": 36, + "end": 46, + "entity_type": "EMAIL_ADDRESS", + "text": "{REDACTED}", + "operator": "replace", + } + ], + } + + # Create Presidio instance with output parsing enabled + presidio = _OPTIONAL_PresidioPIIMasking( + mock_testing=True, + presidio_analyzer_api_base="http://fake:3000", + presidio_anonymizer_api_base="http://fake:3000", + output_parse_pii=True, + default_on=True, + event_hook=["pre_call", "post_call"], + ) + + # Mock the check_pii method to return masked text + presidio.check_pii = AsyncMock(return_value=mock_anonymize_response["text"]) + + # Anthropic response format (dict) + response = { + "id": "msg_123", + "type": "message", + "role": "assistant", + "model": "claude-3", + "content": [ + { + "type": "text", + "text": "Here's a fake email address:\n\njohn.smith@example.com", + } + ], + "stop_reason": "end_turn", + } + + data = {} + user_api_key_dict = UserAPIKeyAuth(api_key="test_key") + + # Call the post-call hook + result = await presidio.async_post_call_success_hook( + data=data, + user_api_key_dict=user_api_key_dict, + response=response, + ) + + # Verify the email was masked in the response + assert result["content"][0]["text"] == "Here's a fake email address:\n\n{REDACTED}" + assert presidio.check_pii.called + + +@pytest.mark.asyncio +async def test_presidio_anthropic_multiple_content_blocks(): + """Test that Presidio handles multiple content blocks in Anthropic format""" + + # Create Presidio instance + presidio = _OPTIONAL_PresidioPIIMasking( + mock_testing=True, + presidio_analyzer_api_base="http://fake:3000", + presidio_anonymizer_api_base="http://fake:3000", + default_on=True, + event_hook="pre_call", + ) + + # Mock to return masked text + presidio.check_pii = AsyncMock( + side_effect=["My name is {REDACTED}", "Email: {REDACTED}"] + ) + + # Test data with multiple text blocks + data = { + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "My name is Ashley"}, + {"type": "text", "text": "Email: ashley@example.com"}, + ], + } + ] + } + + user_api_key_dict = UserAPIKeyAuth(api_key="test_key") + + # Call the pre-call hook + result = await presidio.async_pre_call_hook( + user_api_key_dict=user_api_key_dict, + cache=MagicMock(), + data=data, + call_type="acompletion", + ) + + # Verify both content blocks were masked + assert result["messages"][0]["content"][0]["text"] == "My name is {REDACTED}" + assert result["messages"][0]["content"][1]["text"] == "Email: {REDACTED}" + assert presidio.check_pii.call_count == 2 + + +@pytest.mark.asyncio +async def test_presidio_uses_curly_braces_for_redaction(): + """Test that Presidio uses {REDACTED} instead of to avoid Claude XML interpretation""" + + presidio = _OPTIONAL_PresidioPIIMasking( + mock_testing=True, + presidio_analyzer_api_base="http://fake:3000", + presidio_anonymizer_api_base="http://fake:3000", + default_on=True, + event_hook="pre_call", + ) + + # Mock the anonymize_text to capture the anonymizers config + async def mock_anonymize( + text, analyze_results, output_parse_pii, masked_entity_count + ): + # Verify the DEFAULT anonymizer uses {REDACTED} + # This would be called internally with our custom config + return text.replace("Ashley", "{REDACTED}") + + presidio.analyze_text = AsyncMock( + return_value=[{"start": 11, "end": 17, "score": 0.85, "entity_type": "PERSON"}] + ) + presidio.anonymize_text = mock_anonymize + + data = { + "messages": [ + {"role": "user", "content": [{"type": "text", "text": "My name is Ashley"}]} + ] + } + + result = await presidio.async_pre_call_hook( + user_api_key_dict=UserAPIKeyAuth(api_key="test"), + cache=MagicMock(), + data=data, + call_type="acompletion", + ) + + # Verify {REDACTED} format is used, not + masked_text = result["messages"][0]["content"][0]["text"] + assert "{REDACTED}" in masked_text + assert "" not in masked_text