diff --git a/src/google/adk/cli/adk_web_server.py b/src/google/adk/cli/adk_web_server.py index 44df7908c..b71ccfa26 100644 --- a/src/google/adk/cli/adk_web_server.py +++ b/src/google/adk/cli/adk_web_server.py @@ -18,6 +18,7 @@ from contextlib import asynccontextmanager import logging import os +from pathlib import Path import time import traceback import typing @@ -31,6 +32,7 @@ from fastapi import HTTPException from fastapi import Query from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse from fastapi.responses import RedirectResponse from fastapi.responses import StreamingResponse from fastapi.staticfiles import StaticFiles @@ -44,6 +46,8 @@ from opentelemetry.sdk.trace import TracerProvider from pydantic import Field from pydantic import ValidationError +from starlette.responses import Response +from starlette.staticfiles import PathLike from starlette.types import Lifespan from typing_extensions import override from watchdog.observers import Observer @@ -196,6 +200,28 @@ class GetEventGraphResult(common.BaseModel): dot_src: str +class ConfigInjectingStaticFiles(StaticFiles): + """ + Custom StaticFiles that injects config.json for dev-ui. + Fixes https://github.com/google/adk-python/issues/2072 + """ + + def __init__(self, *, directory: PathLike | None = None, base_url: Optional[str] = None, **kwargs): + super().__init__(directory=directory, **kwargs) + if base_url is None: + base_url = "" + self.base_url = base_url + + async def get_response(self, path: str, scope) -> Response: + # Check if the request is for config.json + if Path(path).as_posix() == "assets/config/runtime-config.json": + config = {"backendUrl": self.base_url} + return JSONResponse(content=config) + + # Otherwise, serve static files normally + return await super().get_response(path, scope) + + class AdkWebServer: """Helper class for setting up and running the ADK web server on FastAPI. @@ -287,6 +313,7 @@ def get_fast_api_app( [Observer, "AdkWebServer"], None ] = lambda o, s: None, register_processors: Callable[[TracerProvider], None] = lambda o: None, + base_url: Optional[str] = None, ): """Creates a FastAPI app for the ADK web server. @@ -303,6 +330,8 @@ def get_fast_api_app( tear_down_observer: Callback for cleaning up the file system observer. register_processors: Callback for additional Span processors to be added to the TracerProvider. + base_url: The base URL for the web-ui, useful if fastapi app is mounted as + a sub-application. If none is provided, the host is used in the frontend. Returns: A FastAPI app instance. @@ -1013,7 +1042,12 @@ async def redirect_dev_ui_add_slash(): app.mount( "/dev-ui/", - StaticFiles(directory=web_assets_dir, html=True, follow_symlink=True), + ConfigInjectingStaticFiles( + directory=web_assets_dir, + base_url=base_url, + html=True, + follow_symlink=True, + ), name="static", ) diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py index 13a0a620a..7e4bece31 100644 --- a/src/google/adk/cli/cli_tools_click.py +++ b/src/google/adk/cli/cli_tools_click.py @@ -23,6 +23,7 @@ import os import tempfile from typing import Optional +from urllib.parse import urlparse import click from click.core import ParameterSource @@ -616,18 +617,12 @@ def fast_api_common_options(): def decorator(func): @click.option( - "--host", + "--base_url", type=str, - help="Optional. The binding host of the server", - default="127.0.0.1", + help="Optional. The base URL of the server.", + default="http://127\.0\.0\.1:8000", show_default=True, ) - @click.option( - "--port", - type=int, - help="Optional. The port of the server", - default=8000, - ) @click.option( "--allow_origins", help="Optional. Any additional origins to allow for CORS.", @@ -719,8 +714,7 @@ def cli_web( eval_storage_uri: Optional[str] = None, log_level: str = "INFO", allow_origins: Optional[list[str]] = None, - host: str = "127.0.0.1", - port: int = 8000, + base_url="http://127.0.0.1:8000", trace_to_cloud: bool = False, reload: bool = True, session_service_uri: Optional[str] = None, @@ -741,6 +735,9 @@ def cli_web( adk web --session_service_uri=[uri] --port=[port] path/to/agents_dir """ logs.setup_adk_logger(getattr(logging, log_level.upper())) + parsed_url = urlparse(base_url) + host = parsed_url.hostname + port = parsed_url.port @asynccontextmanager async def _lifespan(app: FastAPI): @@ -777,8 +774,7 @@ async def _lifespan(app: FastAPI): trace_to_cloud=trace_to_cloud, lifespan=_lifespan, a2a=a2a, - host=host, - port=port, + base_url=base_url, reload_agents=reload_agents, ) config = uvicorn.Config( @@ -810,8 +806,7 @@ def cli_api_server( eval_storage_uri: Optional[str] = None, log_level: str = "INFO", allow_origins: Optional[list[str]] = None, - host: str = "127.0.0.1", - port: int = 8000, + base_url="http://127.0.0.1:8000", trace_to_cloud: bool = False, reload: bool = True, session_service_uri: Optional[str] = None, @@ -833,6 +828,9 @@ def cli_api_server( """ logs.setup_adk_logger(getattr(logging, log_level.upper())) + parsed_url = urlparse(base_url) + host = parsed_url.hostname + port = parsed_url.port session_service_uri = session_service_uri or session_db_url artifact_service_uri = artifact_service_uri or artifact_storage_uri config = uvicorn.Config( @@ -846,8 +844,7 @@ def cli_api_server( web=False, trace_to_cloud=trace_to_cloud, a2a=a2a, - host=host, - port=port, + base_url=base_url, reload_agents=reload_agents, ), host=host, diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index bc1a75dda..8a8e44e80 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -27,9 +27,12 @@ from fastapi import FastAPI from fastapi import UploadFile from fastapi.responses import FileResponse +from fastapi.responses import JSONResponse from fastapi.responses import PlainTextResponse +from fastapi.staticfiles import StaticFiles from opentelemetry.sdk.trace import export from opentelemetry.sdk.trace import TracerProvider +from starlette.responses import Response from starlette.types import Lifespan from watchdog.observers import Observer @@ -64,8 +67,7 @@ def get_fast_api_app( allow_origins: Optional[list[str]] = None, web: bool, a2a: bool = False, - host: str = "127.0.0.1", - port: int = 8000, + base_url: str = "http://127.0.0.1:8000", trace_to_cloud: bool = False, reload_agents: bool = False, lifespan: Optional[Lifespan[FastAPI]] = None, @@ -189,7 +191,9 @@ def _parse_agent_engine_resource_name(agent_engine_id_or_resource_name): ) # Callbacks & other optional args for when constructing the FastAPI instance - extra_fast_api_args = {} + extra_fast_api_args = dict( + base_url=base_url, + ) if trace_to_cloud: from opentelemetry.exporter.cloud_trace import CloudTraceSpanExporter @@ -352,7 +356,7 @@ async def _get_a2a_runner_async() -> Runner: logger.info("Setting up A2A agent: %s", app_name) try: - a2a_rpc_path = f"http://{host}:{port}/a2a/{app_name}" + a2a_rpc_path = f"{base_url}/a2a/{app_name}" agent_executor = A2aAgentExecutor( runner=create_a2a_runner_loader(app_name), diff --git a/tests/unittests/cli/test_fast_api.py b/tests/unittests/cli/test_fast_api.py index f1c9e9d6e..e481f3677 100755 --- a/tests/unittests/cli/test_fast_api.py +++ b/tests/unittests/cli/test_fast_api.py @@ -24,6 +24,7 @@ from unittest.mock import MagicMock from unittest.mock import patch +from fastapi import FastAPI from fastapi.testclient import TestClient from google.adk.agents.base_agent import BaseAgent from google.adk.agents.run_config import RunConfig @@ -432,8 +433,7 @@ def test_app( memory_service_uri="", allow_origins=["*"], a2a=False, # Disable A2A for most tests - host="127.0.0.1", - port=8000, + base_url="http://127.0.0.1:8000", ) # Create a TestClient that doesn't start a real server @@ -607,8 +607,7 @@ def test_app_with_a2a( memory_service_uri="", allow_origins=["*"], a2a=True, - host="127.0.0.1", - port=8000, + base_url="http://127.0.0.1:8000", ) client = TestClient(app) @@ -897,5 +896,92 @@ def test_a2a_disabled_by_default(test_app): logger.info("A2A disabled by default test passed") +def test_runtime_config_contains_base_url(test_app): + """Test that ./assets/config/runtime-config.json contains the base_url passed to the FastAPI app.""" + # The test_app fixture configures the FastAPI app with base_url="http://127.0.0.1:8000" + expected_base_url = "http://127.0.0.1:8000" + + # Make a request to the runtime config endpoint + response = test_app.get("/dev-ui/assets/config/runtime-config.json") + + # Verify the response + assert response.status_code == 200 + data = response.json() + + # Verify the structure and content + assert isinstance(data, dict) + assert "backendUrl" in data + assert data["backendUrl"] == expected_base_url + + logger.info(f"Runtime config test passed - base_url: {data['backendUrl']}") + + +def test_runtime_config_with_custom_base_url( + mock_session_service, + mock_artifact_service, + mock_memory_service, + mock_agent_loader, + mock_eval_sets_manager, + mock_eval_set_results_manager, +): + """Test that runtime-config.json contains a custom base_url when provided.""" + custom_base_url = "https://example.com:9000/adk" + + # Create a FastAPI app with a custom base_url + with ( + patch("signal.signal", return_value=None), + patch( + "google.adk.cli.fast_api.InMemorySessionService", + return_value=mock_session_service, + ), + patch( + "google.adk.cli.fast_api.InMemoryArtifactService", + return_value=mock_artifact_service, + ), + patch( + "google.adk.cli.fast_api.InMemoryMemoryService", + return_value=mock_memory_service, + ), + patch( + "google.adk.cli.fast_api.AgentLoader", + return_value=mock_agent_loader, + ), + patch( + "google.adk.cli.fast_api.LocalEvalSetsManager", + return_value=mock_eval_sets_manager, + ), + patch( + "google.adk.cli.fast_api.LocalEvalSetResultsManager", + return_value=mock_eval_set_results_manager, + ), + ): + adk_app = get_fast_api_app( + agents_dir=".", + web=True, + session_service_uri="", + artifact_service_uri="", + memory_service_uri="", + allow_origins=["*"], + a2a=False, + base_url=custom_base_url, + ) + app = FastAPI() + app.mount("/adk", adk_app) + + client = TestClient(app) + + # Make a request to the runtime config endpoint + response = client.get("/adk/dev-ui/assets/config/runtime-config.json") + + # Verify the response contains the custom base_url + assert response.status_code == 200 + data = response.json() + assert isinstance(data, dict) + assert "backendUrl" in data + assert data["backendUrl"] == custom_base_url + + logger.info(f"Custom runtime config test passed - base_url: {data['backendUrl']}") + + if __name__ == "__main__": pytest.main(["-xvs", __file__])