Skip to content

Commit 8e2352f

Browse files
committed
Added tests
1 parent 7e6249e commit 8e2352f

File tree

4 files changed

+190
-20
lines changed

4 files changed

+190
-20
lines changed

src/a2a/server/apps/jsonrpc/jsonrpc_app.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
from sse_starlette.sse import EventSourceResponse
7272
from starlette.applications import Starlette
7373
from starlette.authentication import BaseUser
74+
from starlette.datastructures import URL
7475
from starlette.exceptions import HTTPException
7576
from starlette.requests import Request
7677
from starlette.responses import JSONResponse, Response
@@ -488,6 +489,12 @@ async def event_generator(
488489
)
489490

490491
def _modify_rpc_url(self, agent_card: AgentCard, request: Request):
492+
"""Modifies Agent's RPC URL based on the AgentCard request.
493+
494+
Args:
495+
agent_card (AgentCard): Original AgentCard
496+
request (Request): AgentCard request
497+
"""
491498
rpc_url = URL(agent_card.url)
492499
rpc_path = rpc_url.path
493500
port = None
@@ -499,6 +506,7 @@ def _modify_rpc_url(self, agent_card: AgentCard, request: Request):
499506

500507
if "X-Forwarded-Proto" in request.headers:
501508
scheme = request.headers["X-Forwarded-Proto"]
509+
port = None
502510
else:
503511
scheme = request.url.scheme
504512
if not scheme:
@@ -526,23 +534,14 @@ def _modify_rpc_url(self, agent_card: AgentCard, request: Request):
526534
new_path = new_path.rstrip("/")
527535
rpc_path = new_path
528536

529-
if port:
530-
agent_card.url = str(
531-
rpc_url.replace(
532-
hostname=host,
533-
port=port,
534-
scheme=scheme,
535-
path=rpc_path
536-
)
537-
)
538-
else:
539-
agent_card.url = str(
540-
rpc_url.replace(
541-
hostname=host,
542-
scheme=scheme,
543-
path=rpc_path
544-
)
537+
agent_card.url = str(
538+
rpc_url.replace(
539+
hostname=host,
540+
port=port,
541+
scheme=scheme,
542+
path=rpc_path
545543
)
544+
)
546545

547546
async def _handle_get_agent_card(self, request: Request) -> JSONResponse:
548547
"""Handles GET requests for the agent card endpoint.

src/a2a/server/apps/rest/rest_adapter.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
if TYPE_CHECKING:
99
from sse_starlette.sse import EventSourceResponse
10+
from starlette.datastructures import URL
1011
from starlette.requests import Request
1112
from starlette.responses import JSONResponse, Response
1213

@@ -15,6 +16,7 @@
1516
else:
1617
try:
1718
from sse_starlette.sse import EventSourceResponse
19+
from starlette.datastructures import URL
1820
from starlette.requests import Request
1921
from starlette.responses import JSONResponse, Response
2022

@@ -119,7 +121,8 @@ async def handle_get_agent_card(self, request: Request) -> JSONResponse:
119121
A JSONResponse containing the agent card data.
120122
"""
121123
# The public agent card is a direct serialization of the agent_card
122-
# provided at initialization.
124+
# provided at initialization except for the RPC URL.
125+
self._modify_rpc_url(self.agent_card, request)
123126
return JSONResponse(
124127
self.agent_card.model_dump(mode='json', exclude_none=True)
125128
)
@@ -145,9 +148,65 @@ async def handle_authenticated_agent_card(
145148
message='Authenticated card not supported'
146149
)
147150
)
151+
self._modify_rpc_url(self.agent_card, request)
148152
return JSONResponse(
149153
self.agent_card.model_dump(mode='json', exclude_none=True)
150154
)
155+
156+
def _modify_rpc_url(self, agent_card: AgentCard, request: Request):
157+
"""Modifies Agent's RPC URL based on the AgentCard request.
158+
159+
Args:
160+
agent_card (AgentCard): Original AgentCard
161+
request (Request): AgentCard request
162+
"""
163+
rpc_url = URL(agent_card.url)
164+
rpc_path = rpc_url.path
165+
port = None
166+
if "X-Forwarded-Host" in request.headers:
167+
host = request.headers["X-Forwarded-Host"]
168+
else:
169+
host = request.url.hostname
170+
port = request.url.port
171+
172+
if "X-Forwarded-Proto" in request.headers:
173+
scheme = request.headers["X-Forwarded-Proto"]
174+
port = None
175+
else:
176+
scheme = request.url.scheme
177+
if not scheme:
178+
scheme = "http"
179+
if ":" in host: # type: ignore
180+
comps = host.rsplit(":", 1) # type: ignore
181+
host = comps[0]
182+
port = comps[1]
183+
184+
# Handle URL maps,
185+
# e.g. "agents/my-agent/.well-known/agent-card.json"
186+
if "X-Forwarded-Path" in request.headers:
187+
forwarded_path = request.headers["X-Forwarded-Path"].strip()
188+
if (
189+
forwarded_path and
190+
request.url.path != forwarded_path
191+
and forwarded_path.endswith(request.url.path)
192+
):
193+
# "agents/my-agent" for "agents/my-agent/.well-known/agent-card.json"
194+
extra_path = forwarded_path[:-len(request.url.path)]
195+
new_path = extra_path + rpc_path
196+
# If original path was just "/",
197+
# we remove trailing "/" in the the extended one
198+
if len(new_path) > 1 and rpc_path == "/":
199+
new_path = new_path.rstrip("/")
200+
rpc_path = new_path
201+
202+
agent_card.url = str(
203+
rpc_url.replace(
204+
hostname=host,
205+
port=port,
206+
scheme=scheme,
207+
path=rpc_path
208+
)
209+
)
151210

