3535from sagemaker .network import NetworkConfig
3636from sagemaker .transformer import Transformer
3737from sagemaker .workflow .properties import Properties
38+ from sagemaker .workflow .parameters import ParameterString , ParameterInteger
3839from sagemaker .workflow .steps import (
3940 ProcessingStep ,
4041 Step ,
@@ -112,16 +113,27 @@ def test_custom_step():
112113
113114
114115def test_training_step (sagemaker_session ):
116+ instance_type_parameter = ParameterString (name = "InstanceType" , default_value = "c4.4xlarge" )
117+ instance_count_parameter = ParameterInteger (name = "InstanceCount" , default_value = 1 )
118+ data_source_uri_parameter = ParameterString (
119+ name = "DataSourceS3Uri" , default_value = f"s3://{ BUCKET } /train_manifest"
120+ )
121+ training_epochs_parameter = ParameterInteger (name = "TrainingEpochs" , default_value = 5 )
122+ training_batch_size_parameter = ParameterInteger (name = "TrainingBatchSize" , default_value = 500 )
115123 estimator = Estimator (
116124 image_uri = IMAGE_URI ,
117125 role = ROLE ,
118- instance_count = 1 ,
119- instance_type = "c4.4xlarge" ,
126+ instance_count = instance_count_parameter ,
127+ instance_type = instance_type_parameter ,
120128 profiler_config = ProfilerConfig (system_monitor_interval_millis = 500 ),
129+ hyperparameters = {
130+ "batch-size" : training_batch_size_parameter ,
131+ "epochs" : training_epochs_parameter ,
132+ },
121133 rules = [],
122134 sagemaker_session = sagemaker_session ,
123135 )
124- inputs = TrainingInput (f"s3:// { BUCKET } /train_manifest" )
136+ inputs = TrainingInput (s3_data = data_source_uri_parameter )
125137 cache_config = CacheConfig (enable_caching = True , expire_after = "PT1H" )
126138 step = TrainingStep (
127139 name = "MyTrainingStep" , estimator = estimator , inputs = inputs , cache_config = cache_config
@@ -131,22 +143,26 @@ def test_training_step(sagemaker_session):
131143 "Type" : "Training" ,
132144 "Arguments" : {
133145 "AlgorithmSpecification" : {"TrainingImage" : IMAGE_URI , "TrainingInputMode" : "File" },
146+ "HyperParameters" : {
147+ "batch-size" : training_batch_size_parameter ,
148+ "epochs" : training_epochs_parameter ,
149+ },
134150 "InputDataConfig" : [
135151 {
136152 "ChannelName" : "training" ,
137153 "DataSource" : {
138154 "S3DataSource" : {
139155 "S3DataDistributionType" : "FullyReplicated" ,
140156 "S3DataType" : "S3Prefix" ,
141- "S3Uri" : f"s3:// { BUCKET } /train_manifest" ,
157+ "S3Uri" : data_source_uri_parameter ,
142158 }
143159 },
144160 }
145161 ],
146162 "OutputDataConfig" : {"S3OutputPath" : f"s3://{ BUCKET } /" },
147163 "ResourceConfig" : {
148- "InstanceCount" : 1 ,
149- "InstanceType" : "c4.4xlarge" ,
164+ "InstanceCount" : instance_count_parameter ,
165+ "InstanceType" : instance_type_parameter ,
150166 "VolumeSizeInGB" : 30 ,
151167 },
152168 "RoleArn" : ROLE ,
@@ -162,16 +178,21 @@ def test_training_step(sagemaker_session):
162178
163179
164180def test_processing_step (sagemaker_session ):
181+ processing_input_data_uri_parameter = ParameterString (
182+ name = "ProcessingInputDataUri" , default_value = f"s3://{ BUCKET } /processing_manifest"
183+ )
184+ instance_type_parameter = ParameterString (name = "InstanceType" , default_value = "ml.m4.4xlarge" )
185+ instance_count_parameter = ParameterInteger (name = "InstanceCount" , default_value = 1 )
165186 processor = Processor (
166187 image_uri = IMAGE_URI ,
167188 role = ROLE ,
168- instance_count = 1 ,
169- instance_type = "ml.m4.4xlarge" ,
189+ instance_count = instance_count_parameter ,
190+ instance_type = instance_type_parameter ,
170191 sagemaker_session = sagemaker_session ,
171192 )
172193 inputs = [
173194 ProcessingInput (
174- source = f"s3:// { BUCKET } /processing_manifest" ,
195+ source = processing_input_data_uri_parameter ,
175196 destination = "processing_manifest" ,
176197 )
177198 ]
@@ -198,14 +219,14 @@ def test_processing_step(sagemaker_session):
198219 "S3DataDistributionType" : "FullyReplicated" ,
199220 "S3DataType" : "S3Prefix" ,
200221 "S3InputMode" : "File" ,
201- "S3Uri" : "s3://my-bucket/processing_manifest" ,
222+ "S3Uri" : processing_input_data_uri_parameter ,
202223 },
203224 }
204225 ],
205226 "ProcessingResources" : {
206227 "ClusterConfig" : {
207- "InstanceCount" : 1 ,
208- "InstanceType" : "ml.m4.4xlarge" ,
228+ "InstanceCount" : instance_count_parameter ,
229+ "InstanceType" : instance_type_parameter ,
209230 "VolumeSizeInGB" : 30 ,
210231 }
211232 },
0 commit comments