|
12 | 12 | import argparse |
13 | 13 | import json |
14 | 14 | import os |
| 15 | +import uuid |
15 | 16 | from collections.abc import AsyncIterator |
16 | 17 | from contextlib import asynccontextmanager, contextmanager |
| 18 | +from datetime import datetime, timezone |
17 | 19 | from pathlib import Path |
18 | 20 | from typing import Any, Dict, Generator, Literal, Optional, Tuple, cast |
19 | 21 |
|
|
52 | 54 | tag_minor_version = 3 |
53 | 55 | query_tag = {"origin": "sf_sit", "name": "mcp_server"} |
54 | 56 |
|
| 57 | +# Default query comment template - matches dbt query tag format for observability |
| 58 | +DEFAULT_QUERY_COMMENT_TEMPLATE = { |
| 59 | + "agent": "{agent_name}", |
| 60 | + "context": { |
| 61 | + "request_id": "{request_id}", |
| 62 | + "timestamp": "{timestamp}", |
| 63 | + }, |
| 64 | + "intent": "{intent}", |
| 65 | + "model_version": "{model}", |
| 66 | + "query": { |
| 67 | + "tool": "{tool_name}", |
| 68 | + "statement_type": "{statement_type}", |
| 69 | + }, |
| 70 | + "query_parameters": "{query_parameters}", |
| 71 | + "user": { |
| 72 | + "email": "{user_email}", |
| 73 | + "name": "{user_name}", |
| 74 | + }, |
| 75 | +} |
| 76 | + |
55 | 77 | logger = get_logger(server_name) |
56 | 78 |
|
57 | 79 |
|
@@ -130,6 +152,10 @@ def __init__( |
130 | 152 | self.semantic_manager = False |
131 | 153 | self.default_session_parameters: Dict[str, Any] = {} |
132 | 154 | self.query_tag = query_tag if query_tag is not None else None |
| 155 | + self.query_comment_template: Optional[Dict[str, Any]] = None |
| 156 | + self.query_comment_enabled = False |
| 157 | + # Runtime query context set by agents via set_query_context tool |
| 158 | + self.query_context: Dict[str, str] = {} |
133 | 159 | self.tag_major_version = ( |
134 | 160 | tag_major_version if tag_major_version is not None else None |
135 | 161 | ) |
@@ -184,6 +210,16 @@ def unpack_service_specs(self) -> None: |
184 | 210 | self.query_manager = other_services.get("query_manager", False) |
185 | 211 | self.semantic_manager = other_services.get("semantic_manager", False) |
186 | 212 |
|
| 213 | + # Parse query comment configuration |
| 214 | + query_comment_config = service_config.get("query_comment", {}) |
| 215 | + if query_comment_config: |
| 216 | + self.query_comment_enabled = query_comment_config.get("enabled", False) |
| 217 | + custom_template = query_comment_config.get("template") |
| 218 | + if custom_template: |
| 219 | + self.query_comment_template = custom_template |
| 220 | + elif self.query_comment_enabled: |
| 221 | + self.query_comment_template = DEFAULT_QUERY_COMMENT_TEMPLATE.copy() |
| 222 | + |
187 | 223 | except Exception as e: |
188 | 224 | logger.error(f"Error extracting service specifications: {e}") |
189 | 225 | raise |
@@ -414,6 +450,139 @@ def get_query_tag_param( |
414 | 450 | else: |
415 | 451 | return None |
416 | 452 |
|
| 453 | + def set_query_context(self, **kwargs: str) -> Dict[str, str]: |
| 454 | + """ |
| 455 | + Set runtime query context values for query comments. |
| 456 | +
|
| 457 | + This method allows agents to set context information (like model name, |
| 458 | + session ID, etc.) that will be included in subsequent query comments. |
| 459 | + Context values override environment variables. |
| 460 | +
|
| 461 | + Parameters |
| 462 | + ---------- |
| 463 | + **kwargs : str |
| 464 | + Key-value pairs to set in the query context. |
| 465 | + Common keys: model, session_id, agent_name, user_context |
| 466 | +
|
| 467 | + Returns |
| 468 | + ------- |
| 469 | + Dict[str, str] |
| 470 | + The updated query context dictionary |
| 471 | + """ |
| 472 | + self.query_context.update(kwargs) |
| 473 | + return self.query_context.copy() |
| 474 | + |
| 475 | + def get_query_context(self) -> Dict[str, str]: |
| 476 | + """ |
| 477 | + Get the current runtime query context. |
| 478 | +
|
| 479 | + Returns |
| 480 | + ------- |
| 481 | + Dict[str, str] |
| 482 | + Current query context dictionary |
| 483 | + """ |
| 484 | + return self.query_context.copy() |
| 485 | + |
| 486 | + def clear_query_context(self) -> None: |
| 487 | + """Clear all runtime query context values.""" |
| 488 | + self.query_context.clear() |
| 489 | + |
| 490 | + def build_query_comment( |
| 491 | + self, |
| 492 | + tool_name: str = "unknown", |
| 493 | + statement_type: str = "unknown", |
| 494 | + ) -> Optional[str]: |
| 495 | + """ |
| 496 | + Build a query comment string with template variable substitution. |
| 497 | +
|
| 498 | + Substitutes template variables in the query comment template with actual values. |
| 499 | + Supported variables: |
| 500 | + - {request_id}: Unique UUID for this request |
| 501 | + - {timestamp}: ISO 8601 timestamp |
| 502 | + - {tool_name}: Name of the MCP tool being used |
| 503 | + - {statement_type}: Type of SQL statement (Select, Insert, etc.) |
| 504 | + - {model}: AI model name (from query_context, env var, or 'unknown') |
| 505 | + - {session_id}: Session ID (from query_context or 'unknown') |
| 506 | + - {agent_name}: Agent name (from query_context or 'unknown') |
| 507 | + - {server_name}: MCP server name |
| 508 | + - {server_version}: Server version string |
| 509 | + - Any custom keys set via set_query_context |
| 510 | +
|
| 511 | + Parameters |
| 512 | + ---------- |
| 513 | + tool_name : str |
| 514 | + Name of the MCP tool making the query |
| 515 | + statement_type : str |
| 516 | + Type of SQL statement being executed |
| 517 | +
|
| 518 | + Returns |
| 519 | + ------- |
| 520 | + str or None |
| 521 | + JSON string of the query comment, or None if disabled |
| 522 | + """ |
| 523 | + if not self.query_comment_enabled or self.query_comment_template is None: |
| 524 | + return None |
| 525 | + |
| 526 | + # Build substitution values - runtime context takes precedence over env vars |
| 527 | + substitutions = { |
| 528 | + "request_id": str(uuid.uuid4()), |
| 529 | + "timestamp": datetime.now(timezone.utc).isoformat(), |
| 530 | + "tool_name": tool_name, |
| 531 | + "statement_type": statement_type, |
| 532 | + # Model: check runtime context first, then env var, then default |
| 533 | + "model": self.query_context.get( |
| 534 | + "model", os.environ.get("SNOWFLAKE_MCP_MODEL", "unknown") |
| 535 | + ), |
| 536 | + # Session ID: from runtime context or default |
| 537 | + "session_id": self.query_context.get("session_id", "unknown"), |
| 538 | + # Agent name: from runtime context or default to server name |
| 539 | + "agent_name": self.query_context.get("agent_name", server_name), |
| 540 | + # User info: from runtime context |
| 541 | + "user_email": self.query_context.get( |
| 542 | + "user_email", os.environ.get("SNOWFLAKE_MCP_USER_EMAIL", "unknown") |
| 543 | + ), |
| 544 | + "user_name": self.query_context.get( |
| 545 | + "user_name", os.environ.get("SNOWFLAKE_MCP_USER_NAME", "unknown") |
| 546 | + ), |
| 547 | + # Intent and query_parameters: complex objects from agent (default to null) |
| 548 | + "intent": self.query_context.get("intent"), |
| 549 | + "query_parameters": self.query_context.get("query_parameters"), |
| 550 | + "server_name": server_name, |
| 551 | + "server_version": f"{tag_major_version}.{tag_minor_version}", |
| 552 | + } |
| 553 | + # Add any additional custom context values |
| 554 | + for key, value in self.query_context.items(): |
| 555 | + if key not in substitutions: |
| 556 | + substitutions[key] = value |
| 557 | + |
| 558 | + def substitute_value(value: Any) -> Any: |
| 559 | + """Recursively substitute template variables in values.""" |
| 560 | + if isinstance(value, str): |
| 561 | + # Check if the entire string is a single placeholder like "{intent}" |
| 562 | + # If so, return the actual value (could be dict, None, etc.) |
| 563 | + import re |
| 564 | + |
| 565 | + match = re.fullmatch(r"\{(\w+)\}", value) |
| 566 | + if match: |
| 567 | + key = match.group(1) |
| 568 | + if key in substitutions: |
| 569 | + return substitutions[key] |
| 570 | + # Otherwise do string replacement |
| 571 | + result = value |
| 572 | + for key, sub_value in substitutions.items(): |
| 573 | + if sub_value is not None: |
| 574 | + result = result.replace(f"{{{key}}}", str(sub_value)) |
| 575 | + return result |
| 576 | + elif isinstance(value, dict): |
| 577 | + return {k: substitute_value(v) for k, v in value.items()} |
| 578 | + elif isinstance(value, list): |
| 579 | + return [substitute_value(item) for item in value] |
| 580 | + else: |
| 581 | + return value |
| 582 | + |
| 583 | + comment = substitute_value(self.query_comment_template) |
| 584 | + return json.dumps(comment) |
| 585 | + |
417 | 586 |
|
418 | 587 | def get_var(var_name: str, env_var_name: str, args) -> Optional[str]: |
419 | 588 | """ |
|
0 commit comments