Skip to content

fix: dev ui url works in sub-apps #2254

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion src/google/adk/cli/adk_web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from contextlib import asynccontextmanager
import logging
import os
from pathlib import Path
import time
import traceback
import typing
Expand All @@ -31,6 +32,7 @@
from fastapi import HTTPException
from fastapi import Query
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi.responses import RedirectResponse
from fastapi.responses import StreamingResponse
from fastapi.staticfiles import StaticFiles
Expand All @@ -44,6 +46,8 @@
from opentelemetry.sdk.trace import TracerProvider
from pydantic import Field
from pydantic import ValidationError
from starlette.responses import Response
from starlette.staticfiles import PathLike
from starlette.types import Lifespan
from typing_extensions import override
from watchdog.observers import Observer
Expand Down Expand Up @@ -196,6 +200,28 @@ class GetEventGraphResult(common.BaseModel):
dot_src: str


class ConfigInjectingStaticFiles(StaticFiles):
"""
Custom StaticFiles that injects config.json for dev-ui.
Fixes https://github.com/google/adk-python/issues/2072
"""

def __init__(self, *, directory: PathLike | None = None, base_url: Optional[str] = None, **kwargs):
super().__init__(directory=directory, **kwargs)
if base_url is None:
base_url = ""
self.base_url = base_url

async def get_response(self, path: str, scope) -> Response:
# Check if the request is for config.json
if Path(path).as_posix() == "assets/config/runtime-config.json":
config = {"backendUrl": self.base_url}
return JSONResponse(content=config)

# Otherwise, serve static files normally
return await super().get_response(path, scope)


class AdkWebServer:
"""Helper class for setting up and running the ADK web server on FastAPI.

Expand Down Expand Up @@ -287,6 +313,7 @@ def get_fast_api_app(
[Observer, "AdkWebServer"], None
] = lambda o, s: None,
register_processors: Callable[[TracerProvider], None] = lambda o: None,
base_url: Optional[str] = None,
):
"""Creates a FastAPI app for the ADK web server.

