Skip to content

Commit 9242d62

Browse files
authored
Merge branch 'main' into remove-camel-case
2 parents 37668bd + c94d6aa commit 9242d62

File tree

21 files changed

+619
-63
lines changed

21 files changed

+619
-63
lines changed

.ruff.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ inline-quotes = "single"
136136
"PLR2004",
137137
"SLF001",
138138
]
139-
"types.py" = ["D", "E501", "N815"] # Ignore docstring and annotation issues in types.py
139+
"types.py" = ["D", "E501"] # Ignore docstring and annotation issues in types.py
140140
"proto_utils.py" = ["D102", "PLR0911"]
141141
"helpers.py" = ["ANN001", "ANN201", "ANN202"]
142142

scripts/format.sh

Lines changed: 61 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,37 +2,84 @@
22
set -e
33
set -o pipefail
44

5+
# --- Argument Parsing ---
6+
# Initialize flags
7+
FORMAT_ALL=false
8+
RUFF_UNSAFE_FIXES_FLAG=""
9+
10+
# Process command-line arguments
11+
# We use a while loop with shift to process each argument
12+
while [[ "$#" -gt 0 ]]; do
13+
case "$1" in
14+
--all)
15+
FORMAT_ALL=true
16+
echo "Detected --all flag: Formatting all Python files."
17+
shift # Consume the argument
18+
;;
19+
--unsafe-fixes)
20+
RUFF_UNSAFE_FIXES_FLAG="--unsafe-fixes"
21+
echo "Detected --unsafe-fixes flag: Ruff will run with unsafe fixes."
22+
shift # Consume the argument
23+
;;
24+
*)
25+
# Handle unknown arguments or just ignore them if we only care about specific ones
26+
echo "Warning: Unknown argument '$1'. Ignoring."
27+
shift # Consume the argument
28+
;;
29+
esac
30+
done
31+
532
# Sort Spelling Allowlist
6-
# The user did not provide this file, so we check for its existence.
733
SPELLING_ALLOW_FILE=".github/actions/spelling/allow.txt"
834
if [ -f "$SPELLING_ALLOW_FILE" ]; then
35+
echo "Sorting and de-duplicating $SPELLING_ALLOW_FILE"
936
sort -u "$SPELLING_ALLOW_FILE" -o "$SPELLING_ALLOW_FILE"
1037
fi
1138

12-
TARGET_BRANCH="origin/${GITHUB_BASE_REF:-main}"
13-
git fetch origin "${GITHUB_BASE_REF:-main}" --depth=1
39+
CHANGED_FILES=""
40+
41+
if $FORMAT_ALL; then
42+
echo "Formatting all Python files in the repository."
43+
# Find all Python files, excluding grpc generated files as per original logic.
44+
# `sort -u` ensures unique files and consistent ordering for display/xargs.
45+
CHANGED_FILES=$(find . -name '*.py' -not -path './src/a2a/grpc/*' | sort -u)
1446

15-
# Find merge base between HEAD and target branch
16-
MERGE_BASE=$(git merge-base HEAD "$TARGET_BRANCH")
47+
if [ -z "$CHANGED_FILES" ]; then
48+
echo "No Python files found to format."
49+
exit 0
50+
fi
51+
else
52+
echo "No '--all' flag found. Formatting changed Python files based on git diff."
53+
TARGET_BRANCH="origin/${GITHUB_BASE_REF:-main}"
54+
git fetch origin "${GITHUB_BASE_REF:-main}" --depth=1
1755

18-
# Get python files changed in this PR, excluding grpc generated files
19-
CHANGED_FILES=$(git diff --name-only --diff-filter=ACMRTUXB "$MERGE_BASE" HEAD -- '*.py' ':!src/a2a/grpc/*')
56+
MERGE_BASE=$(git merge-base HEAD "$TARGET_BRANCH")
2057

