Skip to content

Commit 4f65ae6

Browse files
fix: Integrate the Agent SDK into the chart generation
2 parents a279612 + 0351918 commit 4f65ae6

File tree

5 files changed

+250
-101
lines changed

5 files changed

+250
-101
lines changed
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
from azure.identity import DefaultAzureCredential
2+
from azure.ai.projects import AIProjectClient
3+
4+
from agents.agent_factory_base import BaseAgentFactory
5+
6+
7+
class ChartAgentFactory(BaseAgentFactory):
8+
"""
9+
Factory class for creating Chart agents that generate chart.js compatible JSON
10+
based on numerical and structured data from RAG responses.
11+
"""
12+
13+
@classmethod
14+
async def create_agent(cls, config):
15+
"""
16+
Asynchronously creates an AI agent configured to convert structured data
17+
into chart.js-compatible JSON using Azure AI Project.
18+
19+
Args:
20+
config: Configuration object containing AI project and model settings.
21+
22+
Returns:
23+
dict: A dictionary containing the created 'agent' and its associated 'client'.
24+
"""
25+
instructions = """You are an assistant that helps generate valid chart data to be shown using chart.js with version 4.4.4 compatible.
26+
Include chart type and chart options.
27+
Pick the best chart type for given data.
28+
Do not generate a chart unless the input contains some numbers. Otherwise return a message that Chart cannot be generated.
29+
Only return a valid JSON output and nothing else.
30+
Verify that the generated JSON can be parsed using json.loads.
31+
Do not include tooltip callbacks in JSON.
32+
Always make sure that the generated json can be rendered in chart.js.
33+
Always remove any extra trailing commas.
34+
Verify and refine that JSON should not have any syntax errors like extra closing brackets.
35+
Ensure Y-axis labels are fully visible by increasing **ticks.padding**, **ticks.maxWidth**, or enabling word wrapping where necessary.
36+
Ensure bars and data points are evenly spaced and not squished or cropped at **100%** resolution by maintaining appropriate **barPercentage** and **categoryPercentage** values."""
37+
38+
project_client = AIProjectClient(
39+
endpoint=config.ai_project_endpoint,
40+
credential=DefaultAzureCredential(exclude_interactive_browser_credential=False),
41+
api_version=config.ai_project_api_version,
42+
)
43+
44+
agent = project_client.agents.create_agent(
45+
model=config.azure_openai_deployment_model,
46+
name=f"KM-ChartAgent-{config.solution_name}",
47+
instructions=instructions,
48+
)
49+
50+
return {
51+
"agent": agent,
52+
"client": project_client
53+
}
54+
55+
@classmethod
56+
async def _delete_agent_instance(cls, agent_wrapper: dict):
57+
"""
58+
Asynchronously deletes the specified chart agent instance from the Azure AI project.
59+
60+
Args:
61+
agent_wrapper (dict): Dictionary containing the 'agent' and 'client' to be removed.
62+
"""
63+
agent_wrapper["client"].agents.delete_agent(agent_wrapper["agent"].id)

src/api/app.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from agents.conversation_agent_factory import ConversationAgentFactory
1818
from agents.search_agent_factory import SearchAgentFactory
1919
from agents.sql_agent_factory import SQLAgentFactory
20+
from agents.chart_agent_factory import ChartAgentFactory
2021
from api.api_routes import router as backend_router
2122
from api.history_routes import router as history_router
2223

@@ -34,13 +35,16 @@ async def lifespan(fastapi_app: FastAPI):
3435
fastapi_app.state.agent = await ConversationAgentFactory.get_agent()
3536
fastapi_app.state.search_agent = await SearchAgentFactory.get_agent()
3637
fastapi_app.state.sql_agent = await SQLAgentFactory.get_agent()
38+
fastapi_app.state.chart_agent = await ChartAgentFactory.get_agent()
3739
yield
3840
await ConversationAgentFactory.delete_agent()
3941
await SearchAgentFactory.delete_agent()
4042
await SQLAgentFactory.delete_agent()
43+
await ChartAgentFactory.delete_agent()
4144
fastapi_app.state.sql_agent = None
4245
fastapi_app.state.search_agent = None
4346
fastapi_app.state.agent = None
47+
fastapi_app.state.chart_agent = None
4448

