Skip to content

add optional session parameter to CodeInterpreter for custom credential management #41

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions src/bedrock_agentcore/tools/code_interpreter_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,19 @@ class CodeInterpreter:
session_id (str, optional): The active session ID.
"""

def __init__(self, region: str) -> None:
def __init__(self, region: str, session: Optional[boto3.Session] = None) -> None:
"""Initialize a Code Interpreter client for the specified AWS region.

Args:
region (str): The AWS region to use for the Code Interpreter service.
session (Optional[boto3.Session]): Optional boto3 session to use.
If not provided, a new session will be created. This is useful
for cases where you need to use custom credentials or assume roles.
"""
self.data_plane_service_name = "bedrock-agentcore"
self.client = boto3.client(
if session is None:
session = boto3.Session()
self.client = session.client(
self.data_plane_service_name, region_name=region, endpoint_url=get_data_plane_endpoint(region)
)
self._identifier = None
Expand Down Expand Up @@ -160,24 +165,26 @@ def invoke(self, method: str, params: Optional[Dict] = None):


@contextmanager
def code_session(region: str) -> Generator[CodeInterpreter, None, None]:
def code_session(region: str, session: Optional[boto3.Session] = None) -> Generator[CodeInterpreter, None, None]:
"""Context manager for creating and managing a code interpreter session.

This context manager handles creating a client, starting a session, and
ensuring the session is properly cleaned up when the context exits.

Args:
region (str): The AWS region to use for the Code Interpreter service.
session (Optional[boto3.Session]): Optional boto3 session to use.
If not provided, a new session will be created.

Yields:
CodeInterpreterClient: An initialized and started code interpreter client.
CodeInterpreter: An initialized and started code interpreter client.

