1212# language governing permissions and limitations under the License.
1313from __future__ import absolute_import
1414
15- import os
16-
1715import numpy
16+ import os
1817import pytest
19- from tests .integ import DATA_DIR , PYTHON_VERSION , TRAINING_DEFAULT_TIMEOUT_MINUTES
20- from tests .integ .timeout import timeout , timeout_and_delete_endpoint_by_name
21-
18+ from sagemaker .pytorch .defaults import LATEST_PY2_VERSION
2219from sagemaker .pytorch .estimator import PyTorch
2320from sagemaker .pytorch .model import PyTorchModel
24- from sagemaker .pytorch .defaults import LATEST_PY2_VERSION
2521from sagemaker .utils import sagemaker_timestamp
2622
23+ from tests .integ import (
24+ test_region ,
25+ DATA_DIR ,
26+ PYTHON_VERSION ,
27+ TRAINING_DEFAULT_TIMEOUT_MINUTES ,
28+ EI_SUPPORTED_REGIONS ,
29+ )
30+ from tests .integ .timeout import timeout , timeout_and_delete_endpoint_by_name
31+
2732MNIST_DIR = os .path .join (DATA_DIR , "pytorch_mnist" )
2833MNIST_SCRIPT = os .path .join (MNIST_DIR , "mnist.py" )
2934
@@ -120,6 +125,9 @@ def test_deploy_model(pytorch_training_job, sagemaker_session, cpu_instance_type
120125
121126
122127@pytest .mark .skipif (PYTHON_VERSION == "py2" , reason = "PyTorch EIA does not support Python 2." )
128+ @pytest .mark .skipif (
129+ test_region () not in EI_SUPPORTED_REGIONS , reason = "EI isn't supported in that specific region."
130+ )
123131def test_deploy_model_with_accelerator (sagemaker_session , cpu_instance_type ):
124132 endpoint_name = "test-pytorch-deploy-eia-{}" .format (sagemaker_timestamp ())
125133 model_data = sagemaker_session .upload_data (path = EIA_MODEL )
@@ -134,7 +142,7 @@ def test_deploy_model_with_accelerator(sagemaker_session, cpu_instance_type):
134142 predictor = pytorch .deploy (
135143 initial_instance_count = 1 ,
136144 instance_type = cpu_instance_type ,
137- accelerator_type = "ml.eia2 .medium" ,
145+ accelerator_type = "ml.eia1 .medium" ,
138146 endpoint_name = endpoint_name ,
139147 )
140148
0 commit comments