152211
def routes(self) -> dict[tuple[str, str], Callable[[Request], Any]]:
153212
"""Constructs a dictionary of API routes and their corresponding handlers.

tests/server/apps/jsonrpc/test_jsonrpc_app.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
RequestHandler,
2727
) # For mock spec
2828
from a2a.types import (
29+
AgentCapabilities,
2930
AgentCard,
3031
Message,
3132
MessageSendParams,
@@ -36,7 +37,7 @@
3637
SendMessageSuccessResponse,
3738
TextPart,
3839
)
39-
40+
from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH
4041

4142
# --- StarletteUserProxy Tests ---
4243

@@ -356,5 +357,58 @@ def side_effect(request, context: ServerCallContext):
356357
}
357358

358359

360+
class TestAgentCardHandler:
361+
@pytest.fixture
362+
def agent_card(self):
363+
return AgentCard(
364+
name='APIKeyAgent',
365+
description='An agent that uses API Key auth.',
366+
url='http://localhost:8000',
367+
version='1.0.0',
368+
capabilities=AgentCapabilities(),
369+
default_input_modes=['text/plain'],
370+
default_output_modes=['text/plain'],
371+
skills=[],
372+
)
373+
374+
def test_agent_card_url_rewriting(
375+
self, agent_card: AgentCard,
376+
):
377+
"""
378+
Tests that the A2AStarletteApplication endpoint correctly handles Agent URL rewriting.
379+
"""
380+
handler = AsyncMock()
381+
app_instance = A2AStarletteApplication(agent_card, handler)
382+
client = TestClient(
383+
app_instance.build(),
384+
base_url="https://my-agents.com:5000"
385+
)
386+
387+
response = client.get(AGENT_CARD_WELL_KNOWN_PATH)
388+
response.raise_for_status()
389+
assert response.json()["url"] == "https://my-agents.com:5000"
390+
391+
response = client.get(
392+
AGENT_CARD_WELL_KNOWN_PATH,
393+
headers={
394+
"X-Forwarded-Host": "my-great-agents.com:5678",
395+
"X-Forwarded-Proto": "http",
396+
"X-Forwarded-Path":
397+
"/agents/my-agent" + AGENT_CARD_WELL_KNOWN_PATH
398+
}
399+
)
400+
assert response.json()["url"] == "http://my-great-agents.com:5678/agents/my-agent"
401+
402+
client = TestClient(
403+
app_instance.build(
404+
agent_card_url="/agents/my-agent" + AGENT_CARD_WELL_KNOWN_PATH
405+
),
406+
base_url="https://my-mighty-agents.com"
407+
)
408+
409+
response = client.get("/agents/my-agent" + AGENT_CARD_WELL_KNOWN_PATH)
410+
assert response.json()["url"] == "https://my-mighty-agents.com/agents/my-agent"
411+
412+
359413
if __name__ == '__main__':
360414
pytest.main([__file__])

tests/server/apps/rest/test_rest_fastapi_app.py

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22

33
from typing import Any
4-
from unittest.mock import MagicMock
4+
from unittest.mock import AsyncMock, MagicMock
55

66
import pytest
77

@@ -15,6 +15,7 @@
1515
from a2a.server.apps.rest.rest_adapter import RESTAdapter
1616
from a2a.server.request_handlers.request_handler import RequestHandler
1717
from a2a.types import (
18+
AgentCapabilities,
1819
AgentCard,
1920
Message,
2021
Part,
@@ -24,7 +25,7 @@
2425
TaskStatus,
2526
TextPart,
2627
)
27-
28+
from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH
2829

2930
logger = logging.getLogger(__name__)
3031

@@ -222,5 +223,62 @@ async def test_send_message_success_task(
222223
assert expected_response == actual_response
223224

224225

226+
class TestAgentCardHandler:
227+
@pytest.fixture
228+
def agent_card(self):
229+
return AgentCard(
230+
name='APIKeyAgent',
231+
description='An agent that uses API Key auth.',
232+
url='http://localhost:8000',
233+
version='1.0.0',
234+
capabilities=AgentCapabilities(),
235+
default_input_modes=['text/plain'],
236+
default_output_modes=['text/plain'],
237+
skills=[],
238+
)
239+
240+
@pytest.mark.anyio
241+
async def test_agent_card_url_rewriting(
242+
self, agent_card: AgentCard,
243+
):
244+
"""
245+
Tests that the REST endpoint correctly handles Agent URL rewriting.
246+
"""
247+
app_instance = A2ARESTFastAPIApplication(agent_card, AsyncMock())
248+
app = app_instance.build(
249+
agent_card_url=AGENT_CARD_WELL_KNOWN_PATH
250+
)
251+
client = AsyncClient(
252+
transport=ASGITransport(app=app),
253+
base_url="https://my-agents.com:5000"
254+
)
255+
256+
response = await client.get(AGENT_CARD_WELL_KNOWN_PATH)
257+
response.raise_for_status()
258+
assert response.json()["url"] == "https://my-agents.com:5000"
259+
260+
response = await client.get(
261+
AGENT_CARD_WELL_KNOWN_PATH,
262+
headers={
263+
"X-Forwarded-Host": "my-great-agents.com:5678",
264+
"X-Forwarded-Proto": "http",
265+
"X-Forwarded-Path":
266+
"/agents/my-agent" + AGENT_CARD_WELL_KNOWN_PATH
267+
}
268+
)
269+
assert response.json()["url"] == "http://my-great-agents.com:5678/agents/my-agent"
270+
271+
app = app_instance.build(
272+
agent_card_url="/agents/my-agent" + AGENT_CARD_WELL_KNOWN_PATH
273+
)
274+
client = AsyncClient(
275+
transport=ASGITransport(app=app),
276+
base_url="https://my-mighty-agents.com"
277+
)
278+
279+
response = await client.get("/agents/my-agent" + AGENT_CARD_WELL_KNOWN_PATH)
280+
assert response.json()["url"] == "https://my-mighty-agents.com/agents/my-agent"
281+
282+
225283
if __name__ == '__main__':
226284
pytest.main([__file__])

0 commit comments

Comments
 (0)