Skip to content

Commit 8fd532b

Browse files
authored
Merge pull request #63 from sragrawal/62-fixes
Add the NL2SQL tool to the MySQL MCP server. This tool internally use…
2 parents c36757f + 6b60681 commit 8fd532b

File tree

2 files changed

+145
-22
lines changed

2 files changed

+145
-22
lines changed

src/mysql-mcp-server/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ A Python-based MCP (Model Context Protocol) server that provides a suite of tool
2121
- `ragify_column`: Create/populate vector columns for embeddings
2222
- `ask_ml_rag`: Retrieval-augmented generation from vector stores
2323
- `heatwave_ask_help`: Answers questions about how to use HeatWave ML
24+
- `ask_nl_sql`: Convert natural language questions into SQL queries and execute them automatically
2425

2526
- **Vector Store Management**
2627
- List files in `secure_file_priv` (local mode)
@@ -213,6 +214,7 @@ python mysql_mcp_server.py
213214
11. `list_all_compartments()`: List OCI compartments
214215
12. `object_storage_list_buckets(compartment_name | compartment_id)`: List buckets in a compartment
215216
13. `object_storage_list_objects(namespace, bucket_name)`: List objects in a bucket
217+
14. `ask_nl_sql(connection_id, question)`: Convert natural language questions into SQL queries and execute them automatically
216218

217219
## Security
218220

@@ -236,6 +238,7 @@ Here are example prompts you can use to interact with the MCP server, note that
236238
```
237239
"Generate a summary of error logs"
238240
"Ask ml_rag: Show me refund policy from the vector store"
241+
"What is the average delay incurred by flights?"
239242
```
240243

241244
### 3. Object Storage

src/mysql-mcp-server/mysql_mcp_server.py

Lines changed: 142 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,13 @@
1212
from fastmcp import FastMCP
1313
from mysql import connector
1414
from mysql.connector.abstracts import MySQLConnectionAbstract
15-
16-
from utils import DatabaseConnectionError, get_ssh_command, load_mysql_config, Mode, OciInfo
15+
from utils import (
16+
DatabaseConnectionError,
17+
Mode,
18+
OciInfo,
19+
get_ssh_command,
20+
load_mysql_config,
21+
)
1722

1823
MIN_CONTEXT_SIZE = 10
1924
DEFAULT_CONTEXT_SIZE = 20
@@ -29,20 +34,26 @@
2934
try:
3035
config = load_mysql_config()
3136
except Exception as e:
32-
config_error_msg = json.dumps({
33-
"error" : f"Error loading config. Fix configuration file and try restarting MCP server {str(e)}."
34-
})
37+
config_error_msg = json.dumps(
38+
{
39+
"error": f"Error loading config. Fix configuration file and try restarting MCP server {str(e)}."
40+
}
41+
)
3542

3643
# Setup oci connection if applicable
3744
oci_info: Optional[OciInfo] = None # None if not available, otherwise OCI config info
38-
oci_error_msg: Optional[str] = None # None if OCI available, otherwise a json formatted string
45+
oci_error_msg: Optional[str] = (
46+
None # None if OCI available, otherwise a json formatted string
47+
)
3948
try:
4049
oci_info = OciInfo()
4150
except Exception as e:
42-
oci_error_msg = json.dumps({
43-
"error" : "object store unavailable. If object store is required, the MCP server must be restarted with a valid"
44-
f" OCI config. OCI connection attempt yielded error {str(e)}."
45-
})
51+
oci_error_msg = json.dumps(
52+
{
53+
"error": "object store unavailable. If object store is required, the MCP server must be restarted with a valid"
54+
f" OCI config. OCI connection attempt yielded error {str(e)}."
55+
}
56+
)
4657

4758
# Create mcp server
4859
mcp = FastMCP("MySQL")
@@ -51,6 +62,7 @@
5162
# Finish setup
5263
###############################################################
5364

