diff --git a/packages/toolbox-core/pyproject.toml b/packages/toolbox-core/pyproject.toml index edc45a8a..e08f22a3 100644 --- a/packages/toolbox-core/pyproject.toml +++ b/packages/toolbox-core/pyproject.toml @@ -44,7 +44,10 @@ test = [ "isort==6.0.1", "mypy==1.15.0", "pytest==8.3.5", - "pytest-aioresponses==0.3.0" + "pytest-aioresponses==0.3.0", + "pytest-asyncio==0.26.0", + "google-cloud-secret-manager==2.23.2", + "google-cloud-storage==3.1.0", ] [build-system] requires = ["setuptools"] diff --git a/packages/toolbox-core/tests/conftest.py b/packages/toolbox-core/tests/conftest.py new file mode 100644 index 00000000..231ef349 --- /dev/null +++ b/packages/toolbox-core/tests/conftest.py @@ -0,0 +1,166 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Contains pytest fixtures that are accessible from all +files present in the same directory.""" + +from __future__ import annotations + +import os +import platform +import subprocess +import tempfile +import time +from typing import Generator + +import google +import pytest_asyncio +from google.auth import compute_engine +from google.cloud import secretmanager, storage + + +#### Define Utility Functions +def get_env_var(key: str) -> str: + """Gets environment variables.""" + value = os.environ.get(key) + if value is None: + raise ValueError(f"Must set env var {key}") + return value + + +def access_secret_version( + project_id: str, secret_id: str, version_id: str = "latest" +) -> str: + """Accesses the payload of a given secret version from Secret Manager.""" + client = secretmanager.SecretManagerServiceClient() + name = f"projects/{project_id}/secrets/{secret_id}/versions/{version_id}" + response = client.access_secret_version(request={"name": name}) + return response.payload.data.decode("UTF-8") + + +def create_tmpfile(content: str) -> str: + """Creates a temporary file with the given content.""" + with tempfile.NamedTemporaryFile(delete=False, mode="w") as tmpfile: + tmpfile.write(content) + return tmpfile.name + + +def download_blob( + bucket_name: str, source_blob_name: str, destination_file_name: str +) -> None: + """Downloads a blob from a GCS bucket.""" + storage_client = storage.Client() + + bucket = storage_client.bucket(bucket_name) + blob = bucket.blob(source_blob_name) + blob.download_to_filename(destination_file_name) + + print(f"Blob {source_blob_name} downloaded to {destination_file_name}.") + + +def get_toolbox_binary_url(toolbox_version: str) -> str: + """Constructs the GCS path to the toolbox binary.""" + os_system = platform.system().lower() + arch = ( + "arm64" if os_system == "darwin" and platform.machine() == "arm64" else "amd64" + ) + return f"v{toolbox_version}/{os_system}/{arch}/toolbox" + + +def get_auth_token(client_id: str) -> str: + """Retrieves an authentication token""" + request = google.auth.transport.requests.Request() + credentials = compute_engine.IDTokenCredentials( + request=request, + target_audience=client_id, + use_metadata_identity_endpoint=True, + ) + if not credentials.valid: + credentials.refresh(request) + return credentials.token + + +#### Define Fixtures +@pytest_asyncio.fixture(scope="session") +def project_id() -> str: + return get_env_var("GOOGLE_CLOUD_PROJECT") + + +@pytest_asyncio.fixture(scope="session") +def toolbox_version() -> str: + return get_env_var("TOOLBOX_VERSION") + + +@pytest_asyncio.fixture(scope="session") +def tools_file_path(project_id: str) -> Generator[str]: + """Provides a temporary file path containing the tools manifest.""" + tools_manifest = access_secret_version( + project_id=project_id, secret_id="sdk_testing_tools" + ) + tools_file_path = create_tmpfile(tools_manifest) + yield tools_file_path + os.remove(tools_file_path) + + +@pytest_asyncio.fixture(scope="session") +def auth_token1(project_id: str) -> str: + client_id = access_secret_version( + project_id=project_id, secret_id="sdk_testing_client1" + ) + return get_auth_token(client_id) + + +@pytest_asyncio.fixture(scope="session") +def auth_token2(project_id: str) -> str: + client_id = access_secret_version( + project_id=project_id, secret_id="sdk_testing_client2" + ) + return get_auth_token(client_id) + + +@pytest_asyncio.fixture(scope="session") +def toolbox_server(toolbox_version: str, tools_file_path: str) -> Generator[None]: + """Starts the toolbox server as a subprocess.""" + print("Downloading toolbox binary from gcs bucket...") + source_blob_name = get_toolbox_binary_url(toolbox_version) + download_blob("genai-toolbox", source_blob_name, "toolbox") + print("Toolbox binary downloaded successfully.") + try: + print("Opening toolbox server process...") + # Make toolbox executable + os.chmod("toolbox", 0o700) + # Run toolbox binary + toolbox_server = subprocess.Popen( + ["./toolbox", "--tools_file", tools_file_path] + ) + + # Wait for server to start + # Retry logic with a timeout + for _ in range(5): # retries + time.sleep(4) + print("Checking if toolbox is successfully started...") + if toolbox_server.poll() is None: + print("Toolbox server started successfully.") + break + else: + raise RuntimeError("Toolbox server failed to start after 5 retries.") + except subprocess.CalledProcessError as e: + print(e.stderr.decode("utf-8")) + print(e.stdout.decode("utf-8")) + raise RuntimeError(f"{e}\n\n{e.stderr.decode('utf-8')}") from e + yield + + # Clean up toolbox server + toolbox_server.terminate() + toolbox_server.wait() diff --git a/packages/toolbox-core/tests/test_e2e.py b/packages/toolbox-core/tests/test_e2e.py new file mode 100644 index 00000000..e9411fb0 --- /dev/null +++ b/packages/toolbox-core/tests/test_e2e.py @@ -0,0 +1,62 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import pytest_asyncio + +from toolbox_core.client import ToolboxClient +from toolbox_core.tool import ToolboxTool + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("toolbox_server") +class TestE2EClient: + @pytest_asyncio.fixture(scope="function") + async def toolbox(self): + toolbox = ToolboxClient("http://localhost:5000") + return toolbox + + @pytest_asyncio.fixture(scope="function") + async def get_n_rows_tool(self, toolbox: ToolboxClient) -> ToolboxTool: + tool = await toolbox.load_tool("get-n-rows") + assert tool.__name__ == "get-n-rows" + return tool + + #### Basic e2e tests + @pytest.mark.parametrize( + "toolset_name, expected_length, expected_tools", + [ + ("my-toolset", 1, ["get-row-by-id"]), + ("my-toolset-2", 2, ["get-n-rows", "get-row-by-id"]), + ], + ) + async def test_load_toolset_specific( + self, + toolbox: ToolboxClient, + toolset_name: str, + expected_length: int, + expected_tools: list[str], + ): + toolset = await toolbox.load_toolset(toolset_name) + assert len(toolset) == expected_length + tool_names = {tool.__name__ for tool in toolset} + assert tool_names == set(expected_tools) + + async def test_run_tool(self, get_n_rows_tool: ToolboxTool): + response = await get_n_rows_tool(num_rows="2") + + assert isinstance(response, str) + assert "row1" in response + assert "row2" in response + assert "row3" not in response diff --git a/packages/toolbox-core/tests/test_tool.py b/packages/toolbox-core/tests/test_tool.py new file mode 100644 index 00000000..593f50fe --- /dev/null +++ b/packages/toolbox-core/tests/test_tool.py @@ -0,0 +1,189 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from inspect import Parameter, Signature +from typing import Any, Callable, Optional +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from toolbox_core.tool import ToolboxTool + + +class TestToolboxTool: + @pytest.fixture + def mock_session(self) -> MagicMock: # Added self + session = MagicMock() + session.post = MagicMock() + return session + + @pytest.fixture + def tool_details(self) -> dict: + base_url = "http://fake-toolbox.com" + tool_name = "test_tool" + params = [ + Parameter("arg1", Parameter.POSITIONAL_OR_KEYWORD, annotation=str), + Parameter( + "opt_arg", + Parameter.POSITIONAL_OR_KEYWORD, + default=123, + annotation=Optional[int], + ), + ] + return { + "base_url": base_url, + "name": tool_name, + "desc": "A tool for testing.", + "params": params, + "signature": Signature(parameters=params, return_annotation=str), + "expected_url": f"{base_url}/api/tool/{tool_name}/invoke", + "annotations": {"arg1": str, "opt_arg": Optional[int]}, + } + + @pytest.fixture + def tool(self, mock_session: MagicMock, tool_details: dict) -> ToolboxTool: + return ToolboxTool( + session=mock_session, + base_url=tool_details["base_url"], + name=tool_details["name"], + desc=tool_details["desc"], + params=tool_details["params"], + ) + + @pytest.fixture + def configure_mock_response(self, mock_session: MagicMock) -> Callable: + def _configure(json_data: Any, status: int = 200): + mock_resp = MagicMock() + mock_resp.status = status + mock_resp.json = AsyncMock(return_value=json_data) + mock_resp.__aenter__.return_value = mock_resp + mock_resp.__aexit__.return_value = None + mock_session.post.return_value = mock_resp + + return _configure + + @pytest.mark.asyncio + async def test_initialization_and_introspection( + self, tool: ToolboxTool, tool_details: dict + ): + """Verify attributes are set correctly during initialization.""" + assert tool.__name__ == tool_details["name"] + assert tool.__doc__ == tool_details["desc"] + assert tool._ToolboxTool__url == tool_details["expected_url"] + assert tool._ToolboxTool__session is tool._ToolboxTool__session + assert tool.__signature__ == tool_details["signature"] + assert tool.__annotations__ == tool_details["annotations"] + # assert hasattr(tool, "__qualname__") + + @pytest.mark.asyncio + async def test_call_success( + self, + tool: ToolboxTool, + mock_session: MagicMock, + tool_details: dict, + configure_mock_response: Callable, + ): + expected_result = "Operation successful!" + configure_mock_response({"result": expected_result}) + + arg1_val = "test_value" + opt_arg_val = 456 + result = await tool(arg1_val, opt_arg=opt_arg_val) + + assert result == expected_result + mock_session.post.assert_called_once_with( + tool_details["expected_url"], + json={"arg1": arg1_val, "opt_arg": opt_arg_val}, + ) + mock_session.post.return_value.__aenter__.return_value.json.assert_awaited_once() + + @pytest.mark.asyncio + async def test_call_success_with_defaults( + self, + tool: ToolboxTool, + mock_session: MagicMock, + tool_details: dict, + configure_mock_response: Callable, + ): + expected_result = "Default success!" + configure_mock_response({"result": expected_result}) + + arg1_val = "another_test" + default_opt_val = tool_details["params"][1].default + result = await tool(arg1_val) + + assert result == expected_result + mock_session.post.assert_called_once_with( + tool_details["expected_url"], + json={"arg1": arg1_val, "opt_arg": default_opt_val}, + ) + mock_session.post.return_value.__aenter__.return_value.json.assert_awaited_once() + + @pytest.mark.asyncio + async def test_call_api_error( + self, + tool: ToolboxTool, + mock_session: MagicMock, + tool_details: dict, + configure_mock_response: Callable, + ): + error_message = "Tool execution failed on server" + configure_mock_response({"error": error_message}) + default_opt_val = tool_details["params"][1].default + + with pytest.raises(Exception) as exc_info: + await tool("some_arg") + + assert str(exc_info.value) == error_message + mock_session.post.assert_called_once_with( + tool_details["expected_url"], + json={"arg1": "some_arg", "opt_arg": default_opt_val}, + ) + mock_session.post.return_value.__aenter__.return_value.json.assert_awaited_once() + + @pytest.mark.asyncio + async def test_call_missing_result_key( + self, + tool: ToolboxTool, + mock_session: MagicMock, + tool_details: dict, + configure_mock_response: Callable, + ): + fallback_response = {"status": "completed", "details": "some info"} + configure_mock_response(fallback_response) + default_opt_val = tool_details["params"][1].default + + result = await tool("value_for_arg1") + + assert result == fallback_response + mock_session.post.assert_called_once_with( + tool_details["expected_url"], + json={"arg1": "value_for_arg1", "opt_arg": default_opt_val}, + ) + mock_session.post.return_value.__aenter__.return_value.json.assert_awaited_once() + + @pytest.mark.asyncio + async def test_call_invalid_arguments_type_error( + self, tool: ToolboxTool, mock_session: MagicMock + ): + with pytest.raises(TypeError): + await tool("val1", 2, 3) + + with pytest.raises(TypeError): + await tool("val1", non_existent_arg="bad") + + with pytest.raises(TypeError): + await tool(opt_arg=500) + + mock_session.post.assert_not_called()