2121from sagemaker .mxnet .estimator import MXNet
2222from sagemaker .mxnet .model import MXNetModel
2323from sagemaker .utils import sagemaker_timestamp
24- from tests .integ import DATA_DIR , TRAINING_DEFAULT_TIMEOUT_MINUTES
24+ from tests .integ import DATA_DIR , PYTHON_VERSION , TRAINING_DEFAULT_TIMEOUT_MINUTES
2525from tests .integ .timeout import timeout , timeout_and_delete_endpoint_by_name
2626
2727
@@ -32,7 +32,7 @@ def mxnet_training_job(sagemaker_session, mxnet_full_version):
3232 data_path = os .path .join (DATA_DIR , 'mxnet_mnist' )
3333
3434 mx = MXNet (entry_point = script_path , role = 'SageMakerRole' , framework_version = mxnet_full_version ,
35- train_instance_count = 1 , train_instance_type = 'ml.c4.xlarge' ,
35+ py_version = PYTHON_VERSION , train_instance_count = 1 , train_instance_type = 'ml.c4.xlarge' ,
3636 sagemaker_session = sagemaker_session )
3737
3838 train_input = mx .sagemaker_session .upload_data (path = os .path .join (data_path , 'train' ),
@@ -62,7 +62,8 @@ def test_deploy_model(mxnet_training_job, sagemaker_session):
6262 desc = sagemaker_session .sagemaker_client .describe_training_job (TrainingJobName = mxnet_training_job )
6363 model_data = desc ['ModelArtifacts' ]['S3ModelArtifacts' ]
6464 script_path = os .path .join (DATA_DIR , 'mxnet_mnist' , 'mnist.py' )
65- model = MXNetModel (model_data , 'SageMakerRole' , entry_point = script_path , sagemaker_session = sagemaker_session )
65+ model = MXNetModel (model_data , 'SageMakerRole' , entry_point = script_path ,
66+ py_version = PYTHON_VERSION , sagemaker_session = sagemaker_session )
6667 predictor = model .deploy (1 , 'ml.m4.xlarge' , endpoint_name = endpoint_name )
6768
6869 data = numpy .zeros (shape = (1 , 1 , 28 , 28 ))
@@ -76,7 +77,7 @@ def test_async_fit(sagemaker_session):
7677 script_path = os .path .join (DATA_DIR , 'mxnet_mnist' , 'mnist.py' )
7778 data_path = os .path .join (DATA_DIR , 'mxnet_mnist' )
7879
79- mx = MXNet (entry_point = script_path , role = 'SageMakerRole' ,
80+ mx = MXNet (entry_point = script_path , role = 'SageMakerRole' , py_version = PYTHON_VERSION ,
8081 train_instance_count = 1 , train_instance_type = 'ml.c4.xlarge' ,
8182 sagemaker_session = sagemaker_session )
8283
@@ -105,7 +106,7 @@ def test_failed_training_job(sagemaker_session, mxnet_full_version):
105106 data_path = os .path .join (DATA_DIR , 'mxnet_mnist' )
106107
107108 mx = MXNet (entry_point = script_path , role = 'SageMakerRole' , framework_version = mxnet_full_version ,
108- train_instance_count = 1 , train_instance_type = 'ml.c4.xlarge' ,
109+ py_version = PYTHON_VERSION , train_instance_count = 1 , train_instance_type = 'ml.c4.xlarge' ,
109110 sagemaker_session = sagemaker_session )
110111
111112 train_input = mx .sagemaker_session .upload_data (path = os .path .join (data_path , 'train' ),
0 commit comments