diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py index 13a0a620a..1b708fbf6 100644 --- a/src/google/adk/cli/cli_tools_click.py +++ b/src/google/adk/cli/cli_tools_click.py @@ -23,6 +23,7 @@ import os import tempfile from typing import Optional +from urllib.parse import urlparse import click from click.core import ParameterSource @@ -616,18 +617,12 @@ def fast_api_common_options(): def decorator(func): @click.option( - "--host", + "--base_url", type=str, - help="Optional. The binding host of the server", - default="127.0.0.1", + help="Optional. The base URL of the server.", + default="http://127.0.0.1:8000", show_default=True, ) - @click.option( - "--port", - type=int, - help="Optional. The port of the server", - default=8000, - ) @click.option( "--allow_origins", help="Optional. Any additional origins to allow for CORS.", @@ -719,8 +714,7 @@ def cli_web( eval_storage_uri: Optional[str] = None, log_level: str = "INFO", allow_origins: Optional[list[str]] = None, - host: str = "127.0.0.1", - port: int = 8000, + base_url="http://127.0.0.1:8000", trace_to_cloud: bool = False, reload: bool = True, session_service_uri: Optional[str] = None, @@ -741,6 +735,9 @@ def cli_web( adk web --session_service_uri=[uri] --port=[port] path/to/agents_dir """ logs.setup_adk_logger(getattr(logging, log_level.upper())) + parsed_url = urlparse(base_url) + host = parsed_url.hostname + port = parsed_url.port @asynccontextmanager async def _lifespan(app: FastAPI): @@ -777,8 +774,7 @@ async def _lifespan(app: FastAPI): trace_to_cloud=trace_to_cloud, lifespan=_lifespan, a2a=a2a, - host=host, - port=port, + base_url=base_url, reload_agents=reload_agents, ) config = uvicorn.Config( @@ -810,8 +806,7 @@ def cli_api_server( eval_storage_uri: Optional[str] = None, log_level: str = "INFO", allow_origins: Optional[list[str]] = None, - host: str = "127.0.0.1", - port: int = 8000, + base_url="http://127.0.0.1:8000", trace_to_cloud: bool = False, reload: bool = True, session_service_uri: Optional[str] = None, @@ -833,6 +828,9 @@ def cli_api_server( """ logs.setup_adk_logger(getattr(logging, log_level.upper())) + parsed_url = urlparse(base_url) + host = parsed_url.hostname + port = parsed_url.port session_service_uri = session_service_uri or session_db_url artifact_service_uri = artifact_service_uri or artifact_storage_uri config = uvicorn.Config( @@ -846,8 +844,7 @@ def cli_api_server( web=False, trace_to_cloud=trace_to_cloud, a2a=a2a, - host=host, - port=port, + base_url=base_url, reload_agents=reload_agents, ), host=host, diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index 7d93b5436..0b4a41968 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -64,8 +64,7 @@ def get_fast_api_app( allow_origins: Optional[list[str]] = None, web: bool, a2a: bool = False, - host: str = "127.0.0.1", - port: int = 8000, + base_url: str = "http://127.0.0.1:8000", trace_to_cloud: bool = False, reload_agents: bool = False, lifespan: Optional[Lifespan[FastAPI]] = None, @@ -352,6 +351,8 @@ async def _get_a2a_runner_async() -> Runner: logger.info("Setting up A2A agent: %s", app_name) try: + a2a_rpc_path = f"{base_url}/a2a/{app_name}" + agent_executor = A2aAgentExecutor( runner=create_a2a_runner_loader(app_name), ) @@ -363,6 +364,8 @@ async def _get_a2a_runner_async() -> Runner: with (p / "agent.json").open("r", encoding="utf-8") as f: data = json.load(f) agent_card = AgentCard(**data) + # todo: if url is not defined, override it here + agent_card.url = a2a_rpc_path a2a_app = A2AStarletteApplication( agent_card=agent_card, diff --git a/tests/unittests/cli/test_fast_api.py b/tests/unittests/cli/test_fast_api.py index f1c9e9d6e..c9be32c1f 100755 --- a/tests/unittests/cli/test_fast_api.py +++ b/tests/unittests/cli/test_fast_api.py @@ -432,8 +432,7 @@ def test_app( memory_service_uri="", allow_origins=["*"], a2a=False, # Disable A2A for most tests - host="127.0.0.1", - port=8000, + base_url="http://127.0.0.1:8000", ) # Create a TestClient that doesn't start a real server @@ -607,8 +606,7 @@ def test_app_with_a2a( memory_service_uri="", allow_origins=["*"], a2a=True, - host="127.0.0.1", - port=8000, + base_url="http://127.0.0.1:8000", ) client = TestClient(app)