Skip to content

Commit 3110f7b

Browse files
Add Support for GuardRailConverseContentBlock (#557)
We introduce a new flag, guard_last_turn_only, which applies the selective guardrail for only the last user message. For advanced use cases, we also allow the dev to pass in raw (Bedrock) message blocks. Issue Link: #540 --------- Co-authored-by: Michael Chin <[email protected]>
1 parent 3345624 commit 3110f7b

File tree

2 files changed

+124
-2
lines changed

2 files changed

+124
-2
lines changed

libs/aws/langchain_aws/chat_models/bedrock_converse.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,15 @@ class Joke(BaseModel):
456456
request_metadata: Optional[Dict[str, str]] = None
457457
"""Key-Value pairs that you can use to filter invocation logs."""
458458

459+
guard_last_turn_only: bool = False
460+
"""Boolean flag for applying the guardrail to only the last turn."""
461+
462+
raw_blocks: Optional[List[Dict[str, Any]]] = None
463+
"""Raw Bedrock message blocks that can be passed in.
464+
LangChain will relay them unchanged, enabling any combination of content block types.
465+
This is useful for custom guardrail wrapping
466+
"""
467+
459468
model_config = ConfigDict(
460469
extra="forbid",
461470
populate_by_name=True,
@@ -634,11 +643,31 @@ def validate_environment(self) -> Self:
634643
service_name="bedrock-runtime",
635644
)
636645

646+
if self.guard_last_turn_only and not self.guardrail_config:
647+
raise ValueError(
648+
"`guard_last_turn_only=True` but no `guardrail_config` supplied. "
649+
"Provide a guardrail via `guardrail_config` or "
650+
"disable `guard_last_turn_only`."
651+
)
637652
return self
638653

639654
def _get_base_model(self) -> str:
640655
return self.base_model_id if self.base_model_id else self.model_id
641656

657+
def _apply_guard_last_turn_only(self, messages: List[Dict[str, Any]]) -> None:
658+
for msg in reversed(messages):
659+
if msg.get("role") == "user":
660+
new_content = []
661+
for block in msg["content"]:
662+
if "text" in block:
663+
new_content.append(
664+
{"guardContent": {"text": {"text": block["text"]}}}
665+
)
666+
else:
667+
new_content.append(block)
668+
msg["content"] = new_content
669+
break
670+
642671
def _generate(
643672
self,
644673
messages: List[BaseMessage],
@@ -647,7 +676,16 @@ def _generate(
647676
**kwargs: Any,
648677
) -> ChatResult:
649678
"""Top Level call"""
650-
bedrock_messages, system = _messages_to_bedrock(messages)
679+
680+
if self.raw_blocks is not None:
681+
logger.debug(f"Using raw blocks: {self.raw_blocks}")
682+
bedrock_messages, system = self.raw_blocks, []
683+
else:
684+
bedrock_messages, system = _messages_to_bedrock(messages)
685+
if self.guard_last_turn_only:
686+
logger.debug("Applying selective guardrail to only the last turn")
687+
self._apply_guard_last_turn_only(bedrock_messages)
688+
651689
logger.debug(f"input message to bedrock: {bedrock_messages}")
652690
logger.debug(f"System message to bedrock: {system}")
653691
params = self._converse_params(
@@ -673,7 +711,15 @@ def _stream(
673711
run_manager: Optional[CallbackManagerForLLMRun] = None,
674712
**kwargs: Any,
675713
) -> Iterator[ChatGenerationChunk]:
676-
bedrock_messages, system = _messages_to_bedrock(messages)
714+
if self.raw_blocks is not None:
715+
logger.debug(f"Using raw blocks: {self.raw_blocks}")
716+
bedrock_messages, system = self.raw_blocks, []
717+
else:
718+
bedrock_messages, system = _messages_to_bedrock(messages)
719+
if self.guard_last_turn_only:
720+
logger.debug("Applying selective guardrail to only the last turn")
721+
self._apply_guard_last_turn_only(bedrock_messages)
722+
677723
params = self._converse_params(
678724
stop=stop,
679725
**_snake_to_camel_keys(

libs/aws/tests/unit_tests/chat_models/test_bedrock_converse.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1307,3 +1307,79 @@ def test_model_kwargs() -> None:
13071307
)
13081308
assert llm.additional_model_request_fields == {"temperature": 0.2}
13091309
assert llm.temperature is None
1310+
1311+
1312+
def _create_mock_llm_guard_last_turn_only() -> (
1313+
Tuple[ChatBedrockConverse, mock.MagicMock]
1314+
):
1315+
"""Utility to create an LLM with guard_last_turn_only=True and a mocked client."""
1316+
mocked_client = mock.MagicMock()
1317+
llm = ChatBedrockConverse(
1318+
client=mocked_client,
1319+
model="anthropic.claude-3-sonnet-20240229-v1:0",
1320+
region_name="us-west-2",
1321+
guard_last_turn_only=True,
1322+
guardrails={"guardrailId": "dummy-guardrail", "guardrailVersion": "1"},
1323+
)
1324+
return llm, mocked_client
1325+
1326+
1327+
def test_guard_last_turn_only_no_guardrail_config() -> None:
1328+
"""Test that an error is raised if guard_last_turn_only is True but no guardrail_config is provided."""
1329+
with pytest.raises(ValueError):
1330+
ChatBedrockConverse(
1331+
client=mock.MagicMock(),
1332+
model="anthropic.claude-3-sonnet-20240229-v1:0",
1333+
region_name="us-west-2",
1334+
guard_last_turn_only=True,
1335+
)
1336+
1337+
1338+
def test_generate_guard_last_turn_only() -> None:
1339+
"""Test that _generate() wraps ONLY the final user turn with guardContent."""
1340+
llm, mocked_client = _create_mock_llm_guard_last_turn_only()
1341+
1342+
mocked_client.converse.return_value = {
1343+
"output": {"message": {"content": [{"text": "ok"}]}},
1344+
"usage": {"inputTokens": 1, "outputTokens": 1, "totalTokens": 2},
1345+
}
1346+
1347+
messages = [
1348+
HumanMessage(content="First user message"),
1349+
AIMessage(content="Assistant reply"),
1350+
HumanMessage(content="Second user message"),
1351+
]
1352+
1353+
llm.invoke(messages)
1354+
_, kwargs = mocked_client.converse.call_args
1355+
bedrock_msgs = kwargs["messages"]
1356+
1357+
assert bedrock_msgs[0]["content"][0] == {"text": "First user message"}
1358+
# Last user turn is wrapped in guardContent
1359+
assert bedrock_msgs[-1]["content"][0] == {
1360+
"guardContent": {"text": {"text": "Second user message"}}
1361+
}
1362+
1363+
1364+
def test_stream_guard_last_turn_only() -> None:
1365+
"""Test that stream() applies guardContent to final user turn."""
1366+
llm, mocked_client = _create_mock_llm_guard_last_turn_only()
1367+
1368+
mocked_client.converse_stream.return_value = {
1369+
"stream": [{"messageStart": {"role": "assistant"}}]
1370+
}
1371+
1372+
messages = [
1373+
HumanMessage(content="Hello"),
1374+
AIMessage(content="Hi!"),
1375+
HumanMessage(content="How are you?"),
1376+
]
1377+
list(llm.stream(messages))
1378+
1379+
_, kwargs = mocked_client.converse_stream.call_args
1380+
bedrock_msgs = kwargs["messages"]
1381+
1382+
assert bedrock_msgs[0]["content"][0] == {"text": "Hello"}
1383+
assert bedrock_msgs[-1]["content"][0] == {
1384+
"guardContent": {"text": {"text": "How are you?"}}
1385+
}

0 commit comments

Comments
 (0)