Skip to content

Commit ce2639c

Browse files
committed
adding test cases for GPU local model inference with HF TGI and DJL lmi test cases
1 parent 1d7254c commit ce2639c

File tree

1 file changed

+233
-0
lines changed

1 file changed

+233
-0
lines changed
Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,233 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
import time
15+
from typing import Union
16+
17+
18+
import os
19+
import re
20+
import pytest
21+
import subprocess
22+
import logging
23+
import sagemaker
24+
import boto3
25+
import urllib3
26+
from pathlib import Path
27+
from sagemaker.huggingface import (
28+
HuggingFaceModel,
29+
get_huggingface_llm_image_uri
30+
)
31+
from sagemaker.deserializers import JSONDeserializer
32+
from sagemaker.local import LocalSession
33+
from sagemaker.serializers import JSONSerializer
34+
35+
36+
# Replace this role ARN with an appropriate role for your environment
37+
ROLE = "arn:aws:iam::111111111111:role/service-role/AmazonSageMaker-ExecutionRole-20200101T000001"
38+
39+
40+
def ensure_docker_compose_installed():
41+
"""
42+
Downloads the Docker Compose plugin if not present, and verifies installation
43+
by checking the output of 'docker compose version' matches the pattern:
44+
'Docker Compose version vX.Y.Z'
45+
"""
46+
47+
cli_plugins_path = Path.home() / ".docker" / "cli-plugins"
48+
cli_plugins_path.mkdir(parents=True, exist_ok=True)
49+
50+
compose_binary_path = cli_plugins_path / "docker-compose"
51+
if not compose_binary_path.exists():
52+
subprocess.run(
53+
[
54+
"curl",
55+
"-SL",
56+
"https://github.com/docker/compose/releases/download/v2.3.3/docker-compose-linux-x86_64",
57+
"-o",
58+
str(compose_binary_path),
59+
],
60+
check=True,
61+
)
62+
subprocess.run(["chmod", "+x", str(compose_binary_path)], check=True)
63+
64+
# Verify Docker Compose version
65+
try:
66+
output = subprocess.check_output(["docker", "compose", "version"], stderr=subprocess.STDOUT)
67+
output_decoded = output.decode("utf-8").strip()
68+
logging.info(f"'docker compose version' output: {output_decoded}")
69+
70+
# Example expected format: "Docker Compose version vxxx"
71+
pattern = r"Docker Compose version+"
72+
match = re.search(pattern, output_decoded)
73+
assert (
74+
match is not None
75+
), f"Could not find a Docker Compose version string matching '{pattern}' in: {output_decoded}"
76+
77+
except subprocess.CalledProcessError as e:
78+
raise AssertionError(f"Failed to verify Docker Compose: {e}")
79+
80+
81+
"""
82+
Local Model: HuggingFace LLM Inference
83+
"""
84+
@pytest.mark.local
85+
def test_huggingfacellm_local_model_inference():
86+
"""
87+
Test local mode inference with DJL-LMI inference containers
88+
without a model_data path provided at runtime. This test should
89+
be run on a GPU only machine with instance set to local_gpu.
90+
"""
91+
ensure_docker_compose_installed()
92+
93+
# 1. Create a local session for inference
94+
sagemaker_session = LocalSession()
95+
sagemaker_session.config = {"local": {"local_code": True}}
96+
97+
djllmi_model = sagemaker.Model(
98+
image_uri="763104351884.dkr.ecr.us-east-1.amazonaws.com/djl-inference:0.31.0-lmi13.0.0-cu124",
99+
env={
100+
"HF_MODEL_ID": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
101+
"OPTION_MAX_MODEL_LEN": "10000",
102+
"OPTION_GPU_MEMORY_UTILIZATION": "0.95",
103+
"OPTION_ENABLE_STREAMING": "false",
104+
"OPTION_ROLLING_BATCH": "auto",
105+
"OPTION_MODEL_LOADING_TIMEOUT": "3600",
106+
"OPTION_PAGED_ATTENTION": "false",
107+
"OPTION_DTYPE": "fp16",
108+
},
109+
role=ROLE,
110+
sagemaker_session=sagemaker_session
111+
)
112+
113+
logging.warning('Deploying endpoint in local mode')
114+
logging.warning(
115+
'Note: if launching for the first time in local mode, container image download might take a few minutes to complete.'
116+
)
117+
118+
endpoint_name = "test-djl"
119+
djllmi_model.deploy(
120+
endpoint_name=endpoint_name,
121+
initial_instance_count=1,
122+
instance_type="local_gpu",
123+
container_startup_health_check_timeout=600,
124+
)
125+
predictor = sagemaker.Predictor(
126+
endpoint_name=endpoint_name,
127+
sagemaker_session=sagemaker_session,
128+
serializer=JSONSerializer(),
129+
deserializer=JSONDeserializer(),
130+
)
131+
test_response = predictor.predict(
132+
{
133+
"inputs": """<|begin_of_text|>
134+
<|start_header_id|>system<|end_header_id|>
135+
You are a helpful assistant that thinks and reasons before answering.
136+
<|eot_id|>
137+
<|start_header_id|>user<|end_header_id|>
138+
What's 2x2?
139+
<|eot_id|>
140+
141+
<|start_header_id|>assistant<|end_header_id|>
142+
"""
143+
}
144+
)
145+
logging.warning(test_response)
146+
gen_text = test_response['generated_text']
147+
logging.warning(f"\n=======\nmodel response: {gen_text}\n=======\n")
148+
149+
assert type(test_response) == dict, f"invalid model response format: {gen_text}"
150+
assert type(gen_text) == str, f"assistant response format: {gen_text}"
151+
152+
logging.warning('About to delete the endpoint')
153+
predictor.delete_endpoint()
154+
155+
156+
"""
157+
Local Model: HuggingFace TGI Inference
158+
"""
159+
@pytest.mark.local
160+
def test_huggingfacetgi_local_model_inference():
161+
"""
162+
Test local mode inference with HuggingFace TGI inference containers
163+
without a model_data path provided at runtime. This test should
164+
be run on a GPU only machine with instance set to local_gpu.
165+
"""
166+
ensure_docker_compose_installed()
167+
168+
# 1. Create a local session for inference
169+
sagemaker_session = LocalSession()
170+
sagemaker_session.config = {"local": {"local_code": True}}
171+
172+
huggingface_model = HuggingFaceModel(
173+
image_uri=get_huggingface_llm_image_uri(
174+
"huggingface",
175+
version="2.3.1"
176+
),
177+
env={
178+
"HF_MODEL_ID": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
179+
"ENDPOINT_SERVER_TIMEOUT": "3600",
180+
"MESSAGES_API_ENABLED": "true",
181+
"OPTION_ENTRYPOINT": "inference.py",
182+
"SAGEMAKER_ENV": "1",
183+
"SAGEMAKER_MODEL_SERVER_WORKERS": "1",
184+
"SAGEMAKER_PROGRAM": "inference.py",
185+
"SM_NUM_GPUS": "1",
186+
"MAX_TOTAL_TOKENS": "1024",
187+
"MAX_INPUT_TOKENS": "800",
188+
"MAX_BATCH_PREFILL_TOKENS": "900",
189+
"DTYPE": "bfloat16",
190+
"PORT": "8080"
191+
},
192+
role=ROLE,
193+
sagemaker_session=sagemaker_session
194+
)
195+
196+
logging.warning('Deploying endpoint in local mode')
197+
logging.warning(
198+
'Note: if launching for the first time in local mode, container image download might take a few minutes to complete.'
199+
)
200+
201+
endpoint_name = "test-hf"
202+
huggingface_model.deploy(
203+
endpoint_name=endpoint_name,
204+
initial_instance_count=1,
205+
instance_type="local_gpu",
206+
container_startup_health_check_timeout=600,
207+
)
208+
predictor = sagemaker.Predictor(
209+
endpoint_name=endpoint_name,
210+
sagemaker_session=sagemaker_session,
211+
serializer=JSONSerializer(),
212+
deserializer=JSONDeserializer(),
213+
)
214+
test_response = predictor.predict(
215+
{
216+
"messages": [
217+
{"role": "system", "content": "You are a helpful assistant." },
218+
{"role": "user", "content": "What is 2x2?"}
219+
]
220+
}
221+
)
222+
logging.warning(test_response)
223+
gen_text = test_response['choices'][0]['message']
224+
logging.warning(f"\n=======\nmodel response: {gen_text}\n=======\n")
225+
226+
assert type(gen_text) == dict, f"invalid model response: {gen_text}"
227+
assert gen_text['role'] == 'assistant', f"assistant response missing: {gen_text}"
228+
229+
logging.warning('About to delete the endpoint')
230+
predictor.delete_endpoint()
231+
232+
233+

0 commit comments

Comments
 (0)