Skip to content

Commit d0c5708

Browse files
authored
feat: Add Q CLI and local SMUS MCP (#736)
2 parents b3a6aac + e7b5300 commit d0c5708

File tree

12 files changed

+426
-0
lines changed

12 files changed

+426
-0
lines changed

template/v2/Dockerfile

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,12 @@ RUN apt-get update && apt-get upgrade -y && \
6262
sudo ./aws/install && \
6363
rm -rf aws awscliv2.zip && \
6464
: && \
65+
# Install Q CLI
66+
curl --proto '=https' --tlsv1.2 -sSf "https://desktop-release.q.us-east-1.amazonaws.com/latest/q-x86_64-linux.zip" -o "q.zip" && \
67+
unzip q.zip && \
68+
Q_INSTALL_GLOBAL=true ./q/install.sh --no-confirm && \
69+
rm -rf q q.zip && \
70+
: && \
6571
echo "source /usr/local/bin/_activate_current_env.sh" | tee --append /etc/profile && \
6672
# CodeEditor - create server, user data dirs
6773
mkdir -p /opt/amazon/sagemaker/sagemaker-code-editor-server-data /opt/amazon/sagemaker/sagemaker-code-editor-user-data \
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
{
2+
"mcpServers": {
3+
"smus-local-mcp": {
4+
"command": "python",
5+
"args": ["/etc/sagemaker-ui/sagemaker-mcp/smus-mcp.py"]
6+
}
7+
}
8+
}
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
"""
2+
SageMaker Unified Studio Project Context MCP Server in stdio transport.
3+
4+
"""
5+
6+
import json
7+
import logging
8+
import os
9+
import re
10+
from typing import Any, Dict
11+
12+
from mcp.server.fastmcp import FastMCP
13+
14+
# Configure logging
15+
logging.basicConfig(level=logging.INFO)
16+
logger = logging.getLogger(__name__)
17+
18+
19+
class ProjectContext:
20+
"""
21+
A class that encapsulates AWS session, project object, and region.
22+
This class simplifies the common pattern of setting up an AWS session,
23+
extracting credentials, and getting a project.
24+
"""
25+
26+
def __init__(self):
27+
try:
28+
datazone_domain_id = os.getenv("AmazonDataZoneDomain")
29+
datazone_project_id = os.getenv("AmazonDataZoneProject")
30+
aws_region = os.getenv("AWS_REGION")
31+
if datazone_domain_id and datazone_project_id and aws_region:
32+
self.domain_id = datazone_domain_id
33+
self.project_id = datazone_project_id
34+
self.region = aws_region
35+
else:
36+
with open("/opt/ml/metadata/resource-metadata.json", "r") as metadata_file:
37+
metadata = json.load(metadata_file)
38+
self.domain_id = metadata["AdditionalMetadata"]["DataZoneDomainId"]
39+
self.project_id = metadata["AdditionalMetadata"]["DataZoneProjectId"]
40+
self.region = metadata["AdditionalMetadata"]["DataZoneDomainRegion"]
41+
except Exception as e:
42+
raise RuntimeError(f"Failed to initialize project: {e}")
43+
44+
if not re.match("^dzd[-_][a-zA-Z0-9_-]{1,36}$", self.domain_id):
45+
raise RuntimeError(f"Invalid domain id")
46+
if not re.match("^[a-zA-Z0-9_-]{1,36}$", self.project_id):
47+
raise RuntimeError(f"Invalid project id")
48+
if not re.match("^[a-z]{2}-[a-z]{4,10}-\\d$", self.region):
49+
raise RuntimeError(f"Invalid region")
50+
51+
52+
def safe_get_attr(obj: Any, attr: str, default: Any = None) -> Any:
53+
"""Safely get an attribute from an object."""
54+
if obj is None:
55+
return default
56+
57+
try:
58+
if hasattr(obj, attr):
59+
value = getattr(obj, attr)
60+
# Handle case where attribute access might throw RuntimeError
61+
if callable(value):
62+
try:
63+
return value()
64+
except (RuntimeError, Exception) as e:
65+
logger.error(f"Error calling attribute {attr}: {e}")
66+
return default
67+
return value
68+
return default
69+
except Exception as e:
70+
logger.error(f"Error getting attribute {attr}: {e}")
71+
return default
72+
73+
74+
def create_smus_context_identifiers_response(domain_id: str, project_id: str, region: str) -> str:
75+
76+
return f"""Selectively use the below parameters only when the parameter is required.
77+
<parameter>
78+
domain identifier: "{domain_id}"
79+
project identifier: "{project_id}"
80+
region: "{region}"
81+
aws profiles: "DomainExecutionRoleCreds, default"
82+
</parameter>
83+
Again, include only required parameters. Any extra parameters may cause the API to fail. Stick strictly to the schema."""
84+
85+
86+
async def aws_context_provider() -> Dict[str, Any]:
87+
"""
88+
AWS Context Provider - MUST BE CALLED BEFORE ANY use_aws OPERATIONS
89+
90+
This tool provides essential AWS context parameters that are required by subsequent AWS operations.
91+
It returns configuration details including domain identifiers, project information, and region
92+
settings that would otherwise need to be manually specified with each use_aws call.
93+
94+
The returned parameters include:
95+
- domain identifier: Unique identifier for the AWS DataZone domain
96+
- project identifier: Identifier for the specific project being worked on
97+
- profile name: Name of the AWS profile to use for credentials
98+
- region: AWS region where operations should be performed
99+
- aws profiles: use the aws profile named DomainExecutionRoleCreds for calling datazone APIs; otherwise use default AWS profile
100+
101+
Returns:
102+
dict: Parameter context to be used with subsequent use_aws operations
103+
"""
104+
identifiers_response = ""
105+
try:
106+
ctx = ProjectContext()
107+
domain_id = safe_get_attr(ctx, "domain_id", "")
108+
project_id = safe_get_attr(ctx, "project_id", "")
109+
region = safe_get_attr(ctx, "region", "")
110+
identifiers_response = create_smus_context_identifiers_response(domain_id, project_id, region)
111+
return {"response": identifiers_response}
112+
except Exception as e:
113+
logger.error(f"Error providing SMUS context identifiers: {e}")
114+
return {"response": identifiers_response, "error": "Error providing SMUS context identifiers"}
115+
116+
117+
if __name__ == "__main__":
118+
mcp: FastMCP = FastMCP("SageMakerUnififedStudio Project Context MCP Server")
119+
120+
# Register the tools from tools.py
121+
mcp.tool()(aws_context_provider) # use the doc string of the function as description, do not overwrite here.
122+
123+
mcp.run(transport="stdio")

template/v2/dirs/etc/sagemaker-ui/sagemaker_ui_post_startup.sh

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,80 @@ else
208208
echo readonly LOGNAME >> ~/.bashrc
209209
fi
210210

211+
# Setup Q CLI auth mode
212+
q_settings_file="$HOME/.aws/amazon_q/settings.json"
213+
if [ -f "$q_settings_file" ]; then
214+
q_auth_mode=$(jq -r '.auth_mode' < $q_settings_file)
215+
if [ "$q_auth_mode" == "IAM" ]; then
216+
export AMAZON_Q_SIGV4=true
217+
else
218+
export AMAZON_Q_SIGV4=false
219+
fi
220+
else
221+
export AMAZON_Q_SIGV4=true
222+
fi
223+
224+
if $AMAZON_Q_SIGV4; then
225+
if grep -q "^export AMAZON_Q_SIGV4=" ~/.bashrc; then
226+
echo "AMAZON_Q_SIGV4 is defined in the env"
227+
else
228+
echo export AMAZON_Q_SIGV4=$AMAZON_Q_SIGV4 >> ~/.bashrc
229+
fi
230+
else
231+
# Remove from .bashrc if it exists
232+
sed -i '/^export AMAZON_Q_SIGV4=/d' ~/.bashrc
233+
fi
234+
235+
# Setup SageMaker MCP configuration
236+
echo "Setting up SageMaker MCP configuration..."
237+
mkdir -p $HOME/.aws/amazonq/
238+
target_file="$HOME/.aws/amazonq/mcp.json"
239+
source_file="/etc/sagemaker-ui/sagemaker-mcp/mcp.json"
240+
241+
if [ -f "$source_file" ]; then
242+
# Extract all servers from source configuration
243+
if [ -f "$target_file" ]; then
244+
# Target file exists - merge configurations
245+
echo "Existing MCP configuration found, merging configurations..."
246+
247+
# Check if it's valid JSON first
248+
if jq empty "$target_file" 2>/dev/null; then
249+
# Initialize mcpServers object if it doesn't exist
250+
if ! jq -e '.mcpServers' "$target_file" >/dev/null 2>&1; then
251+
echo "Creating mcpServers object in existing configuration"
252+
jq '. + {"mcpServers":{}}' "$target_file" > "$target_file.tmp"
253+
mv "$target_file.tmp" "$target_file"
254+
fi
255+
256+
servers=$(jq '.mcpServers | keys[]' "$source_file" | tr -d '"')
257+
258+
# Add each server from source to target if it doesn't exist
259+
for server in $servers; do
260+
if ! jq -e ".mcpServers.\"$server\"" "$target_file" >/dev/null 2>&1; then
261+
server_config=$(jq ".mcpServers.\"$server\"" "$source_file")
262+
jq --arg name "$server" --argjson config "$server_config" \
263+
'.mcpServers[$name] = $config' "$target_file" > "$target_file.tmp"
264+
mv "$target_file.tmp" "$target_file"
265+
echo "Added server '$server' to existing configuration"
266+
else
267+
echo "Server '$server' already exists in configuration"
268+
fi
269+
done
270+
else
271+
echo "Warning: Existing MCP configuration is not valid JSON, replacing with default configuration"
272+
cp "$source_file" "$target_file"
273+
fi
274+
else
275+
# File doesn't exist, copy our configuration
276+
cp "$source_file" "$target_file"
277+
echo "Created new MCP configuration with default servers"
278+
fi
279+
280+
echo "Successfully configured MCP for SageMaker"
281+
else
282+
echo "Warning: MCP configuration file not found at $source_file"
283+
fi
284+
211285
# Generate sagemaker pysdk intelligent default config
212286
nohup python /etc/sagemaker/sm_pysdk_default_config.py &
213287
# Only run the following commands if SAGEMAKER_APP_TYPE_LOWERCASE is jupyterlab

template/v2/dirs/usr/local/bin/entrypoint-sagemaker-ui-code-editor

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ micromamba activate base
1111

1212
export SAGEMAKER_APP_TYPE_LOWERCASE=$(echo $SAGEMAKER_APP_TYPE | tr '[:upper:]' '[:lower:]')
1313
export SERVICE_NAME='SageMakerUnifiedStudio'
14+
export Q_CLI_CLIENT_APPLICATION='SMUS_CODE_EDITOR'
1415

1516
mkdir -p $STUDIO_LOGGING_DIR/$SAGEMAKER_APP_TYPE_LOWERCASE/supervisord
1617
exec supervisord -c /etc/supervisor/conf.d/supervisord-sagemaker-ui-code-editor.conf -n

template/v2/dirs/usr/local/bin/entrypoint-sagemaker-ui-jupyter-server

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ fi
2121

2222
export SAGEMAKER_APP_TYPE_LOWERCASE=$(echo $SAGEMAKER_APP_TYPE | tr '[:upper:]' '[:lower:]')
2323
export SERVICE_NAME='SageMakerUnifiedStudio'
24+
export Q_CLI_CLIENT_APPLICATION='SMUS_JUPYTER_LAB'
2425

2526
mkdir -p $STUDIO_LOGGING_DIR/$SAGEMAKER_APP_TYPE_LOWERCASE/supervisord
2627
exec supervisord -c /etc/supervisor/conf.d/supervisord-sagemaker-ui.conf -n

template/v3/Dockerfile

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,12 @@ RUN apt-get update && apt-get upgrade -y && \
6262
sudo ./aws/install && \
6363
rm -rf aws awscliv2.zip && \
6464
: && \
65+
# Install Q CLI
66+
curl --proto '=https' --tlsv1.2 -sSf "https://desktop-release.q.us-east-1.amazonaws.com/latest/q-x86_64-linux.zip" -o "q.zip" && \
67+
unzip q.zip && \
68+
Q_INSTALL_GLOBAL=true ./q/install.sh --no-confirm && \
69+
rm -rf q q.zip && \
70+
: && \
6571
echo "source /usr/local/bin/_activate_current_env.sh" | tee --append /etc/profile && \
6672
# CodeEditor - create server, user data dirs
6773
mkdir -p /opt/amazon/sagemaker/sagemaker-code-editor-server-data /opt/amazon/sagemaker/sagemaker-code-editor-user-data \
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
{
2+
"mcpServers": {
3+
"smus-local-mcp": {
4+
"command": "python",
5+
"args": ["/etc/sagemaker-ui/sagemaker-mcp/smus-mcp.py"]
6+
}
7+
}
8+
}
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
"""
2+
SageMaker Unified Studio Project Context MCP Server in stdio transport.
3+
4+
"""
5+
6+
import json
7+
import logging
8+
import os
9+
import re
10+
from typing import Any, Dict
11+
12+
from mcp.server.fastmcp import FastMCP
13+
14+
# Configure logging
15+
logging.basicConfig(level=logging.INFO)
16+
logger = logging.getLogger(__name__)
17+
18+
19+
class ProjectContext:
20+
"""
21+
A class that encapsulates AWS session, project object, and region.
22+
This class simplifies the common pattern of setting up an AWS session,
23+
extracting credentials, and getting a project.
24+
"""
25+
26+
def __init__(self):
27+
try:
28+
datazone_domain_id = os.getenv("AmazonDataZoneDomain")
29+
datazone_project_id = os.getenv("AmazonDataZoneProject")
30+
aws_region = os.getenv("AWS_REGION")
31+
if datazone_domain_id and datazone_project_id and aws_region:
32+
self.domain_id = datazone_domain_id
33+
self.project_id = datazone_project_id
34+
self.region = aws_region
35+
else:
36+
with open("/opt/ml/metadata/resource-metadata.json", "r") as metadata_file:
37+
metadata = json.load(metadata_file)
38+
self.domain_id = metadata["AdditionalMetadata"]["DataZoneDomainId"]
39+
self.project_id = metadata["AdditionalMetadata"]["DataZoneProjectId"]
40+
self.region = metadata["AdditionalMetadata"]["DataZoneDomainRegion"]
41+
except Exception as e:
42+
raise RuntimeError(f"Failed to initialize project: {e}")
43+
44+
if not re.match("^dzd[-_][a-zA-Z0-9_-]{1,36}$", self.domain_id):
45+
raise RuntimeError(f"Invalid domain id")
46+
if not re.match("^[a-zA-Z0-9_-]{1,36}$", self.project_id):
47+
raise RuntimeError(f"Invalid project id")
48+
if not re.match("^[a-z]{2}-[a-z]{4,10}-\\d$", self.region):
49+
raise RuntimeError(f"Invalid region")
50+
51+
52+
def safe_get_attr(obj: Any, attr: str, default: Any = None) -> Any:
53+
"""Safely get an attribute from an object."""
54+
if obj is None:
55+
return default
56+
57+
try:
58+
if hasattr(obj, attr):
59+
value = getattr(obj, attr)
60+
# Handle case where attribute access might throw RuntimeError
61+
if callable(value):
62+
try:
63+
return value()
64+
except (RuntimeError, Exception) as e:
65+
logger.error(f"Error calling attribute {attr}: {e}")
66+
return default
67+
return value
68+
return default
69+
except Exception as e:
70+
logger.error(f"Error getting attribute {attr}: {e}")
71+
return default
72+
73+
74+
def create_smus_context_identifiers_response(domain_id: str, project_id: str, region: str) -> str:
75+
76+
return f"""Selectively use the below parameters only when the parameter is required.
77+
<parameter>
78+
domain identifier: "{domain_id}"
79+
project identifier: "{project_id}"
80+
region: "{region}"
81+
aws profiles: "DomainExecutionRoleCreds, default"
82+
</parameter>
83+
Again, include only required parameters. Any extra parameters may cause the API to fail. Stick strictly to the schema."""
84+
85+
86+
async def aws_context_provider() -> Dict[str, Any]:
87+
"""
88+
AWS Context Provider - MUST BE CALLED BEFORE ANY use_aws OPERATIONS
89+
90+
This tool provides essential AWS context parameters that are required by subsequent AWS operations.
91+
It returns configuration details including domain identifiers, project information, and region
92+
settings that would otherwise need to be manually specified with each use_aws call.
93+
94+
The returned parameters include:
95+
- domain identifier: Unique identifier for the AWS DataZone domain
96+
- project identifier: Identifier for the specific project being worked on
97+
- profile name: Name of the AWS profile to use for credentials
98+
- region: AWS region where operations should be performed
99+
- aws profiles: use the aws profile named DomainExecutionRoleCreds for calling datazone APIs; otherwise use default AWS profile
100+
101+
Returns:
102+
dict: Parameter context to be used with subsequent use_aws operations
103+
"""
104+
identifiers_response = ""
105+
try:
106+
ctx = ProjectContext()
107+
domain_id = safe_get_attr(ctx, "domain_id", "")
108+
project_id = safe_get_attr(ctx, "project_id", "")
109+
region = safe_get_attr(ctx, "region", "")
110+
identifiers_response = create_smus_context_identifiers_response(domain_id, project_id, region)
111+
return {"response": identifiers_response}
112+
except Exception as e:
113+
logger.error(f"Error providing SMUS context identifiers: {e}")
114+
return {"response": identifiers_response, "error": "Error providing SMUS context identifiers"}
115+
116+
117+
if __name__ == "__main__":
118+
mcp: FastMCP = FastMCP("SageMakerUnififedStudio Project Context MCP Server")
119+
120+
# Register the tools from tools.py
121+
mcp.tool()(aws_context_provider) # use the doc string of the function as description, do not overwrite here.
122+
123+
mcp.run(transport="stdio")

0 commit comments

Comments
 (0)