1313# language governing permissions and limitations under the License.
1414from __future__ import absolute_import
1515
16+ import os
1617import json
1718
1819import pytest
1920import sagemaker
2021import warnings
2122
2223from sagemaker .workflow .pipeline_context import PipelineSession
24+ from sagemaker .workflow .parameters import ParameterString
2325
2426from sagemaker .workflow .steps import TrainingStep
2527from sagemaker .workflow .pipeline import Pipeline
4648from sagemaker .amazon .ntm import NTM
4749from sagemaker .amazon .object2vec import Object2Vec
4850
51+ from tests .integ import DATA_DIR
4952
5053from sagemaker .inputs import TrainingInput
5154from tests .unit .sagemaker .workflow .helpers import CustomStep
5255
5356REGION = "us-west-2"
5457IMAGE_URI = "fakeimage"
5558MODEL_NAME = "gisele"
59+ DUMMY_LOCAL_SCRIPT_PATH = os .path .join (DATA_DIR , "dummy_script.py" )
5660DUMMY_S3_SCRIPT_PATH = "s3://dummy-s3/dummy_script.py"
5761DUMMY_S3_SOURCE_DIR = "s3://dummy-s3-source-dir/"
5862INSTANCE_TYPE = "ml.m4.xlarge"
@@ -122,6 +126,36 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar
122126 assert step .properties .TrainingJobName .expr == {"Get" : "Steps.MyTrainingStep.TrainingJobName" }
123127
124128
129+ def test_estimator_with_parameterized_output (pipeline_session , training_input ):
130+ output_path = ParameterString (name = "OutputPath" )
131+ estimator = XGBoost (
132+ framework_version = "1.3-1" ,
133+ py_version = "py3" ,
134+ role = sagemaker .get_execution_role (),
135+ instance_type = INSTANCE_TYPE ,
136+ instance_count = 1 ,
137+ entry_point = DUMMY_LOCAL_SCRIPT_PATH ,
138+ output_path = output_path ,
139+ sagemaker_session = pipeline_session ,
140+ )
141+ step_args = estimator .fit (inputs = training_input )
142+ step = TrainingStep (
143+ name = "MyTrainingStep" ,
144+ step_args = step_args ,
145+ description = "TrainingStep description" ,
146+ display_name = "MyTrainingStep" ,
147+ )
148+ pipeline = Pipeline (
149+ name = "MyPipeline" ,
150+ steps = [step ],
151+ sagemaker_session = pipeline_session ,
152+ )
153+ step_def = json .loads (pipeline .definition ())["Steps" ][0 ]
154+ assert step_def ["Arguments" ]["OutputDataConfig" ]["S3OutputPath" ] == {
155+ "Get" : "Parameters.OutputPath"
156+ }
157+
158+
125159@pytest .mark .parametrize (
126160 "estimator" ,
127161 [
@@ -131,23 +165,23 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar
131165 instance_type = INSTANCE_TYPE ,
132166 instance_count = 1 ,
133167 role = sagemaker .get_execution_role (),
134- entry_point = "entry_point.py" ,
168+ entry_point = DUMMY_LOCAL_SCRIPT_PATH ,
135169 ),
136170 PyTorch (
137171 role = sagemaker .get_execution_role (),
138172 instance_type = INSTANCE_TYPE ,
139173 instance_count = 1 ,
140174 framework_version = "1.8.0" ,
141175 py_version = "py36" ,
142- entry_point = "entry_point.py" ,
176+ entry_point = DUMMY_LOCAL_SCRIPT_PATH ,
143177 ),
144178 TensorFlow (
145179 role = sagemaker .get_execution_role (),
146180 instance_type = INSTANCE_TYPE ,
147181 instance_count = 1 ,
148182 framework_version = "2.0" ,
149183 py_version = "py3" ,
150- entry_point = "entry_point.py" ,
184+ entry_point = DUMMY_LOCAL_SCRIPT_PATH ,
151185 ),
152186 HuggingFace (
153187 transformers_version = "4.6" ,
@@ -156,23 +190,23 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar
156190 instance_type = "ml.p3.2xlarge" ,
157191 instance_count = 1 ,
158192 py_version = "py36" ,
159- entry_point = "entry_point.py" ,
193+ entry_point = DUMMY_LOCAL_SCRIPT_PATH ,
160194 ),
161195 XGBoost (
162196 framework_version = "1.3-1" ,
163197 py_version = "py3" ,
164198 role = sagemaker .get_execution_role (),
165199 instance_type = INSTANCE_TYPE ,
166200 instance_count = 1 ,
167- entry_point = "entry_point.py" ,
201+ entry_point = DUMMY_LOCAL_SCRIPT_PATH ,
168202 ),
169203 MXNet (
170204 framework_version = "1.4.1" ,
171205 py_version = "py3" ,
172206 role = sagemaker .get_execution_role (),
173207 instance_type = INSTANCE_TYPE ,
174208 instance_count = 1 ,
175- entry_point = "entry_point.py" ,
209+ entry_point = DUMMY_LOCAL_SCRIPT_PATH ,
176210 ),
177211 RLEstimator (
178212 entry_point = "cartpole.py" ,
@@ -185,7 +219,7 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar
185219 ),
186220 Chainer (
187221 role = sagemaker .get_execution_role (),
188- entry_point = "entry_point.py" ,
222+ entry_point = DUMMY_LOCAL_SCRIPT_PATH ,
189223 use_mpi = True ,
190224 num_processes = 4 ,
191225 framework_version = "5.0.0" ,
0 commit comments