2626from sagemaker .tensorflow .serving import Model , Predictor
2727
2828
29- @pytest .fixture (
30- scope = "session" ,
31- params = [
32- "ml.c5.xlarge" ,
33- pytest .param (
34- "ml.p3.2xlarge" ,
35- marks = pytest .mark .skipif (
36- tests .integ .test_region () in tests .integ .HOSTING_NO_P3_REGIONS ,
37- reason = "no ml.p3 instances in this region" ,
38- ),
39- ),
40- ],
41- )
42- def instance_type (request ):
43- return request .param
44-
45-
4629@pytest .fixture (scope = "module" )
47- def tfs_predictor (instance_type , sagemaker_session , tf_full_version ):
30+ def tfs_predictor (sagemaker_session , tf_full_version ):
4831 endpoint_name = sagemaker .utils .unique_name_from_base ("sagemaker-tensorflow-serving" )
4932 model_data = sagemaker_session .upload_data (
5033 path = os .path .join (tests .integ .DATA_DIR , "tensorflow-serving-test-model.tar.gz" ),
@@ -57,7 +40,7 @@ def tfs_predictor(instance_type, sagemaker_session, tf_full_version):
5740 framework_version = tf_full_version ,
5841 sagemaker_session = sagemaker_session ,
5942 )
60- predictor = model .deploy (1 , instance_type , endpoint_name = endpoint_name )
43+ predictor = model .deploy (1 , "ml.c5.xlarge" , endpoint_name = endpoint_name )
6144 yield predictor
6245
6346
@@ -130,8 +113,6 @@ def tfs_predictor_with_model_and_entry_point_and_dependencies(
130113@pytest .fixture (scope = "module" )
131114def tfs_predictor_with_accelerator (sagemaker_session , tf_full_version ):
132115 endpoint_name = sagemaker .utils .unique_name_from_base ("sagemaker-tensorflow-serving" )
133- instance_type = "ml.c4.large"
134- accelerator_type = "ml.eia1.medium"
135116 model_data = sagemaker_session .upload_data (
136117 path = os .path .join (tests .integ .DATA_DIR , "tensorflow-serving-test-model.tar.gz" ),
137118 key_prefix = "tensorflow-serving/models" ,
@@ -144,13 +125,13 @@ def tfs_predictor_with_accelerator(sagemaker_session, tf_full_version):
144125 sagemaker_session = sagemaker_session ,
145126 )
146127 predictor = model .deploy (
147- 1 , instance_type , endpoint_name = endpoint_name , accelerator_type = accelerator_type
128+ 1 , "ml.c4.large" , endpoint_name = endpoint_name , accelerator_type = "ml.eia1.medium"
148129 )
149130 yield predictor
150131
151132
152133@pytest .mark .canary_quick
153- def test_predict (tfs_predictor , instance_type ): # pylint: disable=W0613
134+ def test_predict (tfs_predictor ): # pylint: disable=W0613
154135 input_data = {"instances" : [1.0 , 2.0 , 5.0 ]}
155136 expected_result = {"predictions" : [3.5 , 4.0 , 5.5 ]}
156137
0 commit comments