|
8 | 8 | from langchain_community.chat_message_histories import ChatMessageHistory |
9 | 9 | from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage |
10 | 10 | from langchain_openai import ChatOpenAI |
| 11 | +from mcp_use import MCPAgent |
11 | 12 |
|
12 | 13 | from datu.app_config import get_logger, settings |
13 | 14 | from datu.base.llm_client import BaseLLMClient |
@@ -65,15 +66,54 @@ class OpenAIClient(BaseLLMClient): |
65 | 66 | """ |
66 | 67 |
|
67 | 68 | def __init__(self): |
| 69 | + """Initializes the OpenAIClient with the configured model and API key.""" |
| 70 | + super().__init__() |
68 | 71 | self.model = getattr(settings, "openai_model", "gpt-4o-mini") |
69 | 72 | self.client = ChatOpenAI( |
70 | 73 | api_key=settings.openai_api_key, |
71 | 74 | model=self.model, |
72 | 75 | temperature=settings.llm_temperature, |
73 | 76 | ) |
74 | 77 | self.history = ChatMessageHistory() |
| 78 | + self.agent = None |
| 79 | + if settings.enable_mcp: |
| 80 | + if not self.mcp_client: |
| 81 | + raise RuntimeError("MCP is enabled but mcp_client was not initialized. ") |
| 82 | + try: |
| 83 | + self.agent = MCPAgent( |
| 84 | + llm=self.client, |
| 85 | + client=self.mcp_client, |
| 86 | + max_steps=settings.mcp.max_steps, |
| 87 | + use_server_manager=settings.mcp.use_server_manager, |
| 88 | + ) |
| 89 | + except Exception: |
| 90 | + # Prefer failing early so misconfig doesn’t silently degrade behavior |
| 91 | + logger.exception("Failed to construct MCPAgent with provided MCP settings.") |
| 92 | + raise |
| 93 | + |
| 94 | + async def chat(self, input_text: str) -> str: |
| 95 | + """Sends a chat message to the MCP agent and returns the response. |
| 96 | + Args: |
| 97 | + input_text (str): The input text to send to the agent. |
| 98 | + Returns: |
| 99 | + str: The response from the agent.""" |
75 | 100 |
|
76 | | - def chat_completion(self, messages: list[BaseMessage], system_prompt: str | None = None) -> str: |
| 101 | + if not settings.enable_mcp or self.agent is None: |
| 102 | + raise RuntimeError("chat() requires MCP enabled and an initialized agent.") |
| 103 | + response = await self.agent.run( |
| 104 | + input_text, |
| 105 | + max_steps=30, |
| 106 | + ) |
| 107 | + return response |
| 108 | + |
| 109 | + async def chat_completion(self, messages: list[BaseMessage], system_prompt: str | None = None) -> str: |
| 110 | + """Generates a chat completion response based on the provided messages and system prompt. |
| 111 | + Args: |
| 112 | + messages (list[BaseMessage]): A list of messages to send to the LLM. |
| 113 | + system_prompt (str | None): An optional system prompt to guide the LLM's response. |
| 114 | + Returns: |
| 115 | + str: The generated response from the LLM. |
| 116 | + """ |
77 | 117 | if settings.simulate_llm_response: |
78 | 118 | return create_simulated_llm_response() |
79 | 119 | if not messages: |
@@ -114,9 +154,25 @@ def chat_completion(self, messages: list[BaseMessage], system_prompt: str | None |
114 | 154 | ) |
115 | 155 |
|
116 | 156 | self.history.add_message(HumanMessage(content=last_user_message)) |
117 | | - response = self.client.invoke(self.history.messages) |
118 | | - self.history.add_message(response) |
119 | | - return response.content if response else "" |
| 157 | + # Convert entire history messages to plain text for chat() |
| 158 | + # Adjust this if your llm_with_tools expects different format |
| 159 | + input_text = "\n".join(msg.content for msg in self.history.messages if hasattr(msg, "content")) |
| 160 | + |
| 161 | + if settings.enable_mcp: |
| 162 | + # uses MCP agent |
| 163 | + response = await self.chat(input_text) |
| 164 | + else: |
| 165 | + # direct LLM call without MCP |
| 166 | + response = await self.client.ainvoke(self.history.messages) |
| 167 | + |
| 168 | + # Assuming response is a BaseMessage or similar with 'content' |
| 169 | + if hasattr(response, "content"): |
| 170 | + self.history.add_message(response) |
| 171 | + return response.content |
| 172 | + else: |
| 173 | + # If response is plain text string |
| 174 | + self.history.add_message(HumanMessage(content=response)) |
| 175 | + return response |
120 | 176 |
|
121 | 177 | def fix_sql_error(self, sql_code: str, error_msg: str, loop_count: int) -> str: |
122 | 178 | """Generates a corrected SQL query based on the provided SQL code and error message. |
|
0 commit comments