Skip to content

Commit 1ed89d4

Browse files
Python: updated API in sync with dotnet (#269)
* updated API in sync with dotnet * fix test * updated name and display_name * fixed mypy setup * add pre-commit cache
1 parent dc993b4 commit 1ed89d4

File tree

13 files changed

+137
-83
lines changed

13 files changed

+137
-83
lines changed

.github/workflows/python-code-quality.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,13 @@ jobs:
3838
cache-dependency-glob: "**/uv.lock"
3939
- name: Install the project
4040
run: uv sync --all-extras --dev
41+
- uses: actions/cache@v3
42+
with:
43+
path: ~/.cache/pre-commit
44+
key: pre-commit|${{ matrix.python-version }}|${{ hashFiles('python/.pre-commit-config.yaml') }}
4145
- uses: pre-commit/[email protected]
4246
name: Run Pre-Commit Hooks
4347
with:
4448
extra_args: --config python/.pre-commit-config.yaml --all-files
4549
- name: Run Mypy
46-
run: uv run mypy -p agent_framework
50+
run: uv run poe mypy

python/packages/azure/pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ exclude_dirs = ["tests"]
7474
[tool.poe]
7575
executor.type = "uv"
7676
include = "../../shared_tasks.toml"
77+
[tool.poe.tasks]
78+
mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_azure"
7779

7880
[tool.uv.build-backend]
7981
module-name = "agent_framework_azure"

python/packages/foundry/agent_framework_foundry/_chat_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ class FoundrySettings(AFBaseSettings):
9696
class FoundryChatClient(ChatClientBase):
9797
"""Azure AI Foundry Chat client."""
9898

99-
MODEL_PROVIDER_NAME: ClassVar[str] = "azure_ai_foundry" # type: ignore[reportIncompatibleVariableOverride]
99+
MODEL_PROVIDER_NAME: ClassVar[str] = "azure_ai_foundry" # type: ignore[reportIncompatibleVariableOverride, misc]
100100
client: AIProjectClient = Field(...)
101101
credential: AsyncTokenCredential | None = Field(...)
102102
agent_id: str | None = Field(default=None)

python/packages/foundry/pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ exclude_dirs = ["tests"]
7676
[tool.poe]
7777
executor.type = "uv"
7878
include = "../../shared_tasks.toml"
79+
[tool.poe.tasks]
80+
mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_foundry"
7981

8082
[tool.uv.build-backend]
8183
module-name = "agent_framework_foundry"

python/packages/main/agent_framework/_agents.py

Lines changed: 75 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,6 @@
66
from typing import Any, Literal, Protocol, TypeVar, runtime_checkable
77
from uuid import uuid4
88

9-
if sys.version_info >= (3, 11):
10-
from typing import Self # pragma: no cover
11-
else:
12-
from typing_extensions import Self # pragma: no cover
13-
149
from pydantic import BaseModel, Field
1510

1611
from ._clients import ChatClient
@@ -28,12 +23,17 @@
2823
)
2924
from .exceptions import AgentExecutionException
3025

26+
if sys.version_info >= (3, 11):
27+
from typing import Self # pragma: no cover
28+
else:
29+
from typing_extensions import Self # pragma: no cover
30+
3131
TThreadType = TypeVar("TThreadType", bound="AgentThread")
3232

3333
# region AgentThread
3434

