3838from sagemaker .workflow .conditions import ConditionGreaterThanOrEqualTo
3939from sagemaker .workflow .condition_step import ConditionStep
4040from sagemaker .dataset_definition .inputs import DatasetDefinition , AthenaDatasetDefinition
41+ from sagemaker .workflow .execution_variables import ExecutionVariables
42+ from sagemaker .workflow .functions import Join
4143from sagemaker .workflow .parameters import (
4244 ParameterInteger ,
4345 ParameterString ,
@@ -72,16 +74,9 @@ def role(sagemaker_session):
7274 return get_execution_role (sagemaker_session )
7375
7476
75- # TODO-reinvent-2020: remove use of specific region and this session
7677@pytest .fixture (scope = "module" )
77- def region ():
78- return "us-east-2"
79-
80-
81- # TODO-reinvent-2020: remove use of specific region and this session
82- @pytest .fixture (scope = "module" )
83- def workflow_session (region ):
84- boto_session = boto3 .Session (region_name = region )
78+ def workflow_session (region_name ):
79+ boto_session = boto3 .Session (region_name = region_name )
8580
8681 sagemaker_client_config = dict ()
8782 sagemaker_client_config .setdefault ("config" , Config (retries = dict (max_attempts = 2 )))
@@ -134,6 +129,7 @@ def test_three_step_definition(
134129 framework_version = "0.20.0"
135130 instance_type = ParameterString (name = "InstanceType" , default_value = "ml.m5.xlarge" )
136131 instance_count = ParameterInteger (name = "InstanceCount" , default_value = 1 )
132+ output_prefix = ParameterString (name = "OutputPrefix" , default_value = "output" )
137133
138134 input_data = f"s3://sagemaker-sample-data-{ region_name } /processing/census/census-income.csv"
139135
@@ -154,7 +150,20 @@ def test_three_step_definition(
154150 ],
155151 outputs = [
156152 ProcessingOutput (output_name = "train_data" , source = "/opt/ml/processing/train" ),
157- ProcessingOutput (output_name = "test_data" , source = "/opt/ml/processing/test" ),
153+ ProcessingOutput (
154+ output_name = "test_data" ,
155+ source = "/opt/ml/processing/test" ,
156+ destination = Join (
157+ on = "/" ,
158+ values = [
159+ "s3:/" ,
160+ sagemaker_session .default_bucket (),
161+ "test-sklearn" ,
162+ output_prefix ,
163+ ExecutionVariables .PIPELINE_EXECUTION_ID ,
164+ ],
165+ ),
166+ ),
158167 ],
159168 code = os .path .join (script_dir , "preprocessing.py" ),
160169 )
@@ -194,7 +203,7 @@ def test_three_step_definition(
194203
195204 pipeline = Pipeline (
196205 name = pipeline_name ,
197- parameters = [instance_type , instance_count ],
206+ parameters = [instance_type , instance_count , output_prefix ],
198207 steps = [step_process , step_train , step_model ],
199208 sagemaker_session = workflow_session ,
200209 )
@@ -208,6 +217,7 @@ def test_three_step_definition(
208217 {"Name" : "InstanceType" , "Type" : "String" , "DefaultValue" : "ml.m5.xlarge" }.items ()
209218 ),
210219 tuple ({"Name" : "InstanceCount" , "Type" : "Integer" , "DefaultValue" : 1 }.items ()),
220+ tuple ({"Name" : "OutputPrefix" , "Type" : "String" , "DefaultValue" : "output" }.items ()),
211221 ]
212222 )
213223
@@ -251,17 +261,28 @@ def test_three_step_definition(
251261 assert model_args ["PrimaryContainer" ]["ModelDataUrl" ] == {
252262 "Get" : "Steps.my-train.ModelArtifacts.S3ModelArtifacts"
253263 }
264+ try :
265+ response = pipeline .create (role )
266+ create_arn = response ["PipelineArn" ]
267+ assert re .match (
268+ fr"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } " ,
269+ create_arn ,
270+ )
271+ finally :
272+ try :
273+ pipeline .delete ()
274+ except Exception :
275+ pass
254276
255277
256- # TODO-reinvent-2020: Modify use of the workflow client
257278def test_one_step_sklearn_processing_pipeline (
258279 sagemaker_session ,
259280 workflow_session ,
260281 role ,
261282 sklearn_latest_version ,
262283 cpu_instance_type ,
263284 pipeline_name ,
264- region ,
285+ region_name ,
265286 athena_dataset_definition ,
266287):
267288 instance_count = ParameterInteger (name = "InstanceCount" , default_value = 2 )
@@ -305,21 +326,21 @@ def test_one_step_sklearn_processing_pipeline(
305326 response = pipeline .create (role )
306327 create_arn = response ["PipelineArn" ]
307328 assert re .match (
308- fr"arn:aws:sagemaker:{ region } :\d{{12}}:pipeline/{ pipeline_name } " ,
329+ fr"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } " ,
309330 create_arn ,
310331 )
311332
312333 pipeline .parameters = [ParameterInteger (name = "InstanceCount" , default_value = 1 )]
313334 response = pipeline .update (role )
314335 update_arn = response ["PipelineArn" ]
315336 assert re .match (
316- fr"arn:aws:sagemaker:{ region } :\d{{12}}:pipeline/{ pipeline_name } " ,
337+ fr"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } " ,
317338 update_arn ,
318339 )
319340
320341 execution = pipeline .start (parameters = {})
321342 assert re .match (
322- fr"arn:aws:sagemaker:{ region } :\d{{12}}:pipeline/{ pipeline_name } /execution/" ,
343+ fr"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } /execution/" ,
323344 execution .arn ,
324345 )
325346
@@ -340,14 +361,13 @@ def test_one_step_sklearn_processing_pipeline(
340361 pass
341362
342363
343- # TODO-reinvent-2020: Modify use of the workflow client
344364def test_conditional_pytorch_training_model_registration (
345365 sagemaker_session ,
346366 workflow_session ,
347367 role ,
348368 cpu_instance_type ,
349369 pipeline_name ,
350- region ,
370+ region_name ,
351371):
352372 base_dir = os .path .join (DATA_DIR , "pytorch_mnist" )
353373 entry_point = os .path .join (base_dir , "mnist.py" )
@@ -420,18 +440,18 @@ def test_conditional_pytorch_training_model_registration(
420440 response = pipeline .create (role )
421441 create_arn = response ["PipelineArn" ]
422442 assert re .match (
423- fr"arn:aws:sagemaker:{ region } :\d{{12}}:pipeline/{ pipeline_name } " , create_arn
443+ fr"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } " , create_arn
424444 )
425445
426446 execution = pipeline .start (parameters = {})
427447 assert re .match (
428- fr"arn:aws:sagemaker:{ region } :\d{{12}}:pipeline/{ pipeline_name } /execution/" ,
448+ fr"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } /execution/" ,
429449 execution .arn ,
430450 )
431451
432452 execution = pipeline .start (parameters = {"GoodEnoughInput" : 0 })
433453 assert re .match (
434- fr"arn:aws:sagemaker:{ region } :\d{{12}}:pipeline/{ pipeline_name } /execution/" ,
454+ fr"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } /execution/" ,
435455 execution .arn ,
436456 )
437457 finally :
0 commit comments