Skip to content

Commit d8f6127

Browse files
authored
chore: simplify agent executor implementation (#1730)
Signed-off-by: Radek Ježek <[email protected]>
1 parent f6c7e07 commit d8f6127

File tree

12 files changed

+370
-322
lines changed

12 files changed

+370
-322
lines changed

apps/agentstack-sdk-py/src/agentstack_sdk/a2a/extensions/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def from_agent_card(cls, agent: a2a.types.AgentCard) -> typing.Self | None:
111111
return None
112112

113113

114-
ExtensionSpecT = typing.TypeVar("ExtensionSpecT", bound=BaseExtensionSpec)
114+
ExtensionSpecT = typing.TypeVar("ExtensionSpecT", bound=BaseExtensionSpec[typing.Any])
115115

116116

117117
class BaseExtensionServer(abc.ABC, typing.Generic[ExtensionSpecT, MetadataFromClientT]):

apps/agentstack-sdk-py/src/agentstack_sdk/a2a/extensions/ui/citation.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from __future__ import annotations
55

66
from types import NoneType
7-
from typing import Any
87

98
import pydantic
109
from a2a.types import DataPart, FilePart, Part, TextPart
@@ -59,11 +58,7 @@ class CitationExtensionSpec(NoParamsBaseExtensionSpec):
5958

6059

6160
class CitationExtensionServer(BaseExtensionServer[CitationExtensionSpec, NoneType]):
62-
def citation_metadata(
63-
self,
64-
*,
65-
citations: list[Citation],
66-
) -> Metadata[str, Any]:
61+
def citation_metadata(self, *, citations: list[Citation]) -> Metadata:
6762
return Metadata({self.spec.URI: CitationMetadata(citations=citations).model_dump(mode="json")})
6863

6964
def message(

apps/agentstack-sdk-py/src/agentstack_sdk/a2a/extensions/ui/error.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def context(self) -> JsonDict:
144144
)
145145
return {}
146146

147-
def error_metadata(self, error: BaseException) -> Metadata[str, Any]:
147+
def error_metadata(self, error: BaseException) -> Metadata:
148148
"""
149149
Create metadata for an error.
150150

apps/agentstack-sdk-py/src/agentstack_sdk/a2a/extensions/ui/trajectory.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from __future__ import annotations
55

66
from types import NoneType
7-
from typing import Any
87

98
import pydantic
109
from a2a.types import DataPart, FilePart, Part, TextPart
@@ -49,7 +48,7 @@ class TrajectoryExtensionSpec(NoParamsBaseExtensionSpec):
4948
class TrajectoryExtensionServer(BaseExtensionServer[TrajectoryExtensionSpec, NoneType]):
5049
def trajectory_metadata(
5150
self, *, title: str | None = None, content: str | None = None, group_id: str | None = None
52-
) -> Metadata[str, Any]:
51+
) -> Metadata:
5352
return Metadata(
5453
{self.spec.URI: Trajectory(title=title, content=content, group_id=group_id).model_dump(mode="json")}
5554
)

apps/agentstack-sdk-py/src/agentstack_sdk/a2a/types.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
22
# SPDX-License-Identifier: Apache-2.0
3-
import typing
43
import uuid
5-
from typing import Generic, Literal, TypeAlias, Union
4+
from typing import TYPE_CHECKING, Literal, TypeAlias
65

76
from a2a.types import (
87
Artifact,
@@ -20,13 +19,20 @@
2019
TextPart,
2120
)
2221
from pydantic import Field, model_validator
23-
from typing_extensions import TypeAliasType
2422

25-
K = typing.TypeVar("K")
26-
V = typing.TypeVar("V")
23+
if TYPE_CHECKING:
24+
JsonValue: TypeAlias = list["JsonValue"] | dict[str, "JsonValue"] | str | bool | int | float | None
25+
JsonDict: TypeAlias = dict[str, JsonValue]
26+
else:
27+
from typing import Union
2728

29+
from typing_extensions import TypeAliasType
2830

29-
class Metadata(dict[K, V], Generic[K, V]): ...
31+
JsonValue = TypeAliasType("JsonValue", "Union[dict[str, JsonValue], list[JsonValue], str, int, float, bool, None]") # noqa: UP007
32+
JsonDict = TypeAliasType("JsonDict", "dict[str, JsonValue]")
33+
34+
35+
class Metadata(dict[str, JsonValue]): ...
3036

3137

3238
RunYield: TypeAlias = (
@@ -43,15 +49,15 @@ class Metadata(dict[K, V], Generic[K, V]): ...
4349
| TaskStatusUpdateEvent
4450
| TaskArtifactUpdateEvent
4551
| str
46-
| dict
52+
| JsonDict
4753
| Exception
4854
)
4955
RunYieldResume: TypeAlias = Message | None
5056

5157

5258
class AgentArtifact(Artifact):
5359
artifact_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
54-
parts: list[Part | TextPart | FilePart | DataPart] # pyright: ignore [reportIncompatibleVariableOverride]
60+
parts: list[Part | TextPart | FilePart | DataPart]
5561

5662
@model_validator(mode="after")
5763
def text_message_validate(self):
@@ -61,7 +67,7 @@ def text_message_validate(self):
6167

6268
class ArtifactChunk(Artifact):
6369
last_chunk: bool = False
64-
parts: list[Part | TextPart | FilePart | DataPart] # pyright: ignore [reportIncompatibleVariableOverride]
70+
parts: list[Part | TextPart | FilePart | DataPart]
6571

6672
@model_validator(mode="after")
6773
def text_message_validate(self):
@@ -105,9 +111,3 @@ def text_message_validate(self):
105111

106112
class AuthRequired(InputRequired):
107113
state: Literal[TaskState.auth_required] = TaskState.auth_required # pyright: ignore [reportIncompatibleVariableOverride]
108-
109-
110-
JsonDict = TypeAliasType(
111-
"JsonDict",
112-
"Union[dict[str, JsonDict], list[JsonDict], str, int, float, bool, None]", # pyright: ignore[reportDeprecated] # noqa: UP007
113-
)

0 commit comments

Comments
 (0)