Skip to content

Commit 19dcb17

Browse files
committed
feat: add SageMaker MCP configuration and server setup scripts
1 parent b75682c commit 19dcb17

File tree

6 files changed

+668
-0
lines changed

6 files changed

+668
-0
lines changed
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
{
2+
"mcpServers": {
3+
"smus-local-mcp": {
4+
"command": "python",
5+
"args": ["/etc/sagemaker-mcp/smus-mcp.py"]
6+
}
7+
}
8+
}
Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
"""
2+
SageMaker Unified Studio Project Context MCP Server in stdio transport.
3+
4+
This is a self-contained MCP Server, ready to be used directly.
5+
Dependencies:
6+
pip install mcp[cli]
7+
pip install sagemaker_studio
8+
"""
9+
10+
import json
11+
import logging
12+
from typing import Any, Dict, Optional
13+
14+
import boto3
15+
from mcp.server.fastmcp import FastMCP
16+
from sagemaker_studio import ClientConfig, Project
17+
18+
# Configure logging
19+
logging.basicConfig(level=logging.INFO)
20+
logger = logging.getLogger(__name__)
21+
22+
23+
class ProjectContext:
24+
"""
25+
A class that encapsulates AWS session, project object, and region.
26+
This class simplifies the common pattern of setting up an AWS session,
27+
extracting credentials, and getting a project.
28+
"""
29+
30+
def __init__(self):
31+
try:
32+
with open("/opt/ml/metadata/resource-metadata.json", "r") as metadata_file:
33+
metadata = json.load(metadata_file)
34+
self.domain_id = metadata["AdditionalMetadata"]["DataZoneDomainId"]
35+
self.project_id = metadata["AdditionalMetadata"]["DataZoneProjectId"]
36+
self.region = metadata["AdditionalMetadata"]["DataZoneDomainRegion"]
37+
38+
logger.info(f"Read self.domain: {self.domain_id}")
39+
40+
self.session = boto3.Session(region_name=self.region)
41+
client_conf = ClientConfig(session=self.session, region=self.region)
42+
self.project = Project(id=self.project_id, domain_id=self.domain_id, config=client_conf)
43+
except Exception as e:
44+
raise RuntimeError(f"Failed to initialize project: {e}")
45+
46+
47+
def safe_get_attr(obj: Any, attr: str, default: Any = None) -> Any:
48+
"""Safely get an attribute from an object."""
49+
if obj is None:
50+
return default
51+
52+
try:
53+
if hasattr(obj, attr):
54+
value = getattr(obj, attr)
55+
# Handle case where attribute access might throw RuntimeError
56+
if callable(value):
57+
try:
58+
return value()
59+
except (RuntimeError, Exception) as e:
60+
logger.debug(f"Error calling attribute {attr}: {e}")
61+
return default
62+
return value
63+
return default
64+
except Exception as e:
65+
logger.debug(f"Error getting attribute {attr}: {e}")
66+
return default
67+
68+
69+
def create_smus_context_identifiers_response(domain_id: str, project_id: str, region: str) -> str:
70+
71+
return f"""Selectively use the below parameters only when the parameter is required.
72+
<parameter>
73+
domain identifier: "{domain_id}"
74+
project identifier: "{project_id}"
75+
region: "{region}"
76+
</parameter>
77+
Again, include only required parameters. Any extra parameters may cause the API to fail. Stick strictly to the schema."""
78+
79+
80+
async def aws_context_provider() -> Dict[str, Any]:
81+
"""
82+
AWS Context Provider - MUST BE CALLED BEFORE ANY use_aws OPERATIONS
83+
84+
This tool provides essential AWS context parameters that are required by subsequent AWS operations.
85+
It returns configuration details including domain identifiers, project information, and region
86+
settings that would otherwise need to be manually specified with each use_aws call.
87+
88+
The returned parameters include:
89+
- domain identifier: Unique identifier for the AWS DataZone domain
90+
- project identifier: Identifier for the specific project being worked on
91+
- profile name: Name of the AWS profile to use for credentials
92+
- project profile connection name: Connection name for project integration
93+
- region: AWS region where operations should be performed
94+
95+
Returns:
96+
dict: Parameter context to be used with subsequent use_aws operations
97+
"""
98+
identifiers_response = ""
99+
try:
100+
ctx = ProjectContext()
101+
domain_id = safe_get_attr(ctx, "domain_id", "")
102+
project_id = safe_get_attr(ctx, "project_id", "")
103+
region = safe_get_attr(ctx, "region", "")
104+
identifiers_response = create_smus_context_identifiers_response(
105+
domain_id, project_id, region
106+
)
107+
return {"response": identifiers_response}
108+
except Exception as e:
109+
logger.error(f"Error providing SMUS context identifiers: {e}")
110+
return {"response": identifiers_response, "error": str(e)}
111+
112+
113+
async def list_tables(
114+
catalog_id: Optional[str] = None, database_name: Optional[str] = None
115+
) -> dict:
116+
"""List all available tables, optionally filtered by catalog ID and database name."""
117+
try:
118+
ctx = ProjectContext()
119+
connections = safe_get_attr(ctx.project, "connections", [])
120+
tables_list = []
121+
122+
for conn in connections:
123+
conn_type = safe_get_attr(conn, "type", "")
124+
if conn_type != 'IAM' and conn_type != 'LAKEHOUSE':
125+
continue
126+
# collect lakehouse catalogs
127+
catalogs = safe_get_attr(conn, "catalogs", [])
128+
129+
# collect glue catalogs
130+
catalogs.append(conn.catalog())
131+
for catalog in catalogs:
132+
current_catalog_id = safe_get_attr(catalog, "id")
133+
134+
# Skip if catalog_id is provided and doesn't match
135+
if catalog_id and current_catalog_id != catalog_id:
136+
continue
137+
138+
databases = safe_get_attr(catalog, "databases", [])
139+
for db in databases:
140+
current_db_name = safe_get_attr(db, "name")
141+
142+
# Skip if database_name is provided and doesn't match
143+
if database_name and current_db_name != database_name:
144+
continue
145+
146+
tables = safe_get_attr(db, "tables", [])
147+
148+
for table in tables:
149+
table_name = safe_get_attr(table, "name")
150+
151+
if table_name:
152+
table_info = {
153+
"name": table_name,
154+
"database_name": current_db_name,
155+
"catalog_id": current_catalog_id,
156+
}
157+
158+
# Add location if available
159+
location = safe_get_attr(table, "location")
160+
if location:
161+
table_info["location"] = location
162+
163+
# Get columns if available
164+
columns = safe_get_attr(table, "columns", [])
165+
columns_info = []
166+
167+
for column in columns:
168+
col_name = safe_get_attr(column, "name")
169+
col_type = safe_get_attr(column, "type")
170+
171+
if col_name and col_type:
172+
columns_info.append({"name": col_name, "type": col_type})
173+
174+
if columns_info:
175+
table_info["columns"] = columns_info
176+
177+
tables_list.append(table_info)
178+
179+
return {"tables": tables_list}
180+
except Exception as e:
181+
logger.error(f"Error listing tables: {e}")
182+
return {"tables": [], "error": str(e)}
183+
184+
185+
async def get_table_schema(catalog_id: str, database_name: str, table_name: str) -> dict:
186+
"""Get schema information for a specific table."""
187+
try:
188+
ctx = ProjectContext()
189+
connections = safe_get_attr(ctx.project, "connections", [])
190+
191+
for conn in connections:
192+
conn_type = safe_get_attr(conn, "type", "")
193+
if conn_type != 'IAM' and conn_type != 'LAKEHOUSE':
194+
continue
195+
# collect lakehouse catalogs
196+
catalogs = safe_get_attr(conn, "catalogs", [])
197+
198+
# collect glue catalogs
199+
catalogs.append(conn.catalog())
200+
for catalog in catalogs:
201+
current_catalog_id = safe_get_attr(catalog, "id")
202+
if current_catalog_id != catalog_id:
203+
continue
204+
205+
databases = safe_get_attr(catalog, "databases", [])
206+
for db in databases:
207+
current_db_name = safe_get_attr(db, "name")
208+
if current_db_name != database_name:
209+
continue
210+
211+
tables = safe_get_attr(db, "tables", [])
212+
for table in tables:
213+
current_table_name = safe_get_attr(table, "name")
214+
if current_table_name != table_name:
215+
continue
216+
217+
# Get columns
218+
columns = safe_get_attr(table, "columns", [])
219+
columns_info = []
220+
for column in columns:
221+
col_name = safe_get_attr(column, "name")
222+
col_type = safe_get_attr(column, "type")
223+
if col_name and col_type:
224+
column_info = {"name": col_name, "type": col_type}
225+
comment = safe_get_attr(column, "comment")
226+
if comment:
227+
column_info["comment"] = comment
228+
columns_info.append(column_info)
229+
table_info = {
230+
"name": table_name,
231+
"database_name": database_name,
232+
"catalog_id": catalog_id,
233+
"columns": columns_info,
234+
}
235+
location = safe_get_attr(table, "location")
236+
if location:
237+
table_info["location"] = location
238+
239+
return table_info
240+
241+
return {"error": f"Table not found: {catalog_id}.{database_name}.{table_name}"}
242+
except Exception as e:
243+
logger.error(f"Error listing table schema: {e}")
244+
return {"error": str(e)}
245+
246+
247+
def create_mcp_server():
248+
"""
249+
Create and return a new FastMCP server instance.
250+
This ensures a fresh instance is created for each Lambda invocation.
251+
"""
252+
mcp: FastMCP = FastMCP(
253+
stateless_http=True,
254+
)
255+
256+
# Register the tools from tools.py
257+
mcp.tool()(
258+
aws_context_provider
259+
) # use the doc string of the function as description, do not overwrite here.
260+
261+
mcp.tool(
262+
description="List all available tables, optionally filtered by catalog ID and database name"
263+
)(list_tables)
264+
mcp.tool(description="Get schema information for a specific table")(get_table_schema)
265+
266+
return mcp
267+
268+
269+
# For local development only
270+
if __name__ == "__main__":
271+
# Create the initial FastMCP server for local development
272+
mcp = create_mcp_server()
273+
mcp.run(transport="stdio")

