Skip to content

Commit 544d2ba

Browse files
authored
Added author name logic to ChatClientAgent (#313)
1 parent f5b35d8 commit 544d2ba

File tree

2 files changed

+54
-2
lines changed

2 files changed

+54
-2
lines changed

python/packages/main/agent_framework/_agents.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,7 @@ async def run(
491491
"""
492492
input_messages = self._normalize_messages(messages)
493493
thread, thread_messages = await self._prepare_thread_and_messages(thread=thread, input_messages=input_messages)
494+
agent_name = self._get_agent_name()
494495

495496
response = await self.chat_client.get_response(
496497
messages=thread_messages,
@@ -519,6 +520,11 @@ async def run(
519520

520521
self._update_thread_with_type_and_conversation_id(thread, response.conversation_id)
521522

523+
# Ensure that the author name is set for each message in the response.
524+
for message in response.messages:
525+
if message.author_name is None:
526+
message.author_name = agent_name
527+
522528
# Only notify the thread of new messages if the chatResponse was successful
523529
# to avoid inconsistent messages state in the thread.
524530
await self._notify_thread_of_new_messages(thread, input_messages)
@@ -595,6 +601,7 @@ async def run_streaming(
595601
"""
596602
input_messages = self._normalize_messages(messages)
597603
thread, thread_messages = await self._prepare_thread_and_messages(thread=thread, input_messages=input_messages)
604+
agent_name = self._get_agent_name()
598605
response_updates: list[ChatResponseUpdate] = []
599606

600607
async for update in self.chat_client.get_streaming_response(
@@ -622,6 +629,10 @@ async def run_streaming(
622629
**kwargs,
623630
):
624631
response_updates.append(update)
632+
633+
if update.author_name is None:
634+
update.author_name = agent_name
635+
625636
yield AgentRunResponseUpdate(
626637
contents=update.contents,
627638
role=update.role,
@@ -720,3 +731,6 @@ def _normalize_messages(
720731
return [messages]
721732

722733
return [ChatMessage(role=ChatRole.USER, text=msg) if isinstance(msg, str) else msg for msg in messages]
734+
735+
def _get_agent_name(self) -> str:
736+
return self.name or "UnnamedAgent"

python/packages/main/tests/main/test_agents.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@ def name(self) -> str | None:
4343
"""Returns the name of the agent."""
4444
return "Name"
4545

46+
@property
47+
def display_name(self) -> str:
48+
"""Returns the name of the agent."""
49+
return "Display Name"
50+
4651
@property
4752
def description(self) -> str | None:
4853
return "Description"
@@ -243,7 +248,7 @@ async def test_chat_client_agent_prepare_thread_and_messages(chat_client: ChatCl
243248
message = ChatMessage(role=ChatRole.USER, text="Hello")
244249
thread = ChatClientAgentThread(messages=[message])
245250

246-
result_thread = agent._validate_or_create_thread_type(
251+
result_thread = agent._validate_or_create_thread_type( # type: ignore[reportPrivateUsage]
247252
thread, lambda: ChatClientAgentThread(), expected_type=ChatClientAgentThread
248253
) # type: ignore[reportPrivateUsage]
249254

@@ -264,7 +269,7 @@ async def test_chat_client_agent_validate_or_create_thread(chat_client: ChatClie
264269
agent = ChatClientAgent(chat_client=chat_client)
265270
thread = None
266271

267-
result_thread = agent._validate_or_create_thread_type(
272+
result_thread = agent._validate_or_create_thread_type( # type: ignore[reportPrivateUsage]
268273
thread, lambda: ChatClientAgentThread(), expected_type=ChatClientAgentThread
269274
) # type: ignore[reportPrivateUsage]
270275

@@ -313,3 +318,36 @@ async def test_chat_client_agent_update_thread_conversation_id_missing(chat_clie
313318

314319
with raises(AgentExecutionException, match="Service did not return a valid conversation id"):
315320
agent._update_thread_with_type_and_conversation_id(thread, None) # type: ignore[reportPrivateUsage]
321+
322+
323+
async def test_chat_client_agent_default_author_name(chat_client: ChatClient) -> None:
324+
# Name is not specified here, so default name should be used
325+
agent = ChatClientAgent(chat_client=chat_client)
326+
327+
result = await agent.run("Hello")
328+
assert result.text == "test response"
329+
assert result.messages[0].author_name == "UnnamedAgent"
330+
331+
332+
async def test_chat_client_agent_author_name_as_agent_name(chat_client: ChatClient) -> None:
333+
# Name is specified here, so it should be used as author name
334+
agent = ChatClientAgent(chat_client=chat_client, name="TestAgent")
335+
336+
result = await agent.run("Hello")
337+
assert result.text == "test response"
338+
assert result.messages[0].author_name == "TestAgent"
339+
340+
341+
async def test_chat_client_agent_author_name_is_used_from_response() -> None:
342+
chat_client = MockChatClient(
343+
mock_response=ChatResponse(
344+
messages=[
345+
ChatMessage(role=ChatRole.ASSISTANT, contents=[TextContent("test response")], author_name="TestAuthor")
346+
]
347+
)
348+
)
349+
agent = ChatClientAgent(chat_client=chat_client)
350+
351+
result = await agent.run("Hello")
352+
assert result.text == "test response"
353+
assert result.messages[0].author_name == "TestAuthor"

0 commit comments

Comments
 (0)