|
12 | 12 | # language governing permissions and limitations under the License.
|
13 | 13 | from __future__ import absolute_import
|
14 | 14 |
|
| 15 | +import io |
| 16 | +import sys |
| 17 | + |
15 | 18 | import pytest
|
16 | 19 |
|
17 | 20 | from sagemaker.serve.builder.model_builder import ModelBuilder
|
@@ -54,6 +57,19 @@ def happy_model_builder(sagemaker_session):
|
54 | 57 | )
|
55 | 58 |
|
56 | 59 |
|
| 60 | +@pytest.fixture |
| 61 | +def meta_textgeneration_llama_2_7b_f_schema(): |
| 62 | + prompt = "Hello, I'm a language model," |
| 63 | + response = "Hello, I'm a language model, and I'm here to help you with your English." |
| 64 | + sample_input = {"inputs": prompt} |
| 65 | + sample_output = [{"generated_text": response}] |
| 66 | + |
| 67 | + return SchemaBuilder( |
| 68 | + sample_input=sample_input, |
| 69 | + sample_output=sample_output, |
| 70 | + ) |
| 71 | + |
| 72 | + |
57 | 73 | @pytest.fixture
|
58 | 74 | def happy_mms_model_builder(sagemaker_session):
|
59 | 75 | iam_client = sagemaker_session.boto_session.client("iam")
|
@@ -125,3 +141,59 @@ def test_happy_mms_sagemaker_endpoint(happy_mms_model_builder, gpu_instance_type
|
125 | 141 | )
|
126 | 142 | if caught_ex:
|
127 | 143 | raise caught_ex
|
| 144 | + |
| 145 | + |
| 146 | +@pytest.mark.skipif( |
| 147 | + PYTHON_VERSION_IS_NOT_310, |
| 148 | + reason="The goal of these test are to test the serving components of our feature", |
| 149 | +) |
| 150 | +def test_js_model_with_deployment_configs( |
| 151 | + meta_textgeneration_llama_2_7b_f_schema, |
| 152 | + sagemaker_session, |
| 153 | +): |
| 154 | + logger.info("Running in SAGEMAKER_ENDPOINT mode...") |
| 155 | + caught_ex = None |
| 156 | + iam_client = sagemaker_session.boto_session.client("iam") |
| 157 | + role_arn = iam_client.get_role(RoleName=ROLE_NAME)["Role"]["Arn"] |
| 158 | + |
| 159 | + model_builder = ModelBuilder( |
| 160 | + model="meta-textgeneration-llama-2-7b-f", |
| 161 | + schema_builder=meta_textgeneration_llama_2_7b_f_schema, |
| 162 | + ) |
| 163 | + configs = model_builder.list_deployment_configs() |
| 164 | + |
| 165 | + assert len(configs) > 0 |
| 166 | + |
| 167 | + captured_output = io.StringIO() |
| 168 | + sys.stdout = captured_output |
| 169 | + model_builder.display_benchmark_metrics() |
| 170 | + sys.stdout = sys.__stdout__ |
| 171 | + assert captured_output.getvalue() is not None |
| 172 | + |
| 173 | + model_builder.set_deployment_config( |
| 174 | + configs[0]["ConfigName"], |
| 175 | + "ml.g5.2xlarge", |
| 176 | + ) |
| 177 | + model = model_builder.build(role_arn=role_arn, sagemaker_session=sagemaker_session) |
| 178 | + assert model.config_name == configs[0]["ConfigName"] |
| 179 | + assert model_builder.get_deployment_config() is not None |
| 180 | + |
| 181 | + with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT): |
| 182 | + try: |
| 183 | + logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...") |
| 184 | + predictor = model.deploy(accept_eula=True) |
| 185 | + logger.info("Endpoint successfully deployed.") |
| 186 | + |
| 187 | + updated_sample_input = model_builder.schema_builder.sample_input |
| 188 | + |
| 189 | + predictor.predict(updated_sample_input) |
| 190 | + except Exception as e: |
| 191 | + caught_ex = e |
| 192 | + finally: |
| 193 | + cleanup_model_resources( |
| 194 | + sagemaker_session=sagemaker_session, |
| 195 | + model_name=model.name, |
| 196 | + endpoint_name=model.endpoint_name, |
| 197 | + ) |
| 198 | + if caught_ex: |
| 199 | + raise caught_ex |
0 commit comments