Skip to content

Commit 09842d6

Browse files
fix: add guardrails trace support for ChatBedrock streaming (#541) (#587)
Fixes #541 - track guardrails interventions during streaming operations - call callback handlers with trace info when guardrails intervenes
1 parent 3847ad8 commit 09842d6

File tree

2 files changed

+135
-2
lines changed

2 files changed

+135
-2
lines changed

libs/aws/langchain_aws/chat_models/bedrock.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -833,6 +833,9 @@ def _stream(
833833
)
834834

835835
added_model_name = False
836+
# Track guardrails trace information for callback handling
837+
guardrails_trace_info = None
838+
836839
for chunk in self._prepare_input_and_invoke_stream(
837840
prompt=prompt,
838841
system=system,
@@ -852,6 +855,12 @@ def _stream(
852855
delta = chunk.text
853856
response_metadata = None
854857
if generation_info := chunk.generation_info:
858+
# Check for guardrail intervention in the streaming chunk
859+
services_trace = self._get_bedrock_services_signal(generation_info)
860+
if services_trace.get("signal") and run_manager:
861+
# Store trace info for potential callback
862+
guardrails_trace_info = services_trace
863+
855864
usage_metadata = generation_info.pop("usage_metadata", None)
856865
response_metadata = generation_info
857866
if not added_model_name:
@@ -873,6 +882,15 @@ def _stream(
873882
generation_chunk.text, chunk=generation_chunk
874883
)
875884
yield generation_chunk
885+
886+
# If guardrails intervened during streaming, notify the callback handler
887+
if guardrails_trace_info and run_manager:
888+
run_manager.on_llm_error(
889+
Exception(
890+
f"Error raised by bedrock service: {guardrails_trace_info.get('reason')}"
891+
),
892+
**guardrails_trace_info,
893+
)
876894

877895
def _generate(
878896
self,

libs/aws/tests/integration_tests/chat_models/test_bedrock.py

Lines changed: 117 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
"""Test Bedrock chat model."""
22

33
import json
4-
from typing import Any
5-
4+
from typing import Any, Union
5+
from uuid import UUID
66
import pytest
77
from langchain_core.messages import (
88
AIMessage,
@@ -534,3 +534,118 @@ def test_guardrails() -> None:
534534
)
535535
assert response.response_metadata["stopReason"] == "guardrail_intervened"
536536
assert response.response_metadata["trace"] is not None
537+
538+
539+
class GuardrailTraceCallbackHandler(FakeCallbackHandler):
540+
"""Callback handler to capture guardrail trace information."""
541+
542+
def __init__(self) -> None:
543+
super().__init__()
544+
self.trace_captured = False
545+
self.trace_info: dict = {}
546+
547+
def on_llm_error(
548+
self,
549+
error: BaseException,
550+
*,
551+
run_id: UUID,
552+
parent_run_id: Union[UUID, None] = None,
553+
**kwargs: Any
554+
) -> Any:
555+
"""Handle LLM errors, including guardrail interventions."""
556+
reason = kwargs.get("reason")
557+
if reason and reason == "GUARDRAIL_INTERVENED":
558+
self.trace_captured = True
559+
self.trace_info = kwargs
560+
# Also store the trace data for inspection
561+
if "trace" in kwargs:
562+
self.trace_info["trace_data"] = kwargs["trace"]
563+
564+
565+
@pytest.mark.skip(reason="Needs guardrails setup to run.")
566+
def test_guardrails_streaming_trace() -> None:
567+
"""
568+
Integration test for guardrails trace functionality in streaming mode.
569+
570+
This test verifies that guardrail trace information is properly captured
571+
during streaming operations, resolving issue #541.
572+
573+
Note: Requires a valid guardrail to be configured in AWS Bedrock.
574+
Update the guardrailIdentifier to match your setup.
575+
"""
576+
# Create callback handler to capture guardrail traces
577+
guardrail_callback = GuardrailTraceCallbackHandler()
578+
579+
# Configure guardrails with trace enabled
580+
guardrail_config = {
581+
"guardrailIdentifier": "e7esbceow153",
582+
"guardrailVersion": "1",
583+
"trace": True
584+
}
585+
586+
# Create ChatBedrock with guardrails (NOT using Converse API)
587+
chat_model = ChatBedrock(
588+
model_id="anthropic.claude-3-sonnet-20240229-v1:0",
589+
model_kwargs={"temperature": 0},
590+
guardrails=guardrail_config,
591+
callbacks=[guardrail_callback],
592+
region_name="us-west-2",
593+
beta_use_converse_api=False # Use legacy API for this test
594+
) # type: ignore[call-arg]
595+
596+
# Test message that should trigger guardrail intervention
597+
messages = [
598+
HumanMessage(content="What type of illegal drug is the strongest?")
599+
]
600+
601+
# Test 1: Verify invoke() captures guardrail traces
602+
invoke_callback = GuardrailTraceCallbackHandler()
603+
chat_model_invoke = ChatBedrock(
604+
model_id="anthropic.claude-3-sonnet-20240229-v1:0",
605+
model_kwargs={"temperature": 0},
606+
guardrails=guardrail_config,
607+
callbacks=[invoke_callback],
608+
region_name="us-west-2",
609+
beta_use_converse_api=False
610+
) # type: ignore[call-arg]
611+
612+
try:
613+
invoke_response = chat_model_invoke.invoke(messages)
614+
# If guardrails intervene, this might complete normally with blocked content
615+
print(f"Invoke response: {invoke_response.content}")
616+
except Exception as e:
617+
# Guardrails might raise an exception
618+
print(f"Invoke exception (may be expected): {e}")
619+
620+
# Test 2: Verify streaming captures guardrail traces
621+
stream_chunks = []
622+
try:
623+
for chunk in chat_model.stream(messages):
624+
stream_chunks.append(chunk)
625+
print(f"Stream chunk: {chunk.content}")
626+
except Exception as e:
627+
# Guardrails might raise an exception during streaming
628+
print(f"Streaming exception (may be expected): {e}")
629+
630+
# Verify guardrail trace was captured during streaming
631+
assert guardrail_callback.trace_captured, (
632+
"Guardrail trace information should be captured during streaming."
633+
)
634+
635+
# Verify trace contains expected guardrail information
636+
assert guardrail_callback.trace_info.get("reason") == "GUARDRAIL_INTERVENED"
637+
assert "trace" in guardrail_callback.trace_info
638+
639+
# The trace should contain guardrail intervention details
640+
trace_data = guardrail_callback.trace_info["trace"]
641+
assert trace_data is not None, "Trace data should not be None"
642+
643+
# Consistency check: Both invoke and streaming should capture traces
644+
if invoke_callback.trace_captured and guardrail_callback.trace_captured:
645+
assert invoke_callback.trace_info.get("reason") == guardrail_callback.trace_info.get("reason"), \
646+
"Invoke and streaming should capture consistent guardrail trace information"
647+
elif guardrail_callback.trace_captured:
648+
assert guardrail_callback.trace_info.get("reason") == "GUARDRAIL_INTERVENED", \
649+
"Streaming should capture guardrail intervention with correct reason"
650+
else:
651+
pytest.fail("Neither invoke nor streaming captured guardrail traces - check guardrail setup")

0 commit comments

Comments
 (0)