@@ -52,7 +52,7 @@ def __init__(self, role, train_instance_count, train_instance_type,
5252 train_volume_size = 30 , train_volume_kms_key = None , train_max_run = 24 * 60 * 60 , input_mode = 'File' ,
5353 output_path = None , output_kms_key = None , base_job_name = None , sagemaker_session = None , tags = None ,
5454 subnets = None , security_group_ids = None , model_uri = None , model_channel_name = 'model' ,
55- metric_definitions = None ):
55+ metric_definitions = None , encrypt_inter_container_traffic = False ):
5656 """Initialize an ``EstimatorBase`` instance.
5757
5858 Args:
@@ -103,6 +103,8 @@ def __init__(self, role, train_instance_count, train_instance_type,
103103 training jobs. Each dictionary contains two keys: 'Name' for the name of the metric, and 'Regex' for
104104 the regular expression used to extract the metric from the logs. This should be defined only
105105 for jobs that don't use an Amazon algorithm.
106+ encrypt_inter_container_traffic (bool): Specifies whether traffic between training containers is encrypted
107+ for the training job (default: ``False``).
106108 """
107109 self .role = role
108110 self .train_instance_count = train_instance_count
@@ -138,6 +140,8 @@ def __init__(self, role, train_instance_count, train_instance_type,
138140 self .subnets = subnets
139141 self .security_group_ids = security_group_ids
140142
143+ self .encrypt_inter_container_traffic = encrypt_inter_container_traffic
144+
141145 @abstractmethod
142146 def train_image (self ):
143147 """Return the Docker image to use for training.
@@ -429,6 +433,10 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
429433 if 'MetricDefinitons' in job_details ['AlgorithmSpecification' ]:
430434 init_params ['metric_definitions' ] = job_details ['AlgorithmSpecification' ]['MetricsDefinition' ]
431435
436+ if 'EnableInterContainerTrafficEncryption' in job_details :
437+ init_params ['encrypt_inter_container_traffic' ] = \
438+ job_details ['EnableInterContainerTrafficEncryption' ]
439+
432440 subnets , security_group_ids = vpc_utils .from_dict (job_details .get (vpc_utils .VPC_CONFIG_KEY ))
433441 if subnets :
434442 init_params ['subnets' ] = subnets
@@ -555,6 +563,9 @@ def start_new(cls, estimator, inputs):
555563 if estimator .enable_network_isolation ():
556564 train_args ['enable_network_isolation' ] = True
557565
566+ if estimator .encrypt_inter_container_traffic :
567+ train_args ['encrypt_inter_container_traffic' ] = True
568+
558569 if isinstance (estimator , sagemaker .algorithm .AlgorithmEstimator ):
559570 train_args ['algorithm_arn' ] = estimator .algorithm_arn
560571 else :
@@ -585,7 +596,8 @@ def __init__(self, image_name, role, train_instance_count, train_instance_type,
585596 train_volume_size = 30 , train_volume_kms_key = None , train_max_run = 24 * 60 * 60 ,
586597 input_mode = 'File' , output_path = None , output_kms_key = None , base_job_name = None ,
587598 sagemaker_session = None , hyperparameters = None , tags = None , subnets = None , security_group_ids = None ,
588- model_uri = None , model_channel_name = 'model' , metric_definitions = None ):
599+ model_uri = None , model_channel_name = 'model' , metric_definitions = None ,
600+ encrypt_inter_container_traffic = False ):
589601 """Initialize an ``Estimator`` instance.
590602
591603 Args:
@@ -640,14 +652,17 @@ def __init__(self, image_name, role, train_instance_count, train_instance_type,
640652 training jobs. Each dictionary contains two keys: 'Name' for the name of the metric, and 'Regex' for
641653 the regular expression used to extract the metric from the logs. This should be defined only
642654 for jobs that don't use an Amazon algorithm.
655+ encrypt_inter_container_traffic (bool): Specifies whether traffic between training containers is encrypted
656+ for the training job (default: ``False``).
643657 """
644658 self .image_name = image_name
645659 self .hyperparam_dict = hyperparameters .copy () if hyperparameters else {}
646660 super (Estimator , self ).__init__ (role , train_instance_count , train_instance_type ,
647661 train_volume_size , train_volume_kms_key , train_max_run , input_mode ,
648662 output_path , output_kms_key , base_job_name , sagemaker_session ,
649663 tags , subnets , security_group_ids , model_uri = model_uri ,
650- model_channel_name = model_channel_name , metric_definitions = metric_definitions )
664+ model_channel_name = model_channel_name , metric_definitions = metric_definitions ,
665+ encrypt_inter_container_traffic = encrypt_inter_container_traffic )
651666
652667 def train_image (self ):
653668 """
@@ -743,7 +758,7 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cl
743758 entry_point (str): Path (absolute or relative) to the local Python source file which should be executed
744759 as the entry point to training. This should be compatible with either Python 2.7 or Python 3.5.
745760 source_dir (str): Path (absolute or relative) to a directory with any other training
746- source code dependencies aside from tne entry point file (default: None). Structure within this
761+ source code dependencies aside from the entry point file (default: None). Structure within this
747762 directory are preserved when training on Amazon SageMaker.
748763 dependencies (list[str]): A list of paths to directories (absolute or relative) with
749764 any additional libraries that will be exported to the container (default: []).
0 commit comments