66
77
88def generate_intelligent_default_config (metadata : str ) -> dict :
9+ has_vpc = metadata ["SecurityGroupIds" ] and metadata ["Subnets" ] and metadata ["SecurityGroupIds" ] != ['' ] and metadata ["Subnets" ] != ['' ]
10+ vpc_config = {"SecurityGroupIds" : metadata ["SecurityGroupIds" ], "Subnets" : metadata ["Subnets" ]} if has_vpc else None
11+
912 config = {
1013 "SchemaVersion" : "1.0" ,
1114 "SageMaker" : {
@@ -17,61 +20,45 @@ def generate_intelligent_default_config(metadata: str) -> dict:
1720 },
1821 "RemoteFunction" : {
1922 "IncludeLocalWorkDir" : True ,
20- "VpcConfig" : {"SecurityGroupIds" : metadata ["SecurityGroupIds" ], "Subnets" : metadata ["Subnets" ]},
2123 },
2224 "NotebookJob" : {
2325 "RoleArn" : metadata ["UserRoleArn" ],
2426 "S3RootUri" : f"s3://{ metadata ['S3Bucket' ]} /{ metadata ['S3ObjectKeyPrefix' ]} " ,
25- "VpcConfig" : {"SecurityGroupIds" : metadata ["SecurityGroupIds" ], "Subnets" : metadata ["Subnets" ]},
2627 },
2728 "Serve" : {"S3ModelDataUri" : f"s3://{ metadata ['S3Bucket' ]} /{ metadata ['S3ObjectKeyPrefix' ]} " },
2829 }
2930 },
30- "MonitoringSchedule" : {
31- "MonitoringScheduleConfig" : {
32- "MonitoringJobDefinition" : {
33- "NetworkConfig" : {
34- "VpcConfig" : {
35- "SecurityGroupIds" : metadata ["SecurityGroupIds" ],
36- "Subnets" : metadata ["Subnets" ],
37- }
38- }
39- }
40- }
41- },
42- "AutoMLJob" : {
43- "AutoMLJobConfig" : {
44- "SecurityConfig" : {
45- "VpcConfig" : {"SecurityGroupIds" : metadata ["SecurityGroupIds" ], "Subnets" : metadata ["Subnets" ]}
46- }
47- }
48- },
49- "AutoMLJobV2" : {
50- "SecurityConfig" : {
51- "VpcConfig" : {"SecurityGroupIds" : metadata ["SecurityGroupIds" ], "Subnets" : metadata ["Subnets" ]}
52- }
53- },
54- "CompilationJob" : {
55- "VpcConfig" : {"SecurityGroupIds" : metadata ["SecurityGroupIds" ], "Subnets" : metadata ["Subnets" ]}
56- },
5731 "Pipeline" : {"RoleArn" : metadata ["UserRoleArn" ]},
58- "Model" : {
59- "VpcConfig" : {"SecurityGroupIds" : metadata ["SecurityGroupIds" ], "Subnets" : metadata ["Subnets" ]},
60- "ExecutionRoleArn" : metadata ["UserRoleArn" ],
61- },
32+ "Model" : {"ExecutionRoleArn" : metadata ["UserRoleArn" ]},
6233 "ModelPackage" : {"ValidationSpecification" : {"ValidationRole" : metadata ["UserRoleArn" ]}},
63- "ProcessingJob" : {
64- "NetworkConfig" : {
65- "VpcConfig" : {"SecurityGroupIds" : metadata ["SecurityGroupIds" ], "Subnets" : metadata ["Subnets" ]}
66- },
67- "RoleArn" : metadata ["UserRoleArn" ],
68- },
69- "TrainingJob" : {
70- "RoleArn" : metadata ["UserRoleArn" ],
71- "VpcConfig" : {"SecurityGroupIds" : metadata ["SecurityGroupIds" ], "Subnets" : metadata ["Subnets" ]},
72- },
34+ "ProcessingJob" : {"RoleArn" : metadata ["UserRoleArn" ]},
35+ "TrainingJob" : {"RoleArn" : metadata ["UserRoleArn" ]},
7336 },
7437 }
38+
39+ if has_vpc :
40+ config ["SageMaker" ]["PythonSDK" ]["Modules" ]["RemoteFunction" ]["VpcConfig" ] = vpc_config
41+ config ["SageMaker" ]["PythonSDK" ]["Modules" ]["NotebookJob" ]["VpcConfig" ] = vpc_config
42+ config ["SageMaker" ]["MonitoringSchedule" ] = {
43+ "MonitoringScheduleConfig" : {
44+ "MonitoringJobDefinition" : {
45+ "NetworkConfig" : {"VpcConfig" : vpc_config }
46+ }
47+ }
48+ }
49+ config ["SageMaker" ]["AutoMLJob" ] = {
50+ "AutoMLJobConfig" : {
51+ "SecurityConfig" : {"VpcConfig" : vpc_config }
52+ }
53+ }
54+ config ["SageMaker" ]["AutoMLJobV2" ] = {
55+ "SecurityConfig" : {"VpcConfig" : vpc_config }
56+ }
57+ config ["SageMaker" ]["CompilationJob" ] = {"VpcConfig" : vpc_config }
58+ config ["SageMaker" ]["Model" ]["VpcConfig" ] = vpc_config
59+ config ["SageMaker" ]["ProcessingJob" ]["NetworkConfig" ] = {"VpcConfig" : vpc_config }
60+ config ["SageMaker" ]["TrainingJob" ]["VpcConfig" ] = vpc_config
61+
7562 return config
7663
7764
@@ -106,7 +93,7 @@ def generate_intelligent_default_config(metadata: str) -> dict:
10693 }
10794
10895 # Not create config file when invalid value exists in metadata
109- empty_values = [key for key , value in metadata .items () if value == "" or value == ["" ]]
96+ empty_values = [key for key , value in metadata .items () if key not in [ "SecurityGroupIds" , "Subnets" ] and ( value == "" or value == ["" ]) ]
11097 if empty_values :
11198 raise AttributeError (f"There are empty values in the metadata: { empty_values } " )
11299
0 commit comments