1919import pytest
2020
2121from sagemaker .tensorflow import TensorFlow
22- from sagemaker .utils import unique_name_from_base
22+ from sagemaker .utils import unique_name_from_base , sagemaker_timestamp
2323
2424import tests .integ
2525from tests .integ import timeout
3939TAGS = [{"Key" : "some-key" , "Value" : "some-value" }]
4040
4141
42- def test_mnist (sagemaker_session , instance_type ):
42+ def test_mnist_with_checkpoint_config (sagemaker_session , instance_type ):
43+ checkpoint_s3_uri = "s3://{}/checkpoints/tf-{}" .format (
44+ sagemaker_session .default_bucket (), sagemaker_timestamp ()
45+ )
46+ checkpoint_local_path = "/test/checkpoint/path"
4347 estimator = TensorFlow (
4448 entry_point = SCRIPT ,
4549 role = "SageMakerRole" ,
@@ -50,13 +54,16 @@ def test_mnist(sagemaker_session, instance_type):
5054 framework_version = TensorFlow .LATEST_VERSION ,
5155 py_version = tests .integ .PYTHON_VERSION ,
5256 metric_definitions = [{"Name" : "train:global_steps" , "Regex" : r"global_step\/sec:\s(.*)" }],
57+ checkpoint_s3_uri = checkpoint_s3_uri ,
58+ checkpoint_local_path = checkpoint_local_path ,
5359 )
5460 inputs = estimator .sagemaker_session .upload_data (
5561 path = os .path .join (MNIST_RESOURCE_PATH , "data" ), key_prefix = "scriptmode/mnist"
5662 )
5763
64+ training_job_name = unique_name_from_base ("test-tf-sm-mnist" )
5865 with tests .integ .timeout .timeout (minutes = tests .integ .TRAINING_DEFAULT_TIMEOUT_MINUTES ):
59- estimator .fit (inputs = inputs , job_name = unique_name_from_base ( "test-tf-sm-mnist" ) )
66+ estimator .fit (inputs = inputs , job_name = training_job_name )
6067 assert_s3_files_exist (
6168 sagemaker_session ,
6269 estimator .model_dir ,
@@ -65,33 +72,6 @@ def test_mnist(sagemaker_session, instance_type):
6572 df = estimator .training_job_analytics .dataframe ()
6673 assert df .size > 0
6774
68-
69- @pytest .mark .skipif (
70- tests .integ .test_region () != "us-east-1" ,
71- reason = "checkpoint s3 bucket is in us-east-1, ListObjectsV2 will fail in other regions" ,
72- )
73- def test_checkpoint_config (sagemaker_session , instance_type ):
74- checkpoint_s3_uri = "s3://142577830533-us-east-1-sagemaker-checkpoint"
75- checkpoint_local_path = "/test/checkpoint/path"
76- estimator = TensorFlow (
77- entry_point = SCRIPT ,
78- role = "SageMakerRole" ,
79- train_instance_count = 1 ,
80- train_instance_type = instance_type ,
81- sagemaker_session = sagemaker_session ,
82- script_mode = True ,
83- framework_version = TensorFlow .LATEST_VERSION ,
84- py_version = tests .integ .PYTHON_VERSION ,
85- checkpoint_s3_uri = checkpoint_s3_uri ,
86- checkpoint_local_path = checkpoint_local_path ,
87- )
88- inputs = estimator .sagemaker_session .upload_data (
89- path = os .path .join (MNIST_RESOURCE_PATH , "data" ), key_prefix = "script/mnist"
90- )
91- training_job_name = unique_name_from_base ("test-tf-sm-checkpoint" )
92- with tests .integ .timeout .timeout (minutes = tests .integ .TRAINING_DEFAULT_TIMEOUT_MINUTES ):
93- estimator .fit (inputs = inputs , job_name = training_job_name )
94-
9575 expected_training_checkpoint_config = {
9676 "S3Uri" : checkpoint_s3_uri ,
9777 "LocalPath" : checkpoint_local_path ,
0 commit comments