Skip to content

Commit a257789

Browse files
feat: a2a extensions API and async agent card caching; fix task propagation & streaming
Adds initial extensions API (with registry temporarily no-op), introduces aiocache for async caching, ensures reference task IDs propagate correctly, fixes streamed response model handling, updates streaming tests, and regenerates lockfiles.
1 parent 09f1ba6 commit a257789

File tree

10 files changed

+684
-208
lines changed

10 files changed

+684
-208
lines changed

lib/crewai/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ a2a = [
9595
"a2a-sdk~=0.3.10",
9696
"httpx-auth~=0.23.1",
9797
"httpx-sse~=0.4.0",
98+
"aiocache[redis,memcached]~=0.12.3",
9899
]
99100

100101

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
"""A2A Protocol Extensions for CrewAI.
2+
3+
This module contains extensions to the A2A (Agent-to-Agent) protocol.
4+
"""
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
"""Base extension interface for A2A wrapper integrations.
2+
3+
This module defines the protocol for extending A2A wrapper functionality
4+
with custom logic for conversation processing, prompt augmentation, and
5+
agent response handling.
6+
"""
7+
8+
from __future__ import annotations
9+
10+
from collections.abc import Sequence
11+
from typing import TYPE_CHECKING, Any, Protocol
12+
13+
14+
if TYPE_CHECKING:
15+
from a2a.types import Message
16+
17+
from crewai.agent.core import Agent
18+
19+
20+
class ConversationState(Protocol):
21+
"""Protocol for extension-specific conversation state.
22+
23+
Extensions can define their own state classes that implement this protocol
24+
to track conversation-specific data extracted from message history.
25+
"""
26+
27+
def is_ready(self) -> bool:
28+
"""Check if the state indicates readiness for some action.
29+
30+
Returns:
31+
True if the state is ready, False otherwise.
32+
"""
33+
...
34+
35+
36+
class A2AExtension(Protocol):
37+
"""Protocol for A2A wrapper extensions.
38+
39+
Extensions can implement this protocol to inject custom logic into
40+
the A2A conversation flow at various integration points.
41+
"""
42+
43+
def inject_tools(self, agent: Agent) -> None:
44+
"""Inject extension-specific tools into the agent.
45+
46+
Called when an agent is wrapped with A2A capabilities. Extensions
47+
can add tools that enable extension-specific functionality.
48+
49+
Args:
50+
agent: The agent instance to inject tools into.
51+
"""
52+
...
53+
54+
def extract_state_from_history(
55+
self, conversation_history: Sequence[Message]
56+
) -> ConversationState | None:
57+
"""Extract extension-specific state from conversation history.
58+
59+
Called during prompt augmentation to allow extensions to analyze
60+
the conversation history and extract relevant state information.
61+
62+
Args:
63+
conversation_history: The sequence of A2A messages exchanged.
64+
65+
Returns:
66+
Extension-specific conversation state, or None if no relevant state.
67+
"""
68+
...
69+
70+
def augment_prompt(
71+
self,
72+
base_prompt: str,
73+
conversation_state: ConversationState | None,
74+
) -> str:
75+
"""Augment the task prompt with extension-specific instructions.
76+
77+
Called during prompt augmentation to allow extensions to add
78+
custom instructions based on conversation state.
79+
80+
Args:
81+
base_prompt: The base prompt to augment.
82+
conversation_state: Extension-specific state from extract_state_from_history.
83+
84+
Returns:
85+
The augmented prompt with extension-specific instructions.
86+
"""
87+
...
88+
89+
def process_response(
90+
self,
91+
agent_response: Any,
92+
conversation_state: ConversationState | None,
93+
) -> Any:
94+
"""Process and potentially modify the agent response.
95+
96+
Called after parsing the agent's response, allowing extensions to
97+
enhance or modify the response based on conversation state.
98+
99+
Args:
100+
agent_response: The parsed agent response.
101+
conversation_state: Extension-specific state from extract_state_from_history.
102+
103+
Returns:
104+
The processed agent response (may be modified or original).
105+
"""
106+
...
107+
108+
109+
class ExtensionRegistry:
110+
"""Registry for managing A2A extensions.
111+
112+
Maintains a collection of extensions and provides methods to invoke
113+
their hooks at various integration points.
114+
"""
115+
116+
def __init__(self) -> None:
117+
"""Initialize the extension registry."""
118+
self._extensions: list[A2AExtension] = []
119+
120+
def register(self, extension: A2AExtension) -> None:
121+
"""Register an extension.
122+
123+
Args:
124+
extension: The extension to register.
125+
"""
126+
self._extensions.append(extension)
127+
128+
def inject_all_tools(self, agent: Agent) -> None:
129+
"""Inject tools from all registered extensions.
130+
131+
Args:
132+
agent: The agent instance to inject tools into.
133+
"""
134+
for extension in self._extensions:
135+
extension.inject_tools(agent)
136+
137+
def extract_all_states(
138+
self, conversation_history: Sequence[Message]
139+
) -> dict[type[A2AExtension], ConversationState]:
140+
"""Extract conversation states from all registered extensions.
141+
142+
Args:
143+
conversation_history: The sequence of A2A messages exchanged.
144+
145+
Returns:
146+
Mapping of extension types to their conversation states.
147+
"""
148+
states: dict[type[A2AExtension], ConversationState] = {}
149+
for extension in self._extensions:
150+
state = extension.extract_state_from_history(conversation_history)
151+
if state is not None:
152+
states[type(extension)] = state
153+
return states
154+
155+
def augment_prompt_with_all(
156+
self,
157+
base_prompt: str,
158+
extension_states: dict[type[A2AExtension], ConversationState],
159+
) -> str:
160+
"""Augment prompt with instructions from all registered extensions.
161+
162+
Args:
163+
base_prompt: The base prompt to augment.
164+
extension_states: Mapping of extension types to conversation states.
165+
166+
Returns:
167+
The fully augmented prompt.
168+
"""
169+
augmented = base_prompt
170+
for extension in self._extensions:
171+
state = extension_states.get(type(extension))
172+
augmented = extension.augment_prompt(augmented, state)
173+
return augmented
174+
175+
def process_response_with_all(
176+
self,
177+
agent_response: Any,
178+
extension_states: dict[type[A2AExtension], ConversationState],
179+
) -> Any:
180+
"""Process response through all registered extensions.
181+
182+
Args:
183+
agent_response: The parsed agent response.
184+
extension_states: Mapping of extension types to conversation states.
185+
186+
Returns:
187+
The processed agent response.
188+
"""
189+
processed = agent_response
190+
for extension in self._extensions:
191+
state = extension_states.get(type(extension))
192+
processed = extension.process_response(processed, state)
193+
return processed
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
"""Extension registry factory for A2A configurations.
2+
3+
This module provides utilities for creating extension registries from A2A configurations.
4+
"""
5+
6+
from __future__ import annotations
7+
8+
from typing import TYPE_CHECKING
9+
10+
from crewai.a2a.extensions.base import ExtensionRegistry
11+
12+
13+
if TYPE_CHECKING:
14+
from crewai.a2a.config import A2AConfig
15+
16+
17+
def create_extension_registry_from_config(
18+
a2a_config: list[A2AConfig] | A2AConfig,
19+
) -> ExtensionRegistry:
20+
"""Create an extension registry from A2A configuration.
21+
22+
Args:
23+
a2a_config: A2A configuration (single or list)
24+
25+
Returns:
26+
Configured extension registry with all applicable extensions
27+
"""
28+
registry = ExtensionRegistry()
29+
configs = a2a_config if isinstance(a2a_config, list) else [a2a_config]
30+
31+
for _ in configs:
32+
pass
33+
34+
return registry

