Skip to content

feat: more generic url #2556

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion contributing/samples/a2a_auth/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ When deploying the remote BigQuery A2A agent to different environments (e.g., Cl
}
```

**Important:** The `url` field in `remote_a2a/bigquery_agent/agent.json` must point to the actual RPC endpoint where your remote BigQuery A2A agent is deployed and accessible.
**Important:** The `url` field in `remote_a2a/bigquery_agent/agent.json` must point to the actual RPC endpoint where your remote BigQuery A2A agent is deployed and accessible. If the `url` field is an empty string, it will be automatically filled by the base URL provided to `get_fast_api_app`.

## Troubleshooting

Expand Down
2 changes: 1 addition & 1 deletion contributing/samples/a2a_basic/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ When deploying the remote A2A agent to different environments (e.g., Cloud Run,
}
```

**Important:** The `url` field in `remote_a2a/check_prime_agent/agent.json` must point to the actual RPC endpoint where your remote A2A agent is deployed and accessible.
**Important:** The `url` field in `remote_a2a/check_prime_agent/agent.json` must point to the actual RPC endpoint where your remote A2A agent is deployed and accessible. If the `url` field is an empty string, it will be automatically filled by the base URL provided to `get_fast_api_app`.

## Troubleshooting

Expand Down
2 changes: 1 addition & 1 deletion contributing/samples/a2a_human_in_loop/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ When deploying the remote approval A2A agent to different environments (e.g., Cl
}
```

**Important:** The `url` field in `remote_a2a/human_in_loop/agent.json` must point to the actual RPC endpoint where your remote approval A2A agent is deployed and accessible.
**Important:** The `url` field in `remote_a2a/human_in_loop/agent.json` must point to the actual RPC endpoint where your remote approval A2A agent is deployed and accessible. If the `url` field is an empty string, it will be automatically filled by the base URL provided to `get_fast_api_app`.

## Troubleshooting

Expand Down
31 changes: 14 additions & 17 deletions src/google/adk/cli/cli_tools_click.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import os
import tempfile
from typing import Optional
from urllib.parse import urlparse

import click
from click.core import ParameterSource
Expand Down Expand Up @@ -616,18 +617,12 @@ def fast_api_common_options():

def decorator(func):
@click.option(
"--host",
"--base_url",
type=str,
help="Optional. The binding host of the server",
default="127.0.0.1",
help="Optional. The base URL of the server.",
default="http://127.0.0.1:8000",
show_default=True,
)
@click.option(
"--port",
type=int,
help="Optional. The port of the server",
default=8000,
)
@click.option(
"--allow_origins",
help="Optional. Any additional origins to allow for CORS.",
Expand Down Expand Up @@ -719,8 +714,7 @@ def cli_web(
eval_storage_uri: Optional[str] = None,
log_level: str = "INFO",
allow_origins: Optional[list[str]] = None,
host: str = "127.0.0.1",
port: int = 8000,
base_url="http://127.0.0.1:8000",
trace_to_cloud: bool = False,
reload: bool = True,
session_service_uri: Optional[str] = None,
Expand All @@ -741,6 +735,9 @@ def cli_web(
adk web --session_service_uri=[uri] --port=[port] path/to/agents_dir
"""
logs.setup_adk_logger(getattr(logging, log_level.upper()))
parsed_url = urlparse(base_url)
host = parsed_url.hostname
port = parsed_url.port

@asynccontextmanager
async def _lifespan(app: FastAPI):
Expand Down Expand Up @@ -777,8 +774,7 @@ async def _lifespan(app: FastAPI):
trace_to_cloud=trace_to_cloud,
lifespan=_lifespan,
a2a=a2a,
host=host,
port=port,
base_url=base_url,
reload_agents=reload_agents,
)
config = uvicorn.Config(
Expand Down Expand Up @@ -810,8 +806,7 @@ def cli_api_server(
eval_storage_uri: Optional[str] = None,
log_level: str = "INFO",
allow_origins: Optional[list[str]] = None,
host: str = "127.0.0.1",
port: int = 8000,
base_url="http://127.0.0.1:8000",
trace_to_cloud: bool = False,
reload: bool = True,
session_service_uri: Optional[str] = None,
Expand All @@ -833,6 +828,9 @@ def cli_api_server(
"""
logs.setup_adk_logger(getattr(logging, log_level.upper()))

parsed_url = urlparse(base_url)
host = parsed_url.hostname
port = parsed_url.port
session_service_uri = session_service_uri or session_db_url
artifact_service_uri = artifact_service_uri or artifact_storage_uri
config = uvicorn.Config(
Expand All @@ -846,8 +844,7 @@ def cli_api_server(
web=False,
trace_to_cloud=trace_to_cloud,
a2a=a2a,
host=host,
port=port,
base_url=base_url,
reload_agents=reload_agents,
),
host=host,
Expand Down
7 changes: 5 additions & 2 deletions src/google/adk/cli/fast_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ def get_fast_api_app(
allow_origins: Optional[list[str]] = None,
web: bool,
a2a: bool = False,
host: str = "127.0.0.1",
port: int = 8000,
base_url: str = "http://127.0.0.1:8000",
trace_to_cloud: bool = False,
reload_agents: bool = False,
lifespan: Optional[Lifespan[FastAPI]] = None,
Expand Down Expand Up @@ -352,6 +351,8 @@ async def _get_a2a_runner_async() -> Runner:
logger.info("Setting up A2A agent: %s", app_name)

try:
a2a_rpc_path = f"{base_url}/a2a/{app_name}"

agent_executor = A2aAgentExecutor(
runner=create_a2a_runner_loader(app_name),
)
Expand All @@ -363,6 +364,8 @@ async def _get_a2a_runner_async() -> Runner:
with (p / "agent.json").open("r", encoding="utf-8") as f:
data = json.load(f)
agent_card = AgentCard(**data)
if agent_card.url == "": # empty url is a placeholder to be filled with the provided url
agent_card.url = a2a_rpc_path

a2a_app = A2AStarletteApplication(
agent_card=agent_card,
Expand Down
48 changes: 34 additions & 14 deletions tests/unittests/cli/test_fast_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from unittest.mock import MagicMock
from unittest.mock import patch

from a2a.utils import AGENT_CARD_WELL_KNOWN_PATH
from fastapi.testclient import TestClient
from google.adk.agents.base_agent import BaseAgent
from google.adk.agents.run_config import RunConfig
Expand Down Expand Up @@ -432,8 +433,7 @@ def test_app(
memory_service_uri="",
allow_origins=["*"],
a2a=False, # Disable A2A for most tests
host="127.0.0.1",
port=8000,
base_url="http://127.0.0.1:8000",
)

# Create a TestClient that doesn't start a real server
Expand Down Expand Up @@ -502,11 +502,37 @@ def temp_agents_dir_with_a2a():

# Create agent.json file
agent_card = {
"capabilities": {
"pushNotifications": True,
"streaming": True
},
"defaultInputModes": [
"text",
"text/plain"
],
"defaultOutputModes": [
"text",
"text/plain"
],
"name": "test_a2a_agent",
"description": "Test A2A agent",
"version": "1.0.0",
"author": "test",
"capabilities": ["text"],
"protocolVersion": "0.2.6",
"skills": [
{
"description": "Makes the tests pass",
"examples": [
"Fix the tests."
],
"id": "test_a2a_agent",
"name": "Test A2A agent",
"tags": [
"testing"
]
}
],
"url": "",
}

with open(agent_dir / "agent.json", "w") as f:
Expand Down Expand Up @@ -580,20 +606,12 @@ def test_app_with_a2a(
patch(
"a2a.server.request_handlers.DefaultRequestHandler"
) as mock_handler,
patch("a2a.server.apps.A2AStarletteApplication") as mock_a2a_app,
):
# Configure mocks
mock_task_store.return_value = MagicMock()
mock_executor.return_value = MagicMock()
mock_handler.return_value = MagicMock()

# Mock A2AStarletteApplication
mock_app_instance = MagicMock()
mock_app_instance.routes.return_value = (
[]
) # Return empty routes for testing
mock_a2a_app.return_value = mock_app_instance

# Change to temp directory
original_cwd = os.getcwd()
os.chdir(temp_agents_dir_with_a2a)
Expand All @@ -607,8 +625,7 @@ def test_app_with_a2a(
memory_service_uri="",
allow_origins=["*"],
a2a=True,
host="127.0.0.1",
port=8000,
base_url="http://127.0.0.1:8000",
)

client = TestClient(app)
Expand Down Expand Up @@ -879,9 +896,12 @@ def test_debug_trace(test_app):
)
def test_a2a_agent_discovery(test_app_with_a2a):
"""Test that A2A agents are properly discovered and configured."""
# This test mainly verifies that the A2A setup doesn't break the app
# This test verifies that the A2A setup doesn't break the app
# and that the well known card works
response = test_app_with_a2a.get("/list-apps")
assert response.status_code == 200
response2 = test_app_with_a2a.get(f"/a2a/test_a2a_agent{AGENT_CARD_WELL_KNOWN_PATH}")
assert response2.status_code == 200
logger.info("A2A agent discovery test passed")


Expand Down