diff --git a/contributing/samples/a2a_auth/README.md b/contributing/samples/a2a_auth/README.md index 2e4aa204d..6504972c0 100644 --- a/contributing/samples/a2a_auth/README.md +++ b/contributing/samples/a2a_auth/README.md @@ -185,7 +185,7 @@ When deploying the remote BigQuery A2A agent to different environments (e.g., Cl } ``` -**Important:** The `url` field in `remote_a2a/bigquery_agent/agent.json` must point to the actual RPC endpoint where your remote BigQuery A2A agent is deployed and accessible. +**Important:** The `url` field in `remote_a2a/bigquery_agent/agent.json` must point to the actual RPC endpoint where your remote BigQuery A2A agent is deployed and accessible. If the `url` field is an empty string, it will be automatically filled by the base URL provided to `get_fast_api_app`. ## Troubleshooting diff --git a/contributing/samples/a2a_basic/README.md b/contributing/samples/a2a_basic/README.md index ca61101c2..a07390d2d 100644 --- a/contributing/samples/a2a_basic/README.md +++ b/contributing/samples/a2a_basic/README.md @@ -135,7 +135,7 @@ When deploying the remote A2A agent to different environments (e.g., Cloud Run, } ``` -**Important:** The `url` field in `remote_a2a/check_prime_agent/agent.json` must point to the actual RPC endpoint where your remote A2A agent is deployed and accessible. +**Important:** The `url` field in `remote_a2a/check_prime_agent/agent.json` must point to the actual RPC endpoint where your remote A2A agent is deployed and accessible. If the `url` field is an empty string, it will be automatically filled by the base URL provided to `get_fast_api_app`. ## Troubleshooting diff --git a/contributing/samples/a2a_human_in_loop/README.md b/contributing/samples/a2a_human_in_loop/README.md index 5f90fad9f..9966ddba6 100644 --- a/contributing/samples/a2a_human_in_loop/README.md +++ b/contributing/samples/a2a_human_in_loop/README.md @@ -144,7 +144,7 @@ When deploying the remote approval A2A agent to different environments (e.g., Cl } ``` -**Important:** The `url` field in `remote_a2a/human_in_loop/agent.json` must point to the actual RPC endpoint where your remote approval A2A agent is deployed and accessible. +**Important:** The `url` field in `remote_a2a/human_in_loop/agent.json` must point to the actual RPC endpoint where your remote approval A2A agent is deployed and accessible. If the `url` field is an empty string, it will be automatically filled by the base URL provided to `get_fast_api_app`. ## Troubleshooting diff --git a/src/google/adk/cli/adk_web_server.py b/src/google/adk/cli/adk_web_server.py index 2d8092620..0db08bd8e 100644 --- a/src/google/adk/cli/adk_web_server.py +++ b/src/google/adk/cli/adk_web_server.py @@ -18,6 +18,7 @@ from contextlib import asynccontextmanager import logging import os +from pathlib import Path import time import traceback import typing @@ -31,6 +32,7 @@ from fastapi import HTTPException from fastapi import Query from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse from fastapi.responses import RedirectResponse from fastapi.responses import StreamingResponse from fastapi.staticfiles import StaticFiles @@ -44,6 +46,8 @@ from opentelemetry.sdk.trace import TracerProvider from pydantic import Field from pydantic import ValidationError +from starlette.responses import Response +from starlette.staticfiles import PathLike from starlette.types import Lifespan from typing_extensions import override from watchdog.observers import Observer @@ -197,6 +201,28 @@ class GetEventGraphResult(common.BaseModel): dot_src: str +class ConfigInjectingStaticFiles(StaticFiles): + """ + Custom StaticFiles that injects config.json for dev-ui. + Fixes https://github.com/google/adk-python/issues/2072 + """ + + def __init__(self, *, directory: PathLike | None = None, base_url: Optional[str] = None, **kwargs): + super().__init__(directory=directory, **kwargs) + if base_url is None: + base_url = "" + self.base_url = base_url + + async def get_response(self, path: str, scope) -> Response: + # Check if the request is for config.json + if Path(path).as_posix() == "assets/config/runtime-config.json": + config = {"backendUrl": self.base_url} + return JSONResponse(content=config) + + # Otherwise, serve static files normally + return await super().get_response(path, scope) + + class AdkWebServer: """Helper class for setting up and running the ADK web server on FastAPI. @@ -288,6 +314,7 @@ def get_fast_api_app( [Observer, "AdkWebServer"], None ] = lambda o, s: None, register_processors: Callable[[TracerProvider], None] = lambda o: None, + base_url: Optional[str] = None, ): """Creates a FastAPI app for the ADK web server. @@ -304,6 +331,8 @@ def get_fast_api_app( tear_down_observer: Callback for cleaning up the file system observer. register_processors: Callback for additional Span processors to be added to the TracerProvider. + base_url: The base URL for the web-ui, useful if fastapi app is mounted as + a sub-application. If none is provided, the host is used in the frontend. Returns: A FastAPI app instance. @@ -1024,7 +1053,12 @@ async def redirect_dev_ui_add_slash(): app.mount( "/dev-ui/", - StaticFiles(directory=web_assets_dir, html=True, follow_symlink=True), + ConfigInjectingStaticFiles( + directory=web_assets_dir, + base_url=base_url, + html=True, + follow_symlink=True, + ), name="static", ) diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py index 3000edf78..e9f6555d0 100644 --- a/src/google/adk/cli/cli_tools_click.py +++ b/src/google/adk/cli/cli_tools_click.py @@ -23,6 +23,7 @@ import os import tempfile from typing import Optional +from urllib.parse import urlparse import click from click.core import ParameterSource @@ -616,18 +617,12 @@ def fast_api_common_options(): def decorator(func): @click.option( - "--host", + "--base_url", type=str, - help="Optional. The binding host of the server", - default="127.0.0.1", + help="Optional. The base URL of the server.", + default="http://127.0.0.1:8000", show_default=True, ) - @click.option( - "--port", - type=int, - help="Optional. The port of the server", - default=8000, - ) @click.option( "--allow_origins", help="Optional. Any additional origins to allow for CORS.", @@ -719,8 +714,7 @@ def cli_web( eval_storage_uri: Optional[str] = None, log_level: str = "INFO", allow_origins: Optional[list[str]] = None, - host: str = "127.0.0.1", - port: int = 8000, + base_url="http://127.0.0.1:8000", trace_to_cloud: bool = False, reload: bool = True, session_service_uri: Optional[str] = None, @@ -741,6 +735,9 @@ def cli_web( adk web --session_service_uri=[uri] --port=[port] path/to/agents_dir """ logs.setup_adk_logger(getattr(logging, log_level.upper())) + parsed_url = urlparse(base_url) + host = parsed_url.hostname + port = parsed_url.port @asynccontextmanager async def _lifespan(app: FastAPI): @@ -777,8 +774,7 @@ async def _lifespan(app: FastAPI): trace_to_cloud=trace_to_cloud, lifespan=_lifespan, a2a=a2a, - host=host, - port=port, + base_url=base_url, reload_agents=reload_agents, ) config = uvicorn.Config( @@ -810,8 +806,7 @@ def cli_api_server( eval_storage_uri: Optional[str] = None, log_level: str = "INFO", allow_origins: Optional[list[str]] = None, - host: str = "127.0.0.1", - port: int = 8000, + base_url="http://127.0.0.1:8000", trace_to_cloud: bool = False, reload: bool = True, session_service_uri: Optional[str] = None, @@ -833,6 +828,9 @@ def cli_api_server( """ logs.setup_adk_logger(getattr(logging, log_level.upper())) + parsed_url = urlparse(base_url) + host = parsed_url.hostname + port = parsed_url.port session_service_uri = session_service_uri or session_db_url artifact_service_uri = artifact_service_uri or artifact_storage_uri config = uvicorn.Config( @@ -846,8 +844,7 @@ def cli_api_server( web=False, trace_to_cloud=trace_to_cloud, a2a=a2a, - host=host, - port=port, + base_url=base_url, reload_agents=reload_agents, ), host=host, diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index 7d93b5436..692e910e8 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -64,8 +64,7 @@ def get_fast_api_app( allow_origins: Optional[list[str]] = None, web: bool, a2a: bool = False, - host: str = "127.0.0.1", - port: int = 8000, + base_url: str = "http://127.0.0.1:8000", trace_to_cloud: bool = False, reload_agents: bool = False, lifespan: Optional[Lifespan[FastAPI]] = None, @@ -189,7 +188,9 @@ def _parse_agent_engine_resource_name(agent_engine_id_or_resource_name): ) # Callbacks & other optional args for when constructing the FastAPI instance - extra_fast_api_args = {} + extra_fast_api_args = dict( + base_url=base_url, + ) if trace_to_cloud: from opentelemetry.exporter.cloud_trace import CloudTraceSpanExporter @@ -352,6 +353,8 @@ async def _get_a2a_runner_async() -> Runner: logger.info("Setting up A2A agent: %s", app_name) try: + a2a_rpc_path = f"{base_url}/a2a/{app_name}" + agent_executor = A2aAgentExecutor( runner=create_a2a_runner_loader(app_name), ) @@ -363,6 +366,8 @@ async def _get_a2a_runner_async() -> Runner: with (p / "agent.json").open("r", encoding="utf-8") as f: data = json.load(f) agent_card = AgentCard(**data) + if agent_card.url == "": # empty url is a placeholder to be filled with the provided url + agent_card.url = a2a_rpc_path a2a_app = A2AStarletteApplication( agent_card=agent_card, diff --git a/tests/unittests/cli/test_fast_api.py b/tests/unittests/cli/test_fast_api.py index f1c9e9d6e..60800e2ca 100755 --- a/tests/unittests/cli/test_fast_api.py +++ b/tests/unittests/cli/test_fast_api.py @@ -24,6 +24,8 @@ from unittest.mock import MagicMock from unittest.mock import patch +from a2a.utils import AGENT_CARD_WELL_KNOWN_PATH +from fastapi import FastAPI from fastapi.testclient import TestClient from google.adk.agents.base_agent import BaseAgent from google.adk.agents.run_config import RunConfig @@ -432,8 +434,7 @@ def test_app( memory_service_uri="", allow_origins=["*"], a2a=False, # Disable A2A for most tests - host="127.0.0.1", - port=8000, + base_url="http://127.0.0.1:8000", ) # Create a TestClient that doesn't start a real server @@ -502,11 +503,37 @@ def temp_agents_dir_with_a2a(): # Create agent.json file agent_card = { + "capabilities": { + "pushNotifications": True, + "streaming": True + }, + "defaultInputModes": [ + "text", + "text/plain" + ], + "defaultOutputModes": [ + "text", + "text/plain" + ], "name": "test_a2a_agent", "description": "Test A2A agent", "version": "1.0.0", "author": "test", - "capabilities": ["text"], + "protocolVersion": "0.2.6", + "skills": [ + { + "description": "Makes the tests pass", + "examples": [ + "Fix the tests." + ], + "id": "test_a2a_agent", + "name": "Test A2A agent", + "tags": [ + "testing" + ] + } + ], + "url": "", } with open(agent_dir / "agent.json", "w") as f: @@ -580,20 +607,12 @@ def test_app_with_a2a( patch( "a2a.server.request_handlers.DefaultRequestHandler" ) as mock_handler, - patch("a2a.server.apps.A2AStarletteApplication") as mock_a2a_app, ): # Configure mocks mock_task_store.return_value = MagicMock() mock_executor.return_value = MagicMock() mock_handler.return_value = MagicMock() - # Mock A2AStarletteApplication - mock_app_instance = MagicMock() - mock_app_instance.routes.return_value = ( - [] - ) # Return empty routes for testing - mock_a2a_app.return_value = mock_app_instance - # Change to temp directory original_cwd = os.getcwd() os.chdir(temp_agents_dir_with_a2a) @@ -607,8 +626,7 @@ def test_app_with_a2a( memory_service_uri="", allow_origins=["*"], a2a=True, - host="127.0.0.1", - port=8000, + base_url="http://127.0.0.1:8000", ) client = TestClient(app) @@ -879,9 +897,12 @@ def test_debug_trace(test_app): ) def test_a2a_agent_discovery(test_app_with_a2a): """Test that A2A agents are properly discovered and configured.""" - # This test mainly verifies that the A2A setup doesn't break the app + # This test verifies that the A2A setup doesn't break the app + # and that the well known card works response = test_app_with_a2a.get("/list-apps") assert response.status_code == 200 + response2 = test_app_with_a2a.get(f"/a2a/test_a2a_agent{AGENT_CARD_WELL_KNOWN_PATH}") + assert response2.status_code == 200 logger.info("A2A agent discovery test passed") @@ -897,5 +918,91 @@ def test_a2a_disabled_by_default(test_app): logger.info("A2A disabled by default test passed") +def test_runtime_config_contains_base_url(test_app): + """Test that ./assets/config/runtime-config.json contains the base_url passed to the FastAPI app.""" + # The test_app fixture configures the FastAPI app with base_url="http://127.0.0.1:8000" + expected_base_url = "http://127.0.0.1:8000" + + # Make a request to the runtime config endpoint + response = test_app.get("/dev-ui/assets/config/runtime-config.json") + + # Verify the response + assert response.status_code == 200 + data = response.json() + + # Verify the structure and content + assert isinstance(data, dict) + assert "backendUrl" in data + assert data["backendUrl"] == expected_base_url + + logger.info(f"Runtime config test passed - base_url: {data['backendUrl']}") + + +def test_runtime_config_with_custom_base_url( + mock_session_service, + mock_artifact_service, + mock_memory_service, + mock_agent_loader, + mock_eval_sets_manager, + mock_eval_set_results_manager, +): + """Test that runtime-config.json contains a custom base_url when provided.""" + custom_base_url = "https://example.com:9000/adk" + + # Create a FastAPI app with a custom base_url + with ( + patch("signal.signal", return_value=None), + patch( + "google.adk.cli.fast_api.InMemorySessionService", + return_value=mock_session_service, + ), + patch( + "google.adk.cli.fast_api.InMemoryArtifactService", + return_value=mock_artifact_service, + ), + patch( + "google.adk.cli.fast_api.InMemoryMemoryService", + return_value=mock_memory_service, + ), + patch( + "google.adk.cli.fast_api.AgentLoader", + return_value=mock_agent_loader, + ), + patch( + "google.adk.cli.fast_api.LocalEvalSetsManager", + return_value=mock_eval_sets_manager, + ), + patch( + "google.adk.cli.fast_api.LocalEvalSetResultsManager", + return_value=mock_eval_set_results_manager, + ), + ): + adk_app = get_fast_api_app( + agents_dir=".", + web=True, + session_service_uri="", + artifact_service_uri="", + memory_service_uri="", + allow_origins=["*"], + a2a=False, + base_url=custom_base_url, + ) + app = FastAPI() + app.mount("/adk", adk_app) + + client = TestClient(app) + + # Make a request to the runtime config endpoint + response = client.get("/adk/dev-ui/assets/config/runtime-config.json") + + # Verify the response contains the custom base_url + assert response.status_code == 200 + data = response.json() + assert isinstance(data, dict) + assert "backendUrl" in data + assert data["backendUrl"] == custom_base_url + + logger.info(f"Custom runtime config test passed - base_url: {data['backendUrl']}") + if __name__ == "__main__": pytest.main(["-xvs", __file__])