@@ -31,8 +31,42 @@ def gpu_instance_type(request):
3131 return "ml.p3.2xlarge"
3232
3333
34+ @pytest .fixture (scope = "module" )
35+ def imagenet_val_set (request , sagemaker_session , tmpdir_factory ):
36+ """
37+ Copies the dataset from the bucket it's hosted in to the local bucket in the test region
38+ """
39+ local_path = tmpdir_factory .mktemp ("trcomp_imagenet_val_set" )
40+ sagemaker_session .download_data (
41+ path = local_path ,
42+ bucket = "collection-of-ml-datasets" ,
43+ key_prefix = "Imagenet/TFRecords/validation" ,
44+ )
45+ train_input = sagemaker_session .upload_data (
46+ path = local_path ,
47+ key_prefix = "integ-test-data/trcomp/tensorflow/imagenet/val" ,
48+ )
49+ return train_input
50+
51+
52+ @pytest .fixture (scope = "module" )
53+ def huggingface_dummy_dataset (request , sagemaker_session ):
54+ """
55+ Copies the dataset from the local disk to the local bucket in the test region
56+ """
57+ data_path = os .path .join (DATA_DIR , "huggingface" )
58+ train_input = sagemaker_session .upload_data (
59+ path = os .path .join (data_path , "train" ),
60+ key_prefix = "integ-test-data/trcomp/huggingface/dummy/train" ,
61+ )
62+ return train_input
63+
64+
3465@pytest .fixture (scope = "module" , autouse = True )
3566def skip_if_incompatible (request ):
67+ """
68+ These tests are for training compiler enabled images/estimators only.
69+ """
3670 if integ .test_region () not in integ .TRAINING_COMPILER_SUPPORTED_REGIONS :
3771 pytest .skip ("SageMaker Training Compiler is not supported in this region" )
3872 if integ .test_region () in integ .TRAINING_NO_P3_REGIONS :
@@ -45,7 +79,11 @@ def test_huggingface_pytorch(
4579 gpu_instance_type ,
4680 huggingface_training_compiler_latest_version ,
4781 huggingface_training_compiler_pytorch_latest_version ,
82+ huggingface_dummy_dataset ,
4883):
84+ """
85+ Test the HuggingFace estimator with PyTorch
86+ """
4987 with timeout (minutes = TRAINING_DEFAULT_TIMEOUT_MINUTES ):
5088 data_path = os .path .join (DATA_DIR , "huggingface" )
5189
@@ -73,12 +111,7 @@ def test_huggingface_pytorch(
73111 compiler_config = HFTrainingCompilerConfig (),
74112 )
75113
76- train_input = hf .sagemaker_session .upload_data (
77- path = os .path .join (data_path , "train" ),
78- key_prefix = "integ-test-data/huggingface/train" ,
79- )
80-
81- hf .fit (train_input )
114+ hf .fit (huggingface_dummy_dataset )
82115
83116
84117@pytest .mark .release
@@ -87,7 +120,11 @@ def test_huggingface_tensorflow(
87120 gpu_instance_type ,
88121 huggingface_training_compiler_latest_version ,
89122 huggingface_training_compiler_tensorflow_latest_version ,
123+ huggingface_dummy_dataset ,
90124):
125+ """
126+ Test the HuggingFace estimator with TensorFlow
127+ """
91128 with timeout (minutes = TRAINING_DEFAULT_TIMEOUT_MINUTES ):
92129 data_path = os .path .join (DATA_DIR , "huggingface" )
93130
@@ -112,19 +149,19 @@ def test_huggingface_tensorflow(
112149 compiler_config = HFTrainingCompilerConfig (),
113150 )
114151
115- train_input = hf .sagemaker_session .upload_data (
116- path = os .path .join (data_path , "train" ), key_prefix = "integ-test-data/huggingface/train"
117- )
118-
119- hf .fit (train_input )
152+ hf .fit (huggingface_dummy_dataset )
120153
121154
122155@pytest .mark .release
123156def test_tensorflow (
124157 sagemaker_session ,
125158 gpu_instance_type ,
126159 tensorflow_training_latest_version ,
160+ imagenet_val_set ,
127161):
162+ """
163+ Test the TensorFlow estimator
164+ """
128165 if version .parse (tensorflow_training_latest_version ) < version .parse ("2.9" ):
129166 pytest .skip ("Training Compiler only supports TF >= 2.9" )
130167 with timeout (minutes = TRAINING_DEFAULT_TIMEOUT_MINUTES ):
@@ -173,7 +210,7 @@ def test_tensorflow(
173210 )
174211
175212 tf .fit (
176- inputs = "s3://collection-of-ml-datasets/Imagenet/TFRecords/validation" ,
213+ inputs = imagenet_val_set ,
177214 logs = True ,
178215 wait = True ,
179216 )
0 commit comments