@@ -855,6 +855,148 @@ def test_one_step_callback_pipeline(sagemaker_session, role, pipeline_name, regi
855855 pass
856856
857857
858+ def test_steps_with_map_params_pipeline (
859+ sagemaker_session , role , script_dir , pipeline_name , region_name , athena_dataset_definition
860+ ):
861+ instance_count = ParameterInteger (name = "InstanceCount" , default_value = 2 )
862+ framework_version = "0.20.0"
863+ instance_type = ParameterString (name = "InstanceType" , default_value = "ml.m5.xlarge" )
864+ output_prefix = ParameterString (name = "OutputPrefix" , default_value = "output" )
865+ input_data = f"s3://sagemaker-sample-data-{ region_name } /processing/census/census-income.csv"
866+
867+ sklearn_processor = SKLearnProcessor (
868+ framework_version = framework_version ,
869+ instance_type = instance_type ,
870+ instance_count = instance_count ,
871+ base_job_name = "test-sklearn" ,
872+ sagemaker_session = sagemaker_session ,
873+ role = role ,
874+ )
875+ step_process = ProcessingStep (
876+ name = "my-process" ,
877+ display_name = "ProcessingStep" ,
878+ description = "description for Processing step" ,
879+ processor = sklearn_processor ,
880+ inputs = [
881+ ProcessingInput (source = input_data , destination = "/opt/ml/processing/input" ),
882+ ProcessingInput (dataset_definition = athena_dataset_definition ),
883+ ],
884+ outputs = [
885+ ProcessingOutput (output_name = "train_data" , source = "/opt/ml/processing/train" ),
886+ ProcessingOutput (
887+ output_name = "test_data" ,
888+ source = "/opt/ml/processing/test" ,
889+ destination = Join (
890+ on = "/" ,
891+ values = [
892+ "s3:/" ,
893+ sagemaker_session .default_bucket (),
894+ "test-sklearn" ,
895+ output_prefix ,
896+ ExecutionVariables .PIPELINE_EXECUTION_ID ,
897+ ],
898+ ),
899+ ),
900+ ],
901+ code = os .path .join (script_dir , "preprocessing.py" ),
902+ )
903+
904+ sklearn_train = SKLearn (
905+ framework_version = framework_version ,
906+ entry_point = os .path .join (script_dir , "train.py" ),
907+ instance_type = instance_type ,
908+ sagemaker_session = sagemaker_session ,
909+ role = role ,
910+ hyperparameters = {
911+ "batch-size" : 500 ,
912+ "epochs" : 5 ,
913+ },
914+ )
915+ step_train = TrainingStep (
916+ name = "my-train" ,
917+ display_name = "TrainingStep" ,
918+ description = "description for Training step" ,
919+ estimator = sklearn_train ,
920+ inputs = TrainingInput (
921+ s3_data = step_process .properties .ProcessingOutputConfig .Outputs [
922+ "train_data"
923+ ].S3Output .S3Uri
924+ ),
925+ )
926+
927+ model = Model (
928+ image_uri = sklearn_train .image_uri ,
929+ model_data = step_train .properties .ModelArtifacts .S3ModelArtifacts ,
930+ sagemaker_session = sagemaker_session ,
931+ role = role ,
932+ )
933+ model_inputs = CreateModelInput (
934+ instance_type = "ml.m5.large" ,
935+ accelerator_type = "ml.eia1.medium" ,
936+ )
937+ step_model = CreateModelStep (
938+ name = "my-model" ,
939+ display_name = "ModelStep" ,
940+ description = "description for Model step" ,
941+ model = model ,
942+ inputs = model_inputs ,
943+ )
944+
945+ # Condition step for evaluating model quality and branching execution
946+ cond_lte = ConditionGreaterThanOrEqualTo (
947+ left = step_train .properties .HyperParameters ["batch-size" ],
948+ right = 6.0 ,
949+ )
950+
951+ step_cond = ConditionStep (
952+ name = "CustomerChurnAccuracyCond" ,
953+ conditions = [cond_lte ],
954+ if_steps = [],
955+ else_steps = [step_model ],
956+ )
957+
958+ pipeline = Pipeline (
959+ name = pipeline_name ,
960+ parameters = [instance_type , instance_count , output_prefix ],
961+ steps = [step_process , step_train , step_cond ],
962+ sagemaker_session = sagemaker_session ,
963+ )
964+
965+ definition = json .loads (pipeline .definition ())
966+ assert definition ["Version" ] == "2020-12-01"
967+
968+ steps = definition ["Steps" ]
969+ assert len (steps ) == 3
970+ training_args = {}
971+ condition_args = {}
972+ for step in steps :
973+ if step ["Type" ] == "Training" :
974+ training_args = step ["Arguments" ]
975+ if step ["Type" ] == "Condition" :
976+ condition_args = step ["Arguments" ]
977+
978+ assert training_args ["InputDataConfig" ][0 ]["DataSource" ]["S3DataSource" ]["S3Uri" ] == {
979+ "Get" : "Steps.my-process.ProcessingOutputConfig.Outputs['train_data'].S3Output.S3Uri"
980+ }
981+ assert condition_args ["Conditions" ][0 ]["LeftValue" ] == {
982+ "Get" : "Steps.my-train.HyperParameters['batch-size']"
983+ }
984+
985+ try :
986+ response = pipeline .create (role )
987+ create_arn = response ["PipelineArn" ]
988+ assert re .match (
989+ fr"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } " ,
990+ create_arn ,
991+ )
992+
993+ finally :
994+ try :
995+ pipeline .delete ()
996+ except Exception :
997+ pass
998+
999+
8581000def test_two_step_callback_pipeline_with_output_reference (
8591001 sagemaker_session , role , pipeline_name , region_name
8601002):
0 commit comments