1717from pprint import pformat
1818
1919import pytest
20- from botocore .exceptions import ClientError
2120from sagemaker .model import Model
2221from sagemaker .predictor import Predictor
2322from sagemaker .serializers import JSONSerializer
24- from test_utils import clean_string , random_suffix_name , wait_for_status
23+ from test_utils import clean_string , get_hf_token , random_suffix_name , wait_for_status
24+ from test_utils .constants import INFERENCE_AMI_VERSION , SAGEMAKER_ROLE
2525
2626# To enable debugging, change logging.INFO to logging.DEBUG
2727LOGGER = logging .getLogger (__name__ )
@@ -38,23 +38,6 @@ def get_endpoint_status(sagemaker_client, endpoint_name):
3838 return response ["EndpointStatus" ]
3939
4040
41- def get_hf_token (aws_session ):
42- LOGGER .info ("Retrieving HuggingFace token from AWS Secrets Manager..." )
43- token_path = "test/hf_token"
44-
45- try :
46- get_secret_value_response = aws_session .secretsmanager .get_secret_value (SecretId = token_path )
47- LOGGER .info ("Successfully retrieved HuggingFace token" )
48- except ClientError as e :
49- LOGGER .error (f"Failed to retrieve HuggingFace token: { e } " )
50- raise e
51-
52- # Do not print secrets token in logs
53- response = json .loads (get_secret_value_response ["SecretString" ])
54- token = response .get ("HF_TOKEN" )
55- return token
56-
57-
5841@pytest .fixture (scope = "function" )
5942def model_id (request ):
6043 # Return the model_id given by the test parameter
@@ -63,14 +46,13 @@ def model_id(request):
6346
6447@pytest .fixture (scope = "function" )
6548def instance_type (request ):
66- # Return the model_id given by the test parameter
49+ # Return the instance_type given by the test parameter
6750 return request .param
6851
6952
7053@pytest .fixture (scope = "function" )
7154def model_package (aws_session , image_uri , model_id ):
7255 sagemaker_client = aws_session .sagemaker
73- sagemaker_role = aws_session .iam_resource .Role ("SageMakerRole" ).arn
7456 cleaned_id = clean_string (model_id .split ("/" )[1 ], "_./" )
7557 model_name = random_suffix_name (f"sglang-{ cleaned_id } -model-package" , 50 )
7658
@@ -82,7 +64,7 @@ def model_package(aws_session, image_uri, model_id):
8264 model = Model (
8365 name = model_name ,
8466 image_uri = image_uri ,
85- role = sagemaker_role ,
67+ role = SAGEMAKER_ROLE ,
8668 predictor_cls = Predictor ,
8769 env = {
8870 "SM_SGLANG_MODEL_PATH" : model_id ,
@@ -111,7 +93,7 @@ def model_endpoint(aws_session, model_package, instance_type):
11193 instance_type = instance_type ,
11294 initial_instance_count = 1 ,
11395 endpoint_name = endpoint_name ,
114- inference_ami_version = "al2-ami-sagemaker-inference-gpu-3-1" ,
96+ inference_ami_version = INFERENCE_AMI_VERSION ,
11597 serializer = JSONSerializer (),
11698 wait = True ,
11799 )
0 commit comments