Skip to content

Commit 2b8fc23

Browse files
authored
Merge pull request #5 from makenotion/atvaccaro/query-comment-template
feat: add set_query_context tool for runtime observability
2 parents 1cb0964 + 766e655 commit 2b8fc23

File tree

5 files changed

+818
-3
lines changed

5 files changed

+818
-3
lines changed

mcp_server_snowflake/query_manager/tools.py

Lines changed: 136 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
from mcp_server_snowflake.utils import SnowflakeException
99

1010

11-
def run_query(statement: str, snowflake_service):
11+
def run_query(
12+
statement: str, snowflake_service, tool_name: str = "run_snowflake_query"
13+
):
1214
"""
1315
Execute SQL statement and fetch all results using Snowflake connector.
1416
@@ -21,6 +23,8 @@ def run_query(statement: str, snowflake_service):
2123
SQL statement to execute
2224
snowflake_service : SnowflakeService
2325
The Snowflake service instance to use for connection
26+
tool_name : str
27+
Name of the tool executing the query (for query comments)
2428
2529
Returns
2630
-------
@@ -33,14 +37,29 @@ def run_query(statement: str, snowflake_service):
3337
If connection fails or SQL execution encounters an error
3438
"""
3539
try:
40+
# Get statement type for query comment
41+
statement_type = get_statement_type(statement)
42+
43+
# Build query comment if enabled
44+
query_comment = snowflake_service.build_query_comment(
45+
tool_name=tool_name,
46+
statement_type=statement_type,
47+
)
48+
49+
# Prepend comment to statement if enabled
50+
if query_comment:
51+
statement_with_comment = f"/* {query_comment} */\n{statement}"
52+
else:
53+
statement_with_comment = statement
54+
3655
with snowflake_service.get_connection(
3756
use_dict_cursor=True,
3857
session_parameters=snowflake_service.get_query_tag_param(),
3958
) as (
4059
con,
4160
cur,
4261
):
43-
cur.execute(statement)
62+
cur.execute(statement_with_comment)
4463
return cur.fetchall()
4564
except Exception as e:
4665
raise SnowflakeException(
@@ -63,6 +82,121 @@ def run_query_tool(
6382
):
6483
return run_query(statement, snowflake_service)
6584

85+
@server.tool(
86+
name="set_query_context",
87+
description="""Set runtime context for query comments and observability.
88+
89+
Call this tool at the start of a session to register context information that will be
90+
included in all subsequent SQL query comments. This enables tracking queries back to
91+
specific agents, models, or sessions in Snowflake's query history.
92+
93+
Common context keys:
94+
- model: The AI model name (e.g., "claude-sonnet-4-5-20250929")
95+
- agent_name: Name of the agent or application (e.g., "Claude Code")
96+
- user_email: Email of the user running the agent
97+
- user_name: Name of the user running the agent
98+
- intent: Object describing query intent (category, confidence, domains, question)
99+
- query_parameters: Object describing query details (datasets, dimensions, time_range)
100+
- session_id: A unique session identifier for grouping related queries
101+
102+
Context persists for the lifetime of the MCP server connection. Call this tool again
103+
to update intent/query_parameters for different queries.""",
104+
)
105+
def set_query_context_tool(
106+
model: Annotated[
107+
str,
108+
Field(
109+
default=None,
110+
description="AI model name (e.g., 'claude-sonnet-4-5-20250929')",
111+
),
112+
] = None,
113+
agent_name: Annotated[
114+
str,
115+
Field(
116+
default=None,
117+
description="Name of the agent or application (e.g., 'Claude Code')",
118+
),
119+
] = None,
120+
user_email: Annotated[
121+
str,
122+
Field(
123+
default=None,
124+
description="Email of the user running the agent",
125+
),
126+
] = None,
127+
user_name: Annotated[
128+
str,
129+
Field(
130+
default=None,
131+
description="Name of the user running the agent",
132+
),
133+
] = None,
134+
intent: Annotated[
135+
dict,
136+
Field(
137+
default=None,
138+
description="Query intent: {category, confidence, domains, question}",
139+
),
140+
] = None,
141+
query_parameters: Annotated[
142+
dict,
143+
Field(
144+
default=None,
145+
description="Query parameters: {datasets, dimensions, time_range}",
146+
),
147+
] = None,
148+
session_id: Annotated[
149+
str,
150+
Field(
151+
default=None,
152+
description="Unique session identifier for grouping queries",
153+
),
154+
] = None,
155+
custom_context: Annotated[
156+
dict,
157+
Field(
158+
default=None,
159+
description="Additional custom key-value pairs for context",
160+
),
161+
] = None,
162+
):
163+
"""Set query context for observability."""
164+
context = {}
165+
if model is not None:
166+
context["model"] = model
167+
if agent_name is not None:
168+
context["agent_name"] = agent_name
169+
if user_email is not None:
170+
context["user_email"] = user_email
171+
if user_name is not None:
172+
context["user_name"] = user_name
173+
if intent is not None:
174+
context["intent"] = intent
175+
if query_parameters is not None:
176+
context["query_parameters"] = query_parameters
177+
if session_id is not None:
178+
context["session_id"] = session_id
179+
if custom_context is not None:
180+
context.update(custom_context)
181+
182+
updated_context = snowflake_service.set_query_context(**context)
183+
return {
184+
"status": "success",
185+
"message": "Query context updated successfully",
186+
"context": updated_context,
187+
}
188+
189+
@server.tool(
190+
name="get_query_context",
191+
description="Get the current query context that will be included in query comments.",
192+
)
193+
def get_query_context_tool():
194+
"""Get current query context."""
195+
return {
196+
"context": snowflake_service.get_query_context(),
197+
"query_comments_enabled": snowflake_service.query_comment_enabled,
198+
}
199+
66200

67201
def get_statement_type(sql_string):
68202
"""

mcp_server_snowflake/server.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212
import argparse
1313
import json
1414
import os
15+
import uuid
1516
from collections.abc import AsyncIterator
1617
from contextlib import asynccontextmanager, contextmanager
18+
from datetime import datetime, timezone
1719
from pathlib import Path
1820
from typing import Any, Dict, Generator, Literal, Optional, Tuple, cast
1921

@@ -52,6 +54,26 @@
5254
tag_minor_version = 3
5355
query_tag = {"origin": "sf_sit", "name": "mcp_server"}
5456

57+
# Default query comment template - matches dbt query tag format for observability
58+
DEFAULT_QUERY_COMMENT_TEMPLATE = {
59+
"agent": "{agent_name}",
60+
"context": {
61+
"request_id": "{request_id}",
62+
"timestamp": "{timestamp}",
63+
},
64+
"intent": "{intent}",
65+
"model_version": "{model}",
66+
"query": {
67+
"tool": "{tool_name}",
68+
"statement_type": "{statement_type}",
69+
},
70+
"query_parameters": "{query_parameters}",
71+
"user": {
72+
"email": "{user_email}",
73+
"name": "{user_name}",
74+
},
75+
}
76+
5577
logger = get_logger(server_name)
5678

5779

@@ -130,6 +152,10 @@ def __init__(
130152
self.semantic_manager = False
131153
self.default_session_parameters: Dict[str, Any] = {}
132154
self.query_tag = query_tag if query_tag is not None else None
155+
self.query_comment_template: Optional[Dict[str, Any]] = None
156+
self.query_comment_enabled = False
157+
# Runtime query context set by agents via set_query_context tool
158+
self.query_context: Dict[str, str] = {}
133159
self.tag_major_version = (
134160
tag_major_version if tag_major_version is not None else None
135161
)
@@ -184,6 +210,16 @@ def unpack_service_specs(self) -> None:
184210
self.query_manager = other_services.get("query_manager", False)
185211
self.semantic_manager = other_services.get("semantic_manager", False)
186212

213+
# Parse query comment configuration
214+
query_comment_config = service_config.get("query_comment", {})
215+
if query_comment_config:
216+
self.query_comment_enabled = query_comment_config.get("enabled", False)
217+
custom_template = query_comment_config.get("template")
218+
if custom_template:
219+
self.query_comment_template = custom_template
220+
elif self.query_comment_enabled:
221+
self.query_comment_template = DEFAULT_QUERY_COMMENT_TEMPLATE.copy()
222+
187223
except Exception as e:
188224
logger.error(f"Error extracting service specifications: {e}")
189225
raise
@@ -414,6 +450,139 @@ def get_query_tag_param(
414450
else:
415451
return None
416452

453+
def set_query_context(self, **kwargs: str) -> Dict[str, str]:
454+
"""
455+
Set runtime query context values for query comments.
456+
457+
This method allows agents to set context information (like model name,
458+
session ID, etc.) that will be included in subsequent query comments.
459+
Context values override environment variables.
460+
461+
Parameters
462+
----------
463+
**kwargs : str
464+
Key-value pairs to set in the query context.
465+
Common keys: model, session_id, agent_name, user_context
466+
467+
Returns
468+
-------
469+
Dict[str, str]
470+
The updated query context dictionary
471+
"""
472+
self.query_context.update(kwargs)
473+
return self.query_context.copy()
474+
475+
def get_query_context(self) -> Dict[str, str]:
476+
"""
477+
Get the current runtime query context.
478+
479+
Returns
480+
-------
481+
Dict[str, str]
482+
Current query context dictionary
483+
"""
484+
return self.query_context.copy()
485+
486+
def clear_query_context(self) -> None:
487+
"""Clear all runtime query context values."""
488+
self.query_context.clear()
489+
490+
def build_query_comment(
491+
self,
492+
tool_name: str = "unknown",
493+
statement_type: str = "unknown",
494+
) -> Optional[str]:
495+
"""
496+
Build a query comment string with template variable substitution.
497+
498+
Substitutes template variables in the query comment template with actual values.
499+
Supported variables:
500+
- {request_id}: Unique UUID for this request
501+
- {timestamp}: ISO 8601 timestamp
502+
- {tool_name}: Name of the MCP tool being used
503+
- {statement_type}: Type of SQL statement (Select, Insert, etc.)
504+
- {model}: AI model name (from query_context, env var, or 'unknown')
505+
- {session_id}: Session ID (from query_context or 'unknown')
506+
- {agent_name}: Agent name (from query_context or 'unknown')
507+
- {server_name}: MCP server name
508+
- {server_version}: Server version string
509+
- Any custom keys set via set_query_context
510+
511+
Parameters
512+
----------
513+
tool_name : str
514+
Name of the MCP tool making the query
515+
statement_type : str
516+
Type of SQL statement being executed
517+
518+
Returns
519+
-------
520+
str or None
521+
JSON string of the query comment, or None if disabled
522+
"""
523+
if not self.query_comment_enabled or self.query_comment_template is None:
524+
return None
525+
526+
# Build substitution values - runtime context takes precedence over env vars
527+
substitutions = {
528+
"request_id": str(uuid.uuid4()),
529+
"timestamp": datetime.now(timezone.utc).isoformat(),
530+
"tool_name": tool_name,
531+
"statement_type": statement_type,
532+
# Model: check runtime context first, then env var, then default
533+
"model": self.query_context.get(
534+
"model", os.environ.get("SNOWFLAKE_MCP_MODEL", "unknown")
535+
),
536+
# Session ID: from runtime context or default
537+
"session_id": self.query_context.get("session_id", "unknown"),
538+
# Agent name: from runtime context or default to server name
539+
"agent_name": self.query_context.get("agent_name", server_name),
540+
# User info: from runtime context
541+
"user_email": self.query_context.get(
542+
"user_email", os.environ.get("SNOWFLAKE_MCP_USER_EMAIL", "unknown")
543+
),
544+
"user_name": self.query_context.get(
545+
"user_name", os.environ.get("SNOWFLAKE_MCP_USER_NAME", "unknown")
546+
),
547+
# Intent and query_parameters: complex objects from agent (default to null)
548+
"intent": self.query_context.get("intent"),
549+
"query_parameters": self.query_context.get("query_parameters"),
550+
"server_name": server_name,
551+
"server_version": f"{tag_major_version}.{tag_minor_version}",
552+
}
553+
# Add any additional custom context values
554+
for key, value in self.query_context.items():
555+
if key not in substitutions:
556+
substitutions[key] = value
557+
558+
def substitute_value(value: Any) -> Any:
559+
"""Recursively substitute template variables in values."""
560+
if isinstance(value, str):
561+
# Check if the entire string is a single placeholder like "{intent}"
562+
# If so, return the actual value (could be dict, None, etc.)
563+
import re
564+
565+
match = re.fullmatch(r"\{(\w+)\}", value)
566+
if match:
567+
key = match.group(1)
568+
if key in substitutions:
569+
return substitutions[key]
570+
# Otherwise do string replacement
571+
result = value
572+
for key, sub_value in substitutions.items():
573+
if sub_value is not None:
574+
result = result.replace(f"{{{key}}}", str(sub_value))
575+
return result
576+
elif isinstance(value, dict):
577+
return {k: substitute_value(v) for k, v in value.items()}
578+
elif isinstance(value, list):
579+
return [substitute_value(item) for item in value]
580+
else:
581+
return value
582+
583+
comment = substitute_value(self.query_comment_template)
584+
return json.dumps(comment)
585+
417586

418587
def get_var(var_name: str, env_var_name: str, args) -> Optional[str]:
419588
"""

0 commit comments

Comments
 (0)