21-
if [ -z "$CHANGED_FILES" ]; then
22-
echo "No changed Python files to format."
23-
exit 0
58+
# Get python files changed in this PR, excluding grpc generated files
59+
CHANGED_FILES=$(git diff --name-only --diff-filter=ACMRTUXB "$MERGE_BASE" HEAD -- '*.py' ':!src/a2a/grpc/*')
60+
61+
if [ -z "$CHANGED_FILES" ]; then
62+
echo "No changed Python files to format."
63+
exit 0
64+
fi
2465
fi
2566

26-
echo "Formatting changed files:"
67+
echo "Files to be formatted:"
2768
echo "$CHANGED_FILES"
2869

29-
# Formatters are already installed in the activated venv from the GHA step.
30-
# Use xargs to pass the file list to the formatters.
70+
# Helper function to run formatters with the list of files.
71+
# The list of files is passed to xargs via stdin.
3172
run_formatter() {
3273
echo "$CHANGED_FILES" | xargs -r "$@"
3374
}
3475

76+
echo "Running pyupgrade..."
3577
run_formatter pyupgrade --exit-zero-even-if-changed --py310-plus
78+
echo "Running autoflake..."
3679
run_formatter autoflake -i -r --remove-all-unused-imports
37-
run_formatter ruff check --fix-only
80+
echo "Running ruff check (fix-only)..."
81+
run_formatter ruff check --fix-only $RUFF_UNSAFE_FIXES_FLAG
82+
echo "Running ruff format..."
3883
run_formatter ruff format
84+
85+
echo "Formatting complete."

src/a2a/extensions/common.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from a2a.types import AgentCard, AgentExtension
2+
3+
4+
HTTP_EXTENSION_HEADER = 'X-A2A-Extensions'
5+
6+
7+
def get_requested_extensions(values: list[str]) -> set[str]:
8+
"""Get the set of requested extensions from an input list.
9+
10+
This handles the list containing potentially comma-separated values, as
11+
occurs when using a list in an HTTP header.
12+
"""
13+
return {
14+
stripped
15+
for v in values
16+
for ext in v.split(',')
17+
if (stripped := ext.strip())
18+
}
19+
20+
21+
def find_extension_by_uri(card: AgentCard, uri: str) -> AgentExtension | None:
22+
"""Find an AgentExtension in an AgentCard given a uri."""
23+
for ext in card.capabilities.extensions or []:
24+
if ext.uri == uri:
25+
return ext
26+
27+
return None

src/a2a/grpc/a2a_pb2.py

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/a2a/server/agent_execution/context.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,24 @@ def metadata(self) -> dict[str, Any]:
143143
return {}
144144
return self._params.metadata or {}
145145

146+
def add_activated_extension(self, uri: str) -> None:
147+
"""Add an extension to the set of activated extensions for this request.
148+
149+
This causes the extension to be indicated back to the client in the
150+
response.
151+
"""
152+
if self._call_context:
153+
self._call_context.activated_extensions.add(uri)
154+
155+
@property
156+
def requested_extensions(self) -> set[str]:
157+
"""Extensions that the client requested to activate."""
158+
return (
159+
self._call_context.requested_extensions
160+
if self._call_context
161+
else set()
162+
)
163+
146164
def _check_or_generate_task_id(self) -> None:
147165
"""Ensures a task ID is present, generating one if necessary."""
148166
if not self._params:

src/a2a/server/apps/jsonrpc/jsonrpc_app.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@
1919

2020
from a2a.auth.user import UnauthenticatedUser
2121
from a2a.auth.user import User as A2AUser
22+
from a2a.extensions.common import (
23+
HTTP_EXTENSION_HEADER,
24+
get_requested_extensions,
25+
)
2226
from a2a.server.context import ServerCallContext
2327
from a2a.server.request_handlers.jsonrpc_handler import JSONRPCHandler
2428
from a2a.server.request_handlers.request_handler import RequestHandler
@@ -99,7 +103,13 @@ def build(self, request: Request) -> ServerCallContext:
99103
user = StarletteUserProxy(request.user)
100104
state['auth'] = request.auth
101105
state['headers'] = dict(request.headers)
102-
return ServerCallContext(user=user, state=state)
106+
return ServerCallContext(
107+
user=user,
108+
state=state,
109+
requested_extensions=get_requested_extensions(
110+
request.headers.getlist(HTTP_EXTENSION_HEADER)
111+
),
112+
)
103113

