@@ -83,16 +83,15 @@ def test_pytorch_training_model_registration_and_creation_without_custom_inferen
8383 inputs = TrainingInput (s3_data = input_path )
8484
8585 instance_count = ParameterInteger (name = "InstanceCount" , default_value = 1 )
86+ instance_type = ParameterString (name = "InstanceType" , default_value = "ml.m5.xlarge" )
8687
87- # If image_uri is not provided, the instance_type should not be a pipeline variable
88- # since instance_type is used to retrieve image_uri in compile time (PySDK)
8988 pytorch_estimator = PyTorch (
9089 entry_point = entry_point ,
9190 role = role ,
9291 framework_version = "1.5.0" ,
9392 py_version = "py3" ,
9493 instance_count = instance_count ,
95- instance_type = "ml.m5.xlarge" ,
94+ instance_type = instance_type ,
9695 sagemaker_session = pipeline_session ,
9796 )
9897 train_step_args = pytorch_estimator .fit (inputs = inputs )
@@ -141,7 +140,7 @@ def test_pytorch_training_model_registration_and_creation_without_custom_inferen
141140 )
142141 pipeline = Pipeline (
143142 name = pipeline_name ,
144- parameters = [instance_count ],
143+ parameters = [instance_count , instance_type ],
145144 steps = [step_train , step_model_regis , step_model_create , step_fail ],
146145 sagemaker_session = pipeline_session ,
147146 )
@@ -204,16 +203,15 @@ def test_pytorch_training_model_registration_and_creation_with_custom_inference(
204203 inputs = TrainingInput (s3_data = input_path )
205204
206205 instance_count = ParameterInteger (name = "InstanceCount" , default_value = 1 )
206+ instance_type = ParameterString (name = "InstanceType" , default_value = "ml.m5.xlarge" )
207207
208- # If image_uri is not provided, the instance_type should not be a pipeline variable
209- # since instance_type is used to retrieve image_uri in compile time (PySDK)
210208 pytorch_estimator = PyTorch (
211209 entry_point = entry_point ,
212210 role = role ,
213211 framework_version = "1.5.0" ,
214212 py_version = "py3" ,
215213 instance_count = instance_count ,
216- instance_type = "ml.m5.xlarge" ,
214+ instance_type = instance_type ,
217215 sagemaker_session = pipeline_session ,
218216 output_kms_key = kms_key ,
219217 )
@@ -269,7 +267,7 @@ def test_pytorch_training_model_registration_and_creation_with_custom_inference(
269267 )
270268 pipeline = Pipeline (
271269 name = pipeline_name ,
272- parameters = [instance_count ],
270+ parameters = [instance_count , instance_type ],
273271 steps = [step_train , step_model_regis , step_model_create , step_fail ],
274272 sagemaker_session = pipeline_session ,
275273 )
@@ -402,6 +400,7 @@ def test_model_registration_with_drift_check_baselines_and_model_metrics(
402400 pipeline_name ,
403401):
404402 instance_count = ParameterInteger (name = "InstanceCount" , default_value = 1 )
403+ instance_type = ParameterString (name = "InstanceType" , default_value = "ml.m5.xlarge" )
405404
406405 # upload model data to s3
407406 model_local_path = os .path .join (DATA_DIR , "mxnet_mnist/model.tar.gz" )
@@ -489,12 +488,10 @@ def test_model_registration_with_drift_check_baselines_and_model_metrics(
489488 ),
490489 )
491490 customer_metadata_properties = {"key1" : "value1" }
492- # If image_uri is not provided, the instance_type should not be a pipeline variable
493- # since instance_type is used to retrieve image_uri in compile time (PySDK)
494491 estimator = XGBoost (
495492 entry_point = "training.py" ,
496493 source_dir = os .path .join (DATA_DIR , "sip" ),
497- instance_type = "ml.m5.xlarge" ,
494+ instance_type = instance_type ,
498495 instance_count = instance_count ,
499496 framework_version = "0.90-2" ,
500497 sagemaker_session = pipeline_session ,
@@ -527,6 +524,7 @@ def test_model_registration_with_drift_check_baselines_and_model_metrics(
527524 parameters = [
528525 model_uri_param ,
529526 metrics_uri_param ,
527+ instance_type ,
530528 instance_count ,
531529 ],
532530 steps = [step_model_register ],
@@ -608,14 +606,13 @@ def test_model_registration_with_tensorflow_model_with_pipeline_model(
608606 )
609607 inputs = TrainingInput (s3_data = input_path )
610608 instance_count = ParameterInteger (name = "InstanceCount" , default_value = 1 )
609+ instance_type = ParameterString (name = "InstanceType" , default_value = "ml.m5.xlarge" )
611610
612- # If image_uri is not provided, the instance_type should not be a pipeline variable
613- # since instance_type is used to retrieve image_uri in compile time (PySDK)
614611 tensorflow_estimator = TensorFlow (
615612 entry_point = entry_point ,
616613 role = role ,
617614 instance_count = instance_count ,
618- instance_type = "ml.m5.xlarge" ,
615+ instance_type = instance_type ,
619616 framework_version = tf_full_version ,
620617 py_version = tf_full_py_version ,
621618 sagemaker_session = pipeline_session ,
@@ -648,7 +645,10 @@ def test_model_registration_with_tensorflow_model_with_pipeline_model(
648645 )
649646 pipeline = Pipeline (
650647 name = pipeline_name ,
651- parameters = [instance_count ],
648+ parameters = [
649+ instance_count ,
650+ instance_type ,
651+ ],
652652 steps = [step_train , step_register_model ],
653653 sagemaker_session = pipeline_session ,
654654 )
0 commit comments