Skip to content

Commit 7152db2

Browse files
author
Jonathan Makunga
committed
Integration tests
1 parent d997612 commit 7152db2

File tree

3 files changed

+117
-0
lines changed

3 files changed

+117
-0
lines changed

src/sagemaker/jumpstart/enums.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,10 @@ class JumpStartTag(str, Enum):
9393
MODEL_ID = "sagemaker-sdk:jumpstart-model-id"
9494
MODEL_VERSION = "sagemaker-sdk:jumpstart-model-version"
9595
MODEL_TYPE = "sagemaker-sdk:jumpstart-model-type"
96+
9697
INFERENCE_CONFIG_NAME = "sagemaker-sdk:jumpstart-inference-config-name"
9798
TRAINING_CONFIG_NAME = "sagemaker-sdk:jumpstart-training-config-name"
99+
98100
HUB_CONTENT_ARN = "sagemaker-sdk:hub-content-arn"
99101

100102

tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
14+
15+
import io
1416
import os
17+
import sys
1518
import time
1619
from unittest import mock
1720

@@ -349,3 +352,43 @@ def test_register_gated_jumpstart_model(setup):
349352
predictor.delete_predictor()
350353

351354
assert response is not None
355+
356+
357+
def test_jumpstart_model_with_deployment_configs(setup):
358+
model_id = "meta-textgeneration-llama-2-7b-f"
359+
360+
model = JumpStartModel(
361+
model_id=model_id,
362+
model_version="*",
363+
role=get_sm_session().get_caller_identity_arn(),
364+
sagemaker_session=get_sm_session(),
365+
)
366+
367+
captured_output = io.StringIO()
368+
sys.stdout = captured_output
369+
model.display_benchmark_metrics()
370+
sys.stdout = sys.__stdout__
371+
assert captured_output.getvalue() is not None
372+
373+
configs = model.list_deployment_configs()
374+
assert len(configs) > 0
375+
376+
model.set_deployment_config(
377+
configs[0]["ConfigName"],
378+
"ml.g5.2xlarge",
379+
)
380+
assert model.config_name == configs[0]["ConfigName"]
381+
382+
predictor = model.deploy(
383+
accept_eula=True,
384+
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
385+
)
386+
387+
payload = {
388+
"inputs": "some-payload",
389+
"parameters": {"max_new_tokens": 256, "top_p": 0.9, "temperature": 0.6},
390+
}
391+
392+
response = predictor.predict(payload, custom_attributes="accept_eula=true")
393+
394+
assert response is not None

tests/integ/sagemaker/serve/test_serve_js_happy.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
import io
16+
import sys
17+
1518
import pytest
1619

1720
from sagemaker.serve.builder.model_builder import ModelBuilder
@@ -54,6 +57,19 @@ def happy_model_builder(sagemaker_session):
5457
)
5558

5659

60+
@pytest.fixture
61+
def meta_textgeneration_llama_2_7b_f_schema():
62+
prompt = "Hello, I'm a language model,"
63+
response = "Hello, I'm a language model, and I'm here to help you with your English."
64+
sample_input = {"inputs": prompt}
65+
sample_output = [{"generated_text": response}]
66+
67+
return SchemaBuilder(
68+
sample_input=sample_input,
69+
sample_output=sample_output,
70+
)
71+
72+
5773
@pytest.fixture
5874
def happy_mms_model_builder(sagemaker_session):
5975
iam_client = sagemaker_session.boto_session.client("iam")
@@ -125,3 +141,59 @@ def test_happy_mms_sagemaker_endpoint(happy_mms_model_builder, gpu_instance_type
125141
)
126142
if caught_ex:
127143
raise caught_ex
144+
145+
146+
@pytest.mark.skipif(
147+
PYTHON_VERSION_IS_NOT_310,
148+
reason="The goal of these test are to test the serving components of our feature",
149+
)
150+
def test_js_model_with_deployment_configs(
151+
meta_textgeneration_llama_2_7b_f_schema,
152+
sagemaker_session,
153+
):
154+
logger.info("Running in SAGEMAKER_ENDPOINT mode...")
155+
caught_ex = None
156+
iam_client = sagemaker_session.boto_session.client("iam")
157+
role_arn = iam_client.get_role(RoleName=ROLE_NAME)["Role"]["Arn"]
158+
159+
model_builder = ModelBuilder(
160+
model="meta-textgeneration-llama-2-7b-f",
161+
schema_builder=meta_textgeneration_llama_2_7b_f_schema,
162+
)
163+
configs = model_builder.list_deployment_configs()
164+
165+
assert len(configs) > 0
166+
167+
captured_output = io.StringIO()
168+
sys.stdout = captured_output
169+
model_builder.display_benchmark_metrics()
170+
sys.stdout = sys.__stdout__
171+
assert captured_output.getvalue() is not None
172+
173+
model_builder.set_deployment_config(
174+
configs[0]["ConfigName"],
175+
"ml.g5.2xlarge",
176+
)
177+
model = model_builder.build(role_arn=role_arn, sagemaker_session=sagemaker_session)
178+
assert model.config_name == configs[0]["ConfigName"]
179+
assert model_builder.get_deployment_config() is not None
180+
181+
with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT):
182+
try:
183+
logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...")
184+
predictor = model.deploy(accept_eula=True)
185+
logger.info("Endpoint successfully deployed.")
186+
187+
updated_sample_input = model_builder.schema_builder.sample_input
188+
189+
predictor.predict(updated_sample_input)
190+
except Exception as e:
191+
caught_ex = e
192+
finally:
193+
cleanup_model_resources(
194+
sagemaker_session=sagemaker_session,
195+
model_name=model.name,
196+
endpoint_name=model.endpoint_name,
197+
)
198+
if caught_ex:
199+
raise caught_ex

0 commit comments

Comments
 (0)