1313"""Test docstring"""
1414from __future__ import absolute_import
1515
16+ from typing import Optional , Union , Dict , List
17+
1618import sagemaker
1719import sagemaker .parameter
1820from sagemaker import vpc_utils
1921from sagemaker .deserializers import BytesDeserializer
2022from sagemaker .deprecations import removed_kwargs
2123from sagemaker .estimator import EstimatorBase
24+ from sagemaker .inputs import TrainingInput , FileSystemInput
2225from sagemaker .serializers import IdentitySerializer
2326from sagemaker .transformer import Transformer
2427from sagemaker .predictor import Predictor
28+ from sagemaker .session import Session
29+ from sagemaker .workflow .entities import PipelineVariable
30+
31+ from sagemaker .workflow import is_pipeline_variable
2532
2633
2734class AlgorithmEstimator (EstimatorBase ):
@@ -37,28 +44,28 @@ class AlgorithmEstimator(EstimatorBase):
3744
3845 def __init__ (
3946 self ,
40- algorithm_arn ,
41- role ,
42- instance_count ,
43- instance_type ,
44- volume_size = 30 ,
45- volume_kms_key = None ,
46- max_run = 24 * 60 * 60 ,
47- input_mode = "File" ,
48- output_path = None ,
49- output_kms_key = None ,
50- base_job_name = None ,
51- sagemaker_session = None ,
52- hyperparameters = None ,
53- tags = None ,
54- subnets = None ,
55- security_group_ids = None ,
56- model_uri = None ,
57- model_channel_name = "model" ,
58- metric_definitions = None ,
59- encrypt_inter_container_traffic = False ,
60- use_spot_instances = False ,
61- max_wait = None ,
47+ algorithm_arn : str ,
48+ role : str ,
49+ instance_count : Optional [ Union [ int , PipelineVariable ]] = None ,
50+ instance_type : Optional [ Union [ str , PipelineVariable ]] = None ,
51+ volume_size : Union [ int , PipelineVariable ] = 30 ,
52+ volume_kms_key : Optional [ Union [ str , PipelineVariable ]] = None ,
53+ max_run : Union [ int , PipelineVariable ] = 24 * 60 * 60 ,
54+ input_mode : Union [ str , PipelineVariable ] = "File" ,
55+ output_path : Optional [ Union [ str , PipelineVariable ]] = None ,
56+ output_kms_key : Optional [ Union [ str , PipelineVariable ]] = None ,
57+ base_job_name : Optional [ str ] = None ,
58+ sagemaker_session : Optional [ Session ] = None ,
59+ hyperparameters : Optional [ Dict [ str , Union [ str , PipelineVariable ]]] = None ,
60+ tags : Optional [ List [ Dict [ str , Union [ str , PipelineVariable ]]]] = None ,
61+ subnets : Optional [ List [ Union [ str , PipelineVariable ]]] = None ,
62+ security_group_ids : Optional [ List [ Union [ str , PipelineVariable ]]] = None ,
63+ model_uri : Optional [ str ] = None ,
64+ model_channel_name : Union [ str , PipelineVariable ] = "model" ,
65+ metric_definitions : Optional [ List [ Dict [ str , Union [ str , PipelineVariable ]]]] = None ,
66+ encrypt_inter_container_traffic : Union [ bool , PipelineVariable ] = False ,
67+ use_spot_instances : Union [ bool , PipelineVariable ] = False ,
68+ max_wait : Optional [ Union [ int , PipelineVariable ]] = None ,
6269 ** kwargs # pylint: disable=W0613
6370 ):
6471 """Initialize an ``AlgorithmEstimator`` instance.
@@ -71,18 +78,21 @@ def __init__(
7178 access training data and model artifacts. After the endpoint
7279 is created, the inference code might use the IAM role, if it
7380 needs to access an AWS resource.
74- instance_count (int): Number of Amazon EC2 instances to use for training.
75- instance_type (str): Type of EC2 instance to use for training, for example, 'ml.c4.xlarge'.
76- volume_size (int): Size in GB of the EBS volume to use for
81+ instance_count (int or PipelineVariable): Number of Amazon EC2 instances to use
82+ for training.
83+ instance_type (str or PipelineVariable): Type of EC2 instance to use for training,
84+ for example, 'ml.c4.xlarge'.
85+ volume_size (int or PipelineVariable): Size in GB of the EBS volume to use for
7786 storing input data during training (default: 30). Must be large enough to store
7887 training data if File Mode is used (which is the default).
79- volume_kms_key (str): Optional. KMS key ID for encrypting EBS volume attached
80- to the training instance (default: None).
81- max_run (int): Timeout in seconds for training (default: 24 * 60 * 60).
88+ volume_kms_key (str or PipelineVariable): Optional. KMS key ID for encrypting
89+ EBS volume attached to the training instance (default: None).
90+ max_run (int or PipelineVariable): Timeout in seconds for training
91+ (default: 24 * 60 * 60).
8292 After this amount of time Amazon SageMaker terminates the
8393 job regardless of its current status.
84- input_mode (str): The input mode that the algorithm supports
85- (default: 'File'). Valid modes:
94+ input_mode (str or PipelineVariable ): The input mode that the algorithm supports
95+ (default: 'File'). Valid modes:
8696
8797 * 'File' - Amazon SageMaker copies the training dataset from
8898 the S3 location to a local directory.
@@ -92,13 +102,14 @@ def __init__(
92102 This argument can be overriden on a per-channel basis using
93103 ``sagemaker.inputs.TrainingInput.input_mode``.
94104
95- output_path (str): S3 location for saving the training result (model artifacts and
96- output files). If not specified, results are stored to a default bucket. If
105+ output_path (str or PipelineVariable): S3 location for saving the training result
106+ (model artifacts and output files). If not specified,
107+ results are stored to a default bucket. If
97108 the bucket with the specific name does not exist, the
98109 estimator creates the bucket during the
99110 :meth:`~sagemaker.estimator.EstimatorBase.fit` method
100111 execution.
101- output_kms_key (str): Optional. KMS key ID for encrypting the
112+ output_kms_key (str or PipelineVariable ): Optional. KMS key ID for encrypting the
102113 training output (default: None). base_job_name (str): Prefix for
103114 training job name when the
104115 :meth:`~sagemaker.estimator.EstimatorBase.fit`
@@ -109,9 +120,10 @@ def __init__(
109120 interactions with Amazon SageMaker APIs and any other AWS services needed. If
110121 not specified, the estimator creates one using the default
111122 AWS configuration chain.
112- tags (list[dict]): List of tags for labeling a training job. For more, see
123+ tags (list[dict[str, str] or list[dict[str, PipelineVariable]]): List of tags for
124+ labeling a training job. For more, see
113125 https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
114- subnets (list[str]): List of subnet ids. If not specified
126+ subnets (list[str] or list[PipelineVariable] ): List of subnet ids. If not specified
115127 training job will be created without VPC config.
116128 security_group_ids (list[str]): List of security group ids. If
117129 not specified training job will be created without VPC config.
@@ -122,22 +134,22 @@ def __init__(
122134 other artifacts coming from a different source.
123135 More information:
124136 https://docs.aws.amazon.com/sagemaker/latest/dg/cdf-training.html#td-deserialization
125- model_channel_name (str): Name of the channel where 'model_uri'
137+ model_channel_name (str or PipelineVariable ): Name of the channel where 'model_uri'
126138 will be downloaded (default: 'model'). metric_definitions
127139 (list[dict]): A list of dictionaries that defines the metric(s)
128140 used to evaluate the training jobs. Each dictionary contains two keys: 'Name' for
129141 the name of the metric, and 'Regex' for the regular
130142 expression used to extract the metric from the logs.
131- encrypt_inter_container_traffic (bool): Specifies whether traffic between training
132- containers is encrypted for the training job (default: ``False``).
133- use_spot_instances (bool): Specifies whether to use SageMaker
143+ encrypt_inter_container_traffic (bool or PipelineVariable ): Specifies whether traffic
144+ between training containers is encrypted for the training job (default: ``False``).
145+ use_spot_instances (bool or PipelineVariable ): Specifies whether to use SageMaker
134146 Managed Spot instances for training. If enabled then the
135147 `max_wait` arg should also be set.
136148
137149 More information:
138150 https://docs.aws.amazon.com/sagemaker/latest/dg/model-managed-spot-training.html
139151 (default: ``False``).
140- max_wait (int): Timeout in seconds waiting for spot training
152+ max_wait (int or PipelineVariable ): Timeout in seconds waiting for spot training
141153 instances (default: None). After this amount of time Amazon
142154 SageMaker will stop waiting for Spot instances to become
143155 available (default: ``None``).
@@ -186,22 +198,25 @@ def validate_train_spec(self):
186198 # Check that the input mode provided is compatible with the training input modes for the
187199 # algorithm.
188200 input_modes = self ._algorithm_training_input_modes (train_spec ["TrainingChannels" ])
189- if self .input_mode not in input_modes :
201+ if not is_pipeline_variable ( self . input_mode ) and self .input_mode not in input_modes :
190202 raise ValueError (
191203 "Invalid input mode: %s. %s only supports: %s"
192204 % (self .input_mode , algorithm_name , input_modes )
193205 )
194206
195207 # Check that the training instance type is compatible with the algorithm.
196208 supported_instances = train_spec ["SupportedTrainingInstanceTypes" ]
197- if self .instance_type not in supported_instances :
209+ if (
210+ not is_pipeline_variable (self .instance_type )
211+ and self .instance_type not in supported_instances
212+ ):
198213 raise ValueError (
199214 "Invalid instance_type: %s. %s supports the following instance types: %s"
200215 % (self .instance_type , algorithm_name , supported_instances )
201216 )
202217
203218 # Verify if distributed training is supported by the algorithm
204- if (
219+ if not is_pipeline_variable ( self . instance_count ) and (
205220 self .instance_count > 1
206221 and "SupportsDistributedTraining" in train_spec
207222 and not train_spec ["SupportsDistributedTraining" ]
@@ -414,12 +429,18 @@ def _prepare_for_training(self, job_name=None):
414429
415430 super (AlgorithmEstimator , self )._prepare_for_training (job_name )
416431
417- def fit (self , inputs = None , wait = True , logs = True , job_name = None ):
432+ def fit (
433+ self ,
434+ inputs : Optional [Union [str , Dict , TrainingInput , FileSystemInput ]] = None ,
435+ wait : bool = True ,
436+ logs : bool = True ,
437+ job_name : Optional [str ] = None ,
438+ ):
418439 """Placeholder docstring"""
419440 if inputs :
420441 self ._validate_input_channels (inputs )
421442
422- super (AlgorithmEstimator , self ).fit (inputs , wait , logs , job_name )
443+ return super (AlgorithmEstimator , self ).fit (inputs , wait , logs , job_name )
423444
424445 def _validate_input_channels (self , channels ):
425446 """Placeholder docstring"""
0 commit comments