1515import os
1616import json
1717from mock import Mock , PropertyMock
18+ import re
1819
1920import pytest
2021import warnings
@@ -163,6 +164,7 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar
163164
164165def test_estimator_with_parameterized_output (pipeline_session , training_input ):
165166 output_path = ParameterString (name = "OutputPath" )
167+ # XGBoost
166168 estimator = XGBoost (
167169 framework_version = "1.3-1" ,
168170 py_version = "py3" ,
@@ -174,21 +176,48 @@ def test_estimator_with_parameterized_output(pipeline_session, training_input):
174176 sagemaker_session = pipeline_session ,
175177 )
176178 step_args = estimator .fit (inputs = training_input )
177- step = TrainingStep (
178- name = "MyTrainingStep" ,
179+ step1 = TrainingStep (
180+ name = "MyTrainingStep1" ,
181+ step_args = step_args ,
182+ description = "TrainingStep description" ,
183+ display_name = "MyTrainingStep" ,
184+ )
185+
186+ # TensorFlow
187+ # If model_dir is None and output_path is a pipeline variable
188+ # a default model_dir will be generated with default bucket
189+ estimator = TensorFlow (
190+ framework_version = "2.4.1" ,
191+ py_version = "py37" ,
192+ role = ROLE ,
193+ instance_type = INSTANCE_TYPE ,
194+ instance_count = 1 ,
195+ entry_point = DUMMY_LOCAL_SCRIPT_PATH ,
196+ output_path = output_path ,
197+ sagemaker_session = pipeline_session ,
198+ )
199+ step_args = estimator .fit (inputs = training_input )
200+ step2 = TrainingStep (
201+ name = "MyTrainingStep2" ,
179202 step_args = step_args ,
180203 description = "TrainingStep description" ,
181204 display_name = "MyTrainingStep" ,
182205 )
183206 pipeline = Pipeline (
184207 name = "MyPipeline" ,
185- steps = [step ],
208+ steps = [step1 , step2 ],
209+ parameters = [output_path ],
186210 sagemaker_session = pipeline_session ,
187211 )
188- step_def = json .loads (pipeline .definition ())["Steps" ][0 ]
189- assert step_def ["Arguments" ]["OutputDataConfig" ]["S3OutputPath" ] == {
190- "Get" : "Parameters.OutputPath"
191- }
212+ step_defs = json .loads (pipeline .definition ())["Steps" ]
213+ for step_def in step_defs :
214+ assert step_def ["Arguments" ]["OutputDataConfig" ]["S3OutputPath" ] == {
215+ "Get" : "Parameters.OutputPath"
216+ }
217+ if step_def ["Name" ] != "MyTrainingStep2" :
218+ continue
219+ model_dir = step_def ["Arguments" ]["HyperParameters" ]["model_dir" ]
220+ assert re .match (rf'"s3://{ BUCKET } /.*/model"' , model_dir )
192221
193222
194223@pytest .mark .parametrize (
@@ -316,7 +345,7 @@ def test_training_step_with_algorithm_base(algo_estimator, pipeline_session):
316345 sagemaker_session = pipeline_session ,
317346 )
318347 data = RecordSet (
319- "s3://{}/{}" .format (pipeline_session . default_bucket () , "dummy" ),
348+ "s3://{}/{}" .format (BUCKET , "dummy" ),
320349 num_records = 1000 ,
321350 feature_dim = 128 ,
322351 channel = "train" ,
0 commit comments