@@ -140,7 +140,7 @@ def __init__(
140
140
training output (default: None).
141
141
base_job_name (str): Prefix for training job name when the
142
142
:meth:`~sagemaker.estimator.EstimatorBase.fit` method launches.
143
- If not specified, the estimator generates a default job name,
143
+ If not specified, the estimator generates a default job name
144
144
based on the training image name and current timestamp.
145
145
sagemaker_session (sagemaker.session.Session): Session object which
146
146
manages interactions with Amazon SageMaker APIs and any other
@@ -328,6 +328,28 @@ def prepare_workflow_for_training(self, job_name=None):
328
328
"""
329
329
self ._prepare_for_training (job_name = job_name )
330
330
331
+ def _ensure_base_job_name (self ):
332
+ """Set ``self.base_job_name`` if it is not set already."""
333
+ # honor supplied base_job_name or generate it
334
+ if self .base_job_name is None :
335
+ self .base_job_name = base_name_from_image (self .train_image ())
336
+
337
+ def _get_or_create_name (self , name = None ):
338
+ """Generate a name based on the base job name or training image if needed.
339
+
340
+ Args:
341
+ name (str): User-supplied name. If not specified, a name is generated from
342
+ the base job name or training image.
343
+
344
+ Returns:
345
+ str: Either the user-supplied name or a generated name.
346
+ """
347
+ if name :
348
+ return name
349
+
350
+ self ._ensure_base_job_name ()
351
+ return name_from_base (self .base_job_name )
352
+
331
353
def _prepare_for_training (self , job_name = None ):
332
354
"""Set any values in the estimator that need to be set before training.
333
355
@@ -336,18 +358,7 @@ def _prepare_for_training(self, job_name=None):
336
358
specified, one is generated, using the base name given to the
337
359
constructor if applicable.
338
360
"""
339
- if job_name is not None :
340
- self ._current_job_name = job_name
341
- else :
342
- # honor supplied base_job_name or generate it
343
- if self .base_job_name :
344
- base_name = self .base_job_name
345
- elif isinstance (self , sagemaker .algorithm .AlgorithmEstimator ):
346
- base_name = self .algorithm_arn .split ("/" )[- 1 ] # pylint: disable=no-member
347
- else :
348
- base_name = base_name_from_image (self .train_image ())
349
-
350
- self ._current_job_name = name_from_base (base_name )
361
+ self ._current_job_name = self ._get_or_create_name (job_name )
351
362
352
363
# if output_path was specified we use it otherwise initialize here.
353
364
# For Local Mode with local_code=True we don't need an explicit output_path
@@ -483,7 +494,7 @@ def fit(self, inputs=None, wait=True, logs="All", job_name=None, experiment_conf
483
494
compatibility, boolean values are also accepted and converted to strings.
484
495
Only meaningful when wait is True.
485
496
job_name (str): Training job name. If not specified, the estimator generates
486
- a default job name, based on the training image name and current timestamp.
497
+ a default job name based on the training image name and current timestamp.
487
498
experiment_config (dict[str, str]): Experiment management configuration.
488
499
Dictionary contains three optional keys,
489
500
'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
@@ -667,7 +678,8 @@ def deploy(
667
678
wait (bool): Whether the call should wait until the deployment of
668
679
model completes (default: True).
669
680
model_name (str): Name to use for creating an Amazon SageMaker
670
- model. If not specified, the name of the training job is used.
681
+ model. If not specified, the estimator generates a default job name
682
+ based on the training image name and current timestamp.
671
683
kms_key (str): The ARN of the KMS key that is used to encrypt the
672
684
data on the storage volume attached to the instance hosting the
673
685
endpoint.
@@ -691,8 +703,11 @@ def deploy(
691
703
endpoint and obtain inferences.
692
704
"""
693
705
self ._ensure_latest_training_job ()
694
- endpoint_name = endpoint_name or self .latest_training_job .name
695
- model_name = model_name or self .latest_training_job .name
706
+ self ._ensure_base_job_name ()
707
+ default_name = name_from_base (self .base_job_name )
708
+ endpoint_name = endpoint_name or default_name
709
+ model_name = model_name or default_name
710
+
696
711
self .deploy_instance_type = instance_type
697
712
if use_compiled_model :
698
713
family = "_" .join (instance_type .split ("." )[:- 1 ])
@@ -889,18 +904,18 @@ def transformer(
889
904
If not specified, this setting is taken from the estimator's
890
905
current configuration.
891
906
model_name (str): Name to use for creating an Amazon SageMaker
892
- model. If not specified, the name of the training job is used.
907
+ model. If not specified, the estimator generates a default job name
908
+ based on the training image name and current timestamp.
893
909
"""
894
910
tags = tags or self .tags
911
+ model_name = self ._get_or_create_name (model_name )
895
912
896
913
if self .latest_training_job is None :
897
914
logging .warning (
898
915
"No finished training job found associated with this estimator. Please make sure "
899
916
"this estimator is only used for building workflow config"
900
917
)
901
- model_name = model_name or self ._current_job_name
902
918
else :
903
- model_name = model_name or self .latest_training_job .name
904
919
if enable_network_isolation is None :
905
920
enable_network_isolation = self .enable_network_isolation ()
906
921
@@ -1984,14 +1999,16 @@ def transformer(
1984
1999
If not specified, this setting is taken from the estimator's
1985
2000
current configuration.
1986
2001
model_name (str): Name to use for creating an Amazon SageMaker
1987
- model. If not specified, the name of the training job is used.
2002
+ model. If not specified, the estimator generates a default job name
2003
+ based on the training image name and current timestamp.
1988
2004
1989
2005
Returns:
1990
2006
sagemaker.transformer.Transformer: a ``Transformer`` object that can be used to start a
1991
2007
SageMaker Batch Transform job.
1992
2008
"""
1993
2009
role = role or self .role
1994
2010
tags = tags or self .tags
2011
+ model_name = self ._get_or_create_name (model_name )
1995
2012
1996
2013
if self .latest_training_job is not None :
1997
2014
if enable_network_isolation is None :
@@ -2008,7 +2025,6 @@ def transformer(
2008
2025
)
2009
2026
model ._create_sagemaker_model (instance_type , tags = tags )
2010
2027
2011
- model_name = model .name
2012
2028
transform_env = model .env .copy ()
2013
2029
if env is not None :
2014
2030
transform_env .update (env )
@@ -2017,7 +2033,6 @@ def transformer(
2017
2033
"No finished training job found associated with this estimator. Please make sure "
2018
2034
"this estimator is only used for building workflow config"
2019
2035
)
2020
- model_name = model_name or self ._current_job_name
2021
2036
transform_env = env or {}
2022
2037
2023
2038
return Transformer (
0 commit comments