@@ -46,7 +46,7 @@ class EstimatorBase(with_metaclass(ABCMeta, object)):
4646 """
4747
4848 def __init__ (self , role , train_instance_count , train_instance_type ,
49- train_volume_size = 30 , train_max_run = 24 * 60 * 60 , input_mode = 'File' ,
49+ train_volume_size = 30 , train_volume_kms_key = None , train_max_run = 24 * 60 * 60 , input_mode = 'File' ,
5050 output_path = None , output_kms_key = None , base_job_name = None , sagemaker_session = None , tags = None ,
5151 subnets = None , security_group_ids = None ):
5252 """Initialize an ``EstimatorBase`` instance.
@@ -61,6 +61,8 @@ def __init__(self, role, train_instance_count, train_instance_type,
6161 train_volume_size (int): Size in GB of the EBS volume to use for storing input data
6262 during training (default: 30). Must be large enough to store training data if File Mode is used
6363 (which is the default).
64+ train_volume_kms_key (str): Optional. KMS key ID for encrypting EBS volume attached to the
65+ training instance (default: None).
6466 train_max_run (int): Timeout in seconds for training (default: 24 * 60 * 60).
6567 After this amount of time Amazon SageMaker terminates the job regardless of its current status.
6668 input_mode (str): The input mode that the algorithm supports (default: 'File'). Valid modes:
@@ -87,6 +89,7 @@ def __init__(self, role, train_instance_count, train_instance_type,
8789 self .train_instance_count = train_instance_count
8890 self .train_instance_type = train_instance_type
8991 self .train_volume_size = train_volume_size
92+ self .train_volume_kms_key = train_volume_kms_key
9093 self .train_max_run = train_max_run
9194 self .input_mode = input_mode
9295 self .tags = tags
@@ -427,9 +430,9 @@ class Estimator(EstimatorBase):
427430 """
428431
429432 def __init__ (self , image_name , role , train_instance_count , train_instance_type ,
430- train_volume_size = 30 , train_max_run = 24 * 60 * 60 , input_mode = 'File' ,
431- output_path = None , output_kms_key = None , base_job_name = None , sagemaker_session = None ,
432- hyperparameters = None , tags = None , subnets = None , security_group_ids = None ):
433+ train_volume_size = 30 , train_volume_kms_key = None , train_max_run = 24 * 60 * 60 ,
434+ input_mode = 'File' , output_path = None , output_kms_key = None , base_job_name = None ,
435+ sagemaker_session = None , hyperparameters = None , tags = None , subnets = None , security_group_ids = None ):
433436 """Initialize an ``Estimator`` instance.
434437
435438 Args:
@@ -443,6 +446,8 @@ def __init__(self, image_name, role, train_instance_count, train_instance_type,
443446 train_volume_size (int): Size in GB of the EBS volume to use for storing input data
444447 during training (default: 30). Must be large enough to store training data if File Mode is used
445448 (which is the default).
449+ train_volume_kms_key (str): Optional. KMS key ID for encrypting EBS volume attached to the
450+ training instance (default: None).
446451 train_max_run (int): Timeout in seconds for training (default: 24 * 60 * 60).
447452 After this amount of time Amazon SageMaker terminates the job regardless of its current status.
448453 input_mode (str): The input mode that the algorithm supports (default: 'File'). Valid modes:
@@ -462,11 +467,16 @@ def __init__(self, image_name, role, train_instance_count, train_instance_type,
462467 Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one
463468 using the default AWS configuration chain.
464469 hyperparameters (dict): Dictionary containing the hyperparameters to initialize this estimator with.
470+ tags (list[dict]): List of tags for labeling a training job. For more, see
471+ https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
472+ subnets (list[str]): List of subnet ids. If not specified training job will be created without VPC config.
473+ security_group_ids (list[str]): List of security group ids. If not specified training job will be created
474+ without VPC config.
465475 """
466476 self .image_name = image_name
467477 self .hyperparam_dict = hyperparameters .copy () if hyperparameters else {}
468478 super (Estimator , self ).__init__ (role , train_instance_count , train_instance_type ,
469- train_volume_size , train_max_run , input_mode ,
479+ train_volume_size , train_volume_kms_key , train_max_run , input_mode ,
470480 output_path , output_kms_key , base_job_name , sagemaker_session ,
471481 tags , subnets , security_group_ids )
472482
0 commit comments