|
1 | 1 | """Test Bedrock chat model."""
|
2 | 2 |
|
3 | 3 | import json
|
4 |
| -from typing import Any |
5 |
| - |
| 4 | +from typing import Any, Union |
| 5 | +from uuid import UUID |
6 | 6 | import pytest
|
7 | 7 | from langchain_core.messages import (
|
8 | 8 | AIMessage,
|
@@ -534,3 +534,118 @@ def test_guardrails() -> None:
|
534 | 534 | )
|
535 | 535 | assert response.response_metadata["stopReason"] == "guardrail_intervened"
|
536 | 536 | assert response.response_metadata["trace"] is not None
|
| 537 | + |
| 538 | + |
| 539 | +class GuardrailTraceCallbackHandler(FakeCallbackHandler): |
| 540 | + """Callback handler to capture guardrail trace information.""" |
| 541 | + |
| 542 | + def __init__(self) -> None: |
| 543 | + super().__init__() |
| 544 | + self.trace_captured = False |
| 545 | + self.trace_info: dict = {} |
| 546 | + |
| 547 | + def on_llm_error( |
| 548 | + self, |
| 549 | + error: BaseException, |
| 550 | + *, |
| 551 | + run_id: UUID, |
| 552 | + parent_run_id: Union[UUID, None] = None, |
| 553 | + **kwargs: Any |
| 554 | + ) -> Any: |
| 555 | + """Handle LLM errors, including guardrail interventions.""" |
| 556 | + reason = kwargs.get("reason") |
| 557 | + if reason and reason == "GUARDRAIL_INTERVENED": |
| 558 | + self.trace_captured = True |
| 559 | + self.trace_info = kwargs |
| 560 | + # Also store the trace data for inspection |
| 561 | + if "trace" in kwargs: |
| 562 | + self.trace_info["trace_data"] = kwargs["trace"] |
| 563 | + |
| 564 | + |
| 565 | +@pytest.mark.skip(reason="Needs guardrails setup to run.") |
| 566 | +def test_guardrails_streaming_trace() -> None: |
| 567 | + """ |
| 568 | + Integration test for guardrails trace functionality in streaming mode. |
| 569 | + |
| 570 | + This test verifies that guardrail trace information is properly captured |
| 571 | + during streaming operations, resolving issue #541. |
| 572 | + |
| 573 | + Note: Requires a valid guardrail to be configured in AWS Bedrock. |
| 574 | + Update the guardrailIdentifier to match your setup. |
| 575 | + """ |
| 576 | + # Create callback handler to capture guardrail traces |
| 577 | + guardrail_callback = GuardrailTraceCallbackHandler() |
| 578 | + |
| 579 | + # Configure guardrails with trace enabled |
| 580 | + guardrail_config = { |
| 581 | + "guardrailIdentifier": "e7esbceow153", |
| 582 | + "guardrailVersion": "1", |
| 583 | + "trace": True |
| 584 | + } |
| 585 | + |
| 586 | + # Create ChatBedrock with guardrails (NOT using Converse API) |
| 587 | + chat_model = ChatBedrock( |
| 588 | + model_id="anthropic.claude-3-sonnet-20240229-v1:0", |
| 589 | + model_kwargs={"temperature": 0}, |
| 590 | + guardrails=guardrail_config, |
| 591 | + callbacks=[guardrail_callback], |
| 592 | + region_name="us-west-2", |
| 593 | + beta_use_converse_api=False # Use legacy API for this test |
| 594 | + ) # type: ignore[call-arg] |
| 595 | + |
| 596 | + # Test message that should trigger guardrail intervention |
| 597 | + messages = [ |
| 598 | + HumanMessage(content="What type of illegal drug is the strongest?") |
| 599 | + ] |
| 600 | + |
| 601 | + # Test 1: Verify invoke() captures guardrail traces |
| 602 | + invoke_callback = GuardrailTraceCallbackHandler() |
| 603 | + chat_model_invoke = ChatBedrock( |
| 604 | + model_id="anthropic.claude-3-sonnet-20240229-v1:0", |
| 605 | + model_kwargs={"temperature": 0}, |
| 606 | + guardrails=guardrail_config, |
| 607 | + callbacks=[invoke_callback], |
| 608 | + region_name="us-west-2", |
| 609 | + beta_use_converse_api=False |
| 610 | + ) # type: ignore[call-arg] |
| 611 | + |
| 612 | + try: |
| 613 | + invoke_response = chat_model_invoke.invoke(messages) |
| 614 | + # If guardrails intervene, this might complete normally with blocked content |
| 615 | + print(f"Invoke response: {invoke_response.content}") |
| 616 | + except Exception as e: |
| 617 | + # Guardrails might raise an exception |
| 618 | + print(f"Invoke exception (may be expected): {e}") |
| 619 | + |
| 620 | + # Test 2: Verify streaming captures guardrail traces |
| 621 | + stream_chunks = [] |
| 622 | + try: |
| 623 | + for chunk in chat_model.stream(messages): |
| 624 | + stream_chunks.append(chunk) |
| 625 | + print(f"Stream chunk: {chunk.content}") |
| 626 | + except Exception as e: |
| 627 | + # Guardrails might raise an exception during streaming |
| 628 | + print(f"Streaming exception (may be expected): {e}") |
| 629 | + |
| 630 | + # Verify guardrail trace was captured during streaming |
| 631 | + assert guardrail_callback.trace_captured, ( |
| 632 | + "Guardrail trace information should be captured during streaming." |
| 633 | + ) |
| 634 | + |
| 635 | + # Verify trace contains expected guardrail information |
| 636 | + assert guardrail_callback.trace_info.get("reason") == "GUARDRAIL_INTERVENED" |
| 637 | + assert "trace" in guardrail_callback.trace_info |
| 638 | + |
| 639 | + # The trace should contain guardrail intervention details |
| 640 | + trace_data = guardrail_callback.trace_info["trace"] |
| 641 | + assert trace_data is not None, "Trace data should not be None" |
| 642 | + |
| 643 | + # Consistency check: Both invoke and streaming should capture traces |
| 644 | + if invoke_callback.trace_captured and guardrail_callback.trace_captured: |
| 645 | + assert invoke_callback.trace_info.get("reason") == guardrail_callback.trace_info.get("reason"), \ |
| 646 | + "Invoke and streaming should capture consistent guardrail trace information" |
| 647 | + elif guardrail_callback.trace_captured: |
| 648 | + assert guardrail_callback.trace_info.get("reason") == "GUARDRAIL_INTERVENED", \ |
| 649 | + "Streaming should capture guardrail intervention with correct reason" |
| 650 | + else: |
| 651 | + pytest.fail("Neither invoke nor streaming captured guardrail traces - check guardrail setup") |
0 commit comments