Skip to content

Commit 2908972

Browse files
committed
Adding guardrail support to all strands agents. Uses existing guardrail created by user (used in several places throughout the application outside of agents already). No guardrail is used if users don't specify one at deploy time. Note this required adding a helper function, so IDPAgent's can either create their own strands.BedrockModel (not recommended) or can use the create_strands_bedrock_model helper function to create one to automatically include IDP guardrails (recommended)
1 parent da7c256 commit 2908972

File tree

6 files changed

+65
-8
lines changed

6 files changed

+65
-8
lines changed

lib/idp_common_pkg/idp_common/agents/analytics/agent.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111

1212
import boto3
1313
import strands
14-
from strands.models import BedrockModel
1514

1615
from ..common.config import load_result_format_description
16+
from ..common.strands_bedrock_model import create_strands_bedrock_model
1717
from .config import load_python_plot_generation_examples
1818
from .tools import CodeInterpreterTools, get_database_info, run_athena_query
1919
from .utils import register_code_interpreter_tools
@@ -138,7 +138,9 @@ def run_athena_query_with_config(
138138
# Get model ID from environment variable
139139
model_id = os.environ.get("DOCUMENT_ANALYSIS_AGENT_MODEL_ID")
140140

141-
bedrock_model = BedrockModel(model_id=model_id, boto_session=session)
141+
bedrock_model = create_strands_bedrock_model(
142+
model_id=model_id, boto_session=session
143+
)
142144

143145
# Create the Strands agent with tools and system prompt
144146
strands_agent = strands.Agent(
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: MIT-0
3+
4+
"""
5+
Helper function for creating BedrockModel instances with automatic guardrail support.
6+
"""
7+
8+
import os
9+
10+
from strands.models import BedrockModel
11+
12+
13+
def create_strands_bedrock_model(
14+
model_id: str, boto_session=None, **kwargs
15+
) -> BedrockModel:
16+
"""
17+
Create a BedrockModel with automatic guardrail configuration from environment.
18+
19+
Args:
20+
model_id: The Bedrock model ID to use
21+
boto_session: Optional boto3 session
22+
**kwargs: Additional arguments to pass to BedrockModel
23+
24+
Returns:
25+
BedrockModel instance with guardrails applied if configured
26+
"""
27+
# Get guardrail configuration from environment if available
28+
guardrail_env = os.environ.get("GUARDRAIL_ID_AND_VERSION", "")
29+
if guardrail_env:
30+
try:
31+
guardrail_id, guardrail_version = guardrail_env.split(":")
32+
if guardrail_id and guardrail_version:
33+
kwargs.update(
34+
{
35+
"guardrail_id": guardrail_id,
36+
"guardrail_version": guardrail_version,
37+
"guardrail_trace": "enabled",
38+
}
39+
)
40+
except ValueError:
41+
pass # Invalid format, continue without guardrails
42+
43+
return BedrockModel(model_id=model_id, boto_session=boto_session, **kwargs)

lib/idp_common_pkg/idp_common/agents/external_mcp/agent.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212
import boto3
1313
from mcp.client.streamable_http import streamablehttp_client
1414
from strands import Agent
15-
from strands.models import BedrockModel
1615
from strands.tools.mcp import MCPClient
1716

1817
from ..common.oauth_auth import get_cognito_bearer_token
18+
from ..common.strands_bedrock_model import create_strands_bedrock_model
1919

2020
logger = logging.getLogger(__name__)
2121

@@ -130,7 +130,9 @@ def create_external_mcp_agent(
130130
raise Exception(error_msg)
131131

132132
# Create Bedrock model
133-
bedrock_model = BedrockModel(model_id=model_id, boto_session=session)
133+
bedrock_model = create_strands_bedrock_model(
134+
model_id=model_id, boto_session=session
135+
)
134136

135137
# Create system prompt
136138
system_prompt = f"""

lib/idp_common_pkg/idp_common/agents/orchestrator/agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515

1616
import strands
1717
from strands import tool
18-
from strands.models import BedrockModel
1918

2019
from ..common.config import load_result_format_description
20+
from ..common.strands_bedrock_model import create_strands_bedrock_model
2121

2222
logger = logging.getLogger(__name__)
2323

@@ -161,7 +161,7 @@ def tool_func(query: str) -> str:
161161
)
162162

163163
# Create the orchestrator agent
164-
model = BedrockModel(
164+
model = create_strands_bedrock_model(
165165
model_id=model_id,
166166
session=session,
167167
)

lib/idp_common_pkg/idp_common/agents/sample_calculator/agent.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77

88
import boto3
99
import strands
10-
from strands.models import BedrockModel
1110
from strands_tools import calculator
1211

12+
from ..common.strands_bedrock_model import create_strands_bedrock_model
13+
1314
logger = logging.getLogger(__name__)
1415

1516

@@ -31,7 +32,7 @@ def create_sample_calculator_agent(
3132
model_id = "us.anthropic.claude-3-7-sonnet-20250219-v1:0"
3233

3334
# Create Bedrock model
34-
model = BedrockModel(model_id=model_id, session=session)
35+
model = create_strands_bedrock_model(model_id=model_id, session=session)
3536

3637
# Create and return agent with calculator tool
3738
return strands.Agent(model=model, tools=[calculator])

template.yaml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3608,6 +3608,7 @@ Resources:
36083608
- !Ref ReportingBucketName
36093609
DOCUMENT_ANALYSIS_AGENT_MODEL_ID: !Ref DocumentAnalysisAgentModelId
36103610
AWS_STACK_NAME: !Ref AWS::StackName
3611+
GUARDRAIL_ID_AND_VERSION: !If [HasGuardrailConfig, !Sub "${BedrockGuardrailId}:${BedrockGuardrailVersion}", ""]
36113612
LoggingConfig:
36123613
LogGroup: !Ref AgentProcessorLogGroup
36133614
Policies:
@@ -3645,6 +3646,14 @@ Resources:
36453646
Resource:
36463647
- !Sub "arn:${AWS::Partition}:bedrock:*::foundation-model/*"
36473648
- !Sub "arn:${AWS::Partition}:bedrock:${AWS::Region}:${AWS::AccountId}:inference-profile/*"
3649+
- !If
3650+
- HasGuardrailConfig
3651+
- Effect: Allow
3652+
Action:
3653+
- "bedrock:ApplyGuardrail"
3654+
Resource:
3655+
- !Sub "arn:${AWS::Partition}:bedrock:${AWS::Region}:${AWS::AccountId}:guardrail/${BedrockGuardrailId}"
3656+
- !Ref AWS::NoValue
36483657
- Effect: Allow
36493658
Action:
36503659
- appsync:GraphQL

0 commit comments

Comments
 (0)