3535
__all__ = [
36-
"Agent",
36+
"AIAgent",
3737
"AgentBase",
3838
"AgentThread",
3939
"ChatClientAgent",
@@ -77,7 +77,7 @@ def get_messages(self) -> AsyncIterable[ChatMessage]:
7777

7878

7979
@runtime_checkable
80-
class Agent(Protocol):
80+
class AIAgent(Protocol):
8181
"""A protocol for an agent that can be invoked."""
8282

8383
@property
@@ -90,6 +90,11 @@ def name(self) -> str | None:
9090
"""Returns the name of the agent."""
9191
...
9292

93+
@property
94+
def display_name(self) -> str:
95+
"""Returns the display name of the agent."""
96+
...
97+
9398
@property
9499
def description(self) -> str | None:
95100
"""Returns the description of the agent."""
@@ -124,7 +129,7 @@ async def run(
124129
"""
125130
...
126131

127-
def run_stream(
132+
def run_streaming(
128133
self,
129134
messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
130135
*,
@@ -157,17 +162,19 @@ def get_new_thread(self) -> AgentThread:
157162

158163

159164
class AgentBase(AFBaseModel):
160-
"""Base class for all agents.
165+
"""Base class for all Agent Framework agents.
161166
162167
Attributes:
163168
id: The unique identifier of the agent If no id is provided,
164169
a new UUID will be generated.
165-
name: The name of the agent
166-
description: The description of the agent
170+
name: The name of the agent, can be None.
171+
description: The description of the agent.
172+
display_name: The display name of the agent, which is either the name or id.
173+
167174
"""
168175

169176
id: str = Field(default_factory=lambda: str(uuid4()))
170-
name: str = Field(default="UnnamedAgent")
177+
name: str | None = None
171178
description: str | None = None
172179

173180
async def _notify_thread_of_new_messages(
@@ -177,6 +184,43 @@ async def _notify_thread_of_new_messages(
177184
if isinstance(new_messages, ChatMessage) or len(new_messages) > 0:
178185
await thread.on_new_messages(new_messages)
179186

187+
@property
188+
def display_name(self) -> str:
189+
"""Returns the display name of the agent.
190+
191+
This is the name if present, otherwise the id.
192+
"""
193+
return self.name or self.id
194+
195+
def _validate_or_create_thread_type(
196+
self,
197+
thread: AgentThread | None,
198+
construct_thread: Callable[[], TThreadType],
199+
expected_type: type[TThreadType],
200+
) -> TThreadType:
201+
"""Validate or create a AgentThread of the right type.
202+
203+
Args:
204+
thread: The thread to validate or create.
205+
construct_thread: A callable that constructs a new thread if `thread` is None.
206+
expected_type: The expected type of the thread.
207+
208+
Returns:
209+
The validated or newly created thread of the expected type.
210+
211+
Raises:
212+
AgentExecutionException: If the thread is not of the expected type.
213+
"""
214+
if thread is None:
215+
return construct_thread()
216+
217+
if not isinstance(thread, expected_type):
218+
raise AgentExecutionException(
219+
f"{self.__class__.__name__} currently only supports agent threads of type {expected_type.__name__}."
220+
)
221+
222+
return thread
223+
180224

181225
# region ChatClientAgentThread
182226

@@ -442,13 +486,7 @@ async def run(
442486
will only be passed to functions that are called.
443487
"""
444488
input_messages = self._normalize_messages(messages)
445-
446-
thread, thread_messages = await self._prepare_thread_and_messages(
447-
thread=thread,
448-
input_messages=input_messages,
449-
construct_thread=lambda: ChatClientAgentThread(),
450-
expected_type=ChatClientAgentThread,
451-
)
489+
thread, thread_messages = await self._prepare_thread_and_messages(thread=thread, input_messages=input_messages)
452490

453491
response = await self.chat_client.get_response(
454492
messages=thread_messages,
@@ -491,7 +529,7 @@ async def run(
491529
additional_properties=response.additional_properties,
492530
)
493531

494-
async def run_stream(
532+
async def run_streaming(
495533
self,
496534
messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
497535
*,
@@ -523,7 +561,7 @@ async def run_stream(
523561
"""Stream the agent with the given messages and options.
524562
525563
Remarks:
526-
Since you won't always call the agent.run_stream directly, but it get's called
564+
Since you won't always call the agent.run_streaming directly, but it get's called
527565
through orchestration, it is advised to set your default values for
528566
all the chat client parameters in the agent constructor.
529567
If both parameters are used, the ones passed to the run methods take precedence.
@@ -552,14 +590,7 @@ async def run_stream(
552590
553591
"""
554592
input_messages = self._normalize_messages(messages)
555-
556-
thread, thread_messages = await self._prepare_thread_and_messages(
557-
thread=thread,
558-
input_messages=input_messages,
559-
construct_thread=lambda: ChatClientAgentThread(),
560-
expected_type=ChatClientAgentThread,
561-
)
562-
593+
thread, thread_messages = await self._prepare_thread_and_messages(thread=thread, input_messages=input_messages)
563594
response_updates: list[ChatResponseUpdate] = []
564595

565596
async for update in self.chat_client.get_streaming_response(
@@ -607,7 +638,7 @@ async def run_stream(
607638
await self._notify_thread_of_new_messages(thread, input_messages)
608639
await self._notify_thread_of_new_messages(thread, response.messages)
609640

610-
def get_new_thread(self) -> AgentThread:
641+
def get_new_thread(self) -> ChatClientAgentThread:
611642
return ChatClientAgentThread()
612643

613644
def _update_thread_with_type_and_conversation_id(
@@ -644,45 +675,32 @@ async def _prepare_thread_and_messages(
644675
*,
645676
thread: AgentThread | None,
646677
input_messages: list[ChatMessage] | None = None,
647-
construct_thread: Callable[[], TThreadType],
648-
expected_type: type[TThreadType],
649-
) -> tuple[TThreadType, list[ChatMessage]]:
650-
"""Prepare thread and messages for agent execution.
678+
) -> tuple[ChatClientAgentThread, list[ChatMessage]]:
679+
"""Prepare the messages for agent execution.
651680
652681
Args:
653-
thread: The conversation thread, or None to create a new one.
682+
thread: The conversation thread.
654683
input_messages: Messages to process.
655-
construct_thread: Factory function to create a new thread.
656-
expected_type: Expected thread type for validation.
657684
658685
Returns:
659-
Tuple of the thread and normalized messages.
686+
The validated thread and normalized messages.
660687
661688
Raises:
662-
AgentExecutionException: If thread type is incompatible.
689+
AgentExecutionException: If the thread is not of the expected type.
663690
"""
691+
validated_thread: ChatClientAgentThread = self._validate_or_create_thread_type( # type: ignore[reportAssignmentType]
692+
thread=thread,
693+
construct_thread=self.get_new_thread,
694+
expected_type=ChatClientAgentThread,
695+
)
664696
messages: list[ChatMessage] = []
665697
if self.instructions:
666698
messages.append(ChatMessage(role=ChatRole.SYSTEM, text=self.instructions))
667-
668-
if thread is None:
669-
thread = construct_thread()
670-
671-
if not isinstance(thread, expected_type):
672-
raise AgentExecutionException(
673-
f"{self.__class__.__name__} currently only supports agent threads of type {expected_type.__name__}."
674-
)
675-
676-
# Add any existing messages from the thread to the messages to be sent to the chat client.
677-
if isinstance(thread, MessagesRetrievableThread):
678-
async for message in thread.get_messages():
699+
if isinstance(validated_thread, MessagesRetrievableThread):
700+
async for message in validated_thread.get_messages():
679701
messages.append(message)
680-
681-
if input_messages is None:
682-
return thread, messages
683-
684-
messages.extend(input_messages)
685-
return thread, messages
702+
messages.extend(input_messages or [])
703+
return validated_thread, messages
686704

687705
def _normalize_messages(
688706
self,

python/packages/main/pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ exclude_dirs = ["tests"]
8787
[tool.poe]
8888
executor.type = "uv"
8989
include = "../../shared_tasks.toml"
90+
[tool.poe.tasks]
91+
mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework"
9092

9193
[tool.uv.build-backend]
9294
module-name = "agent_framework"

0 commit comments

Comments
 (0)