@@ -34,10 +34,14 @@ def prepare_framework(estimator, s3_operations):
3434 if estimator .code_location is not None :
3535 bucket , key = fw_utils .parse_s3_url (estimator .code_location )
3636 key = os .path .join (key , estimator ._current_job_name , "source" , "sourcedir.tar.gz" )
37+ elif estimator .uploaded_code is not None :
38+ bucket , key = fw_utils .parse_s3_url (estimator .uploaded_code .s3_prefix )
3739 else :
3840 bucket = estimator .sagemaker_session ._default_bucket
3941 key = os .path .join (estimator ._current_job_name , "source" , "sourcedir.tar.gz" )
42+
4043 script = os .path .basename (estimator .entry_point )
44+
4145 if estimator .source_dir and estimator .source_dir .lower ().startswith ("s3://" ):
4246 code_dir = estimator .source_dir
4347 estimator .uploaded_code = fw_utils .UploadedCode (s3_prefix = code_dir , script_name = script )
@@ -96,7 +100,7 @@ def prepare_amazon_algorithm_estimator(estimator, inputs, mini_batch_size=None):
96100 estimator .mini_batch_size = mini_batch_size
97101
98102
99- def training_base_config (estimator , inputs = None , job_name = None , mini_batch_size = None ):
103+ def training_base_config (estimator , inputs = None , job_name = None , mini_batch_size = None ): # noqa: C901
100104 """Export Airflow base training config from an estimator
101105
102106 Args:
@@ -134,6 +138,13 @@ def training_base_config(estimator, inputs=None, job_name=None, mini_batch_size=
134138 dict: Training config that can be directly used by
135139 SageMakerTrainingOperator in Airflow.
136140 """
141+ if isinstance (estimator , sagemaker .amazon .amazon_estimator .AmazonAlgorithmEstimatorBase ):
142+ estimator .prepare_workflow_for_training (
143+ records = inputs , mini_batch_size = mini_batch_size , job_name = job_name
144+ )
145+ else :
146+ estimator .prepare_workflow_for_training (job_name = job_name )
147+
137148 default_bucket = estimator .sagemaker_session .default_bucket ()
138149 s3_operations = {}
139150
@@ -528,6 +539,7 @@ def model_config_from_estimator(
528539 model_server_workers = model_server_workers ,
529540 role = role ,
530541 vpc_config_override = vpc_config_override ,
542+ entry_point = estimator .entry_point ,
531543 )
532544 else :
533545 raise TypeError (
0 commit comments