65+
5466
def _validate_name(name: str) -> str:
5567
"""
5668
Validate that the string is a legal SQL identifier (letters, digits, underscores).
@@ -81,9 +93,7 @@ def _get_mode(connection_id: str) -> Mode:
8193
Returns:
8294
Mode: The resolved provider mode.
8395
"""
84-
provider_result = _execute_sql_tool(
85-
connection_id, "SELECT @@rapid_cloud_provider;"
86-
)
96+
provider_result = _execute_sql_tool(connection_id, "SELECT @@rapid_cloud_provider;")
8797
if check_error(provider_result):
8898
raise Exception(
8999
f"Exception occurred while fetching cloud provider {str(provider_result)}"
@@ -230,7 +240,7 @@ def list_all_connections() -> str:
230240
{
231241
"key": connection_id,
232242
"error": str(e),
233-
"hint": f"Bastion/jump host may be down. Try starting it with {get_ssh_command(config)}"
243+
"hint": f"Bastion/jump host may be down. Try starting it with {get_ssh_command(config)}",
234244
}
235245
)
236246
return json.dumps({"valid keys": valid_keys, "invalid keys": invalid_keys})
@@ -258,6 +268,19 @@ def execute_sql_tool_by_connection_id(
258268
return _execute_sql_tool(connection_id, sql_script, params=params)
259269

260270

271+
from datetime import date, datetime
272+
from decimal import Decimal
273+
274+
275+
class CustomJSONEncoder(json.JSONEncoder):
276+
def default(self, o):
277+
if isinstance(o, Decimal):
278+
return str(o)
279+
if isinstance(o, (date, datetime)):
280+
return o.isoformat()
281+
return super().default(o)
282+
283+
261284
def _execute_sql_tool(
262285
connection: Union[str, MySQLConnectionAbstract],
263286
sql_script: str,
@@ -309,7 +332,7 @@ def _execute_sql_tool(
309332

310333
db_connection.commit()
311334

312-
return json.dumps(results)
335+
return json.dumps(results, cls=CustomJSONEncoder)
313336

314337
except Exception as e:
315338
return json.dumps(
@@ -565,7 +588,9 @@ def load_vector_store_oci(
565588

566589

567590
@mcp.tool()
568-
def ask_ml_rag_vector_store(connection_id: str, question: str, context_size: int = DEFAULT_CONTEXT_SIZE) -> str:
591+
def ask_ml_rag_vector_store(
592+
connection_id: str, question: str, context_size: int = DEFAULT_CONTEXT_SIZE
593+
) -> str:
569594
"""
570595
[MCP Tool] Retrieve segments from the default vector store (skip_generate=true).
571596
@@ -586,16 +611,26 @@ def ask_ml_rag_vector_store(connection_id: str, question: str, context_size: int
586611
arguments: {"connection_id": "example_local_server", "question": "Find information about refunds."}
587612
"""
588613
if context_size < MIN_CONTEXT_SIZE or MAX_CONTEXT_SIZE < context_size:
589-
return json.dumps({"error": f"Error choose a context_size in [{MIN_CONTEXT_SIZE}, {MAX_CONTEXT_SIZE}]"})
614+
return json.dumps(
615+
{
616+
"error": f"Error choose a context_size in [{MIN_CONTEXT_SIZE}, {MAX_CONTEXT_SIZE}]"
617+
}
618+
)
590619

591620
return _ask_ml_rag_helper(
592-
connection_id, question, f"JSON_OBJECT('skip_generate', true, 'n_citations', {context_size})"
621+
connection_id,
622+
question,
623+
f"JSON_OBJECT('skip_generate', true, 'n_citations', {context_size})",
593624
)
594625

595626

596627
@mcp.tool()
597628
def ask_ml_rag_innodb(
598-
connection_id: str, question: str, segment_col: str, embedding_col: str, context_size: int = DEFAULT_CONTEXT_SIZE
629+
connection_id: str,
630+
question: str,
631+
segment_col: str,
632+
embedding_col: str,
633+
context_size: int = DEFAULT_CONTEXT_SIZE,
599634
) -> str:
600635
"""
601636
[MCP Tool] Retrieve segments from InnoDB tables using specified segment and embedding columns.
@@ -626,7 +661,11 @@ def ask_ml_rag_innodb(
626661
arguments: {"connection_id": "example_local_server", "question": "Search product docs", "segment_col": "body", "embedding_col": "embedding"}
627662
"""
628663
if context_size < MIN_CONTEXT_SIZE or MAX_CONTEXT_SIZE < context_size:
629-
return json.dumps({"error": f"Error choose a context_size in [{MIN_CONTEXT_SIZE}, {MAX_CONTEXT_SIZE}]"})
664+
return json.dumps(
665+
{
666+
"error": f"Error choose a context_size in [{MIN_CONTEXT_SIZE}, {MAX_CONTEXT_SIZE}]"
667+
}
668+
)
630669

631670
try:
632671
# prevent possible injection
@@ -732,6 +771,84 @@ def heatwave_ask_help(connection_id: str, question: str) -> str:
732771
return json.dumps({"error": f"Error with NL2ML: {str(e)}"})
733772

734773

774+
@mcp.tool()
775+
def ask_nl_sql(connection_id: str, question: str) -> str:
776+
"""
777+
[MCP Tool] Convert natural language questions into SQL queries and execute them automatically.
778+
779+
This tool is ideal for database exploration using plain English questions like:
780+
- "What tables are available?"
781+
- "Show me the average price by category"
782+
- "How many users registered last month?"
783+
- "What are the column names in the customers table?"
784+
785+
Args:
786+
connection_id (str): MySQL connection key.
787+
question (str): Natural language query.
788+
789+
Returns:
790+
JSON object containing:
791+
792+
sql_response(str): The response from executing the generated SQL query.
793+
sql_query(str): The generated SQL query
794+
schemas(json): The schemas where metadata was retrieved
795+
tables(json): The tables where metadata was retrieved
796+
is_sql_valid(bool): Whether the generated SQL statement is valid
797+
model_id(str): The LLM used for generation
798+
799+
800+
MCP usage example:
801+
- name: ask_nl_sql
802+
arguments: {"connection_id": "example_local_server", "question": "How many singers are there?"}
803+
804+
Here is the what part of the return JSON looks like;
805+
{
806+
"tables": [
807+
"singer.singer",
808+
"singer.song",
809+
"concert_singer.singer",
810+
"concert_singer.stadium",
811+
"music_2.Songs",
812+
"music_2.Instruments",
813+
"music_2.Band",
814+
"music_2.Vocals",
815+
"music_2.Tracklists"
816+
],
817+
"schemas": [
818+
"concert_singer",
819+
"music_2",
820+
"singer"
821+
],
822+
"sql_query": "SELECT COUNT(`Singer_ID`) FROM `concert_singer`.`singer`;",
823+
"is_sql_valid": 1
824+
}
825+
"""
826+
with _get_database_connection_cm(connection_id) as db_connection:
827+
# Execute the heatwave chat query
828+
set_response = _execute_sql_tool(db_connection, "SET @response = NULL;")
829+
if check_error(set_response):
830+
return json.dumps({"error": f"Error with NL_SQL: {set_response}"})
831+
832+
nl2sql_response = _execute_sql_tool(
833+
db_connection,
834+
f"CALL sys.NL_SQL(%s, @response, NULL)",
835+
params=[question],
836+
)
837+
if check_error(nl2sql_response):
838+
return json.dumps({"error": f"Error with NL_SQL: {nl2sql_response}"})
839+
840+
fetch_response = _execute_sql_tool(db_connection, "SELECT @response;")
841+
if check_error(fetch_response):
842+
return json.dumps({"error": f"Error with ML_RAG: {fetch_response}"})
843+
844+
try:
845+
response = json.loads(fetch_one(fetch_response))
846+
response["sql_response"] = nl2sql_response
847+
return json.dumps(response)
848+
except:
849+
return json.dumps({"error": "Unexpected response format from NL_SQL"})
850+
851+
735852
"""
736853
Object store
737854
"""
@@ -745,7 +862,7 @@ def verify_compartment_access(compartments):
745862
"compartment_id": compartment.id,
746863
"object_storage": False,
747864
"databases": False,
748-
"errors": []
865+
"errors": [],
749866
}
750867

751868
# Test Object Storage
@@ -756,10 +873,13 @@ def verify_compartment_access(compartments):
756873
)
757874
access_report[compartment.name]["object_storage"] = True
758875
except Exception as e:
759-
access_report[compartment.name]["errors"].append(f"Object Storage: {str(e)}")
876+
access_report[compartment.name]["errors"].append(
877+
f"Object Storage: {str(e)}"
878+
)
760879

761880
return access_report
762881

882+
763883
@mcp.tool()
764884
def list_all_compartments() -> str:
765885
"""

0 commit comments

Comments
 (0)