Skip to content

Commit 5b0c844

Browse files
authored
Merge branch 'main' into gate
2 parents e4b6e38 + 6a78c4e commit 5b0c844

File tree

5 files changed

+158
-253
lines changed

5 files changed

+158
-253
lines changed

.github/workflows/pr-vllm.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -935,4 +935,5 @@ jobs:
935935
- name: Run sagemaker endpoint test
936936
run: |
937937
source .venv/bin/activate
938-
python test/vllm/sagemaker/test_sm_endpoint.py --image-uri ${{ needs.set-sagemaker-test-environment.outputs.image-uri }} --endpoint-name test-sm-vllm-endpoint-${{ github.sha }}
938+
cd test/
939+
python3 -m pytest -vs -rA --image-uri ${{ needs.set-sagemaker-test-environment.outputs.image-uri }} vllm/sagemaker

test/sglang/sagemaker/test_sm_endpoint.py

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717
from pprint import pformat
1818

1919
import pytest
20-
from botocore.exceptions import ClientError
2120
from sagemaker.model import Model
2221
from sagemaker.predictor import Predictor
2322
from sagemaker.serializers import JSONSerializer
24-
from test_utils import clean_string, random_suffix_name, wait_for_status
23+
from test_utils import clean_string, get_hf_token, random_suffix_name, wait_for_status
24+
from test_utils.constants import INFERENCE_AMI_VERSION, SAGEMAKER_ROLE
2525

2626
# To enable debugging, change logging.INFO to logging.DEBUG
2727
LOGGER = logging.getLogger(__name__)
@@ -38,23 +38,6 @@ def get_endpoint_status(sagemaker_client, endpoint_name):
3838
return response["EndpointStatus"]
3939

4040

41-
def get_hf_token(aws_session):
42-
LOGGER.info("Retrieving HuggingFace token from AWS Secrets Manager...")
43-
token_path = "test/hf_token"
44-
45-
try:
46-
get_secret_value_response = aws_session.secretsmanager.get_secret_value(SecretId=token_path)
47-
LOGGER.info("Successfully retrieved HuggingFace token")
48-
except ClientError as e:
49-
LOGGER.error(f"Failed to retrieve HuggingFace token: {e}")
50-
raise e
51-
52-
# Do not print secrets token in logs
53-
response = json.loads(get_secret_value_response["SecretString"])
54-
token = response.get("HF_TOKEN")
55-
return token
56-
57-
5841
@pytest.fixture(scope="function")
5942
def model_id(request):
6043
# Return the model_id given by the test parameter
@@ -63,14 +46,13 @@ def model_id(request):
6346

6447
@pytest.fixture(scope="function")
6548
def instance_type(request):
66-
# Return the model_id given by the test parameter
49+
# Return the instance_type given by the test parameter
6750
return request.param
6851

6952

7053
@pytest.fixture(scope="function")
7154
def model_package(aws_session, image_uri, model_id):
7255
sagemaker_client = aws_session.sagemaker
73-
sagemaker_role = aws_session.iam_resource.Role("SageMakerRole").arn
7456
cleaned_id = clean_string(model_id.split("/")[1], "_./")
7557
model_name = random_suffix_name(f"sglang-{cleaned_id}-model-package", 50)
7658

@@ -82,7 +64,7 @@ def model_package(aws_session, image_uri, model_id):
8264
model = Model(
8365
name=model_name,
8466
image_uri=image_uri,
85-
role=sagemaker_role,
67+
role=SAGEMAKER_ROLE,
8668
predictor_cls=Predictor,
8769
env={
8870
"SM_SGLANG_MODEL_PATH": model_id,
@@ -111,7 +93,7 @@ def model_endpoint(aws_session, model_package, instance_type):
11193
instance_type=instance_type,
11294
initial_instance_count=1,
11395
endpoint_name=endpoint_name,
114-
inference_ami_version="al2-ami-sagemaker-inference-gpu-3-1",
96+
inference_ami_version=INFERENCE_AMI_VERSION,
11597
serializer=JSONSerializer(),
11698
wait=True,
11799
)

test/test_utils/__init__.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,18 @@
1616
When necessary, use docstrings to explain the functions' mechanisms.
1717
"""
1818

19+
import json
1920
import logging
2021
import random
2122
import string
2223
import time
2324
from collections.abc import Callable
2425
from typing import Any
2526

27+
from botocore.exceptions import ClientError
28+
29+
from .aws import AWSSessionManager
30+
2631
LOGGER = logging.getLogger(__name__)
2732
LOGGER.setLevel(logging.INFO)
2833

@@ -58,3 +63,20 @@ def wait_for_status(
5863

5964
LOGGER.error(f"Wait for status: {expected_status} timed out. Actual status: {actual_status}")
6065
return False
66+
67+
68+
def get_hf_token(aws_session: AWSSessionManager) -> str:
69+
LOGGER.info("Retrieving HuggingFace token from AWS Secrets Manager...")
70+
token_path = "test/hf_token"
71+
72+
try:
73+
get_secret_value_response = aws_session.secretsmanager.get_secret_value(SecretId=token_path)
74+
LOGGER.info("Successfully retrieved HuggingFace token")
75+
except ClientError as e:
76+
LOGGER.error(f"Failed to retrieve HuggingFace token: {e}")
77+
raise e
78+
79+
# Do not print secrets token in logs
80+
response = json.loads(get_secret_value_response["SecretString"])
81+
token = response.get("HF_TOKEN")
82+
return token

test/test_utils/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,4 @@
11
DEFAULT_REGION = "us-west-2"
2+
SAGEMAKER_ROLE = "SageMakerRole"
3+
# https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_ProductionVariant.html
4+
INFERENCE_AMI_VERSION = "al2-ami-sagemaker-inference-gpu-3-1"

0 commit comments

Comments
 (0)