Skip to content

Commit 94125ce

Browse files
committed
Add simple chat agent with message persistence and tests
1 parent 8697c6c commit 94125ce

File tree

2 files changed

+122
-0
lines changed

2 files changed

+122
-0
lines changed
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
from typing import List, Optional
2+
3+
from context.cosmos_memory_kernel import CosmosMemoryContext
4+
from kernel_agents.agent_base import BaseAgent
5+
from models.messages_kernel import MessageRole, StoredMessage
6+
from semantic_kernel.functions import KernelFunction
7+
8+
9+
class SimpleChatAgent(BaseAgent):
10+
"""A minimal chat agent for open-ended conversation."""
11+
12+
def __init__(
13+
self,
14+
session_id: str,
15+
user_id: str,
16+
memory_store: CosmosMemoryContext,
17+
tools: Optional[List[KernelFunction]] = None,
18+
system_message: Optional[str] = None,
19+
agent_name: str = "SimpleChatAgent",
20+
client=None,
21+
definition=None,
22+
) -> None:
23+
super().__init__(
24+
agent_name=agent_name,
25+
session_id=session_id,
26+
user_id=user_id,
27+
memory_store=memory_store,
28+
tools=tools,
29+
system_message=system_message,
30+
client=client,
31+
definition=definition,
32+
)
33+
34+
@staticmethod
35+
def default_system_message(agent_name: str | None = None) -> str:
36+
"""Return the default system message for open-ended chat."""
37+
name = agent_name or "assistant"
38+
return (
39+
f"You are {name}, a friendly AI for open-ended conversation. "
40+
"Engage with the user naturally and helpfully."
41+
)
42+
43+
async def handle_user_message(self, content: str) -> str:
44+
"""Process a user message, storing it and returning the agent's reply."""
45+
# Record the user message locally and persist it
46+
self._chat_history.append({"role": "user", "content": content})
47+
await self._memory_store.add_item(
48+
StoredMessage(
49+
session_id=self._session_id,
50+
user_id=self._user_id,
51+
role=MessageRole.user,
52+
content=content,
53+
source=self._agent_name,
54+
)
55+
)
56+
57+
# Generate a reply from the underlying model
58+
async_generator = self.invoke(messages=str(self._chat_history), thread=None)
59+
response_content = ""
60+
async for chunk in async_generator:
61+
if chunk is not None:
62+
response_content += str(chunk)
63+
64+
# Record the assistant's response
65+
self._chat_history.append({"role": "assistant", "content": response_content})
66+
await self._memory_store.add_item(
67+
StoredMessage(
68+
session_id=self._session_id,
69+
user_id=self._user_id,
70+
role=MessageRole.assistant,
71+
content=response_content,
72+
source=self._agent_name,
73+
)
74+
)
75+
76+
return response_content
77+
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import os
2+
import pytest
3+
4+
# Set required environment variables before importing the agent
5+
os.environ.setdefault("AZURE_OPENAI_ENDPOINT", "https://mock-endpoint")
6+
os.environ.setdefault("AZURE_AI_SUBSCRIPTION_ID", "sub")
7+
os.environ.setdefault("AZURE_AI_RESOURCE_GROUP", "rg")
8+
os.environ.setdefault("AZURE_AI_PROJECT_NAME", "proj")
9+
os.environ.setdefault("AZURE_AI_AGENT_ENDPOINT", "https://agent-endpoint")
10+
11+
from kernel_agents.simple_chat_agent import SimpleChatAgent
12+
from models.messages_kernel import MessageRole
13+
14+
15+
class DummyMemoryStore:
16+
def __init__(self):
17+
self.items = []
18+
19+
async def add_item(self, item):
20+
self.items.append(item)
21+
22+
23+
@pytest.mark.asyncio
24+
async def test_handle_user_message_stores_and_replies():
25+
memory_store = DummyMemoryStore()
26+
agent = SimpleChatAgent(session_id="s", user_id="u", memory_store=memory_store)
27+
28+
async def fake_invoke(self, messages=None, thread=None):
29+
yield "Hi there!"
30+
31+
agent.invoke = fake_invoke.__get__(agent, SimpleChatAgent)
32+
33+
response = await agent.handle_user_message("Hello")
34+
35+
assert response == "Hi there!"
36+
assert agent._chat_history[-2:] == [
37+
{"role": "user", "content": "Hello"},
38+
{"role": "assistant", "content": "Hi there!"},
39+
]
40+
assert len(memory_store.items) == 2
41+
assert memory_store.items[0].content == "Hello"
42+
assert memory_store.items[0].role == MessageRole.user
43+
assert memory_store.items[1].content == "Hi there!"
44+
assert memory_store.items[1].role == MessageRole.assistant
45+

0 commit comments

Comments
 (0)