4747from sagemaker .workflow .parameters import ParameterInteger , ParameterString
4848from sagemaker .workflow .pipeline import Pipeline
4949from sagemaker .workflow .pipeline_context import PipelineSession
50+ from sagemaker .workflow .pipeline_definition_config import PipelineDefinitionConfig
5051from sagemaker .workflow .step_outputs import get_step
5152from sagemaker .workflow .steps import (
5253 ProcessingStep ,
111112 "Result": [1, 2, 3], "Exception": null
112113}
113114"""
115+ TEST_JOB_NAME = "test-job-name"
114116
115117
116118@pytest .fixture
@@ -188,6 +190,8 @@ def training_step(pipeline_session):
188190 sagemaker_session = pipeline_session ,
189191 output_path = "s3://a/b" ,
190192 use_spot_instances = False ,
193+ # base_job_name would be popped out if no pipeline_definition_config configured
194+ base_job_name = TEST_JOB_NAME ,
191195 )
192196 training_input = TrainingInput (s3_data = f"s3://{ BUCKET } /train_manifest" )
193197 step_args = estimator .fit (inputs = training_input )
@@ -207,6 +211,8 @@ def processing_step(pipeline_session):
207211 instance_count = 1 ,
208212 instance_type = INSTANCE_TYPE ,
209213 sagemaker_session = pipeline_session ,
214+ # base_job_name would be popped out if no pipeline_definition_config configured
215+ base_job_name = TEST_JOB_NAME ,
210216 )
211217 processing_input = [
212218 ProcessingInput (
@@ -239,6 +245,8 @@ def transform_step(pipeline_session):
239245 instance_count = 1 ,
240246 output_path = "s3://my-bucket/my-output-path" ,
241247 sagemaker_session = pipeline_session ,
248+ # base_transform_job_name would be popped out if no pipeline_definition_config configured
249+ base_transform_job_name = TEST_JOB_NAME ,
242250 )
243251 transform_inputs = TransformInput (data = "s3://my-bucket/my-data" )
244252 step_args = transformer .transform (
@@ -871,8 +879,8 @@ def depends_step():
871879 )
872880
873881
874- @patch ("sagemaker.local.image._SageMakerContainer.process" )
875- def test_execute_pipeline_processing_step (process , local_sagemaker_session , processing_step ):
882+ @patch ("sagemaker.local.image._SageMakerContainer.process" , MagicMock () )
883+ def test_execute_pipeline_processing_step (local_sagemaker_session , processing_step ):
876884 pipeline = Pipeline (
877885 name = "MyPipeline2" ,
878886 steps = [processing_step ],
@@ -1362,3 +1370,86 @@ def test_execute_pipeline_step_create_transform_job_fail(
13621370 step_execution = execution .step_execution
13631371 assert step_execution [transform_step .name ].status == _LocalExecutionStatus .FAILED .value
13641372 assert "Dummy RuntimeError" in step_execution [transform_step .name ].failure_reason
1373+
1374+
1375+ @patch (
1376+ "sagemaker.local.image._SageMakerContainer.train" ,
1377+ MagicMock (return_value = "/some/path/to/model" ),
1378+ )
1379+ @patch ("sagemaker.local.image._SageMakerContainer.process" , MagicMock ())
1380+ def test_pipeline_definition_config_in_local_mode_for_train_process_steps (
1381+ processing_step ,
1382+ training_step ,
1383+ local_sagemaker_session ,
1384+ ):
1385+ exe_steps = [processing_step , training_step ]
1386+
1387+ def _verify_execution (exe_step_name , execution , with_custom_job_prefix ):
1388+ assert not execution .failure_reason
1389+ assert execution .status == _LocalExecutionStatus .SUCCEEDED .value
1390+
1391+ step_execution = execution .step_execution [exe_step_name ]
1392+ assert step_execution .status == _LocalExecutionStatus .SUCCEEDED .value
1393+
1394+ if step_execution .type == StepTypeEnum .PROCESSING :
1395+ job_name_field = "ProcessingJobName"
1396+ elif step_execution .type == StepTypeEnum .TRAINING :
1397+ job_name_field = "TrainingJobName"
1398+
1399+ if with_custom_job_prefix :
1400+ assert step_execution .properties [job_name_field ] == TEST_JOB_NAME
1401+ else :
1402+ assert step_execution .properties [job_name_field ].startswith (step_execution .name )
1403+
1404+ for exe_step in exe_steps :
1405+ pipeline = Pipeline (
1406+ name = "MyPipelineX-" + exe_step .name ,
1407+ steps = [exe_step ],
1408+ sagemaker_session = local_sagemaker_session ,
1409+ parameters = [INSTANCE_COUNT_PIPELINE_PARAMETER ],
1410+ )
1411+
1412+ execution = LocalPipelineExecutor (
1413+ _LocalPipelineExecution ("my-execution-x-" + exe_step .name , pipeline ),
1414+ local_sagemaker_session ,
1415+ ).execute ()
1416+
1417+ _verify_execution (
1418+ exe_step_name = exe_step .name , execution = execution , with_custom_job_prefix = False
1419+ )
1420+
1421+ pipeline .pipeline_definition_config = PipelineDefinitionConfig (use_custom_job_prefix = True )
1422+ execution = LocalPipelineExecutor (
1423+ _LocalPipelineExecution ("my-execution-x-" + exe_step .name , pipeline ),
1424+ local_sagemaker_session ,
1425+ ).execute ()
1426+
1427+ _verify_execution (
1428+ exe_step_name = exe_step .name , execution = execution , with_custom_job_prefix = True
1429+ )
1430+
1431+
1432+ @patch ("sagemaker.local.local_session.LocalSagemakerClient.create_transform_job" )
1433+ def test_pipeline_definition_config_in_local_mode_for_transform_step (
1434+ create_transform_job , local_sagemaker_session , transform_step
1435+ ):
1436+ pipeline = Pipeline (
1437+ name = "MyPipelineX-" + transform_step .name ,
1438+ steps = [transform_step ],
1439+ sagemaker_session = local_sagemaker_session ,
1440+ )
1441+ LocalPipelineExecutor (
1442+ _LocalPipelineExecution ("my-execution-x-" + transform_step .name , pipeline ),
1443+ local_sagemaker_session ,
1444+ ).execute ()
1445+
1446+ assert create_transform_job .call_args .args [0 ].startswith (transform_step .name )
1447+
1448+ pipeline .pipeline_definition_config = PipelineDefinitionConfig (use_custom_job_prefix = True )
1449+
1450+ LocalPipelineExecutor (
1451+ _LocalPipelineExecution ("my-execution-x-" + transform_step .name , pipeline ),
1452+ local_sagemaker_session ,
1453+ ).execute ()
1454+
1455+ assert create_transform_job .call_args .args [0 ] == TEST_JOB_NAME
0 commit comments