template/v2/dirs/etc/sagemaker-ui/sagemaker_ui_post_startup.sh

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,59 @@ else
224224
sed -i '/^export AMAZON_Q_SIGV4=/d' ~/.bashrc
225225
fi
226226

227+
# Setup SageMaker MCP configuration
228+
echo "Setting up SageMaker MCP configuration..."
229+
mkdir -p $HOME/.aws/amazonq/
230+
target_file="$HOME/.aws/amazonq/mcp.json"
231+
source_file="/etc/sagemaker-mcp/mcp.json"
232+
233+
if [ -f "$source_file" ]; then
234+
# Extract all servers from source configuration
235+
if [ -f "$target_file" ]; then
236+
# Target file exists - merge configurations
237+
echo "Existing MCP configuration found, merging configurations..."
238+
239+
# Check if it's valid JSON first
240+
if jq empty "$target_file" 2>/dev/null; then
241+
# Initialize mcpServers object if it doesn't exist
242+
if ! jq -e '.mcpServers' "$target_file" >/dev/null 2>&1; then
243+
echo "Creating mcpServers object in existing configuration"
244+
jq '. + {"mcpServers":{}}' "$target_file" > "$target_file.tmp"
245+
mv "$target_file.tmp" "$target_file"
246+
fi
247+
248+
servers=$(jq '.mcpServers | keys[]' "$source_file" | tr -d '"')
249+
250+
# Add each server from source to target if it doesn't exist
251+
for server in $servers; do
252+
if ! jq -e ".mcpServers.\"$server\"" "$target_file" >/dev/null 2>&1; then
253+
server_config=$(jq ".mcpServers.\"$server\"" "$source_file")
254+
jq --arg name "$server" --argjson config "$server_config" \
255+
'.mcpServers[$name] = $config' "$target_file" > "$target_file.tmp"
256+
mv "$target_file.tmp" "$target_file"
257+
echo "Added server '$server' to existing configuration"
258+
else
259+
echo "Server '$server' already exists in configuration"
260+
fi
261+
done
262+
else
263+
echo "Warning: Existing MCP configuration is not valid JSON, replacing with default configuration"
264+
cp "$source_file" "$target_file"
265+
fi
266+
else
267+
# File doesn't exist, copy our configuration
268+
cp "$source_file" "$target_file"
269+
echo "Created new MCP configuration with default servers"
270+
fi
271+
272+
# Set proper ownership and permissions
273+
chown $NB_USER:$NB_GID "$target_file"
274+
chmod 644 "$target_file"
275+
echo "Successfully configured MCP for SageMaker"
276+
else
277+
echo "Warning: MCP configuration file not found at $source_file"
278+
fi
279+
227280
# Generate sagemaker pysdk intelligent default config
228281
nohup python /etc/sagemaker/sm_pysdk_default_config.py &
229282
# Only run the following commands if SAGEMAKER_APP_TYPE_LOWERCASE is jupyterlab
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
{
2+
"mcpServers": {
3+
"smus-local-mcp": {
4+
"command": "python",
5+
"args": ["/etc/sagemaker-mcp/smus-mcp.py"]
6+
}
7+
}
8+
}

0 commit comments

Comments
 (0)