Skip to content

Commit 7855bd8

Browse files
authored
Merge branch 'customize-wpp' into codex/create-simple-chat-agent-implementation
2 parents 94125ce + a52f96a commit 7855bd8

File tree

9 files changed

+458
-2
lines changed

9 files changed

+458
-2
lines changed

src/backend/app_kernel.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from fastapi import FastAPI, HTTPException, Query, Request
2121
from fastapi.middleware.cors import CORSMiddleware
2222
from kernel_agents.agent_factory import AgentFactory
23+
from pydantic import BaseModel
2324

2425
# Local imports
2526
from middleware.health_check import HealthCheckMiddleware
@@ -168,6 +169,55 @@ async def user_browser_language_endpoint(
168169
return {"status": "Language received successfully"}
169170

170171

172+
class ChatRequest(BaseModel):
173+
"""Request model for the simple chat endpoint."""
174+
175+
session_id: str
176+
user_message: str
177+
178+
179+
@app.post("/api/chat")
180+
async def chat_endpoint(chat_request: ChatRequest, request: Request):
181+
"""Handle a simple chat message from the user."""
182+
183+
authenticated_user = get_authenticated_user_details(request_headers=request.headers)
184+
user_id = authenticated_user["user_principal_id"]
185+
186+
if not user_id:
187+
track_event_if_configured(
188+
"UserIdNotFound", {"status_code": 400, "detail": "no user"}
189+
)
190+
raise HTTPException(status_code=400, detail="no user")
191+
192+
kernel, memory_store = await initialize_runtime_and_context(
193+
chat_request.session_id, user_id
194+
)
195+
196+
client = None
197+
try:
198+
client = config.get_ai_project_client()
199+
except Exception as client_exc: # pylint: disable=broad-except
200+
logging.error(f"Error creating AIProjectClient: {client_exc}")
201+
202+
simple_chat_agent = await AgentFactory.create_agent(
203+
agent_type=AgentType.GENERIC,
204+
session_id=chat_request.session_id,
205+
user_id=user_id,
206+
memory_store=memory_store,
207+
client=client,
208+
)
209+
210+
reply = await simple_chat_agent.handle_user_message(chat_request.user_message)
211+
212+
if client:
213+
try:
214+
client.close()
215+
except Exception as e: # pylint: disable=broad-except
216+
logging.error(f"Error sending to AIProjectClient: {e}")
217+
218+
return {"reply": reply}
219+
220+
171221
@app.post("/api/input_task")
172222
async def input_task_endpoint(input_task: InputTask, request: Request):
173223
"""

src/backend/kernel_agents/agent_base.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def __init__(
2828
system_message: Optional[str] = None,
2929
client=None,
3030
definition=None,
31+
summary_after_n_turns: int = 20,
3132
):
3233
"""Initialize the base agent.
3334
@@ -68,6 +69,7 @@ def __init__(
6869
self._tools = tools
6970
self._system_message = system_message
7071
self._chat_history = [{"role": "system", "content": self._system_message}]
72+
self._summary_after_n_turns: int = summary_after_n_turns
7173
# self._agent = None # Will be initialized in async_init
7274

7375
# Required properties for AgentGroupChat compatibility
@@ -86,6 +88,44 @@ def default_system_message(agent_name=None) -> str:
8688
name = agent_name
8789
return f"You are an AI assistant named {name}. Help the user by providing accurate and helpful information."
8890

91+
async def _add_message_to_history(self, role: str, content: str) -> None:
92+
"""Append a message to chat history and summarize if needed."""
93+
self._chat_history.append({"role": role, "content": content})
94+
await self._summarize_chat_history_if_needed()
95+
96+
async def _extend_chat_history(self, messages: List[Mapping[str, str]]) -> None:
97+
"""Extend chat history with multiple messages and summarize if needed."""
98+
self._chat_history.extend(messages)
99+
await self._summarize_chat_history_if_needed()
100+
101+
def _build_summary(self, messages: List[Mapping[str, str]]) -> str:
102+
"""Create a simple summary from a list of messages."""
103+
return " \n".join(m["content"] for m in messages)
104+
105+
async def _summarize_chat_history_if_needed(self) -> None:
106+
"""Summarize chat history after N turns and replace older entries."""
107+
if len(self._chat_history) - 1 <= self._summary_after_n_turns:
108+
return
109+
110+
# Messages to summarize exclude the system prompt and the most recent N messages
111+
messages_to_summarize = self._chat_history[1:-self._summary_after_n_turns]
112+
summary_text = self._build_summary(messages_to_summarize)
113+
114+
await self._memory_store.add_item(
115+
AgentMessage(
116+
session_id=self._session_id,
117+
user_id=self._user_id,
118+
plan_id="summary",
119+
content=summary_text,
120+
source=self._agent_name,
121+
)
122+
)
123+
124+
self._chat_history = (
125+
[self._chat_history[0], {"role": "system", "content": summary_text}]
126+
+ self._chat_history[-self._summary_after_n_turns :]
127+
)
128+
89129
async def handle_action_request(self, action_request: ActionRequest) -> str:
90130
"""Handle an action request from another agent or the system.
91131
@@ -112,7 +152,7 @@ async def handle_action_request(self, action_request: ActionRequest) -> str:
112152

113153
# Add messages to chat history for context
114154
# This gives the agent visibility of the conversation history
115-
self._chat_history.extend(
155+
await self._extend_chat_history(
116156
[
117157
{"role": "assistant", "content": action_request.action},
118158
{

src/backend/kernel_agents/agent_factory.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from kernel_agents.procurement_agent import ProcurementAgent
2121
from kernel_agents.product_agent import ProductAgent
2222
from kernel_agents.tech_support_agent import TechSupportAgent
23+
from kernel_agents.simple_chat_agent import SimpleChatAgent
2324
from models.messages_kernel import AgentType, PlannerResponsePlan
2425
# pylint:disable=E0611
2526
from semantic_kernel.agents.azure_ai.azure_ai_agent import AzureAIAgent
@@ -41,6 +42,7 @@ class AgentFactory:
4142
AgentType.HUMAN: HumanAgent,
4243
AgentType.PLANNER: PlannerAgent,
4344
AgentType.GROUP_CHAT_MANAGER: GroupChatManager, # Add GroupChatManager
45+
AgentType.SIMPLE_CHAT: SimpleChatAgent,
4446
}
4547

4648
# Mapping of agent types to their string identifiers (for automatic tool loading)
@@ -54,6 +56,7 @@ class AgentFactory:
5456
AgentType.HUMAN: AgentType.HUMAN.value,
5557
AgentType.PLANNER: AgentType.PLANNER.value,
5658
AgentType.GROUP_CHAT_MANAGER: AgentType.GROUP_CHAT_MANAGER.value,
59+
AgentType.SIMPLE_CHAT: AgentType.SIMPLE_CHAT.value,
5760
}
5861

5962
# System messages for each agent type
@@ -67,6 +70,7 @@ class AgentFactory:
6770
AgentType.HUMAN: HumanAgent.default_system_message(),
6871
AgentType.PLANNER: PlannerAgent.default_system_message(),
6972
AgentType.GROUP_CHAT_MANAGER: GroupChatManager.default_system_message(),
73+
AgentType.SIMPLE_CHAT: SimpleChatAgent.default_system_message(),
7074
}
7175

7276
# Cache of agent instances by session_id and agent_type

src/backend/kernel_agents/simple_chat_agent.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
12
from typing import List, Optional
23

34
from context.cosmos_memory_kernel import CosmosMemoryContext
@@ -75,3 +76,4 @@ async def handle_user_message(self, content: str) -> str:
7576

7677
return response_content
7778

79+

src/backend/models/messages_kernel.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ class AgentType(str, Enum):
4848
TECH_SUPPORT = "Tech_Support_Agent"
4949
GROUP_CHAT_MANAGER = "Group_Chat_Manager"
5050
PLANNER = "Planner_Agent"
51+
SIMPLE_CHAT = "Simple_Chat_Agent"
5152

5253
# Add other agents as needed
5354

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
import asyncio
2+
import asyncio
3+
import os
4+
import sys
5+
from unittest.mock import MagicMock
6+
import types
7+
8+
# Ensure modules under src/backend are importable
9+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
10+
11+
# Mock Azure dependencies required by app_config
12+
azure_module = types.ModuleType("azure")
13+
sys.modules["azure"] = azure_module
14+
sys.modules["azure.ai"] = types.ModuleType("azure.ai")
15+
sys.modules["azure.ai.projects"] = types.ModuleType("azure.ai.projects")
16+
projects_aio = types.ModuleType("azure.ai.projects.aio")
17+
projects_aio.AIProjectClient = MagicMock()
18+
sys.modules["azure.ai.projects.aio"] = projects_aio
19+
cosmos_module = types.ModuleType("azure.cosmos")
20+
sys.modules["azure.cosmos"] = cosmos_module
21+
cosmos_aio = types.ModuleType("azure.cosmos.aio")
22+
cosmos_aio.CosmosClient = MagicMock()
23+
sys.modules["azure.cosmos.aio"] = cosmos_aio
24+
partition_key_module = types.ModuleType("azure.cosmos.partition_key")
25+
partition_key_module.PartitionKey = MagicMock()
26+
sys.modules["azure.cosmos.partition_key"] = partition_key_module
27+
azure_monitor_module = types.ModuleType("azure.monitor")
28+
sys.modules["azure.monitor"] = azure_monitor_module
29+
events_module = types.ModuleType("azure.monitor.events")
30+
sys.modules["azure.monitor.events"] = events_module
31+
events_ext_module = types.ModuleType("azure.monitor.events.extension")
32+
events_ext_module.track_event = MagicMock()
33+
sys.modules["azure.monitor.events.extension"] = events_ext_module
34+
sys.modules["azure.monitor.opentelemetry"] = types.ModuleType("azure.monitor.opentelemetry")
35+
identity_module = types.ModuleType("azure.identity")
36+
identity_module.ManagedIdentityCredential = MagicMock()
37+
identity_module.DefaultAzureCredential = MagicMock()
38+
sys.modules["azure.identity"] = identity_module
39+
identity_aio_module = types.ModuleType("azure.identity.aio")
40+
identity_aio_module.ManagedIdentityCredential = MagicMock()
41+
identity_aio_module.DefaultAzureCredential = MagicMock()
42+
sys.modules["azure.identity.aio"] = identity_aio_module
43+
44+
# Mock semantic kernel dependencies
45+
sys.modules["semantic_kernel"] = types.ModuleType("semantic_kernel")
46+
kernel_module = types.ModuleType("semantic_kernel.kernel")
47+
class Kernel: # pragma: no cover - simple stub for testing
48+
pass
49+
kernel_module.Kernel = Kernel
50+
sys.modules["semantic_kernel.kernel"] = kernel_module
51+
sys.modules["semantic_kernel.agents"] = types.ModuleType("semantic_kernel.agents")
52+
sys.modules["semantic_kernel.agents.azure_ai"] = types.ModuleType("semantic_kernel.agents.azure_ai")
53+
azure_ai_agent_module = types.ModuleType("semantic_kernel.agents.azure_ai.azure_ai_agent")
54+
class AzureAIAgent: # pragma: no cover - simple stub for testing
55+
def __init__(self, *args, **kwargs):
56+
pass
57+
azure_ai_agent_module.AzureAIAgent = AzureAIAgent
58+
sys.modules["semantic_kernel.agents.azure_ai.azure_ai_agent"] = azure_ai_agent_module
59+
functions_module = types.ModuleType("semantic_kernel.functions")
60+
class KernelFunction: # pragma: no cover - simple stub for testing
61+
pass
62+
functions_module.KernelFunction = KernelFunction
63+
sys.modules["semantic_kernel.functions"] = functions_module
64+
numpy_module = types.ModuleType("numpy")
65+
class ndarray: # pragma: no cover - stub
66+
pass
67+
numpy_module.ndarray = ndarray
68+
sys.modules["numpy"] = numpy_module
69+
sys.modules["semantic_kernel.memory"] = types.ModuleType("semantic_kernel.memory")
70+
memory_record_module = types.ModuleType("semantic_kernel.memory.memory_record")
71+
class MemoryRecord: # pragma: no cover - stub
72+
pass
73+
memory_record_module.MemoryRecord = MemoryRecord
74+
sys.modules["semantic_kernel.memory.memory_record"] = memory_record_module
75+
memory_store_module = types.ModuleType("semantic_kernel.memory.memory_store_base")
76+
class MemoryStoreBase: # pragma: no cover - stub
77+
pass
78+
memory_store_module.MemoryStoreBase = MemoryStoreBase
79+
sys.modules["semantic_kernel.memory.memory_store_base"] = memory_store_module
80+
contents_module = types.ModuleType("semantic_kernel.contents")
81+
class ChatMessageContent: # pragma: no cover - stub
82+
def __init__(self, *args, **kwargs):
83+
pass
84+
class ChatHistory(list):
85+
pass
86+
class AuthorRole: # pragma: no cover - stub
87+
pass
88+
contents_module.ChatMessageContent = ChatMessageContent
89+
contents_module.ChatHistory = ChatHistory
90+
contents_module.AuthorRole = AuthorRole
91+
sys.modules["semantic_kernel.contents"] = contents_module
92+
kernel_pydantic_module = types.ModuleType("semantic_kernel.kernel_pydantic")
93+
class Field: # pragma: no cover - stub
94+
def __init__(self, *args, **kwargs):
95+
pass
96+
class KernelBaseModel: # pragma: no cover - stub
97+
def __init__(self, **data):
98+
for k, v in data.items():
99+
setattr(self, k, v)
100+
101+
def model_dump(self):
102+
return self.__dict__
103+
kernel_pydantic_module.Field = Field
104+
kernel_pydantic_module.KernelBaseModel = KernelBaseModel
105+
sys.modules["semantic_kernel.kernel_pydantic"] = kernel_pydantic_module
106+
107+
# Provide required environment variables for AppConfig
108+
os.environ.setdefault("AZURE_OPENAI_DEPLOYMENT_NAME", "test")
109+
os.environ.setdefault("AZURE_OPENAI_API_VERSION", "2024-05-01")
110+
os.environ.setdefault("AZURE_OPENAI_ENDPOINT", "https://test")
111+
os.environ.setdefault("AZURE_AI_SUBSCRIPTION_ID", "sub")
112+
os.environ.setdefault("AZURE_AI_RESOURCE_GROUP", "rg")
113+
os.environ.setdefault("AZURE_AI_PROJECT_NAME", "proj")
114+
os.environ.setdefault("AZURE_AI_AGENT_ENDPOINT", "https://agent")
115+
116+
from src.backend.kernel_agents.agent_base import BaseAgent
117+
from src.backend.models.messages_kernel import AgentMessage
118+
119+
120+
class StubMemoryStore:
121+
def __init__(self):
122+
self.items = []
123+
124+
async def add_item(self, item):
125+
self.items.append(item)
126+
127+
128+
class DummyAgent(BaseAgent):
129+
def __init__(self, memory_store, summary_after_n_turns=4):
130+
# Bypass parent initialization for testing summarization helpers
131+
self._agent_name = "dummy"
132+
self._session_id = "session"
133+
self._user_id = "user"
134+
self._memory_store = memory_store
135+
self._tools = []
136+
self._system_message = "system"
137+
self._chat_history = [{"role": "system", "content": self._system_message}]
138+
self._summary_after_n_turns = summary_after_n_turns
139+
self.name = self._agent_name
140+
141+
@classmethod
142+
async def create(cls, **kwargs) -> "BaseAgent":
143+
raise NotImplementedError
144+
145+
146+
def test_chat_history_summarization_and_truncation():
147+
store = StubMemoryStore()
148+
agent = DummyAgent(store, summary_after_n_turns=4)
149+
150+
async def run():
151+
for i in range(6):
152+
await agent._add_message_to_history("user", f"message {i}")
153+
154+
asyncio.run(run())
155+
156+
# system message + summary + last 4 messages
157+
assert len(agent._chat_history) == 6
158+
summary_entry = agent._chat_history[1]
159+
assert summary_entry["role"] == "system"
160+
assert "message 0" in summary_entry["content"]
161+
assert "message 1" in summary_entry["content"]
162+
assert agent._chat_history[-1]["content"] == "message 5"
163+
164+
assert len(store.items) >= 1
165+
assert store.items[-1].content == summary_entry["content"]

0 commit comments

Comments
 (0)