1919from sagemaker .huggingface import HuggingFace , HuggingFaceProcessor
2020from sagemaker .huggingface .model import HuggingFaceModel , HuggingFacePredictor
2121from sagemaker .utils import unique_name_from_base
22- from tests import integ
23- from tests .integ .utils import gpu_list , retry_with_instance_list
2422from tests .integ import DATA_DIR , TRAINING_DEFAULT_TIMEOUT_MINUTES
2523from tests .integ .timeout import timeout , timeout_and_delete_endpoint_by_name
2624
2725ROLE = "SageMakerRole"
2826
2927
3028@pytest .mark .release
31- @pytest .mark .skipif (
32- integ .test_region () in integ .TRAINING_NO_P2_REGIONS
33- and integ .test_region () in integ .TRAINING_NO_P3_REGIONS ,
34- reason = "no ml.p2 or ml.p3 instances in this region" ,
35- )
36- @retry_with_instance_list (gpu_list (integ .test_region ()))
3729def test_framework_processing_job_with_deps (
3830 sagemaker_session ,
3931 huggingface_training_latest_version ,
4032 huggingface_training_pytorch_latest_version ,
4133 huggingface_pytorch_latest_training_py_version ,
42- ** kwargs ,
34+ gpu_pytorch_instance_type ,
4335):
4436 with timeout (minutes = TRAINING_DEFAULT_TIMEOUT_MINUTES ):
4537 code_path = os .path .join (DATA_DIR , "dummy_code_bundle_with_reqs" )
@@ -51,7 +43,7 @@ def test_framework_processing_job_with_deps(
5143 py_version = huggingface_pytorch_latest_training_py_version ,
5244 role = ROLE ,
5345 instance_count = 1 ,
54- instance_type = kwargs [ "instance_type" ] ,
46+ instance_type = gpu_pytorch_instance_type ,
5547 sagemaker_session = sagemaker_session ,
5648 base_job_name = "test-huggingface" ,
5749 )
@@ -64,18 +56,12 @@ def test_framework_processing_job_with_deps(
6456
6557
6658@pytest .mark .release
67- @pytest .mark .skipif (
68- integ .test_region () in integ .TRAINING_NO_P2_REGIONS
69- and integ .test_region () in integ .TRAINING_NO_P3_REGIONS ,
70- reason = "no ml.p2 or ml.p3 instances in this region" ,
71- )
72- @retry_with_instance_list (gpu_list (integ .test_region ()))
7359def test_huggingface_training (
7460 sagemaker_session ,
7561 huggingface_training_latest_version ,
7662 huggingface_training_pytorch_latest_version ,
7763 huggingface_pytorch_latest_training_py_version ,
78- ** kwargs ,
64+ gpu_pytorch_instance_type ,
7965):
8066 with timeout (minutes = TRAINING_DEFAULT_TIMEOUT_MINUTES ):
8167 data_path = os .path .join (DATA_DIR , "huggingface" )
@@ -87,7 +73,7 @@ def test_huggingface_training(
8773 transformers_version = huggingface_training_latest_version ,
8874 pytorch_version = huggingface_training_pytorch_latest_version ,
8975 instance_count = 1 ,
90- instance_type = kwargs [ "instance_type" ] ,
76+ instance_type = gpu_pytorch_instance_type ,
9177 hyperparameters = {
9278 "model_name_or_path" : "distilbert-base-cased" ,
9379 "task_name" : "wnli" ,
@@ -111,14 +97,6 @@ def test_huggingface_training(
11197
11298
11399@pytest .mark .release
114- @pytest .mark .skipif (
115- integ .test_region () in integ .TRAINING_NO_P2_REGIONS
116- and integ .test_region () in integ .TRAINING_NO_P3_REGIONS ,
117- reason = "no ml.p2 or ml.p3 instances in this region" ,
118- )
119- @pytest .mark .skip (
120- reason = "need to re enable it later t.corp:V609860141" ,
121- )
122100def test_huggingface_training_tf (
123101 sagemaker_session ,
124102 gpu_instance_type ,
@@ -161,7 +139,7 @@ def test_huggingface_training_tf(
161139)
162140def test_huggingface_inference (
163141 sagemaker_session ,
164- gpu_instance_type ,
142+ gpu_pytorch_instance_type ,
165143 huggingface_inference_latest_version ,
166144 huggingface_inference_pytorch_latest_version ,
167145 huggingface_pytorch_latest_inference_py_version ,
@@ -182,7 +160,9 @@ def test_huggingface_inference(
182160 )
183161 with timeout_and_delete_endpoint_by_name (endpoint_name , sagemaker_session ):
184162 model .deploy (
185- instance_type = gpu_instance_type , initial_instance_count = 1 , endpoint_name = endpoint_name
163+ instance_type = gpu_pytorch_instance_type ,
164+ initial_instance_count = 1 ,
165+ endpoint_name = endpoint_name ,
186166 )
187167
188168 predictor = HuggingFacePredictor (endpoint_name = endpoint_name )
0 commit comments