@@ -441,6 +441,44 @@ def test_tuning_mxnet(sagemaker_session, mxnet_full_version):
441441 predictor .predict (data )
442442
443443
444+ @pytest .mark .canary_quick
445+ def test_tuning_tf_script_mode (sagemaker_session ):
446+ resource_path = os .path .join (DATA_DIR , 'tensorflow_mnist' )
447+ script_path = os .path .join (resource_path , 'mnist.py' )
448+
449+ estimator = TensorFlow (entry_point = script_path ,
450+ role = 'SageMakerRole' ,
451+ train_instance_count = 1 ,
452+ train_instance_type = 'ml.m4.xlarge' ,
453+ script_mode = True ,
454+ sagemaker_session = sagemaker_session ,
455+ py_version = PYTHON_VERSION ,
456+ framework_version = TensorFlow .LATEST_VERSION )
457+
458+ hyperparameter_ranges = {'epochs' : IntegerParameter (1 , 2 )}
459+ objective_metric_name = 'accuracy'
460+ metric_definitions = [{'Name' : objective_metric_name , 'Regex' : 'accuracy = ([0-9\\ .]+)' }]
461+
462+ tuner = HyperparameterTuner (estimator ,
463+ objective_metric_name ,
464+ hyperparameter_ranges ,
465+ metric_definitions ,
466+ max_jobs = 2 ,
467+ max_parallel_jobs = 2 )
468+
469+ with timeout (minutes = TUNING_DEFAULT_TIMEOUT_MINUTES ):
470+ inputs = estimator .sagemaker_session .upload_data (path = os .path .join (resource_path , 'data' ),
471+ key_prefix = 'scriptmode/mnist' )
472+
473+ tuning_job_name = unique_name_from_base ('tune-tf-script-mode' , max_length = 32 )
474+ tuner .fit (inputs , job_name = tuning_job_name )
475+
476+ print ('Started hyperparameter tuning job with name: ' + tuning_job_name )
477+
478+ time .sleep (15 )
479+ tuner .wait ()
480+
481+
444482@pytest .mark .canary_quick
445483@pytest .mark .skipif (PYTHON_VERSION != 'py2' , reason = "TensorFlow image supports only python 2." )
446484def test_tuning_tf (sagemaker_session ):
0 commit comments