Skip to content

Commit eaef29a

Browse files
committed
wip
1 parent d8f6127 commit eaef29a

File tree

29 files changed

+410
-128
lines changed

29 files changed

+410
-128
lines changed

agents/chat/uv.lock

Lines changed: 25 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

apps/agentstack-sdk-py/pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ dependencies = [
2222
"httpx", # version determined by a2a-sdk
2323
"mcp>=1.12.3",
2424
"fastapi>=0.116.1",
25+
"authlib>=1.3.0",
26+
"async-lru>=2.0.4",
2527
]
2628

2729
[dependency-groups]
@@ -75,3 +77,4 @@ addopts = "-v"
7577
ignore = ["tests/**", "examples/cli.py"]
7678
venvPath = "."
7779
venv = ".venv"
80+
reportUnusedCallResult = "none"

apps/agentstack-sdk-py/src/agentstack_sdk/a2a/extensions/auth/oauth/oauth.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,13 @@
88
from typing import TYPE_CHECKING, Any, Self
99
from urllib.parse import parse_qs
1010

11-
import a2a.types
1211
import pydantic
12+
from a2a.server.agent_execution import RequestContext
13+
from a2a.types import Message as A2AMessage
14+
from a2a.types import Role, TextPart
1315
from mcp.client.auth import OAuthClientProvider
1416
from mcp.shared.auth import OAuthClientMetadata
17+
from typing_extensions import override
1518

1619
from agentstack_sdk.a2a.extensions.auth.oauth.storage import MemoryTokenStorageFactory, TokenStorageFactory
1720
from agentstack_sdk.a2a.extensions.base import BaseExtensionClient, BaseExtensionServer, BaseExtensionSpec
@@ -58,13 +61,17 @@ class OAuthExtensionMetadata(pydantic.BaseModel):
5861

5962

6063
class OAuthExtensionServer(BaseExtensionServer[OAuthExtensionSpec, OAuthExtensionMetadata]):
64+
context: RunContext
65+
token_storage_factory: TokenStorageFactory
66+
6167
def __init__(self, spec: OAuthExtensionSpec, token_storage_factory: TokenStorageFactory | None = None) -> None:
6268
super().__init__(spec)
6369
self.token_storage_factory = token_storage_factory or MemoryTokenStorageFactory()
6470

65-
def handle_incoming_message(self, message: a2a.types.Message, context: RunContext):
66-
super().handle_incoming_message(message, context)
67-
self.context = context
71+
@override
72+
def handle_incoming_message(self, message: A2AMessage, run_context: RunContext, request_context: RequestContext):
73+
super().handle_incoming_message(message, run_context, request_context)
74+
self.context = run_context
6875

6976
def _get_fulfillment_for_resource(self, resource_url: pydantic.AnyUrl):
7077
if not self.data:
@@ -117,7 +124,7 @@ def create_auth_request(self, *, authorization_endpoint_url: pydantic.AnyUrl):
117124
data = AuthRequest(authorization_endpoint_url=authorization_endpoint_url)
118125
return AgentMessage(text="Authorization required", metadata={self.spec.URI: data.model_dump(mode="json")})
119126

120-
def parse_auth_response(self, *, message: a2a.types.Message):
127+
def parse_auth_response(self, *, message: A2AMessage):
121128
if not message or not message.metadata or not (data := message.metadata.get(self.spec.URI)):
122129
raise RuntimeError("Invalid auth response")
123130
return AuthResponse.model_validate(data)
@@ -127,18 +134,18 @@ class OAuthExtensionClient(BaseExtensionClient[OAuthExtensionSpec, NoneType]):
127134
def fulfillment_metadata(self, *, oauth_fulfillments: dict[str, Any]) -> dict[str, Any]:
128135
return {self.spec.URI: OAuthExtensionMetadata(oauth_fulfillments=oauth_fulfillments).model_dump(mode="json")}
129136

130-
def parse_auth_request(self, *, message: a2a.types.Message):
137+
def parse_auth_request(self, *, message: A2AMessage):
131138
if not message or not message.metadata or not (data := message.metadata.get(self.spec.URI)):
132139
raise ValueError("Invalid auth request")
133140
return AuthRequest.model_validate(data)
134141

135142
def create_auth_response(self, *, task_id: str, redirect_uri: pydantic.AnyUrl):
136143
data = AuthResponse(redirect_uri=redirect_uri)
137144

