2323from sagemaker .parameter import IntegerParameter
2424from sagemaker .tuner import HyperparameterTuner
2525from sagemaker .workflow .pipeline_context import PipelineSession
26+ from tests .unit .sagemaker .workflow .helpers import CustomStep
2627
2728from sagemaker .workflow .steps import TransformStep , TransformInput
2829from sagemaker .workflow .pipeline import Pipeline
2930from sagemaker .workflow .parameters import ParameterString
31+ from sagemaker .workflow .functions import Join
32+ from sagemaker .workflow import is_pipeline_variable
3033
3134from sagemaker .transformer import Transformer
3235
@@ -53,6 +56,7 @@ def client():
5356 client_mock ._client_config .user_agent = (
5457 "Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource"
5558 )
59+ client_mock .describe_model .return_value = {"PrimaryContainer" : {}, "Containers" : {}}
5660 return client_mock
5761
5862
@@ -80,18 +84,44 @@ def pipeline_session(boto_session, client):
8084 )
8185
8286
83- def test_transform_step_with_transformer (pipeline_session ):
84- model_name = ParameterString ("ModelName" )
87+ @pytest .mark .parametrize (
88+ "model_name" ,
89+ [
90+ "my-model" ,
91+ ParameterString ("ModelName" ),
92+ ParameterString ("ModelName" , default_value = "my-model" ),
93+ Join (on = "-" , values = ["my" , "model" ]),
94+ CustomStep (name = "custom-step" ).properties .RoleArn ,
95+ ],
96+ )
97+ @pytest .mark .parametrize (
98+ "data" ,
99+ [
100+ "s3://my-bucket/my-data" ,
101+ ParameterString ("MyTransformInput" ),
102+ ParameterString ("MyTransformInput" , default_value = "s3://my-model" ),
103+ Join (on = "/" , values = ["s3://my-bucket" , "my-transform-data" , "input" ]),
104+ CustomStep (name = "custom-step" ).properties .OutputDataConfig .S3OutputPath ,
105+ ],
106+ )
107+ @pytest .mark .parametrize (
108+ "output_path" ,
109+ [
110+ "s3://my-bucket/my-output-path" ,
111+ ParameterString ("MyOutputPath" ),
112+ ParameterString ("MyOutputPath" , default_value = "s3://my-output" ),
113+ Join (on = "/" , values = ["s3://my-bucket" , "my-transform-data" , "output" ]),
114+ CustomStep (name = "custom-step" ).properties .OutputDataConfig .S3OutputPath ,
115+ ],
116+ )
117+ def test_transform_step_with_transformer (model_name , data , output_path , pipeline_session ):
85118 transformer = Transformer (
86119 model_name = model_name ,
87120 instance_type = "ml.m5.xlarge" ,
88121 instance_count = 1 ,
89- output_path = f"s3:// { pipeline_session . default_bucket () } /Transform" ,
122+ output_path = output_path ,
90123 sagemaker_session = pipeline_session ,
91124 )
92- data = ParameterString (
93- name = "Data" , default_value = f"s3://{ pipeline_session .default_bucket ()} /batch-data"
94- )
95125 transform_inputs = TransformInput (data = data )
96126
97127 with warnings .catch_warnings (record = True ) as w :
@@ -123,13 +153,27 @@ def test_transform_step_with_transformer(pipeline_session):
123153 parameters = [model_name , data ],
124154 sagemaker_session = pipeline_session ,
125155 )
126- step_args .args ["ModelName" ] = model_name .expr
127- step_args .args ["TransformInput" ]["DataSource" ]["S3DataSource" ]["S3Uri" ] = data .expr
128- assert json .loads (pipeline .definition ())["Steps" ][0 ] == {
129- "Name" : "MyTransformStep" ,
130- "Type" : "Transform" ,
131- "Arguments" : step_args .args ,
132- }
156+ step_args = step_args .args
157+ step_def = json .loads (pipeline .definition ())["Steps" ][0 ]
158+ step_args ["ModelName" ] = model_name .expr if is_pipeline_variable (model_name ) else model_name
159+ step_args ["TransformInput" ]["DataSource" ]["S3DataSource" ]["S3Uri" ] = (
160+ data .expr if is_pipeline_variable (data ) else data
161+ )
162+ step_args ["TransformOutput" ]["S3OutputPath" ] = (
163+ output_path .expr if is_pipeline_variable (output_path ) else output_path
164+ )
165+
166+ del (
167+ step_args ["ModelName" ],
168+ step_args ["TransformInput" ]["DataSource" ]["S3DataSource" ]["S3Uri" ],
169+ step_args ["TransformOutput" ]["S3OutputPath" ],
170+ )
171+ del (
172+ step_def ["Arguments" ]["ModelName" ],
173+ step_def ["Arguments" ]["TransformInput" ]["DataSource" ]["S3DataSource" ]["S3Uri" ],
174+ step_def ["Arguments" ]["TransformOutput" ]["S3OutputPath" ],
175+ )
176+ assert step_def == {"Name" : "MyTransformStep" , "Type" : "Transform" , "Arguments" : step_args }
133177
134178
135179@pytest .mark .parametrize (
0 commit comments