Expand All @@ -303,6 +330,8 @@ def get_fast_api_app(
tear_down_observer: Callback for cleaning up the file system observer.
register_processors: Callback for additional Span processors to be added
to the TracerProvider.
base_url: The base URL for the web-ui, useful if fastapi app is mounted as
a sub-application. If none is provided, the host is used in the frontend.

Returns:
A FastAPI app instance.
Expand Down Expand Up @@ -1013,7 +1042,12 @@ async def redirect_dev_ui_add_slash():

app.mount(
"/dev-ui/",
StaticFiles(directory=web_assets_dir, html=True, follow_symlink=True),
ConfigInjectingStaticFiles(
directory=web_assets_dir,
base_url=base_url,
html=True,
follow_symlink=True,
),
name="static",
)

Expand Down
31 changes: 14 additions & 17 deletions src/google/adk/cli/cli_tools_click.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import os
import tempfile
from typing import Optional
from urllib.parse import urlparse

import click
from click.core import ParameterSource
Expand Down Expand Up @@ -616,18 +617,12 @@ def fast_api_common_options():

def decorator(func):
@click.option(
"--host",
"--base_url",
type=str,
help="Optional. The binding host of the server",
default="127.0.0.1",
help="Optional. The base URL of the server.",
default="http://127\.0\.0\.1:8000",
show_default=True,
)
@click.option(
"--port",
type=int,
help="Optional. The port of the server",
default=8000,
)
@click.option(
"--allow_origins",
help="Optional. Any additional origins to allow for CORS.",
Expand Down Expand Up @@ -719,8 +714,7 @@ def cli_web(
eval_storage_uri: Optional[str] = None,
log_level: str = "INFO",
allow_origins: Optional[list[str]] = None,
host: str = "127.0.0.1",
port: int = 8000,
base_url="http://127.0.0.1:8000",
trace_to_cloud: bool = False,
reload: bool = True,
session_service_uri: Optional[str] = None,
Expand All @@ -741,6 +735,9 @@ def cli_web(
adk web --session_service_uri=[uri] --port=[port] path/to/agents_dir
"""
logs.setup_adk_logger(getattr(logging, log_level.upper()))
parsed_url = urlparse(base_url)
host = parsed_url.hostname
port = parsed_url.port

@asynccontextmanager
async def _lifespan(app: FastAPI):
Expand Down Expand Up @@ -777,8 +774,7 @@ async def _lifespan(app: FastAPI):
trace_to_cloud=trace_to_cloud,
lifespan=_lifespan,
a2a=a2a,
host=host,
port=port,
base_url=base_url,
reload_agents=reload_agents,
)
config = uvicorn.Config(
Expand Down Expand Up @@ -810,8 +806,7 @@ def cli_api_server(
eval_storage_uri: Optional[str] = None,
log_level: str = "INFO",
allow_origins: Optional[list[str]] = None,
host: str = "127.0.0.1",
port: int = 8000,
base_url="http://127.0.0.1:8000",
trace_to_cloud: bool = False,
reload: bool = True,
session_service_uri: Optional[str] = None,
Expand All @@ -833,6 +828,9 @@ def cli_api_server(
"""
logs.setup_adk_logger(getattr(logging, log_level.upper()))

parsed_url = urlparse(base_url)
host = parsed_url.hostname
port = parsed_url.port
session_service_uri = session_service_uri or session_db_url
artifact_service_uri = artifact_service_uri or artifact_storage_uri
config = uvicorn.Config(
Expand All @@ -846,8 +844,7 @@ def cli_api_server(
web=False,
trace_to_cloud=trace_to_cloud,
a2a=a2a,
host=host,
port=port,
base_url=base_url,
reload_agents=reload_agents,
),
host=host,
Expand Down
12 changes: 8 additions & 4 deletions src/google/adk/cli/fast_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,12 @@
from fastapi import FastAPI
from fastapi import UploadFile
from fastapi.responses import FileResponse
from fastapi.responses import JSONResponse
from fastapi.responses import PlainTextResponse
from fastapi.staticfiles import StaticFiles
from opentelemetry.sdk.trace import export
from opentelemetry.sdk.trace import TracerProvider
from starlette.responses import Response
from starlette.types import Lifespan
from watchdog.observers import Observer

Expand Down Expand Up @@ -64,8 +67,7 @@ def get_fast_api_app(
allow_origins: Optional[list[str]] = None,
web: bool,
a2a: bool = False,
host: str = "127.0.0.1",
port: int = 8000,
base_url: str = "http://127.0.0.1:8000",
trace_to_cloud: bool = False,
reload_agents: bool = False,
lifespan: Optional[Lifespan[FastAPI]] = None,
Expand Down Expand Up @@ -189,7 +191,9 @@ def _parse_agent_engine_resource_name(agent_engine_id_or_resource_name):
)

# Callbacks & other optional args for when constructing the FastAPI instance
extra_fast_api_args = {}
extra_fast_api_args = dict(
base_url=base_url,
)

if trace_to_cloud:
from opentelemetry.exporter.cloud_trace import CloudTraceSpanExporter
Expand Down Expand Up @@ -352,7 +356,7 @@ async def _get_a2a_runner_async() -> Runner:
logger.info("Setting up A2A agent: %s", app_name)

try:
a2a_rpc_path = f"http://{host}:{port}/a2a/{app_name}"
a2a_rpc_path = f"{base_url}/a2a/{app_name}"

agent_executor = A2aAgentExecutor(
runner=create_a2a_runner_loader(app_name),
Expand Down
94 changes: 90 additions & 4 deletions tests/unittests/cli/test_fast_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from unittest.mock import MagicMock
from unittest.mock import patch

from fastapi import FastAPI
from fastapi.testclient import TestClient
from google.adk.agents.base_agent import BaseAgent
from google.adk.agents.run_config import RunConfig
Expand Down Expand Up @@ -432,8 +433,7 @@ def test_app(
memory_service_uri="",
allow_origins=["*"],
a2a=False, # Disable A2A for most tests
host="127.0.0.1",
port=8000,
base_url="http://127.0.0.1:8000",
)

# Create a TestClient that doesn't start a real server
Expand Down Expand Up @@ -607,8 +607,7 @@ def test_app_with_a2a(
memory_service_uri="",
allow_origins=["*"],
a2a=True,
host="127.0.0.1",
port=8000,
base_url="http://127.0.0.1:8000",
)

client = TestClient(app)
Expand Down Expand Up @@ -897,5 +896,92 @@ def test_a2a_disabled_by_default(test_app):
logger.info("A2A disabled by default test passed")


def test_runtime_config_contains_base_url(test_app):
"""Test that ./assets/config/runtime-config.json contains the base_url passed to the FastAPI app."""
# The test_app fixture configures the FastAPI app with base_url="http://127.0.0.1:8000"
expected_base_url = "http://127.0.0.1:8000"

# Make a request to the runtime config endpoint
response = test_app.get("/dev-ui/assets/config/runtime-config.json")

# Verify the response
assert response.status_code == 200
data = response.json()

# Verify the structure and content
assert isinstance(data, dict)
assert "backendUrl" in data
assert data["backendUrl"] == expected_base_url

logger.info(f"Runtime config test passed - base_url: {data['backendUrl']}")


def test_runtime_config_with_custom_base_url(
mock_session_service,
mock_artifact_service,
mock_memory_service,
mock_agent_loader,
mock_eval_sets_manager,
mock_eval_set_results_manager,
):
"""Test that runtime-config.json contains a custom base_url when provided."""
custom_base_url = "https://example.com:9000/adk"

# Create a FastAPI app with a custom base_url
with (
patch("signal.signal", return_value=None),
patch(
"google.adk.cli.fast_api.InMemorySessionService",
return_value=mock_session_service,
),
patch(
"google.adk.cli.fast_api.InMemoryArtifactService",
return_value=mock_artifact_service,
),
patch(
"google.adk.cli.fast_api.InMemoryMemoryService",
return_value=mock_memory_service,
),
patch(
"google.adk.cli.fast_api.AgentLoader",
return_value=mock_agent_loader,
),
patch(
"google.adk.cli.fast_api.LocalEvalSetsManager",
return_value=mock_eval_sets_manager,
),
patch(
"google.adk.cli.fast_api.LocalEvalSetResultsManager",
return_value=mock_eval_set_results_manager,
),
):
adk_app = get_fast_api_app(
agents_dir=".",
web=True,
session_service_uri="",
artifact_service_uri="",
memory_service_uri="",
allow_origins=["*"],
a2a=False,
base_url=custom_base_url,
)
app = FastAPI()
app.mount("/adk", adk_app)

client = TestClient(app)

# Make a request to the runtime config endpoint
response = client.get("/adk/dev-ui/assets/config/runtime-config.json")

# Verify the response contains the custom base_url
assert response.status_code == 200
data = response.json()
assert isinstance(data, dict)
assert "backendUrl" in data
assert data["backendUrl"] == custom_base_url

logger.info(f"Custom runtime config test passed - base_url: {data['backendUrl']}")


if __name__ == "__main__":
pytest.main(["-xvs", __file__])