Skip to content

Commit c088bb8

Browse files
authored
Mcp connectivity (#6)
1 parent 7193f42 commit c088bb8

26 files changed

+995
-884
lines changed

CHANGELOG.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
## [0.0.1](https://github.com/Datuanalytics/datu-core/tree/0.0.1) - 2025-08-18
2+
3+
### Added
4+
5+
- Basic LLM and postgres, msql integrations
6+
- MCP Connectivity
7+
- Prompt updated for welcoming and tool listing dialogue.
8+
- SQL generator is activated and connected by default.

changelog.d/+ab395eb1.added.md

Lines changed: 0 additions & 1 deletion
This file was deleted.

mcp_config.json

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
{
2+
"mcpServers": {
3+
"sql_generator": {
4+
"command": "python",
5+
"args": ["-m", "datu.mcp.tools.sql_generator"],
6+
"env": { "PYTHONPATH": "." }
7+
}
8+
}
9+
}

pyproject.toml

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,18 @@ dependencies = [
2727
"types-pyyaml>=6.0.12.20250402",
2828
"sql-metadata>=2.17.0",
2929
"sentence-transformers>=2.5.1",
30+
"safetensors>=0.6.2,<0.7",
31+
"transformers<4.44.0",
3032
"torch==2.2.2",
3133
"numpy<2.0.0",
3234
"networkx>=3.0",
3335
"langgraph>=0.0.37",
3436
"langchain-core>=0.1.41",
3537
"langchain-mcp-adapters>=0.1.8",
3638
"openai>=1.30.1",
37-
"fastmcp>=2.10.5"
39+
"fastmcp>=2.10.5",
40+
"mcp-use[search]>=1.3.7",
41+
"onnxruntime==1.19.2 ; sys_platform == 'darwin' and platform_machine == 'x86_64'",
3842
]
3943

4044
[project.urls]
@@ -298,11 +302,17 @@ ignore_missing_imports = true
298302
[[tool.mypy.overrides]]
299303
module = "langchain_mcp_adapters.tools"
300304
ignore_missing_imports = true
305+
[[tool.mypy.overrides]]
306+
module = "mcp_use"
307+
ignore_missing_imports = true
308+
[[tool.mypy.overrides]]
309+
module = "mcp_use.*"
310+
ignore_missing_imports = true
301311

302312
[tool.towncrier]
303313
directory = "changelog.d"
304314
filename = "CHANGELOG.md"
305-
package = "datu-core"
315+
package = "datu"
306316
package_dir = "src"
307317
start_string = "<!-- towncrier release notes start -->\n"
308318
underlines = ["", "", ""]

src/datu/app_config.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from pydantic_settings import BaseSettings, SettingsConfigDict
1616

1717
from datu.integrations.config import IntegrationConfigs
18+
from datu.mcp.config import MCPConfig
1819
from datu.services.config import SchemaRAGConfig
1920

2021

@@ -72,7 +73,11 @@ class DatuConfig(BaseSettings):
7273
simulate_llm_response (str): Whether to simulate LLM responses.
7374
schema_sample_limit (int): The maximum number of rows to sample from the schema.
7475
schema_categorical_threshold (int): The threshold for categorical columns in the schema.
76+
enable_mcp (bool): Whether to enable MCP integration.
77+
mcp (MCPConfig | None): Configuration settings for MCP integration.
7578
enable_schema_rag (bool): Enable RAG for schema extraction.
79+
schema_rag (SchemaRAGConfig | None): Configuration settings for schema RAG.
80+
7681
7782
"""
7883

@@ -94,6 +99,11 @@ class DatuConfig(BaseSettings):
9499
schema_categorical_detection: bool = True
95100
schema_sample_limit: int = 1000
96101
schema_categorical_threshold: int = 10
102+
enable_mcp: bool = False
103+
mcp: MCPConfig | None = Field(
104+
default_factory=MCPConfig,
105+
description="Configuration settings for MCP integration.",
106+
)
97107
enable_schema_rag: bool = False
98108
schema_rag: SchemaRAGConfig | None = Field(
99109
default_factory=SchemaRAGConfig,

src/datu/base/chat_schema.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from typing import List, Optional
2+
3+
from pydantic import BaseModel
4+
5+
6+
class ChatMessage(BaseModel):
7+
"""Represents a single message in a chat conversation.
8+
9+
Attributes:
10+
role (str): The role of the message sender (e.g., "user", "assistant").
11+
content (str): The content of the message.
12+
"""
13+
14+
role: str
15+
content: str
16+
17+
18+
class ChatRequest(BaseModel):
19+
"""Represents a chat request containing a list of messages and an optional system prompt.
20+
21+
Attributes:
22+
messages (List[ChatMessage]): A list of messages in the chat conversation.
23+
system_prompt (Optional[str]): An optional system prompt to provide context for the conversation.
24+
"""
25+
26+
messages: List[ChatMessage]
27+
system_prompt: Optional[str] = None

src/datu/base/llm_client.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,34 @@
22

33
from abc import ABC, abstractmethod
44

5+
from mcp_use import MCPClient
6+
7+
from datu.app_config import settings
8+
59

610
class BaseLLMClient(ABC):
711
"""BaseLLMClient class to provide a common interface for LLM clients.
812
This class serves as an abstract base class for all LLM clients,
913
providing a common interface and shared functionality.
1014
"""
1115

16+
def __init__(self):
17+
"""Initializes the BaseLLMClient.
18+
Sets up the client and MCP client if enabled in the settings.
19+
20+
Attributes:
21+
client: The LLM client instance.
22+
mcp_client: The MCP client instance if MCP is enabled in the settings.
23+
agent: The agent instance if applicable.
24+
"""
25+
self.client = None
26+
self.mcp_client = None
27+
if settings.enable_mcp:
28+
self.mcp_client = MCPClient.from_config_file(settings.mcp.config_file)
29+
self.agent = None
30+
1231
@abstractmethod
13-
def chat_completion(self, messages: list, system_prompt: str | None = None) -> str:
32+
async def chat_completion(self, messages: list, system_prompt: str | None = None) -> str:
1433
"""Given a conversation (and an optional system prompt), returns the assistant's text response."""
1534

1635
@abstractmethod

src/datu/llm_clients/openai_client.py

Lines changed: 60 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from langchain_community.chat_message_histories import ChatMessageHistory
99
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
1010
from langchain_openai import ChatOpenAI
11+
from mcp_use import MCPAgent
1112

1213
from datu.app_config import get_logger, settings
1314
from datu.base.llm_client import BaseLLMClient
@@ -65,15 +66,54 @@ class OpenAIClient(BaseLLMClient):
6566
"""
6667

6768
def __init__(self):
69+
"""Initializes the OpenAIClient with the configured model and API key."""
70+
super().__init__()
6871
self.model = getattr(settings, "openai_model", "gpt-4o-mini")
6972
self.client = ChatOpenAI(
7073
api_key=settings.openai_api_key,
7174
model=self.model,
7275
temperature=settings.llm_temperature,
7376
)
7477
self.history = ChatMessageHistory()
78+
self.agent = None
79+
if settings.enable_mcp:
80+
if not self.mcp_client:
81+
raise RuntimeError("MCP is enabled but mcp_client was not initialized. ")
82+
try:
83+
self.agent = MCPAgent(
84+
llm=self.client,
85+
client=self.mcp_client,
86+
max_steps=settings.mcp.max_steps,
87+
use_server_manager=settings.mcp.use_server_manager,
88+
)
89+
except Exception:
90+
# Prefer failing early so misconfig doesn’t silently degrade behavior
91+
logger.exception("Failed to construct MCPAgent with provided MCP settings.")
92+
raise
93+
94+
async def chat(self, input_text: str) -> str:
95+
"""Sends a chat message to the MCP agent and returns the response.
96+
Args:
97+
input_text (str): The input text to send to the agent.
98+
Returns:
99+
str: The response from the agent."""
75100

76-
def chat_completion(self, messages: list[BaseMessage], system_prompt: str | None = None) -> str:
101+
if not settings.enable_mcp or self.agent is None:
102+
raise RuntimeError("chat() requires MCP enabled and an initialized agent.")
103+
response = await self.agent.run(
104+
input_text,
105+
max_steps=30,
106+
)
107+
return response
108+
109+
async def chat_completion(self, messages: list[BaseMessage], system_prompt: str | None = None) -> str:
110+
"""Generates a chat completion response based on the provided messages and system prompt.
111+
Args:
112+
messages (list[BaseMessage]): A list of messages to send to the LLM.
113+
system_prompt (str | None): An optional system prompt to guide the LLM's response.
114+
Returns:
115+
str: The generated response from the LLM.
116+
"""
77117
if settings.simulate_llm_response:
78118
return create_simulated_llm_response()
79119
if not messages:
@@ -114,9 +154,25 @@ def chat_completion(self, messages: list[BaseMessage], system_prompt: str | None
114154
)
115155

116156
self.history.add_message(HumanMessage(content=last_user_message))
117-
response = self.client.invoke(self.history.messages)
118-
self.history.add_message(response)
119-
return response.content if response else ""
157+
# Convert entire history messages to plain text for chat()
158+
# Adjust this if your llm_with_tools expects different format
159+
input_text = "\n".join(msg.content for msg in self.history.messages if hasattr(msg, "content"))
160+
161+
if settings.enable_mcp:
162+
# uses MCP agent
163+
response = await self.chat(input_text)
164+
else:
165+
# direct LLM call without MCP
166+
response = await self.client.ainvoke(self.history.messages)
167+
168+
# Assuming response is a BaseMessage or similar with 'content'
169+
if hasattr(response, "content"):
170+
self.history.add_message(response)
171+
return response.content
172+
else:
173+
# If response is plain text string
174+
self.history.add_message(HumanMessage(content=response))
175+
return response
120176

121177
def fix_sql_error(self, sql_code: str, error_msg: str, loop_count: int) -> str:
122178
"""Generates a corrected SQL query based on the provided SQL code and error message.

src/datu/main.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,14 @@
1111
from fastapi.staticfiles import StaticFiles
1212

1313
from datu.app_config import get_logger, settings
14-
from datu.mcp.launcher import launch_mcp_server
1514
from datu.routers import chat, metadata, transformations
1615
from datu.schema_extractor.schema_cache import load_schema_cache
1716

1817
logger = get_logger(__name__)
1918

2019
# Optionally load schema and graph-rag in cache for use in prompts or logging
2120
if settings.app_environment != "test":
22-
if settings.enable_schema_rag:
23-
launch_mcp_server("schema_rag_server")
24-
else:
25-
schema_data = load_schema_cache()
21+
schema_data = load_schema_cache()
2622

2723

2824
# Create the FastAPI application instance.

src/datu/mcp/client.py

Lines changed: 0 additions & 66 deletions
This file was deleted.

0 commit comments

Comments
 (0)