66from typing import Any , Literal , Protocol , TypeVar , runtime_checkable
77from uuid import uuid4
88
9- if sys .version_info >= (3 , 11 ):
10- from typing import Self # pragma: no cover
11- else :
12- from typing_extensions import Self # pragma: no cover
13-
149from pydantic import BaseModel , Field
1510
1611from ._clients import ChatClient
2823)
2924from .exceptions import AgentExecutionException
3025
26+ if sys .version_info >= (3 , 11 ):
27+ from typing import Self # pragma: no cover
28+ else :
29+ from typing_extensions import Self # pragma: no cover
30+
3131TThreadType = TypeVar ("TThreadType" , bound = "AgentThread" )
3232
3333# region AgentThread
3434
3535__all__ = [
36- "Agent " ,
36+ "AIAgent " ,
3737 "AgentBase" ,
3838 "AgentThread" ,
3939 "ChatClientAgent" ,
@@ -77,7 +77,7 @@ def get_messages(self) -> AsyncIterable[ChatMessage]:
7777
7878
7979@runtime_checkable
80- class Agent (Protocol ):
80+ class AIAgent (Protocol ):
8181 """A protocol for an agent that can be invoked."""
8282
8383 @property
@@ -90,6 +90,11 @@ def name(self) -> str | None:
9090 """Returns the name of the agent."""
9191 ...
9292
93+ @property
94+ def display_name (self ) -> str :
95+ """Returns the display name of the agent."""
96+ ...
97+
9398 @property
9499 def description (self ) -> str | None :
95100 """Returns the description of the agent."""
@@ -124,7 +129,7 @@ async def run(
124129 """
125130 ...
126131
127- def run_stream (
132+ def run_streaming (
128133 self ,
129134 messages : str | ChatMessage | list [str ] | list [ChatMessage ] | None = None ,
130135 * ,
@@ -157,17 +162,19 @@ def get_new_thread(self) -> AgentThread:
157162
158163
159164class AgentBase (AFBaseModel ):
160- """Base class for all agents.
165+ """Base class for all Agent Framework agents.
161166
162167 Attributes:
163168 id: The unique identifier of the agent If no id is provided,
164169 a new UUID will be generated.
165- name: The name of the agent
166- description: The description of the agent
170+ name: The name of the agent, can be None.
171+ description: The description of the agent.
172+ display_name: The display name of the agent, which is either the name or id.
173+
167174 """
168175
169176 id : str = Field (default_factory = lambda : str (uuid4 ()))
170- name : str = Field ( default = "UnnamedAgent" )
177+ name : str | None = None
171178 description : str | None = None
172179
173180 async def _notify_thread_of_new_messages (
@@ -177,6 +184,43 @@ async def _notify_thread_of_new_messages(
177184 if isinstance (new_messages , ChatMessage ) or len (new_messages ) > 0 :
178185 await thread .on_new_messages (new_messages )
179186
187+ @property
188+ def display_name (self ) -> str :
189+ """Returns the display name of the agent.
190+
191+ This is the name if present, otherwise the id.
192+ """
193+ return self .name or self .id
194+
195+ def _validate_or_create_thread_type (
196+ self ,
197+ thread : AgentThread | None ,
198+ construct_thread : Callable [[], TThreadType ],
199+ expected_type : type [TThreadType ],
200+ ) -> TThreadType :
201+ """Validate or create a AgentThread of the right type.
202+
203+ Args:
204+ thread: The thread to validate or create.
205+ construct_thread: A callable that constructs a new thread if `thread` is None.
206+ expected_type: The expected type of the thread.
207+
208+ Returns:
209+ The validated or newly created thread of the expected type.
210+
211+ Raises:
212+ AgentExecutionException: If the thread is not of the expected type.
213+ """
214+ if thread is None :
215+ return construct_thread ()
216+
217+ if not isinstance (thread , expected_type ):
218+ raise AgentExecutionException (
219+ f"{ self .__class__ .__name__ } currently only supports agent threads of type { expected_type .__name__ } ."
220+ )
221+
222+ return thread
223+
180224
181225# region ChatClientAgentThread
182226
@@ -442,13 +486,7 @@ async def run(
442486 will only be passed to functions that are called.
443487 """
444488 input_messages = self ._normalize_messages (messages )
445-
446- thread , thread_messages = await self ._prepare_thread_and_messages (
447- thread = thread ,
448- input_messages = input_messages ,
449- construct_thread = lambda : ChatClientAgentThread (),
450- expected_type = ChatClientAgentThread ,
451- )
489+ thread , thread_messages = await self ._prepare_thread_and_messages (thread = thread , input_messages = input_messages )
452490
453491 response = await self .chat_client .get_response (
454492 messages = thread_messages ,
@@ -491,7 +529,7 @@ async def run(
491529 additional_properties = response .additional_properties ,
492530 )
493531
494- async def run_stream (
532+ async def run_streaming (
495533 self ,
496534 messages : str | ChatMessage | list [str ] | list [ChatMessage ] | None = None ,
497535 * ,
@@ -523,7 +561,7 @@ async def run_stream(
523561 """Stream the agent with the given messages and options.
524562
525563 Remarks:
526- Since you won't always call the agent.run_stream directly, but it get's called
564+ Since you won't always call the agent.run_streaming directly, but it get's called
527565 through orchestration, it is advised to set your default values for
528566 all the chat client parameters in the agent constructor.
529567 If both parameters are used, the ones passed to the run methods take precedence.
@@ -552,14 +590,7 @@ async def run_stream(
552590
553591 """
554592 input_messages = self ._normalize_messages (messages )
555-
556- thread , thread_messages = await self ._prepare_thread_and_messages (
557- thread = thread ,
558- input_messages = input_messages ,
559- construct_thread = lambda : ChatClientAgentThread (),
560- expected_type = ChatClientAgentThread ,
561- )
562-
593+ thread , thread_messages = await self ._prepare_thread_and_messages (thread = thread , input_messages = input_messages )
563594 response_updates : list [ChatResponseUpdate ] = []
564595
565596 async for update in self .chat_client .get_streaming_response (
@@ -607,7 +638,7 @@ async def run_stream(
607638 await self ._notify_thread_of_new_messages (thread , input_messages )
608639 await self ._notify_thread_of_new_messages (thread , response .messages )
609640
610- def get_new_thread (self ) -> AgentThread :
641+ def get_new_thread (self ) -> ChatClientAgentThread :
611642 return ChatClientAgentThread ()
612643
613644 def _update_thread_with_type_and_conversation_id (
@@ -644,45 +675,32 @@ async def _prepare_thread_and_messages(
644675 * ,
645676 thread : AgentThread | None ,
646677 input_messages : list [ChatMessage ] | None = None ,
647- construct_thread : Callable [[], TThreadType ],
648- expected_type : type [TThreadType ],
649- ) -> tuple [TThreadType , list [ChatMessage ]]:
650- """Prepare thread and messages for agent execution.
678+ ) -> tuple [ChatClientAgentThread , list [ChatMessage ]]:
679+ """Prepare the messages for agent execution.
651680
652681 Args:
653- thread: The conversation thread, or None to create a new one .
682+ thread: The conversation thread.
654683 input_messages: Messages to process.
655- construct_thread: Factory function to create a new thread.
656- expected_type: Expected thread type for validation.
657684
658685 Returns:
659- Tuple of the thread and normalized messages.
686+ The validated thread and normalized messages.
660687
661688 Raises:
662- AgentExecutionException: If thread type is incompatible .
689+ AgentExecutionException: If the thread is not of the expected type .
663690 """
691+ validated_thread : ChatClientAgentThread = self ._validate_or_create_thread_type ( # type: ignore[reportAssignmentType]
692+ thread = thread ,
693+ construct_thread = self .get_new_thread ,
694+ expected_type = ChatClientAgentThread ,
695+ )
664696 messages : list [ChatMessage ] = []
665697 if self .instructions :
666698 messages .append (ChatMessage (role = ChatRole .SYSTEM , text = self .instructions ))
667-
668- if thread is None :
669- thread = construct_thread ()
670-
671- if not isinstance (thread , expected_type ):
672- raise AgentExecutionException (
673- f"{ self .__class__ .__name__ } currently only supports agent threads of type { expected_type .__name__ } ."
674- )
675-
676- # Add any existing messages from the thread to the messages to be sent to the chat client.
677- if isinstance (thread , MessagesRetrievableThread ):
678- async for message in thread .get_messages ():
699+ if isinstance (validated_thread , MessagesRetrievableThread ):
700+ async for message in validated_thread .get_messages ():
679701 messages .append (message )
680-
681- if input_messages is None :
682- return thread , messages
683-
684- messages .extend (input_messages )
685- return thread , messages
702+ messages .extend (input_messages or [])
703+ return validated_thread , messages
686704
687705 def _normalize_messages (
688706 self ,
0 commit comments