lib/crewai/src/crewai/a2a/utils.py

Lines changed: 80 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
TextPart,
2424
TransportProtocol,
2525
)
26+
from aiocache import cached # type: ignore[import-untyped]
27+
from aiocache.serializers import PickleSerializer # type: ignore[import-untyped]
2628
import httpx
2729
from pydantic import BaseModel, Field, create_model
2830

@@ -65,7 +67,7 @@ def _fetch_agent_card_cached(
6567
endpoint: A2A agent endpoint URL
6668
auth_hash: Hash of the auth object
6769
timeout: Request timeout
68-
_ttl_hash: Time-based hash for cache invalidation (unused in body)
70+
_ttl_hash: Time-based hash for cache invalidation
6971
7072
Returns:
7173
Cached AgentCard
@@ -106,7 +108,18 @@ def fetch_agent_card(
106108
A2AClientHTTPError: If authentication fails
107109
"""
108110
if use_cache:
109-
auth_hash = hash((type(auth).__name__, id(auth))) if auth else 0
111+
if auth:
112+
auth_data = auth.model_dump_json(
113+
exclude={
114+
"_access_token",
115+
"_token_expires_at",
116+
"_refresh_token",
117+
"_authorization_callback",
118+
}
119+
)
120+
auth_hash = hash((type(auth).__name__, auth_data))
121+
else:
122+
auth_hash = 0
110123
_auth_store[auth_hash] = auth
111124
ttl_hash = int(time.time() // cache_ttl)
112125
return _fetch_agent_card_cached(endpoint, auth_hash, timeout, ttl_hash)
@@ -121,6 +134,26 @@ def fetch_agent_card(
121134
loop.close()
122135

123136

137+
@cached(ttl=300, serializer=PickleSerializer()) # type: ignore[untyped-decorator]
138+
async def _fetch_agent_card_async_cached(
139+
endpoint: str,
140+
auth_hash: int,
141+
timeout: int,
142+
) -> AgentCard:
143+
"""Cached async implementation of AgentCard fetching.
144+
145+
Args:
146+
endpoint: A2A agent endpoint URL
147+
auth_hash: Hash of the auth object
148+
timeout: Request timeout in seconds
149+
150+
Returns:
151+
Cached AgentCard object
152+
"""
153+
auth = _auth_store.get(auth_hash)
154+
return await _fetch_agent_card_async(endpoint=endpoint, auth=auth, timeout=timeout)
155+
156+
124157
async def _fetch_agent_card_async(
125158
endpoint: str,
126159
auth: AuthScheme | None,
@@ -339,7 +372,22 @@ async def _execute_a2a_delegation_async(
339372
Returns:
340373
Dictionary with status, result/error, and new history
341374
"""
342-
agent_card = await _fetch_agent_card_async(endpoint, auth, timeout)
375+
if auth:
376+
auth_data = auth.model_dump_json(
377+
exclude={
378+
"_access_token",
379+
"_token_expires_at",
380+
"_refresh_token",
381+
"_authorization_callback",
382+
}
383+
)
384+
auth_hash = hash((type(auth).__name__, auth_data))
385+
else:
386+
auth_hash = 0
387+
_auth_store[auth_hash] = auth
388+
agent_card = await _fetch_agent_card_async_cached(
389+
endpoint=endpoint, auth_hash=auth_hash, timeout=timeout
390+
)
343391

344392
validate_auth_against_agent_card(agent_card, auth)
345393

@@ -556,6 +604,34 @@ async def _execute_a2a_delegation_async(
556604
}
557605
break
558606
except Exception as e:
607+
if isinstance(e, A2AClientHTTPError):
608+
error_msg = f"HTTP Error {e.status_code}: {e!s}"
609+
610+
error_message = Message(
611+
role=Role.agent,
612+
message_id=str(uuid.uuid4()),
613+
parts=[Part(root=TextPart(text=error_msg))],
614+
context_id=context_id,
615+
task_id=task_id,
616+
)
617+
new_messages.append(error_message)
618+
619+
crewai_event_bus.emit(
620+
None,
621+
A2AResponseReceivedEvent(
622+
response=error_msg,
623+
turn_number=turn_number,
624+
is_multiturn=is_multiturn,
625+
status="failed",
626+
agent_role=agent_role,
627+
),
628+
)
629+
return {
630+
"status": "failed",
631+
"error": error_msg,
632+
"history": new_messages,
633+
}
634+
559635
current_exception: Exception | BaseException | None = e
560636
while current_exception:
561637
if hasattr(current_exception, "response"):
@@ -752,4 +828,5 @@ def get_a2a_agents_and_response_model(
752828
Tuple of A2A agent IDs and response model
753829
"""
754830
a2a_agents, agent_ids = extract_a2a_agent_ids_from_config(a2a_config=a2a_config)
831+
755832
return a2a_agents, create_agent_response_model(agent_ids)

0 commit comments

Comments
 (0)