@@ -83,15 +83,16 @@ def test_pytorch_training_model_registration_and_creation_without_custom_inferen
8383 inputs = TrainingInput (s3_data = input_path )
8484
8585 instance_count = ParameterInteger (name = "InstanceCount" , default_value = 1 )
86- instance_type = ParameterString (name = "InstanceType" , default_value = "ml.m5.xlarge" )
8786
87+ # If image_uri is not provided, the instance_type should not be a pipeline variable
88+ # since instance_type is used to retrieve image_uri in compile time (PySDK)
8889 pytorch_estimator = PyTorch (
8990 entry_point = entry_point ,
9091 role = role ,
9192 framework_version = "1.5.0" ,
9293 py_version = "py3" ,
9394 instance_count = instance_count ,
94- instance_type = instance_type ,
95+ instance_type = "ml.m5.xlarge" ,
9596 sagemaker_session = pipeline_session ,
9697 )
9798 train_step_args = pytorch_estimator .fit (inputs = inputs )
@@ -140,7 +141,7 @@ def test_pytorch_training_model_registration_and_creation_without_custom_inferen
140141 )
141142 pipeline = Pipeline (
142143 name = pipeline_name ,
143- parameters = [instance_count , instance_type ],
144+ parameters = [instance_count ],
144145 steps = [step_train , step_model_regis , step_model_create , step_fail ],
145146 sagemaker_session = pipeline_session ,
146147 )
@@ -203,15 +204,16 @@ def test_pytorch_training_model_registration_and_creation_with_custom_inference(
203204 inputs = TrainingInput (s3_data = input_path )
204205
205206 instance_count = ParameterInteger (name = "InstanceCount" , default_value = 1 )
206- instance_type = ParameterString (name = "InstanceType" , default_value = "ml.m5.xlarge" )
207207
208+ # If image_uri is not provided, the instance_type should not be a pipeline variable
209+ # since instance_type is used to retrieve image_uri in compile time (PySDK)
208210 pytorch_estimator = PyTorch (
209211 entry_point = entry_point ,
210212 role = role ,
211213 framework_version = "1.5.0" ,
212214 py_version = "py3" ,
213215 instance_count = instance_count ,
214- instance_type = instance_type ,
216+ instance_type = "ml.m5.xlarge" ,
215217 sagemaker_session = pipeline_session ,
216218 output_kms_key = kms_key ,
217219 )
@@ -267,7 +269,7 @@ def test_pytorch_training_model_registration_and_creation_with_custom_inference(
267269 )
268270 pipeline = Pipeline (
269271 name = pipeline_name ,
270- parameters = [instance_count , instance_type ],
272+ parameters = [instance_count ],
271273 steps = [step_train , step_model_regis , step_model_create , step_fail ],
272274 sagemaker_session = pipeline_session ,
273275 )
@@ -400,7 +402,6 @@ def test_model_registration_with_drift_check_baselines_and_model_metrics(
400402 pipeline_name ,
401403):
402404 instance_count = ParameterInteger (name = "InstanceCount" , default_value = 1 )
403- instance_type = ParameterString (name = "InstanceType" , default_value = "ml.m5.xlarge" )
404405
405406 # upload model data to s3
406407 model_local_path = os .path .join (DATA_DIR , "mxnet_mnist/model.tar.gz" )
@@ -488,10 +489,12 @@ def test_model_registration_with_drift_check_baselines_and_model_metrics(
488489 ),
489490 )
490491 customer_metadata_properties = {"key1" : "value1" }
492+ # If image_uri is not provided, the instance_type should not be a pipeline variable
493+ # since instance_type is used to retrieve image_uri in compile time (PySDK)
491494 estimator = XGBoost (
492495 entry_point = "training.py" ,
493496 source_dir = os .path .join (DATA_DIR , "sip" ),
494- instance_type = instance_type ,
497+ instance_type = "ml.m5.xlarge" ,
495498 instance_count = instance_count ,
496499 framework_version = "0.90-2" ,
497500 sagemaker_session = pipeline_session ,
@@ -524,7 +527,6 @@ def test_model_registration_with_drift_check_baselines_and_model_metrics(
524527 parameters = [
525528 model_uri_param ,
526529 metrics_uri_param ,
527- instance_type ,
528530 instance_count ,
529531 ],
530532 steps = [step_model_register ],
@@ -606,13 +608,14 @@ def test_model_registration_with_tensorflow_model_with_pipeline_model(
606608 )
607609 inputs = TrainingInput (s3_data = input_path )
608610 instance_count = ParameterInteger (name = "InstanceCount" , default_value = 1 )
609- instance_type = ParameterString (name = "InstanceType" , default_value = "ml.m5.xlarge" )
610611
612+ # If image_uri is not provided, the instance_type should not be a pipeline variable
613+ # since instance_type is used to retrieve image_uri in compile time (PySDK)
611614 tensorflow_estimator = TensorFlow (
612615 entry_point = entry_point ,
613616 role = role ,
614617 instance_count = instance_count ,
615- instance_type = instance_type ,
618+ instance_type = "ml.m5.xlarge" ,
616619 framework_version = tf_full_version ,
617620 py_version = tf_full_py_version ,
618621 sagemaker_session = pipeline_session ,
@@ -645,10 +648,7 @@ def test_model_registration_with_tensorflow_model_with_pipeline_model(
645648 )
646649 pipeline = Pipeline (
647650 name = pipeline_name ,
648- parameters = [
649- instance_count ,
650- instance_type ,
651- ],
651+ parameters = [instance_count ],
652652 steps = [step_train , step_register_model ],
653653 sagemaker_session = pipeline_session ,
654654 )
0 commit comments