diff --git a/src/bedrock_agentcore/tools/code_interpreter_client.py b/src/bedrock_agentcore/tools/code_interpreter_client.py index b1ceec7..36b672b 100644 --- a/src/bedrock_agentcore/tools/code_interpreter_client.py +++ b/src/bedrock_agentcore/tools/code_interpreter_client.py @@ -30,14 +30,19 @@ class CodeInterpreter: session_id (str, optional): The active session ID. """ - def __init__(self, region: str) -> None: + def __init__(self, region: str, session: Optional[boto3.Session] = None) -> None: """Initialize a Code Interpreter client for the specified AWS region. Args: region (str): The AWS region to use for the Code Interpreter service. + session (Optional[boto3.Session]): Optional boto3 session to use. + If not provided, a new session will be created. This is useful + for cases where you need to use custom credentials or assume roles. """ self.data_plane_service_name = "bedrock-agentcore" - self.client = boto3.client( + if session is None: + session = boto3.Session() + self.client = session.client( self.data_plane_service_name, region_name=region, endpoint_url=get_data_plane_endpoint(region) ) self._identifier = None @@ -160,7 +165,7 @@ def invoke(self, method: str, params: Optional[Dict] = None): @contextmanager -def code_session(region: str) -> Generator[CodeInterpreter, None, None]: +def code_session(region: str, session: Optional[boto3.Session] = None) -> Generator[CodeInterpreter, None, None]: """Context manager for creating and managing a code interpreter session. This context manager handles creating a client, starting a session, and @@ -168,16 +173,18 @@ def code_session(region: str) -> Generator[CodeInterpreter, None, None]: Args: region (str): The AWS region to use for the Code Interpreter service. + session (Optional[boto3.Session]): Optional boto3 session to use. + If not provided, a new session will be created. Yields: - CodeInterpreterClient: An initialized and started code interpreter client. + CodeInterpreter: An initialized and started code interpreter client. Example: >>> with code_session('us-west-2') as client: ... result = client.invoke('listFiles') ... # Process result here """ - client = CodeInterpreter(region) + client = CodeInterpreter(region, session=session) client.start() try: diff --git a/tests/bedrock_agentcore/tools/test_code_interpreter_client.py b/tests/bedrock_agentcore/tools/test_code_interpreter_client.py index 2be8d22..5dcff77 100644 --- a/tests/bedrock_agentcore/tools/test_code_interpreter_client.py +++ b/tests/bedrock_agentcore/tools/test_code_interpreter_client.py @@ -8,8 +8,10 @@ class TestCodeInterpreterClient: @patch("bedrock_agentcore.tools.code_interpreter_client.get_data_plane_endpoint") def test_init(self, mock_get_endpoint, mock_boto3): # Arrange + mock_session = MagicMock() mock_client = MagicMock() - mock_boto3.client.return_value = mock_client + mock_session.client.return_value = mock_client + mock_boto3.Session.return_value = mock_session mock_get_endpoint.return_value = "https://mock-endpoint.com" region = "us-west-2" @@ -17,7 +19,28 @@ def test_init(self, mock_get_endpoint, mock_boto3): client = CodeInterpreter(region) # Assert - mock_boto3.client.assert_called_once_with( + mock_boto3.Session.assert_called_once() + mock_session.client.assert_called_once_with( + "bedrock-agentcore", region_name=region, endpoint_url="https://mock-endpoint.com" + ) + assert client.client == mock_client + assert client.identifier is None + assert client.session_id is None + + @patch("bedrock_agentcore.tools.code_interpreter_client.get_data_plane_endpoint") + def test_init_with_custom_session(self, mock_get_endpoint): + # Arrange + mock_session = MagicMock() + mock_client = MagicMock() + mock_session.client.return_value = mock_client + mock_get_endpoint.return_value = "https://mock-endpoint.com" + region = "us-west-2" + + # Act + client = CodeInterpreter(region, session=mock_session) + + # Assert + mock_session.client.assert_called_once_with( "bedrock-agentcore", region_name=region, endpoint_url="https://mock-endpoint.com" ) assert client.client == mock_client @@ -27,6 +50,11 @@ def test_init(self, mock_get_endpoint, mock_boto3): @patch("bedrock_agentcore.tools.code_interpreter_client.boto3") def test_property_getters_setters(self, mock_boto3): # Arrange + mock_session = MagicMock() + mock_client = MagicMock() + mock_session.client.return_value = mock_client + mock_boto3.Session.return_value = mock_session + client = CodeInterpreter("us-west-2") test_identifier = "test.identifier" test_session_id = "test-session-id" @@ -43,8 +71,10 @@ def test_property_getters_setters(self, mock_boto3): @patch("bedrock_agentcore.tools.code_interpreter_client.uuid.uuid4") def test_start_with_defaults(self, mock_uuid4, mock_boto3): # Arrange + mock_session = MagicMock() mock_client = MagicMock() - mock_boto3.client.return_value = mock_client + mock_session.client.return_value = mock_client + mock_boto3.Session.return_value = mock_session mock_uuid4.return_value.hex = "12345678abcdef" client = CodeInterpreter("us-west-2") @@ -67,8 +97,10 @@ def test_start_with_defaults(self, mock_uuid4, mock_boto3): @patch("bedrock_agentcore.tools.code_interpreter_client.boto3") def test_start_with_custom_params(self, mock_boto3): # Arrange + mock_session = MagicMock() mock_client = MagicMock() - mock_boto3.client.return_value = mock_client + mock_session.client.return_value = mock_client + mock_boto3.Session.return_value = mock_session client = CodeInterpreter("us-west-2") mock_response = {"codeInterpreterIdentifier": "custom.interpreter", "sessionId": "custom-session-123"} @@ -94,8 +126,10 @@ def test_start_with_custom_params(self, mock_boto3): @patch("bedrock_agentcore.tools.code_interpreter_client.boto3") def test_stop_when_session_exists(self, mock_boto3): # Arrange + mock_session = MagicMock() mock_client = MagicMock() - mock_boto3.client.return_value = mock_client + mock_session.client.return_value = mock_client + mock_boto3.Session.return_value = mock_session client = CodeInterpreter("us-west-2") client.identifier = "test.identifier" @@ -114,8 +148,10 @@ def test_stop_when_session_exists(self, mock_boto3): @patch("bedrock_agentcore.tools.code_interpreter_client.boto3") def test_stop_when_no_session(self, mock_boto3): # Arrange + mock_session = MagicMock() mock_client = MagicMock() - mock_boto3.client.return_value = mock_client + mock_session.client.return_value = mock_client + mock_boto3.Session.return_value = mock_session client = CodeInterpreter("us-west-2") client.identifier = None @@ -132,8 +168,10 @@ def test_stop_when_no_session(self, mock_boto3): @patch("bedrock_agentcore.tools.code_interpreter_client.uuid.uuid4") def test_invoke_with_existing_session(self, mock_uuid4, mock_boto3): # Arrange + mock_session = MagicMock() mock_client = MagicMock() - mock_boto3.client.return_value = mock_client + mock_session.client.return_value = mock_client + mock_boto3.Session.return_value = mock_session mock_uuid4.return_value.hex = "12345678abcdef" client = CodeInterpreter("us-west-2") @@ -158,8 +196,10 @@ def test_invoke_with_existing_session(self, mock_uuid4, mock_boto3): @patch("bedrock_agentcore.tools.code_interpreter_client.boto3") def test_invoke_with_no_session(self, mock_boto3): # Arrange + mock_session = MagicMock() mock_client = MagicMock() - mock_boto3.client.return_value = mock_client + mock_session.client.return_value = mock_client + mock_boto3.Session.return_value = mock_session client = CodeInterpreter("us-west-2") client.identifier = None @@ -195,6 +235,22 @@ def test_code_session_context_manager(self, mock_client_class): pass # Assert - mock_client_class.assert_called_once_with("us-west-2") + mock_client_class.assert_called_once_with("us-west-2", session=None) + mock_client.start.assert_called_once() + mock_client.stop.assert_called_once() + + @patch("bedrock_agentcore.tools.code_interpreter_client.CodeInterpreter") + def test_code_session_context_manager_with_session(self, mock_client_class): + # Arrange + mock_client = MagicMock() + mock_client_class.return_value = mock_client + mock_session = MagicMock() + + # Act + with code_session("us-west-2", session=mock_session): + pass + + # Assert + mock_client_class.assert_called_once_with("us-west-2", session=mock_session) mock_client.start.assert_called_once() mock_client.stop.assert_called_once()