@@ -173,6 +173,7 @@ def __init__(
173173 instance_groups : Optional [List [InstanceGroup ]] = None ,
174174 training_repository_access_mode : Optional [Union [str , PipelineVariable ]] = None ,
175175 training_repository_credentials_provider_arn : Optional [Union [str , PipelineVariable ]] = None ,
176+ enable_infra_check : Optional [Union [bool , PipelineVariable ]] = None ,
176177 container_entry_point : Optional [List [str ]] = None ,
177178 container_arguments : Optional [List [str ]] = None ,
178179 disable_output_compression : bool = False ,
@@ -536,6 +537,8 @@ def __init__(
536537 a training job.
537538 disable_output_compression (bool): Optional. When set to true, Model is uploaded
538539 to Amazon S3 without compression after training finishes.
540+ enable_infra_check (bool or PipelineVariable): Optional.
541+ Specifies whether it is running Sagemaker built-in infra check jobs.
539542 """
540543 instance_count = renamed_kwargs (
541544 "train_instance_count" , "instance_count" , instance_count , kwargs
@@ -665,6 +668,7 @@ def __init__(
665668 training_repository_credentials_provider_arn
666669 )
667670
671+ self .enable_infra_check = enable_infra_check
668672 # container entry point / arguments configs
669673 self .container_entry_point = container_entry_point
670674 self .container_arguments = container_arguments
@@ -1904,6 +1908,11 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
19041908 "EnableInterContainerTrafficEncryption"
19051909 ]
19061910
1911+ if "InfraCheckConfig" in job_details :
1912+ init_params ["enable_infra_check" ] = job_details ["InfraCheckConfig" ].get (
1913+ "EnableInfraCheck"
1914+ )
1915+
19071916 subnets , security_group_ids = vpc_utils .from_dict (job_details .get (vpc_utils .VPC_CONFIG_KEY ))
19081917 if subnets :
19091918 init_params ["subnets" ] = subnets
@@ -2446,6 +2455,10 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
24462455 ] = estimator .training_repository_credentials_provider_arn
24472456 train_args ["training_image_config" ] = training_image_config
24482457
2458+ if estimator .enable_infra_check is not None :
2459+ infra_check_config = {"EnableInfraCheck" : estimator .enable_infra_check }
2460+ train_args ["infra_check_config" ] = infra_check_config
2461+
24492462 if estimator .container_entry_point is not None :
24502463 train_args ["container_entry_point" ] = estimator .container_entry_point
24512464
@@ -2661,6 +2674,7 @@ def __init__(
26612674 container_entry_point : Optional [List [str ]] = None ,
26622675 container_arguments : Optional [List [str ]] = None ,
26632676 disable_output_compression : bool = False ,
2677+ enable_infra_check : Optional [Union [bool , PipelineVariable ]] = None ,
26642678 ** kwargs ,
26652679 ):
26662680 """Initialize an ``Estimator`` instance.
@@ -3020,6 +3034,8 @@ def __init__(
30203034 a training job.
30213035 disable_output_compression (bool): Optional. When set to true, Model is uploaded
30223036 to Amazon S3 without compression after training finishes.
3037+ enable_infra_check (bool or PipelineVariable): Optional.
3038+ Specifies whether it is running Sagemaker built-in infra check jobs.
30233039 """
30243040 self .image_uri = image_uri
30253041 self ._hyperparameters = hyperparameters .copy () if hyperparameters else {}
0 commit comments