4848
4949
5050@pytest .fixture (scope = "session" )
51- def add_models ():
51+ def add_model_references ():
5252 # Create Model References to test in Hub
5353 hub_instance = Hub (
5454 hub_name = os .environ [ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME ], sagemaker_session = get_sm_session ()
@@ -57,27 +57,27 @@ def add_models():
5757 hub_instance .create_model_reference (model_arn = get_public_hub_model_arn (hub_instance , model ))
5858
5959
60- def test_jumpstart_hub_model (setup , add_models ):
61-
62- JUMPSTART_LOGGER .info ("starting test" )
63- JUMPSTART_LOGGER .info (f"get identity { get_sm_session ().get_caller_identity_arn ()} " )
60+ def test_jumpstart_hub_model (setup , add_model_references ):
6461
6562 model_id = "catboost-classification-model"
6663
64+ sagemaker_session = get_sm_session ()
65+
6766 model = JumpStartModel (
6867 model_id = model_id ,
69- role = get_sm_session () .get_caller_identity_arn (),
70- sagemaker_session = get_sm_session () ,
68+ role = sagemaker_session .get_caller_identity_arn (),
69+ sagemaker_session = sagemaker_session ,
7170 hub_name = os .environ [ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME ],
7271 )
7372
74- # uses ml.m5.4xlarge instance
75- model .deploy (
73+ predictor = model .deploy (
7674 tags = [{"Key" : JUMPSTART_TAG , "Value" : os .environ [ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID ]}],
7775 )
7876
77+ assert sagemaker_session .endpoint_in_service_or_not (predictor .endpoint_name )
7978
80- def test_jumpstart_hub_gated_model (setup , add_models ):
79+
80+ def test_jumpstart_hub_gated_model (setup , add_model_references ):
8181
8282 model_id = "meta-textgeneration-llama-3-2-1b"
8383
@@ -88,23 +88,19 @@ def test_jumpstart_hub_gated_model(setup, add_models):
8888 hub_name = os .environ [ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME ],
8989 )
9090
91- # uses ml.g6.xlarge instance
9291 predictor = model .deploy (
9392 accept_eula = True ,
9493 tags = [{"Key" : JUMPSTART_TAG , "Value" : os .environ [ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID ]}],
9594 )
9695
97- payload = {
98- "inputs" : "some-payload" ,
99- "parameters" : {"max_new_tokens" : 256 , "top_p" : 0.9 , "temperature" : 0.6 },
100- }
96+ payload = model .retrieve_example_payload ()
10197
102- response = predictor .predict (payload , custom_attributes = "accept_eula=true" )
98+ response = predictor .predict (payload )
10399
104100 assert response is not None
105101
106102
107- def test_jumpstart_gated_model_inference_component_enabled (setup , add_models ):
103+ def test_jumpstart_gated_model_inference_component_enabled (setup , add_model_references ):
108104
109105 model_id = "meta-textgeneration-llama-2-7b"
110106
@@ -125,7 +121,6 @@ def test_jumpstart_gated_model_inference_component_enabled(setup, add_models):
125121 hub_name = os .environ [ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME ],
126122 )
127123
128- # uses ml.g5.2xlarge instance
129124 model .deploy (
130125 tags = [{"Key" : JUMPSTART_TAG , "Value" : os .environ [ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID ]}],
131126 accept_eula = True ,
@@ -139,10 +134,7 @@ def test_jumpstart_gated_model_inference_component_enabled(setup, add_models):
139134 hub_arn = hub_arn ,
140135 )
141136
142- payload = {
143- "inputs" : "some-payload" ,
144- "parameters" : {"max_new_tokens" : 256 , "top_p" : 0.9 , "temperature" : 0.6 },
145- }
137+ payload = model .retrieve_example_payload ()
146138
147139 response = predictor .predict (payload )
148140
@@ -156,7 +148,7 @@ def test_jumpstart_gated_model_inference_component_enabled(setup, add_models):
156148 assert model .inference_component_name == predictor .component_name
157149
158150
159- def test_instatiating_model (setup , add_models ):
151+ def test_instantiating_model (setup , add_model_references ):
160152
161153 model_id = "catboost-regression-model"
162154
0 commit comments