Skip to content

Commit 1a8482b

Browse files
Update PySDK config to work correctly when subnetds/securitygroups are not present (#801)
1 parent ff9b859 commit 1a8482b

File tree

2 files changed

+62
-88
lines changed

2 files changed

+62
-88
lines changed

template/v2/dirs/etc/sagemaker/sm_pysdk_default_config.py

Lines changed: 31 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77

88
def generate_intelligent_default_config(metadata: str) -> dict:
9+
has_vpc = metadata["SecurityGroupIds"] and metadata["Subnets"] and metadata["SecurityGroupIds"] != [''] and metadata["Subnets"] != ['']
10+
911
config = {
1012
"SchemaVersion": "1.0",
1113
"SageMaker": {
@@ -17,61 +19,46 @@ def generate_intelligent_default_config(metadata: str) -> dict:
1719
},
1820
"RemoteFunction": {
1921
"IncludeLocalWorkDir": True,
20-
"VpcConfig": {"SecurityGroupIds": metadata["SecurityGroupIds"], "Subnets": metadata["Subnets"]},
2122
},
2223
"NotebookJob": {
2324
"RoleArn": metadata["UserRoleArn"],
2425
"S3RootUri": f"s3://{metadata['S3Bucket']}/{metadata['S3ObjectKeyPrefix']}",
25-
"VpcConfig": {"SecurityGroupIds": metadata["SecurityGroupIds"], "Subnets": metadata["Subnets"]},
2626
},
2727
"Serve": {"S3ModelDataUri": f"s3://{metadata['S3Bucket']}/{metadata['S3ObjectKeyPrefix']}"},
2828
}
2929
},
30-
"MonitoringSchedule": {
31-
"MonitoringScheduleConfig": {
32-
"MonitoringJobDefinition": {
33-
"NetworkConfig": {
34-
"VpcConfig": {
35-
"SecurityGroupIds": metadata["SecurityGroupIds"],
36-
"Subnets": metadata["Subnets"],
37-
}
38-
}
39-
}
40-
}
41-
},
42-
"AutoMLJob": {
43-
"AutoMLJobConfig": {
44-
"SecurityConfig": {
45-
"VpcConfig": {"SecurityGroupIds": metadata["SecurityGroupIds"], "Subnets": metadata["Subnets"]}
46-
}
47-
}
48-
},
49-
"AutoMLJobV2": {
50-
"SecurityConfig": {
51-
"VpcConfig": {"SecurityGroupIds": metadata["SecurityGroupIds"], "Subnets": metadata["Subnets"]}
52-
}
53-
},
54-
"CompilationJob": {
55-
"VpcConfig": {"SecurityGroupIds": metadata["SecurityGroupIds"], "Subnets": metadata["Subnets"]}
56-
},
5730
"Pipeline": {"RoleArn": metadata["UserRoleArn"]},
58-
"Model": {
59-
"VpcConfig": {"SecurityGroupIds": metadata["SecurityGroupIds"], "Subnets": metadata["Subnets"]},
60-
"ExecutionRoleArn": metadata["UserRoleArn"],
61-
},
31+
"Model": {"ExecutionRoleArn": metadata["UserRoleArn"]},
6232
"ModelPackage": {"ValidationSpecification": {"ValidationRole": metadata["UserRoleArn"]}},
63-
"ProcessingJob": {
64-
"NetworkConfig": {
65-
"VpcConfig": {"SecurityGroupIds": metadata["SecurityGroupIds"], "Subnets": metadata["Subnets"]}
66-
},
67-
"RoleArn": metadata["UserRoleArn"],
68-
},
69-
"TrainingJob": {
70-
"RoleArn": metadata["UserRoleArn"],
71-
"VpcConfig": {"SecurityGroupIds": metadata["SecurityGroupIds"], "Subnets": metadata["Subnets"]},
72-
},
33+
"ProcessingJob": {"RoleArn": metadata["UserRoleArn"]},
34+
"TrainingJob": {"RoleArn": metadata["UserRoleArn"]},
7335
},
7436
}
37+
38+
if has_vpc:
39+
vpc_config = {"SecurityGroupIds": metadata["SecurityGroupIds"], "Subnets": metadata["Subnets"]}
40+
config["SageMaker"]["PythonSDK"]["Modules"]["RemoteFunction"]["VpcConfig"] = vpc_config
41+
config["SageMaker"]["PythonSDK"]["Modules"]["NotebookJob"]["VpcConfig"] = vpc_config
42+
config["SageMaker"]["MonitoringSchedule"] = {
43+
"MonitoringScheduleConfig": {
44+
"MonitoringJobDefinition": {
45+
"NetworkConfig": {"VpcConfig": vpc_config}
46+
}
47+
}
48+
}
49+
config["SageMaker"]["AutoMLJob"] = {
50+
"AutoMLJobConfig": {
51+
"SecurityConfig": {"VpcConfig": vpc_config}
52+
}
53+
}
54+
config["SageMaker"]["AutoMLJobV2"] = {
55+
"SecurityConfig": {"VpcConfig": vpc_config}
56+
}
57+
config["SageMaker"]["CompilationJob"] = {"VpcConfig": vpc_config}
58+
config["SageMaker"]["Model"]["VpcConfig"] = vpc_config
59+
config["SageMaker"]["ProcessingJob"]["NetworkConfig"] = {"VpcConfig": vpc_config}
60+
config["SageMaker"]["TrainingJob"]["VpcConfig"] = vpc_config
61+
7562
return config
7663

