@@ -830,3 +830,48 @@ def test_tf_script_mode_mpi(time, strftime, sagemaker_session):
830830
831831 actual_train_args = sagemaker_session .method_calls [0 ][2 ]
832832 assert actual_train_args == expected_train_args
833+
834+
835+ @patch ('sagemaker.utils.create_tar_file' , MagicMock ())
836+ def test_tf_script_mode_attach (sagemaker_session , tf_version ):
837+ training_image = '1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow-py3-cpu:{}-cpu-py3' .format (tf_version )
838+ rjd = {
839+ 'AlgorithmSpecification' : {
840+ 'TrainingInputMode' : 'File' ,
841+ 'TrainingImage' : training_image
842+ },
843+ 'HyperParameters' : {
844+ 'sagemaker_submit_directory' : '"s3://some/sourcedir.tar.gz"' ,
845+ 'sagemaker_program' : '"iris-dnn-classifier.py"' ,
846+ 'sagemaker_enable_cloudwatch_metrics' : 'false' ,
847+ 'sagemaker_container_log_level' : '"logging.INFO"' ,
848+ 'sagemaker_job_name' : '"neo"'
849+ },
850+ 'RoleArn' : 'arn:aws:iam::366:role/SageMakerRole' ,
851+ 'ResourceConfig' : {
852+ 'VolumeSizeInGB' : 30 ,
853+ 'InstanceCount' : 1 ,
854+ 'InstanceType' : 'ml.c4.xlarge'
855+ },
856+ 'StoppingCondition' : {'MaxRuntimeInSeconds' : 24 * 60 * 60 },
857+ 'TrainingJobName' : 'neo' ,
858+ 'TrainingJobStatus' : 'Completed' ,
859+ 'OutputDataConfig' : {'KmsKeyId' : '' , 'S3OutputPath' : 's3://place/output/neo' },
860+ 'TrainingJobOutput' : {'S3TrainingJobOutput' : 's3://here/output.tar.gz' }}
861+ sagemaker_session .sagemaker_client .describe_training_job = Mock (name = 'describe_training_job' , return_value = rjd )
862+
863+ estimator = TensorFlow .attach (training_job_name = 'neo' , sagemaker_session = sagemaker_session )
864+ assert estimator .latest_training_job .job_name == 'neo'
865+ assert estimator .py_version == 'py3'
866+ assert estimator .framework_version == tf_version
867+ assert estimator .role == 'arn:aws:iam::366:role/SageMakerRole'
868+ assert estimator .train_instance_count == 1
869+ assert estimator .train_max_run == 24 * 60 * 60
870+ assert estimator .input_mode == 'File'
871+ assert estimator .input_mode == 'File'
872+ assert estimator .base_job_name == 'neo'
873+ assert estimator .output_path == 's3://place/output/neo'
874+ assert estimator .output_kms_key == ''
875+ assert estimator .hyperparameters () is not None
876+ assert estimator .source_dir == 's3://some/sourcedir.tar.gz'
877+ assert estimator .entry_point == 'iris-dnn-classifier.py'
0 commit comments