4549

4650
def build_app() -> FastAPI:

src/api/services/chat_service.py

Lines changed: 44 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,13 @@
2121
from semantic_kernel.agents import AzureAIAgentThread
2222
from semantic_kernel.exceptions.agent_exceptions import AgentException
2323

24-
from azure.ai.agents.models import TruncationObject
24+
from azure.ai.agents.models import TruncationObject, MessageRole, ListSortOrder
2525

2626
from cachetools import TTLCache
2727

2828
from helpers.utils import format_stream_response
29-
from helpers.azure_openai_helper import get_azure_openai_client
3029
from common.config.config import Config
30+
from agents.chart_agent_factory import ChartAgentFactory
3131

3232
# Constants
3333
HOST_NAME = "CKM"
@@ -86,47 +86,59 @@ def __init__(self, request : Request):
8686
if ChatService.thread_cache is None:
8787
ChatService.thread_cache = ExpCache(maxsize=1000, ttl=3600.0, agent=self.agent)
8888

89-
def process_rag_response(self, rag_response, query):
89+
async def process_rag_response(self, rag_response, query):
9090
"""
91-
Parses the RAG response dynamically to extract chart data for Chart.js.
91+
Uses the ChartAgent directly (agentic call) to extract chart data for Chart.js.
9292
"""
9393
try:
94-
client = get_azure_openai_client()
95-
96-
system_prompt = """You are an assistant that helps generate valid chart data to be shown using chart.js with version 4.4.4 compatible.
97-
Include chart type and chart options.
98-
Pick the best chart type for given data.
99-
Do not generate a chart unless the input contains some numbers. Otherwise return a message that Chart cannot be generated.
100-
Only return a valid JSON output and nothing else.
101-
Verify that the generated JSON can be parsed using json.loads.
102-
Do not include tooltip callbacks in JSON.
103-
Always make sure that the generated json can be rendered in chart.js.
104-
Always remove any extra trailing commas.
105-
Verify and refine that JSON should not have any syntax errors like extra closing brackets.
106-
Ensure Y-axis labels are fully visible by increasing **ticks.padding**, **ticks.maxWidth**, or enabling word wrapping where necessary.
107-
Ensure bars and data points are evenly spaced and not squished or cropped at **100%** resolution by maintaining appropriate **barPercentage** and **categoryPercentage** values."""
10894
user_prompt = f"""Generate chart data for -
10995
{query}
11096
{rag_response}
11197
"""
112-
logger.info(">>> Processing chart data for response: %s", rag_response)
113-
114-
completion = client.chat.completions.create(
115-
model=self.azure_openai_deployment_name,
116-
messages=[
117-
{"role": "system", "content": system_prompt},
118-
{"role": "user", "content": user_prompt},
119-
],
120-
temperature=0,
98+
99+
agent_info = await ChartAgentFactory.get_agent()
100+
agent = agent_info["agent"]
101+
client = agent_info["client"]
102+
103+
thread = client.agents.threads.create()
104+
105+
client.agents.messages.create(
106+
thread_id=thread.id,
107+
role=MessageRole.USER,
108+
content=user_prompt
109+
)
110+
111+
run = client.agents.runs.create_and_process(
112+
thread_id=thread.id,
113+
agent_id=agent.id
121114
)
122115

123-
chart_data = completion.choices[0].message.content.strip().replace("```json", "").replace("```", "")
124-
logger.info(">>> Generated chart data: %s", chart_data)
116+
if run.status == "failed":
117+
print(f"[Chart Agent] Run failed: {run.last_error}")
118+
return {"error": "Chart could not be generated due to agent failure."}
119+
120+
chart_json = ""
121+
messages = client.agents.messages.list(thread_id=thread.id, order=ListSortOrder.ASCENDING)
122+
for msg in messages:
123+
if msg.role == MessageRole.AGENT and msg.text_messages:
124+
chart_json = msg.text_messages[-1].text.value.strip()
125+
break
126+
127+
client.agents.threads.delete(thread_id=thread.id)
128+
129+
chart_json = chart_json.replace("```json", "").replace("```", "").strip()
130+
chart_data = json.loads(chart_json)
131+
132+
if not chart_data or "error" in chart_data:
133+
return {
134+
"error": chart_data.get("error", "Chart could not be generated from this data."),
135+
"hint": "Try asking a question with some numerical values, like 'sales per region' or 'calls per day'."
136+
}
125137

126-
return json.loads(chart_data)
138+
return chart_data
127139

128140
except Exception as e:
129-
logger.error("Error processing RAG response: %s", e)
141+
logger.error("Agent error in chart generation: %s", e)
130142
return {"error": "Chart could not be generated from this data. Please ask a different question."}
131143

132144
async def stream_openai_text(self, conversation_id: str, query: str) -> StreamingResponse:
@@ -254,7 +266,7 @@ async def complete_chat_request(self, query, last_rag_response=None):
254266
return {"error": "A previous RAG response is required to generate a chart."}
255267

256268
# Process RAG response to generate chart data
257-
chart_data = self.process_rag_response(last_rag_response, query)
269+
chart_data = await self.process_rag_response(last_rag_response, query)
258270

259271
if not chart_data or "error" in chart_data:
260272
return {
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import pytest
2+
from unittest.mock import AsyncMock, patch, MagicMock
3+
from agents.chart_agent_factory import ChartAgentFactory
4+
5+
6+
@pytest.mark.asyncio
7+
@patch("agents.chart_agent_factory.DefaultAzureCredential")
8+
@patch("agents.chart_agent_factory.AIProjectClient")
9+
async def test_create_agent_success(mock_ai_project_client_class, mock_credential_class):
10+
# Mock config
11+
mock_config = MagicMock()
12+
mock_config.ai_project_endpoint = "https://example-endpoint/"
13+
mock_config.ai_project_api_version = "2024-04-01-preview"
14+
mock_config.azure_openai_deployment_model = "gpt-4"
15+
mock_config.solution_name = "TestSolution"
16+
17+
# Mock client and agent
18+
mock_agent = MagicMock()
19+
mock_client = MagicMock()
20+
mock_client.agents.create_agent.return_value = mock_agent
21+
mock_ai_project_client_class.return_value = mock_client
22+
mock_credential_class.return_value = MagicMock()
23+
24+
# Call create_agent
25+
result = await ChartAgentFactory.create_agent(mock_config)
26+
27+
# Assertions
28+
assert result["agent"] == mock_agent
29+
assert result["client"] == mock_client
30+
mock_ai_project_client_class.assert_called_once_with(
31+
endpoint=mock_config.ai_project_endpoint,
32+
credential=mock_credential_class.return_value,
33+
api_version=mock_config.ai_project_api_version
34+
)
35+
mock_client.agents.create_agent.assert_called_once()
36+
37+
38+
@pytest.mark.asyncio
39+
async def test_delete_agent_instance():
40+
mock_client = MagicMock()
41+
mock_agent = MagicMock()
42+
mock_agent.id = "mock-agent-id"
43+
44+
agent_wrapper = {
45+
"agent": mock_agent,
46+
"client": mock_client
47+
}
48+
49+
await ChartAgentFactory._delete_agent_instance(agent_wrapper)
50+
51+
mock_client.agents.delete_agent.assert_called_once_with("mock-agent-id")

0 commit comments

Comments
 (0)