138-
return a2a.types.Message(
145+
return A2AMessage(
139146
message_id=str(uuid.uuid4()),
140-
role=a2a.types.Role.user,
141-
parts=[a2a.types.TextPart(text="Authorization completed")], # type: ignore
147+
role=Role.user,
148+
parts=[TextPart(text="Authorization completed")], # type: ignore
142149
task_id=task_id,
143150
metadata={self.spec.URI: data.model_dump(mode="json")},
144151
)

apps/agentstack-sdk-py/src/agentstack_sdk/a2a/extensions/auth/secrets/secrets.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
22
# SPDX-License-Identifier: Apache-2.0
33

4-
import typing
5-
from typing import TYPE_CHECKING
4+
from __future__ import annotations
5+
6+
from typing import TYPE_CHECKING, Self
67

78
import pydantic
9+
from a2a.server.agent_execution.context import RequestContext
810
from a2a.types import Message as A2AMessage
11+
from typing_extensions import override
912

1013
from agentstack_sdk.a2a.extensions.base import BaseExtensionClient, BaseExtensionServer, BaseExtensionSpec
1114
from agentstack_sdk.a2a.types import AgentMessage, AuthRequired
@@ -35,7 +38,7 @@ class SecretsExtensionSpec(BaseExtensionSpec[SecretsServiceExtensionParams | Non
3538
URI: str = "https://a2a-extensions.agentstack.beeai.dev/auth/secrets/v1"
3639

3740
@classmethod
38-
def single_demand(cls, name: str, key: str | None = None, description: str | None = None) -> typing.Self:
41+
def single_demand(cls, name: str, key: str | None = None, description: str | None = None) -> Self:
3942
return cls(
4043
params=SecretsServiceExtensionParams(
4144
secret_demands={key or "default": SecretDemand(description=description, name=name)}
@@ -44,9 +47,12 @@ def single_demand(cls, name: str, key: str | None = None, description: str | Non
4447

4548

4649
class SecretsExtensionServer(BaseExtensionServer[SecretsExtensionSpec, SecretsServiceExtensionMetadata]):
47-
def handle_incoming_message(self, message: A2AMessage, context: "RunContext"):
48-
super().handle_incoming_message(message, context)
49-
self.context = context
50+
context: RunContext
51+
52+
@override
53+
def handle_incoming_message(self, message: A2AMessage, run_context: RunContext, request_context: RequestContext):
54+
super().handle_incoming_message(message, run_context, request_context)
55+
self.context = run_context
5056

5157
def parse_secret_response(self, message: A2AMessage) -> SecretsServiceExtensionMetadata:
5258
if not message or not message.metadata or not (data := message.metadata.get(self.spec.URI)):

apps/agentstack-sdk-py/src/agentstack_sdk/a2a/extensions/base.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,11 @@
99
from contextlib import asynccontextmanager
1010
from types import NoneType
1111

12-
import a2a.types
1312
import pydantic
13+
from a2a.server.agent_execution.context import RequestContext
14+
from a2a.types import AgentCard, AgentExtension
15+
from a2a.types import Message as A2AMessage
16+
from typing_extensions import override
1417

1518
ParamsT = typing.TypeVar("ParamsT")
1619
MetadataFromClientT = typing.TypeVar("MetadataFromClientT")
@@ -19,6 +22,7 @@
1922

2023
if typing.TYPE_CHECKING:
2124
from agentstack_sdk.server.context import RunContext
25+
from agentstack_sdk.server.dependencies import Dependency
2226

2327

2428
def _get_generic_args(cls: type, base_class: type) -> tuple[typing.Any, ...]:
@@ -68,7 +72,7 @@ def __init__(self, params: ParamsT) -> None:
6872
self.params = params
6973

7074
@classmethod
71-
def from_agent_card(cls, agent: a2a.types.AgentCard) -> typing.Self | None:
75+
def from_agent_card(cls, agent: AgentCard) -> typing.Self | None:
7276
"""
7377
Client should construct an extension instance using this classmethod.
7478
"""
@@ -81,14 +85,14 @@ def from_agent_card(cls, agent: a2a.types.AgentCard) -> typing.Self | None:
8185
except StopIteration:
8286
return None
8387

84-
def to_agent_card_extensions(self, *, required: bool = False) -> list[a2a.types.AgentExtension]:
88+
def to_agent_card_extensions(self, *, required: bool = False) -> list[AgentExtension]:
8589
"""
8690
Agent should use this method to obtain extension definitions to advertise on the agent card.
8791
This returns a list, as it's possible to support multiple A2A extensions within a single class.
8892
(Usually, that would be different versions of the extension spec.)
8993
"""
9094
return [
91-
a2a.types.AgentExtension(
95+
AgentExtension(
9296
uri=self.URI,
9397
description=self.DESCRIPTION,
9498
params=typing.cast(
@@ -105,7 +109,8 @@ def __init__(self):
105109
super().__init__(None)
106110

107111
@classmethod
108-
def from_agent_card(cls, agent: a2a.types.AgentCard) -> typing.Self | None:
112+
@override
113+
def from_agent_card(cls, agent: AgentCard) -> typing.Self | None:
109114
if any(e.uri == cls.URI for e in agent.capabilities.extensions or []):
110115
return cls()
111116
return None
@@ -125,7 +130,7 @@ def __init_subclass__(cls, **kwargs):
125130
cls.MetadataFromClient = _get_generic_args(cls, BaseExtensionServer)[1]
126131

127132
_metadata_from_client: MetadataFromClientT | None = None
128-
_dependencies: dict
133+
_dependencies: dict[str, Dependency] = {}
129134

130135
@property
131136
def data(self):
@@ -139,7 +144,7 @@ def __init__(self, spec: ExtensionSpecT, *args, **kwargs) -> None:
139144
self._args = args
140145
self._kwargs = kwargs
141146

142-
def parse_client_metadata(self, message: a2a.types.Message) -> MetadataFromClientT | None:
147+
def parse_client_metadata(self, message: A2AMessage) -> MetadataFromClientT | None:
143148
"""
144149
Server should use this method to retrieve extension-associated metadata from a message.
145150
"""
@@ -149,7 +154,7 @@ def parse_client_metadata(self, message: a2a.types.Message) -> MetadataFromClien
149154
else pydantic.TypeAdapter(self.MetadataFromClient).validate_python(message.metadata[self.spec.URI])
150155
)
151156

152-
def handle_incoming_message(self, message: a2a.types.Message, context: RunContext):
157+
def handle_incoming_message(self, message: A2AMessage, run_context: RunContext, request_context: RequestContext):
153158
if self._metadata_from_client is None:
154159
self._metadata_from_client = self.parse_client_metadata(message)
155160

@@ -158,12 +163,16 @@ def _fork(self) -> typing.Self:
158163
return type(self)(self.spec, *self._args, **self._kwargs)
159164

160165
def __call__(
161-
self, message: a2a.types.Message, context: RunContext, dependencies: dict[str, typing.Any]
166+
self,
167+
message: A2AMessage,
168+
run_context: RunContext,
169+
request_context: RequestContext,
170+
dependencies: dict[str, Dependency],
162171
) -> typing.Self:
163172
"""Works as a dependency constructor - create a private instance for the request"""
164173
instance = self._fork()
165174
instance._dependencies = dependencies
166-
instance.handle_incoming_message(message, context)
175+
instance.handle_incoming_message(message, run_context, request_context)
167176
return instance
168177

169178
@asynccontextmanager
@@ -185,7 +194,7 @@ def __init_subclass__(cls, **kwargs):
185194
def __init__(self, spec: ExtensionSpecT) -> None:
186195
self.spec = spec
187196

188-
def parse_server_metadata(self, message: a2a.types.Message) -> MetadataFromServerT | None:
197+
def parse_server_metadata(self, message: A2AMessage) -> MetadataFromServerT | None:
189198
"""
190199
Client should use this method to retrieve extension-associated metadata from a message.
191200
"""

apps/agentstack-sdk-py/src/agentstack_sdk/a2a/extensions/services/embedding.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,18 @@
55

66
import re
77
from types import NoneType
8-
from typing import Any, Self
8+
from typing import TYPE_CHECKING, Any, Self
99

1010
import pydantic
11+
from a2a.server.agent_execution.context import RequestContext
12+
from a2a.types import Message as A2AMessage
13+
from typing_extensions import override
1114

1215
from agentstack_sdk.a2a.extensions.base import BaseExtensionClient, BaseExtensionServer, BaseExtensionSpec
1316

17+
if TYPE_CHECKING:
18+
from agentstack_sdk.server.context import RunContext
19+
1420

1521
class EmbeddingFulfillment(pydantic.BaseModel):
1622
identifier: str | None = None
@@ -78,10 +84,11 @@ class EmbeddingServiceExtensionMetadata(pydantic.BaseModel):
7884
class EmbeddingServiceExtensionServer(
7985
BaseExtensionServer[EmbeddingServiceExtensionSpec, EmbeddingServiceExtensionMetadata]
8086
):
81-
def handle_incoming_message(self, message, context):
87+
@override
88+
def handle_incoming_message(self, message: A2AMessage, run_context: RunContext, request_context: RequestContext):
8289
from agentstack_sdk.platform import get_platform_client
8390

84-
super().handle_incoming_message(message, context)
91+
super().handle_incoming_message(message, run_context, request_context)
8592
if not self.data:
8693
return
8794

apps/agentstack-sdk-py/src/agentstack_sdk/a2a/extensions/services/llm.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@
88
from typing import TYPE_CHECKING, Any, Self
99

1010
import pydantic
11+
from a2a.server.agent_execution.context import RequestContext
12+
from a2a.types import Message as A2AMessage
13+
from typing_extensions import override
1114

1215
from agentstack_sdk.a2a.extensions.base import BaseExtensionClient, BaseExtensionServer, BaseExtensionSpec
1316

1417
if TYPE_CHECKING:
15-
from a2a.types import Message
16-
1718
from agentstack_sdk.server.context import RunContext
1819

1920

@@ -81,10 +82,11 @@ class LLMServiceExtensionMetadata(pydantic.BaseModel):
8182

8283

8384
class LLMServiceExtensionServer(BaseExtensionServer[LLMServiceExtensionSpec, LLMServiceExtensionMetadata]):
84-
def handle_incoming_message(self, message: Message, context: RunContext):
85+
@override
86+
def handle_incoming_message(self, message: A2AMessage, run_context: RunContext, request_context: RequestContext):
8587
from agentstack_sdk.platform import get_platform_client
8688

87-
super().handle_incoming_message(message, context)
89+
super().handle_incoming_message(message, run_context, request_context)
8890
if not self.data:
8991
return
9092

apps/agentstack-sdk-py/src/agentstack_sdk/a2a/extensions/services/mcp.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@
88
from types import NoneType
99
from typing import TYPE_CHECKING, Annotated, Any, Literal, Self
1010

11-
import a2a.types
1211
import pydantic
12+
from a2a.server.agent_execution.context import RequestContext
13+
from a2a.types import Message as A2AMessage
1314
from mcp.client.stdio import StdioServerParameters, stdio_client
1415
from mcp.client.streamable_http import streamablehttp_client
16+
from typing_extensions import override
1517

1618
from agentstack_sdk.a2a.extensions.auth.oauth.oauth import OAuthExtensionServer
1719
from agentstack_sdk.a2a.extensions.base import BaseExtensionClient, BaseExtensionServer, BaseExtensionSpec
@@ -102,8 +104,9 @@ class MCPServiceExtensionMetadata(pydantic.BaseModel):
102104

103105

104106
class MCPServiceExtensionServer(BaseExtensionServer[MCPServiceExtensionSpec, MCPServiceExtensionMetadata]):
105-
def handle_incoming_message(self, message: a2a.types.Message, context: RunContext):
106-
super().handle_incoming_message(message, context)
107+
@override
108+
def handle_incoming_message(self, message: A2AMessage, run_context: RunContext, request_context: RequestContext):
109+
super().handle_incoming_message(message, run_context, request_context)
107110
if not self.data:
108111
return
109112

@@ -115,7 +118,8 @@ def handle_incoming_message(self, message: a2a.types.Message, context: RunContex
115118
except Exception:
116119
logger.warning("Platform URL substitution failed", exc_info=True)
117120

118-
def parse_client_metadata(self, message: a2a.types.Message) -> MCPServiceExtensionMetadata | None:
121+
@override
122+
def parse_client_metadata(self, message: A2AMessage) -> MCPServiceExtensionMetadata | None:
119123
metadata = super().parse_client_metadata(message)
120124
if metadata:
121125
for name, demand in self.spec.params.mcp_demands.items():

0 commit comments

Comments
 (0)