Skip to content

Commit 7a33185

Browse files
Add network validation script executed in the sagemaker_ui_post_startup script (#713)
Co-authored-by: Marco Friaz <[email protected]>
1 parent dd1016b commit 7a33185

File tree

4 files changed

+400
-0
lines changed

4 files changed

+400
-0
lines changed
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
#!/bin/bash
2+
set -eux
3+
4+
# Input parameters with defaults:
5+
# Default to 1 (Git storage) if no parameter is passed.
6+
is_s3_storage=${1:-"1"}
7+
# Output file path for unreachable services JSON
8+
network_validation_file=${2:-"/tmp/.network_validation.json"}
9+
10+
# Function to write unreachable services to a JSON file
11+
write_unreachable_services_to_file() {
12+
local value="$1"
13+
local file="$network_validation_file"
14+
15+
# Create the file if it doesn't exist
16+
if [ ! -f "$file" ]; then
17+
touch "$file" || {
18+
echo "Failed to create $file" >&2
19+
return 0
20+
}
21+
fi
22+
23+
# Check file is writable
24+
if [ ! -w "$file" ]; then
25+
echo "Error: $file is not writable" >&2
26+
return 0
27+
fi
28+
29+
# Write JSON object with UnreachableServices key and the comma-separated list value
30+
jq -n --arg value "$value" '{"UnreachableServices": $value}' > "$file"
31+
}
32+
33+
# Configure AWS CLI region using environment variable REGION_NAME
34+
aws configure set region "${REGION_NAME}"
35+
echo "Successfully configured region to ${REGION_NAME}"
36+
37+
# Metadata file location containing DataZone info
38+
sourceMetaData=/opt/ml/metadata/resource-metadata.json
39+
40+
# Extract necessary DataZone metadata fields via jq
41+
dataZoneDomainId=$(jq -r '.AdditionalMetadata.DataZoneDomainId' < "$sourceMetaData")
42+
dataZoneProjectId=$(jq -r '.AdditionalMetadata.DataZoneProjectId' < "$sourceMetaData")
43+
dataZoneEndPoint=$(jq -r '.AdditionalMetadata.DataZoneEndpoint' < "$sourceMetaData")
44+
dataZoneDomainRegion=$(jq -r '.AdditionalMetadata.DataZoneDomainRegion' < "$sourceMetaData")
45+
s3Path=$(jq -r '.AdditionalMetadata.ProjectS3Path' < "$sourceMetaData")
46+
47+
# Extract bucket name, fallback to empty string if not found
48+
s3ValidationBucket=$(echo "${s3Path:-}" | sed -E 's#s3://([^/]+).*#\1#')
49+
50+
# Call AWS CLI list-connections, including endpoint if specified
51+
if [ -n "$dataZoneEndPoint" ]; then
52+
response=$(aws datazone list-connections \
53+
--endpoint-url "$dataZoneEndPoint" \
54+
--domain-identifier "$dataZoneDomainId" \
55+
--project-identifier "$dataZoneProjectId" \
56+
--region "$dataZoneDomainRegion")
57+
else
58+
response=$(aws datazone list-connections \
59+
--domain-identifier "$dataZoneDomainId" \
60+
--project-identifier "$dataZoneProjectId" \
61+
--region "$dataZoneDomainRegion")
62+
fi
63+
64+
# Extract each connection item as a compact JSON string
65+
connection_items=$(echo "$response" | jq -c '.items[]')
66+
67+
# Required AWS Services for Compute connections and Git
68+
# Initialize SERVICE_COMMANDS with always-needed STS and S3 checks
69+
declare -A SERVICE_COMMANDS=(
70+
["STS"]="aws sts get-caller-identity"
71+
["S3"]="aws s3api list-objects --bucket \"$s3ValidationBucket\" --max-items 1"
72+
)
73+
74+
# Track connection types found for conditional checks
75+
declare -A seen_types=()
76+
77+
# Iterate over each connection to populate service commands conditionally
78+
while IFS= read -r item; do
79+
# Extract connection type
80+
type=$(echo "$item" | jq -r '.type')
81+
seen_types["$type"]=1
82+
83+
# For SPARK connections, check for Glue and EMR properties
84+
if [[ "$type" == "SPARK" ]]; then
85+
# If sparkGlueProperties present, add Glue check
86+
if echo "$item" | jq -e '.props.sparkGlueProperties' > /dev/null; then
87+
SERVICE_COMMANDS["Glue"]="aws glue get-databases --max-items 1"
88+
fi
89+
90+
# Check for emr-serverless in sparkEmrProperties.computeArn for EMR Serverless check
91+
emr_arn=$(echo "$item" | jq -r '.props.sparkEmrProperties.computeArn // empty')
92+
if [[ "$emr_arn" == *"emr-serverless"* && "$emr_arn" == *"/applications/"* ]]; then
93+
# Extract the application ID from the ARN
94+
emr_app_id=$(echo "$emr_arn" | sed -E 's#.*/applications/([^/]+)#\1#')
95+
96+
# Only set the service command if the application ID is valid
97+
if [[ -n "$emr_app_id" ]]; then
98+
SERVICE_COMMANDS["EMR Serverless"]="aws emr-serverless get-application --application-id \"$emr_app_id\""
99+
fi
100+
fi
101+
fi
102+
done <<< "$connection_items"
103+
104+
# Add Athena if ATHENA connection found
105+
[[ -n "${seen_types["ATHENA"]}" ]] && SERVICE_COMMANDS["Athena"]="aws athena list-data-catalogs --max-items 1"
106+
107+
# Add Redshift checks if REDSHIFT connection found
108+
if [[ -n "${seen_types["REDSHIFT"]}" ]]; then
109+
SERVICE_COMMANDS["Redshift Clusters"]="aws redshift describe-clusters --max-records 20"
110+
SERVICE_COMMANDS["Redshift Serverless"]="aws redshift-serverless list-namespaces --max-results 1"
111+
fi
112+
113+
# If using Git Storage (S3 storage flag == 1), check CodeConnections connectivity
114+
# Domain Execution role contains permissions for CodeConnections
115+
if [[ "$is_s3_storage" == "1" ]]; then
116+
SERVICE_COMMANDS["CodeConnections"]="aws codeconnections list-connections --max-results 1 --profile DomainExecutionRoleCreds"
117+
fi
118+
119+
# Timeout (seconds) for each API call
120+
api_time_out_limit=10
121+
# Array to accumulate unreachable services
122+
unreachable_services=()
123+
# Create a temporary directory to store individual service results
124+
temp_dir=$(mktemp -d)
125+
126+
# Launch all service API checks in parallel background jobs
127+
for service in "${!SERVICE_COMMANDS[@]}"; do
128+
{
129+
# Run command with timeout, discard stdout/stderr
130+
if timeout "${api_time_out_limit}s" bash -c "${SERVICE_COMMANDS[$service]}" > /dev/null 2>&1; then
131+
# Success: write OK to temp file
132+
echo "OK" > "$temp_dir/$service"
133+
else
134+
# Get exit code to differentiate timeout or other errors
135+
exit_code=$?
136+
if [ "$exit_code" -eq 124 ]; then
137+
# Timeout exit code
138+
echo "TIMEOUT" > "$temp_dir/$service"
139+
else
140+
# Other errors (e.g., permission denied)
141+
echo "ERROR" > "$temp_dir/$service"
142+
fi
143+
fi
144+
} &
145+
done
146+
147+
# Wait for all background jobs to complete before continuing
148+
wait
149+
150+
# Process each service's result file to identify unreachable ones
151+
for service in "${!SERVICE_COMMANDS[@]}"; do
152+
result_file="$temp_dir/$service"
153+
if [ -f "$result_file" ]; then
154+
result=$(<"$result_file")
155+
if [[ "$result" == "TIMEOUT" ]]; then
156+
echo "$service API did NOT resolve within ${api_time_out_limit}s. Marking as unreachable."
157+
unreachable_services+=("$service")
158+
elif [[ "$result" == "OK" ]]; then
159+
echo "$service API is reachable."
160+
else
161+
echo "$service API returned an error (but not a timeout). Ignored for network check."
162+
fi
163+
else
164+
echo "$service check did not produce a result file. Skipping."
165+
fi
166+
done
167+
168+
# Cleanup temporary directory
169+
rm -rf "$temp_dir"
170+
171+
# Write unreachable services to file if any, else write empty string
172+
if (( ${#unreachable_services[@]} > 0 )); then
173+
joined_services=$(IFS=','; echo "${unreachable_services[*]}")
174+
# Add spaces after commas for readability
175+
joined_services_with_spaces=${joined_services//,/,\ }
176+
write_unreachable_services_to_file "$joined_services_with_spaces"
177+
echo "Unreachable AWS Services: ${joined_services_with_spaces}"
178+
else
179+
write_unreachable_services_to_file ""
180+
echo "All required AWS services reachable within ${api_time_out_limit}s"
181+
fi

template/v2/dirs/etc/sagemaker-ui/sagemaker_ui_post_startup.sh

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,4 +224,23 @@ if [ "${SAGEMAKER_APP_TYPE_LOWERCASE}" = "jupyterlab" ]; then
224224
bash /etc/sagemaker-ui/workflows/sm-spark-cli-install.sh
225225
fi
226226

227+
# Execute network validation script, to check if any required AWS Services are unreachable
228+
echo "Starting network validation script..."
229+
230+
network_validation_file="/tmp/.network_validation.json"
231+
232+
# Run the validation script; only if it succeeds, check unreachable services
233+
if bash /etc/sagemaker-ui/network_validation.sh "$is_s3_storage_flag" "$network_validation_file"; then
234+
# Read unreachable services from JSON file
235+
failed_services=$(jq -r '.UnreachableServices // empty' "$network_validation_file" || echo "")
236+
if [[ -n "$failed_services" ]]; then
237+
error_message="$failed_services are unreachable. Please contact your admin."
238+
# Example error message: Redshift Clusters, Athena, STS, Glue are unreachable. Please contact your admin.
239+
write_status_to_file "error" "$error_message"
240+
echo "$error_message"
241+
fi
242+
else
243+
echo "Warning: network_validation.sh failed, skipping unreachable services check."
244+
fi
245+
227246
write_status_to_file_on_script_complete
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
#!/bin/bash
2+
set -eux
3+
4+
# Input parameters with defaults:
5+
# Default to 1 (Git storage) if no parameter is passed.
6+
is_s3_storage=${1:-"1"}
7+
# Output file path for unreachable services JSON
8+
network_validation_file=${2:-"/tmp/.network_validation.json"}
9+
10+
# Function to write unreachable services to a JSON file
11+
write_unreachable_services_to_file() {
12+
local value="$1"
13+
local file="$network_validation_file"
14+
15+
# Create the file if it doesn't exist
16+
if [ ! -f "$file" ]; then
17+
touch "$file" || {
18+
echo "Failed to create $file" >&2
19+
return 0
20+
}
21+
fi
22+
23+
# Check file is writable
24+
if [ ! -w "$file" ]; then
25+
echo "Error: $file is not writable" >&2
26+
return 0
27+
fi
28+
29+
# Write JSON object with UnreachableServices key and the comma-separated list value
30+
jq -n --arg value "$value" '{"UnreachableServices": $value}' > "$file"
31+
}
32+
33+
# Configure AWS CLI region using environment variable REGION_NAME
34+
aws configure set region "${REGION_NAME}"
35+
echo "Successfully configured region to ${REGION_NAME}"
36+
37+
# Metadata file location containing DataZone info
38+
sourceMetaData=/opt/ml/metadata/resource-metadata.json
39+
40+
# Extract necessary DataZone metadata fields via jq
41+
dataZoneDomainId=$(jq -r '.AdditionalMetadata.DataZoneDomainId' < "$sourceMetaData")
42+
dataZoneProjectId=$(jq -r '.AdditionalMetadata.DataZoneProjectId' < "$sourceMetaData")
43+
dataZoneEndPoint=$(jq -r '.AdditionalMetadata.DataZoneEndpoint' < "$sourceMetaData")
44+
dataZoneDomainRegion=$(jq -r '.AdditionalMetadata.DataZoneDomainRegion' < "$sourceMetaData")
45+
s3Path=$(jq -r '.AdditionalMetadata.ProjectS3Path' < "$sourceMetaData")
46+
47+
# Extract bucket name, fallback to empty string if not found
48+
s3ValidationBucket=$(echo "${s3Path:-}" | sed -E 's#s3://([^/]+).*#\1#')
49+
50+
# Call AWS CLI list-connections, including endpoint if specified
51+
if [ -n "$dataZoneEndPoint" ]; then
52+
response=$(aws datazone list-connections \
53+
--endpoint-url "$dataZoneEndPoint" \
54+
--domain-identifier "$dataZoneDomainId" \
55+
--project-identifier "$dataZoneProjectId" \
56+
--region "$dataZoneDomainRegion")
57+
else
58+
response=$(aws datazone list-connections \
59+
--domain-identifier "$dataZoneDomainId" \
60+
--project-identifier "$dataZoneProjectId" \
61+
--region "$dataZoneDomainRegion")
62+
fi
63+
64+
# Extract each connection item as a compact JSON string
65+
connection_items=$(echo "$response" | jq -c '.items[]')
66+
67+
# Required AWS Services for Compute connections and Git
68+
# Initialize SERVICE_COMMANDS with always-needed STS and S3 checks
69+
declare -A SERVICE_COMMANDS=(
70+
["STS"]="aws sts get-caller-identity"
71+
["S3"]="aws s3api list-objects --bucket \"$s3ValidationBucket\" --max-items 1"
72+
)
73+
74+
# Track connection types found for conditional checks
75+
declare -A seen_types=()
76+
77+
# Iterate over each connection to populate service commands conditionally
78+
while IFS= read -r item; do
79+
# Extract connection type
80+
type=$(echo "$item" | jq -r '.type')
81+
seen_types["$type"]=1
82+
83+
# For SPARK connections, check for Glue and EMR properties
84+
if [[ "$type" == "SPARK" ]]; then
85+
# If sparkGlueProperties present, add Glue check
86+
if echo "$item" | jq -e '.props.sparkGlueProperties' > /dev/null; then
87+
SERVICE_COMMANDS["Glue"]="aws glue get-databases --max-items 1"
88+
fi
89+
90+
# Check for emr-serverless in sparkEmrProperties.computeArn for EMR Serverless check
91+
emr_arn=$(echo "$item" | jq -r '.props.sparkEmrProperties.computeArn // empty')
92+
if [[ "$emr_arn" == *"emr-serverless"* && "$emr_arn" == *"/applications/"* ]]; then
93+
# Extract the application ID from the ARN
94+
emr_app_id=$(echo "$emr_arn" | sed -E 's#.*/applications/([^/]+)#\1#')
95+
96+
# Only set the service command if the application ID is valid
97+
if [[ -n "$emr_app_id" ]]; then
98+
SERVICE_COMMANDS["EMR Serverless"]="aws emr-serverless get-application --application-id \"$emr_app_id\""
99+
fi
100+
fi
101+
fi
102+
done <<< "$connection_items"
103+
104+
# Add Athena if ATHENA connection found
105+
[[ -n "${seen_types["ATHENA"]}" ]] && SERVICE_COMMANDS["Athena"]="aws athena list-data-catalogs --max-items 1"
106+
107+
# Add Redshift checks if REDSHIFT connection found
108+
if [[ -n "${seen_types["REDSHIFT"]}" ]]; then
109+
SERVICE_COMMANDS["Redshift Clusters"]="aws redshift describe-clusters --max-records 20"
110+
SERVICE_COMMANDS["Redshift Serverless"]="aws redshift-serverless list-namespaces --max-results 1"
111+
fi
112+
113+
# If using Git Storage (S3 storage flag == 1), check CodeConnections connectivity
114+
# Domain Execution role contains permissions for CodeConnections
115+
if [[ "$is_s3_storage" == "1" ]]; then
116+
SERVICE_COMMANDS["CodeConnections"]="aws codeconnections list-connections --max-results 1 --profile DomainExecutionRoleCreds"
117+
fi
118+
119+
# Timeout (seconds) for each API call
120+
api_time_out_limit=10
121+
# Array to accumulate unreachable services
122+
unreachable_services=()
123+
# Create a temporary directory to store individual service results
124+
temp_dir=$(mktemp -d)
125+
126+
# Launch all service API checks in parallel background jobs
127+
for service in "${!SERVICE_COMMANDS[@]}"; do
128+
{
129+
# Run command with timeout, discard stdout/stderr
130+
if timeout "${api_time_out_limit}s" bash -c "${SERVICE_COMMANDS[$service]}" > /dev/null 2>&1; then
131+
# Success: write OK to temp file
132+
echo "OK" > "$temp_dir/$service"
133+
else
134+
# Get exit code to differentiate timeout or other errors
135+
exit_code=$?
136+
if [ "$exit_code" -eq 124 ]; then
137+
# Timeout exit code
138+
echo "TIMEOUT" > "$temp_dir/$service"
139+
else
140+
# Other errors (e.g., permission denied)
141+
echo "ERROR" > "$temp_dir/$service"
142+
fi
143+
fi
144+
} &
145+
done
146+
147+
# Wait for all background jobs to complete before continuing
148+
wait
149+
150+
# Process each service's result file to identify unreachable ones
151+
for service in "${!SERVICE_COMMANDS[@]}"; do
152+
result_file="$temp_dir/$service"
153+
if [ -f "$result_file" ]; then
154+
result=$(<"$result_file")
155+
if [[ "$result" == "TIMEOUT" ]]; then
156+
echo "$service API did NOT resolve within ${api_time_out_limit}s. Marking as unreachable."
157+
unreachable_services+=("$service")
158+
elif [[ "$result" == "OK" ]]; then
159+
echo "$service API is reachable."
160+
else
161+
echo "$service API returned an error (but not a timeout). Ignored for network check."
162+
fi
163+
else
164+
echo "$service check did not produce a result file. Skipping."
165+
fi
166+
done
167+
168+
# Cleanup temporary directory
169+
rm -rf "$temp_dir"
170+
171+
# Write unreachable services to file if any, else write empty string
172+
if (( ${#unreachable_services[@]} > 0 )); then
173+
joined_services=$(IFS=','; echo "${unreachable_services[*]}")
174+
# Add spaces after commas for readability
175+
joined_services_with_spaces=${joined_services//,/,\ }
176+
write_unreachable_services_to_file "$joined_services_with_spaces"
177+
echo "Unreachable AWS Services: ${joined_services_with_spaces}"
178+
else
179+
write_unreachable_services_to_file ""
180+
echo "All required AWS services reachable within ${api_time_out_limit}s"
181+
fi

0 commit comments

Comments
 (0)