6
6
7
7
8
8
def generate_intelligent_default_config (metadata : str ) -> dict :
9
+ has_vpc = metadata ["SecurityGroupIds" ] and metadata ["Subnets" ] and metadata ["SecurityGroupIds" ] != ['' ] and metadata ["Subnets" ] != ['' ]
10
+ vpc_config = {"SecurityGroupIds" : metadata ["SecurityGroupIds" ], "Subnets" : metadata ["Subnets" ]} if has_vpc else None
11
+
9
12
config = {
10
13
"SchemaVersion" : "1.0" ,
11
14
"SageMaker" : {
@@ -17,61 +20,45 @@ def generate_intelligent_default_config(metadata: str) -> dict:
17
20
},
18
21
"RemoteFunction" : {
19
22
"IncludeLocalWorkDir" : True ,
20
- "VpcConfig" : {"SecurityGroupIds" : metadata ["SecurityGroupIds" ], "Subnets" : metadata ["Subnets" ]},
21
23
},
22
24
"NotebookJob" : {
23
25
"RoleArn" : metadata ["UserRoleArn" ],
24
26
"S3RootUri" : f"s3://{ metadata ['S3Bucket' ]} /{ metadata ['S3ObjectKeyPrefix' ]} " ,
25
- "VpcConfig" : {"SecurityGroupIds" : metadata ["SecurityGroupIds" ], "Subnets" : metadata ["Subnets" ]},
26
27
},
27
28
"Serve" : {"S3ModelDataUri" : f"s3://{ metadata ['S3Bucket' ]} /{ metadata ['S3ObjectKeyPrefix' ]} " },
28
29
}
29
30
},
30
- "MonitoringSchedule" : {
31
- "MonitoringScheduleConfig" : {
32
- "MonitoringJobDefinition" : {
33
- "NetworkConfig" : {
34
- "VpcConfig" : {
35
- "SecurityGroupIds" : metadata ["SecurityGroupIds" ],
36
- "Subnets" : metadata ["Subnets" ],
37
- }
38
- }
39
- }
40
- }
41
- },
42
- "AutoMLJob" : {
43
- "AutoMLJobConfig" : {
44
- "SecurityConfig" : {
45
- "VpcConfig" : {"SecurityGroupIds" : metadata ["SecurityGroupIds" ], "Subnets" : metadata ["Subnets" ]}
46
- }
47
- }
48
- },
49
- "AutoMLJobV2" : {
50
- "SecurityConfig" : {
51
- "VpcConfig" : {"SecurityGroupIds" : metadata ["SecurityGroupIds" ], "Subnets" : metadata ["Subnets" ]}
52
- }
53
- },
54
- "CompilationJob" : {
55
- "VpcConfig" : {"SecurityGroupIds" : metadata ["SecurityGroupIds" ], "Subnets" : metadata ["Subnets" ]}
56
- },
57
31
"Pipeline" : {"RoleArn" : metadata ["UserRoleArn" ]},
58
- "Model" : {
59
- "VpcConfig" : {"SecurityGroupIds" : metadata ["SecurityGroupIds" ], "Subnets" : metadata ["Subnets" ]},
60
- "ExecutionRoleArn" : metadata ["UserRoleArn" ],
61
- },
32
+ "Model" : {"ExecutionRoleArn" : metadata ["UserRoleArn" ]},
62
33
"ModelPackage" : {"ValidationSpecification" : {"ValidationRole" : metadata ["UserRoleArn" ]}},
63
- "ProcessingJob" : {
64
- "NetworkConfig" : {
65
- "VpcConfig" : {"SecurityGroupIds" : metadata ["SecurityGroupIds" ], "Subnets" : metadata ["Subnets" ]}
66
- },
67
- "RoleArn" : metadata ["UserRoleArn" ],
68
- },
69
- "TrainingJob" : {
70
- "RoleArn" : metadata ["UserRoleArn" ],
71
- "VpcConfig" : {"SecurityGroupIds" : metadata ["SecurityGroupIds" ], "Subnets" : metadata ["Subnets" ]},
72
- },
34
+ "ProcessingJob" : {"RoleArn" : metadata ["UserRoleArn" ]},
35
+ "TrainingJob" : {"RoleArn" : metadata ["UserRoleArn" ]},
73
36
},
74
37
}
38
+
39
+ if has_vpc :
40
+ config ["SageMaker" ]["PythonSDK" ]["Modules" ]["RemoteFunction" ]["VpcConfig" ] = vpc_config
41
+ config ["SageMaker" ]["PythonSDK" ]["Modules" ]["NotebookJob" ]["VpcConfig" ] = vpc_config
42
+ config ["SageMaker" ]["MonitoringSchedule" ] = {
43
+ "MonitoringScheduleConfig" : {
44
+ "MonitoringJobDefinition" : {
45
+ "NetworkConfig" : {"VpcConfig" : vpc_config }
46
+ }
47
+ }
48
+ }
49
+ config ["SageMaker" ]["AutoMLJob" ] = {
50
+ "AutoMLJobConfig" : {
51
+ "SecurityConfig" : {"VpcConfig" : vpc_config }
52
+ }
53
+ }
54
+ config ["SageMaker" ]["AutoMLJobV2" ] = {
55
+ "SecurityConfig" : {"VpcConfig" : vpc_config }
56
+ }
57
+ config ["SageMaker" ]["CompilationJob" ] = {"VpcConfig" : vpc_config }
58
+ config ["SageMaker" ]["Model" ]["VpcConfig" ] = vpc_config
59
+ config ["SageMaker" ]["ProcessingJob" ]["NetworkConfig" ] = {"VpcConfig" : vpc_config }
60
+ config ["SageMaker" ]["TrainingJob" ]["VpcConfig" ] = vpc_config
61
+
75
62
return config
76
63
77
64
@@ -106,7 +93,7 @@ def generate_intelligent_default_config(metadata: str) -> dict:
106
93
}
107
94
108
95
# Not create config file when invalid value exists in metadata
109
- empty_values = [key for key , value in metadata .items () if value == "" or value == ["" ]]
96
+ empty_values = [key for key , value in metadata .items () if key not in [ "SecurityGroupIds" , "Subnets" ] and ( value == "" or value == ["" ]) ]
110
97
if empty_values :
111
98
raise AttributeError (f"There are empty values in the metadata: { empty_values } " )
112
99
0 commit comments