7764

@@ -106,7 +93,7 @@ def generate_intelligent_default_config(metadata: str) -> dict:
10693
}
10794

10895
# Not create config file when invalid value exists in metadata
109-
empty_values = [key for key, value in metadata.items() if value == "" or value == [""]]
96+
empty_values = [key for key, value in metadata.items() if key not in ["SecurityGroupIds", "Subnets"] and (value == "" or value == [""])]
11097
if empty_values:
11198
raise AttributeError(f"There are empty values in the metadata: {empty_values}")
11299

template/v3/dirs/etc/sagemaker/sm_pysdk_default_config.py

Lines changed: 31 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66

77

88
def generate_intelligent_default_config(metadata: str) -> dict:
9+
has_vpc = metadata["SecurityGroupIds"] and metadata["Subnets"] and metadata["SecurityGroupIds"] != [''] and metadata["Subnets"] != ['']
10+
vpc_config = {"SecurityGroupIds": metadata["SecurityGroupIds"], "Subnets": metadata["Subnets"]} if has_vpc else None
11+
912
config = {
1013
"SchemaVersion": "1.0",
1114
"SageMaker": {
@@ -17,61 +20,45 @@ def generate_intelligent_default_config(metadata: str) -> dict:
1720
},
1821
"RemoteFunction": {
1922
"IncludeLocalWorkDir": True,
20-
"VpcConfig": {"SecurityGroupIds": metadata["SecurityGroupIds"], "Subnets": metadata["Subnets"]},
2123
},
2224
"NotebookJob": {
2325
"RoleArn": metadata["UserRoleArn"],
2426
"S3RootUri": f"s3://{metadata['S3Bucket']}/{metadata['S3ObjectKeyPrefix']}",
25-
"VpcConfig": {"SecurityGroupIds": metadata["SecurityGroupIds"], "Subnets": metadata["Subnets"]},
2627
},
2728
"Serve": {"S3ModelDataUri": f"s3://{metadata['S3Bucket']}/{metadata['S3ObjectKeyPrefix']}"},
2829
}
2930
},
30-
"MonitoringSchedule": {
31-
"MonitoringScheduleConfig": {
32-
"MonitoringJobDefinition": {
33-
"NetworkConfig": {
34-
"VpcConfig": {
35-
"SecurityGroupIds": metadata["SecurityGroupIds"],
36-
"Subnets": metadata["Subnets"],
37-
}
38-
}
39-
}
40-
}
41-
},
42-
"AutoMLJob": {
43-
"AutoMLJobConfig": {
44-
"SecurityConfig": {
45-
"VpcConfig": {"SecurityGroupIds": metadata["SecurityGroupIds"], "Subnets": metadata["Subnets"]}
46-
}
47-
}
48-
},
49-
"AutoMLJobV2": {
50-
"SecurityConfig": {
51-
"VpcConfig": {"SecurityGroupIds": metadata["SecurityGroupIds"], "Subnets": metadata["Subnets"]}
52-
}
53-
},
54-
"CompilationJob": {
55-
"VpcConfig": {"SecurityGroupIds": metadata["SecurityGroupIds"], "Subnets": metadata["Subnets"]}
56-
},
5731
"Pipeline": {"RoleArn": metadata["UserRoleArn"]},
58-
"Model": {
59-
"VpcConfig": {"SecurityGroupIds": metadata["SecurityGroupIds"], "Subnets": metadata["Subnets"]},
60-
"ExecutionRoleArn": metadata["UserRoleArn"],
61-
},
32+
"Model": {"ExecutionRoleArn": metadata["UserRoleArn"]},
6233
"ModelPackage": {"ValidationSpecification": {"ValidationRole": metadata["UserRoleArn"]}},
63-
"ProcessingJob": {
64-
"NetworkConfig": {
65-
"VpcConfig": {"SecurityGroupIds": metadata["SecurityGroupIds"], "Subnets": metadata["Subnets"]}
66-
},
67-
"RoleArn": metadata["UserRoleArn"],
68-
},
69-
"TrainingJob": {
70-
"RoleArn": metadata["UserRoleArn"],
71-
"VpcConfig": {"SecurityGroupIds": metadata["SecurityGroupIds"], "Subnets": metadata["Subnets"]},
72-
},
34+
"ProcessingJob": {"RoleArn": metadata["UserRoleArn"]},
35+
"TrainingJob": {"RoleArn": metadata["UserRoleArn"]},
7336
},
7437
}
38+
39+
if has_vpc:
40+
config["SageMaker"]["PythonSDK"]["Modules"]["RemoteFunction"]["VpcConfig"] = vpc_config
41+
config["SageMaker"]["PythonSDK"]["Modules"]["NotebookJob"]["VpcConfig"] = vpc_config
42+
config["SageMaker"]["MonitoringSchedule"] = {
43+
"MonitoringScheduleConfig": {
44+
"MonitoringJobDefinition": {
45+
"NetworkConfig": {"VpcConfig": vpc_config}
46+
}
47+
}
48+
}
49+
config["SageMaker"]["AutoMLJob"] = {
50+
"AutoMLJobConfig": {
51+
"SecurityConfig": {"VpcConfig": vpc_config}
52+
}
53+
}
54+
config["SageMaker"]["AutoMLJobV2"] = {
55+
"SecurityConfig": {"VpcConfig": vpc_config}
56+
}
57+
config["SageMaker"]["CompilationJob"] = {"VpcConfig": vpc_config}
58+
config["SageMaker"]["Model"]["VpcConfig"] = vpc_config
59+
config["SageMaker"]["ProcessingJob"]["NetworkConfig"] = {"VpcConfig": vpc_config}
60+
config["SageMaker"]["TrainingJob"]["VpcConfig"] = vpc_config
61+
7562
return config
7663

7764

@@ -106,7 +93,7 @@ def generate_intelligent_default_config(metadata: str) -> dict:
10693
}
10794

10895
# Not create config file when invalid value exists in metadata
109-
empty_values = [key for key, value in metadata.items() if value == "" or value == [""]]
96+
empty_values = [key for key, value in metadata.items() if key not in ["SecurityGroupIds", "Subnets"] and (value == "" or value == [""])]
11097
if empty_values:
11198
raise AttributeError(f"There are empty values in the metadata: {empty_values}")
11299

0 commit comments

Comments
 (0)