@@ -203,7 +203,7 @@ def default_bucket(self):
203203 return self ._default_bucket
204204
205205 def train (self , image , input_mode , input_config , role , job_name , output_config ,
206- resource_config , hyperparameters , stop_condition ):
206+ resource_config , hyperparameters , stop_condition , tags ):
207207 """Create an Amazon SageMaker training job.
208208
209209 Args:
@@ -232,6 +232,8 @@ def train(self, image, input_mode, input_config, role, job_name, output_config,
232232 keys and values, but ``str()`` will be called to convert them before training.
233233 stop_condition (dict): Defines when training shall finish. Contains entries that can be understood by the
234234 service like ``MaxRuntimeInSeconds``.
235+ tags (list[dict]): List of tags for labeling a training job. For more, see
236+ https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
235237
236238 Returns:
237239 str: ARN of the training job, if it is created.
@@ -242,7 +244,6 @@ def train(self, image, input_mode, input_config, role, job_name, output_config,
242244 'TrainingImage' : image ,
243245 'TrainingInputMode' : input_mode
244246 },
245- # 'HyperParameters': hyperparameters,
246247 'InputDataConfig' : input_config ,
247248 'OutputDataConfig' : output_config ,
248249 'TrainingJobName' : job_name ,
@@ -253,6 +254,10 @@ def train(self, image, input_mode, input_config, role, job_name, output_config,
253254
254255 if hyperparameters and len (hyperparameters ) > 0 :
255256 train_request ['HyperParameters' ] = hyperparameters
257+
258+ if tags is not None :
259+ train_request ['Tags' ] = tags
260+
256261 LOGGER .info ('Creating training-job with name: {}' .format (job_name ))
257262 LOGGER .debug ('train request: {}' .format (json .dumps (train_request , indent = 4 )))
258263 self .sagemaker_client .create_training_job (** train_request )
0 commit comments