Skip to content

Commit f387803

Browse files
Merge pull request #14658 from ARajan1084/bedrock-custom-guardrail-fix
fix: check for AWS exceptions despite a 200 response
2 parents 3671c6a + 2bbcf5a commit f387803

File tree

2 files changed

+136
-12
lines changed

2 files changed

+136
-12
lines changed

litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py

Lines changed: 61 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,9 @@ async def make_bedrock_api_request(
384384
)
385385
#########################################################
386386
if response.status_code == 200:
387+
# check if the response contains an error
388+
if self._check_bedrock_response_for_exception(response=response):
389+
raise self._get_http_exception_for_failed_guardrail(response)
387390
# check if the response was flagged
388391
_json_response = response.json()
389392
redacted_response = _redact_pii_matches(_json_response)
@@ -404,16 +407,64 @@ async def make_bedrock_api_request(
404407

405408
return bedrock_guardrail_response
406409

410+
def _check_bedrock_response_for_exception(self, response) -> bool:
411+
"""
412+
Return True if the Bedrock ApplyGuardrail response indicates an exception.
413+
414+
Works with real httpx.Response objects and MagicMock responses used in tests.
415+
"""
416+
payload = None
417+
418+
try:
419+
json_method = getattr(response, "json", None)
420+
if callable(json_method):
421+
payload = json_method()
422+
except Exception:
423+
payload = None
424+
425+
if payload is None:
426+
try:
427+
raw = getattr(response, "content", None)
428+
if isinstance(raw, (bytes, bytearray)):
429+
payload = json.loads(raw.decode("utf-8"))
430+
else:
431+
text = getattr(response, "text", None)
432+
if isinstance(text, str):
433+
payload = json.loads(text)
434+
except Exception:
435+
# Can't parse -> assume no explicit Exception marker
436+
return False
437+
438+
if not isinstance(payload, dict):
439+
return False
440+
441+
return "Exception" in payload.get("Output", {}).get("__type", "")
442+
407443
def _get_bedrock_guardrail_response_status(
408444
self, response: httpx.Response
409445
) -> Literal["success", "failure"]:
410446
"""
411447
Get the status of the bedrock guardrail response.
412448
"""
413449
if response.status_code == 200:
450+
if self._check_bedrock_response_for_exception(response):
451+
return "failure"
414452
return "success"
415453
return "failure"
416454

455+
def _get_http_exception_for_failed_guardrail(
456+
self, response: httpx.Response
457+
) -> HTTPException:
458+
return HTTPException(
459+
status_code=400,
460+
detail={
461+
"error": "Guardrail application failed.",
462+
"bedrock_guardrail_response": json.loads(
463+
response.content.decode("utf-8")
464+
).get("Output", {}),
465+
},
466+
)
467+
417468
def _get_http_exception_for_blocked_guardrail(
418469
self, response: BedrockGuardrailResponse
419470
) -> HTTPException:
@@ -562,11 +613,11 @@ async def async_pre_call_hook(
562613
#########################################################
563614
########## 2. Update the messages with the guardrail response ##########
564615
#########################################################
565-
data["messages"] = (
566-
self._update_messages_with_updated_bedrock_guardrail_response(
567-
messages=new_messages,
568-
bedrock_guardrail_response=bedrock_guardrail_response,
569-
)
616+
data[
617+
"messages"
618+
] = self._update_messages_with_updated_bedrock_guardrail_response(
619+
messages=new_messages,
620+
bedrock_guardrail_response=bedrock_guardrail_response,
570621
)
571622

572623
#########################################################
@@ -617,11 +668,11 @@ async def async_moderation_hook(
617668
#########################################################
618669
########## 2. Update the messages with the guardrail response ##########
619670
#########################################################
620-
data["messages"] = (
621-
self._update_messages_with_updated_bedrock_guardrail_response(
622-
messages=new_messages,
623-
bedrock_guardrail_response=bedrock_guardrail_response,
624-
)
671+
data[
672+
"messages"
673+
] = self._update_messages_with_updated_bedrock_guardrail_response(
674+
messages=new_messages,
675+
bedrock_guardrail_response=bedrock_guardrail_response,
625676
)
626677

627678
#########################################################

tests/test_litellm/proxy/guardrails/guardrail_hooks/test_bedrock_guardrails.py

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
"""
22
Unit tests for Bedrock Guardrails
33
"""
4-
4+
import json
55
import os
66
import sys
77
from unittest.mock import AsyncMock, MagicMock, patch
88

99
import pytest
10+
from fastapi import HTTPException
1011

1112
sys.path.insert(0, os.path.abspath("../../../../../.."))
1213

@@ -860,7 +861,6 @@ async def test__redact_pii_matches_comprehensive_coverage():
860861

861862
print("Comprehensive coverage redaction test passed")
862863

863-
864864
@pytest.mark.asyncio
865865
async def test_bedrock_guardrail_respects_custom_runtime_endpoint(monkeypatch):
866866
"""Test that BedrockGuardrail respects aws_bedrock_runtime_endpoint when set"""
@@ -1049,3 +1049,76 @@ async def test_bedrock_guardrail_parameter_takes_precedence_over_env(monkeypatch
10491049
), f"Expected parameter endpoint to take precedence. Got: {prepped_request.url}"
10501050

10511051
print(f"Parameter precedence test passed. URL: {prepped_request.url}")
1052+
1053+
@pytest.mark.asyncio
1054+
async def test_bedrock_guardrail_200_with_exception_in_output_raises_and_logs_failure():
1055+
"""
1056+
When Bedrock returns HTTP 200 but the body contains Output.__type with 'Exception',
1057+
the guardrail should:
1058+
- raise an HTTPException(400) with the Output payload in detail
1059+
- log the request trace with guardrail_status='failure'
1060+
"""
1061+
guardrail = BedrockGuardrail(
1062+
guardrailIdentifier="test-guardrail", guardrailVersion="DRAFT"
1063+
)
1064+
1065+
# Mock a Bedrock "success" HTTP status but an Exception embedded in the body
1066+
payload = {
1067+
"Output": {
1068+
"__type": "com.amazonaws#InternalServerException",
1069+
"message": "Something went wrong upstream",
1070+
},
1071+
"action": "NONE",
1072+
}
1073+
mock_resp = MagicMock()
1074+
mock_resp.status_code = 200
1075+
mock_resp.content = json.dumps(payload).encode("utf-8")
1076+
mock_resp.text = json.dumps(payload)
1077+
mock_resp.json.return_value = payload
1078+
1079+
# Minimal request data
1080+
request_data = {
1081+
"model": "gpt-4o",
1082+
"messages": [{"role": "user", "content": "hello"}],
1083+
}
1084+
1085+
# Mock creds and request prep
1086+
mock_credentials = MagicMock()
1087+
mock_credentials.access_key = "ak"
1088+
mock_credentials.secret_key = "sk"
1089+
mock_credentials.token = None
1090+
1091+
with patch.object(
1092+
guardrail.async_handler, "post", new_callable=AsyncMock
1093+
) as mock_post, patch.object(
1094+
guardrail, "_load_credentials", return_value=(mock_credentials, "us-east-1")
1095+
), patch.object(
1096+
guardrail,
1097+
"_prepare_request",
1098+
return_value=MagicMock(url="http://example", headers={}, body=b""),
1099+
), patch.object(
1100+
guardrail, "add_standard_logging_guardrail_information_to_request_data"
1101+
) as mock_add_trace:
1102+
mock_post.return_value = mock_resp
1103+
1104+
with pytest.raises(HTTPException) as excinfo:
1105+
await guardrail.make_bedrock_api_request(
1106+
source="INPUT",
1107+
messages=request_data["messages"],
1108+
request_data=request_data,
1109+
)
1110+
1111+
# 1) Raised HTTPException with 400 status
1112+
err = excinfo.value
1113+
assert err.status_code == 400
1114+
assert err.detail["error"] == "Guardrail application failed."
1115+
1116+
# 2) Detail includes the Output object from the Bedrock body
1117+
assert err.detail["bedrock_guardrail_response"] == payload["Output"]
1118+
1119+
# 3) Trace logging received a 'failure' status
1120+
assert mock_add_trace.called
1121+
_, kwargs = mock_add_trace.call_args
1122+
assert kwargs["guardrail_status"] == "failure"
1123+
# And the JSON passed to tracing is the same response we received
1124+
assert kwargs["guardrail_json_response"] == payload

0 commit comments

Comments
 (0)