|
8 | 8 | from langchain_community.chat_message_histories import ChatMessageHistory |
9 | 9 | from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage |
10 | 10 | from langchain_openai import ChatOpenAI |
| 11 | +from mcp_use import MCPAgent |
11 | 12 |
|
12 | 13 | from datu.app_config import get_logger, settings |
13 | 14 | from datu.base.llm_client import BaseLLMClient |
@@ -65,15 +66,37 @@ class OpenAIClient(BaseLLMClient): |
65 | 66 | """ |
66 | 67 |
|
67 | 68 | def __init__(self): |
| 69 | + super().__init__() |
68 | 70 | self.model = getattr(settings, "openai_model", "gpt-4o-mini") |
69 | 71 | self.client = ChatOpenAI( |
70 | 72 | api_key=settings.openai_api_key, |
71 | 73 | model=self.model, |
72 | 74 | temperature=settings.llm_temperature, |
73 | 75 | ) |
74 | 76 | self.history = ChatMessageHistory() |
| 77 | + if settings.enable_mcp: |
| 78 | + if not self.mcp_client: |
| 79 | + raise RuntimeError("MCP is enabled but mcp_client was not initialized. ") |
| 80 | + try: |
| 81 | + self.agent = MCPAgent( |
| 82 | + llm=self.client, |
| 83 | + client=self.mcp_client, |
| 84 | + max_steps=settings.mcp.max_steps, |
| 85 | + use_server_manager=settings.mcp.use_server_manager, |
| 86 | + ) |
| 87 | + except Exception: |
| 88 | + # Prefer failing early so misconfig doesn’t silently degrade behavior |
| 89 | + logger.exception("Failed to construct MCPAgent with provided MCP settings.") |
| 90 | + raise |
| 91 | + |
| 92 | + async def chat(self, input_text: str) -> str: |
| 93 | + response = await self.agent.run( |
| 94 | + input_text, |
| 95 | + max_steps=30, |
| 96 | + ) |
| 97 | + return response |
75 | 98 |
|
76 | | - def chat_completion(self, messages: list[BaseMessage], system_prompt: str | None = None) -> str: |
| 99 | + async def chat_completion(self, messages: list[BaseMessage], system_prompt: str | None = None) -> str: |
77 | 100 | if settings.simulate_llm_response: |
78 | 101 | return create_simulated_llm_response() |
79 | 102 | if not messages: |
@@ -114,9 +137,20 @@ def chat_completion(self, messages: list[BaseMessage], system_prompt: str | None |
114 | 137 | ) |
115 | 138 |
|
116 | 139 | self.history.add_message(HumanMessage(content=last_user_message)) |
117 | | - response = self.client.invoke(self.history.messages) |
118 | | - self.history.add_message(response) |
119 | | - return response.content if response else "" |
| 140 | + # Convert entire history messages to plain text for chat() |
| 141 | + # Adjust this if your llm_with_tools expects different format |
| 142 | + input_text = "\n".join(msg.content for msg in self.history.messages if hasattr(msg, "content")) |
| 143 | + |
| 144 | + response = await self.chat(input_text) |
| 145 | + |
| 146 | + # Assuming response is a BaseMessage or similar with 'content' |
| 147 | + if hasattr(response, "content"): |
| 148 | + self.history.add_message(response) |
| 149 | + return response.content |
| 150 | + else: |
| 151 | + # If response is plain text string |
| 152 | + self.history.add_message(HumanMessage(content=response)) |
| 153 | + return response |
120 | 154 |
|
121 | 155 | def fix_sql_error(self, sql_code: str, error_msg: str, loop_count: int) -> str: |
122 | 156 | """Generates a corrected SQL query based on the provided SQL code and error message. |
|
0 commit comments