104114

105115
class JSONRPCApplication(ABC):
@@ -281,7 +291,7 @@ async def _process_streaming_request(
281291
request_obj, context
282292
)
283293

284-
return self._create_response(handler_result)
294+
return self._create_response(context, handler_result)
285295

286296
async def _process_non_streaming_request(
287297
self,
@@ -353,10 +363,11 @@ async def _process_non_streaming_request(
353363
id=request_id, error=error
354364
)
355365

356-
return self._create_response(handler_result)
366+
return self._create_response(context, handler_result)
357367

358368
def _create_response(
359369
self,
370+
context: ServerCallContext,
360371
handler_result: (
361372
AsyncGenerator[SendStreamingMessageResponse]
362373
| JSONRPCErrorResponse
@@ -372,12 +383,16 @@ def _create_response(
372383
payloads.
373384
374385
Args:
386+
context: The ServerCallContext provided to the request handler.
375387
handler_result: The result from a request handler method. Can be an
376388
async generator for streaming or a Pydantic model for non-streaming.
377389
378390
Returns:
379391
A Starlette JSONResponse or EventSourceResponse.
380392
"""
393+
headers = {}
394+
if exts := context.activated_extensions:
395+
headers[HTTP_EXTENSION_HEADER] = ', '.join(sorted(exts))
381396
if isinstance(handler_result, AsyncGenerator):
382397
# Result is a stream of SendStreamingMessageResponse objects
383398
async def event_generator(
@@ -386,17 +401,21 @@ async def event_generator(
386401
async for item in stream:
387402
yield {'data': item.root.model_dump_json(exclude_none=True)}
388403

389-
return EventSourceResponse(event_generator(handler_result))
404+
return EventSourceResponse(
405+
event_generator(handler_result), headers=headers
406+
)
390407
if isinstance(handler_result, JSONRPCErrorResponse):
391408
return JSONResponse(
392409
handler_result.model_dump(
393410
mode='json',
394411
exclude_none=True,
395-
)
412+
),
413+
headers=headers,
396414
)
397415

398416
return JSONResponse(
399-
handler_result.root.model_dump(mode='json', exclude_none=True)
417+
handler_result.root.model_dump(mode='json', exclude_none=True),
418+
headers=headers,
400419
)
401420

402421
async def _handle_get_agent_card(self, request: Request) -> JSONResponse:

src/a2a/server/context.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,5 @@ class ServerCallContext(BaseModel):
2121

2222
state: State = Field(default={})
2323
user: User = Field(default=UnauthenticatedUser())
24+
requested_extensions: set[str] = Field(default_factory=set)
25+
activated_extensions: set[str] = Field(default_factory=set)

src/a2a/server/request_handlers/grpc_handler.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
import logging
44

55
from abc import ABC, abstractmethod
6-
from collections.abc import AsyncIterable
6+
from collections.abc import AsyncIterable, Sequence
77

88

99
try:
1010
import grpc
1111
import grpc.aio
12+
13+
from grpc.aio import Metadata
1214
except ImportError as e:
1315
raise ImportError(
1416
'GrpcHandler requires grpcio and grpcio-tools to be installed. '
@@ -20,6 +22,10 @@
2022

2123
from a2a import types
2224
from a2a.auth.user import UnauthenticatedUser
25+
from a2a.extensions.common import (
26+
HTTP_EXTENSION_HEADER,
27+
get_requested_extensions,
28+
)
2329
from a2a.grpc import a2a_pb2
2430
from a2a.server.context import ServerCallContext
2531
from a2a.server.request_handlers.request_handler import RequestHandler
@@ -42,6 +48,19 @@ def build(self, context: grpc.aio.ServicerContext) -> ServerCallContext:
4248
"""Builds a ServerCallContext from a gRPC Request."""
4349

4450

51+
def _get_metadata_value(
52+
context: grpc.aio.ServicerContext, key: str
53+
) -> list[str]:
54+
md = context.invocation_metadata
55+
raw_values: list[str | bytes] = []
56+
if isinstance(md, Metadata):
57+
raw_values = md.get_all(key)
58+
elif isinstance(md, Sequence):
59+
lower_key = key.lower()
60+
raw_values = [e for (k, e) in md if k.lower() == lower_key]
61+
return [e if isinstance(e, str) else e.decode('utf-8') for e in raw_values]
62+
63+
4564
class DefaultCallContextBuilder(CallContextBuilder):
4665
"""A default implementation of CallContextBuilder."""
4766

@@ -51,7 +70,13 @@ def build(self, context: grpc.aio.ServicerContext) -> ServerCallContext:
5170
state = {}
5271
with contextlib.suppress(Exception):
5372
state['grpc_context'] = context
54-
return ServerCallContext(user=user, state=state)
73+
return ServerCallContext(
74+
user=user,
75+
state=state,
76+
requested_extensions=get_requested_extensions(
77+
_get_metadata_value(context, HTTP_EXTENSION_HEADER)
78+
),
79+
)
5580

5681

5782
class GrpcHandler(a2a_grpc.A2AServiceServicer):
@@ -102,6 +127,7 @@ async def SendMessage(
102127
task_or_message = await self.request_handler.on_message_send(
103128
a2a_request, server_context
104129
)
130+
self._set_extension_metadata(context, server_context)
105131
return proto_utils.ToProto.task_or_message(task_or_message)
106132
except ServerError as e:
107133
await self.abort_context(e, context)
@@ -140,6 +166,7 @@ async def SendStreamingMessage(
140166
a2a_request, server_context
141167
):
142168
yield proto_utils.ToProto.stream_response(event)
169+
self._set_extension_metadata(context, server_context)
143170
except ServerError as e:
144171
await self.abort_context(e, context)
145172
return
@@ -371,3 +398,16 @@ async def abort_context(
371398
grpc.StatusCode.UNKNOWN,
372399
f'Unknown error type: {error.error}',
373400
)
401+
402+
def _set_extension_metadata(
403+
self,
404+
context: grpc.aio.ServicerContext,
405+
server_context: ServerCallContext,
406+
) -> None:
407+
if server_context.activated_extensions:
408+
context.set_trailing_metadata(
409+
[
410+
(HTTP_EXTENSION_HEADER, e)
411+
for e in sorted(server_context.activated_extensions)
412+
]
413+
)

src/a2a/utils/task.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def new_task(request: Message) -> Task:
1818
1919
Raises:
2020
TypeError: If the message role is None.
21-
ValueError: If the message parts are empty or if any part has empty content.
21+
ValueError: If the message parts are empty, if any part has empty content, or if the provided context_id is invalid.
2222
"""
2323
if not request.role:
2424
raise TypeError('Message role cannot be None')
@@ -28,12 +28,22 @@ def new_task(request: Message) -> Task:
2828
if isinstance(part.root, TextPart) and not part.root.text:
2929
raise ValueError('TextPart content cannot be empty')
3030

31+
context_id_str = request.context_id
32+
if context_id_str is not None:
33+
try:
34+
uuid.UUID(context_id_str)
35+
context_id = context_id_str
36+
except (ValueError, AttributeError, TypeError) as e:
37+
raise ValueError(
38+
f"Invalid context_id: '{context_id_str}' is not a valid UUID."
39+
) from e
40+
else:
41+
context_id = str(uuid.uuid4())
42+
3143
return Task(
3244
status=TaskStatus(state=TaskState.submitted),
3345
id=(request.task_id if request.task_id else str(uuid.uuid4())),
34-
context_id=(
35-
request.context_id if request.context_id else str(uuid.uuid4())
36-
),
46+
context_id=context_id,
3747
history=[request],
3848
)
3949

0 commit comments

Comments
 (0)