Example:
>>> with code_session('us-west-2') as client:
... result = client.invoke('listFiles')
... # Process result here
"""
client = CodeInterpreter(region)
client = CodeInterpreter(region, session=session)
client.start()

try:
Expand Down
74 changes: 65 additions & 9 deletions tests/bedrock_agentcore/tools/test_code_interpreter_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,39 @@ class TestCodeInterpreterClient:
@patch("bedrock_agentcore.tools.code_interpreter_client.get_data_plane_endpoint")
def test_init(self, mock_get_endpoint, mock_boto3):
# Arrange
mock_session = MagicMock()
mock_client = MagicMock()
mock_boto3.client.return_value = mock_client
mock_session.client.return_value = mock_client
mock_boto3.Session.return_value = mock_session
mock_get_endpoint.return_value = "https://mock-endpoint.com"
region = "us-west-2"

# Act
client = CodeInterpreter(region)

# Assert
mock_boto3.client.assert_called_once_with(
mock_boto3.Session.assert_called_once()
mock_session.client.assert_called_once_with(
"bedrock-agentcore", region_name=region, endpoint_url="https://mock-endpoint.com"
)
assert client.client == mock_client
assert client.identifier is None
assert client.session_id is None

@patch("bedrock_agentcore.tools.code_interpreter_client.get_data_plane_endpoint")
def test_init_with_custom_session(self, mock_get_endpoint):
# Arrange
mock_session = MagicMock()
mock_client = MagicMock()
mock_session.client.return_value = mock_client
mock_get_endpoint.return_value = "https://mock-endpoint.com"
region = "us-west-2"

# Act
client = CodeInterpreter(region, session=mock_session)

# Assert
mock_session.client.assert_called_once_with(
"bedrock-agentcore", region_name=region, endpoint_url="https://mock-endpoint.com"
)
assert client.client == mock_client
Expand All @@ -27,6 +50,11 @@ def test_init(self, mock_get_endpoint, mock_boto3):
@patch("bedrock_agentcore.tools.code_interpreter_client.boto3")
def test_property_getters_setters(self, mock_boto3):
# Arrange
mock_session = MagicMock()
mock_client = MagicMock()
mock_session.client.return_value = mock_client
mock_boto3.Session.return_value = mock_session

client = CodeInterpreter("us-west-2")
test_identifier = "test.identifier"
test_session_id = "test-session-id"
Expand All @@ -43,8 +71,10 @@ def test_property_getters_setters(self, mock_boto3):
@patch("bedrock_agentcore.tools.code_interpreter_client.uuid.uuid4")
def test_start_with_defaults(self, mock_uuid4, mock_boto3):
# Arrange
mock_session = MagicMock()
mock_client = MagicMock()
mock_boto3.client.return_value = mock_client
mock_session.client.return_value = mock_client
mock_boto3.Session.return_value = mock_session
mock_uuid4.return_value.hex = "12345678abcdef"

client = CodeInterpreter("us-west-2")
Expand All @@ -67,8 +97,10 @@ def test_start_with_defaults(self, mock_uuid4, mock_boto3):
@patch("bedrock_agentcore.tools.code_interpreter_client.boto3")
def test_start_with_custom_params(self, mock_boto3):
# Arrange
mock_session = MagicMock()
mock_client = MagicMock()
mock_boto3.client.return_value = mock_client
mock_session.client.return_value = mock_client
mock_boto3.Session.return_value = mock_session

client = CodeInterpreter("us-west-2")
mock_response = {"codeInterpreterIdentifier": "custom.interpreter", "sessionId": "custom-session-123"}
Expand All @@ -94,8 +126,10 @@ def test_start_with_custom_params(self, mock_boto3):
@patch("bedrock_agentcore.tools.code_interpreter_client.boto3")
def test_stop_when_session_exists(self, mock_boto3):
# Arrange
mock_session = MagicMock()
mock_client = MagicMock()
mock_boto3.client.return_value = mock_client
mock_session.client.return_value = mock_client
mock_boto3.Session.return_value = mock_session

client = CodeInterpreter("us-west-2")
client.identifier = "test.identifier"
Expand All @@ -114,8 +148,10 @@ def test_stop_when_session_exists(self, mock_boto3):
@patch("bedrock_agentcore.tools.code_interpreter_client.boto3")
def test_stop_when_no_session(self, mock_boto3):
# Arrange
mock_session = MagicMock()
mock_client = MagicMock()
mock_boto3.client.return_value = mock_client
mock_session.client.return_value = mock_client
mock_boto3.Session.return_value = mock_session

client = CodeInterpreter("us-west-2")
client.identifier = None
Expand All @@ -132,8 +168,10 @@ def test_stop_when_no_session(self, mock_boto3):
@patch("bedrock_agentcore.tools.code_interpreter_client.uuid.uuid4")
def test_invoke_with_existing_session(self, mock_uuid4, mock_boto3):
# Arrange
mock_session = MagicMock()
mock_client = MagicMock()
mock_boto3.client.return_value = mock_client
mock_session.client.return_value = mock_client
mock_boto3.Session.return_value = mock_session
mock_uuid4.return_value.hex = "12345678abcdef"

client = CodeInterpreter("us-west-2")
Expand All @@ -158,8 +196,10 @@ def test_invoke_with_existing_session(self, mock_uuid4, mock_boto3):
@patch("bedrock_agentcore.tools.code_interpreter_client.boto3")
def test_invoke_with_no_session(self, mock_boto3):
# Arrange
mock_session = MagicMock()
mock_client = MagicMock()
mock_boto3.client.return_value = mock_client
mock_session.client.return_value = mock_client
mock_boto3.Session.return_value = mock_session

client = CodeInterpreter("us-west-2")
client.identifier = None
Expand Down Expand Up @@ -195,6 +235,22 @@ def test_code_session_context_manager(self, mock_client_class):
pass

# Assert
mock_client_class.assert_called_once_with("us-west-2")
mock_client_class.assert_called_once_with("us-west-2", session=None)
mock_client.start.assert_called_once()
mock_client.stop.assert_called_once()

@patch("bedrock_agentcore.tools.code_interpreter_client.CodeInterpreter")
def test_code_session_context_manager_with_session(self, mock_client_class):
# Arrange
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_session = MagicMock()

# Act
with code_session("us-west-2", session=mock_session):
pass

# Assert
mock_client_class.assert_called_once_with("us-west-2", session=mock_session)
mock_client.start.assert_called_once()
mock_client.stop.assert_called_once()